未验证 提交 4be5448b 编写于 作者: F furnace 提交者: GitHub

[Phi] move infershape for mv (#39954)

* [Phi] move infershape for mv

* [Phi] delete extra codes for mv
上级 94f03dc2
......@@ -16,8 +16,11 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -42,33 +45,6 @@ class MVOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "mv");
OP_INOUT_CHECK(context->HasInput("Vec"), "Input", "Vec", "mv");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "mv");
auto dim_x = context->GetInputDim("X");
auto dim_vec = context->GetInputDim("Vec");
PADDLE_ENFORCE_EQ(
dim_x.size(), 2,
platform::errors::InvalidArgument(
"The rank of input X should be 2, but is %d", dim_x.size()));
PADDLE_ENFORCE_EQ(
dim_vec.size(), 1,
platform::errors::InvalidArgument(
"The rank of input Vec should be 1, but is %d", dim_vec.size()));
PADDLE_ENFORCE_EQ(dim_x[1], dim_vec[0],
platform::errors::InvalidArgument(
"X's second dimension is expected to be equal to "
"Vec's first dimension"
"but recieved X'shape = [%s], Vec's shape = [%s]",
dim_x, dim_vec));
framework::DDim dim_out = phi::make_ddim({dim_x[0]});
context->SetOutputDim("Out", dim_out);
context->ShareLoD("X", /*->*/ "Out");
}
};
template <typename T>
......@@ -118,7 +94,11 @@ class MVOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DELCARE_INFER_SHAPE_FUNCTOR(mv, MvInferShapeFunctor,
PT_INFER_META(phi::MvInferMeta));
REGISTER_OPERATOR(mv, ops::MVOp, ops::MVOpMaker,
ops::MVOpGradMaker<paddle::framework::OpDesc>,
ops::MVOpGradMaker<paddle::imperative::OpBase>);
ops::MVOpGradMaker<paddle::imperative::OpBase>,
MvInferShapeFunctor);
REGISTER_OPERATOR(mv_grad, ops::MVOpGrad);
......@@ -443,4 +443,34 @@ void GatherTreeMeta(const MetaTensor& ids,
out->set_dims(ids_dims);
}
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
auto dim_x = x.dims();
auto dim_vec = vec.dims();
PADDLE_ENFORCE_EQ(
dim_x.size(),
2,
phi::errors::InvalidArgument("The rank of input X should be 2, but is %d",
dim_x.size()));
PADDLE_ENFORCE_EQ(
dim_vec.size(),
1,
phi::errors::InvalidArgument(
"The rank of input Vec should be 1, but is %d", dim_vec.size()));
PADDLE_ENFORCE_EQ(dim_x[1],
dim_vec[0],
phi::errors::InvalidArgument(
"X's second dimension is expected to be equal to "
"Vec's first dimension"
"but recieved X'shape = [%s], Vec's shape = [%s]",
dim_x,
dim_vec));
auto dim_out = phi::make_ddim({dim_x[0]});
out->set_dims(dim_out);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}
} // namespace phi
......@@ -85,4 +85,7 @@ void GatherNdInferMeta(const MetaTensor& x,
void GatherTreeMeta(const MetaTensor& ids,
const MetaTensor& parents,
MetaTensor* out);
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out);
} // namespace phi
......@@ -16,10 +16,6 @@
namespace phi {
KernelSignature MvOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("mv", {"X", "Vec"}, {}, {"Out"});
}
KernelSignature MvGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("mv_grad",
{"X", "Vec", GradVarName("Out")},
......@@ -29,5 +25,4 @@ KernelSignature MvGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(mv, phi::MvOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(mv_grad, phi::MvGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册