From 1faa06f07ad042b293740aeffa4854cdeb7f8fd9 Mon Sep 17 00:00:00 2001 From: lzydev Date: Thu, 30 Mar 2023 10:37:23 +0800 Subject: [PATCH] Change some op with xpu control (#52067) * change op with xpu * change range yaml * fix bug in generate_op.py --- .../operators/amp/update_loss_scaling_op.cc | 140 ------------------ .../fluid/operators/generator/generate_op.py | 22 ++- .../generator/get_expected_kernel_func.cc | 10 ++ .../generator/get_expected_kernel_func.h | 4 + paddle/fluid/operators/linspace_op.cc | 91 ------------ paddle/fluid/operators/range_op.cc | 74 --------- paddle/phi/api/yaml/legacy_ops.yaml | 11 -- paddle/phi/api/yaml/op_compat.yaml | 26 +++- paddle/phi/api/yaml/op_version.yaml | 8 + paddle/phi/api/yaml/ops.yaml | 13 ++ paddle/phi/api/yaml/static_ops.yaml | 23 +++ paddle/phi/ops/compat/range_sig.cc | 17 --- .../phi/ops/compat/update_loss_scaling_sig.cc | 47 ------ 13 files changed, 101 insertions(+), 385 deletions(-) delete mode 100644 paddle/fluid/operators/amp/update_loss_scaling_op.cc delete mode 100644 paddle/fluid/operators/linspace_op.cc delete mode 100644 paddle/fluid/operators/range_op.cc delete mode 100644 paddle/phi/ops/compat/range_sig.cc delete mode 100644 paddle/phi/ops/compat/update_loss_scaling_sig.cc diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.cc b/paddle/fluid/operators/amp/update_loss_scaling_op.cc deleted file mode 100644 index 7f9b7da62f4..00000000000 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/multiary.h" - -namespace paddle { -namespace operators { - -class UpdateLossScalingOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto dtype = framework::proto::VarType::FP32; - if (ctx.MultiInputVar("X").size() >= 1) { - dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - } - - return phi::KernelKey(dtype, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const phi::KernelKey& expected_kernel_type) const override { -#ifndef PADDLE_WITH_XPU - if (var_name == "FoundInfinite" || var_name == "StopUpdate") { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } -#endif - return framework::OperatorWithKernel::GetKernelTypeForVar( - var_name, tensor, expected_kernel_type); - } -}; - -class UpdateLossScalingOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Tensors) The input tensors of update_loss_scaling operator.") - .AsDuplicable(); - AddInput("FoundInfinite", - "(Tensor) 1-dim tensor, contains a bool scalar, which indicates " - "whether there is any infinite gradient."); - AddInput("PrevLossScaling", - "(Tensor) 1-dim tensor, previous loss scaling."); - AddInput("InGoodSteps", - "(Tensor) 1-dim tensor, accumulates good steps in which all " - "gradients are finite."); - AddInput("InBadSteps", - "(Tensor) 1-dim tensor, accumulates bad steps in which some " - "gradients are infinite."); - AddOutput("Out", - "(Tensors) The output tensor of update_loss_scaling operator.") - .AsDuplicable(); - AddOutput("LossScaling", "(Tensor) 1-dim tensor, updated loss scaling."); - AddOutput("OutGoodSteps", "(Tensor) 1-dim tensor, pdated good steps."); - AddOutput("OutBadSteps", "(Tensor) 1-dim tensor, updated bad steps."); - AddInput("StopUpdate", - "(Tensor) 1-dim tensor. Stop updating loss scaling, and just " - "zero inputs. It has higher priority than Attr(stop_update).") - .AsDispensable(); - AddAttr("incr_every_n_steps", - "A value represents increasing loss scaling every n " - "consecutive steps with finite gradients."); - AddAttr("decr_every_n_nan_or_inf", - "A value represents decreasing loss scaling every n " - "accumulated steps with nan or inf gradients."); - AddAttr("incr_ratio", - "The multiplier to use when increasing the loss scaling.") - .AddCustomChecker([](float incr_ratio) { - PADDLE_ENFORCE_EQ(incr_ratio > 1.0f, - true, - platform::errors::InvalidArgument( - "'incr_ratio' should be greater than 1, but " - "the received is %f", - incr_ratio)); - }); - AddAttr( - "decr_ratio", - "The less-than-one-multiplier to use when decreasing loss scaling.") - .AddCustomChecker([](float decr_ratio) { - PADDLE_ENFORCE_EQ(decr_ratio > 0.0f && decr_ratio < 1.0f, - true, - platform::errors::InvalidArgument( - "'decr_ratio' should be between 0 and 1, but " - "the received is %f", - decr_ratio)); - }); - AddAttr("stop_update", - "Stop updating loss scaling, and just zero inputs.") - .SetDefault(false); - AddComment(R"DOC( -Update loss scaling according to overall gradients. If all gradients is -finite after incr_every_n_steps, loss scaling will increase by incr_ratio. -Otherwise, loss scaling will decrease by decr_ratio after -decr_every_n_nan_or_inf steps and each step some gradients are infinite. - -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CPU = phi::CPUContext; - -DECLARE_INFER_SHAPE_FUNCTOR(update_loss_scaling, - UpdateLossScalingInferShapeFunctor, - PD_INFER_META(phi::UpdateLossScalingInferMeta)); -REGISTER_OPERATOR( - update_loss_scaling, - ops::UpdateLossScalingOp, - ops::UpdateLossScalingOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - UpdateLossScalingInferShapeFunctor); diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index ac788f5f931..6f1ea3d8b3c 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -479,6 +479,15 @@ def parse_get_expected_kerneltype( for op_comp_map in op_fluid_list: if 'get_expected_kernel_type' in op_comp_map: fw_name = op_comp_map['op'].split('(')[0].strip() + # deal the last underline of function name in op_comp_map['get_expected_kernel_type'] + new_get_expected_kernel_type_func_map = {} + for (key, value) in op_comp_map['get_expected_kernel_type'].items(): + new_get_expected_kernel_type_func_map[ + delete_last_underline(key) + ] = value + op_comp_map[ + 'get_expected_kernel_type' + ] = new_get_expected_kernel_type_func_map if fw_name in op_comp_map['get_expected_kernel_type']: # static_ops.yaml and ops.yaml use the common op_compat.yaml if fw_name in fw_op_dict: @@ -507,10 +516,15 @@ def parse_keep_signature( for op_comp_map in op_fluid_list: if 'manual_signature' in op_comp_map: for op_name in op_comp_map['manual_signature']: - if op_name in fw_op_dict: - fw_op_dict[op_name]["manual_signature"] = True - elif op_name in bw_op_dict: - bw_op_dict[op_name]["manual_signature"] = True + op_name_without_last_underline = delete_last_underline(op_name) + if op_name_without_last_underline in fw_op_dict: + fw_op_dict[op_name_without_last_underline][ + "manual_signature" + ] = True + elif op_name_without_last_underline in bw_op_dict: + bw_op_dict[op_name_without_last_underline][ + "manual_signature" + ] = True def split_ops_list(ops, backward_op_dict, split_num): diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.cc b/paddle/fluid/operators/generator/get_expected_kernel_func.cc index 79a17940280..5c0e53e2437 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.cc +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.cc @@ -141,5 +141,15 @@ phi::KernelKey GetSgdExpectedKernelType( return phi::KernelKey(data_type, ctx.GetPlace()); } +phi::KernelKey GetUpdateLossScalingExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr) { + auto dtype = framework::proto::VarType::FP32; + if (ctx.MultiInputVar("X").size() >= 1) { + dtype = op_ptr->IndicateVarDataType(ctx, "X"); + } + return phi::KernelKey(dtype, ctx.GetPlace()); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.h b/paddle/fluid/operators/generator/get_expected_kernel_func.h index f360c0d4b08..9b5be6feac3 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.h +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.h @@ -36,5 +36,9 @@ phi::KernelKey GetSgdExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); +phi::KernelKey GetUpdateLossScalingExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr); + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc deleted file mode 100644 index e3fade6d612..00000000000 --- a/paddle/fluid/operators/linspace_op.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/ternary.h" - -namespace paddle { -namespace operators { - -class LinspaceOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return phi::KernelKey( - framework::proto::VarType::Type(ctx.Attr("dtype")), - ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string &var_name, - const phi::DenseTensor &tensor, - const phi::KernelKey &expected_kernel_type) const override { - if (platform::is_xpu_place(tensor.place())) { - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } -}; - -class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Start", - "First entry in the sequence. It is a tensor of shape [1], should " - "be of type float32 or float64."); - AddInput("Stop", - "Last entry in the sequence. It is a tensor of shape [1], should " - "be of type float32 or float64."); - AddInput("Num", - "Number of entry in the sequence. It is a tensor of shape [1], " - "should be of type int32."); - AddAttr("dtype", "The output data type."); - AddOutput("Out", "A sequence of numbers."); - AddComment(R"DOC( - Return fixed number of evenly spaced values within a given interval. First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy. -)DOC"); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(linspace, - LinspaceInferShapeFunctor, - PD_INFER_META(phi::LinspaceRawInferMeta)); -REGISTER_OPERATOR( - linspace, - ops::LinspaceOp, - ops::LinspaceOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - LinspaceInferShapeFunctor); - -REGISTER_OP_VERSION(linspace).AddCheckpoint( - R"ROC( - Upgrade linspace to add a new attribute [dtype]. - )ROC", - paddle::framework::compatible::OpVersionDesc().NewAttr( - "dtype", "In order to change output data type ", 5)); diff --git a/paddle/fluid/operators/range_op.cc b/paddle/fluid/operators/range_op.cc deleted file mode 100644 index 08706bc7052..00000000000 --- a/paddle/fluid/operators/range_op.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/range_op.h" - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/ternary.h" - -namespace paddle { -namespace operators { - -class RangeOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetKernelTypeForVar( - const std::string &var_name, - const phi::DenseTensor &tensor, - const phi::KernelKey &expected_kernel_type) const override { - if (platform::is_xpu_place(tensor.place())) { - return phi::KernelKey( - tensor.place(), tensor.layout(), expected_kernel_type.dtype()); - } - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } -}; - -class RangeOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Start", - "Start of interval. The interval includes this value. It is a " - "tensor with shape=[1]."); - AddInput("End", - "End of interval. The interval does not include this value, " - "except in some cases where step is not an integer and floating " - "point round-off affects the length of out. It is a tensor with " - "shape=[1]."); - AddInput("Step", "Spacing between values. It is a tensor with shape=[1]."); - AddOutput("Out", "A sequence of numbers."); - AddComment(R"DOC( - Return evenly spaced values within a given interval. Values are generated within the half-open interval [start, stop) (in other words, the interval including start but excluding stop). Like arange function of numpy. -)DOC"); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(range, - RangeInferMetaFunctor, - PD_INFER_META(phi::ArangeInferMeta)); -REGISTER_OP_WITHOUT_GRADIENT(range, - ops::RangeOp, - ops::RangeOpMaker, - RangeInferMetaFunctor); diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 69485286601..abc364f9cc5 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1658,17 +1658,6 @@ data_type: x backward: unpool3d_grad -- op : update_loss_scaling_ - args : (Tensor[] x, Tensor found_infinite, Tensor prev_loss_scaling, Tensor in_good_steps, Tensor in_bad_steps, int incr_every_n_steps, int decr_every_n_nan_or_inf, float incr_ratio, float decr_ratio, Scalar stop_update) - output : Tensor[](out){x.size()}, Tensor(loss_scaling), Tensor(out_good_steps), Tensor(out_bad_steps) - infer_meta : - func : UpdateLossScalingInferMeta - param : [x, found_infinite, prev_loss_scaling, in_good_steps, in_bad_steps] - kernel : - func : update_loss_scaling - data_type : x - inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps) - - op : warpctc args : (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times) output : Tensor(loss), Tensor(warpctcgrad) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 8a521138665..dc48ead7f57 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -97,6 +97,12 @@ extra : attrs : [bool use_mkldnn = false] +- op : arange(range) + inputs : + {start : Start, end : End, step : Step} + outputs : + out : Out + - op : argsort inputs : x : X @@ -1081,6 +1087,12 @@ extra : attrs : [bool use_mkldnn = false] +- op : linspace + inputs : + {start : Start, stop : Stop, number : Num} + outputs : + out : Out + - op : log backward : log_grad, log_double_grad (log_grad_grad) inputs : @@ -1652,7 +1664,7 @@ outputs : {param_out : ParamOut, master_param_out : MasterParamOut} get_expected_kernel_type : - sgd : GetSgdExpectedKernelType #"sgd_" becomes "sgd" + sgd_ : GetSgdExpectedKernelType extra : attrs : [bool use_mkldnn=false] @@ -1993,6 +2005,18 @@ outputs : out : Y +- op : update_loss_scaling_ + inputs : + {x : X, found_infinite : FoundInfinite, prev_loss_scaling : PrevLossScaling, in_good_steps : InGoodSteps, in_bad_steps : InBadSteps} + outputs : + {out : Out, loss_scaling : LossScaling, out_good_steps : OutGoodSteps, out_bad_steps : OutBadSteps} + scalar : + stop_update : + data_type : bool + tensor_name : StopUpdate + get_expected_kernel_type : + update_loss_scaling_ : GetUpdateLossScalingExpectedKernelType + - op : viterbi_decode inputs : {potentials : Input, transition_params : Transition, lengths : Length} diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 83860ce5c8f..df5b56edb3e 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -103,6 +103,14 @@ comment : In order to force fill output variable to gpu memory. default : "false" +- op : linspace + version : + - checkpoint : Upgrade linspace to add a new attribute [dtype] + action : + - add_attr : dtype + comment : In order to change output data type + default : 5 + - op : not_equal version : - checkpoint : Upgrade compare ops, add a new attribute [force_cpu] diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 30d9bbe716c..37dce69cb58 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1663,6 +1663,19 @@ func : unstack backward : unstack_grad +- op : update_loss_scaling_ + args : (Tensor[] x, Tensor found_infinite, Tensor prev_loss_scaling, Tensor in_good_steps, Tensor in_bad_steps, int incr_every_n_steps, int decr_every_n_nan_or_inf, float incr_ratio, float decr_ratio, Scalar stop_update=false) + output : Tensor[](out){x.size()}, Tensor(loss_scaling), Tensor(out_good_steps), Tensor(out_bad_steps) + infer_meta : + func : UpdateLossScalingInferMeta + param : [x, found_infinite, prev_loss_scaling, in_good_steps, in_bad_steps] + kernel : + func : update_loss_scaling + data_type : x + data_transform : + skip_transform : found_infinite + inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps) + - op : viterbi_decode args : (Tensor potentials, Tensor transition_params, Tensor lengths, bool include_bos_eos_tag = true) output : Tensor(scores), Tensor(path) diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 350b5c2848d..a2ecdc8790a 100644 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -26,6 +26,16 @@ func : all_reduce param: [x, reduce_type] +- op : arange + args : (Tensor start, Tensor end, Tensor step) + output : Tensor(out) + infer_meta : + func : ArangeInferMeta + kernel : + func : arange + data_transform : + skip_transform : start, end, step + - op : assign args : (Tensor x) output : Tensor @@ -117,6 +127,19 @@ backend : x force_backend : force_cpu +- op : linspace + args : (Tensor start, Tensor stop, Tensor number, DataType dtype) + output : Tensor(out) + infer_meta : + func : LinspaceInferMeta + param: [start, stop, number, dtype] + kernel : + func : linspace + param: [start, stop, number, dtype] + data_type : dtype + data_transform : + skip_transform : start, stop, number + - op : not_equal args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false) output : Tensor(out) diff --git a/paddle/phi/ops/compat/range_sig.cc b/paddle/phi/ops/compat/range_sig.cc deleted file mode 100644 index d48898bd848..00000000000 --- a/paddle/phi/ops/compat/range_sig.cc +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -PD_REGISTER_BASE_KERNEL_NAME(range, arange); diff --git a/paddle/phi/ops/compat/update_loss_scaling_sig.cc b/paddle/phi/ops/compat/update_loss_scaling_sig.cc deleted file mode 100644 index 8223d0c7dfd..00000000000 --- a/paddle/phi/ops/compat/update_loss_scaling_sig.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature UpdateLossScalingOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.HasInput("StopUpdate")) { - return KernelSignature( - "update_loss_scaling", - {"X", "FoundInfinite", "PrevLossScaling", "InGoodSteps", "InBadSteps"}, - {"incr_every_n_steps", - "decr_every_n_nan_or_inf", - "incr_ratio", - "decr_ratio", - "StopUpdate"}, - {"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}); - } else { - return KernelSignature( - "update_loss_scaling", - {"X", "FoundInfinite", "PrevLossScaling", "InGoodSteps", "InBadSteps"}, - {"incr_every_n_steps", - "decr_every_n_nan_or_inf", - "incr_ratio", - "decr_ratio", - "stop_update"}, - {"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}); - } -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(update_loss_scaling, - phi::UpdateLossScalingOpArgumentMapping); -- GitLab