未验证 提交 630d14f5 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Add more infer var type (#52818)

* add more infer var type

* fix split error

* fix ut

* fix top_k infer vartype

* fix top_k infer vartype
上级 1ab7e77a
...@@ -683,6 +683,13 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { ...@@ -683,6 +683,13 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
} }
}; };
class Reshape2InferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInferer, DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
...@@ -735,6 +742,7 @@ REGISTER_OPERATOR(reshape2, ...@@ -735,6 +742,7 @@ REGISTER_OPERATOR(reshape2,
ops::Reshape2OpMaker, ops::Reshape2OpMaker,
ops::Reshape2GradMaker<paddle::framework::OpDesc>, ops::Reshape2GradMaker<paddle::framework::OpDesc>,
ops::Reshape2GradMaker<paddle::imperative::OpBase>, ops::Reshape2GradMaker<paddle::imperative::OpBase>,
ops::Reshape2InferVarType,
ops::Reshape2CompositeGradOpMaker, ops::Reshape2CompositeGradOpMaker,
ops::ReshapeOpInplaceInferer); ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape2_grad, REGISTER_OPERATOR(reshape2_grad,
......
...@@ -237,6 +237,13 @@ class SplitCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { ...@@ -237,6 +237,13 @@ class SplitCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
} }
}; };
class SplitInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -246,5 +253,6 @@ REGISTER_OPERATOR(split, ...@@ -246,5 +253,6 @@ REGISTER_OPERATOR(split,
ops::SplitOp, ops::SplitOp,
ops::SplitOpMaker, ops::SplitOpMaker,
ops::SplitCompositeGradOpMaker, ops::SplitCompositeGradOpMaker,
ops::SplitInferVarType,
ops::SplitGradMaker<paddle::framework::OpDesc>, ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>); ops::SplitGradMaker<paddle::imperative::OpBase>);
...@@ -146,6 +146,18 @@ class TopkGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -146,6 +146,18 @@ class TopkGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class TopkInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
if (ctx->HasInput("K")) {
ctx->SyncTypeAndDataType("K", "Indices");
} else {
ctx->SetOutputDataType("Indices", framework::proto::VarType::INT32);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -153,6 +165,7 @@ namespace ops = paddle::operators; ...@@ -153,6 +165,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(top_k, REGISTER_OPERATOR(top_k,
ops::TopkOp, ops::TopkOp,
ops::TopkOpMaker, ops::TopkOpMaker,
ops::TopkInferVarType,
ops::TopkGradOpMaker<paddle::framework::OpDesc>, ops::TopkGradOpMaker<paddle::framework::OpDesc>,
ops::TopkGradOpMaker<paddle::imperative::OpBase>); ops::TopkGradOpMaker<paddle::imperative::OpBase>);
......
...@@ -745,6 +745,10 @@ def _lower_composite( ...@@ -745,6 +745,10 @@ def _lower_composite(
del block.vars[var_name] del block.vars[var_name]
block._sync_with_cpp() block._sync_with_cpp()
for op in block.ops:
if op._has_kernel(op.desc.type()):
op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc)
# composite ops may contain other composite ops, thus, call _lower_composite again. # composite ops may contain other composite ops, thus, call _lower_composite again.
if change: if change:
_lower_composite( _lower_composite(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册