未验证 提交 630d14f5 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Add more infer var type (#52818)

* add more infer var type

* fix split error

* fix ut

* fix top_k infer vartype

* fix top_k infer vartype
上级 1ab7e77a
......@@ -683,6 +683,13 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
}
};
class Reshape2InferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInferer,
{framework::GradVarName("Out"),
......@@ -735,6 +742,7 @@ REGISTER_OPERATOR(reshape2,
ops::Reshape2OpMaker,
ops::Reshape2GradMaker<paddle::framework::OpDesc>,
ops::Reshape2GradMaker<paddle::imperative::OpBase>,
ops::Reshape2InferVarType,
ops::Reshape2CompositeGradOpMaker,
ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape2_grad,
......
......@@ -237,6 +237,13 @@ class SplitCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
}
};
class SplitInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
} // namespace operators
} // namespace paddle
......@@ -246,5 +253,6 @@ REGISTER_OPERATOR(split,
ops::SplitOp,
ops::SplitOpMaker,
ops::SplitCompositeGradOpMaker,
ops::SplitInferVarType,
ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>);
......@@ -146,6 +146,18 @@ class TopkGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class TopkInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
if (ctx->HasInput("K")) {
ctx->SyncTypeAndDataType("K", "Indices");
} else {
ctx->SetOutputDataType("Indices", framework::proto::VarType::INT32);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -153,6 +165,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(top_k,
ops::TopkOp,
ops::TopkOpMaker,
ops::TopkInferVarType,
ops::TopkGradOpMaker<paddle::framework::OpDesc>,
ops::TopkGradOpMaker<paddle::imperative::OpBase>);
......
......@@ -745,6 +745,10 @@ def _lower_composite(
del block.vars[var_name]
block._sync_with_cpp()
for op in block.ops:
if op._has_kernel(op.desc.type()):
op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc)
# composite ops may contain other composite ops, thus, call _lower_composite again.
if change:
_lower_composite(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册