未验证 提交 4825addd 编写于 作者: S Sylwester Fraczek 提交者: GitHub

Fix conv act int8 scale (#38331)

* fix conv act int8 scale

* add unit test for conv+hard_swish
上级 d296456c
...@@ -42,4 +42,4 @@ class FuseFCActOneDNNPass : public FusePassBase { ...@@ -42,4 +42,4 @@ class FuseFCActOneDNNPass : public FusePassBase {
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddlea } // namespace paddle
...@@ -218,13 +218,15 @@ class ConvMKLDNNHandlerT ...@@ -218,13 +218,15 @@ class ConvMKLDNNHandlerT
: dnnl::prop_kind::forward_training; : dnnl::prop_kind::forward_training;
float sum_scale = 1.0f; float sum_scale = 1.0f;
float activation_scale = 1.0f;
std::vector<float> output_shift_scale; std::vector<float> output_shift_scale;
if (platform::is_int8<T>()) if (platform::is_int8<T>())
std::tie(sum_scale, output_shift_scale) = get_int8_scales(ctx); std::tie(sum_scale, output_shift_scale, activation_scale) =
get_int8_scales(ctx);
const dnnl::primitive_attr conv_attr = CreatePostOps( const dnnl::primitive_attr conv_attr = CreatePostOps(
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn, fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
output_shift_scale, sum_scale); // for INT8 only! output_shift_scale, sum_scale, activation_scale); // for INT8 only!
if (bias) { if (bias) {
auto bias_tz = framework::vectorize(bias->dims()); auto bias_tz = framework::vectorize(bias->dims());
...@@ -432,7 +434,7 @@ class ConvMKLDNNHandlerT ...@@ -432,7 +434,7 @@ class ConvMKLDNNHandlerT
return bias_scale_tuple; return bias_scale_tuple;
} }
std::tuple<float, std::vector<float>> get_int8_scales( std::tuple<float, std::vector<float>, float> get_int8_scales(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
const auto* filter = ctx.Input<Tensor>("Filter"); const auto* filter = ctx.Input<Tensor>("Filter");
const auto& weights_tz = framework::vectorize(filter->dims()); const auto& weights_tz = framework::vectorize(filter->dims());
...@@ -445,8 +447,14 @@ class ConvMKLDNNHandlerT ...@@ -445,8 +447,14 @@ class ConvMKLDNNHandlerT
const auto& scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise"); const auto& scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights"); auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
bool is_multi_channel = scale_weights_data.size() > 1; bool is_multi_channel = scale_weights_data.size() > 1;
bool has_activation = !ctx.Attr<std::string>("fuse_activation").empty();
float activation_scale =
force_fp32_output ? 1.0f : has_activation ? ctx.Attr<float>("Scale_out")
: 1.0f;
auto scale_out_data = auto scale_out_data =
force_fp32_output ? 1.0f : ctx.Attr<float>("Scale_out"); force_fp32_output ? 1.0f : has_activation
? 1.0f
: ctx.Attr<float>("Scale_out");
float sum_scale = float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f; fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
int count = int count =
...@@ -468,13 +476,13 @@ class ConvMKLDNNHandlerT ...@@ -468,13 +476,13 @@ class ConvMKLDNNHandlerT
static_cast<double>(scale_weights_data[i]))); static_cast<double>(scale_weights_data[i])));
} }
return std::make_tuple(sum_scale, output_shift_scale); return std::make_tuple(sum_scale, output_shift_scale, activation_scale);
} }
dnnl::primitive_attr CreatePostOps( dnnl::primitive_attr CreatePostOps(
std::string fuse_activation, float fuse_alpha, float fuse_beta, std::string fuse_activation, float fuse_alpha, float fuse_beta,
bool fuse_residual_conn, const std::vector<float> output_shift_scale = {}, bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
float sum_scale = 1.0f) { float sum_scale = 1.0f, float activation_scale = 1.0f) {
dnnl::primitive_attr conv_attr; dnnl::primitive_attr conv_attr;
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
if (output_shift_scale.size() > 0) { if (output_shift_scale.size() > 0) {
...@@ -492,30 +500,34 @@ class ConvMKLDNNHandlerT ...@@ -492,30 +500,34 @@ class ConvMKLDNNHandlerT
} }
// Fusion with ReLU layer is executed through the PostOps feature. Create a // Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation. // PostOps object and configure it to execute an eltwise relu operation.
constexpr float scale = 1.0f;
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") { if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_relu, post_operations.append_eltwise(activation_scale,
fuse_alpha, fuse_beta); dnnl::algorithm::eltwise_relu, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "relu6") { } else if (fuse_activation == "relu6") {
post_operations.append_eltwise( post_operations.append_eltwise(activation_scale,
scale, dnnl::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta); dnnl::algorithm::eltwise_bounded_relu,
} else if (fuse_activation == "swish") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_swish,
fuse_alpha, fuse_beta); fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_swish, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "hard_swish") { } else if (fuse_activation == "hard_swish") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_hardswish, post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_hardswish,
fuse_alpha, fuse_beta); fuse_alpha, fuse_beta);
} else if (fuse_activation == "hard_sigmoid") { } else if (fuse_activation == "hard_sigmoid") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_linear, post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha, fuse_beta); fuse_alpha, fuse_beta);
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_clip, 0.0f, post_operations.append_eltwise(activation_scale,
1.0f); dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else if (fuse_activation == "gelu_tanh") { } else if (fuse_activation == "gelu_tanh") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_tanh, post_operations.append_eltwise(
0.0f, 0.0f); activation_scale, dnnl::algorithm::eltwise_gelu_tanh, 0.0f, 0.0f);
} else if (fuse_activation == "gelu_erf") { } else if (fuse_activation == "gelu_erf") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_erf, post_operations.append_eltwise(
0.0f, 0.0f); activation_scale, dnnl::algorithm::eltwise_gelu_erf, 0.0f, 0.0f);
} }
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
......
...@@ -426,6 +426,7 @@ class Quant2Int8MkldnnPass(object): ...@@ -426,6 +426,7 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_hard_swish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass', graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'], [False, False]) ['use_gpu', 'use_fc_padding'], [False, False])
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
......
...@@ -43,7 +43,7 @@ class TestConv2DInt8Op(TestConv2DOp): ...@@ -43,7 +43,7 @@ class TestConv2DInt8Op(TestConv2DOp):
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
self.init_test_case() self.init_test_case()
self.init_fuse_relu() self.init_fuse_activation()
self.init_fuse_residual() self.init_fuse_residual()
self.init_data_type() self.init_data_type()
...@@ -54,7 +54,9 @@ class TestConv2DInt8Op(TestConv2DOp): ...@@ -54,7 +54,9 @@ class TestConv2DInt8Op(TestConv2DOp):
} }
# This implementation of convolution quantization is based on OneDNN documentation # This implementation of convolution quantization is based on OneDNN documentation
# https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html#doxid-dev-guide-int8-computations-1dg-i8-comp-s11 # https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html#doxid-dev-guide-int8-computations-1dg-i8-comp-s11
scale_output_shift = (self.scale_out / inner_scale = 1. if self.fuse_activation != "" else self.scale_out
activation_scale = self.scale_out if self.fuse_activation != "" else 1.
scale_output_shift = (inner_scale /
(self.scale_in * self.scale_weights[0])) (self.scale_in * self.scale_weights[0]))
filter = np.random.random(self.filter_size).astype(self.weighttype) filter = np.random.random(self.filter_size).astype(self.weighttype)
...@@ -78,7 +80,7 @@ class TestConv2DInt8Op(TestConv2DOp): ...@@ -78,7 +80,7 @@ class TestConv2DInt8Op(TestConv2DOp):
init_low, init_high, init_low, init_high,
self.input_residual_size).astype(self.srctype) self.input_residual_size).astype(self.srctype)
return (output_ + input_residual_ * return (output_ + input_residual_ *
(self.scale_out / self.scale_in_eltwise)), input_residual_ (inner_scale / self.scale_in_eltwise)), input_residual_
if self.srctype == np.int8: if self.srctype == np.int8:
init_low, init_high = (-5, 5) init_low, init_high = (-5, 5)
...@@ -101,12 +103,24 @@ class TestConv2DInt8Op(TestConv2DOp): ...@@ -101,12 +103,24 @@ class TestConv2DInt8Op(TestConv2DOp):
output, input_residual = residual_helper(init_low, init_high, output, input_residual = residual_helper(init_low, init_high,
output) output)
output = np.round(output) if self.fuse_activation == "":
pass
if self.fuse_activation == "relu": elif self.fuse_activation == "relu":
output = np.maximum(output, 0) output = activation_scale * np.maximum(output, 0)
elif self.fuse_activation == "hard_swish":
output = activation_scale * output / 6. * np.minimum(
np.maximum(0, output + 3.), 6)
elif self.fuse_activation == "relu6":
output = activation_scale * np.maximum(0, np.minimum(6, output))
elif self.fuse_activation == "swish":
output = activation_scale * output / (1. + np.exp(-1. * output))
elif self.fuse_activation == "leaky_relu":
output = activation_scale * np.maximum(output, 0.02 * output)
else:
raise NotImplementedError("test for " + self.fuse_activation +
" activation not implemented")
output = output.astype(self.dsttype) output = np.round(output).astype(self.dsttype)
self.inputs = { self.inputs = {
'Input': 'Input':
...@@ -131,6 +145,8 @@ class TestConv2DInt8Op(TestConv2DOp): ...@@ -131,6 +145,8 @@ class TestConv2DInt8Op(TestConv2DOp):
'Scale_weights': self.scale_weights, 'Scale_weights': self.scale_weights,
'Scale_in_eltwise': self.scale_in_eltwise, 'Scale_in_eltwise': self.scale_in_eltwise,
'fuse_activation': self.fuse_activation, 'fuse_activation': self.fuse_activation,
'fuse_alpha': self.fuse_alpha,
'fuse_beta': self.fuse_beta,
'fuse_residual_connection': self.fuse_residual, 'fuse_residual_connection': self.fuse_residual,
'mkldnn_data_type': self.mkldnn_data_type 'mkldnn_data_type': self.mkldnn_data_type
} }
...@@ -165,8 +181,10 @@ class TestConv2DInt8Op(TestConv2DOp): ...@@ -165,8 +181,10 @@ class TestConv2DInt8Op(TestConv2DOp):
self.srctype = np.uint8 self.srctype = np.uint8
self.dsttype = np.int8 self.dsttype = np.int8
def init_fuse_relu(self): def init_fuse_activation(self):
self.fuse_activation = "relu" self.fuse_activation = "relu"
self.fuse_alpha = 0
self.fuse_beta = 0
def init_fuse_residual(self): def init_fuse_residual(self):
self.fuse_residual = True self.fuse_residual = True
...@@ -190,6 +208,34 @@ class TestConv2D(TestConv2DInt8Op): ...@@ -190,6 +208,34 @@ class TestConv2D(TestConv2DInt8Op):
self.scale_in_eltwise = 0.6 self.scale_in_eltwise = 0.6
class TestWithHardSwish(TestConv2D):
def init_fuse_activation(self):
self.fuse_activation = "hard_swish"
self.fuse_alpha = 0
self.fuse_beta = 0
class TestWithRelu6(TestConv2D):
def init_fuse_activation(self):
self.fuse_activation = "relu6"
self.fuse_alpha = 6
self.fuse_beta = 0
class TestWithSwish(TestConv2D):
def init_fuse_activation(self):
self.fuse_activation = "swish"
self.fuse_alpha = 1
self.fuse_beta = 0
class TestWithLeakyRelu(TestConv2D):
def init_fuse_activation(self):
self.fuse_activation = "leaky_relu"
self.fuse_alpha = 0.02
self.fuse_beta = 0
class TestWithPad(TestConv2D): class TestWithPad(TestConv2D):
def init_test_case(self): def init_test_case(self):
TestConv2D.init_test_case(self) TestConv2D.init_test_case(self)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册