diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index ff0fbf21f86269885df5491afab7443df813f13f..0ddbfdb4aa9e844adbb291e1c5612e96681831d6 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -133,7 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel { platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv forward --------------------- - T alpha = 1.0f, beta = 0.0f; + typename platform::CudnnDataType::ScalingParamType alpha = 1.0f, + beta = 0.0f; for (int i = 0; i < groups; i++) { PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, @@ -280,7 +282,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv backward data --------------------- - T alpha = 1.0f, beta = 0.0f; + typename platform::CudnnDataType::ScalingParamType alpha = 1.0f, + beta = 0.0f; if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. @@ -315,16 +318,18 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_KERNEL(conv2d, CUDNN, ::paddle::platform::CUDAPlace, +namespace plat = paddle::platform; +REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, - paddle::operators::CUDNNConvOpKernel); -REGISTER_OP_KERNEL(conv2d_grad, CUDNN, ::paddle::platform::CUDAPlace, + paddle::operators::CUDNNConvOpKernel, + paddle::operators::CUDNNConvOpKernel); +REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, paddle::operators::CUDNNConvGradOpKernel); -REGISTER_OP_KERNEL(conv3d, CUDNN, ::paddle::platform::CUDAPlace, +REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvOpKernel, paddle::operators::CUDNNConvOpKernel); -REGISTER_OP_KERNEL(conv3d_grad, CUDNN, ::paddle::platform::CUDAPlace, +REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, paddle::operators::CUDNNConvGradOpKernel, paddle::operators::CUDNNConvGradOpKernel); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 4b02b80d7772fa15d2333692551da5e59d93765f..e3fc21c90f95469d646139a4454501d1c30bd51c 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -83,12 +83,23 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } #endif + auto input_data_type = + framework::ToDataType(ctx.Input("Input")->type()); + auto filter_data_type = + framework::ToDataType(ctx.Input("Filter")->type()); + PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, + "input and filter data type should be consistent"); + + if (input_data_type == framework::proto::VarType::FP16) { + PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, + "float16 can only be used when CUDNN is used"); + } + std::string data_format = ctx.Attr("data_format"); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, + library_); } Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 9a2ac3ff33df3f8b9e24203f9dba2130e1d16510..7e001ecc56173db76e8c576e7efd66f41192f292 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/macros.h" namespace paddle { @@ -80,6 +81,22 @@ enum class PoolingMode { template class CudnnDataType; +template <> +class CudnnDataType { + public: + static const cudnnDataType_t type = CUDNN_DATA_HALF; + // The scaling param type is float for HALF and FLOAT tensors + typedef const float ScalingParamType; + static ScalingParamType* kOne() { + static ScalingParamType v = 1.0; + return &v; + } + static ScalingParamType* kZero() { + static ScalingParamType v = 0.0; + return &v; + } +}; + template <> class CudnnDataType { public: diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index f7e02595ec3b41ae7bb32353c258736968ca78d4..6a42f763a6c436f5d33569dc65f711c2930a8b2e 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -469,6 +469,28 @@ class OpTest(unittest.TestCase): tensor.set_lod(lod) return tensor + @staticmethod + def np_dtype_to_fluid_dtype(input): + """Change the dtype of float16 numpy array + + numpy float16 is binded to paddle::platform::float16 + in tensor_py.h via the help of uint16 data type since + the internal memory representation of float16 is + uint16_t in paddle and np.uint16 in numpy, which are + themselves binded together by pybind. + + Args: + input: input numpy array + + Returns: + input: if the dtype of input is np.float16, its dtype will be + changed to np.uint16 so that the internal memory will be + reinterpreted input as of dtype np.uint16. + """ + if input.dtype == np.float16: + input.dtype = np.uint16 + return input + def _get_gradient(self, input_to_check, place, output_names, no_grad_set): prog = Program() block = prog.global_block() diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index a49fecf09509f7b1d9f758eebcf90bf9fbf7669f..7913b98240fbd0aaa8d94911e68f61a0e416e7c8 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -68,6 +68,7 @@ class TestConv2dOp(OpTest): self.init_op_type() self.init_group() self.init_dilation() + self.init_data_type() self.init_test_case() conv2d_param = { @@ -75,12 +76,16 @@ class TestConv2dOp(OpTest): 'pad': self.pad, 'dilation': self.dilations } - input = np.random.random(self.input_size).astype("float32") - filter = np.random.random(self.filter_size).astype("float32") + + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.random(self.filter_size).astype(self.dtype) output = conv2d_forward_naive(input, filter, self.groups, - conv2d_param).astype('float32') + conv2d_param).astype(self.dtype) - self.inputs = {'Input': input, 'Filter': filter} + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } self.attrs = { 'strides': self.stride, 'paddings': self.pad, @@ -99,6 +104,8 @@ class TestConv2dOp(OpTest): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return if self.use_cudnn: place = core.CUDAPlace(0) self.check_grad_with_place( @@ -111,6 +118,8 @@ class TestConv2dOp(OpTest): set(['Input', 'Filter']), 'Output', max_relative_error=0.02) def test_check_grad_no_filter(self): + if self.dtype == np.float16: + return if self.use_cudnn: place = core.CUDAPlace(0) self.check_grad_with_place( @@ -126,6 +135,8 @@ class TestConv2dOp(OpTest): no_grad_set=set(['Filter'])) def test_check_grad_no_input(self): + if self.dtype == np.float16: + return if self.use_cudnn: place = core.CUDAPlace(0) self.check_grad_with_place( @@ -148,6 +159,9 @@ class TestConv2dOp(OpTest): f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3] + def init_data_type(self): + self.dtype = np.float32 + def init_dilation(self): self.dilations = [1, 1] @@ -232,36 +246,102 @@ class TestCUDNN(TestConv2dOp): self.op_type = "conv2d" +class TestFP16CUDNN(TestCUDNN): + def init_data_type(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=2e-2) + + class TestCUDNNWithPad(TestWithPad): def init_op_type(self): self.use_cudnn = True self.op_type = "conv2d" +class TestFP16CUDNNWithPad(TestCUDNNWithPad): + def init_data_type(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=2e-2) + + class TestCUDNNWithStride(TestWithStride): def init_op_type(self): self.use_cudnn = True self.op_type = "conv2d" +class TestFP16CUDNNWithStride(TestCUDNNWithStride): + def init_data_type(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=2e-2) + + class TestCUDNNWithGroup(TestWithGroup): def init_op_type(self): self.use_cudnn = True self.op_type = "conv2d" +class TestFP16CUDNNWithGroup(TestCUDNNWithGroup): + def init_data_type(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=2e-2) + + class TestCUDNNWith1x1(TestWith1x1): def init_op_type(self): self.use_cudnn = True self.op_type = "conv2d" +class TestFP16CUDNNWith1x1(TestCUDNNWith1x1): + def init_data_type(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=2e-2) + + class TestCUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1): def init_op_type(self): self.use_cudnn = True self.op_type = "conv2d" +class TestFP16CUDNNWithInput1x1Filter1x1(TestCUDNNWithInput1x1Filter1x1): + def init_data_type(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=2e-2) + + class TestDepthwiseConv(TestConv2dOp): def init_test_case(self): self.pad = [1, 1]