未验证 提交 3705b12c 编写于 作者: J Jacek Czaja 提交者: GitHub

Added caching of scales for bias in conv2d int8 (#36980)

* - Cached bias scales

* - Fix

* - fixes after review

* - second round of fixes after internal review
上级 db6c00c4
...@@ -389,6 +389,49 @@ class ConvMKLDNNHandlerT ...@@ -389,6 +389,49 @@ class ConvMKLDNNHandlerT
} }
} }
std::shared_ptr<std::tuple<float, std::vector<float>>> get_int8_bias_scales(
const framework::ExecutionContext& ctx) {
// Get scales int8 bias key
const std::string key_bs = this->key_ + "@bs";
// Scales for int8 bias are to be cached to avoid
// computing them each iteration
auto bias_scale_tuple =
std::static_pointer_cast<std::tuple<float, std::vector<float>>>(
this->dev_ctx_.GetBlob(key_bs));
if (bias_scale_tuple) return bias_scale_tuple;
const auto* filter = ctx.Input<Tensor>("Filter");
const auto& weights_tz = framework::vectorize(filter->dims());
const int groups = std::max(ctx.Attr<int>("groups"), 1);
const auto& scale_weights_data =
ctx.Attr<std::vector<float>>("Scale_weights");
const auto& scale_in_data = ctx.Attr<float>("Scale_in");
bool is_multi_channel = scale_weights_data.size() > 1;
int mask_reorder = is_multi_channel ? 1 << 0 : 1;
int count = 1;
if (is_multi_channel) {
count *= weights_tz[0];
if (groups > 1) {
count *= weights_tz[1];
}
}
bias_scale_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(std::make_tuple(
static_cast<float>(mask_reorder), std::vector<float>(count)));
for (int i = 0; i < count; i++) {
std::get<1>(*bias_scale_tuple)[i] = scale_in_data * scale_weights_data[i];
}
this->dev_ctx_.SetBlob(key_bs, bias_scale_tuple);
return bias_scale_tuple;
}
std::tuple<float, std::vector<float>> get_int8_scales( std::tuple<float, std::vector<float>> get_int8_scales(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
const auto* filter = ctx.Input<Tensor>("Filter"); const auto* filter = ctx.Input<Tensor>("Filter");
...@@ -428,32 +471,6 @@ class ConvMKLDNNHandlerT ...@@ -428,32 +471,6 @@ class ConvMKLDNNHandlerT
return std::make_tuple(sum_scale, output_shift_scale); return std::make_tuple(sum_scale, output_shift_scale);
} }
std::tuple<float, std::vector<float>> get_int8_bias_scales(
const framework::ExecutionContext& ctx) const {
const auto* filter = ctx.Input<Tensor>("Filter");
const auto& weights_tz = framework::vectorize(filter->dims());
const int groups = std::max(ctx.Attr<int>("groups"), 1);
const auto& scale_weights_data =
ctx.Attr<std::vector<float>>("Scale_weights");
const auto& scale_in_data = ctx.Attr<float>("Scale_in");
bool is_multi_channel = scale_weights_data.size() > 1;
int mask_reorder = is_multi_channel ? 1 << 0 : 1;
int count =
is_multi_channel
? (groups > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0])
: 1;
std::vector<float> scale_bias_data(count);
#pragma omp parallel for if (count > 50)
for (int i = 0; i < count; i++) {
scale_bias_data[i] = scale_in_data * scale_weights_data[i];
}
return std::make_tuple(mask_reorder, scale_bias_data);
}
mkldnn::primitive_attr CreatePostOps( mkldnn::primitive_attr CreatePostOps(
std::string fuse_activation, float fuse_alpha, float fuse_beta, std::string fuse_activation, float fuse_alpha, float fuse_beta,
bool fuse_residual_conn, const std::vector<float> output_shift_scale = {}, bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
...@@ -818,13 +835,11 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -818,13 +835,11 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
{MKLDNN_ARG_DST, *dst_memory_p}}; {MKLDNN_ARG_DST, *dst_memory_p}};
if (bias) { if (bias) {
float mask_reorder; auto p_scales_tuple = handler.get_int8_bias_scales(ctx);
std::vector<float> scale_bias_data;
std::tie(mask_reorder, scale_bias_data) =
handler.get_int8_bias_scales(ctx);
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
bias, is_test, scale_bias_data, mask_reorder); bias, is_test, std::get<1>(*p_scales_tuple),
std::get<0>(*p_scales_tuple));
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册