提交 a13ec343 编写于 作者: K Kexin Zhao

fix test error

上级 e4de5dc3
......@@ -134,8 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward ---------------------
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
beta = 0.0f;
for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
......@@ -321,7 +321,7 @@ namespace plat = paddle::platform;
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>,
paddle::operators::CUDNNConvOpKernel < plat::float16);
paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>);
......
......@@ -85,13 +85,14 @@ template <>
class CudnnDataType<float16> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
typedef const float16 ScalingParamType;
// The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = static_cast<float16>(1.0);
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = static_cast<float16>(0.0);
static ScalingParamType v = 0.0;
return &v;
}
};
......
......@@ -79,7 +79,7 @@ class TestConv2dOp(OpTest):
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv2d_forward_naive(self.input, self.filter, self.groups,
output = conv2d_forward_naive(input, filter, self.groups,
conv2d_param).astype(self.dtype)
# numpy float16 is binded to paddle::platform::float16
......@@ -88,9 +88,12 @@ class TestConv2dOp(OpTest):
# uint16_t in paddle or np.uint16 in numpy, which are
# themselves binded together.
self.inputs = {
'Input': input.view(np.uint16)
if self.dtype == np.float16 else input,
'Filter': create_view(filter)
#'Input': (input.view(np.uint16)
# if self.dtype == np.float16 else input),
#'Filter': (filter.view(np.uint16)
# if self.dtype == np.float16 else filter)
'Input': OpTest.create_view(input),
'Filter': OpTest.create_view(filter)
}
self.attrs = {
'strides': self.stride,
......@@ -254,7 +257,7 @@ class TestFP16CUDNN(TestCUDNN):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-1)
self.check_output_with_place(place, atol=2e-2)
def test_check_grad(self):
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册