From b7bcd0f643b90e87da749251011e364e3681e5d7 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 23 Feb 2022 11:14:23 +0800 Subject: [PATCH] [Phi] Migrate lable_smooth_op into Phi (#39796) * [Phi] Migrate lable_smooth_op into Phi * fix PT->PD --- paddle/fluid/framework/operator.cc | 2 +- paddle/fluid/operators/label_smooth_op.cc | 11 +- paddle/fluid/operators/label_smooth_op.cu | 125 ------------------ paddle/fluid/operators/label_smooth_op.h | 70 ---------- paddle/fluid/operators/label_smooth_op_npu.cc | 2 +- paddle/fluid/operators/label_smooth_op_xpu.cc | 1 - .../kernels/cpu/label_smooth_grad_kernel.cc | 45 +++++++ paddle/phi/kernels/cpu/label_smooth_kernel.cc | 50 +++++++ .../kernels/gpu/label_smooth_grad_kernel.cu | 55 ++++++++ paddle/phi/kernels/gpu/label_smooth_kernel.cu | 86 ++++++++++++ paddle/phi/kernels/label_smooth_grad_kernel.h | 28 ++++ paddle/phi/kernels/label_smooth_kernel.h | 30 +++++ paddle/phi/ops/compat/label_smooth_sig.cc | 37 ++++++ 13 files changed, 334 insertions(+), 208 deletions(-) delete mode 100644 paddle/fluid/operators/label_smooth_op.cu delete mode 100644 paddle/fluid/operators/label_smooth_op.h create mode 100644 paddle/phi/kernels/cpu/label_smooth_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/label_smooth_kernel.cc create mode 100644 paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/label_smooth_kernel.cu create mode 100644 paddle/phi/kernels/label_smooth_grad_kernel.h create mode 100644 paddle/phi/kernels/label_smooth_kernel.h create mode 100644 paddle/phi/ops/compat/label_smooth_sig.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e589f059f5..701fc7de69 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2040,7 +2040,7 @@ void OperatorWithKernel::BuildPtenKernelContext( (i == 0 ? 0 : pt_kernel_context->InputRangeAt(i - 1).second); // deal with optional here - if ((it == ctx.inputs.end()) && + if ((it == ctx.inputs.end() || it->second.size() == 0) && (input_defs[i].type_index == std::type_index(typeid(paddle::optional)))) { pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr); diff --git a/paddle/fluid/operators/label_smooth_op.cc b/paddle/fluid/operators/label_smooth_op.cc index 5ae9fd7a61..7e07610db2 100644 --- a/paddle/fluid/operators/label_smooth_op.cc +++ b/paddle/fluid/operators/label_smooth_op.cc @@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/label_smooth_op.h" - #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace framework { @@ -152,11 +151,3 @@ REGISTER_OPERATOR(label_smooth, ops::LabelSmoothOp, ops::LabelSmoothOpMaker, ops::LabelSmoothGradMaker, ops::LabelSmoothGradMaker); REGISTER_OPERATOR(label_smooth_grad, ops::LabelSmoothGradOp); -REGISTER_OP_CPU_KERNEL( - label_smooth, - ops::LabelSmoothKernel, - ops::LabelSmoothKernel); -REGISTER_OP_CPU_KERNEL( - label_smooth_grad, - ops::LabelSmoothGradKernel, - ops::LabelSmoothGradKernel); diff --git a/paddle/fluid/operators/label_smooth_op.cu b/paddle/fluid/operators/label_smooth_op.cu deleted file mode 100644 index f149e104ef..0000000000 --- a/paddle/fluid/operators/label_smooth_op.cu +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/operators/label_smooth_op.h" -namespace paddle { -namespace operators { - -template -struct LabelSmoothFunctor { - T epsilon; - T label_dim; - - __forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) { - epsilon = static_cast(epsilon_data); - label_dim = static_cast(label_dim_data); - } - - __device__ __forceinline__ T operator()(const T x) const { - return (static_cast(1 - epsilon) * x + - static_cast(epsilon / label_dim)); - } -}; - -template -struct LabelSmoothGradFunctor { - T epsilon; - - __forceinline__ LabelSmoothGradFunctor(float epsilon_data) { - epsilon = static_cast(epsilon_data); - } - - __device__ __forceinline__ T operator()(const T x) const { - return static_cast(1 - epsilon) * x; - } -}; - -template -__global__ void LabelSmoothRunDistKernel(const int N, const float epsilon, - const int dist_numel, const T* src, - const T* dist_data, T* dst) { - CUDA_KERNEL_LOOP(idx, N) { - int dist_idx = idx % dist_numel; - dst[idx] = static_cast(1 - epsilon) * src[idx] + - static_cast(epsilon) * dist_data[dist_idx]; - } -} - -template -class LabelSmoothGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* out_t = ctx.Output("Out"); - auto* in_t = ctx.Input("X"); - auto* dist_t = ctx.Input("PriorDist"); - auto label_dim = in_t->dims()[in_t->dims().size() - 1]; - auto epsilon = ctx.Attr("epsilon"); - auto& dev = *ctx.template device_context().eigen_device(); - auto size_prob = in_t->numel(); - const T* in_data = in_t->data(); - T* out_data = out_t->mutable_data(ctx.GetPlace()); - int threads = 512; - int grid = (size_prob + threads - 1) / threads; - auto stream = ctx.cuda_device_context().stream(); - if (dist_t) { - auto dist_numel = dist_t->numel(); - const T* dist_data = dist_t->data(); - LabelSmoothRunDistKernel<<>>( - size_prob, epsilon, dist_numel, in_data, dist_data, out_data); - - } else { - auto& dev_ctx = - ctx.template device_context(); - - std::vector ins = {in_t}; - std::vector outs = {out_t}; - auto functor = LabelSmoothFunctor(epsilon, label_dim); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); - } - } -}; - -template -class LabelSmoothGradGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out_t = ctx.Input(framework::GradVarName("Out")); - auto* d_in_t = ctx.Output(framework::GradVarName("X")); - d_in_t->mutable_data(ctx.GetPlace()); - - auto epsilon = ctx.Attr("epsilon"); - auto& dev_ctx = ctx.template device_context(); - - std::vector ins = {d_out_t}; - std::vector outs = {d_in_t}; - auto functor = LabelSmoothGradFunctor(epsilon); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); - } -}; -} // namespace operators -} // namespace paddle -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - label_smooth, - ops::LabelSmoothGPUKernel, - ops::LabelSmoothGPUKernel); -REGISTER_OP_CUDA_KERNEL( - label_smooth_grad, - ops::LabelSmoothGradGPUKernel, - ops::LabelSmoothGradGPUKernel); diff --git a/paddle/fluid/operators/label_smooth_op.h b/paddle/fluid/operators/label_smooth_op.h deleted file mode 100644 index 6b509eb64c..0000000000 --- a/paddle/fluid/operators/label_smooth_op.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class LabelSmoothKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* out_t = ctx.Output("Out"); - auto* in_t = ctx.Input("X"); - auto* dist_t = ctx.Input("PriorDist"); - auto label_dim = in_t->dims()[in_t->dims().size() - 1]; - out_t->mutable_data(ctx.GetPlace()); - if (label_dim != 0) { - auto epsilon = ctx.Attr("epsilon"); - auto out = framework::EigenVector::Flatten(*out_t); - auto in = framework::EigenVector::Flatten(*in_t); - auto& dev = *ctx.template device_context().eigen_device(); - if (dist_t) { - auto dist = framework::EigenVector::Flatten(*dist_t); - out.device(dev) = static_cast(1 - epsilon) * in + - static_cast(epsilon) * - dist.broadcast(Eigen::DSizes( - in_t->numel() / label_dim)); - } else { - out.device(dev) = static_cast(1 - epsilon) * in + - static_cast(epsilon / label_dim); - } - } - } -}; - -template -class LabelSmoothGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out_t = ctx.Input(framework::GradVarName("Out")); - auto* d_in_t = ctx.Output(framework::GradVarName("X")); - d_in_t->mutable_data(ctx.GetPlace()); - auto d_out_dim = d_out_t->dims()[d_out_t->dims().size() - 1]; - if (d_out_dim != 0) { - auto d_out = framework::EigenVector::Flatten(*d_out_t); - auto d_in = framework::EigenVector::Flatten(*d_in_t); - - auto epsilon = ctx.Attr("epsilon"); - auto& dev = *ctx.template device_context().eigen_device(); - d_in.device(dev) = static_cast(1 - epsilon) * d_out; - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/label_smooth_op_npu.cc b/paddle/fluid/operators/label_smooth_op_npu.cc index af519cc909..c24b896e0a 100644 --- a/paddle/fluid/operators/label_smooth_op_npu.cc +++ b/paddle/fluid/operators/label_smooth_op_npu.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/label_smooth_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/label_smooth_op_xpu.cc b/paddle/fluid/operators/label_smooth_op_xpu.cc index 6b63507539..dd8d0c721c 100644 --- a/paddle/fluid/operators/label_smooth_op_xpu.cc +++ b/paddle/fluid/operators/label_smooth_op_xpu.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/label_smooth_op.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { diff --git a/paddle/phi/kernels/cpu/label_smooth_grad_kernel.cc b/paddle/phi/kernels/cpu/label_smooth_grad_kernel.cc new file mode 100644 index 0000000000..74664fb270 --- /dev/null +++ b/paddle/phi/kernels/cpu/label_smooth_grad_kernel.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/label_smooth_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void LabelSmoothGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float epsilon, + DenseTensor* label_grad) { + ctx.template Alloc(label_grad); + auto d_out_dim = out_grad.dims()[out_grad.dims().size() - 1]; + if (d_out_dim != 0) { + auto d_out = EigenVector::Flatten(out_grad); + auto d_in = EigenVector::Flatten(*label_grad); + + auto& dev = *ctx.eigen_device(); + d_in.device(dev) = static_cast(1 - epsilon) * d_out; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(label_smooth_grad, + CPU, + ALL_LAYOUT, + phi::LabelSmoothGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/label_smooth_kernel.cc b/paddle/phi/kernels/cpu/label_smooth_kernel.cc new file mode 100644 index 0000000000..c76fb826cd --- /dev/null +++ b/paddle/phi/kernels/cpu/label_smooth_kernel.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/label_smooth_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void LabelSmoothKernel(const Context& ctx, + const DenseTensor& label, + paddle::optional prior_dist, + float epsilon, + DenseTensor* out) { + auto label_dim = label.dims()[label.dims().size() - 1]; + ctx.template Alloc(out); + auto& dev = *ctx.eigen_device(); + if (label_dim != 0) { + auto eigen_out = EigenVector::Flatten(*out); + auto eigen_in = EigenVector::Flatten(label); + if (prior_dist.is_initialized()) { + auto dist = EigenVector::Flatten(*prior_dist.get_ptr()); + eigen_out.device(dev) = + static_cast(1 - epsilon) * eigen_in + + static_cast(epsilon) * + dist.broadcast(Eigen::DSizes(label.numel() / label_dim)); + } else { + eigen_out.device(dev) = static_cast(1 - epsilon) * eigen_in + + static_cast(epsilon / label_dim); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + label_smooth, CPU, ALL_LAYOUT, phi::LabelSmoothKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu b/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu new file mode 100644 index 0000000000..f30e8c3cdc --- /dev/null +++ b/paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/label_smooth_grad_kernel.h" + +namespace phi { +template +struct LabelSmoothGradFunctor { + T epsilon; + + __forceinline__ LabelSmoothGradFunctor(float epsilon_data) { + epsilon = static_cast(epsilon_data); + } + + __device__ __forceinline__ T operator()(const T x) const { + return static_cast(1 - epsilon) * x; + } +}; + +template +void LabelSmoothGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float epsilon, + DenseTensor* label_grad) { + ctx.template Alloc(label_grad); + + std::vector ins = {&out_grad}; + std::vector outs = {label_grad}; + auto functor = LabelSmoothGradFunctor(epsilon); + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL(label_smooth_grad, + GPU, + ALL_LAYOUT, + phi::LabelSmoothGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/label_smooth_kernel.cu b/paddle/phi/kernels/gpu/label_smooth_kernel.cu new file mode 100644 index 0000000000..50f7548450 --- /dev/null +++ b/paddle/phi/kernels/gpu/label_smooth_kernel.cu @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/label_smooth_kernel.h" + +namespace phi { + +template +struct LabelSmoothFunctor { + T epsilon; + T label_dim; + + __forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) { + epsilon = static_cast(epsilon_data); + label_dim = static_cast(label_dim_data); + } + + __device__ __forceinline__ T operator()(const T x) const { + return (static_cast(1 - epsilon) * x + + static_cast(epsilon / label_dim)); + } +}; + +template +__global__ void LabelSmoothRunDistKernel(const int N, + const float epsilon, + const int dist_numel, + const T* src, + const T* dist_data, + T* dst) { + CUDA_KERNEL_LOOP(idx, N) { + int dist_idx = idx % dist_numel; + dst[idx] = static_cast(1 - epsilon) * src[idx] + + static_cast(epsilon) * dist_data[dist_idx]; + } +} + +template +void LabelSmoothKernel(const Context& ctx, + const DenseTensor& label, + paddle::optional prior_dist, + float epsilon, + DenseTensor* out) { + auto label_dim = label.dims()[label.dims().size() - 1]; + auto size_prob = label.numel(); + const T* in_data = label.data(); + T* out_data = ctx.template Alloc(out); + + if (prior_dist.get_ptr()) { + int threads = 512; + int grid = (size_prob + threads - 1) / threads; + auto stream = ctx.stream(); + const auto* dist_t = prior_dist.get_ptr(); + auto dist_numel = dist_t->numel(); + const T* dist_data = dist_t->data(); + LabelSmoothRunDistKernel<<>>( + size_prob, epsilon, dist_numel, in_data, dist_data, out_data); + + } else { + std::vector ins = {&label}; + std::vector outs = {out}; + auto functor = LabelSmoothFunctor(epsilon, label_dim); + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + ctx, ins, &outs, functor); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + label_smooth, GPU, ALL_LAYOUT, phi::LabelSmoothKernel, float, double) {} diff --git a/paddle/phi/kernels/label_smooth_grad_kernel.h b/paddle/phi/kernels/label_smooth_grad_kernel.h new file mode 100644 index 0000000000..993e967814 --- /dev/null +++ b/paddle/phi/kernels/label_smooth_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void LabelSmoothGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float epsilon, + DenseTensor* label_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/label_smooth_kernel.h b/paddle/phi/kernels/label_smooth_kernel.h new file mode 100644 index 0000000000..b7e1f27088 --- /dev/null +++ b/paddle/phi/kernels/label_smooth_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void LabelSmoothKernel(const Context& ctx, + const DenseTensor& label, + paddle::optional prior_dist, + float epsilon, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/label_smooth_sig.cc b/paddle/phi/ops/compat/label_smooth_sig.cc new file mode 100644 index 0000000000..4fb62a8ca2 --- /dev/null +++ b/paddle/phi/ops/compat/label_smooth_sig.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LabelSmoothOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "label_smooth", {"X", "PriorDist"}, {"epsilon"}, {"Out"}); +} + +KernelSignature LabelSmoothGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("label_smooth_grad", + {GradVarName("Out")}, + {"epsilon"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(label_smooth, phi::LabelSmoothOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(label_smooth_grad, + phi::LabelSmoothGradOpArgumentMapping); -- GitLab