From 4440d7ced0075eb116b5e4e74658049f5968d61a Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Fri, 6 Sep 2019 21:59:51 +0800 Subject: [PATCH] test=develop cuda realization of label smooth op (#19175) --- paddle/fluid/operators/label_smooth_op.cu | 94 ++++++++++++++++++++++- 1 file changed, 90 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/label_smooth_op.cu b/paddle/fluid/operators/label_smooth_op.cu index ab259b48e3..89f1d28e99 100644 --- a/paddle/fluid/operators/label_smooth_op.cu +++ b/paddle/fluid/operators/label_smooth_op.cu @@ -12,15 +12,101 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/label_smooth_op.h" +namespace paddle { +namespace operators { +template +__global__ void LabelSmoothRunOriginKernel(const int N, const float epsilon, + const int label_dim, const T* src, + T* dst) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < N; idx += blockDim.x * gridDim.x) { + dst[idx] = static_cast(1 - epsilon) * src[idx] + + static_cast(epsilon / label_dim); + } +} + +template +__global__ void LabelSmoothRunDistKernel(const int N, const float epsilon, + const int dist_numel, const T* src, + const T* dist_data, T* dst) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < N; idx += blockDim.x * gridDim.x) { + int dist_idx = idx - (idx / dist_numel) * dist_numel; + dst[idx] = static_cast(1 - epsilon) * src[idx] + + static_cast(epsilon) * dist_data[dist_idx]; + } +} + +template +__global__ void LabelSmoothGradRunKernel(const int N, const float epsilon, + const T* src, T* dst) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < N; idx += blockDim.x * gridDim.x) { + dst[idx] = static_cast(1 - epsilon) * src[idx]; + } +} + +template +class LabelSmoothGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + auto* in_t = ctx.Input("X"); + auto* dist_t = ctx.Input("PriorDist"); + auto label_dim = in_t->dims()[1]; + auto epsilon = ctx.Attr("epsilon"); + auto& dev = *ctx.template device_context().eigen_device(); + auto size_prob = in_t->numel(); + const T* in_data = in_t->data(); + T* out_data = out_t->mutable_data(ctx.GetPlace()); + int threads = 512; + int grid = (size_prob + threads - 1) / threads; + auto stream = ctx.cuda_device_context().stream(); + if (dist_t) { + auto dist_numel = dist_t->numel(); + const T* dist_data = dist_t->data(); + LabelSmoothRunDistKernel<<>>( + size_prob, epsilon, dist_numel, in_data, dist_data, out_data); + + } else { + LabelSmoothRunOriginKernel<<>>( + size_prob, epsilon, label_dim, in_data, out_data); + } + } +}; + +template +class LabelSmoothGradGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out_t = ctx.Input(framework::GradVarName("Out")); + auto* d_in_t = ctx.Output(framework::GradVarName("X")); + d_in_t->mutable_data(ctx.GetPlace()); + + auto epsilon = ctx.Attr("epsilon"); + auto& dev = *ctx.template device_context().eigen_device(); + const T* in_data = d_out_t->data(); + auto size_prob = d_out_t->numel(); + T* out_data = d_in_t->mutable_data(ctx.GetPlace()); + int threads = 512; + int grid = (size_prob + threads - 1) / threads; + auto stream = ctx.cuda_device_context().stream(); + LabelSmoothGradRunKernel<<>>( + size_prob, epsilon, in_data, out_data); + } +}; +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( label_smooth, - ops::LabelSmoothKernel, - ops::LabelSmoothKernel); + ops::LabelSmoothGPUKernel, + ops::LabelSmoothGPUKernel); REGISTER_OP_CUDA_KERNEL( label_smooth_grad, - ops::LabelSmoothGradKernel, - ops::LabelSmoothGradKernel); + ops::LabelSmoothGradGPUKernel, + ops::LabelSmoothGradGPUKernel); -- GitLab