未验证 提交 197a4ffe 编写于 作者: P Paulina Gacek 提交者: GitHub

fuse quantize+transpose and transpose+dequantize (#49509)

* QuantTranpose pattern is being found by pass

* quant + transpose fuse

* code style changes

* UT written, reorder fixed

* Dequantize + transpose2 fuse  added

* pass name changed

* UT added & shift corrected

* got rid of redundancy

* review changes

* AsIntermediate corrected

* compat added
上级 b47923b4
......@@ -177,6 +177,7 @@ if(WITH_MKLDNN)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(layer_norm_onednn_optimization_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(quant_transpose2_dequant_onednn_fuse_pass inference DIR mkldnn)
pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
......
......@@ -979,6 +979,44 @@ PDNode *patterns::OperatorActivation::operator()(
return activation_out;
}
PDNode *patterns::QuantTranspose2::operator()() {
auto *quant_in = pattern->NewNode(quant_in_repr())
->AsInput()
->assert_is_op_input("quantize", "Input");
auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
auto *quant_out = pattern->NewNode(quant_out_repr())
->AsOutput()
->AsIntermediate()
->assert_has_n_outputs(1)
->assert_is_op_output("quantize")
->assert_is_op_input("transpose2", "X");
auto *transpose2_op =
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2");
quant_op->LinksFrom({quant_in}).LinksTo({quant_out});
transpose2_op->LinksFrom({quant_out});
return transpose2_op;
}
PDNode *patterns::Transpose2Dequant::operator()() {
auto *transpose2_op =
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2");
auto dequant_in = pattern->NewNode(dequant_in_repr())
->AsIntermediate()
->assert_has_n_inputs(1)
->assert_is_op_input("dequantize", "Input");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
transpose2_op->LinksTo({dequant_in});
dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out});
return dequant_out;
}
PDNode *patterns::Squeeze2Transpose2::operator()() {
auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr())
->AsInput()
......
......@@ -552,6 +552,29 @@ struct OperatorActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out);
};
struct QuantTranspose2 : public PatternBase {
QuantTranspose2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_transpose2") {}
PDNode* operator()();
PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(quant_out);
PATTERN_DECL_NODE(transpose2_op);
};
struct Transpose2Dequant : public PatternBase {
Transpose2Dequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "transpose2_dequant") {}
PDNode* operator()();
PATTERN_DECL_NODE(transpose2_op);
PATTERN_DECL_NODE(dequant_in);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};
struct Squeeze2Transpose2 : public PatternBase {
Squeeze2Transpose2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "squeeze2_transpose2") {}
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2(
Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::QuantTranspose2 quant_transpose2_pattern(gpd.mutable_pattern(),
name_scope);
quant_transpose2_pattern();
int found_patterns_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, quant_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, quant_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, quant_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_op, transpose2_op, quant_transpose2_pattern);
if (!transpose2_op->Op()->HasAttr("use_mkldnn") ||
!(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn")))) {
VLOG(4)
<< "Only oneDNN version of transpose2 can be fused with quantize.";
return;
}
float scale =
quant_op->Op()->HasAttr("Scale")
? PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Scale"))
: 1;
float shift =
quant_op->Op()->HasAttr("Shift")
? PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Shift"))
: 0;
transpose2_op->Op()->SetAttr("scale", scale);
transpose2_op->Op()->SetAttr("shift", shift);
bool is_negative_output =
quant_op->Op()->HasAttr("is_negative_input")
? PADDLE_GET_CONST(bool,
quant_op->Op()->GetAttr("is_negative_input"))
: false;
bool is_bfloat =
quant_op->Op()->HasAttr("bfloat16")
? PADDLE_GET_CONST(bool, quant_op->Op()->GetAttr("bfloat16"))
: false;
std::string output_dtype;
if (is_bfloat) {
output_dtype = "bf16";
} else if (is_negative_output) {
output_dtype = "int8";
} else {
output_dtype = "uint8";
}
transpose2_op->Op()->SetAttr("output_data_type", output_dtype);
transpose2_op->Op()->SetInput("X",
std::vector<std::string>({quant_in->Name()}));
IR_NODE_LINK_TO(quant_in, transpose2_op);
GraphSafeRemoveNodes(graph, {quant_op, quant_out});
found_patterns_count++;
};
gpd(graph, handler);
AddStatis(found_patterns_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs"))) {
paddle::string::PrettyLogDetail("--- fused %d quant with transpose2",
found_patterns_count);
}
}
void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize(
Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::Transpose2Dequant transpose2_dequant_pattern(gpd.mutable_pattern(),
name_scope);
transpose2_dequant_pattern();
int found_patterns_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_op, transpose2_op, transpose2_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dequant_in, dequant_in, transpose2_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dequant_op, dequant_op, transpose2_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dequant_out, dequant_out, transpose2_dequant_pattern);
if (!transpose2_op->Op()->HasAttr("use_mkldnn") ||
!(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn")))) {
VLOG(4)
<< "Only oneDNN version of transpose2 can be fused with dequantize.";
return;
}
float scale =
dequant_op->Op()->HasAttr("Scale")
? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Scale"))
: 1;
float reorder_scale = 1.0 / scale;
float shift =
dequant_op->Op()->HasAttr("Shift")
? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Shift"))
: 0;
transpose2_op->Op()->SetAttr("scale", reorder_scale);
transpose2_op->Op()->SetAttr("shift", shift);
transpose2_op->Op()->SetAttr("output_data_type", std::string("fp32"));
transpose2_op->Op()->SetOutput(
"Out", std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(transpose2_op, dequant_out);
GraphSafeRemoveNodes(graph, {dequant_in, dequant_op});
found_patterns_count++;
};
gpd(graph, handler);
AddStatis(found_patterns_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs"))) {
paddle::string::PrettyLogDetail("--- fused %d transpose2 with dequant",
found_patterns_count);
}
}
void FuseQuantTranspose2DequantOneDNNPass::ApplyImpl(Graph *graph) const {
FuseQuantizeTranspose2(graph);
FuseTranspose2Dequantize(graph);
}
FuseQuantTranspose2DequantOneDNNPass::FuseQuantTranspose2DequantOneDNNPass() {
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(quant_transpose2_dequant_onednn_fuse_pass,
paddle::framework::ir::FuseQuantTranspose2DequantOneDNNPass);
REGISTER_PASS_CAPABILITY(quant_transpose2_dequant_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"transpose2", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseQuantTranspose2DequantOneDNNPass : public FusePassBase {
public:
virtual ~FuseQuantTranspose2DequantOneDNNPass() {}
FuseQuantTranspose2DequantOneDNNPass();
protected:
void ApplyImpl(Graph *graph) const override;
void FuseQuantizeTranspose2(Graph *graph) const;
void FuseTranspose2Dequantize(Graph *graph) const;
private:
std::string name_scope = "quant_transpose2_dequant_onednn_fuse_pass";
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -475,6 +475,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass");
passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass");
passes_.push_back("int8_scale_calculation_mkldnn_pass");
passes_.push_back("params_quantization_mkldnn_pass");
}
......
......@@ -122,6 +122,9 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"Bias_scales", ExtraAttrProperty::ONEDNN},
{"Output_shift_scale", ExtraAttrProperty::ONEDNN},
{"Sum_scale", ExtraAttrProperty::ONEDNN},
{"scale", ExtraAttrProperty::ONEDNN},
{"shift", ExtraAttrProperty::ONEDNN},
{"output_data_type", ExtraAttrProperty::ONEDNN},
// GPUDNN dedicated attributes
{"exhaustive_search", ExtraAttrProperty::GPUDNN},
{"fuse_relu_before_depthwise_conv", ExtraAttrProperty::GPUDNN},
......
......@@ -86,31 +86,69 @@ void TransposeKernel(const Context& dev_ctx,
auto x_vec_dims = vectorize(x.dims());
auto x_type = funcs::ToOneDNNDataType(x.dtype());
dnnl::primitive_attr attrs;
const int32_t mask = 0;
const auto quantization_scale =
dev_ctx.HasDnnAttr("scale")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("scale"))
: 1.0f;
const auto quantization_shift =
dev_ctx.HasDnnAttr("shift")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("shift"))
: 0.0f;
const auto output_data_type =
dev_ctx.HasDnnAttr("output_data_type")
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("output_data_type"))
: "";
const bool with_scale = quantization_scale != 1.0f;
const bool with_shift = quantization_shift != 0.0f;
if (with_scale) {
attrs.set_output_scales(mask, {quantization_scale});
}
if (with_shift) {
auto dst = output_data_type == "fp32" ? DNNL_ARG_SRC : DNNL_ARG_DST;
attrs.set_zero_points(
dst, mask, {static_cast<int32_t>(quantization_shift)});
}
DataType out_dtype;
if (output_data_type == "bf16") {
out_dtype = DataType::BFLOAT16;
} else if (output_data_type == "int8") {
out_dtype = DataType::INT8;
} else if (output_data_type == "uint8") {
out_dtype = DataType::UINT8;
} else if (output_data_type == "fp32") {
out_dtype = DataType::FLOAT32;
} else {
out_dtype = x.dtype();
}
auto out_type = phi::funcs::ToOneDNNDataType(out_dtype);
funcs::ReorderOneDNNHandler reorder_handler(
x_vec_dims, x.dtype(), x_type, dev_ctx.GetEngine());
x_vec_dims, x.dtype(), x_type, out_dtype, out_type, dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x.mem_desc(), funcs::to_void_cast(x.data<T>()));
auto dst_md =
dnnl::memory::desc(x_vec_dims,
x.mem_desc().data_type(),
funcs::GetPlainOneDNNFormat(x_vec_dims.size()));
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
std::vector<int64_t> fake_strides(axis.size());
auto dims = dst_md.dims();
int total_stride = 1;
for (int i = static_cast<int>(dims.size()) - 1; i >= 0; --i) {
for (int i = static_cast<int>(x_vec_dims.size()) - 1; i >= 0; --i) {
fake_strides[axis[i]] = total_stride;
total_stride *= dims[axis[i]];
total_stride *= x_vec_dims[axis[i]];
}
dst_md =
dnnl::memory::desc(x_vec_dims, x.mem_desc().data_type(), fake_strides);
auto dst_data = dev_ctx.template Alloc<T>(out);
auto dst_md = dnnl::memory::desc(x_vec_dims, out_type, fake_strides);
auto reorder_dst_memory_p =
std::make_shared<dnnl::memory>(dst_md, dev_ctx.GetEngine(), dst_data);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
reorder_handler.AcquireDstMemory(out, dst_md, dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
......@@ -122,6 +160,7 @@ void TransposeKernel(const Context& dev_ctx,
for (size_t i = 0; i < axis.size(); ++i) {
permute_axis[axis[i]] = i;
}
funcs::SetOutMemDescWithLogicalLayoutFusesSupport(
dev_ctx,
out,
......
......@@ -448,6 +448,8 @@ def create_quant_model(
"pad2d",
"reshape",
"layer_norm",
"quantize",
"dequantize",
]
op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
......@@ -497,6 +499,8 @@ def create_quant_model(
"pad2d": [["X"], ["Out"]],
"flatten": [["X"], ["Out"]],
"flatten2": [["X"], ["Out"]],
"quantize": [["Input"], ["Output"]],
"dequantize": [["Input"], ["Output"]],
}
def _get_op_output_var_names(op):
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestQuantTranspose2DequantOneDNNFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_config(self, draw):
transpose_X = draw(st.booleans())
axis = draw(st.sampled_from([[0, 2, 1, 3]]))
batch_size = draw(st.integers(min_value=1, max_value=4))
channel = draw(st.integers(min_value=1, max_value=64))
input_dim = draw(st.sampled_from([32, 64]))
scale = draw(st.floats(min_value=1, max_value=16))
shift = draw(st.integers(min_value=1, max_value=3))
is_negative_input = draw(st.booleans())
def generate_input():
if transpose_X:
shape_x = [batch_size, channel, input_dim, 32]
else:
shape_x = [batch_size, channel, 32, input_dim]
return np.random.random(shape_x).astype(np.float32)
quantize_op = OpConfig(
type='quantize',
inputs={'Input': ['input_data']},
outputs={'Output': ['quantize_output']},
attrs={
'is_negative_input': is_negative_input,
'Scale': scale,
'Shift': shift,
},
)
transpose2_op_1 = OpConfig(
type='transpose2',
inputs={'X': ['quantize_output']},
outputs={
'Out': ['transpose2_output_1'],
'XShape': ['transpose2_xshape'],
},
attrs={
'axis': axis,
'use_mkldnn': True,
'mkldnn_data_type': 'int8',
},
use_mkldnn=True,
)
transpose2_op_2 = OpConfig(
type='transpose2',
inputs={'X': ['transpose2_output_1']},
outputs={
'Out': ['transpose2_output_2'],
'XShape': ['transpose2_xshape'],
},
attrs={
'axis': axis,
'use_mkldnn': True,
'mkldnn_data_type': 'int8',
},
use_mkldnn=True,
)
dequantize_op = OpConfig(
type='dequantize',
inputs={'Input': ['transpose2_output_2']},
outputs={'Output': ['dequantize_output']},
attrs={
'Scale': scale,
'Shift': shift,
},
)
program_config = ProgramConfig(
ops=[quantize_op, transpose2_op_1, transpose2_op_2, dequantize_op],
weights={},
inputs={
'input_data': TensorConfig(data_gen=partial(generate_input))
},
outputs=['dequantize_output'],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True,
passes=['quant_transpose2_dequant_onednn_fuse_pass'],
)
yield config, ['transpose2', 'transpose2'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False, passes=['quant_transpose2_dequant_onednn_fuse_pass']
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册