提交 e4de5dc3 编写于 作者: K Kexin Zhao

add conv2d fp16 support

上级 1cd700d8
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -133,7 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward ---------------------
T alpha = 1.0f, beta = 0.0f;
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
......@@ -315,16 +317,18 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(conv2d, CUDNN, ::paddle::platform::CUDAPlace,
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, ::paddle::platform::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<double>,
paddle::operators::CUDNNConvOpKernel < plat::float16);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>);
REGISTER_OP_KERNEL(conv3d, CUDNN, ::paddle::platform::CUDAPlace,
REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>);
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, ::paddle::platform::CUDAPlace,
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>);
......@@ -83,12 +83,23 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
#endif
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Input")->type());
auto filter_data_type =
framework::ToDataType(ctx.Input<Tensor>("Filter")->type());
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent");
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used");
}
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_);
}
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
......@@ -80,6 +81,21 @@ enum class PoolingMode {
template <typename T>
class CudnnDataType;
template <>
class CudnnDataType<float16> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
typedef const float16 ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = static_cast<float16>(1.0);
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = static_cast<float16>(0.0);
return &v;
}
};
template <>
class CudnnDataType<float> {
public:
......
......@@ -469,6 +469,31 @@ class OpTest(unittest.TestCase):
tensor.set_lod(lod)
return tensor
@staticmethod
def create_view(input):
"""Create a view of the input numpy array
numpy float16 is binded to paddle::platform::float16
in tensor_py.h via the help of numpy uint16 because
the internal memory representation of float16 is
uint16_t in paddle or np.uint16 in numpy, which are
themselves binded together.
Args:
input: input numpy array
Returns:
input_view: if the dtype of input is np.float16, input_view
will reinterpret input as with dtype np.uint16.
Otherwise, input_view will be input itself.
"""
if input.dtype == np.float16:
# view will only reinterpret memory without copying
input_view = input.view(np.uint16)
else:
input_view = input
return input_view
def _get_gradient(self, input_to_check, place, output_names, no_grad_set):
prog = Program()
block = prog.global_block()
......
......@@ -68,6 +68,7 @@ class TestConv2dOp(OpTest):
self.init_op_type()
self.init_group()
self.init_dilation()
self.init_data_type()
self.init_test_case()
conv2d_param = {
......@@ -75,12 +76,22 @@ class TestConv2dOp(OpTest):
'pad': self.pad,
'dilation': self.dilations
}
input = np.random.random(self.input_size).astype("float32")
filter = np.random.random(self.filter_size).astype("float32")
output = conv2d_forward_naive(input, filter, self.groups,
conv2d_param).astype('float32')
self.inputs = {'Input': input, 'Filter': filter}
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv2d_forward_naive(self.input, self.filter, self.groups,
conv2d_param).astype(self.dtype)
# numpy float16 is binded to paddle::platform::float16
# in tensor_py.h via the help of numpy uint16 because
# the internal memory representation of float16 is
# uint16_t in paddle or np.uint16 in numpy, which are
# themselves binded together.
self.inputs = {
'Input': input.view(np.uint16)
if self.dtype == np.float16 else input,
'Filter': create_view(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
......@@ -148,6 +159,9 @@ class TestConv2dOp(OpTest):
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3]
def init_data_type(self):
self.dtype = np.float32
def init_dilation(self):
self.dilations = [1, 1]
......@@ -232,6 +246,26 @@ class TestCUDNN(TestConv2dOp):
self.op_type = "conv2d"
class TestFP16CUDNN(TestCUDNN):
def init_data_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-1)
def test_check_grad(self):
pass
def test_check_grad_no_filter(self):
pass
def test_check_grad_no_input(self):
pass
class TestCUDNNWithPad(TestWithPad):
def init_op_type(self):
self.use_cudnn = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册