未验证 提交 36046a89 编写于 作者: X xiongkun 提交者: GitHub

[Dy2Static] Filter int64/int32/int16/bool in conditional op (#45759)

* stop pass filter int32/int16/int64/bool inputs in cond_op

* fix bugs: except block 0, the backward vars and forward vars exist in different blocks.

* fix code by review
上级 bb725e3a
...@@ -181,6 +181,7 @@ class GradOpDescMakerBase { ...@@ -181,6 +181,7 @@ class GradOpDescMakerBase {
} }
std::string ForwardOpType() const { return this->fwd_op_.Type(); } std::string ForwardOpType() const { return this->fwd_op_.Type(); }
const BlockDesc* GetForwardOpBlock() const { return fwd_op_.Block(); }
protected: protected:
bool HasInput(const std::string& name) const { bool HasInput(const std::string& name) const {
......
...@@ -312,6 +312,37 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -312,6 +312,37 @@ class ConditionalBlockGradOp : public ConditionalOp {
} }
}; };
template <class T>
struct FilterNoGradInput {};
template <>
struct FilterNoGradInput<framework::OpDesc> {
static void filter(const framework::BlockDesc *desc,
std::vector<std::string> *vec) {
auto f = [desc](const std::string &name) -> std::string {
if (name == framework::kEmptyVarName) {
// don't drop empty var name, you can use Input(name, true) to drop it.
return framework::kEmptyVarName;
}
auto var_desc =
desc->FindVarRecursive(framework::GradOriginalVarName(name));
std::set<framework::proto::VarType::Type> not_support_backward_dtype = {
framework::proto::VarType::BOOL,
framework::proto::VarType::INT8,
framework::proto::VarType::UINT8,
framework::proto::VarType::INT16,
framework::proto::VarType::INT32,
framework::proto::VarType::INT64,
};
if (!var_desc ||
not_support_backward_dtype.count(var_desc->GetDataType()))
return framework::kEmptyVarName;
return name;
};
std::transform(vec->begin(), vec->end(), vec->begin(), f);
}
};
class ConditionalBlockGradInferShape : public framework::InferShapeBase { class ConditionalBlockGradInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
...@@ -369,8 +400,11 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -369,8 +400,11 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad(ConditionalOp::kOutputs)); this->OutputGrad(ConditionalOp::kOutputs));
grad_op->SetInput(ConditionalOp::kScope, grad_op->SetInput(ConditionalOp::kScope,
this->Output(ConditionalOp::kScope)); this->Output(ConditionalOp::kScope));
auto fwd_inputs = this->InputGrad(ConditionalOp::kInputs, false);
FilterNoGradInput<T>::filter(this->GetForwardOpBlock(), &fwd_inputs);
grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs), grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs),
this->InputGrad(ConditionalOp::kInputs, false)); fwd_inputs);
grad_op->SetBlockAttr("sub_block", this->grad_block_[0]); grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
grad_op->SetAttr("is_scalar_condition", grad_op->SetAttr("is_scalar_condition",
this->GetAttr("is_scalar_condition")); this->GetAttr("is_scalar_condition"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册