未验证 提交 8f800dc0 编写于 作者: S Sylwester Fraczek 提交者: GitHub

add map_matmul and fc_act_fuse passes to quant2_int8_mkldnn_pass (#38023)

* add map_matmul passes to quant2_int8_mkldnn_pass

* fix fc+act fuse (activation scale)

* ci fix, c++17 structured bindings not available

* fix ci static check
上级 f8202941
......@@ -403,26 +403,36 @@ class FCPrimitiveFactory {
// scaled with its own scales, this data needs to be divided by
// those scales to normalise them back to what their floating-point range
// was. Then we multiply them by desired output scale we want on the output.
std::vector<float> ComputeOutputShiftScale(const ExecutionContext& ctx) {
std::tuple<std::vector<float>, float> ComputeOutputShiftScale(
const ExecutionContext& ctx) {
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
// If the output will be in floats, we don't multiply by scale_out.
auto scale_out_data = ctx.Attr<bool>("force_fp32_output")
? 1.0f
: ctx.Attr<float>("Scale_out");
float activation_scale = 1.0f;
float inner_scale = 1.0f;
if (!ctx.Attr<bool>("force_fp32_output")) {
// if has activation use it's scale, otherwise use inner scale.
if (!ctx.Attr<std::string>("activation_type").empty()) {
activation_scale = ctx.Attr<float>("Scale_out");
} else {
inner_scale = ctx.Attr<float>("Scale_out");
}
}
const size_t weight_scales_num = scale_weights_data.size();
std::vector<float> output_shift_scale(weight_scales_num);
#pragma omp parallel for
for (size_t i = 0; i < weight_scales_num; i++) {
if (scale_weights_data[i] == 0.0)
output_shift_scale[i] = scale_out_data;
output_shift_scale[i] = inner_scale;
else
output_shift_scale[i] =
scale_out_data / (scale_in_data * scale_weights_data[i]);
inner_scale / (scale_in_data * scale_weights_data[i]);
}
return output_shift_scale;
return make_tuple(output_shift_scale, activation_scale);
}
// Computing MKL-DNN's scaling mask which determines along which dimension
......@@ -449,48 +459,43 @@ class FCPrimitiveFactory {
dnnl::primitive_attr attributes;
dnnl::post_ops post_operations;
auto output_shift_scale = ComputeOutputShiftScale(ctx);
std::vector<float> output_shift_scale;
float scale;
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale);
if (ctx.Attr<std::string>("activation_type") == "relu") {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 1.0f; // beta
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_relu,
negative_slope, placeholder);
} else if (ctx.Attr<std::string>("activation_type") == "gelu") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "gelu_tanh") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_tanh,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "gelu_erf") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_erf,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "tanh") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_tanh,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "sigmoid") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_logistic,
alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "hard_swish") {
constexpr float scale = 1.0f;
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_hardswish,
......
......@@ -410,6 +410,9 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass')
graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass')
graph = self._apply_pass(graph, 'is_test_pass')
graph = self._apply_pass(graph, 'map_matmul_v2_to_mul_pass')
graph = self._apply_pass(graph, 'map_matmul_v2_to_matmul_pass')
graph = self._apply_pass(graph, 'map_matmul_to_mul_pass')
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
......@@ -426,7 +429,9 @@ class Quant2Int8MkldnnPass(object):
['use_gpu', 'use_fc_padding'], [False, False])
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
if self._is_fc_quantized(graph):
# Disabled due to topology-dependent speed-up
graph = self._apply_pass(graph, 'fc_mkldnn_pass')
graph = self._apply_pass(graph, 'fc_act_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass')
graph = self._apply_pass(graph, 'matmul_v2_transpose_reshape_fuse_pass')
# the following pass should be the last one since it will work on all fused ops.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册