未验证 提交 0f38bb45 编写于 作者: K Kexin Zhao 提交者: GitHub

add fp16 support to activation op (#9769)

上级 22b9d4e6
......@@ -662,14 +662,3 @@ REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad,
ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
REGISTER_OP_CPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ReluFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ReluFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ReluGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ReluGradFunctor<double>>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -17,31 +14,19 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, \
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext, \
namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<double>>, \
ops::ActivationKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
REGISTER_OP_CUDA_KERNEL(
relu, ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<float>>,
ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<double>>,
ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<paddle::platform::float16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradFunctor<double>>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -15,9 +12,11 @@ limitations under the License. */
#pragma once
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/float16.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -338,11 +337,25 @@ struct Sine {
HOSTDEVICE T operator()(const T& val) const { return sin(val); }
};
template <>
struct Sine<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(sin(static_cast<float>(val)));
}
};
template <typename T>
struct Cosine {
HOSTDEVICE T operator()(const T& val) const { return cos(val); }
};
template <>
struct Cosine<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(cos(static_cast<float>(val)));
}
};
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
......@@ -826,6 +839,7 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
......
......@@ -1003,6 +1003,46 @@ HOSTDEVICE inline float16 exp(const float16& a) {
return float16(::expf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 log(const float16& a) {
return float16(::logf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
return float16(::tanhf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
return float16(::sqrtf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
return float16(::ceilf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 floor(const float16& a) {
return float16(::floorf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 round(const float16& a) {
return float16(::roundf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}
template <>
HOSTDEVICE inline float16 abs(const float16& a) {
return float16(::fabs(static_cast<float>(a)));
}
} // namespace numext
} // namespace Eigen
......@@ -22,221 +22,504 @@ from scipy.special import expit
class TestExp(OpTest):
def setUp(self):
self.op_type = "exp"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Out': np.exp(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.exp(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Exp(TestExp):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSigmoid(OpTest):
def setUp(self):
self.op_type = "sigmoid"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Out': 1 / (1 + np.exp(-self.inputs['X']))}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = 1 / (1 + np.exp(-x))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.008)
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.01)
def init_dtype(self):
pass
class TestFP16Sigmoid(TestSigmoid):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestLogSigmoid(OpTest):
def setUp(self):
self.op_type = "logsigmoid"
self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Out': np.log(1 / (1 + np.exp(-self.inputs['X'])))}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = np.log(1 / (1 + np.exp(-x)))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.008)
def init_dtype(self):
pass
class TestFP16LogSigmoid(TestLogSigmoid):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestTanh(OpTest):
def setUp(self):
self.op_type = "tanh"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Out': np.tanh(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.tanh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Tanh(TestTanh):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestTanhShrink(OpTest):
def setUp(self):
self.op_type = "tanh_shrink"
self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 17]).astype("float32")
}
self.outputs = {'Out': self.inputs['X'] - np.tanh(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(0.1, 1, [10, 17]).astype(self.dtype)
out = x - np.tanh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.008)
def init_dtype(self):
pass
class TestFP16TanhShrink(TestTanhShrink):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestHardShrink(OpTest):
def setUp(self):
self.op_type = "hard_shrink"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.dtype = np.float32
self.init_dtype()
threshold = 0.5
x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype)
out = np.copy(x)
out[(out >= -threshold) & (out <= threshold)] = 0
self.inputs = {'X': x}
self.attrs = {'lambda': threshold}
t = np.copy(x)
t[(t >= -threshold) & (t <= threshold)] = 0
self.outputs = {'Out': t}
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.005)
def init_dtype(self):
pass
class TestFP16HardShrink(TestHardShrink):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSoftShrink(OpTest):
def setUp(self):
self.op_type = "softshrink"
self.dtype = np.float32
self.init_dtype()
lambda_val = 0.1
x = np.random.uniform(0.25, 10, [4, 4]).astype(self.dtype)
out = np.copy(x)
out = (out < -lambda_val) * (out + lambda_val) + (out > lambda_val) * (
out - lambda_val)
self.attrs = {'lambda': lambda_val}
self.inputs = {
'X': np.random.uniform(0.25, 10, [4, 4]).astype("float32")
}
y = np.copy(self.inputs['X'])
y = (y < -lambda_val) * (y + lambda_val) + (y > lambda_val) * (
y - lambda_val)
self.outputs = {'Out': y}
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16SoftShrink(TestSoftShrink):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSqrt(OpTest):
def setUp(self):
self.op_type = "sqrt"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Out': np.sqrt(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.sqrt(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Sqrt(TestSqrt):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestAbs(OpTest):
def setUp(self):
self.op_type = "abs"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype)
# Because we set delta = 0.005 in caculating numeric gradient,
# if x is too small, such as 0.002, x_neg will be -0.003
# x_pos will be 0.007, so the numeric gradient is unaccurate.
# we should avoid this
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Out': np.abs(self.inputs['X'])}
out = np.abs(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Abs(TestAbs):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCeil(OpTest):
def setUp(self):
self.op_type = "ceil"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x}
self.outputs = {'Out': np.ceil(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype)
out = np.ceil(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Ceil(TestCeil):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestFloor(OpTest):
def setUp(self):
self.op_type = "floor"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x}
self.outputs = {'Out': np.floor(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype)
out = np.floor(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Floor(TestFloor):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCos(OpTest):
def setUp(self):
self.op_type = "cos"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x}
self.outputs = {'Out': np.cos(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype)
out = np.cos(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Cos(TestCos):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSin(OpTest):
def setUp(self):
self.op_type = "sin"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x}
self.outputs = {'Out': np.sin(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype)
out = np.sin(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Sin(TestSin):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestRound(OpTest):
def setUp(self):
self.op_type = "round"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x}
self.outputs = {'Out': np.round(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype)
out = np.round(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Round(TestRound):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestRelu(OpTest):
def setUp(self):
......@@ -278,222 +561,463 @@ class TestFP16Relu(TestRelu):
class TestBRelu(OpTest):
def setUp(self):
self.op_type = "brelu"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype)
t_min = 1.0
t_max = 4.0
# The same with TestAbs
x[np.abs(x - t_min) < 0.005] = t_min + 0.02
x[np.abs(x - t_max) < 0.005] = t_max + 0.02
self.inputs = {'X': x}
self.attrs = {'t_min': t_min, 't_max': t_max}
t = np.copy(x)
t[t < t_min] = t_min
t[t > t_max] = t_max
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'t_min': t_min, 't_max': t_max}
self.outputs = {'Out': t}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.02)
def init_dtype(self):
pass
class TestFP16BRelu(TestBRelu):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestRelu6(OpTest):
def setUp(self):
self.op_type = "relu6"
x = np.random.uniform(-1, 1, [4, 10]).astype("float32")
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [4, 10]).astype(self.dtype)
threshold = 6.0
# The same with TestAbs
x[np.abs(x) < 0.005] = 0.02
x[np.abs(x - threshold) < 0.005] = threshold + 0.02
out = np.minimum(np.maximum(x, 0), threshold)
self.inputs = {'X': x}
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'threshold': threshold}
self.outputs = {
'Out': np.minimum(np.maximum(self.inputs['X'], 0), threshold)
}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.02)
def init_dtype(self):
pass
class TestFP16Relu6(TestRelu6):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSoftRelu(OpTest):
def setUp(self):
self.op_type = "soft_relu"
x = np.random.uniform(-3, 3, [4, 4]).astype("float32")
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-3, 3, [4, 4]).astype(self.dtype)
threshold = 2.0
# The same reason with TestAbs
x[np.abs(x - threshold) < 0.005] = threshold + 0.02
x[np.abs(x + threshold) < 0.005] = -threshold + 0.02
self.inputs = {'X': x}
self.attrs = {'threshold': threshold}
t = np.copy(x)
t[t < -threshold] = -threshold
t[t > threshold] = threshold
self.outputs = {'Out': np.log((np.exp(t) + 1))}
out = np.log((np.exp(t) + 1))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'threshold': threshold}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.02)
def init_dtype(self):
pass
class TestFP16SoftRelu(TestSoftRelu):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestELU(OpTest):
def setUp(self):
self.op_type = "elu"
x = np.random.uniform(-3, 3, [4, 4]).astype("float32")
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-3, 3, [4, 4]).astype(self.dtype)
alpha = 1.
out = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
# Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1)
# is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here
self.inputs = {'X': x}
self.attrs = {'alpha': alpha}
self.outputs = {
'Out': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.02)
def init_dtype(self):
pass
class TestFP16ELU(TestELU):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestReciprocal(OpTest):
def setUp(self):
self.op_type = "reciprocal"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.outputs = {'Out': np.reciprocal(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np.reciprocal(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.01)
def init_dtype(self):
pass
class TestFP16Reciprocal(TestReciprocal):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestLog(OpTest):
def setUp(self):
self.op_type = "log"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Out': np.log(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.log(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Log(TestLog):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSquare(OpTest):
def setUp(self):
self.op_type = "square"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Out': np.square(self.inputs['X'])}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.square(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Square(TestSquare):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestPow(OpTest):
def setUp(self):
self.op_type = "pow"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np.power(x, 3)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'factor': 3.0}
self.outputs = {'Out': np.power(self.inputs['X'], 3)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.02)
def init_dtype(self):
pass
class TestFP16Pow(TestPow):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=5e-2)
class TestSTanh(OpTest):
def setUp(self):
self.op_type = "stanh"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
scale_a = 2.0 / 3.0
scale_b = 1.7159
out = scale_b * np.tanh(x * scale_a)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'scale_a': scale_a, 'scale_b': scale_b}
self.outputs = {'Out': scale_b * np.tanh(self.inputs['X'] * scale_a)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16STanh(TestSTanh):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSoftplus(OpTest):
def setUp(self):
self.op_type = "softplus"
self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float64")
}
self.outputs = {'Out': np.log(1 + np.exp(self.inputs['X']))}
self.dtype = np.float64
self.init_dtype()
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = np.log(1 + np.exp(x))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Softplus(TestSoftplus):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSoftsign(OpTest):
def setUp(self):
self.op_type = "softsign"
self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float32")
}
self.outputs = {
'Out': np.divide(self.inputs['X'], 1 + np.abs(self.inputs['X']))
}
self.dtype = np.float32
self.init_dtype()
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = np.divide(x, 1 + np.abs(x))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.007)
def init_dtype(self):
pass
class TestFP16Softsign(TestSoftsign):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestThresholdedRelu(OpTest):
def setUp(self):
self.op_type = "thresholded_relu"
self.dtype = np.float32
self.init_dtype()
threshold = 0.25
self.relative_error = 0.005
X = np.random.uniform(-1, 1, [11, 17]).astype("float32")
X = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
# Same reason as TestAbs
X[np.abs(X - threshold) < self.relative_error] = threshold + 0.2
out = (X > threshold) * X
self.inputs = {'X': X}
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(X)}
self.attrs = {'threshold': threshold}
self.outputs = {'Out': (X > threshold) * X}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=self.relative_error)
def init_dtype(self):
pass
class TestFP16ThresholdedRelu(TestThresholdedRelu):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestHardSigmoid(OpTest):
def setUp(self):
self.op_type = "hard_sigmoid"
self.dtype = np.float32
self.init_dtype()
self.relative_error = 0.002
X = np.random.uniform(-5, 5, [2, 2]).astype("float32")
......@@ -502,7 +1026,6 @@ class TestHardSigmoid(OpTest):
lower_threshold = -offset / slope
upper_threshold = (1 - offset) / slope
self.inputs = {'X': X}
# Same reason as TestAbs
X[np.abs(X - lower_threshold) < self.relative_error] = \
lower_threshold + 0.2
......@@ -510,29 +1033,70 @@ class TestHardSigmoid(OpTest):
upper_threshold - 0.2
temp = X * slope + offset
self.outputs = {'Out': np.maximum(0.0, np.minimum(1.0, temp))}
out = np.maximum(0.0, np.minimum(1.0, temp))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(X)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.002)
def init_dtype(self):
pass
class TestFP16HardSigmoid(TestHardSigmoid):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSwish(OpTest):
def setUp(self):
self.op_type = "swish"
X = np.random.uniform(0.1, 1, [11, 17]).astype("float32")
self.inputs = {'X': X}
self.attrs = {'beta': 2.3}
self.outputs = {'Out': X * expit(self.attrs['beta'] * X)}
self.dtype = np.float32
self.init_dtype()
X = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
beta = 2.3
out = X * expit(beta * X)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(X)}
self.attrs = {'beta': beta}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.008)
def init_dtype(self):
pass
class TestFP16Swish(TestSwish):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
#--------------------test MKLDNN--------------------
class TestMKLDNNReluDim2(TestRelu):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册