未验证 提交 27996fd1 编写于 作者: Y YuanRisheng 提交者: GitHub

[Phi] Move backward infershape of Reshape Op (#40914)

* perfect reshape kernel

* fix bugs of sig

* add unittest for reshape_sig

* fix bugs when run converage
上级 287cbde8
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
......@@ -365,10 +366,14 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel {
"FlattenContiguousRangeGrad");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "FlattenContiguousRangeGrad");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
// Construct MetaTensor for InferMeta Func
using CompatMetaTensor = framework::CompatMetaTensor;
CompatMetaTensor xshape(context->GetInputVarPtrs("XShape")[0],
context->IsRuntime());
CompatMetaTensor dx(
context->GetOutputVarPtrs(framework::GradVarName("X"))[0],
context->IsRuntime());
phi::KernelWithXShapeInferMeta(xshape, &dx);
}
protected:
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/reshape_grad_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
......@@ -558,10 +559,14 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) shouldn't be null."));
auto xshape_dims = ctx->GetInputDim("XShape");
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("XShape", framework::GradVarName("X"));
// Construct MetaTensor for InferMeta Func
using CompatMetaTensor = framework::CompatMetaTensor;
CompatMetaTensor xshape(ctx->GetInputVarPtrs("XShape")[0],
ctx->IsRuntime());
CompatMetaTensor dx(ctx->GetOutputVarPtrs(framework::GradVarName("X"))[0],
ctx->IsRuntime());
phi::KernelWithXShapeInferMeta(xshape, &dx);
}
protected:
......@@ -592,15 +597,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("DDX"), true,
platform::errors::InvalidArgument(
"Input(X@GRAD_GRAD) shouldn't be null."));
if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) {
ctx->ShareDim("DOut", "DDOut");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -659,9 +655,15 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
ops::ReshapeGradInplaceInferer);
DECLARE_INFER_SHAPE_FUNCTOR(reshape2_grad_grad,
Reshape2DoubleGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInferer,
ops::ReshapeDoubleGradOpNoNeedBufferVarInferer);
ops::ReshapeDoubleGradOpNoNeedBufferVarInferer,
Reshape2DoubleGradInferShapeFunctor);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
......
......@@ -161,6 +161,13 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
dx->share_meta(dout);
}
void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx) {
auto xshape_dims = xshape.dims();
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
dx->set_dims(x_dims);
dx->share_lod(xshape);
}
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
......
......@@ -92,6 +92,8 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
int axis,
MetaTensor* dx);
void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx);
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
......
......@@ -28,8 +28,15 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"});
}
} else {
if (ctx.InputSize("ShapeTensor") > 0) {
return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"});
} else if (ctx.HasInput("Shape")) {
return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"});
} else {
return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
}
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ReshapeGradOpArgumentMapping(
......
......@@ -577,5 +577,23 @@ TEST(ARG_MAP, allclose) {
ASSERT_EQ(attr_names2[1], "Atol");
}
TEST(ARG_MAP, reshape) {
TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"});
auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case1);
ASSERT_EQ(signature1.name, "reshape");
TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"});
auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case2);
ASSERT_EQ(signature2.name, "reshape");
TestArgumentMappingContext arg_case3(
{"X"}, {}, {{"shape", paddle::any(std::vector<int>({1, 2}))}}, {"Out"});
auto signature3 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case3);
ASSERT_EQ(signature3.name, "reshape");
}
} // namespace tests
} // namespace phi
......@@ -57,7 +57,7 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
}
size_t InputSize(const std::string& name) const override {
return dense_tensor_inputs.size() + selected_rows_inputs.size();
return dense_tensor_inputs.count(name) + selected_rows_inputs.count(name);
}
size_t OutputSize(const std::string& name) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册