未验证 提交 8e731014 编写于 作者: K Kexin Zhao 提交者: GitHub

Merge pull request #9143 from kexinzhao/numpy_conv2d_pool2d_fp16

Add float16 support for cudnn conv2d
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -133,7 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -133,7 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace()); platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward --------------------- // ------------------- cudnn conv forward ---------------------
T alpha = 1.0f, beta = 0.0f; typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
beta = 0.0f;
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
...@@ -280,7 +282,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -280,7 +282,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace()); platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data --------------------- // ------------------- cudnn conv backward data ---------------------
T alpha = 1.0f, beta = 0.0f; typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
beta = 0.0f;
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
...@@ -315,16 +318,18 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -315,16 +318,18 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_KERNEL(conv2d, CUDNN, ::paddle::platform::CUDAPlace, namespace plat = paddle::platform;
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>, paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>); paddle::operators::CUDNNConvOpKernel<double>,
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, ::paddle::platform::CUDAPlace, paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>, paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>); paddle::operators::CUDNNConvGradOpKernel<double>);
REGISTER_OP_KERNEL(conv3d, CUDNN, ::paddle::platform::CUDAPlace, REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>, paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>); paddle::operators::CUDNNConvOpKernel<double>);
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, ::paddle::platform::CUDAPlace, REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>, paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>); paddle::operators::CUDNNConvGradOpKernel<double>);
...@@ -83,12 +83,23 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -83,12 +83,23 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
#endif #endif
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Input")->type());
auto filter_data_type =
framework::ToDataType(ctx.Input<Tensor>("Filter")->type());
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent");
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used");
}
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), library_);
layout_, library_);
} }
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
...@@ -80,6 +81,22 @@ enum class PoolingMode { ...@@ -80,6 +81,22 @@ enum class PoolingMode {
template <typename T> template <typename T>
class CudnnDataType; class CudnnDataType;
template <>
class CudnnDataType<float16> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
// The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};
template <> template <>
class CudnnDataType<float> { class CudnnDataType<float> {
public: public:
......
...@@ -469,6 +469,28 @@ class OpTest(unittest.TestCase): ...@@ -469,6 +469,28 @@ class OpTest(unittest.TestCase):
tensor.set_lod(lod) tensor.set_lod(lod)
return tensor return tensor
@staticmethod
def np_dtype_to_fluid_dtype(input):
"""Change the dtype of float16 numpy array
numpy float16 is binded to paddle::platform::float16
in tensor_py.h via the help of uint16 data type since
the internal memory representation of float16 is
uint16_t in paddle and np.uint16 in numpy, which are
themselves binded together by pybind.
Args:
input: input numpy array
Returns:
input: if the dtype of input is np.float16, its dtype will be
changed to np.uint16 so that the internal memory will be
reinterpreted input as of dtype np.uint16.
"""
if input.dtype == np.float16:
input.dtype = np.uint16
return input
def _get_gradient(self, input_to_check, place, output_names, no_grad_set): def _get_gradient(self, input_to_check, place, output_names, no_grad_set):
prog = Program() prog = Program()
block = prog.global_block() block = prog.global_block()
......
...@@ -68,6 +68,7 @@ class TestConv2dOp(OpTest): ...@@ -68,6 +68,7 @@ class TestConv2dOp(OpTest):
self.init_op_type() self.init_op_type()
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
self.init_data_type()
self.init_test_case() self.init_test_case()
conv2d_param = { conv2d_param = {
...@@ -75,12 +76,16 @@ class TestConv2dOp(OpTest): ...@@ -75,12 +76,16 @@ class TestConv2dOp(OpTest):
'pad': self.pad, 'pad': self.pad,
'dilation': self.dilations 'dilation': self.dilations
} }
input = np.random.random(self.input_size).astype("float32")
filter = np.random.random(self.filter_size).astype("float32") input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv2d_forward_naive(input, filter, self.groups, output = conv2d_forward_naive(input, filter, self.groups,
conv2d_param).astype('float32') conv2d_param).astype(self.dtype)
self.inputs = {'Input': input, 'Filter': filter} self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = { self.attrs = {
'strides': self.stride, 'strides': self.stride,
'paddings': self.pad, 'paddings': self.pad,
...@@ -99,6 +104,8 @@ class TestConv2dOp(OpTest): ...@@ -99,6 +104,8 @@ class TestConv2dOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16:
return
if self.use_cudnn: if self.use_cudnn:
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
...@@ -111,6 +118,8 @@ class TestConv2dOp(OpTest): ...@@ -111,6 +118,8 @@ class TestConv2dOp(OpTest):
set(['Input', 'Filter']), 'Output', max_relative_error=0.02) set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
if self.use_cudnn: if self.use_cudnn:
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
...@@ -126,6 +135,8 @@ class TestConv2dOp(OpTest): ...@@ -126,6 +135,8 @@ class TestConv2dOp(OpTest):
no_grad_set=set(['Filter'])) no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
if self.use_cudnn: if self.use_cudnn:
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
...@@ -148,6 +159,9 @@ class TestConv2dOp(OpTest): ...@@ -148,6 +159,9 @@ class TestConv2dOp(OpTest):
f_c = self.input_size[1] / self.groups f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3] self.filter_size = [6, f_c, 3, 3]
def init_data_type(self):
self.dtype = np.float32
def init_dilation(self): def init_dilation(self):
self.dilations = [1, 1] self.dilations = [1, 1]
...@@ -232,36 +246,102 @@ class TestCUDNN(TestConv2dOp): ...@@ -232,36 +246,102 @@ class TestCUDNN(TestConv2dOp):
self.op_type = "conv2d" self.op_type = "conv2d"
class TestFP16CUDNN(TestCUDNN):
def init_data_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestCUDNNWithPad(TestWithPad): class TestCUDNNWithPad(TestWithPad):
def init_op_type(self): def init_op_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv2d" self.op_type = "conv2d"
class TestFP16CUDNNWithPad(TestCUDNNWithPad):
def init_data_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestCUDNNWithStride(TestWithStride): class TestCUDNNWithStride(TestWithStride):
def init_op_type(self): def init_op_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv2d" self.op_type = "conv2d"
class TestFP16CUDNNWithStride(TestCUDNNWithStride):
def init_data_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestCUDNNWithGroup(TestWithGroup): class TestCUDNNWithGroup(TestWithGroup):
def init_op_type(self): def init_op_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv2d" self.op_type = "conv2d"
class TestFP16CUDNNWithGroup(TestCUDNNWithGroup):
def init_data_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestCUDNNWith1x1(TestWith1x1): class TestCUDNNWith1x1(TestWith1x1):
def init_op_type(self): def init_op_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv2d" self.op_type = "conv2d"
class TestFP16CUDNNWith1x1(TestCUDNNWith1x1):
def init_data_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestCUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1): class TestCUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def init_op_type(self): def init_op_type(self):
self.use_cudnn = True self.use_cudnn = True
self.op_type = "conv2d" self.op_type = "conv2d"
class TestFP16CUDNNWithInput1x1Filter1x1(TestCUDNNWithInput1x1Filter1x1):
def init_data_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestDepthwiseConv(TestConv2dOp): class TestDepthwiseConv(TestConv2dOp):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册