未验证 提交 8c92337c 编写于 作者: W wangxinxin08 提交者: GitHub

modify mish op and add mish api (#38734)

* add mish operator and api

* remove redundant code and modify grad_atol of mish unittest

* modify mish code to be consistent with other activation implementation
上级 fb3313e9
......@@ -806,6 +806,36 @@ $$out = \\frac{x}{1 + e^{- \beta \ x}}$$
}
};
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Mish operator");
AddOutput("Out", "Output of Mish operator");
AddAttr<float>(
"threshold",
"Constant threshold of softplus in Mish operator. Approximate value "
"of softplus will be used if absolute value of input is greater than "
":attr:`threshold`")
.SetDefault(20.f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC(
Mish Activation Operator.
.. math::
softplus(x) = \begin{cases}
x, \text{if } x > \text{threshold} \\
\ln(1 + e^{x}), \text{otherwise}
\end{cases}
out = x * \tanh(softplus(x))
)DOC");
}
};
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -1901,4 +1931,11 @@ REGISTER_OP_VERSION(softplus)
.NewAttr("threshold", "The threshold value of the new formula",
20.0f));
REGISTER_OP_VERSION(mish)
.AddCheckpoint(
R"ROC(add new attributes [use_mkldnn], and when computing softplus the formula is changed as the new veriosn of softplus)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_mkldnn", "(bool, default false) Only used in mkldnn kernel",
false));
/* ========================================================================== */
......@@ -1145,6 +1145,55 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaMishFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T& arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
return static_cast<T>(x * tanh(sp));
}
};
template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T& arg_dout,
const T& arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
MPType gsp =
(x > static_cast<MPType>(threshold)) ? one : one / (one + exp(-x));
MPType tsp = tanh(sp);
return static_cast<T>(dout * (tsp + x * (one - tsp * tsp) * gsp));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
......@@ -1808,6 +1857,7 @@ REGISTER_OP_CUDA_KERNEL(
__macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor, \
CudaHardSigmoidGradFunctor); \
__macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \
__macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor); \
__macro(thresholded_relu, ThresholdedRelu, CudaThresholdedReluFunctor, \
CudaThresholdedReluGradFunctor); \
__macro(hard_swish, HardSwish, CudaHardSwishFunctor, \
......
......@@ -1412,6 +1412,46 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
template <typename T>
struct MishFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
auto sp = (x > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x.exp()).log());
out.device(d) = x * sp.tanh();
}
};
// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)
template <typename T>
struct MishGradFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
auto sp = (x > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x.exp()).log());
auto gsp = static_cast<T>(1) - (-sp).exp();
auto tsp = sp.tanh();
dx.device(d) = dout * (tsp + x * (static_cast<T>(1) - tsp * tsp) * gsp);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
......@@ -2841,4 +2881,5 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
__macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, \
ThresholdedReluGradFunctor); \
__macro(mish, Mish, MishFunctor, MishGradFunctor); \
__macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mish_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
class MishOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mish");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mish");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Mish operator");
AddOutput("Out", "Output of Mish operator");
AddAttr<float>(
"threshold",
"Constant threshold of softplus in Mish operator. Approximate value "
"of softplus will be used if absolute value of input is greater than "
":attr:`threshold`")
.SetDefault(20.f);
AddComment(R"DOC(
Mish Activation Operator.
.. math::
softplus = \begin{cases}
x, \text{if } x > \text{threshold} \\
e^{x}, \text{if } x < -\text{threshold} \\
\ln(1 + e^{x}), \text{otherwise}
\end{cases}
out = x * \tanh(softplus)
)DOC");
}
};
// The operator to calculate gradients of a prelu operator.
class MishGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mish");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "mish");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
template <typename T>
class MishGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("mish_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mish, ops::MishOp, ops::MishOpMaker,
ops::MishGradOpMaker<paddle::framework::OpDesc>,
ops::MishGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(mish_grad, ops::MishGradOp);
REGISTER_OP_CPU_KERNEL(
mish, ops::MishFP32CPUKernel<paddle::platform::CPUDeviceContext>,
ops::MishCPUKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
mish_grad, ops::MishGradFP32CPUKernel<paddle::platform::CPUDeviceContext>,
ops::MishGradCPUKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mish_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void KeMishFw(const T* in, T* out, const int numel,
const float threshold) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < numel; tid += stride) {
T x = in[tid];
T sp = CalcSoftplus<T>(x, threshold);
out[tid] = x * tanh(sp);
}
}
// expf instead of exp should be used for float type, complement
// and register float kernel separatelly
__global__ void KeMishFwFP32(const float* in, float* out, const int numel,
const float threshold) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < numel; tid += stride) {
float x = in[tid];
float sp = CalcSoftplusFP32(x, threshold);
out[tid] = x * tanhf(sp);
}
}
template <typename T>
__global__ void KeMishBw(const T* in, const T* dout, T* din, const int numel,
const float threshold) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < numel; tid += stride) {
T x = in[tid];
T sp = CalcSoftplus<T>(x, threshold);
T tsp = tanh(sp);
T grad_sp = -expm1(-sp);
T grad_tsp = (static_cast<T>(1) - tsp * tsp) * grad_sp;
din[tid] = dout[tid] * (x * grad_tsp + tsp);
}
}
__global__ void KeMishBwFP32(const float* in, const float* dout, float* din,
const int numel, const float threshold) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < numel; tid += stride) {
float x = in[tid];
float sp = CalcSoftplusFP32(x, threshold);
float tsp = tanhf(sp);
float grad_sp = -expm1f(-sp);
float grad_tsp = (static_cast<float>(1) - tsp * tsp) * grad_sp;
din[tid] = dout[tid] * (x * grad_tsp + tsp);
}
}
template <typename DeviceContext, typename T>
class MishCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
const float threshold = ctx.Attr<float>("threshold");
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
const int numel = x->numel();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel);
KeMishFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(x_data, out_data, numel,
threshold);
}
};
template <typename DeviceContext>
class MishFP32CUDAKernel : public framework::OpKernel<float> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
const float threshold = ctx.Attr<float>("threshold");
const float* x_data = x->data<float>();
float* out_data = out->mutable_data<float>(ctx.GetPlace());
const int numel = x->numel();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel);
KeMishFwFP32<<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(x_data, out_data,
numel, threshold);
}
};
template <typename DeviceContext, typename T>
class MishGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto threshold = ctx.Attr<float>("threshold");
const T* x_data = x->data<T>();
const T* dout_data = dout->data<T>();
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const int numel = x->numel();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel);
KeMishBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
x_data, dout_data, dx_data, numel, threshold);
}
};
template <typename DeviceContext>
class MishGradFP32CUDAKernel : public framework::OpKernel<float> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto threshold = ctx.Attr<float>("threshold");
const float* x_data = x->data<float>();
const float* dout_data = dout->data<float>();
float* dx_data = dx->mutable_data<float>(ctx.GetPlace());
const int numel = x->numel();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), numel);
KeMishBwFP32<<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
x_data, dout_data, dx_data, numel, threshold);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
mish, ops::MishFP32CUDAKernel<paddle::platform::CUDADeviceContext>,
ops::MishCUDAKernel<paddle::platform::CUDADeviceContext, double>)
REGISTER_OP_CUDA_KERNEL(
mish_grad, ops::MishGradFP32CUDAKernel<paddle::platform::CUDADeviceContext>,
ops::MishGradCUDAKernel<paddle::platform::CUDADeviceContext, double>)
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
HOSTDEVICE static T CalcSoftplus(T x, float threshold) {
if (threshold > 0 && x > threshold) {
return x;
} else if (threshold > 0 && x < -threshold) {
return exp(x);
} else {
return log1p(exp(x));
}
}
// expf instead of exp should be used for float type, complement
// and register float kernel separatelly
HOSTDEVICE static float CalcSoftplusFP32(float x, float threshold) {
if (threshold > 0 && x > threshold) {
return x;
} else if (threshold > 0 && x < -threshold) {
return expf(x);
} else {
return log1pf(expf(x));
}
}
template <typename DeviceContext, typename T>
class MishCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
const float threshold = ctx.Attr<float>("threshold");
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
int numel = x->numel();
for (int i = 0; i < numel; i++) {
T x_d = x_data[i];
T sp = CalcSoftplus<T>(x_d, threshold);
out_data[i] = x_d * std::tanh(sp);
}
}
};
template <typename DeviceContext>
class MishFP32CPUKernel : public framework::OpKernel<float> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
const float threshold = ctx.Attr<float>("threshold");
const float* x_data = x->data<float>();
float* out_data = out->mutable_data<float>(ctx.GetPlace());
int numel = x->numel();
for (int i = 0; i < numel; i++) {
float x_d = x_data[i];
float sp = CalcSoftplusFP32(x_d, threshold);
out_data[i] = x_d * std::tanh(sp);
}
}
};
template <typename DeviceContext, typename T>
class MishGradCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto threshold = ctx.Attr<float>("threshold");
const T* x_data = x->data<T>();
const T* dout_data = dout->data<T>();
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
int numel = x->numel();
for (int i = 0; i < numel; i++) {
T x_d = x_data[i];
T sp = CalcSoftplus<T>(x_d, threshold);
T tsp = std::tanh(sp);
T grad_sp = -std::expm1(-sp);
T grad_tsp = (static_cast<T>(1) - tsp * tsp) * grad_sp;
dx_data[i] = dout_data[i] * (x_d * grad_tsp + tsp);
}
}
};
template <typename DeviceContext>
class MishGradFP32CPUKernel : public framework::OpKernel<float> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto threshold = ctx.Attr<float>("threshold");
const float* x_data = x->data<float>();
const float* dout_data = dout->data<float>();
float* dx_data = dx->mutable_data<float>(ctx.GetPlace());
int numel = x->numel();
for (int i = 0; i < numel; i++) {
float x_d = x_data[i];
float sp = CalcSoftplusFP32(x_d, threshold);
float tsp = std::tanh(sp);
float grad_sp = -std::expm1f(-sp);
float grad_tsp = (static_cast<float>(1) - tsp * tsp) * grad_sp;
dx_data[i] = dout_data[i] * (x_d * grad_tsp + tsp);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -15191,6 +15191,9 @@ def mish(x, threshold=20, name=None):
out, = exe.run(feed={'x':x_data}, fetch_list=[y.name])
print(out) # [[0.66666667, 1.66666667, 3., 4.]]
"""
if in_dygraph_mode():
return _C_ops.mish(x, 'threshold', threshold)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'mish')
check_type(threshold, 'threshold', (float, int), 'mish')
assert threshold > 0, "threshold of mish should be greater than 0, " \
......@@ -15202,7 +15205,7 @@ def mish(x, threshold=20, name=None):
type='mish',
inputs={'X': x},
outputs={'Out': out},
attrs={'threshold': threshold or -1})
attrs={'threshold': threshold})
return out
......
......@@ -2837,6 +2837,86 @@ class TestSwishAPI(unittest.TestCase):
F.swish(x_fp16)
def ref_mish(x, threshold=20.):
softplus = np.select([x <= threshold, x > threshold],
[np.log(1 + np.exp(x)), x])
return x * np.tanh(softplus)
class TestMish(TestActivation):
def setUp(self):
self.op_type = "mish"
self.init_dtype()
np.random.seed(1024)
x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype)
out = ref_mish(x)
self.inputs = {'X': x}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
class TestMishAPI(unittest.TestCase):
# test paddle.nn.Mish, paddle.nn.functional.mish
def setUp(self):
np.random.seed(1024)
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float64)
self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.mish(x)
mish = paddle.nn.Mish()
out2 = mish(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_mish(self.x_np)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.mish(x)
mish = paddle.nn.Mish()
out2 = mish(x)
out_ref = ref_mish(self.x_np)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()
def test_fluid_api(self):
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
x = fluid.data('X', self.x_np.shape, self.x_np.dtype)
out = fluid.layers.mish(x)
exe = fluid.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = ref_mish(self.x_np)
self.assertEqual(np.allclose(out_ref, res[0]), True)
def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.mish, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.mish, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.mish(x_fp16)
#------------------ Test Error Activation----------------------
def create_test_error_class(op_type):
class TestOpErrors(unittest.TestCase):
......@@ -2972,6 +3052,7 @@ create_test_act_fp16_class(TestThresholdedRelu)
create_test_act_fp16_class(TestHardSigmoid)
create_test_act_fp16_class(TestSwish, grad_atol=0.85)
create_test_act_fp16_class(TestHardSwish)
create_test_act_fp16_class(TestMish, grad_atol=0.9)
def create_test_act_bf16_class(parent,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import six
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from op_test import OpTest, skip_check_grad_ci
class TestMishOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.mish, 0.1, 20)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.mish, x_int32, 20)
# support the input dtype is float32
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float32')
fluid.layers.mish(x_fp16, threshold=20)
class MishTest(OpTest):
def setUp(self):
self.init_dtype()
self.init_input_shape()
self.init_input_range()
self.init_threshold()
self.op_type = "mish"
x_np = np.random.uniform(self.x_range[0], self.x_range[1],
self.x_shape).astype(self.dtype)
self.inputs = {'X': x_np}
softplus = x_np * (x_np > self.threshold) + np.exp(x_np) * \
(x_np < -self.threshold) + np.log(np.exp(x_np) + 1.) * \
(x_np >= -self.threshold) * (x_np <= self.threshold)
out_np = x_np * np.tanh(softplus)
self.outputs = {'Out': out_np}
self.attrs = {'threshold': self.threshold}
def init_dtype(self):
self.dtype = 'float32'
def init_input_shape(self):
self.x_shape = (10, 12)
def init_input_range(self):
self.x_range = [-1, 1]
def init_threshold(self):
self.threshold = 5.
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class MishTestUpperThresh(MishTest):
def init_input_range(self):
self.x_range = [6, 7]
class MishTestLowerThresh(MishTest):
def init_input_range(self):
self.x_range = [-7, -6]
# mish op contain calculation like: tanh, exp, log, while tanh
# may have diff on CPUPlace(see test_activation_op.py::TestTanh),
# especially when abs(x) is a large value, only check input value
# in range [-1, 1] for float64 here.
class MishTestFP64(MishTest):
def init_dtype(self):
self.dtype = 'float64'
def init_input_range(self):
self.x_range = [-1, 1]
if __name__ == "__main__":
unittest.main()
......@@ -46,6 +46,7 @@ from .layer.activation import Softplus # noqa: F401
from .layer.activation import Softshrink # noqa: F401
from .layer.activation import Softsign # noqa: F401
from .layer.activation import Swish # noqa: F401
from .layer.activation import Mish # noqa: F401
from .layer.activation import Tanhshrink # noqa: F401
from .layer.activation import ThresholdedReLU # noqa: F401
from .layer.activation import LogSoftmax # noqa: F401
......@@ -294,6 +295,7 @@ __all__ = [ #noqa
'LogSoftmax',
'Sigmoid',
'Swish',
'Mish',
'PixelShuffle',
'ELU',
'ReLU6',
......
......@@ -39,6 +39,7 @@ from .activation import softplus # noqa: F401
from .activation import softshrink # noqa: F401
from .activation import softsign # noqa: F401
from .activation import swish # noqa: F401
from .activation import mish # noqa: F401
from .activation import tanh # noqa: F401
from .activation import tanh_ # noqa: F401
from .activation import tanhshrink # noqa: F401
......@@ -149,6 +150,7 @@ __all__ = [ #noqa
'sigmoid',
'silu',
'swish',
'mish',
'tanh',
'tanh_',
'tanhshrink',
......
......@@ -1174,6 +1174,47 @@ def swish(x, name=None):
return out
def mish(x, name=None):
r"""
mish activation.
.. math::
softplus(x) = \begin{cases}
x, \text{if } x > \text{threshold} \\
\ln(1 + e^{x}), \text{otherwise}
\end{cases}
mish(x) = x * \tanh(softplus(x))
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
x = paddle.to_tensor(np.array([-5., 0., 5.]))
out = F.mish(x) # [-0.03357624, 0., 4.99955208]
"""
if in_dygraph_mode():
return _C_ops.mish(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'mish')
helper = LayerHelper('mish', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='mish', inputs={'X': x}, outputs={'Out': out})
return out
def tanhshrink(x, name=None):
"""
tanhshrink activation
......
......@@ -881,6 +881,52 @@ class Swish(Layer):
return name_str
class Mish(Layer):
r"""
Mish Activation.
.. math::
softplus(x) = \begin{cases}
x, \text{if } x > \text{threshold} \\
\ln(1 + e^{x}), \text{otherwise}
\end{cases}
Mish(x) = x * \tanh(softplus(x))
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
x = paddle.to_tensor(np.array([-5., 0., 5.]))
m = paddle.nn.Mish()
out = m(x) # [-0.03357624, 0., 4.99955208]
"""
def __init__(self, name=None):
super(Mish, self).__init__()
self._name = name
def forward(self, x):
return F.mish(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class Tanhshrink(Layer):
"""
Tanhshrink Activation
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册