提交 c8f101e5 编写于 作者: X xiaolil1 提交者: Tao Luo

Conv int8 relu (#15130)

* Enable basic MKL-DNN INT8 Conv OP
test=develop

* Modify test case
test=develop

* Clean unittest code
test=develop

* Fix test
test=develop

* Modify test
test=develop

* Enable MKL-DNN INT8 Conv with Relu Fusion OP
test=develop

* Modify basic INT8 Conv
test=develop

* fix type
test=develop

* Modify test
test=develop
上级 f3a13512
...@@ -319,6 +319,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -319,6 +319,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool is_conv3d = strides.size() == 3U; bool is_conv3d = strides.size() == 3U;
...@@ -329,6 +331,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -329,6 +331,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dilations[2] == 1 dilations[2] == 1
: dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, : dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet"); "dilation in convolution is not implemented yet");
PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently"); PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently");
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
...@@ -340,15 +343,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -340,15 +343,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
auto dst_dt = fuse_relu ? paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<uint8_t>::DataType)
: paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType);
if (force_fp32_output) {
dst_dt = paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<float>::DataType);
}
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
std::string key; std::string key;
key.reserve(MaxKeyLength); key.reserve(MaxKeyLength);
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
platform::ConvMKLDNNHandler::AppendKey( platform::ConvMKLDNNHandler::AppendKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), ctx.op().Output("Output")); input->format(), dst_dt, ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr; std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr;
...@@ -413,13 +425,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -413,13 +425,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format); platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc( auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8, chosen_memory_format); weights_tz, memory::data_type::s8, chosen_memory_format);
auto dst_dt = force_fp32_output
? paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<float>::DataType)
: paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType);
auto dst_md = auto dst_md =
platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format); platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
...@@ -429,11 +434,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -429,11 +434,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::format::x); memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
output_shift_scale, is_test); fuse_relu, output_shift_scale, is_test);
} else { } else {
conv_pd = conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, paddings, mkldnn_engine, fuse_relu,
mkldnn_engine, output_shift_scale, is_test); output_shift_scale, is_test);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -459,7 +464,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -459,7 +464,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mask_reorder); mask_reorder);
if (!force_fp32_output) { if (!force_fp32_output) {
if (fuse_relu) {
dst_memory_p = platform::SetDstMemory<uint8_t>(ctx, output, handler);
} else {
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler); dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
}
} else { } else {
dst_memory_p = platform::SetDstMemory<float>(ctx, output, handler); dst_memory_p = platform::SetDstMemory<float>(ctx, output, handler);
} }
...@@ -518,8 +527,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -518,8 +527,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn_engine, key)); mkldnn_engine, key));
} }
if (!force_fp32_output) { if (!force_fp32_output) {
if (fuse_relu) {
dst_memory_p =
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler);
} else {
dst_memory_p = dst_memory_p =
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler); platform::SetDstMemoryHandler<int8_t>(ctx, output, handler);
}
} else { } else {
dst_memory_p = dst_memory_p =
platform::SetDstMemoryHandler<float>(ctx, output, handler); platform::SetDstMemoryHandler<float>(ctx, output, handler);
...@@ -563,11 +577,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -563,11 +577,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
mkldnn::primitive_attr CreatePostOps( mkldnn::primitive_attr CreatePostOps(
const std::vector<float> output_shift_scale) const { bool fuse_relu, const std::vector<float> output_shift_scale) const {
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale); conv_attr.set_output_scales(mask, output_shift_scale);
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 1.0f; // beta
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
} }
...@@ -600,7 +621,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -600,7 +621,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const mkldnn::engine& engine, const bool fuse_relu,
const std::vector<float> output_shift_scale, const std::vector<float> output_shift_scale,
bool is_test) const { bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
...@@ -613,7 +634,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -613,7 +634,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims, propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims,
padding_dims, padding_dims, mkldnn::padding_kind::zero); padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(output_shift_scale); mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, output_shift_scale);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
...@@ -652,7 +674,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -652,7 +674,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& bias, const memory::desc& dst, const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const mkldnn::engine& engine, const bool fuse_relu,
const std::vector<float> output_shift_scale, const std::vector<float> output_shift_scale,
bool is_test) const { bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
...@@ -665,7 +687,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -665,7 +687,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
propagation, mkldnn::convolution_direct, src, weights, bias, dst, propagation, mkldnn::convolution_direct, src, weights, bias, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(output_shift_scale); mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, output_shift_scale);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
......
...@@ -214,16 +214,18 @@ class MKLDNNHandler { ...@@ -214,16 +214,18 @@ class MKLDNNHandler {
std::string* key, const mkldnn::memory::dims& input_dims, std::string* key, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides, const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations, const std::vector<int>& paddings, const std::vector<int>& dilations,
const int& groups, const mkldnn::memory::data_type& type, const int& groups, const mkldnn::memory::data_type& srcdt,
const mkldnn::memory::format& format, const std::string& suffix) { const mkldnn::memory::format& format,
const mkldnn::memory::data_type& dstdt, const std::string& suffix) {
AppendKeyDims(key, input_dims); AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims); AppendKeyDims(key, weights_dims);
AppendKeyVec(key, strides); AppendKeyVec(key, strides);
AppendKeyVec(key, paddings); AppendKeyVec(key, paddings);
AppendKeyVec(key, dilations); AppendKeyVec(key, dilations);
AppendKey(key, std::to_string(groups)); AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(type)); AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format)); AppendKey(key, std::to_string(format));
AppendKey(key, std::to_string(dstdt));
AppendKey(key, suffix); AppendKey(key, suffix);
} }
......
...@@ -47,7 +47,8 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -47,7 +47,8 @@ class TestConv2dInt8Op(TestConv2dOp):
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
self.init_test_case() self.init_test_case()
self.init_dtype() self.init_fuse_relu()
self.init_data_type()
conv2d_param = { conv2d_param = {
'stride': self.stride, 'stride': self.stride,
...@@ -78,6 +79,10 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -78,6 +79,10 @@ class TestConv2dInt8Op(TestConv2dOp):
np.round((input_shift) * self.scale_in).astype(np.int32), np.round((input_shift) * self.scale_in).astype(np.int32),
filter_int, self.groups, filter_int, self.groups,
conv2d_param).astype(np.float32) * scale_output_shift conv2d_param).astype(np.float32) * scale_output_shift
if self.fuse_relu:
output = np.maximum(np.round(output1 - output2),
0).astype(self.dsttype)
else:
output = np.round(output1 - output2).astype(self.dsttype) output = np.round(output1 - output2).astype(self.dsttype)
else: else:
filter_int = np.round(filter * filter_int = np.round(filter *
...@@ -87,7 +92,15 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -87,7 +92,15 @@ class TestConv2dInt8Op(TestConv2dOp):
output1 = conv2d_forward_refer( output1 = conv2d_forward_refer(
input.astype(np.int32), filter_int, self.groups, input.astype(np.int32), filter_int, self.groups,
conv2d_param).astype(np.float32) conv2d_param).astype(np.float32)
output = np.round(output1 * scale_output_shift).astype(self.dsttype) if self.fuse_relu:
output = np.maximum(
np.round(output1 * (self.scale_out / (
self.scale_in * self.scale_weights[0]))),
0).astype(self.dsttype)
else:
output = np.round(output1 * (self.scale_out / (
self.scale_in *
self.scale_weights[0]))).astype(self.dsttype)
self.inputs = { self.inputs = {
'Input': 'Input':
...@@ -106,6 +119,7 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -106,6 +119,7 @@ class TestConv2dInt8Op(TestConv2dOp):
'Scale_in': self.scale_in, 'Scale_in': self.scale_in,
'Scale_out': self.scale_out, 'Scale_out': self.scale_out,
'Scale_weights': self.scale_weights, 'Scale_weights': self.scale_weights,
'fuse_relu': self.fuse_relu
} }
self.outputs = {'Output': output} self.outputs = {'Output': output}
...@@ -129,12 +143,15 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -129,12 +143,15 @@ class TestConv2dInt8Op(TestConv2dOp):
self.scale_out = 0.5 self.scale_out = 0.5
self.scale_weights = [10.0] self.scale_weights = [10.0]
def init_dtype(self): def init_data_type(self):
self.srctype = np.uint8 self.srctype = np.uint8
self.dsttype = np.int8 self.dsttype = np.int8
def init_fuse_relu(self):
self.fuse_relu = True
#--------------------test conv2d u8 in and s8 out-------------------- #--------------------test conv2d u8 in and u8 out--------------------
class TestConv2d(TestConv2dInt8Op): class TestConv2d(TestConv2dInt8Op):
...@@ -203,18 +220,43 @@ class TestWithInput1x1Filter1x1(TestConv2dInt8Op): ...@@ -203,18 +220,43 @@ class TestWithInput1x1Filter1x1(TestConv2dInt8Op):
self.groups = 3 self.groups = 3
#--------------------test conv2d s8 in and s8 out-------------------- def init_data_type_with_fusion(self, input_dt, fuse_relu):
self.srctype = input_dt
self.dsttype = np.uint8 if fuse_relu else np.int8
def init_fuse_relu(self):
self.fuse_relu = fuse_relu
def create_test_int8_class(parent): def create_test_int8_class(parent):
class TestInt8Case(parent):
def init_dtype(self):
self.srctype = np.int8
self.dsttype = np.int8
cls_name = "{0}_{1}".format(parent.__name__, "s8s8") #--------------------test conv2d s8 in and u8 out--------------------
TestInt8Case.__name__ = cls_name
globals()[cls_name] = TestInt8Case class TestS8U8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, True)
#--------------------test conv2d s8 in and s8 out--------------------
class TestS8S8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, False)
#--------------------test conv2d u8 in and s8 out--------------------
class TestU8S8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, False)
cls_name_s8u8 = "{0}_relu_{1}".format(parent.__name__, "1")
cls_name_s8s8 = "{0}_relu_{1}".format(parent.__name__, "0")
cls_name_u8s8 = "{0}_relu_{1}".format(parent.__name__, "0")
TestS8U8Case.__name__ = cls_name_s8u8
TestS8S8Case.__name__ = cls_name_s8s8
TestU8S8Case.__name__ = cls_name_u8s8
globals()[cls_name_s8u8] = TestS8U8Case
globals()[cls_name_s8s8] = TestS8S8Case
globals()[cls_name_u8s8] = TestU8S8Case
create_test_int8_class(TestConv2dInt8Op) create_test_int8_class(TestConv2dInt8Op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册