提交 ff19223e 编写于 作者: Y Yi Wang

Reforamt

上级 e4aea7fd
...@@ -59,16 +59,14 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -59,16 +59,14 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// If all input gradients of forwarding operator do not need to calculate, // If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take // just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic. // too much time for calculation, but it is useful for simplifying logic.
if (AllInSet(forwardOp.inputs_, kGradVarSuffix, if (AllInSet(forwardOp.inputs_, kGradVarSuffix, no_grad_names)) {
no_grad_names)) {
return NOP(); return NOP();
} }
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into // Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP. // `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, kGradVarSuffix, if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) {
no_grad_names)) {
for (auto& name : forwardOp.inputs_) { for (auto& name : forwardOp.inputs_) {
// Mark all input is not need // Mark all input is not need
no_grad_names.insert(name + kGradVarSuffix); no_grad_names.insert(name + kGradVarSuffix);
...@@ -134,8 +132,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -134,8 +132,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
for (std::string& grad_input : grad_op->inputs_) { for (std::string& grad_input : grad_op->inputs_) {
if (no_grad_names.count(grad_input)) { if (no_grad_names.count(grad_input)) {
std::string prefix = grad_input.substr( std::string prefix =
0, grad_input.size() - kGradVarSuffix.size()); grad_input.substr(0, grad_input.size() - kGradVarSuffix.size());
grad_input = prefix + kZeroVarSuffix; grad_input = prefix + kZeroVarSuffix;
// If part of input gradient of that operator is not calculated, fill // If part of input gradient of that operator is not calculated, fill
...@@ -168,8 +166,7 @@ std::shared_ptr<OperatorBase> Backward( ...@@ -168,8 +166,7 @@ std::shared_ptr<OperatorBase> Backward(
std::unordered_set<std::string> no_grad_names; std::unordered_set<std::string> no_grad_names;
no_grad_names.reserve(no_grad_vars.size()); no_grad_names.reserve(no_grad_vars.size());
no_grad_names.insert(kEmptyVarName + no_grad_names.insert(kEmptyVarName + kGradVarSuffix);
kGradVarSuffix);
for (auto& name : no_grad_vars) { for (auto& name : no_grad_vars) {
no_grad_names.insert(name + kGradVarSuffix); no_grad_names.insert(name + kGradVarSuffix);
...@@ -177,5 +174,6 @@ std::shared_ptr<OperatorBase> Backward( ...@@ -177,5 +174,6 @@ std::shared_ptr<OperatorBase> Backward(
size_t uid = 0; size_t uid = 0;
return BackwardRecursive(forwardOp, no_grad_names, uid); return BackwardRecursive(forwardOp, no_grad_names, uid);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -168,8 +168,7 @@ TEST(Backward, simple_op_grad) { ...@@ -168,8 +168,7 @@ TEST(Backward, simple_op_grad) {
ASSERT_EQ("X" + f::kGradVarSuffix, gop->outputs_[0]); ASSERT_EQ("X" + f::kGradVarSuffix, gop->outputs_[0]);
ASSERT_EQ("b" + f::kGradVarSuffix, gop->outputs_[1]); ASSERT_EQ("b" + f::kGradVarSuffix, gop->outputs_[1]);
ASSERT_EQ("X" + f::kGradVarSuffix, ASSERT_EQ("X" + f::kGradVarSuffix, gop->Output("X" + f::kGradVarSuffix));
gop->Output("X" + f::kGradVarSuffix));
} }
TEST(Backward, simple_op_not_need_grad) { TEST(Backward, simple_op_not_need_grad) {
...@@ -210,9 +209,9 @@ TEST(Backward, net_fc_backward_normal) { ...@@ -210,9 +209,9 @@ TEST(Backward, net_fc_backward_normal) {
} }
TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_fc_backward_not_have_b) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> fwd =
"fc", {"X", "w", f::kEmptyVarName}, f::OpRegistry::CreateOp("fc", {"X", "w", f::kEmptyVarName},
{"mul_result", "add_result", "tmp"}, {}); {"mul_result", "add_result", "tmp"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
...@@ -245,21 +244,18 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -245,21 +244,18 @@ TEST(Backward, net_input_of_network_not_need_grad) {
all_output.erase(f::kEmptyVarName); all_output.erase(f::kEmptyVarName);
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
ASSERT_NE(all_output.find(out + f::kGradVarSuffix), ASSERT_NE(all_output.find(out + f::kGradVarSuffix), all_output.end());
all_output.end());
} }
// Not Generated X // Not Generated X
ASSERT_EQ(all_output.find("X" + f::kGradVarSuffix), ASSERT_EQ(all_output.find("X" + f::kGradVarSuffix), all_output.end());
all_output.end());
ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_EQ(2UL, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get()); auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size()); ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ( ASSERT_EQ(f::kEmptyVarName,
f::kEmptyVarName, first_fc_grad->ops_[2]->Output("A" + f::kGradVarSuffix));
first_fc_grad->ops_[2]->Output("A" + f::kGradVarSuffix));
} }
TEST(Backward, net_shared_weight) { TEST(Backward, net_shared_weight) {
...@@ -316,10 +312,8 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -316,10 +312,8 @@ TEST(Backward, op_part_of_output_are_not_need) {
auto &d_many_out = *net->ops_[1]; auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.type_); ASSERT_EQ("many_output_op_grad", d_many_out.type_);
ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
ASSERT_EQ("Z" + f::kZeroVarSuffix, ASSERT_EQ("Z" + f::kZeroVarSuffix, d_many_out.Input("z" + f::kGradVarSuffix));
d_many_out.Input("z" + f::kGradVarSuffix)); ASSERT_EQ("Y" + f::kGradVarSuffix, d_many_out.Input("y" + f::kGradVarSuffix));
ASSERT_EQ("Y" + f::kGradVarSuffix,
d_many_out.Input("y" + f::kGradVarSuffix));
ASSERT_EQ("X" + f::kGradVarSuffix, ASSERT_EQ("X" + f::kGradVarSuffix,
d_many_out.Output("x" + f::kGradVarSuffix)); d_many_out.Output("x" + f::kGradVarSuffix));
} }
...@@ -331,10 +325,8 @@ TEST(Backward, op_part_of_input_are_not_need) { ...@@ -331,10 +325,8 @@ TEST(Backward, op_part_of_input_are_not_need) {
ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
ASSERT_EQ(grad_mul.Output("A" + f::kGradVarSuffix), ASSERT_EQ(grad_mul.Output("A" + f::kGradVarSuffix), f::kEmptyVarName);
f::kEmptyVarName); ASSERT_EQ(grad_mul.Output("B" + f::kGradVarSuffix), "b" + f::kGradVarSuffix);
ASSERT_EQ(grad_mul.Output("B" + f::kGradVarSuffix),
"b" + f::kGradVarSuffix);
ASSERT_EQ(grad_mul.Input("Out" + f::kGradVarSuffix), ASSERT_EQ(grad_mul.Input("Out" + f::kGradVarSuffix),
"out" + f::kGradVarSuffix); "out" + f::kGradVarSuffix);
ASSERT_EQ(grad_mul.Input("A"), "a"); ASSERT_EQ(grad_mul.Input("A"), "a");
...@@ -368,23 +360,4 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -368,23 +360,4 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL);
/*
EXPECT_EQ(grad_fc.Output("X" + f::kGradVarSuffix),
f::kEmptyVarName);
EXPECT_EQ(grad_fc.Output("W" + f::kGradVarSuffix),
"w3" + f::kGradVarSuffix);
EXPECT_EQ(grad_fc.Output("b" + f::kGradVarSuffix),
"b3" + f::kGradVarSuffix);
EXPECT_EQ(grad_fc.Output("mul_result" + f::kGradVarSuffix),
"mul_out3" + f::kGradVarSuffix);
EXPECT_EQ(grad_fc.Input("Out" + f::kGradVarSuffix),
"out3" + f::kGradVarSuffix);
EXPECT_EQ(grad_fc.Input("X"), "out2");
EXPECT_EQ(grad_fc.Input("W"), "w3");
EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3");
EXPECT_EQ(grad_fc.Input("add_result"), "tmp_out3");
EXPECT_EQ(grad_fc.Input("Out"), "out3");
*/
} }
...@@ -56,8 +56,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, ...@@ -56,8 +56,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
std::string src_name = arg.name(); std::string src_name = arg.name();
std::string dst_name = std::string dst_name = is_grad ? src_name + kGradVarSuffix : src_name;
is_grad ? src_name + kGradVarSuffix : src_name;
(*dst_op->in_out_idxs_)[dst_name] = idx++; (*dst_op->in_out_idxs_)[dst_name] = idx++;
int src_arg_idx = src_op->in_out_idxs_->at(src_name); int src_arg_idx = src_op->in_out_idxs_->at(src_name);
int src_begin = int src_begin =
...@@ -65,10 +64,9 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, ...@@ -65,10 +64,9 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
int src_end = src_format == nullptr ? src_arg_idx + 1 int src_end = src_format == nullptr ? src_arg_idx + 1
: src_format->at(src_arg_idx + 1); : src_format->at(src_arg_idx + 1);
for (int i = src_begin; i < src_end; ++i) { for (int i = src_begin; i < src_end; ++i) {
std::string s = is_grad ? src_inout[i] + kGradVarSuffix std::string s =
: arg.ignore_gradient() is_grad ? src_inout[i] + kGradVarSuffix
? kEmptyVarName : (arg.ignore_gradient() ? kEmptyVarName : src_inout[i]);
: src_inout[i];
dst_inout.emplace_back(s); dst_inout.emplace_back(s);
} }
if (dst_format != nullptr) { if (dst_format != nullptr) {
......
...@@ -45,13 +45,12 @@ const std::string kTempVarName = "@TEMP@"; ...@@ -45,13 +45,12 @@ const std::string kTempVarName = "@TEMP@";
const std::string kGradVarSuffix = "@GRAD"; const std::string kGradVarSuffix = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros. /// Variables with this suffix are supposed to be filled up with zeros.
const std::string kZeroVarSuffix = "@ZERO"; const std::string kZeroVarSuffix = "@ZERO";
inline std::string GradVarName(const std::string& var_name) { inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix; return var_name + kGradVarSuffix;
} }
class OperatorBase; class OperatorBase;
class InferShapeContext; class InferShapeContext;
class ExecutionContext; class ExecutionContext;
......
...@@ -51,7 +51,7 @@ protected: ...@@ -51,7 +51,7 @@ protected:
PADDLE_ENFORCE(ctx.InputVar(framework::GradVarName("Y")) != nullptr, PADDLE_ENFORCE(ctx.InputVar(framework::GradVarName("Y")) != nullptr,
"Input(Y@GRAD) should not be null"); "Input(Y@GRAD) should not be null");
PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() == PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() ==
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(), ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
"the shape of Input(0) and Input(1) should be the same"); "the shape of Input(0) and Input(1) should be the same");
ctx.Output<Tensor>(framework::GradVarName("X")) ctx.Output<Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims()); ->Resize(ctx.Input<Tensor>("Y")->dims());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册