提交 7087a043 编写于 作者: D dongzhihong

"add unittest"

上级 74cd9a75
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/net.h" #include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
...@@ -161,12 +162,23 @@ TEST(Backward, simple_op_grad) { ...@@ -161,12 +162,23 @@ TEST(Backward, simple_op_grad) {
} }
TEST(Backward, simple_op_not_need_grad) { TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"x", "b"}, {"out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {"x"}); auto gop = f::Backward(*fwd, {"X"});
LOG(INFO) << gop->DebugString(); LOG(INFO) << "full " << gop->DebugString();
ASSERT_NE(gop->outputs_.find("x" + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_NE(std::find(gop->outputs_.begin(), gop->outputs_.end(),
"X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
gop->outputs_.end()); gop->outputs_.end());
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
LOG(INFO) << "no input gop " << no_input_gop->DebugString();
ASSERT_NE(no_input_gop, nullptr);
ASSERT_EQ(std::vector<std::string>{}, no_input_gop->outputs_);
ASSERT_EQ(
std::vector<std::string>{"Out" + f::OperatorBase::GRAD_VAR_SUFFIX()},
no_input_gop->inputs_);
// auto no_output_gop = f::Backward(*fwd, {"Out"});
// ASSERT_EQ(std::vector<std::string>{"X" +
// f::OperatorBase::GRAD_VAR_SUFFIX(), "b"})
} }
TEST(Backward, net_fc_backward_normal) { TEST(Backward, net_fc_backward_normal) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册