From ec09ef260f35f11de2436edc6f40839c810b7357 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 12 Mar 2022 20:44:29 +0800 Subject: [PATCH] [Phi] Add softmax infermeta functions (#40471) * rename softmax kernel name * move softmax infershape * fix failed test --- .../mkldnn/test_mkldnn_op_inplace.cc | 3 + paddle/fluid/operators/softmax_op.cc | 55 ++++--------------- paddle/phi/infermeta/backward.cc | 6 ++ paddle/phi/infermeta/backward.h | 2 + paddle/phi/infermeta/unary.cc | 19 +++++++ paddle/phi/infermeta/unary.h | 2 + paddle/phi/kernels/cpu/softmax_kernel.cc | 2 +- paddle/phi/kernels/gpu/softmax_kernel.cu | 2 +- paddle/phi/kernels/gpudnn/softmax_kernel.cu | 14 ++--- paddle/phi/kernels/impl/softmax_kernel_impl.h | 8 +-- paddle/phi/kernels/softmax_kernel.h | 12 +--- 11 files changed, 58 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc index e9dadd5ec9..4090d5ffca 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc @@ -24,6 +24,7 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/core/kernel_registry.h" USE_OP_ITSELF(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); @@ -32,6 +33,8 @@ USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_ITSELF(softmax); USE_OP_DEVICE_KERNEL(softmax, MKLDNN); +PD_DECLARE_KERNEL(softmax, CPU, ALL_LAYOUT); + namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 3749920966..af90baf27d 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" @@ -23,6 +24,10 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { @@ -30,30 +35,6 @@ class SoftmaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::NotFound("Input(X) of SoftmaxOp is not found.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::NotFound("Output(Out) of SoftmaxOp is not found.")); - - auto dim_x = ctx->GetInputDim("X"); - auto rank_x = dim_x.size(); - auto axis = ctx->Attrs().Get("axis"); - PADDLE_ENFORCE_GE(axis, -rank_x, - platform::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], " - "R is the rank of Input(X).")); - PADDLE_ENFORCE_LT(axis, rank_x, - platform::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], " - "R is the rank of Input(X).")); - - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -168,23 +149,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("Out"), true, - platform::errors::InvalidArgument("Input(Out) is not found.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput(framework::GradVarName("Out")), true, - platform::errors::InvalidArgument("Input(Out@GRAD) is not found.")); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("Out"), - ctx->GetInputDim(framework::GradVarName("Out")), - platform::errors::InvalidArgument("Input(Out) and its gradients " - "should have a same shape.")); - - ctx->SetOutputDim(framework::GradVarName("X"), - ctx->GetInputDim(framework::GradVarName("Out"))); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -244,9 +208,14 @@ DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"}); namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(softmax, SoftmaxInferShapeFunctor, + PD_INFER_META(phi::SoftmaxInferMeta)); REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ops::SoftmaxOpInferVarType, ops::SoftmaxOpGradMaker, ops::SoftmaxOpGradMaker, - ops::SoftmaxInplaceInferer); -REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad); + ops::SoftmaxInplaceInferer, SoftmaxInferShapeFunctor); +DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad, SoftmaxGradnferShapeFunctor, + PD_INFER_META(phi::GeneralUnaryGradInferMeta)); +REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad, + SoftmaxGradnferShapeFunctor); diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 0a2b4dcae5..801bd98b50 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -64,6 +64,12 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x, } } +void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) { + if (dx) { + dx->share_meta(x); + } +} + void GeneralBinaryGradInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* dx, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index c4003ca1fe..9ed24ef864 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -30,6 +30,8 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x, MetaTensor* dweight, MetaTensor* dbias); +void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx); + void GeneralBinaryGradInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* dx, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 9daad7d6aa..1b82051047 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1409,6 +1409,25 @@ void ShardIndexInferMeta(const MetaTensor& in, out->set_dtype(in.dtype()); } +void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out) { + auto dim_x = x.dims(); + auto rank_x = dim_x.size(); + PADDLE_ENFORCE_GE(axis, + -rank_x, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(X).")); + PADDLE_ENFORCE_LT(axis, + rank_x, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(X).")); + + out->set_dims(x.dims()); + out->set_dtype(x.dtype()); + out->share_lod(x); +} + } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index e8be73e943..c7b7f8e3c1 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -203,4 +203,6 @@ void ShardIndexInferMeta(const MetaTensor& in, MetaTensor* out, MetaConfig config = MetaConfig()); +void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/softmax_kernel.cc b/paddle/phi/kernels/cpu/softmax_kernel.cc index 537b432668..1d28669571 100644 --- a/paddle/phi/kernels/cpu/softmax_kernel.cc +++ b/paddle/phi/kernels/cpu/softmax_kernel.cc @@ -19,4 +19,4 @@ limitations under the License. */ #include "paddle/phi/kernels/impl/softmax_kernel_impl.h" PD_REGISTER_KERNEL( - softmax, CPU, ALL_LAYOUT, phi::SoftmaxRawKernel, float, double) {} + softmax, CPU, ALL_LAYOUT, phi::SoftmaxKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/softmax_kernel.cu b/paddle/phi/kernels/gpu/softmax_kernel.cu index 03c5714b96..4a02f438c7 100644 --- a/paddle/phi/kernels/gpu/softmax_kernel.cu +++ b/paddle/phi/kernels/gpu/softmax_kernel.cu @@ -23,7 +23,7 @@ limitations under the License. */ PD_REGISTER_KERNEL(softmax, GPU, ALL_LAYOUT, - phi::SoftmaxRawKernel, + phi::SoftmaxKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/gpudnn/softmax_kernel.cu b/paddle/phi/kernels/gpudnn/softmax_kernel.cu index 7685c7dbb6..37175c427f 100644 --- a/paddle/phi/kernels/gpudnn/softmax_kernel.cu +++ b/paddle/phi/kernels/gpudnn/softmax_kernel.cu @@ -21,10 +21,10 @@ limitations under the License. */ namespace phi { template -void SoftmaxRawGPUDNNKernel(const Context& dev_ctx, - const DenseTensor& x, - int axis, - DenseTensor* out) { +void SoftmaxGPUDNNKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DenseTensor* out) { dev_ctx.template Alloc(out); SoftmaxForwardCUDAKernelDriver(dev_ctx, x, axis, out); } @@ -35,7 +35,7 @@ void SoftmaxRawGPUDNNKernel(const Context& dev_ctx, PD_REGISTER_KERNEL(softmax, GPUDNN, ALL_LAYOUT, - phi::SoftmaxRawGPUDNNKernel, + phi::SoftmaxGPUDNNKernel, float, phi::dtype::float16, phi::dtype::bfloat16) {} @@ -44,7 +44,7 @@ PD_REGISTER_KERNEL(softmax, PD_REGISTER_KERNEL(softmax, GPUDNN, ALL_LAYOUT, - phi::SoftmaxRawGPUDNNKernel, + phi::SoftmaxGPUDNNKernel, float, double, phi::dtype::float16, @@ -53,7 +53,7 @@ PD_REGISTER_KERNEL(softmax, PD_REGISTER_KERNEL(softmax, GPUDNN, ALL_LAYOUT, - phi::SoftmaxRawGPUDNNKernel, + phi::SoftmaxGPUDNNKernel, float, double, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/softmax_kernel_impl.h b/paddle/phi/kernels/impl/softmax_kernel_impl.h index 6552f6ed58..7aa43fdb7f 100644 --- a/paddle/phi/kernels/impl/softmax_kernel_impl.h +++ b/paddle/phi/kernels/impl/softmax_kernel_impl.h @@ -22,10 +22,10 @@ limitations under the License. */ namespace phi { template -void SoftmaxRawKernel(const Context& dev_ctx, - const DenseTensor& x, - int axis, - DenseTensor* out) { +void SoftmaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DenseTensor* out) { const int rank = x.dims().size(); const int calc_axis = phi::funcs::CanonicalAxis(axis, rank); int axis_dim = x.dims()[calc_axis]; diff --git a/paddle/phi/kernels/softmax_kernel.h b/paddle/phi/kernels/softmax_kernel.h index ca69d65277..4edd562ca8 100644 --- a/paddle/phi/kernels/softmax_kernel.h +++ b/paddle/phi/kernels/softmax_kernel.h @@ -19,20 +19,10 @@ limitations under the License. */ namespace phi { -template -void SoftmaxRawKernel(const Context& dev_ctx, - const DenseTensor& x, - int axis, - DenseTensor* out); - template void SoftmaxKernel(const Context& dev_ctx, const DenseTensor& x, int axis, - DataType dtype, - DenseTensor* out) { - auto cast_x = phi::Cast(dev_ctx, x, dtype); - phi::SoftmaxRawKernel(dev_ctx, axis, out); -} + DenseTensor* out); } // namespace phi -- GitLab