提交 597ac215 编写于 作者: D dongzhihong

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into mul_op

......@@ -32,9 +32,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").AsNoGradient();
AddInput("b", "Bias of Add").AsNoGradient();
AddOutput("Out", "Out of Add").AsNoGradient();
AddInput("X", "Input X of Add").NotInGradient();
AddInput("b", "Bias of Add").NotInGradient();
AddOutput("Out", "Out of Add").NotInGradient();
AddComment("Add Op");
}
};
......
......@@ -60,7 +60,7 @@ message OpProto {
optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ];
optional bool no_gradient = 5 [ default = false ];
optional bool not_in_gradient = 5 [ default = false ];
}
// AttrProto describes the C++ type Attribute.
......
......@@ -28,7 +28,7 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
const auto& src_arg_list =
src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
for (const auto& arg : src_arg_list) {
if (arg.no_gradient() && !is_grad) continue;
if (arg.not_in_gradient() && !is_grad) continue;
const std::string src_name = arg.name();
std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
dst_inout[dst_name].reserve(src_inout.at(src_name).size());
......
......@@ -26,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient();
AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient();
AddInput("In3_mult", "another multiple input").AsDuplicable();
AddOutput("Out1_mult", "a multiple output").AsDuplicable();
AddOutput("Out2", "a single output").AsNoGradient();
AddOutput("Out2", "a single output").NotInGradient();
AddComment("op with inputs and outputs ignored in gradient calculating");
}
};
......
......@@ -184,11 +184,8 @@ class OpProtoAndCheckerMaker {
return *this;
}
// TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it
// means that input/output is not needed when calculate gradient. It does
// not mean no gradient when backward. It should be changed soon.
VariableBuilder& AsNoGradient() {
var_->set_no_gradient(true);
VariableBuilder& NotInGradient() {
var_->set_not_in_gradient(true);
return *this;
}
};
......
......@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").AsNoGradient();
AddOutput("Out", "The output of mean op").NotInGradient();
AddComment("Mean Operator");
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册