diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index d5e41b7b7e8347effbdde4f7d0ef3083553940ff..13706f8b562a1d68fe0d603f51c2fb47b4e18164 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -59,16 +59,14 @@ std::shared_ptr BackwardRecursive( // If all input gradients of forwarding operator do not need to calculate, // just return an NOP. Not return null ptr because NOP does not take // too much time for calculation, but it is useful for simplifying logic. - if (AllInSet(forwardOp.inputs_, kGradVarSuffix, - no_grad_names)) { + if (AllInSet(forwardOp.inputs_, kGradVarSuffix, no_grad_names)) { return NOP(); } // All output gradients of forwarding operator do not need to calculate. // Then all input gradients cannot be computed at all, and we put them into // `no_grad_names` set. Return an NOP. - if (AllInSet(forwardOp.outputs_, kGradVarSuffix, - no_grad_names)) { + if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) { for (auto& name : forwardOp.inputs_) { // Mark all input is not need no_grad_names.insert(name + kGradVarSuffix); @@ -134,8 +132,8 @@ std::shared_ptr BackwardRecursive( std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); for (std::string& grad_input : grad_op->inputs_) { if (no_grad_names.count(grad_input)) { - std::string prefix = grad_input.substr( - 0, grad_input.size() - kGradVarSuffix.size()); + std::string prefix = + grad_input.substr(0, grad_input.size() - kGradVarSuffix.size()); grad_input = prefix + kZeroVarSuffix; // If part of input gradient of that operator is not calculated, fill @@ -168,8 +166,7 @@ std::shared_ptr Backward( std::unordered_set no_grad_names; no_grad_names.reserve(no_grad_vars.size()); - no_grad_names.insert(kEmptyVarName + - kGradVarSuffix); + no_grad_names.insert(kEmptyVarName + kGradVarSuffix); for (auto& name : no_grad_vars) { no_grad_names.insert(name + kGradVarSuffix); @@ -177,5 +174,6 @@ std::shared_ptr Backward( size_t uid = 0; return BackwardRecursive(forwardOp, no_grad_names, uid); } + } // namespace framework } // namespace paddle diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 061bf1063ff92eb25772a562efe6255851870a00..6c6e12ca254553a8fc02cadbe3a99989ee848943 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -168,8 +168,7 @@ TEST(Backward, simple_op_grad) { ASSERT_EQ("X" + f::kGradVarSuffix, gop->outputs_[0]); ASSERT_EQ("b" + f::kGradVarSuffix, gop->outputs_[1]); - ASSERT_EQ("X" + f::kGradVarSuffix, - gop->Output("X" + f::kGradVarSuffix)); + ASSERT_EQ("X" + f::kGradVarSuffix, gop->Output("X" + f::kGradVarSuffix)); } TEST(Backward, simple_op_not_need_grad) { @@ -210,9 +209,9 @@ TEST(Backward, net_fc_backward_normal) { } TEST(Backward, net_fc_backward_not_have_b) { - std::shared_ptr fwd = f::OpRegistry::CreateOp( - "fc", {"X", "w", f::kEmptyVarName}, - {"mul_result", "add_result", "tmp"}, {}); + std::shared_ptr fwd = + f::OpRegistry::CreateOp("fc", {"X", "w", f::kEmptyVarName}, + {"mul_result", "add_result", "tmp"}, {}); ASSERT_NE(fwd, nullptr); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); @@ -245,21 +244,18 @@ TEST(Backward, net_input_of_network_not_need_grad) { all_output.erase(f::kEmptyVarName); for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { - ASSERT_NE(all_output.find(out + f::kGradVarSuffix), - all_output.end()); + ASSERT_NE(all_output.find(out + f::kGradVarSuffix), all_output.end()); } // Not Generated X - ASSERT_EQ(all_output.find("X" + f::kGradVarSuffix), - all_output.end()); + ASSERT_EQ(all_output.find("X" + f::kGradVarSuffix), all_output.end()); ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); ASSERT_EQ(3UL, first_fc_grad->ops_.size()); - ASSERT_EQ( - f::kEmptyVarName, - first_fc_grad->ops_[2]->Output("A" + f::kGradVarSuffix)); + ASSERT_EQ(f::kEmptyVarName, + first_fc_grad->ops_[2]->Output("A" + f::kGradVarSuffix)); } TEST(Backward, net_shared_weight) { @@ -316,10 +312,8 @@ TEST(Backward, op_part_of_output_are_not_need) { auto &d_many_out = *net->ops_[1]; ASSERT_EQ("many_output_op_grad", d_many_out.type_); ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG - ASSERT_EQ("Z" + f::kZeroVarSuffix, - d_many_out.Input("z" + f::kGradVarSuffix)); - ASSERT_EQ("Y" + f::kGradVarSuffix, - d_many_out.Input("y" + f::kGradVarSuffix)); + ASSERT_EQ("Z" + f::kZeroVarSuffix, d_many_out.Input("z" + f::kGradVarSuffix)); + ASSERT_EQ("Y" + f::kGradVarSuffix, d_many_out.Input("y" + f::kGradVarSuffix)); ASSERT_EQ("X" + f::kGradVarSuffix, d_many_out.Output("x" + f::kGradVarSuffix)); } @@ -331,10 +325,8 @@ TEST(Backward, op_part_of_input_are_not_need) { ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL); - ASSERT_EQ(grad_mul.Output("A" + f::kGradVarSuffix), - f::kEmptyVarName); - ASSERT_EQ(grad_mul.Output("B" + f::kGradVarSuffix), - "b" + f::kGradVarSuffix); + ASSERT_EQ(grad_mul.Output("A" + f::kGradVarSuffix), f::kEmptyVarName); + ASSERT_EQ(grad_mul.Output("B" + f::kGradVarSuffix), "b" + f::kGradVarSuffix); ASSERT_EQ(grad_mul.Input("Out" + f::kGradVarSuffix), "out" + f::kGradVarSuffix); ASSERT_EQ(grad_mul.Input("A"), "a"); @@ -368,23 +360,4 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL); - - /* - EXPECT_EQ(grad_fc.Output("X" + f::kGradVarSuffix), - f::kEmptyVarName); - EXPECT_EQ(grad_fc.Output("W" + f::kGradVarSuffix), - "w3" + f::kGradVarSuffix); - EXPECT_EQ(grad_fc.Output("b" + f::kGradVarSuffix), - "b3" + f::kGradVarSuffix); - EXPECT_EQ(grad_fc.Output("mul_result" + f::kGradVarSuffix), - "mul_out3" + f::kGradVarSuffix); - - EXPECT_EQ(grad_fc.Input("Out" + f::kGradVarSuffix), - "out3" + f::kGradVarSuffix); - EXPECT_EQ(grad_fc.Input("X"), "out2"); - EXPECT_EQ(grad_fc.Input("W"), "w3"); - EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3"); - EXPECT_EQ(grad_fc.Input("add_result"), "tmp_out3"); - EXPECT_EQ(grad_fc.Input("Out"), "out3"); - */ } diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index f34aaa28c52dccaf116b60e0804b802e6e2b2ea4..3aefbb3fffb064c6fa9f41d7296088a24758af80 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -56,8 +56,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, for (const auto& arg : src_arg_list) { std::string src_name = arg.name(); - std::string dst_name = - is_grad ? src_name + kGradVarSuffix : src_name; + std::string dst_name = is_grad ? src_name + kGradVarSuffix : src_name; (*dst_op->in_out_idxs_)[dst_name] = idx++; int src_arg_idx = src_op->in_out_idxs_->at(src_name); int src_begin = @@ -65,10 +64,9 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, int src_end = src_format == nullptr ? src_arg_idx + 1 : src_format->at(src_arg_idx + 1); for (int i = src_begin; i < src_end; ++i) { - std::string s = is_grad ? src_inout[i] + kGradVarSuffix - : arg.ignore_gradient() - ? kEmptyVarName - : src_inout[i]; + std::string s = + is_grad ? src_inout[i] + kGradVarSuffix + : (arg.ignore_gradient() ? kEmptyVarName : src_inout[i]); dst_inout.emplace_back(s); } if (dst_format != nullptr) { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 572c1d2b58f3212c01691760a04fe1f6eea96d69..c4e23c3350fc64fa10b187c39eefa5846ba29238 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -45,13 +45,12 @@ const std::string kTempVarName = "@TEMP@"; const std::string kGradVarSuffix = "@GRAD"; /// Variables with this suffix are supposed to be filled up with zeros. -const std::string kZeroVarSuffix = "@ZERO"; +const std::string kZeroVarSuffix = "@ZERO"; inline std::string GradVarName(const std::string& var_name) { return var_name + kGradVarSuffix; } - class OperatorBase; class InferShapeContext; class ExecutionContext; diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index e0d5e16ca0c10962ab7d6b2b8dd0811e1285cc29..e8bb7032f852fed1e79eba8391aa4ddd50f8602b 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -51,7 +51,7 @@ protected: PADDLE_ENFORCE(ctx.InputVar(framework::GradVarName("Y")) != nullptr, "Input(Y@GRAD) should not be null"); PADDLE_ENFORCE(ctx.Input("Y")->dims() == - ctx.Input(framework::GradVarName("Y"))->dims(), + ctx.Input(framework::GradVarName("Y"))->dims(), "the shape of Input(0) and Input(1) should be the same"); ctx.Output(framework::GradVarName("X")) ->Resize(ctx.Input("Y")->dims());