/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/label_smooth_op.h" namespace paddle { namespace operators { template struct LabelSmoothFunctor { T epsilon; T label_dim; __forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) { epsilon = static_cast(epsilon_data); label_dim = static_cast(label_dim_data); } __device__ __forceinline__ T operator()(const T x) const { return (static_cast(1 - epsilon) * x + static_cast(epsilon / label_dim)); } }; template struct LabelSmoothGradFunctor { T epsilon; __forceinline__ LabelSmoothGradFunctor(float epsilon_data) { epsilon = static_cast(epsilon_data); } __device__ __forceinline__ T operator()(const T x) const { return static_cast(1 - epsilon) * x; } }; template __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon, const int dist_numel, const T* src, const T* dist_data, T* dst) { CUDA_KERNEL_LOOP(idx, N) { int dist_idx = idx % dist_numel; dst[idx] = static_cast(1 - epsilon) * src[idx] + static_cast(epsilon) * dist_data[dist_idx]; } } template class LabelSmoothGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* out_t = ctx.Output("Out"); auto* in_t = ctx.Input("X"); auto* dist_t = ctx.Input("PriorDist"); auto label_dim = in_t->dims()[in_t->dims().size() - 1]; auto epsilon = ctx.Attr("epsilon"); auto& dev = *ctx.template device_context().eigen_device(); auto size_prob = in_t->numel(); const T* in_data = in_t->data(); T* out_data = out_t->mutable_data(ctx.GetPlace()); int threads = 512; int grid = (size_prob + threads - 1) / threads; auto stream = ctx.cuda_device_context().stream(); if (dist_t) { auto dist_numel = dist_t->numel(); const T* dist_data = dist_t->data(); LabelSmoothRunDistKernel<<>>( size_prob, epsilon, dist_numel, in_data, dist_data, out_data); } else { auto& dev_ctx = ctx.template device_context(); std::vector ins = {in_t}; std::vector outs = {out_t}; auto functor = LabelSmoothFunctor(epsilon, label_dim); LaunchSameDimsElementwiseCudaKernel( dev_ctx, ins, &outs, functor); } } }; template class LabelSmoothGradGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_out_t = ctx.Input(framework::GradVarName("Out")); auto* d_in_t = ctx.Output(framework::GradVarName("X")); d_in_t->mutable_data(ctx.GetPlace()); auto epsilon = ctx.Attr("epsilon"); auto& dev_ctx = ctx.template device_context(); std::vector ins = {d_out_t}; std::vector outs = {d_in_t}; auto functor = LabelSmoothGradFunctor(epsilon); LaunchSameDimsElementwiseCudaKernel( dev_ctx, ins, &outs, functor); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( label_smooth, ops::LabelSmoothGPUKernel, ops::LabelSmoothGPUKernel); REGISTER_OP_CUDA_KERNEL( label_smooth_grad, ops::LabelSmoothGradGPUKernel, ops::LabelSmoothGradGPUKernel);