未验证 提交 eb612a82 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #6181 from emailweixu/enforce_drop_empty_ig

Enforce drop_empty_grad=false When the input of an op is duplicable.
......@@ -22,6 +22,14 @@
namespace paddle {
namespace framework {
/*
This functor class is responsible for creating the gradient ops for the given
operator fwd_op. After it is called (through operator()), the pairs of
(gradient variable, corresponding input variable of fwd_op) will be added to
grad_to_var. If an input variable of fwd_op is contained in no_grad_set, its
gradient varialbe will be ignored or kEmptyVarName depending on the template
argument DropEmptyIG in the derived classes.
*/
class GradOpDescMakerBase {
public:
explicit GradOpDescMakerBase(
......@@ -56,6 +64,16 @@ class GradOpDescMakerBase {
if (!drop_empty_grad) {
return ret_val;
}
PADDLE_ENFORCE_LE(var_names.size(), 1UL,
"BUG from operator developer:"
" for input argument with a list of variables, "
" drop_empty_grad is not allowed because it makes"
" the correspondence bewteen a variable and its gradient"
" ambiguous. Use REGISTER_OP_EX to register the op"
" or call InputGrad(?,false) in GradOpDescMaker."
" Op type %s",
fwd_op_.Type());
std::vector<std::string> dropped_ret_val;
dropped_ret_val.reserve(ret_val.size());
std::copy_if(ret_val.begin(), ret_val.end(),
......
......@@ -127,7 +127,9 @@ class OpDesc {
}
proto::OpDesc desc_;
// input arg name => output variable names
VariableNameMap inputs_;
// output arg name => output variable names
VariableNameMap outputs_;
AttributeMap attrs_;
......
......@@ -126,6 +126,14 @@ class OpKernelRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \
msg)
/*
The variadic arguments should be class types derived from one of the
following classes:
OpProtoAndCheckerMaker
GradOpDescMakerBase
VarTypeInference
InferShapeBase
*/
#define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \
......@@ -144,15 +152,24 @@ class OpKernelRegistrar : public Registrar {
}
/**
* Macro to register Operator.
* Macro to register Operator. When the input is duplicable, you should
* use REGISTER_OP_EX with deop_empty_grad=false instead.
*/
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
REGISTER_OP_EX(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class, true)
// When an argument is duplicable, we need to use this version.
// Perhaps we can omit DropEmptyIG template parameter and
// only have one version of REGISTER_OP.
#define REGISTER_OP_EX(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class, drop_empty_grad) \
REGISTER_OPERATOR(grad_op_type, grad_op_class); \
class _GradOpDescMaker_##grad_op_type##_ \
: public ::paddle::framework::DefaultGradOpDescMaker<true> { \
: public ::paddle::framework::DefaultGradOpDescMaker<drop_empty_grad> { \
using ::paddle::framework::DefaultGradOpDescMaker< \
true>::DefaultGradOpDescMaker; \
drop_empty_grad>::DefaultGradOpDescMaker; \
\
protected: \
virtual std::string GradOpType() const { return #grad_op_type; } \
......
......@@ -98,8 +98,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
ops::ConcatOpGrad)
REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
ops::ConcatOpGrad, false)
REGISTER_OP_CPU_KERNEL(concat,
ops::ConcatKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP_CPU_KERNEL(concat_grad,
......
......@@ -178,8 +178,9 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetInput("Out", Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetInput("Scope", Output("Scope"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
grad_op->SetOutput(framework::GradVarName("Params"),
InputGrad("Params", false));
grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]);
return std::unique_ptr<framework::OpDesc>(grad_op);
}
......
......@@ -570,7 +570,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
for (auto &input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param));
this->InputGrad(input_param, false));
}
for (auto &output_param : this->OutputNames()) {
......
......@@ -124,8 +124,9 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_concat, ops::SequenceConcatOp, ops::SequenceConcatOpMaker,
sequence_concat_grad, ops::SequenceConcatGradOp);
REGISTER_OP_EX(sequence_concat, ops::SequenceConcatOp,
ops::SequenceConcatOpMaker, sequence_concat_grad,
ops::SequenceConcatGradOp, false);
REGISTER_OP_CPU_KERNEL(
sequence_concat,
ops::SequenceConcatOpKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -170,7 +170,7 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
auto x_grads = InputGrad("X");
auto x_grads = InputGrad("X", false);
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册