未验证 提交 5434d663 编写于 作者: S Sławomir Siwek 提交者: GitHub

Matmul post-ops for fuses (#43198)

* add method for post ops

* format code

* change post-ops pattern

* code style
上级 fd40502e
...@@ -144,12 +144,6 @@ class ConvMKLDNNHandlerT ...@@ -144,12 +144,6 @@ class ConvMKLDNNHandlerT
bias->dims().size())); bias->dims().size()));
} }
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
const bool fuse_residual_conn =
ctx.Attr<bool>("fuse_residual_connection");
const int groups = ctx.Attr<int>("groups"); const int groups = ctx.Attr<int>("groups");
const std::string padding_algorithm = const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm"); ctx.Attr<std::string>("padding_algorithm");
...@@ -221,24 +215,7 @@ class ConvMKLDNNHandlerT ...@@ -221,24 +215,7 @@ class ConvMKLDNNHandlerT
const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training; : dnnl::prop_kind::forward_training;
float sum_scale = 1.0f; const dnnl::primitive_attr conv_attr = CreateConvAttrs(ctx);
float activation_scale = 1.0f;
std::vector<float> output_shift_scale;
if (platform::is_int8<T>()) {
if (ctx.HasAttr("Sum_scale")) {
sum_scale = ctx.Attr<float>("Sum_scale");
activation_scale = ctx.Attr<float>("Activation_scale");
output_shift_scale =
ctx.Attr<std::vector<float>>("Output_shift_scale");
} else {
std::tie(sum_scale, output_shift_scale, activation_scale) =
get_int8_scales(ctx);
}
}
const dnnl::primitive_attr conv_attr = CreatePostOps(
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
output_shift_scale, sum_scale, activation_scale); // for INT8 only!
if (bias) { if (bias) {
auto bias_tz = phi::vectorize(bias->dims()); auto bias_tz = phi::vectorize(bias->dims());
...@@ -460,11 +437,12 @@ class ConvMKLDNNHandlerT ...@@ -460,11 +437,12 @@ class ConvMKLDNNHandlerT
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights"); auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
bool is_multi_channel = scale_weights_data.size() > 1; bool is_multi_channel = scale_weights_data.size() > 1;
bool has_activation = !ctx.Attr<std::string>("fuse_activation").empty(); bool has_activation = !ctx.Attr<std::string>("fuse_activation").empty();
float activation_scale = force_fp32_output ? 1.0f float activation_scale = (!force_fp32_output && has_activation)
: has_activation ? ctx.Attr<float>("Scale_out") ? ctx.Attr<float>("Scale_out")
: 1.0f; : 1.0f;
auto scale_out_data = force_fp32_output ? 1.0f
: has_activation ? 1.0f float scale_out_data = (force_fp32_output || has_activation)
? 1.0f
: ctx.Attr<float>("Scale_out"); : ctx.Attr<float>("Scale_out");
float sum_scale = float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f; fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
...@@ -490,16 +468,34 @@ class ConvMKLDNNHandlerT ...@@ -490,16 +468,34 @@ class ConvMKLDNNHandlerT
return std::make_tuple(sum_scale, output_shift_scale, activation_scale); return std::make_tuple(sum_scale, output_shift_scale, activation_scale);
} }
dnnl::primitive_attr CreatePostOps( dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) {
std::string fuse_activation, float fuse_alpha, float fuse_beta,
bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
float sum_scale = 1.0f, float activation_scale = 1.0f) {
dnnl::primitive_attr conv_attr; dnnl::primitive_attr conv_attr;
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
float sum_scale = 1.0f;
float activation_scale = 1.0f;
std::vector<float> output_shift_scale;
if (platform::is_int8<T>()) {
if (ctx.HasAttr("Sum_scale")) {
sum_scale = ctx.Attr<float>("Sum_scale");
activation_scale = ctx.Attr<float>("Activation_scale");
output_shift_scale = ctx.Attr<std::vector<float>>("Output_shift_scale");
} else {
std::tie(sum_scale, output_shift_scale, activation_scale) =
get_int8_scales(ctx);
}
if (output_shift_scale.size() > 0) { if (output_shift_scale.size() > 0) {
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale); conv_attr.set_output_scales(mask, output_shift_scale);
} }
}
// Fusion with Elementwise layer relies on adding a sum post-operation with // Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is // the scale parameter. It is assumed that when fuse_residual_connection is
......
...@@ -139,10 +139,6 @@ class ConvTransposeMKLDNNHandlerT ...@@ -139,10 +139,6 @@ class ConvTransposeMKLDNNHandlerT
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
const auto chosen_memory_format = MKLDNNMemoryFormat::any; const auto chosen_memory_format = MKLDNNMemoryFormat::any;
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
auto data_type = dnnl::memory::data_type::f32; auto data_type = dnnl::memory::data_type::f32;
if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" || if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
...@@ -156,8 +152,7 @@ class ConvTransposeMKLDNNHandlerT ...@@ -156,8 +152,7 @@ class ConvTransposeMKLDNNHandlerT
const auto dst_md = platform::MKLDNNMemDesc( const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
const dnnl::primitive_attr conv_trans_attr = const dnnl::primitive_attr conv_trans_attr = CreateConvAttrs(ctx);
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta);
auto fwd_prop_kind = is_test_ ? dnnl::prop_kind::forward_inference auto fwd_prop_kind = is_test_ ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training; : dnnl::prop_kind::forward_training;
if (bias) { if (bias) {
...@@ -176,12 +171,15 @@ class ConvTransposeMKLDNNHandlerT ...@@ -176,12 +171,15 @@ class ConvTransposeMKLDNNHandlerT
} }
} }
dnnl::primitive_attr CreatePostOps(const std::string& fuse_activation, dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) {
const float& fuse_alpha,
const float& fuse_beta) {
dnnl::primitive_attr conv_attr; dnnl::primitive_attr conv_attr;
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
// Fusion with ReLU layer is executed through the PostOps feature. Create a // Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation. // PostOps object and configure it to execute an eltwise relu operation.
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") { if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
......
...@@ -201,7 +201,7 @@ class FCPrimitiveFactory { ...@@ -201,7 +201,7 @@ class FCPrimitiveFactory {
CreateMemDescriptor<T_w>(weight_dims, MKLDNNMemoryFormat::any); CreateMemDescriptor<T_w>(weight_dims, MKLDNNMemoryFormat::any);
auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x); auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x);
auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any); auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any);
const auto attrs = CreatePostOps(ctx); const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
} }
...@@ -230,7 +230,7 @@ class FCPrimitiveFactory { ...@@ -230,7 +230,7 @@ class FCPrimitiveFactory {
auto dst_dims = {input_dims[0] * input_dims[1], weight_dims[0]}; auto dst_dims = {input_dims[0] * input_dims[1], weight_dims[0]};
auto dst_desc = auto dst_desc =
CreateMemDescriptor<T_out>(dst_dims, MKLDNNMemoryFormat::any); CreateMemDescriptor<T_out>(dst_dims, MKLDNNMemoryFormat::any);
const auto attrs = CreatePostOps(ctx); const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
} }
...@@ -255,7 +255,7 @@ class FCPrimitiveFactory { ...@@ -255,7 +255,7 @@ class FCPrimitiveFactory {
auto weights_desc = CreateMemDescriptor<T_w>(dims, MKLDNNMemoryFormat::any); auto weights_desc = CreateMemDescriptor<T_w>(dims, MKLDNNMemoryFormat::any);
auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x); auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x);
auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any); auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any);
const auto attrs = CreatePostOps(ctx); const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
} }
...@@ -455,8 +455,7 @@ class FCPrimitiveFactory { ...@@ -455,8 +455,7 @@ class FCPrimitiveFactory {
bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales); bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales);
} }
// Fuse relu into FC with activation type attribute has been set to 'relu' dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) {
dnnl::primitive_attr CreatePostOps(const ExecutionContext& ctx) {
dnnl::primitive_attr attributes; dnnl::primitive_attr attributes;
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
...@@ -465,8 +464,8 @@ class FCPrimitiveFactory { ...@@ -465,8 +464,8 @@ class FCPrimitiveFactory {
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx); std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1); int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale); attributes.set_output_scales(mask, output_shift_scale);
float sum_scale = 1.0f;
float sum_scale = 1.0f;
if (ctx.HasAttr("fuse_residual_connection") && if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) { ctx.Attr<bool>("fuse_residual_connection")) {
post_operations.append_sum(sum_scale); post_operations.append_sum(sum_scale);
......
...@@ -147,16 +147,10 @@ class MatMulMKLDNNHandler ...@@ -147,16 +147,10 @@ class MatMulMKLDNNHandler
this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
} }
// Constructor for FWD MatMul // Constructor for FWD MatMul
MatMulMKLDNNHandler(const dnnl::engine engine, const ExecutionContext& ctx, MatMulMKLDNNHandler(const dnnl::engine engine, const ExecutionContext& ctx)
float scale)
: paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>( : paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(
engine, ctx.GetPlace()) { engine, ctx.GetPlace()) {
dnnl::primitive_attr attr; const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx);
float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) {
constexpr unsigned tensor_wide_scale = 0;
attr.set_output_scales(tensor_wide_scale, {scale_out});
}
auto matmul_dims_ = GetMatmulDims(ctx); auto matmul_dims_ = GetMatmulDims(ctx);
auto x_md = memory::desc(matmul_dims_.x_dims, MKLDNNGetDataType<XT>(), auto x_md = memory::desc(matmul_dims_.x_dims, MKLDNNGetDataType<XT>(),
...@@ -165,7 +159,7 @@ class MatMulMKLDNNHandler ...@@ -165,7 +159,7 @@ class MatMulMKLDNNHandler
matmul_dims_.y_strides); matmul_dims_.y_strides);
auto out_md = memory::desc(matmul_dims_.out_dims, MKLDNNGetDataType<OT>(), auto out_md = memory::desc(matmul_dims_.out_dims, MKLDNNGetDataType<OT>(),
matmul_dims_.out_strides); matmul_dims_.out_strides);
this->AcquireForwardPrimitiveDescriptor(attr, x_md, y_md, out_md); this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
} }
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) { std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
...@@ -429,6 +423,19 @@ class MatMulMKLDNNHandler ...@@ -429,6 +423,19 @@ class MatMulMKLDNNHandler
return std::make_tuple(x_offset_, y_offset_, out_offset_); return std::make_tuple(x_offset_, y_offset_, out_offset_);
} }
dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext& ctx) {
dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
private: private:
uint32_t x_offset_; uint32_t x_offset_;
uint32_t y_offset_; uint32_t y_offset_;
...@@ -499,23 +506,19 @@ static void ExecuteMatMul(const ExecutionContext& ctx) { ...@@ -499,23 +506,19 @@ static void ExecuteMatMul(const ExecutionContext& ctx) {
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
const auto& dev_ctx = const auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
MatMulMKLDNNHandler<XT, YT, float>(dev_ctx.GetEngine(), ctx, alpha) MatMulMKLDNNHandler<XT, YT, float>(onednn_engine, ctx).Execute(x, y, out);
.Execute(x, y, out);
} else if (is_bfloat16) { } else if (is_bfloat16) {
MatMulMKLDNNHandler<XT, YT, paddle::platform::bfloat16>(dev_ctx.GetEngine(), MatMulMKLDNNHandler<XT, YT, paddle::platform::bfloat16>(onednn_engine, ctx)
ctx, alpha)
.Execute(x, y, out); .Execute(x, y, out);
} else if (fuse_relu) { } else if (fuse_relu) {
MatMulMKLDNNHandler<XT, YT, uint8_t>(dev_ctx.GetEngine(), ctx, alpha) MatMulMKLDNNHandler<XT, YT, uint8_t>(onednn_engine, ctx).Execute(x, y, out);
.Execute(x, y, out);
} else { } else {
MatMulMKLDNNHandler<XT, YT, int8_t>(dev_ctx.GetEngine(), ctx, alpha) MatMulMKLDNNHandler<XT, YT, int8_t>(onednn_engine, ctx).Execute(x, y, out);
.Execute(x, y, out);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册