未验证 提交 b94cf842 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move mean infershape into phi (#40922)

* move mean infershape into phi

* try to run ci

* share layout for mkldnn

* revert grad infershape

* revert grad infershape
上级 afe2fdd1
...@@ -16,7 +16,10 @@ limitations under the License. */ ...@@ -16,7 +16,10 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,12 +27,6 @@ namespace operators { ...@@ -24,12 +27,6 @@ namespace operators {
class MeanOp : public framework::OperatorWithKernel { class MeanOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mean");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mean");
ctx->SetOutputDim("Out", {1});
}
}; };
class MeanOpMaker : public framework::OpProtoAndCheckerMaker { class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -90,8 +87,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInferer, "X"); ...@@ -90,8 +87,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInferer, "X");
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(mean, MeanInferShapeFunctor,
PD_INFER_META(phi::MeanAllInferMeta));
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType, REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker<paddle::framework::OpDesc>, ops::MeanGradMaker<paddle::framework::OpDesc>,
ops::MeanGradMaker<paddle::imperative::OpBase>); ops::MeanGradMaker<paddle::imperative::OpBase>,
MeanInferShapeFunctor);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp, REGISTER_OPERATOR(mean_grad, ops::MeanGradOp,
ops::MeanGradNoNeedBufferVarsInferer); ops::MeanGradNoNeedBufferVarsInferer);
...@@ -823,6 +823,12 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, ...@@ -823,6 +823,12 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
mask->set_dtype(paddle::experimental::CppTypeToDataType<int>::Type()); mask->set_dtype(paddle::experimental::CppTypeToDataType<int>::Type());
} }
void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(phi::make_ddim({1}));
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
void ModeInferMeta(const MetaTensor& x, void ModeInferMeta(const MetaTensor& x,
int axis, int axis,
bool keepdim, bool keepdim,
......
...@@ -144,6 +144,8 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, ...@@ -144,6 +144,8 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
MetaTensor* mask, MetaTensor* mask,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out);
void ModeInferMeta(const MetaTensor& x, void ModeInferMeta(const MetaTensor& x,
int axis, int axis,
bool keepdim, bool keepdim,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册