From 5b6d834dee56c84ba9e59aad646e1286bc47451d Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Mon, 10 Jan 2022 15:03:53 +0800 Subject: [PATCH] [cherry-pick] modify mish op and add mish api (#38803) * add mish operator and api * remove redundant code and modify grad_atol of mish unittest * modify mish code to be consistent with other activation implementation * modify comment of mish --- paddle/fluid/operators/activation_op.cc | 37 ++++ paddle/fluid/operators/activation_op.cu | 50 +++++ paddle/fluid/operators/activation_op.h | 41 ++++ paddle/fluid/operators/mish_op.cc | 121 ------------ paddle/fluid/operators/mish_op.cu | 177 ------------------ paddle/fluid/operators/mish_op.h | 137 -------------- python/paddle/fluid/layers/nn.py | 5 +- .../tests/unittests/test_activation_op.py | 81 ++++++++ .../fluid/tests/unittests/test_mish_op.py | 102 ---------- python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/activation.py | 40 ++++ python/paddle/nn/layer/activation.py | 45 +++++ 13 files changed, 302 insertions(+), 538 deletions(-) delete mode 100644 paddle/fluid/operators/mish_op.cc delete mode 100644 paddle/fluid/operators/mish_op.cu delete mode 100644 paddle/fluid/operators/mish_op.h delete mode 100644 python/paddle/fluid/tests/unittests/test_mish_op.py diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index cacb6dd8fe..0bf6c623f6 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -706,6 +706,36 @@ $$out = \\frac{x}{1 + e^{- \beta \ x}}$$ } }; +class MishOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input of Mish operator"); + AddOutput("Out", "Output of Mish operator"); + AddAttr( + "threshold", + "Constant threshold of softplus in Mish operator. Approximate value " + "of softplus will be used if absolute value of input is greater than " + ":attr:`threshold`") + .SetDefault(20.f); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false) + .AsExtra(); + AddComment(R"DOC( +Mish Activation Operator. + +.. math:: + softplus(x) = \begin{cases} + x, \text{if } x > \text{threshold} \\ + \ln(1 + e^{x}), \text{otherwise} + \end{cases} + + out = x * \tanh(softplus(x)) + +)DOC"); + } +}; + class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -1530,4 +1560,11 @@ REGISTER_OP_VERSION(softplus) .NewAttr("threshold", "The threshold value of the new formula", 20.0f)); +REGISTER_OP_VERSION(mish) + .AddCheckpoint( + R"ROC(add new attributes [use_mkldnn], and when computing softplus the formula is changed as the new veriosn of softplus)ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_mkldnn", "(bool, default false) Only used in mkldnn kernel", + false)); + /* ========================================================================== */ diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 3ff164a869..4db6fb0fc4 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1066,6 +1066,55 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct CudaMishFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // mish(x) = x * tanh(softplus(x)) + // softplus(x) = x, if x > threshold + // = ln(1 + exp(x)), otherwise + // Inputs: args[0], the input x + __device__ __forceinline__ T operator()(const T& arg_x) const { + MPType x = static_cast(arg_x); + MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); + return static_cast(x * tanh(sp)); + } +}; + +template +struct CudaMishGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) + // sp = softplus(x) + // Inputs: args[0], the input dout + // args[1], the input x + __device__ __forceinline__ T operator()(const T& arg_dout, + const T& arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType sp = (x > static_cast(threshold)) ? x : log(one + exp(x)); + MPType gsp = + (x > static_cast(threshold)) ? one : one / (one + exp(-x)); + MPType tsp = tanh(sp); + return static_cast(dout * (tsp + x * (one - tsp * tsp) * gsp)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + template struct CudaThresholdedReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); @@ -1627,6 +1676,7 @@ REGISTER_OP_CUDA_KERNEL( __macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor, \ CudaHardSigmoidGradFunctor); \ __macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \ + __macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor); \ __macro(thresholded_relu, ThresholdedRelu, CudaThresholdedReluFunctor, \ CudaThresholdedReluGradFunctor); \ __macro(hard_swish, HardSwish, CudaHardSwishFunctor, \ diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 6fc9fc5d1f..add876d9c9 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1197,6 +1197,46 @@ struct SoftplusGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +// mish(x) = x * tanh(softplus(x)) +// softplus(x) = x, if x > threshold +// = ln(1 + exp(x)), otherwise +template +struct MishFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Out out) { + auto sp = (x > static_cast(threshold)) + .select(x, (static_cast(1) + x.exp()).log()); + out.device(d) = x * sp.tanh(); + } +}; + +// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp))) +// sp = softplus(x) +template +struct MishGradFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) { + auto sp = (x > static_cast(threshold)) + .select(x, (static_cast(1) + x.exp()).log()); + auto gsp = static_cast(1) - (-sp).exp(); + auto tsp = sp.tanh(); + dx.device(d) = dout * (tsp + x * (static_cast(1) - tsp * tsp) * gsp); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { @@ -2330,4 +2370,5 @@ struct LogGradGradFunctor : public BaseActivationFunctor { __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ __macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, \ ThresholdedReluGradFunctor); \ + __macro(mish, Mish, MishFunctor, MishGradFunctor); \ __macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor); diff --git a/paddle/fluid/operators/mish_op.cc b/paddle/fluid/operators/mish_op.cc deleted file mode 100644 index ea754b5b1e..0000000000 --- a/paddle/fluid/operators/mish_op.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/mish_op.h" -#include -#include - -namespace paddle { -namespace operators { - -class MishOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mish"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mish"); - - ctx->ShareDim("X", /*->*/ "Out"); - ctx->ShareLoD("X", /*->*/ "Out"); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); - } -}; - -class MishOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "Input of Mish operator"); - AddOutput("Out", "Output of Mish operator"); - AddAttr( - "threshold", - "Constant threshold of softplus in Mish operator. Approximate value " - "of softplus will be used if absolute value of input is greater than " - ":attr:`threshold`") - .SetDefault(20.f); - AddComment(R"DOC( -Mish Activation Operator. - -.. math:: - softplus = \begin{cases} - x, \text{if } x > \text{threshold} \\ - e^{x}, \text{if } x < -\text{threshold} \\ - \ln(1 + e^{x}), \text{otherwise} - \end{cases} - - out = x * \tanh(softplus) - -)DOC"); - } -}; - -// The operator to calculate gradients of a prelu operator. -class MishGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mish"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - "Out@GRAD", "mish"); - - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); - } - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); - } -}; - -template -class MishGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("mish_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(mish, ops::MishOp, ops::MishOpMaker, - ops::MishGradOpMaker, - ops::MishGradOpMaker); -REGISTER_OPERATOR(mish_grad, ops::MishGradOp); -REGISTER_OP_CPU_KERNEL( - mish, ops::MishFP32CPUKernel, - ops::MishCPUKernel); -REGISTER_OP_CPU_KERNEL( - mish_grad, ops::MishGradFP32CPUKernel, - ops::MishGradCPUKernel); diff --git a/paddle/fluid/operators/mish_op.cu b/paddle/fluid/operators/mish_op.cu deleted file mode 100644 index 6513e5d95e..0000000000 --- a/paddle/fluid/operators/mish_op.cu +++ /dev/null @@ -1,177 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/mish_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" -#include "paddle/fluid/platform/gpu_launch_config.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -__global__ void KeMishFw(const T* in, T* out, const int numel, - const float threshold) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - for (; tid < numel; tid += stride) { - T x = in[tid]; - T sp = CalcSoftplus(x, threshold); - out[tid] = x * tanh(sp); - } -} - -// expf instead of exp should be used for float type, complement -// and register float kernel separatelly -__global__ void KeMishFwFP32(const float* in, float* out, const int numel, - const float threshold) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - for (; tid < numel; tid += stride) { - float x = in[tid]; - float sp = CalcSoftplusFP32(x, threshold); - out[tid] = x * tanhf(sp); - } -} - -template -__global__ void KeMishBw(const T* in, const T* dout, T* din, const int numel, - const float threshold) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - for (; tid < numel; tid += stride) { - T x = in[tid]; - T sp = CalcSoftplus(x, threshold); - T tsp = tanh(sp); - T grad_sp = -expm1(-sp); - T grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; - din[tid] = dout[tid] * (x * grad_tsp + tsp); - } -} - -__global__ void KeMishBwFP32(const float* in, const float* dout, float* din, - const int numel, const float threshold) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - for (; tid < numel; tid += stride) { - float x = in[tid]; - float sp = CalcSoftplusFP32(x, threshold); - float tsp = tanhf(sp); - float grad_sp = -expm1f(-sp); - float grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; - din[tid] = dout[tid] * (x * grad_tsp + tsp); - } -} - -template -class MishCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - - const float threshold = ctx.Attr("threshold"); - - const T* x_data = x->data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - - const int numel = x->numel(); - - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel); - KeMishFw<<>>(x_data, out_data, numel, - threshold); - } -}; - -template -class MishFP32CUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - - const float threshold = ctx.Attr("threshold"); - - const float* x_data = x->data(); - float* out_data = out->mutable_data(ctx.GetPlace()); - - const int numel = x->numel(); - - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel); - KeMishFwFP32<<>>(x_data, out_data, - numel, threshold); - } -}; - -template -class MishGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - - auto threshold = ctx.Attr("threshold"); - - const T* x_data = x->data(); - const T* dout_data = dout->data(); - T* dx_data = dx->mutable_data(ctx.GetPlace()); - - const int numel = x->numel(); - - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel); - KeMishBw<<>>( - x_data, dout_data, dx_data, numel, threshold); - } -}; - -template -class MishGradFP32CUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - - auto threshold = ctx.Attr("threshold"); - - const float* x_data = x->data(); - const float* dout_data = dout->data(); - float* dx_data = dx->mutable_data(ctx.GetPlace()); - - const int numel = x->numel(); - - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel); - KeMishBwFP32<<>>( - x_data, dout_data, dx_data, numel, threshold); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - mish, ops::MishFP32CUDAKernel, - ops::MishCUDAKernel) -REGISTER_OP_CUDA_KERNEL( - mish_grad, ops::MishGradFP32CUDAKernel, - ops::MishGradCUDAKernel) diff --git a/paddle/fluid/operators/mish_op.h b/paddle/fluid/operators/mish_op.h deleted file mode 100644 index 86ccb57d92..0000000000 --- a/paddle/fluid/operators/mish_op.h +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -HOSTDEVICE static T CalcSoftplus(T x, float threshold) { - if (threshold > 0 && x > threshold) { - return x; - } else if (threshold > 0 && x < -threshold) { - return exp(x); - } else { - return log1p(exp(x)); - } -} - -// expf instead of exp should be used for float type, complement -// and register float kernel separatelly -HOSTDEVICE static float CalcSoftplusFP32(float x, float threshold) { - if (threshold > 0 && x > threshold) { - return x; - } else if (threshold > 0 && x < -threshold) { - return expf(x); - } else { - return log1pf(expf(x)); - } -} - -template -class MishCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - - const float threshold = ctx.Attr("threshold"); - - const T* x_data = x->data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - - int numel = x->numel(); - for (int i = 0; i < numel; i++) { - T x_d = x_data[i]; - T sp = CalcSoftplus(x_d, threshold); - out_data[i] = x_d * std::tanh(sp); - } - } -}; - -template -class MishFP32CPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - - const float threshold = ctx.Attr("threshold"); - - const float* x_data = x->data(); - float* out_data = out->mutable_data(ctx.GetPlace()); - - int numel = x->numel(); - for (int i = 0; i < numel; i++) { - float x_d = x_data[i]; - float sp = CalcSoftplusFP32(x_d, threshold); - out_data[i] = x_d * std::tanh(sp); - } - } -}; - -template -class MishGradCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dout = ctx.Input(framework::GradVarName("Out")); - - auto threshold = ctx.Attr("threshold"); - - const T* x_data = x->data(); - const T* dout_data = dout->data(); - T* dx_data = dx->mutable_data(ctx.GetPlace()); - - int numel = x->numel(); - for (int i = 0; i < numel; i++) { - T x_d = x_data[i]; - T sp = CalcSoftplus(x_d, threshold); - T tsp = std::tanh(sp); - T grad_sp = -std::expm1(-sp); - T grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; - dx_data[i] = dout_data[i] * (x_d * grad_tsp + tsp); - } - } -}; - -template -class MishGradFP32CPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dout = ctx.Input(framework::GradVarName("Out")); - - auto threshold = ctx.Attr("threshold"); - - const float* x_data = x->data(); - const float* dout_data = dout->data(); - float* dx_data = dx->mutable_data(ctx.GetPlace()); - - int numel = x->numel(); - for (int i = 0; i < numel; i++) { - float x_d = x_data[i]; - float sp = CalcSoftplusFP32(x_d, threshold); - float tsp = std::tanh(sp); - float grad_sp = -std::expm1f(-sp); - float grad_tsp = (static_cast(1) - tsp * tsp) * grad_sp; - dx_data[i] = dout_data[i] * (x_d * grad_tsp + tsp); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8d993f20da..303eb6e3be 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -15128,6 +15128,9 @@ def mish(x, threshold=20, name=None): out, = exe.run(feed={'x':x_data}, fetch_list=[y.name]) print(out) # [[0.66666667, 1.66666667, 3., 4.]] """ + if in_dygraph_mode(): + return _C_ops.mish(x, 'threshold', threshold) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'mish') check_type(threshold, 'threshold', (float, int), 'mish') assert threshold > 0, "threshold of mish should be greater than 0, " \ @@ -15139,7 +15142,7 @@ def mish(x, threshold=20, name=None): type='mish', inputs={'X': x}, outputs={'Out': out}, - attrs={'threshold': threshold or -1}) + attrs={'threshold': threshold}) return out diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 23eb42c5a9..182c44c19d 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -2695,6 +2695,86 @@ class TestSwishAPI(unittest.TestCase): F.swish(x_fp16) +def ref_mish(x, threshold=20.): + softplus = np.select([x <= threshold, x > threshold], + [np.log(1 + np.exp(x)), x]) + return x * np.tanh(softplus) + + +class TestMish(TestActivation): + def setUp(self): + self.op_type = "mish" + self.init_dtype() + + np.random.seed(1024) + x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) + out = ref_mish(x) + self.inputs = {'X': x} + self.outputs = {'Out': out} + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out') + + +class TestMishAPI(unittest.TestCase): + # test paddle.nn.Mish, paddle.nn.functional.mish + def setUp(self): + np.random.seed(1024) + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64) + self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) + out1 = F.mish(x) + mish = paddle.nn.Mish() + out2 = mish(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_mish(self.x_np) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.mish(x) + mish = paddle.nn.Mish() + out2 = mish(x) + out_ref = ref_mish(self.x_np) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_fluid_api(self): + paddle.enable_static() + with fluid.program_guard(fluid.Program()): + x = fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = fluid.layers.mish(x) + exe = fluid.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = ref_mish(self.x_np) + self.assertEqual(np.allclose(out_ref, res[0]), True) + + def test_errors(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.mish, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.fluid.data( + name='x_int32', shape=[12, 10], dtype='int32') + self.assertRaises(TypeError, F.mish, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.fluid.data( + name='x_fp16', shape=[12, 10], dtype='float16') + F.mish(x_fp16) + + #------------------ Test Error Activation---------------------- def create_test_error_class(op_type): class TestOpErrors(unittest.TestCase): @@ -2823,6 +2903,7 @@ create_test_act_fp16_class(TestThresholdedRelu) create_test_act_fp16_class(TestHardSigmoid) create_test_act_fp16_class(TestSwish, grad_atol=0.85) create_test_act_fp16_class(TestHardSwish) +create_test_act_fp16_class(TestMish, grad_atol=0.9) def create_test_act_bf16_class(parent, diff --git a/python/paddle/fluid/tests/unittests/test_mish_op.py b/python/paddle/fluid/tests/unittests/test_mish_op.py deleted file mode 100644 index 8cc785e450..0000000000 --- a/python/paddle/fluid/tests/unittests/test_mish_op.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import numpy as np -import six -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid import Program, program_guard -from op_test import OpTest, skip_check_grad_ci - - -class TestMishOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program()): - # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.mish, 0.1, 20) - # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.mish, x_int32, 20) - # support the input dtype is float32 - x_fp16 = fluid.layers.data( - name='x_fp16', shape=[12, 10], dtype='float32') - fluid.layers.mish(x_fp16, threshold=20) - - -class MishTest(OpTest): - def setUp(self): - self.init_dtype() - self.init_input_shape() - self.init_input_range() - self.init_threshold() - self.op_type = "mish" - - x_np = np.random.uniform(self.x_range[0], self.x_range[1], - self.x_shape).astype(self.dtype) - self.inputs = {'X': x_np} - - softplus = x_np * (x_np > self.threshold) + np.exp(x_np) * \ - (x_np < -self.threshold) + np.log(np.exp(x_np) + 1.) * \ - (x_np >= -self.threshold) * (x_np <= self.threshold) - out_np = x_np * np.tanh(softplus) - - self.outputs = {'Out': out_np} - self.attrs = {'threshold': self.threshold} - - def init_dtype(self): - self.dtype = 'float32' - - def init_input_shape(self): - self.x_shape = (10, 12) - - def init_input_range(self): - self.x_range = [-1, 1] - - def init_threshold(self): - self.threshold = 5. - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class MishTestUpperThresh(MishTest): - def init_input_range(self): - self.x_range = [6, 7] - - -class MishTestLowerThresh(MishTest): - def init_input_range(self): - self.x_range = [-7, -6] - - -# mish op contain calculation like: tanh, exp, log, while tanh -# may have diff on CPUPlace(see test_activation_op.py::TestTanh), -# especially when abs(x) is a large value, only check input value -# in range [-1, 1] for float64 here. -class MishTestFP64(MishTest): - def init_dtype(self): - self.dtype = 'float64' - - def init_input_range(self): - self.x_range = [-1, 1] - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 98444e69d0..11105b3428 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -45,6 +45,7 @@ from .layer.activation import Softplus # noqa: F401 from .layer.activation import Softshrink # noqa: F401 from .layer.activation import Softsign # noqa: F401 from .layer.activation import Swish # noqa: F401 +from .layer.activation import Mish # noqa: F401 from .layer.activation import Tanhshrink # noqa: F401 from .layer.activation import ThresholdedReLU # noqa: F401 from .layer.activation import LogSoftmax # noqa: F401 @@ -288,6 +289,7 @@ __all__ = [ #noqa 'LogSoftmax', 'Sigmoid', 'Swish', + 'Mish', 'PixelShuffle', 'ELU', 'ReLU6', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 4151f25b94..52ac779531 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -38,6 +38,7 @@ from .activation import softplus # noqa: F401 from .activation import softshrink # noqa: F401 from .activation import softsign # noqa: F401 from .activation import swish # noqa: F401 +from .activation import mish # noqa: F401 from .activation import tanh # noqa: F401 from .activation import tanh_ # noqa: F401 from .activation import tanhshrink # noqa: F401 @@ -144,6 +145,7 @@ __all__ = [ #noqa 'sigmoid', 'silu', 'swish', + 'mish', 'tanh', 'tanh_', 'tanhshrink', diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 07cc0d0d23..53c8579151 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1130,6 +1130,46 @@ def swish(x, name=None): return out +def mish(x, name=None): + r""" + mish activation. + + .. math:: + + softplus(x) = \begin{cases} + x, \text{if } x > \text{threshold} \\ + \ln(1 + e^{x}), \text{otherwise} + \end{cases} + + mish(x) = x * \tanh(softplus(x)) + + Parameters: + x (Tensor): The input Tensor with data type float32, float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + x = paddle.to_tensor([-5., 0., 5.]) + out = F.mish(x) # [-0.03357624, 0., 4.99955208] + """ + if in_dygraph_mode(): + return _C_ops.mish(x) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'mish') + helper = LayerHelper('mish', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='mish', inputs={'X': x}, outputs={'Out': out}) + return out + + def tanhshrink(x, name=None): """ tanhshrink activation diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 2c2df196b8..cc2b40c0ae 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -839,6 +839,51 @@ class Swish(Layer): return name_str +class Mish(Layer): + r""" + Mish Activation. + + .. math:: + + softplus(x) = \begin{cases} + x, \text{if } x > \text{threshold} \\ + \ln(1 + e^{x}), \text{otherwise} + \end{cases} + + Mish(x) = x * \tanh(softplus(x)) + + Parameters: + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.to_tensor([-5., 0., 5.]) + m = paddle.nn.Mish() + out = m(x) # [-0.03357624, 0., 4.99955208] + + """ + + def __init__(self, name=None): + super(Mish, self).__init__() + self._name = name + + def forward(self, x): + return F.mish(x, self._name) + + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + + class Tanhshrink(Layer): """ Tanhshrink Activation -- GitLab