/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mish_op.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template __global__ void KeMishFw(const T* in, T* out, const int numel, const float threshold) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (; tid < numel; tid += stride) { T x = in[tid]; T sp = CalcSoftplus(x, threshold); out[tid] = x * tanh(sp); } } // expf instead of exp should be used for float type, complement // and register float kernel separatelly __global__ void KeMishFwFP32(const float* in, float* out, const int numel, const float threshold) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (; tid < numel; tid += stride) { float x = in[tid]; float sp = CalcSoftplusFP32(x, threshold); out[tid] = x * tanhf(sp); } } template __global__ void KeMishBw(const T* in, const T* dout, T* din, const int numel, const float threshold) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (; tid < numel; tid += stride) { T x = in[tid]; T sp = CalcSoftplus(x, threshold); T tsp = tanh(sp); T grad_sp = -expm1(-sp); T grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; din[tid] = dout[tid] * (x * grad_tsp + tsp); } } __global__ void KeMishBwFP32(const float* in, const float* dout, float* din, const int numel, const float threshold) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (; tid < numel; tid += stride) { float x = in[tid]; float sp = CalcSoftplusFP32(x, threshold); float tsp = tanhf(sp); float grad_sp = -expm1f(-sp); float grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; din[tid] = dout[tid] * (x * grad_tsp + tsp); } } template class MishCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); const float threshold = ctx.Attr("threshold"); const T* x_data = x->data(); T* out_data = out->mutable_data(ctx.GetPlace()); const int numel = x->numel(); platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel); KeMishFw<<>>(x_data, out_data, numel, threshold); } }; template class MishFP32CUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); const float threshold = ctx.Attr("threshold"); const float* x_data = x->data(); float* out_data = out->mutable_data(ctx.GetPlace()); const int numel = x->numel(); platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel); KeMishFwFP32<<>>(x_data, out_data, numel, threshold); } }; template class MishGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto threshold = ctx.Attr("threshold"); const T* x_data = x->data(); const T* dout_data = dout->data(); T* dx_data = dx->mutable_data(ctx.GetPlace()); const int numel = x->numel(); platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel); KeMishBw<<>>( x_data, dout_data, dx_data, numel, threshold); } }; template class MishGradFP32CUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto threshold = ctx.Attr("threshold"); const float* x_data = x->data(); const float* dout_data = dout->data(); float* dx_data = dx->mutable_data(ctx.GetPlace()); const int numel = x->numel(); platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel); KeMishBwFP32<<>>( x_data, dout_data, dx_data, numel, threshold); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( mish, ops::MishFP32CUDAKernel, ops::MishCUDAKernel) REGISTER_OP_CUDA_KERNEL( mish_grad, ops::MishGradFP32CUDAKernel, ops::MishGradCUDAKernel)