未验证 提交 42a75145 编写于 作者: A Aurelius84 提交者: GitHub

Fix inferMefer in transpose2_grad (#50388)

* Fix inferMefer in transpose2_grad

* fix infershape

* fix unittest
上级 ca520280
......@@ -16,6 +16,10 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
......@@ -179,19 +183,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "TransposeOpGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"TransposeOpGrad");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -320,21 +311,6 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(
ctx->HasInput("XShape"), "Input", "XShape", "Transpose2OpGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"Transpose2OpGrad");
if (ctx->HasOutput(framework::GradVarName("X"))) {
auto xshape_dim = ctx->GetInputDim("XShape");
auto x_shape_dim = phi::slice_ddim(xshape_dim, 1, xshape_dim.size());
ctx->SetOutputDim(framework::GradVarName("X"), x_shape_dim);
ctx->ShareLoD("XShape", framework::GradVarName("X"));
}
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -359,6 +335,13 @@ class TransposeGradInferVarType : public framework::VarTypeInference {
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(transpose_grad,
TransposeGradInferShapeFunctor,
PD_INFER_META(phi::TransposeGradInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(transpose2_grad,
Transpose2GradInferShapeFunctor,
PD_INFER_META(phi::TransposeGradInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(
transpose,
......@@ -368,7 +351,8 @@ REGISTER_OPERATOR(
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(transpose_grad,
ops::TransposeOpGrad,
ops::TransposeGradInferVarType);
ops::TransposeGradInferVarType,
TransposeGradInferShapeFunctor);
REGISTER_OPERATOR(transpose2,
ops::Transpose2Op,
......@@ -379,4 +363,5 @@ REGISTER_OPERATOR(transpose2_grad,
ops::Transpose2OpGrad,
ops::TransposeGradInferVarType,
ops::Transpose2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>);
ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>,
Transpose2GradInferShapeFunctor);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册