From 312f018730c31fd02ef291539ee59d08aea7ee86 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Tue, 16 May 2023 14:20:39 +0800 Subject: [PATCH] static graph autogen code support for softmax op (#53581) * static graph autogen code support for softmax op * bug fixed * fix PR-CI-Windows error * fix CI error * bug fixed * fix conflicts --- .../generator/get_expected_kernel_func.cc | 39 ++++ .../generator/get_expected_kernel_func.h | 8 + paddle/fluid/operators/softmax_op.cc | 199 ------------------ paddle/fluid/operators/unity_build_rule.cmake | 3 +- paddle/phi/api/yaml/op_compat.yaml | 7 +- paddle/phi/api/yaml/static_backward.yaml | 11 + paddle/phi/api/yaml/static_ops.yaml | 10 + paddle/phi/ops/compat/softmax_sig.cc | 32 --- test/cpp/fluid/CMakeLists.txt | 2 +- test/cpp/fluid/mkldnn/CMakeLists.txt | 9 +- 10 files changed, 81 insertions(+), 239 deletions(-) delete mode 100644 paddle/fluid/operators/softmax_op.cc delete mode 100644 paddle/phi/ops/compat/softmax_sig.cc diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.cc b/paddle/fluid/operators/generator/get_expected_kernel_func.cc index 49697a48a17..026d57bba4f 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.cc +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.cc @@ -139,6 +139,45 @@ phi::KernelKey GetSgdExpectedKernelType( return phi::KernelKey(data_type, ctx.GetPlace()); } +phi::KernelKey GetSoftmaxExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr) { + // choose cudnn kernel if the runtime supported. + std::string data_format = ctx.Attr("data_format"); + phi::DataLayout layout_ = phi::StringToDataLayout(data_format); + auto input_data_type = op_ptr->IndicateVarDataType(ctx, "X"); + if (input_data_type == framework::proto::VarType::FP16) { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()) || + platform::is_xpu_place(ctx.GetPlace()) || + platform::is_custom_place(ctx.GetPlace()), + true, + platform::errors::InvalidArgument( + "float16 can only be used on GPU/XPU and custom place")); + } + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); +} + +phi::KernelKey GetSoftmaxGradExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr) { + // choose cudnn kernel if the runtime supported. + std::string data_format = ctx.Attr("data_format"); + phi::DataLayout layout_ = phi::StringToDataLayout(data_format); + auto input_data_type = + op_ptr->IndicateVarDataType(ctx, framework::GradVarName("Out")); + if (input_data_type == framework::proto::VarType::FP16) { + if (!(platform::is_gpu_place(ctx.GetPlace()) || + platform::is_xpu_place(ctx.GetPlace()) || + platform::is_custom_place(ctx.GetPlace()))) + PADDLE_THROW(platform::errors::InvalidArgument( + "float16 can only be used on GPU/XPU and custom place")); + } + return phi::KernelKey( + ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); +} + phi::KernelKey GetUpdateLossScalingExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr) { diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.h b/paddle/fluid/operators/generator/get_expected_kernel_func.h index 4ef88909984..7923c8d79fb 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.h +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.h @@ -36,6 +36,14 @@ phi::KernelKey GetSgdExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); +phi::KernelKey GetSoftmaxExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr); + +phi::KernelKey GetSoftmaxGradExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr); + phi::KernelKey GetUpdateLossScalingExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc deleted file mode 100644 index 2fb7883cb3f..00000000000 --- a/paddle/fluid/operators/softmax_op.cc +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" -#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" -#include "paddle/fluid/prim/utils/static/desc_tensor.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class SoftmaxOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - // choose cudnn kernel if the runtime supported. - std::string data_format = ctx.Attr("data_format"); - phi::DataLayout layout_ = phi::StringToDataLayout(data_format); - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - if (input_data_type == framework::proto::VarType::FP16) { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()) || - platform::is_xpu_place(ctx.GetPlace()) || - platform::is_custom_place(ctx.GetPlace()), - true, - platform::errors::InvalidArgument( - "float16 can only be used on GPU/XPU and custom place")); - } - return phi::KernelKey( - ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); - } -}; - -class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "The input tensor of softmax, " - "whose dimension :attr:`axis` is the input_feature_dimensions."); - AddOutput("Out", "The normalized values with the same shape as X."); - AddAttr("axis", - "The dimension index of Input(x) to perform softmax," - "default -1 for last dimension") - .SetDefault(-1); - AddAttr( - "data_format", - "(string, default NCHW) Only used in " - "An optional string from: \"NHWC\", \"NCHW\". " - "Defaults to \"NHWC\". Specify the data format of the output data, " - "the input will be transformed automatically. ") - .SetDefault("AnyLayout"); - AddAttr( - "use_cudnn", - "(bool, default false) Only used in cudnn kernel, need install cudnn") - .SetDefault(false) - .AsExtra(); - AddComment(R"DOC( -Softmax Operator. - -The input of the softmax operator is a tensor of any rank. The output tensor -has the same shape as the input. - -The dimension :attr:`axis` of the input tensor will be permuted to the last. -Then the input tensor will be logically flattened to a 2-D matrix. The matrix's -second dimension(row length) is as same as the dimension :attr:`axis` of the input -tensor, and the first dimension(column length) is the product of all other -dimensions of the input tensor. For each row of the matrix, the softmax operator -squashes the K-dimensional(K is the width of the matrix, which is also the size -of the input tensor's dimension :attr:`axis`) vector of arbitrary real values to a -K-dimensional vector of real values in the range [0, 1] that add up to 1. -It computes the exponential of the given dimension and the sum of exponential -values of all the other dimensions in the K-dimensional vector input. -Then the ratio of the exponential of the given dimension and the sum of -exponential values of all the other dimensions is the output of the softmax -operator. - -For each row $i$ and each column $j$ in the matrix, we have: - $$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$ - -)DOC"); - } -}; - -class SoftmaxOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map& GetInputOutputWithSameType() - const override { - static std::unordered_map m{{"X", /*->*/ "Out"}}; - return m; - } -}; - -class SoftmaxOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - // choose cudnn kernel if the runtime supported. - std::string data_format = ctx.Attr("data_format"); - phi::DataLayout layout_ = phi::StringToDataLayout(data_format); - auto input_data_type = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - if (input_data_type == framework::proto::VarType::FP16) { - if (!(platform::is_gpu_place(ctx.GetPlace()) || - platform::is_xpu_place(ctx.GetPlace()) || - platform::is_custom_place(ctx.GetPlace()))) - PADDLE_THROW(platform::errors::InvalidArgument( - "float16 can only be used on GPU/XPU and custom place")); - } - return phi::KernelKey( - ctx.GetPlace(), layout_, phi::TransToPhiDataType(input_data_type)); - } -}; - -template -class SoftmaxOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("softmax_grad"); - - op->SetInput("Out", this->Output("Out")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - - op->SetAttrMap(this->Attrs()); - - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - } -}; - -class SoftmaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { - using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; - - public: - void Apply() override { - paddle::Tensor out = this->GetSingleForwardOutput("Out"); - paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); - paddle::Tensor dx = this->GetSingleInputGrad("X"); - auto* dx_ptr = this->GetOutputPtr(&dx); - std::string dx_name = this->GetOutputName(dx); - int axis = static_cast(this->Attr("axis")); - VLOG(6) << "Runing softmax_grad composite func"; - prim::softmax_grad(out, out_grad, axis, dx_ptr); - this->RecoverOutputName(dx, dx_name); - } -}; - -DECLARE_INPLACE_OP_INFERER(SoftmaxInplaceInferer, {"X", "Out"}); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(softmax, - SoftmaxInferShapeFunctor, - PD_INFER_META(phi::SoftmaxInferMeta)); -REGISTER_OPERATOR(softmax, - ops::SoftmaxOp, - ops::SoftmaxOpMaker, - ops::SoftmaxOpInferVarType, - ops::SoftmaxOpGradMaker, - ops::SoftmaxOpGradMaker, - ops::SoftmaxCompositeGradOpMaker, - ops::SoftmaxInplaceInferer, - SoftmaxInferShapeFunctor); -DECLARE_INFER_SHAPE_FUNCTOR(softmax_grad, - SoftmaxGradInferShapeFunctor, - PD_INFER_META(phi::GeneralUnaryGradInferMeta)); -REGISTER_OPERATOR(softmax_grad, - ops::SoftmaxOpGrad, - SoftmaxGradInferShapeFunctor); diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 6eab7b3f44e..740402b155f 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -259,8 +259,7 @@ register_unity_group( sign_op.cc similarity_focus_op.cc size_op.cc - slice_op.cc - softmax_op.cc) + slice_op.cc) register_unity_group( cc space_to_depth_op.cc diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index ab8274ac601..7087e913e94 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2147,8 +2147,13 @@ backward : softmax_grad inputs : x : X + outputs : + out : Out + get_expected_kernel_type : + softmax : GetSoftmaxExpectedKernelType + softmax_grad : GetSoftmaxGradExpectedKernelType extra : - attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] + attrs : [str data_format = "AnyLayout", bool use_cudnn=false, bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] - op : softplus backward : softplus_grad, softplus_double_grad diff --git a/paddle/phi/api/yaml/static_backward.yaml b/paddle/phi/api/yaml/static_backward.yaml index 3bda010909c..5d48977ebab 100755 --- a/paddle/phi/api/yaml/static_backward.yaml +++ b/paddle/phi/api/yaml/static_backward.yaml @@ -53,3 +53,14 @@ kernel : func : rnn_grad data_type: out_grad + +- backward_op : softmax_grad + forward : softmax (Tensor x, int axis=-1) -> Tensor(out) + args : (Tensor out, Tensor out_grad, int axis) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] + kernel : + func : softmax_grad + composite : softmax_grad(out, out_grad, axis, x_grad) diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index d1d07681ed1..f481cedbc1c 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -335,6 +335,16 @@ kernel : func : share_buffer +- op : softmax + args : (Tensor x, int axis = -1) + output : Tensor(out) + infer_meta : + func : SoftmaxInferMeta + kernel : + func : softmax + inplace : (x -> out) + backward : softmax_grad + - op : tril_indices args : (int rows = 0, int cols = 0, int offset = 0, DataType dtype = DataType::INT64) output : Tensor(out) diff --git a/paddle/phi/ops/compat/softmax_sig.cc b/paddle/phi/ops/compat/softmax_sig.cc deleted file mode 100644 index a30a2a2b06f..00000000000 --- a/paddle/phi/ops/compat/softmax_sig.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature SoftmaxOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("softmax", {"X"}, {"axis"}, {"Out"}); -} - -KernelSignature SoftmaxGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "softmax_grad", {"Out", "Out@GRAD"}, {"axis"}, {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(softmax, phi::SoftmaxOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(softmax_grad, phi::SoftmaxGradOpArgumentMapping); diff --git a/test/cpp/fluid/CMakeLists.txt b/test/cpp/fluid/CMakeLists.txt index 7bc33dfda33..590816d1b5e 100644 --- a/test/cpp/fluid/CMakeLists.txt +++ b/test/cpp/fluid/CMakeLists.txt @@ -42,7 +42,7 @@ cc_test( test_common_infer_shape_functions SRCS test_common_infer_shape_functions.cc DEPS common_infer_shape_functions ${COMMON_OP_DEPS} activation_op - elementwise_add_op softmax_op softmax) + elementwise_add_op softmax generated_static_op) cc_test( gather_test SRCS gather_test.cc diff --git a/test/cpp/fluid/mkldnn/CMakeLists.txt b/test/cpp/fluid/mkldnn/CMakeLists.txt index 34f3ce8959c..dae56ea5eb6 100644 --- a/test/cpp/fluid/mkldnn/CMakeLists.txt +++ b/test/cpp/fluid/mkldnn/CMakeLists.txt @@ -4,26 +4,27 @@ cc_test( DEPS op_registry elementwise_add_op activation_op - softmax_op softmax scope device_context enforce - executor) + executor + generated_static_op) set(TEST_MKLDNN_CACHING_DEPS op_registry elementwise_mul_op elementwise_add_op activation_op - softmax_op conv_op im2col vol2col softmax scope device_context - enforce) + enforce + generated_static_op) + if(WITH_GPU OR WITH_ROCM) set(TEST_MKLDNN_CACHING_DEPS ${TEST_MKLDNN_CACHING_DEPS} depthwise_conv) endif() -- GitLab