未验证 提交 75280d36 编写于 作者: C chentianyu03 提交者: GitHub

remove dot infershape (#39945)

上级 4149cabe
......@@ -14,6 +14,10 @@
#include "paddle/fluid/operators/dot_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -21,51 +25,6 @@ class DotOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(true, ctx->HasInput("X"),
platform::errors::PreconditionNotMet(
"Input(X) of DotOp should not be null."));
PADDLE_ENFORCE_EQ(true, ctx->HasInput("Y"),
platform::errors::PreconditionNotMet(
"Input(Y) of DotOp should not be null."));
PADDLE_ENFORCE_EQ(true, ctx->HasOutput("Out"),
platform::errors::PreconditionNotMet(
"Output(Out) of DotOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto x_rank = static_cast<size_t>(x_dims.size());
PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank,
platform::errors::PreconditionNotMet(
"ShapeError: The dimensions of input tensor X (%s) "
"should be 1 or 2",
x_dims.to_str()));
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(
true, x_rank == (size_t)y_dims.size(),
platform::errors::PreconditionNotMet(
"ShapeError: The shape of input tensor Y: %s should match with "
"input tenosr X: %s",
y_dims.to_str(), x_dims.to_str()));
bool shape_match = true;
for (size_t i = 0; i < x_rank; ++i) {
if (x_dims[i] != y_dims[i]) {
shape_match = false;
break;
}
}
PADDLE_ENFORCE_EQ(true, shape_match,
platform::errors::PreconditionNotMet(
"ShapeError: The shape of input tensor X: %s should "
"be exactly the same "
"with input tensor Y: %s",
x_dims.to_str(), y_dims.to_str()));
auto dims = vectorize(x_dims);
dims[dims.size() - 1] = 1;
ctx->SetOutputDim("Out", phi::make_ddim(dims));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
......@@ -142,9 +101,13 @@ class DotOpGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(dot, DotInferShapeFunctor,
PT_INFER_META(phi::DotInferMeta));
REGISTER_OPERATOR(dot, ops::DotOp, ops::DotOpMaker,
ops::DotOpGradMaker<paddle::framework::OpDesc>,
ops::DotOpGradMaker<paddle::imperative::OpBase>);
ops::DotOpGradMaker<paddle::imperative::OpBase>,
DotInferShapeFunctor);
REGISTER_OPERATOR(dot_grad, ops::DotGradOp);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册