提交 302046aa 编写于 作者: D dongzhihong

"fix return net error"

上级 11974205
...@@ -60,6 +60,16 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker { ...@@ -60,6 +60,16 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class NoGradOpMaker : public OpProtoAndCheckerMaker {
public:
NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "X input");
AddOutput("Y", "Y output");
AddComment("NoGradOp, same input output. no Grad");
}
};
class FcOp : public NetOp { class FcOp : public NetOp {
public: public:
void Init() override { void Init() override {
...@@ -139,6 +149,7 @@ REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); ...@@ -139,6 +149,7 @@ REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp);
REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
...@@ -266,9 +277,11 @@ TEST(Backward, net_shared_weight) { ...@@ -266,9 +277,11 @@ TEST(Backward, net_shared_weight) {
} }
TEST(Backward, op_register_grad_not_for_network) { TEST(Backward, op_register_grad_not_for_network) {
auto fwd = // auto fwd =
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, // f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
{{"temporary_index", std::vector<int>{1}}}); // {{"temporary_index", std::vector<int>{1}}});
auto fwd = f::OpRegistry::CreateOp("nograd", {"x"}, {"x"}, {});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
} }
...@@ -316,11 +329,7 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -316,11 +329,7 @@ TEST(Backward, op_part_of_output_are_not_need) {
TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"a"}); auto backward = f::Backward(*fwd, {"a"});
ASSERT_FALSE(backward->IsNetOp()); auto &grad_mul = *backward;
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 1UL);
auto &grad_mul = *net->ops_[0];
ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册