未验证 提交 eb612a82 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #6181 from emailweixu/enforce_drop_empty_ig

Enforce drop_empty_grad=false When the input of an op is duplicable.
...@@ -22,6 +22,14 @@ ...@@ -22,6 +22,14 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/*
This functor class is responsible for creating the gradient ops for the given
operator fwd_op. After it is called (through operator()), the pairs of
(gradient variable, corresponding input variable of fwd_op) will be added to
grad_to_var. If an input variable of fwd_op is contained in no_grad_set, its
gradient varialbe will be ignored or kEmptyVarName depending on the template
argument DropEmptyIG in the derived classes.
*/
class GradOpDescMakerBase { class GradOpDescMakerBase {
public: public:
explicit GradOpDescMakerBase( explicit GradOpDescMakerBase(
...@@ -56,6 +64,16 @@ class GradOpDescMakerBase { ...@@ -56,6 +64,16 @@ class GradOpDescMakerBase {
if (!drop_empty_grad) { if (!drop_empty_grad) {
return ret_val; return ret_val;
} }
PADDLE_ENFORCE_LE(var_names.size(), 1UL,
"BUG from operator developer:"
" for input argument with a list of variables, "
" drop_empty_grad is not allowed because it makes"
" the correspondence bewteen a variable and its gradient"
" ambiguous. Use REGISTER_OP_EX to register the op"
" or call InputGrad(?,false) in GradOpDescMaker."
" Op type %s",
fwd_op_.Type());
std::vector<std::string> dropped_ret_val; std::vector<std::string> dropped_ret_val;
dropped_ret_val.reserve(ret_val.size()); dropped_ret_val.reserve(ret_val.size());
std::copy_if(ret_val.begin(), ret_val.end(), std::copy_if(ret_val.begin(), ret_val.end(),
......
...@@ -127,7 +127,9 @@ class OpDesc { ...@@ -127,7 +127,9 @@ class OpDesc {
} }
proto::OpDesc desc_; proto::OpDesc desc_;
// input arg name => output variable names
VariableNameMap inputs_; VariableNameMap inputs_;
// output arg name => output variable names
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
......
...@@ -126,6 +126,14 @@ class OpKernelRegistrar : public Registrar { ...@@ -126,6 +126,14 @@ class OpKernelRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \ __test_global_namespace_##uniq_name##__>::value, \
msg) msg)
/*
The variadic arguments should be class types derived from one of the
following classes:
OpProtoAndCheckerMaker
GradOpDescMakerBase
VarTypeInference
InferShapeBase
*/
#define REGISTER_OPERATOR(op_type, op_class, ...) \ #define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \ __reg_op__##op_type, \
...@@ -144,15 +152,24 @@ class OpKernelRegistrar : public Registrar { ...@@ -144,15 +152,24 @@ class OpKernelRegistrar : public Registrar {
} }
/** /**
* Macro to register Operator. * Macro to register Operator. When the input is duplicable, you should
* use REGISTER_OP_EX with deop_empty_grad=false instead.
*/ */
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ #define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \ grad_op_class) \
REGISTER_OP_EX(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class, true)
// When an argument is duplicable, we need to use this version.
// Perhaps we can omit DropEmptyIG template parameter and
// only have one version of REGISTER_OP.
#define REGISTER_OP_EX(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class, drop_empty_grad) \
REGISTER_OPERATOR(grad_op_type, grad_op_class); \ REGISTER_OPERATOR(grad_op_type, grad_op_class); \
class _GradOpDescMaker_##grad_op_type##_ \ class _GradOpDescMaker_##grad_op_type##_ \
: public ::paddle::framework::DefaultGradOpDescMaker<true> { \ : public ::paddle::framework::DefaultGradOpDescMaker<drop_empty_grad> { \
using ::paddle::framework::DefaultGradOpDescMaker< \ using ::paddle::framework::DefaultGradOpDescMaker< \
true>::DefaultGradOpDescMaker; \ drop_empty_grad>::DefaultGradOpDescMaker; \
\ \
protected: \ protected: \
virtual std::string GradOpType() const { return #grad_op_type; } \ virtual std::string GradOpType() const { return #grad_op_type; } \
......
...@@ -98,8 +98,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel { ...@@ -98,8 +98,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad, REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
ops::ConcatOpGrad) ops::ConcatOpGrad, false)
REGISTER_OP_CPU_KERNEL(concat, REGISTER_OP_CPU_KERNEL(concat,
ops::ConcatKernel<paddle::platform::CPUPlace, float>) ops::ConcatKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP_CPU_KERNEL(concat_grad, REGISTER_OP_CPU_KERNEL(concat_grad,
......
...@@ -178,8 +178,9 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker { ...@@ -178,8 +178,9 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetInput("Out", Output("Out")); grad_op->SetInput("Out", Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetInput("Scope", Output("Scope")); grad_op->SetInput("Scope", Output("Scope"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params")); grad_op->SetOutput(framework::GradVarName("Params"),
InputGrad("Params", false));
grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]); grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]);
return std::unique_ptr<framework::OpDesc>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
......
...@@ -570,7 +570,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -570,7 +570,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
for (auto &input_param : this->InputNames()) { for (auto &input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param)); grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(framework::GradVarName(input_param), grad->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param)); this->InputGrad(input_param, false));
} }
for (auto &output_param : this->OutputNames()) { for (auto &output_param : this->OutputNames()) {
......
...@@ -124,8 +124,9 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel { ...@@ -124,8 +124,9 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sequence_concat, ops::SequenceConcatOp, ops::SequenceConcatOpMaker, REGISTER_OP_EX(sequence_concat, ops::SequenceConcatOp,
sequence_concat_grad, ops::SequenceConcatGradOp); ops::SequenceConcatOpMaker, sequence_concat_grad,
ops::SequenceConcatGradOp, false);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_concat, sequence_concat,
ops::SequenceConcatOpKernel<paddle::platform::CPUDeviceContext, float>); ops::SequenceConcatOpKernel<paddle::platform::CPUDeviceContext, float>);
......
...@@ -170,7 +170,7 @@ class SumGradMaker : public framework::GradOpDescMakerBase { ...@@ -170,7 +170,7 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
using framework::GradOpDescMakerBase::GradOpDescMakerBase; using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override { std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
auto x_grads = InputGrad("X"); auto x_grads = InputGrad("X", false);
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops; std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
grad_ops.reserve(x_grads.size()); grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out"); auto og = OutputGrad("Out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册