提交 8bf209f9 编写于 作者: D dzhwinter 提交者: GitHub

Merge pull request #4548 from dzhwinter/fix_backward2

add generic add operator
...@@ -141,9 +141,27 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -141,9 +141,27 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net->ops_[op_offset]->Rename(name, dup_outputs.back()); net->ops_[op_offset]->Rename(name, dup_outputs.back());
} }
// collect all the offset to append `add` op for each alias // collect all the offset to append `add` op for each alias
insert_position.push_back( //
{dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}}, // one variable is shared between multiple operators.
{{"Out", {name}}}, {})}); // insert add operator one by one, then add it to output
for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1;
++output_idx) {
auto insert_add_x = dup_outputs[output_idx];
auto insert_add_y = dup_outputs[output_idx];
auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx);
// first add op inserted
if (output_idx == dup_outputs.size() - 2) {
insert_add_out = name;
}
if (output_idx != 0) {
insert_add_y = name + "@SHARED@" + std::to_string(output_idx - 1);
}
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}},
{{"Out", {insert_add_out}}}, {})});
}
} }
// make sure the inserted `add` ops follow the BFS order. // make sure the inserted `add` ops follow the BFS order.
......
...@@ -133,15 +133,18 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { ...@@ -133,15 +133,18 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class AddOpMaker : public OpProtoAndCheckerMaker { class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").AsDuplicable(); AddInput("X", "the input tensors of sum operator.")
AddOutput("Out", "out"); .AsDuplicable()
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddComment(""); AddComment("");
} }
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -154,7 +157,7 @@ REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); ...@@ -154,7 +157,7 @@ REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker);
REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP); REGISTER_OP(sum, f::NOP, f::SumOpMaker, sum_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP); f::NOP);
...@@ -283,7 +286,7 @@ TEST(Backward, net_shared_weight) { ...@@ -283,7 +286,7 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get()); auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->Type()); ASSERT_EQ("sum", bwd_net->ops_[2]->Type());
} }
TEST(Backward, op_register_grad_not_for_network) { TEST(Backward, op_register_grad_not_for_network) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册