From 3484874278a1e1377af37677d29609f95fff2325 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 16 Aug 2017 14:44:51 -0700 Subject: [PATCH] Rename `AsNoGradient` of VariableBuilder to `NotInGradient` --- paddle/framework/backward_test.cc | 6 +++--- paddle/framework/framework.proto | 2 +- paddle/framework/grad_op_builder.cc | 2 +- paddle/framework/grad_op_builder_test.cc | 4 ++-- paddle/framework/operator.h | 7 ++----- paddle/operators/mean_op.cc | 2 +- 6 files changed, 10 insertions(+), 13 deletions(-) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index d942604bf0..8780b50773 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -32,9 +32,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input X of Add").AsNoGradient(); - AddInput("b", "Bias of Add").AsNoGradient(); - AddOutput("Out", "Out of Add").AsNoGradient(); + AddInput("X", "Input X of Add").NotInGradient(); + AddInput("b", "Bias of Add").NotInGradient(); + AddOutput("Out", "Out of Add").NotInGradient(); AddComment("Add Op"); } }; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 7077e8aa2c..ae44a1ffd4 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -60,7 +60,7 @@ message OpProto { optional bool duplicable = 3 [ default = false ]; optional bool intermediate = 4 [ default = false ]; - optional bool no_gradient = 5 [ default = false ]; + optional bool not_in_gradient = 5 [ default = false ]; } // AttrProto describes the C++ type Attribute. diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index b73dac22d0..0a2a41f6b6 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -28,7 +28,7 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, const auto& src_arg_list = src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); for (const auto& arg : src_arg_list) { - if (arg.no_gradient() && !is_grad) continue; + if (arg.not_in_gradient() && !is_grad) continue; const std::string src_name = arg.name(); std::string dst_name = is_grad ? GradVarName(src_name) : src_name; dst_inout[dst_name].reserve(src_inout.at(src_name).size()); diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 0c26293fd2..902c2655e9 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -26,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient(); + AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient(); AddInput("In3_mult", "another multiple input").AsDuplicable(); AddOutput("Out1_mult", "a multiple output").AsDuplicable(); - AddOutput("Out2", "a single output").AsNoGradient(); + AddOutput("Out2", "a single output").NotInGradient(); AddComment("op with inputs and outputs ignored in gradient calculating"); } }; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 2c8620a7ce..dbe205976c 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -164,11 +164,8 @@ class OpProtoAndCheckerMaker { return *this; } - // TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it - // means that input/output is not needed when calculate gradient. It does - // not mean no gradient when backward. It should be changed soon. - VariableBuilder& AsNoGradient() { - var_->set_no_gradient(true); + VariableBuilder& NotInGradient() { + var_->set_not_in_gradient(true); return *this; } }; diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 49d0f43508..d3d0e55a67 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op").AsNoGradient(); + AddOutput("Out", "The output of mean op").NotInGradient(); AddComment("Mean Operator"); } }; -- GitLab