未验证 提交 74468bf4 编写于 作者: K Kaipeng Deng 提交者: GitHub

add mish op. (#24565)

* add mish op. test=develop
上级 b9aeb681
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mish_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
class MishOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mish");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mish");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Mish operator");
AddOutput("Out", "Output of Mish operator");
AddAttr<float>(
"threshold",
"Constant threshold of softplus in Mish operator. Approximate value "
"of softplus will be used if absolute value of input is greater than "
":attr:`threshold`")
.SetDefault(20.f);
AddComment(R"DOC(
Mish Activation Operator.
.. math::
softplus = \begin{cases}
x, \text{if } x > \text{threshold} \\
e^{x}, \text{if } x < -\text{threshold} \\
\ln(1 + e^{x}), \text{otherwise}
\end{cases}
out = x * \tanh(softplus)
)DOC");
}
};
// The operator to calculate gradients of a prelu operator.
class MishGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mish");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "mish");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
template <typename T>
class MishGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("mish_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mish, ops::MishOp, ops::MishOpMaker,
ops::MishGradOpMaker<paddle::framework::OpDesc>,
ops::MishGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(mish_grad, ops::MishGradOp);
REGISTER_OP_CPU_KERNEL(
mish, ops::MishFP32CPUKernel<paddle::platform::CPUDeviceContext>,
ops::MishCPUKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
mish_grad, ops::MishGradFP32CPUKernel<paddle::platform::CPUDeviceContext>,
ops::MishGradCPUKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mish_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void KeMishFw(const T* in, T* out, const int numel,
const float threshold) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < numel; tid += stride) {
T x = in[tid];
T sp = CalcSoftplus<T>(x, threshold);
out[tid] = x * tanh(sp);
}
}
// expf instead of exp should be used for float type, complement
// and register float kernel separatelly
__global__ void KeMishFwFP32(const float* in, float* out, const int numel,
const float threshold) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < numel; tid += stride) {
float x = in[tid];
float sp = CalcSoftplusFP32(x, threshold);
out[tid] = x * tanhf(sp);
}
}
template <typename T>
__global__ void KeMishBw(const T* in, const T* dout, T* din, const int numel,
const float threshold) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < numel; tid += stride) {
T x = in[tid];
T sp = CalcSoftplus<T>(x, threshold);
T tsp = tanh(sp);
T grad_sp = -expm1(-sp);
T grad_tsp = (static_cast<T>(1) - tsp * tsp) * grad_sp;
din[tid] = dout[tid] * (x * grad_tsp + tsp);
}
}
__global__ void KeMishBwFP32(const float* in, const float* dout, float* din,
const int numel, const float threshold) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < numel; tid += stride) {
float x = in[tid];
float sp = CalcSoftplusFP32(x, threshold);
float tsp = tanhf(sp);
float grad_sp = -expm1f(-sp);
float grad_tsp = (static_cast<float>(1) - tsp * tsp) * grad_sp;
din[tid] = dout[tid] * (x * grad_tsp + tsp);
}
}
template <typename DeviceContext, typename T>
class MishCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
const float threshold = ctx.Attr<float>("threshold");
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
const int numel = x->numel();
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
KeMishFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(x_data, out_data, numel,
threshold);
}
};
template <typename DeviceContext>
class MishFP32CUDAKernel : public framework::OpKernel<float> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
const float threshold = ctx.Attr<float>("threshold");
const float* x_data = x->data<float>();
float* out_data = out->mutable_data<float>(ctx.GetPlace());
const int numel = x->numel();
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
KeMishFwFP32<<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(x_data, out_data,
numel, threshold);
}
};
template <typename DeviceContext, typename T>
class MishGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto threshold = ctx.Attr<float>("threshold");
const T* x_data = x->data<T>();
const T* dout_data = dout->data<T>();
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const int numel = x->numel();
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
KeMishBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
x_data, dout_data, dx_data, numel, threshold);
}
};
template <typename DeviceContext>
class MishGradFP32CUDAKernel : public framework::OpKernel<float> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto threshold = ctx.Attr<float>("threshold");
const float* x_data = x->data<float>();
const float* dout_data = dout->data<float>();
float* dx_data = dx->mutable_data<float>(ctx.GetPlace());
const int numel = x->numel();
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
KeMishBwFP32<<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
x_data, dout_data, dx_data, numel, threshold);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
mish, ops::MishFP32CUDAKernel<paddle::platform::CUDADeviceContext>,
ops::MishCUDAKernel<paddle::platform::CUDADeviceContext, double>)
REGISTER_OP_CUDA_KERNEL(
mish_grad, ops::MishGradFP32CUDAKernel<paddle::platform::CUDADeviceContext>,
ops::MishGradCUDAKernel<paddle::platform::CUDADeviceContext, double>)
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
HOSTDEVICE static T CalcSoftplus(T x, float threshold) {
if (threshold > 0 && x > threshold) {
return x;
} else if (threshold > 0 && x < -threshold) {
return exp(x);
} else {
return log1p(exp(x));
}
}
// expf instead of exp should be used for float type, complement
// and register float kernel separatelly
HOSTDEVICE static float CalcSoftplusFP32(float x, float threshold) {
if (threshold > 0 && x > threshold) {
return x;
} else if (threshold > 0 && x < -threshold) {
return expf(x);
} else {
return log1pf(expf(x));
}
}
template <typename DeviceContext, typename T>
class MishCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
const float threshold = ctx.Attr<float>("threshold");
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
int numel = x->numel();
for (int i = 0; i < numel; i++) {
T x_d = x_data[i];
T sp = CalcSoftplus<T>(x_d, threshold);
out_data[i] = x_d * std::tanh(sp);
}
}
};
template <typename DeviceContext>
class MishFP32CPUKernel : public framework::OpKernel<float> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
const float threshold = ctx.Attr<float>("threshold");
const float* x_data = x->data<float>();
float* out_data = out->mutable_data<float>(ctx.GetPlace());
int numel = x->numel();
for (int i = 0; i < numel; i++) {
float x_d = x_data[i];
float sp = CalcSoftplusFP32(x_d, threshold);
out_data[i] = x_d * std::tanh(sp);
}
}
};
template <typename DeviceContext, typename T>
class MishGradCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto threshold = ctx.Attr<float>("threshold");
const T* x_data = x->data<T>();
const T* dout_data = dout->data<T>();
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
int numel = x->numel();
for (int i = 0; i < numel; i++) {
T x_d = x_data[i];
T sp = CalcSoftplus<T>(x_d, threshold);
T tsp = std::tanh(sp);
T grad_sp = -std::expm1(-sp);
T grad_tsp = (static_cast<T>(1) - tsp * tsp) * grad_sp;
dx_data[i] = dout_data[i] * (x_d * grad_tsp + tsp);
}
}
};
template <typename DeviceContext>
class MishGradFP32CPUKernel : public framework::OpKernel<float> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto threshold = ctx.Attr<float>("threshold");
const float* x_data = x->data<float>();
const float* dout_data = dout->data<float>();
float* dx_data = dx->mutable_data<float>(ctx.GetPlace());
int numel = x->numel();
for (int i = 0; i < numel; i++) {
float x_d = x_data[i];
float sp = CalcSoftplusFP32(x_d, threshold);
float tsp = std::tanh(sp);
float grad_sp = -std::expm1f(-sp);
float grad_tsp = (static_cast<float>(1) - tsp * tsp) * grad_sp;
dx_data[i] = dout_data[i] * (x_d * grad_tsp + tsp);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -185,6 +185,7 @@ __all__ = [
'filter_by_instag',
'shard_index',
'hard_swish',
'mish',
'gather_tree',
'uniform_random',
'unbind',
......@@ -14782,6 +14783,81 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
return out
@templatedoc()
def mish(x, threshold=20, name=None):
"""
This operator implements the mish activation function.
Refer to `Mish: A Self Regularized Non-Monotonic Neural
Activation Function <https://arxiv.org/abs/1908.08681>`_
The formula is as follows if :attr:`threshold` is :code:`None` or negative:
.. math::
out = x * \\tanh(\\ln(1 + e^{x}))
The formula is as follows if :attr:`threshold` is set as positive value:
.. math::
out = \\begin{cases}
x \\ast \\tanh(x), \\text{if } x > \\text{threshold} \\\\
x \\ast \\tanh(e^{x}), \\text{if } x < -\\text{threshold} \\\\
x \\ast \\tanh(\\ln(1 + e^{x})), \\text{otherwise}
\\end{cases}
Args:
x (Variable): Input feature, multi-dimensional Tensor. The data type
should be float16, float32 or float64.
threshold (float|None): threshold for softplus in Mish operator.
Approximate value of softplus will be used if absolute value
of input is greater than :attr:threshold and :attr:threshold
is set as positive value. For none or negative threshold,
approximate value is not used. Default 20.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`
Returns:
Variable: The output tensor with the same shape and data type as input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
DATATYPE='float32'
x_data = np.array([i for i in range(1,5)]).reshape([1,1,4]).astype(DATATYPE)
x = fluid.data(name="x", shape=[None,1,4], dtype=DATATYPE)
y = fluid.layers.mish(x)
place = fluid.CPUPlace()
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
out, = exe.run(feed={'x':x_data}, fetch_list=[y.name])
print(out) # [[0.66666667, 1.66666667, 3., 4.]]
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'mish')
check_type(threshold, 'threshold', (float, int), 'mish')
assert threshold > 0, "threshold of mish should be greater than 0, " \
"but got {}".format(threshold)
helper = LayerHelper('mish', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='mish',
inputs={'X': x},
outputs={'Out': out},
attrs={'threshold': threshold or -1})
return out
def gather_tree(ids, parents):
"""
To be used after beam search. After beam search, we get selected ids at
......
......@@ -2709,6 +2709,13 @@ class TestBook(LayerTest):
out = layers.softsign(input, name='softsign')
return (out)
def make_mish(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = self._get_data(name="input", shape=[16], dtype="float32")
out = layers.mish(input, name='mish')
return (out)
def make_cross_entropy(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import six
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from op_test import OpTest, skip_check_grad_ci
class TestMishOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.mish, 0.1, 20)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.mish, x_int32, 20)
# support the input dtype is float32
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float32')
fluid.layers.mish(x_fp16, threshold=20)
class MishTest(OpTest):
def setUp(self):
self.init_dtype()
self.init_input_shape()
self.init_input_range()
self.init_threshold()
self.op_type = "mish"
x_np = np.random.uniform(self.x_range[0], self.x_range[1],
self.x_shape).astype(self.dtype)
self.inputs = {'X': x_np}
softplus = x_np * (x_np > self.threshold) + np.exp(x_np) * \
(x_np < -self.threshold) + np.log(np.exp(x_np) + 1.) * \
(x_np >= -self.threshold) * (x_np <= self.threshold)
out_np = x_np * np.tanh(softplus)
self.outputs = {'Out': out_np}
self.attrs = {'threshold': self.threshold}
def init_dtype(self):
self.dtype = 'float32'
def init_input_shape(self):
self.x_shape = (10, 12)
def init_input_range(self):
self.x_range = [-1, 1]
def init_threshold(self):
self.threshold = 5.
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class MishTestUpperThresh(MishTest):
def init_input_range(self):
self.x_range = [6, 7]
class MishTestLowerThresh(MishTest):
def init_input_range(self):
self.x_range = [-7, -6]
# mish op contain calculation like: tanh, exp, log, while tanh
# may have diff on CPUPlace(see test_activation_op.py::TestTanh),
# especially when abs(x) is a large value, only check input value
# in range [-1, 1] for float64 here.
class MishTestFP64(MishTest):
def init_dtype(self):
self.dtype = 'float64'
def init_input_range(self):
self.x_range = [-1, 1]
if __name__ == "__main__":
unittest.main()
......@@ -70,6 +70,7 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
'squared_l2_distance', \
'squared_l2_norm', \
'tanh', \
'mish', \
'transpose2', \
'trilinear_interp', \
'var_conv_2d', \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册