diff --git a/paddle/fluid/operators/mish_op.cc b/paddle/fluid/operators/mish_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea754b5b1e9413fbd28b351c13fe1da549ccfafb --- /dev/null +++ b/paddle/fluid/operators/mish_op.cc @@ -0,0 +1,121 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/mish_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class MishOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mish"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mish"); + + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class MishOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input of Mish operator"); + AddOutput("Out", "Output of Mish operator"); + AddAttr( + "threshold", + "Constant threshold of softplus in Mish operator. Approximate value " + "of softplus will be used if absolute value of input is greater than " + ":attr:`threshold`") + .SetDefault(20.f); + AddComment(R"DOC( +Mish Activation Operator. + +.. math:: + softplus = \begin{cases} + x, \text{if } x > \text{threshold} \\ + e^{x}, \text{if } x < -\text{threshold} \\ + \ln(1 + e^{x}), \text{otherwise} + \end{cases} + + out = x * \tanh(softplus) + +)DOC"); + } +}; + +// The operator to calculate gradients of a prelu operator. +class MishGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mish"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "mish"); + + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +template +class MishGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("mish_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(mish, ops::MishOp, ops::MishOpMaker, + ops::MishGradOpMaker, + ops::MishGradOpMaker); +REGISTER_OPERATOR(mish_grad, ops::MishGradOp); +REGISTER_OP_CPU_KERNEL( + mish, ops::MishFP32CPUKernel, + ops::MishCPUKernel); +REGISTER_OP_CPU_KERNEL( + mish_grad, ops::MishGradFP32CPUKernel, + ops::MishGradCPUKernel); diff --git a/paddle/fluid/operators/mish_op.cu b/paddle/fluid/operators/mish_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..77817e526e13d0618fbbfea313fe1d4c28cd582d --- /dev/null +++ b/paddle/fluid/operators/mish_op.cu @@ -0,0 +1,173 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mish_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_config.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void KeMishFw(const T* in, T* out, const int numel, + const float threshold) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < numel; tid += stride) { + T x = in[tid]; + T sp = CalcSoftplus(x, threshold); + out[tid] = x * tanh(sp); + } +} + +// expf instead of exp should be used for float type, complement +// and register float kernel separatelly +__global__ void KeMishFwFP32(const float* in, float* out, const int numel, + const float threshold) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < numel; tid += stride) { + float x = in[tid]; + float sp = CalcSoftplusFP32(x, threshold); + out[tid] = x * tanhf(sp); + } +} + +template +__global__ void KeMishBw(const T* in, const T* dout, T* din, const int numel, + const float threshold) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < numel; tid += stride) { + T x = in[tid]; + T sp = CalcSoftplus(x, threshold); + T tsp = tanh(sp); + T grad_sp = -expm1(-sp); + T grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; + din[tid] = dout[tid] * (x * grad_tsp + tsp); + } +} + +__global__ void KeMishBwFP32(const float* in, const float* dout, float* din, + const int numel, const float threshold) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < numel; tid += stride) { + float x = in[tid]; + float sp = CalcSoftplusFP32(x, threshold); + float tsp = tanhf(sp); + float grad_sp = -expm1f(-sp); + float grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; + din[tid] = dout[tid] * (x * grad_tsp + tsp); + } +} + +template +class MishCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + const float threshold = ctx.Attr("threshold"); + + const T* x_data = x->data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + const int numel = x->numel(); + + platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); + KeMishFw<<>>(x_data, out_data, numel, + threshold); + } +}; + +template +class MishFP32CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + const float threshold = ctx.Attr("threshold"); + + const float* x_data = x->data(); + float* out_data = out->mutable_data(ctx.GetPlace()); + + const int numel = x->numel(); + + platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); + KeMishFwFP32<<>>(x_data, out_data, + numel, threshold); + } +}; + +template +class MishGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + auto threshold = ctx.Attr("threshold"); + + const T* x_data = x->data(); + const T* dout_data = dout->data(); + T* dx_data = dx->mutable_data(ctx.GetPlace()); + + const int numel = x->numel(); + + platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); + KeMishBw<<>>( + x_data, dout_data, dx_data, numel, threshold); + } +}; + +template +class MishGradFP32CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + auto threshold = ctx.Attr("threshold"); + + const float* x_data = x->data(); + const float* dout_data = dout->data(); + float* dx_data = dx->mutable_data(ctx.GetPlace()); + + const int numel = x->numel(); + + platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); + KeMishBwFP32<<>>( + x_data, dout_data, dx_data, numel, threshold); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + mish, ops::MishFP32CUDAKernel, + ops::MishCUDAKernel) +REGISTER_OP_CUDA_KERNEL( + mish_grad, ops::MishGradFP32CUDAKernel, + ops::MishGradCUDAKernel) diff --git a/paddle/fluid/operators/mish_op.h b/paddle/fluid/operators/mish_op.h new file mode 100644 index 0000000000000000000000000000000000000000..86ccb57d929e5dec72fe67185530478109e2d7f0 --- /dev/null +++ b/paddle/fluid/operators/mish_op.h @@ -0,0 +1,137 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +HOSTDEVICE static T CalcSoftplus(T x, float threshold) { + if (threshold > 0 && x > threshold) { + return x; + } else if (threshold > 0 && x < -threshold) { + return exp(x); + } else { + return log1p(exp(x)); + } +} + +// expf instead of exp should be used for float type, complement +// and register float kernel separatelly +HOSTDEVICE static float CalcSoftplusFP32(float x, float threshold) { + if (threshold > 0 && x > threshold) { + return x; + } else if (threshold > 0 && x < -threshold) { + return expf(x); + } else { + return log1pf(expf(x)); + } +} + +template +class MishCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + const float threshold = ctx.Attr("threshold"); + + const T* x_data = x->data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + int numel = x->numel(); + for (int i = 0; i < numel; i++) { + T x_d = x_data[i]; + T sp = CalcSoftplus(x_d, threshold); + out_data[i] = x_d * std::tanh(sp); + } + } +}; + +template +class MishFP32CPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + const float threshold = ctx.Attr("threshold"); + + const float* x_data = x->data(); + float* out_data = out->mutable_data(ctx.GetPlace()); + + int numel = x->numel(); + for (int i = 0; i < numel; i++) { + float x_d = x_data[i]; + float sp = CalcSoftplusFP32(x_d, threshold); + out_data[i] = x_d * std::tanh(sp); + } + } +}; + +template +class MishGradCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + auto threshold = ctx.Attr("threshold"); + + const T* x_data = x->data(); + const T* dout_data = dout->data(); + T* dx_data = dx->mutable_data(ctx.GetPlace()); + + int numel = x->numel(); + for (int i = 0; i < numel; i++) { + T x_d = x_data[i]; + T sp = CalcSoftplus(x_d, threshold); + T tsp = std::tanh(sp); + T grad_sp = -std::expm1(-sp); + T grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; + dx_data[i] = dout_data[i] * (x_d * grad_tsp + tsp); + } + } +}; + +template +class MishGradFP32CPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + auto threshold = ctx.Attr("threshold"); + + const float* x_data = x->data(); + const float* dout_data = dout->data(); + float* dx_data = dx->mutable_data(ctx.GetPlace()); + + int numel = x->numel(); + for (int i = 0; i < numel; i++) { + float x_d = x_data[i]; + float sp = CalcSoftplusFP32(x_d, threshold); + float tsp = std::tanh(sp); + float grad_sp = -std::expm1f(-sp); + float grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; + dx_data[i] = dout_data[i] * (x_d * grad_tsp + tsp); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 7233d27e847512e7eaa5682247b3b363159973fc..6ddb6c222d449413b1648e57453a7bd955cb52cb 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -185,6 +185,7 @@ __all__ = [ 'filter_by_instag', 'shard_index', 'hard_swish', + 'mish', 'gather_tree', 'uniform_random', 'unbind', @@ -14782,6 +14783,81 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None): return out +@templatedoc() +def mish(x, threshold=20, name=None): + """ + This operator implements the mish activation function. + Refer to `Mish: A Self Regularized Non-Monotonic Neural + Activation Function `_ + + + The formula is as follows if :attr:`threshold` is :code:`None` or negative: + + .. math:: + + out = x * \\tanh(\\ln(1 + e^{x})) + + The formula is as follows if :attr:`threshold` is set as positive value: + + .. math:: + + out = \\begin{cases} + x \\ast \\tanh(x), \\text{if } x > \\text{threshold} \\\\ + x \\ast \\tanh(e^{x}), \\text{if } x < -\\text{threshold} \\\\ + x \\ast \\tanh(\\ln(1 + e^{x})), \\text{otherwise} + \\end{cases} + + Args: + x (Variable): Input feature, multi-dimensional Tensor. The data type + should be float16, float32 or float64. + threshold (float|None): threshold for softplus in Mish operator. + Approximate value of softplus will be used if absolute value + of input is greater than :attr:threshold and :attr:threshold + is set as positive value. For none or negative threshold, + approximate value is not used. Default 20. + name (str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name` + + Returns: + Variable: The output tensor with the same shape and data type as input. + + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + + DATATYPE='float32' + + x_data = np.array([i for i in range(1,5)]).reshape([1,1,4]).astype(DATATYPE) + + x = fluid.data(name="x", shape=[None,1,4], dtype=DATATYPE) + y = fluid.layers.mish(x) + + place = fluid.CPUPlace() + # place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + out, = exe.run(feed={'x':x_data}, fetch_list=[y.name]) + print(out) # [[0.66666667, 1.66666667, 3., 4.]] + """ + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'mish') + check_type(threshold, 'threshold', (float, int), 'mish') + assert threshold > 0, "threshold of mish should be greater than 0, " \ + "but got {}".format(threshold) + + helper = LayerHelper('mish', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='mish', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'threshold': threshold or -1}) + return out + + def gather_tree(ids, parents): """ To be used after beam search. After beam search, we get selected ids at diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index ad091b7e5fee9cd868b2b911b7654d074238b585..d700397cfaf2acb797a7235730dabec79ebe6562 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -2709,6 +2709,13 @@ class TestBook(LayerTest): out = layers.softsign(input, name='softsign') return (out) + def make_mish(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") + out = layers.mish(input, name='mish') + return (out) + def make_cross_entropy(self): with program_guard(fluid.default_main_program(), fluid.default_startup_program()): diff --git a/python/paddle/fluid/tests/unittests/test_mish_op.py b/python/paddle/fluid/tests/unittests/test_mish_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc785e450f0bac54f2193aac45165bc9b800b73 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_mish_op.py @@ -0,0 +1,102 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import six +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard +from op_test import OpTest, skip_check_grad_ci + + +class TestMishOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program()): + # The input type must be Variable. + self.assertRaises(TypeError, fluid.layers.mish, 0.1, 20) + # The input dtype must be float16, float32, float64. + x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, fluid.layers.mish, x_int32, 20) + # support the input dtype is float32 + x_fp16 = fluid.layers.data( + name='x_fp16', shape=[12, 10], dtype='float32') + fluid.layers.mish(x_fp16, threshold=20) + + +class MishTest(OpTest): + def setUp(self): + self.init_dtype() + self.init_input_shape() + self.init_input_range() + self.init_threshold() + self.op_type = "mish" + + x_np = np.random.uniform(self.x_range[0], self.x_range[1], + self.x_shape).astype(self.dtype) + self.inputs = {'X': x_np} + + softplus = x_np * (x_np > self.threshold) + np.exp(x_np) * \ + (x_np < -self.threshold) + np.log(np.exp(x_np) + 1.) * \ + (x_np >= -self.threshold) * (x_np <= self.threshold) + out_np = x_np * np.tanh(softplus) + + self.outputs = {'Out': out_np} + self.attrs = {'threshold': self.threshold} + + def init_dtype(self): + self.dtype = 'float32' + + def init_input_shape(self): + self.x_shape = (10, 12) + + def init_input_range(self): + self.x_range = [-1, 1] + + def init_threshold(self): + self.threshold = 5. + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class MishTestUpperThresh(MishTest): + def init_input_range(self): + self.x_range = [6, 7] + + +class MishTestLowerThresh(MishTest): + def init_input_range(self): + self.x_range = [-7, -6] + + +# mish op contain calculation like: tanh, exp, log, while tanh +# may have diff on CPUPlace(see test_activation_op.py::TestTanh), +# especially when abs(x) is a large value, only check input value +# in range [-1, 1] for float64 here. +class MishTestFP64(MishTest): + def init_dtype(self): + self.dtype = 'float64' + + def init_input_range(self): + self.x_range = [-1, 1] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py index ae99aeff557e4aa31f2868fbb8be9d038d5538ca..db5ad92ff5ead4fbc609209692268bef254d8c27 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py @@ -70,6 +70,7 @@ NO_FP64_CHECK_GRAD_OP_LIST = [ 'squared_l2_distance', \ 'squared_l2_norm', \ 'tanh', \ + 'mish', \ 'transpose2', \ 'trilinear_interp', \ 'var_conv_2d', \