未验证 提交 59e5c49f 编写于 作者: C Chen Weihang 提交者: GitHub

move gather infershape (#40594)

上级 ec6b8fbd
......@@ -15,9 +15,14 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -26,58 +31,6 @@ class GatherOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of GatherOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Input(Index) of GatherOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of GatherOp should not be null."));
auto index_dims = ctx->GetInputDim("Index");
if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of index should be 1 when it is 2D, but we get %d",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(), 1,
platform::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
auto axis = ctx->Attrs().Get<int>("axis");
auto input_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Axis") || axis == 0) {
// if HasInput("Axis"), we can not obtain correct shape of output
int batch_size = index_dims[0];
framework::DDim output_dims(input_dim);
output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -100,11 +53,6 @@ class GatherGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -193,11 +141,17 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(gather, GatherInferShapeFunctor,
PD_INFER_META(phi::GatherInferMeta));
REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
ops::GatherGradOpMaker<paddle::framework::OpDesc>,
ops::GatherGradOpMaker<paddle::imperative::OpBase>);
ops::GatherGradOpMaker<paddle::imperative::OpBase>,
GatherInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(gather_grad, GatherGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
ops::GatherGradNoNeedBufferVarInferer);
ops::GatherGradNoNeedBufferVarInferer,
GatherGradInferShapeFunctor);
REGISTER_OP_VERSION(gather)
.AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC",
......
......@@ -431,6 +431,55 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
out->share_lod(x);
}
void GatherInferMeta(const MetaTensor& x,
const MetaTensor& index,
const Scalar& axis,
MetaTensor* out) {
auto index_dims = index.dims();
if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of index should be 1 when it is 2D, but we get %d",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
auto input_dim = x.dims();
auto axis_v = axis.to<int>();
if (axis.FromTensor() || axis_v == 0) {
// if axis.FromTensor(), we can not obtain correct shape of output
int batch_size = index_dims[0];
phi::DDim output_dims(input_dim);
output_dims[0] = batch_size;
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis_v; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis_v + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
out->set_dims(output_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
}
void GatherNdInferMeta(const MetaTensor& x,
const MetaTensor& index,
MetaTensor* out) {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/meta_tensor.h"
namespace phi {
......@@ -81,6 +82,11 @@ void ElementwiseRawInferMeta(const MetaTensor& x_meta,
int axis,
MetaTensor* out);
void GatherInferMeta(const MetaTensor& x,
const MetaTensor& index,
const Scalar& axis,
MetaTensor* out);
void GatherNdInferMeta(const MetaTensor& x,
const MetaTensor& index,
MetaTensor* out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册