未验证 提交 b9e6364e 编写于 作者: K Kexin Zhao 提交者: GitHub

Merge pull request #9267 from kexinzhao/new_relu_fp16

Add float16 support to relu op
......@@ -613,3 +613,14 @@ REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad,
ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
REGISTER_OP_CPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ReluFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ReluFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ReluGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ReluGradFunctor<double>>);
......@@ -14,6 +14,7 @@ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
......@@ -31,3 +32,16 @@ namespace ops = paddle::operators;
ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
REGISTER_OP_CUDA_KERNEL(
relu, ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<float>>,
ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<double>>,
ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<paddle::platform::float16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradFunctor<double>>);
......@@ -772,7 +772,6 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
......
......@@ -14,6 +14,7 @@
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
from scipy.special import expit
......@@ -212,18 +213,39 @@ class TestRound(OpTest):
class TestRelu(OpTest):
def setUp(self):
self.op_type = "relu"
x = np.random.uniform(-1, 1, [11, 17]).astype("float32")
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Out': np.maximum(self.inputs['X'], 0)}
out = np.maximum(x, 0)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Relu(TestRelu):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestBRelu(OpTest):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册