提交 c8f101e5 编写于 作者: X xiaolil1 提交者: Tao Luo

Conv int8 relu (#15130)

* Enable basic MKL-DNN INT8 Conv OP
test=develop

* Modify test case
test=develop

* Clean unittest code
test=develop

* Fix test
test=develop

* Modify test
test=develop

* Enable MKL-DNN INT8 Conv with Relu Fusion OP
test=develop

* Modify basic INT8 Conv
test=develop

* fix type
test=develop

* Modify test
test=develop
上级 f3a13512
......@@ -319,6 +319,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool is_conv3d = strides.size() == 3U;
......@@ -329,6 +331,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dilations[2] == 1
: dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently");
const T* input_data = input->data<T>();
......@@ -340,15 +343,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
auto dst_dt = fuse_relu ? paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<uint8_t>::DataType)
: paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType);
if (force_fp32_output) {
dst_dt = paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<float>::DataType);
}
// Get unique name for storing MKLDNN primitives
std::string key;
key.reserve(MaxKeyLength);
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
platform::ConvMKLDNNHandler::AppendKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), ctx.op().Output("Output"));
input->format(), dst_dt, ctx.op().Output("Output"));
const std::string key_conv_pd = key + "@conv_pd";
std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr;
......@@ -413,13 +425,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8, chosen_memory_format);
auto dst_dt = force_fp32_output
? paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<float>::DataType)
: paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType);
auto dst_md =
platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
......@@ -429,11 +434,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
output_shift_scale, is_test);
fuse_relu, output_shift_scale, is_test);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, output_shift_scale, is_test);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
paddings, mkldnn_engine, fuse_relu,
output_shift_scale, is_test);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
......@@ -459,7 +464,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mask_reorder);
if (!force_fp32_output) {
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
if (fuse_relu) {
dst_memory_p = platform::SetDstMemory<uint8_t>(ctx, output, handler);
} else {
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
}
} else {
dst_memory_p = platform::SetDstMemory<float>(ctx, output, handler);
}
......@@ -518,8 +527,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn_engine, key));
}
if (!force_fp32_output) {
dst_memory_p =
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler);
if (fuse_relu) {
dst_memory_p =
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler);
} else {
dst_memory_p =
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler);
}
} else {
dst_memory_p =
platform::SetDstMemoryHandler<float>(ctx, output, handler);
......@@ -563,11 +577,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
mkldnn::primitive_attr CreatePostOps(
const std::vector<float> output_shift_scale) const {
bool fuse_relu, const std::vector<float> output_shift_scale) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 1.0f; // beta
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
......@@ -600,7 +621,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine,
const mkldnn::engine& engine, const bool fuse_relu,
const std::vector<float> output_shift_scale,
bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]};
......@@ -613,7 +634,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims,
padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(output_shift_scale);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, output_shift_scale);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
......@@ -652,7 +674,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine,
const mkldnn::engine& engine, const bool fuse_relu,
const std::vector<float> output_shift_scale,
bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]};
......@@ -665,7 +687,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
propagation, mkldnn::convolution_direct, src, weights, bias, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(output_shift_scale);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, output_shift_scale);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
......
......@@ -214,16 +214,18 @@ class MKLDNNHandler {
std::string* key, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations,
const int& groups, const mkldnn::memory::data_type& type,
const mkldnn::memory::format& format, const std::string& suffix) {
const int& groups, const mkldnn::memory::data_type& srcdt,
const mkldnn::memory::format& format,
const mkldnn::memory::data_type& dstdt, const std::string& suffix) {
AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims);
AppendKeyVec(key, strides);
AppendKeyVec(key, paddings);
AppendKeyVec(key, dilations);
AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(type));
AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format));
AppendKey(key, std::to_string(dstdt));
AppendKey(key, suffix);
}
......
......@@ -47,7 +47,8 @@ class TestConv2dInt8Op(TestConv2dOp):
self.init_group()
self.init_dilation()
self.init_test_case()
self.init_dtype()
self.init_fuse_relu()
self.init_data_type()
conv2d_param = {
'stride': self.stride,
......@@ -78,7 +79,11 @@ class TestConv2dInt8Op(TestConv2dOp):
np.round((input_shift) * self.scale_in).astype(np.int32),
filter_int, self.groups,
conv2d_param).astype(np.float32) * scale_output_shift
output = np.round(output1 - output2).astype(self.dsttype)
if self.fuse_relu:
output = np.maximum(np.round(output1 - output2),
0).astype(self.dsttype)
else:
output = np.round(output1 - output2).astype(self.dsttype)
else:
filter_int = np.round(filter *
self.scale_weights[0]).astype(np.int32)
......@@ -87,7 +92,15 @@ class TestConv2dInt8Op(TestConv2dOp):
output1 = conv2d_forward_refer(
input.astype(np.int32), filter_int, self.groups,
conv2d_param).astype(np.float32)
output = np.round(output1 * scale_output_shift).astype(self.dsttype)
if self.fuse_relu:
output = np.maximum(
np.round(output1 * (self.scale_out / (
self.scale_in * self.scale_weights[0]))),
0).astype(self.dsttype)
else:
output = np.round(output1 * (self.scale_out / (
self.scale_in *
self.scale_weights[0]))).astype(self.dsttype)
self.inputs = {
'Input':
......@@ -106,6 +119,7 @@ class TestConv2dInt8Op(TestConv2dOp):
'Scale_in': self.scale_in,
'Scale_out': self.scale_out,
'Scale_weights': self.scale_weights,
'fuse_relu': self.fuse_relu
}
self.outputs = {'Output': output}
......@@ -129,12 +143,15 @@ class TestConv2dInt8Op(TestConv2dOp):
self.scale_out = 0.5
self.scale_weights = [10.0]
def init_dtype(self):
def init_data_type(self):
self.srctype = np.uint8
self.dsttype = np.int8
def init_fuse_relu(self):
self.fuse_relu = True
#--------------------test conv2d u8 in and s8 out--------------------
#--------------------test conv2d u8 in and u8 out--------------------
class TestConv2d(TestConv2dInt8Op):
......@@ -203,18 +220,43 @@ class TestWithInput1x1Filter1x1(TestConv2dInt8Op):
self.groups = 3
#--------------------test conv2d s8 in and s8 out--------------------
def init_data_type_with_fusion(self, input_dt, fuse_relu):
self.srctype = input_dt
self.dsttype = np.uint8 if fuse_relu else np.int8
def init_fuse_relu(self):
self.fuse_relu = fuse_relu
def create_test_int8_class(parent):
class TestInt8Case(parent):
def init_dtype(self):
self.srctype = np.int8
self.dsttype = np.int8
cls_name = "{0}_{1}".format(parent.__name__, "s8s8")
TestInt8Case.__name__ = cls_name
globals()[cls_name] = TestInt8Case
#--------------------test conv2d s8 in and u8 out--------------------
class TestS8U8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, True)
#--------------------test conv2d s8 in and s8 out--------------------
class TestS8S8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, False)
#--------------------test conv2d u8 in and s8 out--------------------
class TestU8S8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, False)
cls_name_s8u8 = "{0}_relu_{1}".format(parent.__name__, "1")
cls_name_s8s8 = "{0}_relu_{1}".format(parent.__name__, "0")
cls_name_u8s8 = "{0}_relu_{1}".format(parent.__name__, "0")
TestS8U8Case.__name__ = cls_name_s8u8
TestS8S8Case.__name__ = cls_name_s8s8
TestU8S8Case.__name__ = cls_name_u8s8
globals()[cls_name_s8u8] = TestS8U8Case
globals()[cls_name_s8s8] = TestS8S8Case
globals()[cls_name_u8s8] = TestU8S8Case
create_test_int8_class(TestConv2dInt8Op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册