未验证 提交 609077e9 编写于 作者: C Chen Weihang 提交者: GitHub

move mul op infershape (#40917)

上级 4ab8255a
...@@ -21,6 +21,10 @@ limitations under the License. */ ...@@ -21,6 +21,10 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -34,72 +38,6 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -34,72 +38,6 @@ class MulOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Mul");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Mul");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Mul");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
VLOG(3) << "mul operator x.shape=" << x_dims << " y.shape=" << y_dims
<< " x_num_col_dims=" << x_num_col_dims
<< " y_num_col_dims=" << y_num_col_dims;
PADDLE_ENFORCE_NE(phi::product(y_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable Y(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("Y").front()));
PADDLE_ENFORCE_GT(
x_dims.size(), x_num_col_dims,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of MulOp "
"should be larger than x_num_col_dims. But received X's "
"dimensions = %d, X's shape = [%s], x_num_col_dims = %d.",
x_dims.size(), x_dims, x_num_col_dims));
PADDLE_ENFORCE_GT(
y_dims.size(), y_num_col_dims,
platform::errors::InvalidArgument(
"The input tensor Y's dimensions of MulOp "
"should be larger than y_num_col_dims. But received Y's "
"dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.",
y_dims.size(), y_dims, y_num_col_dims));
auto x_mat_dims = phi::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = phi::flatten_to_2d(y_dims, y_num_col_dims);
PADDLE_ENFORCE_EQ(
x_mat_dims[1], y_mat_dims[0],
platform::errors::InvalidArgument(
"After flatten the input tensor X and Y to 2-D dimensions matrix "
"X1 and Y1, the matrix X1's width must be equal with matrix Y1's "
"height. But received X's shape = [%s], X1's shape = [%s], X1's "
"width = %s; Y's shape = [%s], Y1's shape = [%s], Y1's height = "
"%s.",
x_dims, x_mat_dims, x_mat_dims[1], y_dims, y_mat_dims,
y_mat_dims[0]));
std::vector<int64_t> output_dims;
output_dims.reserve(
static_cast<size_t>(x_num_col_dims + y_dims.size() - y_num_col_dims));
for (int i = 0; i < x_num_col_dims; ++i) {
output_dims.push_back(x_dims[i]);
}
for (int i = y_num_col_dims; i < y_dims.size(); ++i) {
output_dims.push_back(y_dims[i]);
}
ctx->SetOutputDim("Out", phi::make_ddim(output_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
...@@ -225,25 +163,6 @@ class MulGradOp : public framework::OperatorWithKernel { ...@@ -225,25 +163,6 @@ class MulGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mul");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "mul");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "mul");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
...@@ -348,12 +267,18 @@ class MulDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -348,12 +267,18 @@ class MulDoubleGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(mul, MulInferShapeFunctor,
PD_INFER_META(phi::MatmulWithFlattenInferMeta));
REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType, REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType,
ops::MulOpGradMaker<paddle::framework::OpDesc>, ops::MulOpGradMaker<paddle::framework::OpDesc>,
ops::MulOpGradMaker<paddle::imperative::OpBase>); ops::MulOpGradMaker<paddle::imperative::OpBase>,
MulInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(mul_grad, MulGradInferShapeFunctor,
PD_INFER_META(phi::GeneralBinaryGradInferMeta));
REGISTER_OPERATOR(mul_grad, ops::MulGradOp, REGISTER_OPERATOR(mul_grad, ops::MulGradOp,
ops::MulDoubleGradMaker<paddle::framework::OpDesc>, ops::MulDoubleGradMaker<paddle::framework::OpDesc>,
ops::MulDoubleGradMaker<paddle::imperative::OpBase>); ops::MulDoubleGradMaker<paddle::imperative::OpBase>,
MulGradInferShapeFunctor);
REGISTER_OPERATOR(mul_grad_grad, ops::MulDoubleGradOp); REGISTER_OPERATOR(mul_grad_grad, ops::MulDoubleGradOp);
...@@ -1267,6 +1267,81 @@ void MatmulInferMeta(const MetaTensor& x, ...@@ -1267,6 +1267,81 @@ void MatmulInferMeta(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void MatmulWithFlattenInferMeta(const MetaTensor& x,
const MetaTensor& y,
int x_num_col_dims,
int y_num_col_dims,
MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
VLOG(3) << "mul operator x.shape=" << x_dims << " y.shape=" << y_dims
<< " x_num_col_dims=" << x_num_col_dims
<< " y_num_col_dims=" << y_num_col_dims;
PADDLE_ENFORCE_NE(phi::product(y_dims),
0,
phi::errors::PreconditionNotMet(
"The Input variable Y has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_GT(
x_dims.size(),
x_num_col_dims,
phi::errors::InvalidArgument(
"The input tensor X's dimensions of MulOp "
"should be larger than x_num_col_dims. But received X's "
"dimensions = %d, X's shape = [%s], x_num_col_dims = %d.",
x_dims.size(),
x_dims,
x_num_col_dims));
PADDLE_ENFORCE_GT(
y_dims.size(),
y_num_col_dims,
phi::errors::InvalidArgument(
"The input tensor Y's dimensions of MulOp "
"should be larger than y_num_col_dims. But received Y's "
"dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.",
y_dims.size(),
y_dims,
y_num_col_dims));
auto x_mat_dims = phi::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = phi::flatten_to_2d(y_dims, y_num_col_dims);
PADDLE_ENFORCE_EQ(
x_mat_dims[1],
y_mat_dims[0],
phi::errors::InvalidArgument(
"After flatten the input tensor X and Y to 2-D dimensions matrix "
"X1 and Y1, the matrix X1's width must be equal with matrix Y1's "
"height. But received X's shape = [%s], X1's shape = [%s], X1's "
"width = %s; Y's shape = [%s], Y1's shape = [%s], Y1's height = "
"%s.",
x_dims,
x_mat_dims,
x_mat_dims[1],
y_dims,
y_mat_dims,
y_mat_dims[0]));
std::vector<int64_t> output_dims;
output_dims.reserve(
static_cast<size_t>(x_num_col_dims + y_dims.size() - y_num_col_dims));
for (int i = 0; i < x_num_col_dims; ++i) {
output_dims.push_back(x_dims[i]);
}
for (int i = y_num_col_dims; i < y_dims.size(); ++i) {
output_dims.push_back(y_dims[i]);
}
out->set_dims(phi::make_ddim(output_dims));
out->set_dtype(x.dtype());
out->share_lod(x);
}
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
auto dim_x = x.dims(); auto dim_x = x.dims();
auto dim_vec = vec.dims(); auto dim_vec = vec.dims();
......
...@@ -186,6 +186,12 @@ void MatmulInferMeta(const MetaTensor& x, ...@@ -186,6 +186,12 @@ void MatmulInferMeta(const MetaTensor& x,
bool trans_y, bool trans_y,
MetaTensor* out); MetaTensor* out);
void MatmulWithFlattenInferMeta(const MetaTensor& x,
const MetaTensor& y,
int x_num_col_dims,
int y_num_col_dims,
MetaTensor* out);
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out);
void PReluInferMeta(const MetaTensor& x, void PReluInferMeta(const MetaTensor& x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册