未验证 提交 b94cf842 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move mean infershape into phi (#40922)

* move mean infershape into phi

* try to run ci

* share layout for mkldnn

* revert grad infershape

* revert grad infershape
上级 afe2fdd1
......@@ -16,7 +16,10 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -24,12 +27,6 @@ namespace operators {
class MeanOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mean");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mean");
ctx->SetOutputDim("Out", {1});
}
};
class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -90,8 +87,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInferer, "X");
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(mean, MeanInferShapeFunctor,
PD_INFER_META(phi::MeanAllInferMeta));
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker<paddle::framework::OpDesc>,
ops::MeanGradMaker<paddle::imperative::OpBase>);
ops::MeanGradMaker<paddle::imperative::OpBase>,
MeanInferShapeFunctor);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp,
ops::MeanGradNoNeedBufferVarsInferer);
......@@ -823,6 +823,12 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
mask->set_dtype(paddle::experimental::CppTypeToDataType<int>::Type());
}
void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(phi::make_ddim({1}));
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
void ModeInferMeta(const MetaTensor& x,
int axis,
bool keepdim,
......
......@@ -144,6 +144,8 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
MetaTensor* mask,
MetaConfig config = MetaConfig());
void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out);
void ModeInferMeta(const MetaTensor& x,
int axis,
bool keepdim,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册