diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index d942604bf05998ab9e1ee147b28586e7e4e9777d..8780b5077386a0c25ddbe6f00d73f9ab8148165e 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -32,9 +32,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input X of Add").AsNoGradient(); - AddInput("b", "Bias of Add").AsNoGradient(); - AddOutput("Out", "Out of Add").AsNoGradient(); + AddInput("X", "Input X of Add").NotInGradient(); + AddInput("b", "Bias of Add").NotInGradient(); + AddOutput("Out", "Out of Add").NotInGradient(); AddComment("Add Op"); } }; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 7077e8aa2c77c24efdbb34ed3a13821fe7678455..ae44a1ffd45dacdc44a72edc630e771e7a2f2990 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -60,7 +60,7 @@ message OpProto { optional bool duplicable = 3 [ default = false ]; optional bool intermediate = 4 [ default = false ]; - optional bool no_gradient = 5 [ default = false ]; + optional bool not_in_gradient = 5 [ default = false ]; } // AttrProto describes the C++ type Attribute. diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index b73dac22d029876de9a012de533647db3dd74cbb..0a2a41f6b62658ac8633a6e384d099f8d6641f33 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -28,7 +28,7 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, const auto& src_arg_list = src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); for (const auto& arg : src_arg_list) { - if (arg.no_gradient() && !is_grad) continue; + if (arg.not_in_gradient() && !is_grad) continue; const std::string src_name = arg.name(); std::string dst_name = is_grad ? GradVarName(src_name) : src_name; dst_inout[dst_name].reserve(src_inout.at(src_name).size()); diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 0c26293fd29d24a7a40c47bdf055d2758846709b..902c2655e9182d74a48ad13e17a39a3304d5fa57 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -26,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient(); + AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient(); AddInput("In3_mult", "another multiple input").AsDuplicable(); AddOutput("Out1_mult", "a multiple output").AsDuplicable(); - AddOutput("Out2", "a single output").AsNoGradient(); + AddOutput("Out2", "a single output").NotInGradient(); AddComment("op with inputs and outputs ignored in gradient calculating"); } }; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 2c8620a7ce007ede4e2bef089e2fc8902bf0c9f4..dbe205976cf3b0efdce02fdaf53f5abc4ca9b6db 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -164,11 +164,8 @@ class OpProtoAndCheckerMaker { return *this; } - // TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it - // means that input/output is not needed when calculate gradient. It does - // not mean no gradient when backward. It should be changed soon. - VariableBuilder& AsNoGradient() { - var_->set_no_gradient(true); + VariableBuilder& NotInGradient() { + var_->set_not_in_gradient(true); return *this; } }; diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 49d0f43508b1ee3df0c6b5987942970e1649e310..d3d0e55a674587fb04f43f24d0790de4358f035a 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op").AsNoGradient(); + AddOutput("Out", "The output of mean op").NotInGradient(); AddComment("Mean Operator"); } };