未验证 提交 5434d663 编写于 作者: S Sławomir Siwek 提交者: GitHub

Matmul post-ops for fuses (#43198)

* add method for post ops

* format code

* change post-ops pattern

* code style
上级 fd40502e
......@@ -144,12 +144,6 @@ class ConvMKLDNNHandlerT
bias->dims().size()));
}
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
const bool fuse_residual_conn =
ctx.Attr<bool>("fuse_residual_connection");
const int groups = ctx.Attr<int>("groups");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
......@@ -221,24 +215,7 @@ class ConvMKLDNNHandlerT
const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training;
float sum_scale = 1.0f;
float activation_scale = 1.0f;
std::vector<float> output_shift_scale;
if (platform::is_int8<T>()) {
if (ctx.HasAttr("Sum_scale")) {
sum_scale = ctx.Attr<float>("Sum_scale");
activation_scale = ctx.Attr<float>("Activation_scale");
output_shift_scale =
ctx.Attr<std::vector<float>>("Output_shift_scale");
} else {
std::tie(sum_scale, output_shift_scale, activation_scale) =
get_int8_scales(ctx);
}
}
const dnnl::primitive_attr conv_attr = CreatePostOps(
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
output_shift_scale, sum_scale, activation_scale); // for INT8 only!
const dnnl::primitive_attr conv_attr = CreateConvAttrs(ctx);
if (bias) {
auto bias_tz = phi::vectorize(bias->dims());
......@@ -460,11 +437,12 @@ class ConvMKLDNNHandlerT
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
bool is_multi_channel = scale_weights_data.size() > 1;
bool has_activation = !ctx.Attr<std::string>("fuse_activation").empty();
float activation_scale = force_fp32_output ? 1.0f
: has_activation ? ctx.Attr<float>("Scale_out")
float activation_scale = (!force_fp32_output && has_activation)
? ctx.Attr<float>("Scale_out")
: 1.0f;
auto scale_out_data = force_fp32_output ? 1.0f
: has_activation ? 1.0f
float scale_out_data = (force_fp32_output || has_activation)
? 1.0f
: ctx.Attr<float>("Scale_out");
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
......@@ -490,16 +468,34 @@ class ConvMKLDNNHandlerT
return std::make_tuple(sum_scale, output_shift_scale, activation_scale);
}
dnnl::primitive_attr CreatePostOps(
std::string fuse_activation, float fuse_alpha, float fuse_beta,
bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
float sum_scale = 1.0f, float activation_scale = 1.0f) {
dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) {
dnnl::primitive_attr conv_attr;
dnnl::post_ops post_operations;
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
float sum_scale = 1.0f;
float activation_scale = 1.0f;
std::vector<float> output_shift_scale;
if (platform::is_int8<T>()) {
if (ctx.HasAttr("Sum_scale")) {
sum_scale = ctx.Attr<float>("Sum_scale");
activation_scale = ctx.Attr<float>("Activation_scale");
output_shift_scale = ctx.Attr<std::vector<float>>("Output_shift_scale");
} else {
std::tie(sum_scale, output_shift_scale, activation_scale) =
get_int8_scales(ctx);
}
if (output_shift_scale.size() > 0) {
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
}
}
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is
......
......@@ -139,10 +139,6 @@ class ConvTransposeMKLDNNHandlerT
* the memory format preferred for best performance
*/
const auto chosen_memory_format = MKLDNNMemoryFormat::any;
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
auto data_type = dnnl::memory::data_type::f32;
if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
......@@ -156,8 +152,7 @@ class ConvTransposeMKLDNNHandlerT
const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
const dnnl::primitive_attr conv_trans_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta);
const dnnl::primitive_attr conv_trans_attr = CreateConvAttrs(ctx);
auto fwd_prop_kind = is_test_ ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training;
if (bias) {
......@@ -176,12 +171,15 @@ class ConvTransposeMKLDNNHandlerT
}
}
dnnl::primitive_attr CreatePostOps(const std::string& fuse_activation,
const float& fuse_alpha,
const float& fuse_beta) {
dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) {
dnnl::primitive_attr conv_attr;
dnnl::post_ops post_operations;
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
......
......@@ -201,7 +201,7 @@ class FCPrimitiveFactory {
CreateMemDescriptor<T_w>(weight_dims, MKLDNNMemoryFormat::any);
auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x);
auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any);
const auto attrs = CreatePostOps(ctx);
const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
}
......@@ -230,7 +230,7 @@ class FCPrimitiveFactory {
auto dst_dims = {input_dims[0] * input_dims[1], weight_dims[0]};
auto dst_desc =
CreateMemDescriptor<T_out>(dst_dims, MKLDNNMemoryFormat::any);
const auto attrs = CreatePostOps(ctx);
const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
}
......@@ -255,7 +255,7 @@ class FCPrimitiveFactory {
auto weights_desc = CreateMemDescriptor<T_w>(dims, MKLDNNMemoryFormat::any);
auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x);
auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any);
const auto attrs = CreatePostOps(ctx);
const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
}
......@@ -455,8 +455,7 @@ class FCPrimitiveFactory {
bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales);
}
// Fuse relu into FC with activation type attribute has been set to 'relu'
dnnl::primitive_attr CreatePostOps(const ExecutionContext& ctx) {
dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) {
dnnl::primitive_attr attributes;
dnnl::post_ops post_operations;
......@@ -465,8 +464,8 @@ class FCPrimitiveFactory {
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale);
float sum_scale = 1.0f;
float sum_scale = 1.0f;
if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) {
post_operations.append_sum(sum_scale);
......
......@@ -147,16 +147,10 @@ class MatMulMKLDNNHandler
this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
}
// Constructor for FWD MatMul
MatMulMKLDNNHandler(const dnnl::engine engine, const ExecutionContext& ctx,
float scale)
MatMulMKLDNNHandler(const dnnl::engine engine, const ExecutionContext& ctx)
: paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(
engine, ctx.GetPlace()) {
dnnl::primitive_attr attr;
float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) {
constexpr unsigned tensor_wide_scale = 0;
attr.set_output_scales(tensor_wide_scale, {scale_out});
}
const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx);
auto matmul_dims_ = GetMatmulDims(ctx);
auto x_md = memory::desc(matmul_dims_.x_dims, MKLDNNGetDataType<XT>(),
......@@ -165,7 +159,7 @@ class MatMulMKLDNNHandler
matmul_dims_.y_strides);
auto out_md = memory::desc(matmul_dims_.out_dims, MKLDNNGetDataType<OT>(),
matmul_dims_.out_strides);
this->AcquireForwardPrimitiveDescriptor(attr, x_md, y_md, out_md);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
......@@ -429,6 +423,19 @@ class MatMulMKLDNNHandler
return std::make_tuple(x_offset_, y_offset_, out_offset_);
}
dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext& ctx) {
dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
private:
uint32_t x_offset_;
uint32_t y_offset_;
......@@ -499,23 +506,19 @@ static void ExecuteMatMul(const ExecutionContext& ctx) {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out");
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
const auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
MatMulMKLDNNHandler<XT, YT, float>(dev_ctx.GetEngine(), ctx, alpha)
.Execute(x, y, out);
MatMulMKLDNNHandler<XT, YT, float>(onednn_engine, ctx).Execute(x, y, out);
} else if (is_bfloat16) {
MatMulMKLDNNHandler<XT, YT, paddle::platform::bfloat16>(dev_ctx.GetEngine(),
ctx, alpha)
MatMulMKLDNNHandler<XT, YT, paddle::platform::bfloat16>(onednn_engine, ctx)
.Execute(x, y, out);
} else if (fuse_relu) {
MatMulMKLDNNHandler<XT, YT, uint8_t>(dev_ctx.GetEngine(), ctx, alpha)
.Execute(x, y, out);
MatMulMKLDNNHandler<XT, YT, uint8_t>(onednn_engine, ctx).Execute(x, y, out);
} else {
MatMulMKLDNNHandler<XT, YT, int8_t>(dev_ctx.GetEngine(), ctx, alpha)
.Execute(x, y, out);
MatMulMKLDNNHandler<XT, YT, int8_t>(onednn_engine, ctx).Execute(x, y, out);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册