From 197a4ffee970c807057aeb10df54f607987a8e21 Mon Sep 17 00:00:00 2001 From: Paulina Gacek Date: Wed, 8 Feb 2023 14:04:04 +0100 Subject: [PATCH] fuse quantize+transpose and transpose+dequantize (#49509) * QuantTranpose pattern is being found by pass * quant + transpose fuse * code style changes * UT written, reorder fixed * Dequantize + transpose2 fuse added * pass name changed * UT added & shift corrected * got rid of redundancy * review changes * AsIntermediate corrected * compat added --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/graph_pattern_detector.cc | 38 ++++ .../framework/ir/graph_pattern_detector.h | 23 +++ ...ant_transpose2_dequant_onednn_fuse_pass.cc | 194 ++++++++++++++++++ ...uant_transpose2_dequant_onednn_fuse_pass.h | 41 ++++ .../inference/api/paddle_pass_builder.cc | 1 + paddle/fluid/operators/ops_extra_info.h | 3 + paddle/phi/kernels/onednn/transpose_kernel.cc | 67 ++++-- .../unittests/ir/inference/program_config.py | 4 + ...nednn_quant_transpose_dequant_fuse_pass.py | 121 +++++++++++ 10 files changed, 479 insertions(+), 14 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_onednn_quant_transpose_dequant_fuse_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index cd93a71720..1bee7f9960 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -177,6 +177,7 @@ if(WITH_MKLDNN) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(layer_norm_onednn_optimization_pass inference DIR mkldnn) pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) + pass_library(quant_transpose2_dequant_onednn_fuse_pass inference DIR mkldnn) pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6591ede1f6..90230ae9bf 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -979,6 +979,44 @@ PDNode *patterns::OperatorActivation::operator()( return activation_out; } +PDNode *patterns::QuantTranspose2::operator()() { + auto *quant_in = pattern->NewNode(quant_in_repr()) + ->AsInput() + ->assert_is_op_input("quantize", "Input"); + auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize"); + auto *quant_out = pattern->NewNode(quant_out_repr()) + ->AsOutput() + ->AsIntermediate() + ->assert_has_n_outputs(1) + ->assert_is_op_output("quantize") + ->assert_is_op_input("transpose2", "X"); + auto *transpose2_op = + pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); + + quant_op->LinksFrom({quant_in}).LinksTo({quant_out}); + transpose2_op->LinksFrom({quant_out}); + + return transpose2_op; +} + +PDNode *patterns::Transpose2Dequant::operator()() { + auto *transpose2_op = + pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); + auto dequant_in = pattern->NewNode(dequant_in_repr()) + ->AsIntermediate() + ->assert_has_n_inputs(1) + ->assert_is_op_input("dequantize", "Input"); + auto dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); + auto dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize", "Output"); + + transpose2_op->LinksTo({dequant_in}); + dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out}); + return dequant_out; +} + PDNode *patterns::Squeeze2Transpose2::operator()() { auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr()) ->AsInput() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 5ff8498d2d..eca263d724 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -552,6 +552,29 @@ struct OperatorActivation : public PatternBase { PATTERN_DECL_NODE(activation_out); }; +struct QuantTranspose2 : public PatternBase { + QuantTranspose2(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "quant_transpose2") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(quant_in); + PATTERN_DECL_NODE(quant_op); + PATTERN_DECL_NODE(quant_out); + PATTERN_DECL_NODE(transpose2_op); +}; + +struct Transpose2Dequant : public PatternBase { + Transpose2Dequant(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "transpose2_dequant") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(transpose2_op); + PATTERN_DECL_NODE(dequant_in); + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); +}; + struct Squeeze2Transpose2 : public PatternBase { Squeeze2Transpose2(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "squeeze2_transpose2") {} diff --git a/paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.cc new file mode 100644 index 0000000000..61f635c77f --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.cc @@ -0,0 +1,194 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2( + Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::QuantTranspose2 quant_transpose2_pattern(gpd.mutable_pattern(), + name_scope); + quant_transpose2_pattern(); + + int found_patterns_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, quant_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, quant_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, quant_transpose2_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_op, transpose2_op, quant_transpose2_pattern); + + if (!transpose2_op->Op()->HasAttr("use_mkldnn") || + !(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn")))) { + VLOG(4) + << "Only oneDNN version of transpose2 can be fused with quantize."; + return; + } + + float scale = + quant_op->Op()->HasAttr("Scale") + ? PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Scale")) + : 1; + float shift = + quant_op->Op()->HasAttr("Shift") + ? PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Shift")) + : 0; + + transpose2_op->Op()->SetAttr("scale", scale); + transpose2_op->Op()->SetAttr("shift", shift); + + bool is_negative_output = + quant_op->Op()->HasAttr("is_negative_input") + ? PADDLE_GET_CONST(bool, + quant_op->Op()->GetAttr("is_negative_input")) + : false; + bool is_bfloat = + quant_op->Op()->HasAttr("bfloat16") + ? PADDLE_GET_CONST(bool, quant_op->Op()->GetAttr("bfloat16")) + : false; + + std::string output_dtype; + if (is_bfloat) { + output_dtype = "bf16"; + } else if (is_negative_output) { + output_dtype = "int8"; + } else { + output_dtype = "uint8"; + } + transpose2_op->Op()->SetAttr("output_data_type", output_dtype); + transpose2_op->Op()->SetInput("X", + std::vector({quant_in->Name()})); + + IR_NODE_LINK_TO(quant_in, transpose2_op); + GraphSafeRemoveNodes(graph, {quant_op, quant_out}); + found_patterns_count++; + }; + gpd(graph, handler); + AddStatis(found_patterns_count); + if ((!Has("disable_logs") || !Get("disable_logs"))) { + paddle::string::PrettyLogDetail("--- fused %d quant with transpose2", + found_patterns_count); + } +} + +void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize( + Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::Transpose2Dequant transpose2_dequant_pattern(gpd.mutable_pattern(), + name_scope); + transpose2_dequant_pattern(); + + int found_patterns_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_op, transpose2_op, transpose2_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dequant_in, dequant_in, transpose2_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dequant_op, dequant_op, transpose2_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dequant_out, dequant_out, transpose2_dequant_pattern); + + if (!transpose2_op->Op()->HasAttr("use_mkldnn") || + !(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn")))) { + VLOG(4) + << "Only oneDNN version of transpose2 can be fused with dequantize."; + return; + } + + float scale = + dequant_op->Op()->HasAttr("Scale") + ? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Scale")) + : 1; + float reorder_scale = 1.0 / scale; + float shift = + dequant_op->Op()->HasAttr("Shift") + ? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Shift")) + : 0; + + transpose2_op->Op()->SetAttr("scale", reorder_scale); + transpose2_op->Op()->SetAttr("shift", shift); + transpose2_op->Op()->SetAttr("output_data_type", std::string("fp32")); + transpose2_op->Op()->SetOutput( + "Out", std::vector({dequant_out->Name()})); + + IR_NODE_LINK_TO(transpose2_op, dequant_out); + GraphSafeRemoveNodes(graph, {dequant_in, dequant_op}); + found_patterns_count++; + }; + + gpd(graph, handler); + AddStatis(found_patterns_count); + if ((!Has("disable_logs") || !Get("disable_logs"))) { + paddle::string::PrettyLogDetail("--- fused %d transpose2 with dequant", + found_patterns_count); + } +} + +void FuseQuantTranspose2DequantOneDNNPass::ApplyImpl(Graph *graph) const { + FuseQuantizeTranspose2(graph); + FuseTranspose2Dequantize(graph); +} + +FuseQuantTranspose2DequantOneDNNPass::FuseQuantTranspose2DequantOneDNNPass() { + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(quant_transpose2_dequant_onednn_fuse_pass, + paddle::framework::ir::FuseQuantTranspose2DequantOneDNNPass); +REGISTER_PASS_CAPABILITY(quant_transpose2_dequant_onednn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "transpose2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h new file mode 100644 index 0000000000..6959f3006e --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h @@ -0,0 +1,41 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FuseQuantTranspose2DequantOneDNNPass : public FusePassBase { + public: + virtual ~FuseQuantTranspose2DequantOneDNNPass() {} + FuseQuantTranspose2DequantOneDNNPass(); + + protected: + void ApplyImpl(Graph *graph) const override; + void FuseQuantizeTranspose2(Graph *graph) const; + void FuseTranspose2Dequantize(Graph *graph) const; + + private: + std::string name_scope = "quant_transpose2_dequant_onednn_fuse_pass"; +}; + +} // namespace ir +} // namespace framework + +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 903c01f636..5f17cdde4e 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -475,6 +475,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_squash_pass"); + passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass"); passes_.push_back("int8_scale_calculation_mkldnn_pass"); passes_.push_back("params_quantization_mkldnn_pass"); } diff --git a/paddle/fluid/operators/ops_extra_info.h b/paddle/fluid/operators/ops_extra_info.h index 10ee3994b5..0f7f6d8b21 100644 --- a/paddle/fluid/operators/ops_extra_info.h +++ b/paddle/fluid/operators/ops_extra_info.h @@ -122,6 +122,9 @@ const std::unordered_map {"Bias_scales", ExtraAttrProperty::ONEDNN}, {"Output_shift_scale", ExtraAttrProperty::ONEDNN}, {"Sum_scale", ExtraAttrProperty::ONEDNN}, + {"scale", ExtraAttrProperty::ONEDNN}, + {"shift", ExtraAttrProperty::ONEDNN}, + {"output_data_type", ExtraAttrProperty::ONEDNN}, // GPUDNN dedicated attributes {"exhaustive_search", ExtraAttrProperty::GPUDNN}, {"fuse_relu_before_depthwise_conv", ExtraAttrProperty::GPUDNN}, diff --git a/paddle/phi/kernels/onednn/transpose_kernel.cc b/paddle/phi/kernels/onednn/transpose_kernel.cc index a36d5e4493..ab433fc6a4 100644 --- a/paddle/phi/kernels/onednn/transpose_kernel.cc +++ b/paddle/phi/kernels/onednn/transpose_kernel.cc @@ -86,31 +86,69 @@ void TransposeKernel(const Context& dev_ctx, auto x_vec_dims = vectorize(x.dims()); auto x_type = funcs::ToOneDNNDataType(x.dtype()); + + dnnl::primitive_attr attrs; + const int32_t mask = 0; + const auto quantization_scale = + dev_ctx.HasDnnAttr("scale") + ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("scale")) + : 1.0f; + const auto quantization_shift = + dev_ctx.HasDnnAttr("shift") + ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("shift")) + : 0.0f; + const auto output_data_type = + dev_ctx.HasDnnAttr("output_data_type") + ? PADDLE_GET_CONST(std::string, + dev_ctx.GetDnnAttr("output_data_type")) + : ""; + const bool with_scale = quantization_scale != 1.0f; + const bool with_shift = quantization_shift != 0.0f; + + if (with_scale) { + attrs.set_output_scales(mask, {quantization_scale}); + } + + if (with_shift) { + auto dst = output_data_type == "fp32" ? DNNL_ARG_SRC : DNNL_ARG_DST; + attrs.set_zero_points( + dst, mask, {static_cast(quantization_shift)}); + } + + DataType out_dtype; + if (output_data_type == "bf16") { + out_dtype = DataType::BFLOAT16; + } else if (output_data_type == "int8") { + out_dtype = DataType::INT8; + } else if (output_data_type == "uint8") { + out_dtype = DataType::UINT8; + } else if (output_data_type == "fp32") { + out_dtype = DataType::FLOAT32; + } else { + out_dtype = x.dtype(); + } + auto out_type = phi::funcs::ToOneDNNDataType(out_dtype); + funcs::ReorderOneDNNHandler reorder_handler( - x_vec_dims, x.dtype(), x_type, dev_ctx.GetEngine()); + x_vec_dims, x.dtype(), x_type, out_dtype, out_type, dev_ctx.GetEngine()); + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( x.mem_desc(), funcs::to_void_cast(x.data())); - auto dst_md = - dnnl::memory::desc(x_vec_dims, - x.mem_desc().data_type(), - funcs::GetPlainOneDNNFormat(x_vec_dims.size())); // a trick is used here to fake transpose of out_md, so later it will be // "untransposed", leaving output data in plain format tag std::vector fake_strides(axis.size()); - auto dims = dst_md.dims(); int total_stride = 1; - for (int i = static_cast(dims.size()) - 1; i >= 0; --i) { + for (int i = static_cast(x_vec_dims.size()) - 1; i >= 0; --i) { fake_strides[axis[i]] = total_stride; - total_stride *= dims[axis[i]]; + total_stride *= x_vec_dims[axis[i]]; } - dst_md = - dnnl::memory::desc(x_vec_dims, x.mem_desc().data_type(), fake_strides); - auto dst_data = dev_ctx.template Alloc(out); + auto dst_md = dnnl::memory::desc(x_vec_dims, out_type, fake_strides); auto reorder_dst_memory_p = - std::make_shared(dst_md, dev_ctx.GetEngine(), dst_data); - auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, - reorder_src_memory_p); + reorder_handler.AcquireDstMemory(out, dst_md, dev_ctx.GetPlace()); + + auto reorder_p = reorder_handler.AcquireReorder( + reorder_dst_memory_p, reorder_src_memory_p, attrs); auto& astream = OneDNNContext::tls().get_stream(); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); @@ -122,6 +160,7 @@ void TransposeKernel(const Context& dev_ctx, for (size_t i = 0; i < axis.size(); ++i) { permute_axis[axis[i]] = i; } + funcs::SetOutMemDescWithLogicalLayoutFusesSupport( dev_ctx, out, diff --git a/python/paddle/fluid/tests/unittests/ir/inference/program_config.py b/python/paddle/fluid/tests/unittests/ir/inference/program_config.py index 1d2b442d2d..cd9edd8350 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/program_config.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/program_config.py @@ -448,6 +448,8 @@ def create_quant_model( "pad2d", "reshape", "layer_norm", + "quantize", + "dequantize", ] op_real_in_out_name = { "conv2d": [["Input", "Filter"], ["Output"]], @@ -497,6 +499,8 @@ def create_quant_model( "pad2d": [["X"], ["Out"]], "flatten": [["X"], ["Out"]], "flatten2": [["X"], ["Out"]], + "quantize": [["Input"], ["Output"]], + "dequantize": [["Input"], ["Output"]], } def _get_op_output_var_names(op): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_quant_transpose_dequant_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_quant_transpose_dequant_fuse_pass.py new file mode 100644 index 0000000000..7272d07a5b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_quant_transpose_dequant_fuse_pass.py @@ -0,0 +1,121 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestQuantTranspose2DequantOneDNNFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_config(self, draw): + transpose_X = draw(st.booleans()) + axis = draw(st.sampled_from([[0, 2, 1, 3]])) + batch_size = draw(st.integers(min_value=1, max_value=4)) + channel = draw(st.integers(min_value=1, max_value=64)) + input_dim = draw(st.sampled_from([32, 64])) + scale = draw(st.floats(min_value=1, max_value=16)) + shift = draw(st.integers(min_value=1, max_value=3)) + is_negative_input = draw(st.booleans()) + + def generate_input(): + if transpose_X: + shape_x = [batch_size, channel, input_dim, 32] + else: + shape_x = [batch_size, channel, 32, input_dim] + return np.random.random(shape_x).astype(np.float32) + + quantize_op = OpConfig( + type='quantize', + inputs={'Input': ['input_data']}, + outputs={'Output': ['quantize_output']}, + attrs={ + 'is_negative_input': is_negative_input, + 'Scale': scale, + 'Shift': shift, + }, + ) + + transpose2_op_1 = OpConfig( + type='transpose2', + inputs={'X': ['quantize_output']}, + outputs={ + 'Out': ['transpose2_output_1'], + 'XShape': ['transpose2_xshape'], + }, + attrs={ + 'axis': axis, + 'use_mkldnn': True, + 'mkldnn_data_type': 'int8', + }, + use_mkldnn=True, + ) + + transpose2_op_2 = OpConfig( + type='transpose2', + inputs={'X': ['transpose2_output_1']}, + outputs={ + 'Out': ['transpose2_output_2'], + 'XShape': ['transpose2_xshape'], + }, + attrs={ + 'axis': axis, + 'use_mkldnn': True, + 'mkldnn_data_type': 'int8', + }, + use_mkldnn=True, + ) + + dequantize_op = OpConfig( + type='dequantize', + inputs={'Input': ['transpose2_output_2']}, + outputs={'Output': ['dequantize_output']}, + attrs={ + 'Scale': scale, + 'Shift': shift, + }, + ) + + program_config = ProgramConfig( + ops=[quantize_op, transpose2_op_1, transpose2_op_2, dequantize_op], + weights={}, + inputs={ + 'input_data': TensorConfig(data_gen=partial(generate_input)) + }, + outputs=['dequantize_output'], + ) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config( + use_mkldnn=True, + passes=['quant_transpose2_dequant_onednn_fuse_pass'], + ) + yield config, ['transpose2', 'transpose2'], (1e-5, 1e-5) + + def test(self): + self.run_and_statis( + quant=False, passes=['quant_transpose2_dequant_onednn_fuse_pass'] + ) + + +if __name__ == '__main__': + unittest.main() -- GitLab