From 967dee45955c6c7347a385ff2aa340fd66e4e588 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 31 Mar 2023 16:38:36 +0800 Subject: [PATCH] Add Yaml config for some op (#52347) * add yaml for some op * fix inplace_abn * fix test_leaky_relu_grad_grad_functor * fix yaml * fix typo --- paddle/fluid/operators/activation_op.cc | 21 --- paddle/fluid/operators/activation_op.h | 3 - paddle/fluid/operators/inplace_abn_op.h | 2 + paddle/fluid/operators/nanmedian_op.cc | 131 ------------------ .../test_leaky_relu_grad_grad_functor.h | 2 + paddle/phi/api/yaml/backward.yaml | 19 +++ paddle/phi/api/yaml/fused_backward.yaml | 7 +- paddle/phi/api/yaml/fused_ops.yaml | 7 +- paddle/phi/api/yaml/op_compat.yaml | 19 +++ paddle/phi/api/yaml/ops.yaml | 20 +++ paddle/phi/ops/compat/activation_sig.cc | 8 +- paddle/phi/ops/compat/nanmedian_sig.cc | 35 ----- python/paddle/tensor/math.py | 2 +- python/paddle/tensor/stat.py | 7 +- 14 files changed, 74 insertions(+), 209 deletions(-) delete mode 100644 paddle/fluid/operators/nanmedian_op.cc delete mode 100644 paddle/phi/ops/compat/nanmedian_sig.cc diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 81c7f054214..c7cea3122fe 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -178,26 +178,6 @@ $$out = \min(\max(0, x), threshold)$$ } }; -class STanhOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "Input of STanh operator." - " A Tensor with type float32, float64."); - AddOutput("Out", "Output of STanh operator. A Tensor with type float32."); - AddAttr("scale_a", "The scale parameter of a for the input. ") - .SetDefault(0.67f); - AddAttr("scale_b", "The scale parameter of b for the input") - .SetDefault(1.7159f); - AddComment(R"DOC( -STanh Activation Operator. - -$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ - -)DOC"); - } -}; - class SwishOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -436,7 +416,6 @@ REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu) REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor); REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor); -REGISTER_ACTIVATION_OP(stanh, STanh, STanhFunctor, STanhGradFunctor); REGISTER_ACTIVATION_OP(hard_swish, HardSwish, HardSwishFunctor, diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index a4720905ad5..609938f1e66 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -268,8 +268,6 @@ using BReluGradFunctor = phi::funcs::HardTanhGradFunctor; USE_PHI_FUNCTOR(Tanh) USE_PHI_FUNCTOR(Relu6) -USE_PHI_FUNCTOR(LeakyRelu) -USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu) USE_PHI_FUNCTOR(HardShrink) USE_PHI_FUNCTOR(ELU) USE_PHI_FUNCTOR(Sigmoid) @@ -278,7 +276,6 @@ USE_PHI_FUNCTOR(Swish) USE_PHI_FUNCTOR(HardSwish) USE_PHI_FUNCTOR(Pow) USE_PHI_FUNCTOR(Mish) -USE_PHI_FUNCTOR(STanh) template using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor; diff --git a/paddle/fluid/operators/inplace_abn_op.h b/paddle/fluid/operators/inplace_abn_op.h index 29253662d4d..abdb1e33aaa 100644 --- a/paddle/fluid/operators/inplace_abn_op.h +++ b/paddle/fluid/operators/inplace_abn_op.h @@ -22,6 +22,8 @@ namespace paddle { namespace operators { +USE_PHI_FUNCTOR(LeakyRelu) + template diff --git a/paddle/fluid/operators/nanmedian_op.cc b/paddle/fluid/operators/nanmedian_op.cc deleted file mode 100644 index f0bc985f3ea..00000000000 --- a/paddle/fluid/operators/nanmedian_op.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class NanmedianOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); - } -}; - -class NanmedianOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Tensor), " - "the input feature data of NanmedianOp, dtype should be" - "int32, int64, float16, float32 or float64."); - AddOutput( - "MedianIndex", - "Store the index position of median values, The calculation differs " - "in the odd or even valid elements numbers." - "Along the axis, two elements contributed to the median value in " - "each row." - "If the amount of valid elements were even, both were the same.") - .AsIntermediate() - .AsExtra(); - AddOutput("Out", - "(Tensor)," - " the output of NanmedianOp, whose dtype is the same as X"); - AddAttr("keepdim", - "(bool, default true) " - "If true, retain the reduced axis with length 1.") - .SetDefault(true); - AddAttr>("axis", - "(std::vector). List of integers," - " indicating the dimensions to calculate medians") - .SetDefault({}); - AddComment(R"DOC( - Nanmedian operator - - This operator is considered as an extention of median operation, - which supports specifically the case of NaN values in the input. - - If all the elements in input are NaN it will also return NaN. - If no elements in input are Nan, this op is identical to thie median op. - - If the valid count of elements is a even number, the average value of - the elements in the middle is calculated as the median. - - This operator can also supports multiple axis. - )DOC"); - } -}; - -template -class NanmedianGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr op) const override { - op->SetType("nanmedian_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput("MedianIndex", this->Output("MedianIndex")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); - } -}; - -class NanmedianGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(nanmedian, - NanmedianInferShapeFunctor, - PD_INFER_META(phi::NanmedianInferMeta)); - -REGISTER_OPERATOR(nanmedian, - ops::NanmedianOp, - ops::NanmedianOpMaker, - ops::NanmedianGradMaker, - ops::NanmedianGradMaker, - NanmedianInferShapeFunctor); - -DECLARE_INFER_SHAPE_FUNCTOR(nanmedian_grad, - NanmedianGradInferShapeFunctor, - PD_INFER_META(phi::NanmedianGradInferMeta)); - -REGISTER_OPERATOR(nanmedian_grad, - ops::NanmedianGradOp, - NanmedianGradInferShapeFunctor); diff --git a/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h b/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h index 31f913cc65b..5b51d8ddb00 100644 --- a/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h +++ b/paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h @@ -24,6 +24,8 @@ namespace paddle { namespace operators { +USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu) + template static void InitRandom(phi::DenseTensor *tensor, const platform::Place &place) { phi::DenseTensor cpu_tensor; diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 858b2d15609..bc85a1d0ca7 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1068,6 +1068,15 @@ kernel : func : mv_grad +- backward_op : nanmedian_grad + forward : nanmedian (Tensor x, IntArray axis, bool keepdim) -> Tensor(out), Tensor(medians) + args : (Tensor x, Tensor medians, Tensor out_grad, IntArray axis, bool keepdim) + output : Tensor(x_grad) + infer_meta : + func : NanmedianGradInferMeta + kernel : + func : nanmedian_grad + - backward_op : nearest_interp_grad forward : nearest_interp (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout="NCHW", int out_d=0, int out_h=0, int out_w=0, float[] scale={}, str interp_method="bilinear", bool align_corners=true, int align_mode=1) -> Tensor(output) args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, Tensor output_grad, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode) @@ -1647,6 +1656,16 @@ data_type : out_grad no_need_buffer : x +- backward_op : stanh_grad + forward : stanh(Tensor x, float scale_a, float scale_b) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float scale_a, float scale_b) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : stanh_grad + - backward_op : svd_grad forward : svd (Tensor x, bool full_matrices = false) -> Tensor(u), Tensor(s), Tensor(vh) args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full_matrices) diff --git a/paddle/phi/api/yaml/fused_backward.yaml b/paddle/phi/api/yaml/fused_backward.yaml index 2fbd12d06f6..95914828461 100644 --- a/paddle/phi/api/yaml/fused_backward.yaml +++ b/paddle/phi/api/yaml/fused_backward.yaml @@ -1,7 +1,8 @@ # This file is designed for fusion C++ backward operators, which manages the -# generated code for dynamic mode and static mode. -# The operators in the file have extra configuration item "support_dygraph_mode". -# If one operator have "support_dygraph_mode : True", it supports dygraph mode. +# generated code for static mode and dynamic mode (when `support_dygraph_mode` is true). +# "support_dygraph_mode" is and extra configuration item in this file, +# if one operator have "support_dygraph_mode : true", it supports dygraph mode, +# otherwise the operator only could be used in static mode. - backward_op : fused_dropout_add_grad forward : fused_dropout_add (Tensor x, Tensor y, Scalar p, bool is_test, str mode, int seed, bool fix_seed) -> Tensor(out), Tensor(seed_offset) diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 071943d3a8c..23435c210b4 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -1,7 +1,8 @@ # This file is designed for fusion C++ farward operators, which manages the -# generated code for dynamic mode and static mode. -# The operators in the file have extra configuration item "support_dygraph_mode". -# If one operator have "support_dygraph_mode : True", it supports dygraph mode. +# generated code for static mode and dynamic mode (when `support_dygraph_mode` is true). +# "support_dygraph_mode" is and extra configuration item in this file, +# if one operator have "support_dygraph_mode : true", it supports dygraph mode, +# otherwise the operator only could be used in static mode. - op : embedding_with_eltwise_add_xpu args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 15ca92c78b4..e6f1d597f58 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1321,6 +1321,18 @@ outputs : out : Out +- op : nanmedian + backward : nanmedian_grad + inputs : + {x : X} + outputs : + {out : Out, medians : MedianIndex} + int_array: + axis: + data_type : int + extra: + outputs : [medians] + - op : nce backward : nce_grad extra : @@ -1846,6 +1858,13 @@ attrs : [bool use_mkldnn = false] drop_empty_grad : [x_grad] +- op : stanh + backward : stanh_grad + inputs : + x : X + outputs : + out : Out + - op : subtract (elementwise_sub) backward : subtract_grad (elementwise_sub_grad) inputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 31329287d58..358c8095632 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1075,6 +1075,16 @@ func : mv backward : mv_grad +- op : nanmedian + args : (Tensor x, IntArray axis = {}, bool keepdim = true) + output : Tensor(out), Tensor(medians) + infer_meta : + func : NanmedianInferMeta + kernel : + func : nanmedian + intermediate : medians + backward : nanmedian_grad + - op : nearest_interp args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout="NCHW", int out_d=0, int out_h=0, int out_w=0, float[] scale={}, str interp_method="bilinear", bool align_corners=true, int align_mode=1) output : Tensor(output) @@ -1552,6 +1562,16 @@ func : stack backward : stack_grad +- op : stanh + args : (Tensor x, float scale_a=0.67f, float scale_b=1.7159f) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : stanh + backward : stanh_grad + - op : svd args : (Tensor x, bool full_matrices = false) output : Tensor(u), Tensor(s), Tensor(vh) diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index 5106c63a9e0..804d0d63aa2 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -41,12 +41,7 @@ namespace phi { DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hardtanh", "t_min" comma "t_max"); DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold"); -DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Swish, "swish", "beta"); // NOLINT - -DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(STanh, - "stanh", - "scale_a" comma "scale_b"); // NOLINT - +DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Swish, "swish", "beta"); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu6, "relu6", "threshold"); // NOLINT KernelSignature HardSwishGradOpArgumentMapping( @@ -72,7 +67,6 @@ PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish); PD_REGISTER_BASE_KERNEL_NAME(hard_swish_grad, hardswish_grad); PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(relu6_grad, phi::Relu6GradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(relu6, phi::Relu6OpArgumentMapping); diff --git a/paddle/phi/ops/compat/nanmedian_sig.cc b/paddle/phi/ops/compat/nanmedian_sig.cc deleted file mode 100644 index 5ca0d450e3b..00000000000 --- a/paddle/phi/ops/compat/nanmedian_sig.cc +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature NanmedianOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "nanmedian", {"X"}, {"axis", "keepdim"}, {"Out", "MedianIndex"}); -} - -KernelSignature NanmedianGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("nanmedian_grad", - {"X", "MedianIndex", "Out@GRAD"}, - {"axis", "keepdim"}, - {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(nanmedian, phi::NanmedianOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(nanmedian_grad, phi::NanmedianGradOpArgumentMapping); diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 0494338eb45..bb4b9646374 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -285,7 +285,7 @@ def stanh(x, scale_a=0.67, scale_b=1.7159, name=None): """ if in_dygraph_mode(): - return _legacy_C_ops.stanh(x, 'scale_a', scale_a, 'scale_b', scale_b) + return _C_ops.stanh(x, scale_a, scale_b) else: check_variable_and_dtype( x, 'x', ['float16', 'float32', 'float64'], 'stanh' diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index c8126a78b2d..7677f71bcac 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -15,7 +15,7 @@ # TODO: define statistical functions of a tensor import paddle -from paddle import _C_ops, _legacy_C_ops +from paddle import _C_ops from paddle.fluid.framework import in_dygraph_mode from ..common_ops_import import Variable @@ -332,10 +332,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None): raise ValueError("Axis has duplicated elements.") if in_dygraph_mode(): - median_index, out = _legacy_C_ops.nanmedian( - x, 'axis', axis, 'keepdim', keepdim - ) - return out + return _C_ops.nanmedian(x, axis, keepdim) else: check_variable_and_dtype( x, -- GitLab