未验证 提交 8f800dc0 编写于 作者: S Sylwester Fraczek 提交者: GitHub

add map_matmul and fc_act_fuse passes to quant2_int8_mkldnn_pass (#38023)

* add map_matmul passes to quant2_int8_mkldnn_pass

* fix fc+act fuse (activation scale)

* ci fix, c++17 structured bindings not available

* fix ci static check
上级 f8202941
...@@ -403,26 +403,36 @@ class FCPrimitiveFactory { ...@@ -403,26 +403,36 @@ class FCPrimitiveFactory {
// scaled with its own scales, this data needs to be divided by // scaled with its own scales, this data needs to be divided by
// those scales to normalise them back to what their floating-point range // those scales to normalise them back to what their floating-point range
// was. Then we multiply them by desired output scale we want on the output. // was. Then we multiply them by desired output scale we want on the output.
std::vector<float> ComputeOutputShiftScale(const ExecutionContext& ctx) { std::tuple<std::vector<float>, float> ComputeOutputShiftScale(
const ExecutionContext& ctx) {
auto scale_in_data = ctx.Attr<float>("Scale_in"); auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights"); auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
// If the output will be in floats, we don't multiply by scale_out. // If the output will be in floats, we don't multiply by scale_out.
auto scale_out_data = ctx.Attr<bool>("force_fp32_output") float activation_scale = 1.0f;
? 1.0f float inner_scale = 1.0f;
: ctx.Attr<float>("Scale_out"); if (!ctx.Attr<bool>("force_fp32_output")) {
// if has activation use it's scale, otherwise use inner scale.
if (!ctx.Attr<std::string>("activation_type").empty()) {
activation_scale = ctx.Attr<float>("Scale_out");
} else {
inner_scale = ctx.Attr<float>("Scale_out");
}
}
const size_t weight_scales_num = scale_weights_data.size(); const size_t weight_scales_num = scale_weights_data.size();
std::vector<float> output_shift_scale(weight_scales_num); std::vector<float> output_shift_scale(weight_scales_num);
#pragma omp parallel for #pragma omp parallel for
for (size_t i = 0; i < weight_scales_num; i++) { for (size_t i = 0; i < weight_scales_num; i++) {
if (scale_weights_data[i] == 0.0) if (scale_weights_data[i] == 0.0)
output_shift_scale[i] = scale_out_data; output_shift_scale[i] = inner_scale;
else else
output_shift_scale[i] = output_shift_scale[i] =
scale_out_data / (scale_in_data * scale_weights_data[i]); inner_scale / (scale_in_data * scale_weights_data[i]);
} }
return output_shift_scale; return make_tuple(output_shift_scale, activation_scale);
} }
// Computing MKL-DNN's scaling mask which determines along which dimension // Computing MKL-DNN's scaling mask which determines along which dimension
...@@ -449,48 +459,43 @@ class FCPrimitiveFactory { ...@@ -449,48 +459,43 @@ class FCPrimitiveFactory {
dnnl::primitive_attr attributes; dnnl::primitive_attr attributes;
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
auto output_shift_scale = ComputeOutputShiftScale(ctx); std::vector<float> output_shift_scale;
float scale;
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1); int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale); attributes.set_output_scales(mask, output_shift_scale);
if (ctx.Attr<std::string>("activation_type") == "relu") { if (ctx.Attr<std::string>("activation_type") == "relu") {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f; constexpr float negative_slope = 0.0f;
constexpr float placeholder = 1.0f; // beta constexpr float placeholder = 1.0f; // beta
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_relu, post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_relu,
negative_slope, placeholder); negative_slope, placeholder);
} else if (ctx.Attr<std::string>("activation_type") == "gelu") { } else if (ctx.Attr<std::string>("activation_type") == "gelu") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f; constexpr float alpha = 0.0f;
constexpr float beta = 0.0f; constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu, post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu,
alpha, beta); alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "gelu_tanh") { } else if (ctx.Attr<std::string>("activation_type") == "gelu_tanh") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f; constexpr float alpha = 0.0f;
constexpr float beta = 0.0f; constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_tanh, post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_tanh,
alpha, beta); alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "gelu_erf") { } else if (ctx.Attr<std::string>("activation_type") == "gelu_erf") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f; constexpr float alpha = 0.0f;
constexpr float beta = 0.0f; constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_erf, post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_erf,
alpha, beta); alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "tanh") { } else if (ctx.Attr<std::string>("activation_type") == "tanh") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f; constexpr float alpha = 0.0f;
constexpr float beta = 0.0f; constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_tanh, post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_tanh,
alpha, beta); alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "sigmoid") { } else if (ctx.Attr<std::string>("activation_type") == "sigmoid") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f; constexpr float alpha = 0.0f;
constexpr float beta = 0.0f; constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_logistic, post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_logistic,
alpha, beta); alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "hard_swish") { } else if (ctx.Attr<std::string>("activation_type") == "hard_swish") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f; constexpr float alpha = 0.0f;
constexpr float beta = 0.0f; constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_hardswish, post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_hardswish,
......
...@@ -410,6 +410,9 @@ class Quant2Int8MkldnnPass(object): ...@@ -410,6 +410,9 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass') graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass')
graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass') graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass')
graph = self._apply_pass(graph, 'is_test_pass') graph = self._apply_pass(graph, 'is_test_pass')
graph = self._apply_pass(graph, 'map_matmul_v2_to_mul_pass')
graph = self._apply_pass(graph, 'map_matmul_v2_to_matmul_pass')
graph = self._apply_pass(graph, 'map_matmul_to_mul_pass')
graph = self._apply_pass(graph, 'mkldnn_placement_pass', graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()]) ['mkldnn_enabled_op_types'], [set()])
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
...@@ -426,7 +429,9 @@ class Quant2Int8MkldnnPass(object): ...@@ -426,7 +429,9 @@ class Quant2Int8MkldnnPass(object):
['use_gpu', 'use_fc_padding'], [False, False]) ['use_gpu', 'use_fc_padding'], [False, False])
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
if self._is_fc_quantized(graph): if self._is_fc_quantized(graph):
# Disabled due to topology-dependent speed-up
graph = self._apply_pass(graph, 'fc_mkldnn_pass') graph = self._apply_pass(graph, 'fc_mkldnn_pass')
graph = self._apply_pass(graph, 'fc_act_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass') graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass')
graph = self._apply_pass(graph, 'matmul_v2_transpose_reshape_fuse_pass') graph = self._apply_pass(graph, 'matmul_v2_transpose_reshape_fuse_pass')
# the following pass should be the last one since it will work on all fused ops. # the following pass should be the last one since it will work on all fused ops.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册