提交 597ac215 编写于 作者: D dongzhihong

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into mul_op

...@@ -32,9 +32,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { ...@@ -32,9 +32,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").AsNoGradient(); AddInput("X", "Input X of Add").NotInGradient();
AddInput("b", "Bias of Add").AsNoGradient(); AddInput("b", "Bias of Add").NotInGradient();
AddOutput("Out", "Out of Add").AsNoGradient(); AddOutput("Out", "Out of Add").NotInGradient();
AddComment("Add Op"); AddComment("Add Op");
} }
}; };
......
...@@ -60,7 +60,7 @@ message OpProto { ...@@ -60,7 +60,7 @@ message OpProto {
optional bool duplicable = 3 [ default = false ]; optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ]; optional bool intermediate = 4 [ default = false ];
optional bool no_gradient = 5 [ default = false ]; optional bool not_in_gradient = 5 [ default = false ];
} }
// AttrProto describes the C++ type Attribute. // AttrProto describes the C++ type Attribute.
......
...@@ -28,7 +28,7 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, ...@@ -28,7 +28,7 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
const auto& src_arg_list = const auto& src_arg_list =
src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
if (arg.no_gradient() && !is_grad) continue; if (arg.not_in_gradient() && !is_grad) continue;
const std::string src_name = arg.name(); const std::string src_name = arg.name();
std::string dst_name = is_grad ? GradVarName(src_name) : src_name; std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
dst_inout[dst_name].reserve(src_inout.at(src_name).size()); dst_inout[dst_name].reserve(src_inout.at(src_name).size());
......
...@@ -26,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { ...@@ -26,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input"); AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient(); AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient();
AddInput("In3_mult", "another multiple input").AsDuplicable(); AddInput("In3_mult", "another multiple input").AsDuplicable();
AddOutput("Out1_mult", "a multiple output").AsDuplicable(); AddOutput("Out1_mult", "a multiple output").AsDuplicable();
AddOutput("Out2", "a single output").AsNoGradient(); AddOutput("Out2", "a single output").NotInGradient();
AddComment("op with inputs and outputs ignored in gradient calculating"); AddComment("op with inputs and outputs ignored in gradient calculating");
} }
}; };
......
...@@ -184,11 +184,8 @@ class OpProtoAndCheckerMaker { ...@@ -184,11 +184,8 @@ class OpProtoAndCheckerMaker {
return *this; return *this;
} }
// TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it VariableBuilder& NotInGradient() {
// means that input/output is not needed when calculate gradient. It does var_->set_not_in_gradient(true);
// not mean no gradient when backward. It should be changed soon.
VariableBuilder& AsNoGradient() {
var_->set_no_gradient(true);
return *this; return *this;
} }
}; };
......
...@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").AsNoGradient(); AddOutput("Out", "The output of mean op").NotInGradient();
AddComment("Mean Operator"); AddComment("Mean Operator");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册