diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 2f01b593135210e8a5fbbcca3e116c9e6b50dca0..fa2428458e5690dd2dde21f39ef9d8d34471a3cb 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -389,6 +389,49 @@ class ConvMKLDNNHandlerT } } + std::shared_ptr>> get_int8_bias_scales( + const framework::ExecutionContext& ctx) { + // Get scales int8 bias key + const std::string key_bs = this->key_ + "@bs"; + + // Scales for int8 bias are to be cached to avoid + // computing them each iteration + auto bias_scale_tuple = + std::static_pointer_cast>>( + this->dev_ctx_.GetBlob(key_bs)); + if (bias_scale_tuple) return bias_scale_tuple; + + const auto* filter = ctx.Input("Filter"); + const auto& weights_tz = framework::vectorize(filter->dims()); + const int groups = std::max(ctx.Attr("groups"), 1); + + const auto& scale_weights_data = + ctx.Attr>("Scale_weights"); + const auto& scale_in_data = ctx.Attr("Scale_in"); + + bool is_multi_channel = scale_weights_data.size() > 1; + int mask_reorder = is_multi_channel ? 1 << 0 : 1; + + int count = 1; + if (is_multi_channel) { + count *= weights_tz[0]; + if (groups > 1) { + count *= weights_tz[1]; + } + } + + bias_scale_tuple = + std::make_shared>>(std::make_tuple( + static_cast(mask_reorder), std::vector(count))); + for (int i = 0; i < count; i++) { + std::get<1>(*bias_scale_tuple)[i] = scale_in_data * scale_weights_data[i]; + } + + this->dev_ctx_.SetBlob(key_bs, bias_scale_tuple); + + return bias_scale_tuple; + } + std::tuple> get_int8_scales( const framework::ExecutionContext& ctx) const { const auto* filter = ctx.Input("Filter"); @@ -428,32 +471,6 @@ class ConvMKLDNNHandlerT return std::make_tuple(sum_scale, output_shift_scale); } - std::tuple> get_int8_bias_scales( - const framework::ExecutionContext& ctx) const { - const auto* filter = ctx.Input("Filter"); - const auto& weights_tz = framework::vectorize(filter->dims()); - const int groups = std::max(ctx.Attr("groups"), 1); - - const auto& scale_weights_data = - ctx.Attr>("Scale_weights"); - const auto& scale_in_data = ctx.Attr("Scale_in"); - - bool is_multi_channel = scale_weights_data.size() > 1; - int mask_reorder = is_multi_channel ? 1 << 0 : 1; - int count = - is_multi_channel - ? (groups > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0]) - : 1; - std::vector scale_bias_data(count); - -#pragma omp parallel for if (count > 50) - for (int i = 0; i < count; i++) { - scale_bias_data[i] = scale_in_data * scale_weights_data[i]; - } - - return std::make_tuple(mask_reorder, scale_bias_data); - } - mkldnn::primitive_attr CreatePostOps( std::string fuse_activation, float fuse_alpha, float fuse_beta, bool fuse_residual_conn, const std::vector output_shift_scale = {}, @@ -818,13 +835,11 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { {MKLDNN_ARG_DST, *dst_memory_p}}; if (bias) { - float mask_reorder; - std::vector scale_bias_data; - std::tie(mask_reorder, scale_bias_data) = - handler.get_int8_bias_scales(ctx); + auto p_scales_tuple = handler.get_int8_bias_scales(ctx); auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( - bias, is_test, scale_bias_data, mask_reorder); + bias, is_test, std::get<1>(*p_scales_tuple), + std::get<0>(*p_scales_tuple)); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); }