未验证 提交 27996fd1 编写于 作者: Y YuanRisheng 提交者: GitHub

[Phi] Move backward infershape of Reshape Op (#40914)

* perfect reshape kernel

* fix bugs of sig

* add unittest for reshape_sig

* fix bugs when run converage
上级 287cbde8
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
...@@ -365,10 +366,14 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { ...@@ -365,10 +366,14 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel {
"FlattenContiguousRangeGrad"); "FlattenContiguousRangeGrad");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input", OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "FlattenContiguousRangeGrad"); framework::GradVarName("Out"), "FlattenContiguousRangeGrad");
auto xshape_dims = context->GetInputDim("XShape"); // Construct MetaTensor for InferMeta Func
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); using CompatMetaTensor = framework::CompatMetaTensor;
context->SetOutputDim(framework::GradVarName("X"), x_dims); CompatMetaTensor xshape(context->GetInputVarPtrs("XShape")[0],
context->ShareLoD("XShape", framework::GradVarName("X")); context->IsRuntime());
CompatMetaTensor dx(
context->GetOutputVarPtrs(framework::GradVarName("X"))[0],
context->IsRuntime());
phi::KernelWithXShapeInferMeta(xshape, &dx);
} }
protected: protected:
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/reshape_grad_kernel.h" #include "paddle/phi/kernels/reshape_grad_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h"
...@@ -558,10 +559,14 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -558,10 +559,14 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(Out@GRAD) shouldn't be null.")); "Input(Out@GRAD) shouldn't be null."));
auto xshape_dims = ctx->GetInputDim("XShape");
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); // Construct MetaTensor for InferMeta Func
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); using CompatMetaTensor = framework::CompatMetaTensor;
ctx->ShareLoD("XShape", framework::GradVarName("X")); CompatMetaTensor xshape(ctx->GetInputVarPtrs("XShape")[0],
ctx->IsRuntime());
CompatMetaTensor dx(ctx->GetOutputVarPtrs(framework::GradVarName("X"))[0],
ctx->IsRuntime());
phi::KernelWithXShapeInferMeta(xshape, &dx);
} }
protected: protected:
...@@ -592,15 +597,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { ...@@ -592,15 +597,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("DDX"), true,
platform::errors::InvalidArgument(
"Input(X@GRAD_GRAD) shouldn't be null."));
if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) {
ctx->ShareDim("DOut", "DDOut");
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -659,9 +655,15 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, ...@@ -659,9 +655,15 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>, ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>, ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
ops::ReshapeGradInplaceInferer); ops::ReshapeGradInplaceInferer);
DECLARE_INFER_SHAPE_FUNCTOR(reshape2_grad_grad,
Reshape2DoubleGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInferer, ops::ReshapeDoubleGradInplaceInferer,
ops::ReshapeDoubleGradOpNoNeedBufferVarInferer); ops::ReshapeDoubleGradOpNoNeedBufferVarInferer,
Reshape2DoubleGradInferShapeFunctor);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
......
...@@ -161,6 +161,13 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out, ...@@ -161,6 +161,13 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
dx->share_meta(dout); dx->share_meta(dout);
} }
void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx) {
auto xshape_dims = xshape.dims();
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
dx->set_dims(x_dims);
dx->share_lod(xshape);
}
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask, const MetaTensor& mask,
const MetaTensor& dout, const MetaTensor& dout,
......
...@@ -92,6 +92,8 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out, ...@@ -92,6 +92,8 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
int axis, int axis,
MetaTensor* dx); MetaTensor* dx);
void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx);
void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask, const MetaTensor& mask,
const MetaTensor& dout, const MetaTensor& dout,
......
...@@ -28,8 +28,15 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -28,8 +28,15 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"}); "reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"});
} }
} else {
if (ctx.InputSize("ShapeTensor") > 0) {
return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"});
} else if (ctx.HasInput("Shape")) {
return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"});
} else {
return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
}
} }
return KernelSignature("unregistered", {}, {}, {});
} }
KernelSignature ReshapeGradOpArgumentMapping( KernelSignature ReshapeGradOpArgumentMapping(
......
...@@ -577,5 +577,23 @@ TEST(ARG_MAP, allclose) { ...@@ -577,5 +577,23 @@ TEST(ARG_MAP, allclose) {
ASSERT_EQ(attr_names2[1], "Atol"); ASSERT_EQ(attr_names2[1], "Atol");
} }
TEST(ARG_MAP, reshape) {
TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"});
auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case1);
ASSERT_EQ(signature1.name, "reshape");
TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"});
auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case2);
ASSERT_EQ(signature2.name, "reshape");
TestArgumentMappingContext arg_case3(
{"X"}, {}, {{"shape", paddle::any(std::vector<int>({1, 2}))}}, {"Out"});
auto signature3 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case3);
ASSERT_EQ(signature3.name, "reshape");
}
} // namespace tests } // namespace tests
} // namespace phi } // namespace phi
...@@ -57,7 +57,7 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -57,7 +57,7 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
} }
size_t InputSize(const std::string& name) const override { size_t InputSize(const std::string& name) const override {
return dense_tensor_inputs.size() + selected_rows_inputs.size(); return dense_tensor_inputs.count(name) + selected_rows_inputs.count(name);
} }
size_t OutputSize(const std::string& name) const override { size_t OutputSize(const std::string& name) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册