未验证 提交 4e62af80 编写于 作者: C cc 提交者: GitHub

Add FP16 PRelu (#35532)

上级 afd1b372
......@@ -33,7 +33,8 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
size_t channel_index = temp % channel_num;
T scale = alpha[channel_index];
T x = input[index];
output[index] = (x > 0) ? x : scale * x;
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
......@@ -45,7 +46,8 @@ __global__ void PReluElementWiseKernel(const T *input, const T *alpha,
size_t element_index = index % spatial_size;
T scale = alpha[element_index];
T x = input[index];
output[index] = (x > 0) ? x : scale * x;
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
......@@ -55,7 +57,8 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
T scale = alpha[0];
CUDA_KERNEL_LOOP(index, numel) {
T x = input[index];
output[index] = (x > 0) ? x : scale * x;
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
......@@ -88,12 +91,15 @@ void PreluScalarDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
}
template class PreluChannelWiseDirectCUDAFunctor<float>;
template class PreluChannelWiseDirectCUDAFunctor<paddle::platform::float16>;
template class PreluChannelWiseDirectCUDAFunctor<double>;
template class PreluElementWiseDirectCUDAFunctor<float>;
template class PreluElementWiseDirectCUDAFunctor<paddle::platform::float16>;
template class PreluElementWiseDirectCUDAFunctor<double>;
template class PreluScalarDirectCUDAFunctor<float>;
template class PreluScalarDirectCUDAFunctor<paddle::platform::float16>;
template class PreluScalarDirectCUDAFunctor<double>;
} // namespace math
......
......@@ -87,8 +87,9 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
}
T x = x_ptr[index];
T dy = dy_ptr[index];
if (dx_ptr != nullptr) dx_ptr[index] = (x > 0) ? dy : scale * dy;
if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > 0) ? 0 : x * dy;
T zero = static_cast<T>(0);
if (dx_ptr != nullptr) dx_ptr[index] = (x > zero) ? dy : scale * dy;
if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > zero) ? zero : x * dy;
}
}
......@@ -112,9 +113,11 @@ class PreluOpGradFunctor {
}
};
template <typename T>
struct IdentityFunctor {
HOSTDEVICE inline T operator()(const T& x) const { return x; }
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return x;
}
};
template <typename DeviceContext, typename T>
......@@ -174,9 +177,9 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
reduce_dims.push_back(i);
}
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
TensorReduce<T, T, cub::Sum, IdentityFunctor>(
dalpha_tmp, dalpha, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T>(), stream);
IdentityFunctor(), stream);
}
};
......@@ -184,10 +187,14 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
prelu, ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, float>,
ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
prelu_grad,
ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -153,11 +153,12 @@ class TestNNPReluAPI(unittest.TestCase):
class PReluTest(OpTest):
def setUp(self):
self.init_dtype()
self.init_input_shape()
self.init_attr()
self.op_type = "prelu"
x_np = np.random.uniform(-1, 1, self.x_shape)
x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype)
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
x_np[np.abs(x_np) < 0.005] = 0.02
......@@ -168,6 +169,7 @@ class PReluTest(OpTest):
alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1], 1, 1])
else:
alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
alpha_np = alpha_np.astype(self.dtype)
self.inputs = {'X': x_np, 'Alpha': alpha_np}
......@@ -184,6 +186,9 @@ class PReluTest(OpTest):
assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np}
def init_dtype(self):
self.dtype = np.float64
def init_input_shape(self):
self.x_shape = [2, 100, 3, 4]
......@@ -270,6 +275,44 @@ class TestModeElementRank6(PReluTest):
self.attrs = {'mode': "element"}
def create_test_fp16_class(parent,
check_grad=True,
atol=1e-3,
max_relative_error=0.05):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestPReluFp16Case(parent):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=atol)
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and check_grad:
self.check_grad_with_place(
place, ['X', 'Alpha'],
'Out',
max_relative_error=max_relative_error)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op")
TestPReluFp16Case.__name__ = cls_name
globals()[cls_name] = TestPReluFp16Case
create_test_fp16_class(TestModeElt)
create_test_fp16_class(TestModeAllRank3)
create_test_fp16_class(TestModeAllRank6)
create_test_fp16_class(TestModeChannelRank3)
create_test_fp16_class(TestModeChannelRank6)
create_test_fp16_class(TestModeElementRank3)
create_test_fp16_class(TestModeElementRank6)
def prelu_t(x, mode, param_attr=None, name=None):
helper = fluid.layer_helper.LayerHelper('prelu', **locals())
alpha_shape = [1, x.shape[1], 1, 1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册