From 36046a89604ce085cfa717b740b95a5f816da2f4 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 8 Sep 2022 12:53:29 +0800 Subject: [PATCH] [Dy2Static] Filter int64/int32/int16/bool in conditional op (#45759) * stop pass filter int32/int16/int64/bool inputs in cond_op * fix bugs: except block 0, the backward vars and forward vars exist in different blocks. * fix code by review --- paddle/fluid/framework/grad_op_desc_maker.h | 1 + .../controlflow/conditional_block_op.cc | 36 ++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index 2596038390..141499d005 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -181,6 +181,7 @@ class GradOpDescMakerBase { } std::string ForwardOpType() const { return this->fwd_op_.Type(); } + const BlockDesc* GetForwardOpBlock() const { return fwd_op_.Block(); } protected: bool HasInput(const std::string& name) const { diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index bdc07efbc0..3a4ee516d0 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -312,6 +312,37 @@ class ConditionalBlockGradOp : public ConditionalOp { } }; +template +struct FilterNoGradInput {}; + +template <> +struct FilterNoGradInput { + static void filter(const framework::BlockDesc *desc, + std::vector *vec) { + auto f = [desc](const std::string &name) -> std::string { + if (name == framework::kEmptyVarName) { + // don't drop empty var name, you can use Input(name, true) to drop it. + return framework::kEmptyVarName; + } + auto var_desc = + desc->FindVarRecursive(framework::GradOriginalVarName(name)); + std::set not_support_backward_dtype = { + framework::proto::VarType::BOOL, + framework::proto::VarType::INT8, + framework::proto::VarType::UINT8, + framework::proto::VarType::INT16, + framework::proto::VarType::INT32, + framework::proto::VarType::INT64, + }; + if (!var_desc || + not_support_backward_dtype.count(var_desc->GetDataType())) + return framework::kEmptyVarName; + return name; + }; + std::transform(vec->begin(), vec->end(), vec->begin(), f); + } +}; + class ConditionalBlockGradInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { @@ -369,8 +400,11 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker { this->OutputGrad(ConditionalOp::kOutputs)); grad_op->SetInput(ConditionalOp::kScope, this->Output(ConditionalOp::kScope)); + + auto fwd_inputs = this->InputGrad(ConditionalOp::kInputs, false); + FilterNoGradInput::filter(this->GetForwardOpBlock(), &fwd_inputs); grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs), - this->InputGrad(ConditionalOp::kInputs, false)); + fwd_inputs); grad_op->SetBlockAttr("sub_block", this->grad_block_[0]); grad_op->SetAttr("is_scalar_condition", this->GetAttr("is_scalar_condition")); -- GitLab