diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index 25960383904b6a8199b1603eaabb7b48706a9c9f..141499d005deb906cd0338273f7af3286eb91595 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -181,6 +181,7 @@ class GradOpDescMakerBase { } std::string ForwardOpType() const { return this->fwd_op_.Type(); } + const BlockDesc* GetForwardOpBlock() const { return fwd_op_.Block(); } protected: bool HasInput(const std::string& name) const { diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index bdc07efbc0f8c45532ea67dfaaa3f812679018bf..3a4ee516d08f99e8d87656b70a72290fd3007b9a 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -312,6 +312,37 @@ class ConditionalBlockGradOp : public ConditionalOp { } }; +template +struct FilterNoGradInput {}; + +template <> +struct FilterNoGradInput { + static void filter(const framework::BlockDesc *desc, + std::vector *vec) { + auto f = [desc](const std::string &name) -> std::string { + if (name == framework::kEmptyVarName) { + // don't drop empty var name, you can use Input(name, true) to drop it. + return framework::kEmptyVarName; + } + auto var_desc = + desc->FindVarRecursive(framework::GradOriginalVarName(name)); + std::set not_support_backward_dtype = { + framework::proto::VarType::BOOL, + framework::proto::VarType::INT8, + framework::proto::VarType::UINT8, + framework::proto::VarType::INT16, + framework::proto::VarType::INT32, + framework::proto::VarType::INT64, + }; + if (!var_desc || + not_support_backward_dtype.count(var_desc->GetDataType())) + return framework::kEmptyVarName; + return name; + }; + std::transform(vec->begin(), vec->end(), vec->begin(), f); + } +}; + class ConditionalBlockGradInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { @@ -369,8 +400,11 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker { this->OutputGrad(ConditionalOp::kOutputs)); grad_op->SetInput(ConditionalOp::kScope, this->Output(ConditionalOp::kScope)); + + auto fwd_inputs = this->InputGrad(ConditionalOp::kInputs, false); + FilterNoGradInput::filter(this->GetForwardOpBlock(), &fwd_inputs); grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs), - this->InputGrad(ConditionalOp::kInputs, false)); + fwd_inputs); grad_op->SetBlockAttr("sub_block", this->grad_block_[0]); grad_op->SetAttr("is_scalar_condition", this->GetAttr("is_scalar_condition"));