提交 a13ec343 编写于 作者: K Kexin Zhao

fix test error

上级 e4de5dc3
...@@ -134,8 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -134,8 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace()); platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward --------------------- // ------------------- cudnn conv forward ---------------------
T alpha = static_cast<T>(1.0f); typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
T beta = static_cast<T>(0.0f); beta = 0.0f;
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
...@@ -321,7 +321,7 @@ namespace plat = paddle::platform; ...@@ -321,7 +321,7 @@ namespace plat = paddle::platform;
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>, paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>, paddle::operators::CUDNNConvOpKernel<double>,
paddle::operators::CUDNNConvOpKernel < plat::float16); paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>, paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>); paddle::operators::CUDNNConvGradOpKernel<double>);
......
...@@ -85,13 +85,14 @@ template <> ...@@ -85,13 +85,14 @@ template <>
class CudnnDataType<float16> { class CudnnDataType<float16> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_HALF; static const cudnnDataType_t type = CUDNN_DATA_HALF;
typedef const float16 ScalingParamType; // The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType;
static ScalingParamType* kOne() { static ScalingParamType* kOne() {
static ScalingParamType v = static_cast<float16>(1.0); static ScalingParamType v = 1.0;
return &v; return &v;
} }
static ScalingParamType* kZero() { static ScalingParamType* kZero() {
static ScalingParamType v = static_cast<float16>(0.0); static ScalingParamType v = 0.0;
return &v; return &v;
} }
}; };
......
...@@ -79,7 +79,7 @@ class TestConv2dOp(OpTest): ...@@ -79,7 +79,7 @@ class TestConv2dOp(OpTest):
input = np.random.random(self.input_size).astype(self.dtype) input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype) filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv2d_forward_naive(self.input, self.filter, self.groups, output = conv2d_forward_naive(input, filter, self.groups,
conv2d_param).astype(self.dtype) conv2d_param).astype(self.dtype)
# numpy float16 is binded to paddle::platform::float16 # numpy float16 is binded to paddle::platform::float16
...@@ -88,9 +88,12 @@ class TestConv2dOp(OpTest): ...@@ -88,9 +88,12 @@ class TestConv2dOp(OpTest):
# uint16_t in paddle or np.uint16 in numpy, which are # uint16_t in paddle or np.uint16 in numpy, which are
# themselves binded together. # themselves binded together.
self.inputs = { self.inputs = {
'Input': input.view(np.uint16) #'Input': (input.view(np.uint16)
if self.dtype == np.float16 else input, # if self.dtype == np.float16 else input),
'Filter': create_view(filter) #'Filter': (filter.view(np.uint16)
# if self.dtype == np.float16 else filter)
'Input': OpTest.create_view(input),
'Filter': OpTest.create_view(filter)
} }
self.attrs = { self.attrs = {
'strides': self.stride, 'strides': self.stride,
...@@ -254,7 +257,7 @@ class TestFP16CUDNN(TestCUDNN): ...@@ -254,7 +257,7 @@ class TestFP16CUDNN(TestCUDNN):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-1) self.check_output_with_place(place, atol=2e-2)
def test_check_grad(self): def test_check_grad(self):
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册