提交 dc06eaa0 编写于 作者: F fengjiayi

Merge branch 'feature/backward' of https://github.com/reyoung/Paddle into feature/backward

......@@ -60,6 +60,16 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
}
};
class NoGradOpMaker : public OpProtoAndCheckerMaker {
public:
NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "X input");
AddOutput("Y", "Y output");
AddComment("NoGradOp, same input output. no Grad");
}
};
class FcOp : public NetOp {
public:
void Init() override {
......@@ -139,6 +149,7 @@ REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp);
REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
......@@ -266,9 +277,11 @@ TEST(Backward, net_shared_weight) {
}
TEST(Backward, op_register_grad_not_for_network) {
auto fwd =
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
{{"temporary_index", std::vector<int>{1}}});
// auto fwd =
// f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
// {{"temporary_index", std::vector<int>{1}}});
auto fwd = f::OpRegistry::CreateOp("nograd", {"x"}, {"x"}, {});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}
......@@ -317,11 +330,7 @@ TEST(Backward, op_part_of_output_are_not_need) {
TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"a"});
ASSERT_FALSE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 1UL);
auto &grad_mul = *net->ops_[0];
auto &grad_mul = *backward;
ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册