diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 283e79b81e7c678e8a4fcc3becefa5279eacb38b..d000dc70853659d27885721e2d1c1863f49d3067 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -123,6 +123,7 @@ if(WITH_MKLDNN) pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) + pass_library(int8_scale_calculation_mkldnn_pass inference DIR mkldnn) pass_library(fc_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(scale_matmul_fuse_pass inference DIR mkldnn) pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn) @@ -209,6 +210,7 @@ if (WITH_MKLDNN) cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass) cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util) + cc_test(test_int8_scale_calculation_mkldnn_pass SRCS mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc DEPS int8_scale_calculation_mkldnn_pass pass_test_util) cc_test(test_fc_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS fc_elementwise_add_mkldnn_fuse_pass pass_test_util) cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util) cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 96a1e5c0719dc71993728cc862e3e6cbd661365e..c9fea057d444d7946a404f31746f9e73086ba30d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1057,7 +1057,7 @@ struct Pool : public PatternBase { // Elementwise ops // Forward pass for element-wise operators (add, mul) -// elementwise_mul_out is the result of the operator +// elementwise_out is the result of the operator struct Elementwise : public PatternBase { Elementwise(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "elementwise") {} diff --git a/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..678a8fb4a6955626f153b104a926ea7e5e66ff51 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.cc @@ -0,0 +1,179 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.h" + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +Int8ScaleCalculationMkldnnPass::Int8ScaleCalculationMkldnnPass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "AnyLayout"}) + .End(); +} + +void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument should not be NULL.")); + FusePassBase::Init("int8_scale_calculation_mkldnn_pass", graph); + GraphPatternDetector gpd; + patterns::Conv conv_pattern(gpd.mutable_pattern(), + "int8_scale_calculation_mkldnn_pass"); + conv_pattern(); + + int found_int8_scales_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); + + if (!platform::HasOpINT8DataType(conv_op->Op()) || + conv_op->Op()->HasAttr("Sum_scale")) { + return; + } + + GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); + + auto input_names = conv_op->Op()->InputNames(); + bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") != + input_names.end(); + std::vector weights_tz = conv_filter->Var()->GetShape(); + const int groups = + std::max(conv_op->Op()->GetAttrIfExists("groups"), 1); + + const auto& scale_weights_data = + conv_op->Op()->GetAttrIfExists>("Scale_weights"); + const auto& scale_in_data = + conv_op->Op()->GetAttrIfExists("Scale_in"); + + bool is_multi_channel = scale_weights_data.size() > 1; + + int count = 1; + if (is_multi_channel) { + count *= weights_tz[0]; + if (groups > 1) { + count *= weights_tz[1]; + } + } + + if (has_bias && conv_op->Op()->Input("Bias").size() > 0) { + auto bias_scales = std::vector(count); + for (int i = 0; i < count; i++) { + bias_scales[i] = scale_in_data * scale_weights_data[i]; + } + conv_op->Op()->SetAttr("Bias_scales", bias_scales); + } + + const bool& force_fp32_output = + conv_op->Op()->GetAttrIfExists("force_fp32_output"); + const bool& fuse_residual_conn = + conv_op->Op()->GetAttrIfExists("fuse_residual_connection"); + const auto& scale_in_eltwise_data = + conv_op->Op()->GetAttrIfExists("Scale_in_eltwise"); + bool has_activation = + !conv_op->Op()->GetAttrIfExists("fuse_activation").empty(); + float activation_scale = + force_fp32_output + ? 1.0f + : has_activation + ? conv_op->Op()->GetAttrIfExists("Scale_out") + : 1.0f; + auto scale_out_data = + force_fp32_output + ? 1.0f + : has_activation + ? 1.0f + : conv_op->Op()->GetAttrIfExists("Scale_out"); + float sum_scale = + fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f; + + std::vector output_shift_scale(count); + +#pragma omp parallel for if (count > 50) + for (int i = 0; i < count; i++) { + if (scale_weights_data[i] == 0.0) + // weights data will contain 0 in some models, then weights + // scale couldn't be calculated + output_shift_scale[i] = scale_out_data; + else + output_shift_scale[i] = + static_cast(static_cast(scale_out_data) / + (static_cast(scale_in_data) * + static_cast(scale_weights_data[i]))); + } + + conv_op->Op()->SetAttr("Sum_scale", sum_scale); + conv_op->Op()->SetAttr("Output_shift_scale", output_shift_scale); + conv_op->Op()->SetAttr("Activation_scale", activation_scale); + found_int8_scales_count++; + }; + gpd(graph, handler); + AddStatis(found_int8_scales_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(int8_scale_calculation_mkldnn_pass, + paddle::framework::ir::Int8ScaleCalculationMkldnnPass); +REGISTER_PASS_CAPABILITY(int8_scale_calculation_mkldnn_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().LE( + "conv2d", 1)); diff --git a/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..9233650a2db3c2dca88391243abd33252b87e777 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// #include +// #include +// #include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; +/* + * compute quantization scales for biases and weights + */ +class Int8ScaleCalculationMkldnnPass : public FusePassBase { + public: + Int8ScaleCalculationMkldnnPass(); + virtual ~Int8ScaleCalculationMkldnnPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..804d04e35f6909d070db3e9310aa5a006ee2f7c2 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.h" +#include + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, + const std::vector& inputs, + const std::vector& outputs, + std::vector scale_weights = {1.5f}) { + auto* op = prog->MutableBlock(0)->AppendOp(); + + op->SetType(type); + if (type == "conv2d") { + op->SetAttr("use_mkldnn", true); + op->SetAttr("name", name); + op->SetAttr("strides", std::vector({1, 1})); + op->SetAttr("groups", 1); + op->SetAttr("paddings", std::vector({0, 0})); + op->SetAttr("padding_algorithm", std::string("EXPLICIT")); + op->SetAttr("dilations", std::vector({1, 1})); + op->SetAttr("data_format", std::string("NCHW")); + op->SetInput("Input", {inputs[0]}); + op->SetInput("Filter", {inputs[1]}); + if (inputs.size() > 2) + op->SetInput("Bias", {inputs[2]}); + else + op->SetInput("Bias", {}); + + op->SetOutput("Output", outputs); + op->SetAttr("Scale_in", 1.0f); + op->SetAttr("Scale_out", 1.0f); + op->SetAttr("Scale_weights", scale_weights); + op->SetAttr("use_mkldnn", true); + op->SetAttr("mkldnn_data_type", std::string("int8")); + } else { + FAIL() << "Unexpected operator type."; + } +} + +ProgramDesc BuildProgramDesc(bool convWithExistingBias, + std::vector scale_weights = {1.5}) { + ProgramDesc prog; + std::vector nodes{"c", "weights", "f"}; + if (convWithExistingBias) nodes.push_back("conv_bias"); + for (auto& v : nodes) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + if (v == "weights") { + var->SetPersistable(true); + var->SetShape({1, static_cast(scale_weights.size()), 1, 1}); + } + } + + if (convWithExistingBias) { + SetOp(&prog, "conv2d", "conv", + std::vector({"c", "weights", "conv_bias"}), + std::vector({"f"}), scale_weights); + } else if (scale_weights.size() > 1) { + SetOp(&prog, "conv2d", "conv", + std::vector({"c", "weights", "conv_bias"}), + std::vector({"f"}), scale_weights); + } else { + SetOp(&prog, "conv2d", "conv", std::vector({"c", "weights"}), + std::vector({"f"})); + } + + return prog; +} + +void MainTest(bool convWithExistingBias, int removed_nodes_count, float scale, + std::vector scale_weights = {1.5f}) { + auto prog = BuildProgramDesc(convWithExistingBias, scale_weights); + std::unique_ptr graph(new ir::Graph(prog)); + auto pass = + PassRegistry::Instance().Get("int8_scale_calculation_mkldnn_pass"); + int original_nodes_num = graph->Nodes().size(); + graph.reset(pass->Apply(graph.release())); + int current_nodes_num = graph->Nodes().size(); + + EXPECT_EQ(original_nodes_num, current_nodes_num); + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "conv2d") { + auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + + EXPECT_EQ(op->GetAttrIfExists>("Scale_weights"), + scale_weights); + EXPECT_EQ(op->GetAttrIfExists("Scale_in"), scale); + EXPECT_EQ(op->GetAttrIfExists("Scale_out"), scale); + + EXPECT_EQ(op->GetAttrIfExists("Sum_scale"), scale); + EXPECT_EQ( + op->GetAttrIfExists>("Output_shift_scale")[0], + scale / scale_weights[0]); + EXPECT_EQ(op->GetAttrIfExists("Activation_scale"), scale); + + if (convWithExistingBias) { + EXPECT_EQ(op->GetAttrIfExists>("Bias_scales")[0], + scale * scale_weights[0]); + } + } + } + EXPECT_EQ(original_nodes_num - removed_nodes_count, current_nodes_num); +} + +TEST(Int8ScaleCalculationMkldnnPass, int8_scale_calculation_with_no_bias) { + auto scale = 1.0f; + int removed_nodes_count = 0; + auto scale_weights = {1.5f}; + MainTest(false, removed_nodes_count, scale, scale_weights); +} + +TEST(Int8ScaleCalculationMkldnnPass, int8_scale_calculation_with_bias) { + auto scale = 1.0f; + int removed_nodes_count = 0; + auto scale_weights = {1.5f}; + MainTest(true, removed_nodes_count, scale, scale_weights); +} + +TEST(Int8ScaleCalculationMkldnnPass, + int8_scale_calculation_with_bias_scale_weights) { + auto scale = 1.0f; + int removed_nodes_count = 0; + std::vector scale_weights = {1.5f, 2.3f}; + MainTest(true, removed_nodes_count, scale, scale_weights); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(int8_scale_calculation_mkldnn_pass); diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 3a3e6a0908ea1b4b38bb106587f05ac4e25ce77b..4dc80a1d75390a0c6f353c8c9a20428d49d4a94f 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -571,6 +571,7 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { auto* builder = predictor_.config_.pass_builder(); builder->SetPasses({ "cpu_quantize_pass", "cpu_quantize_squash_pass", + "int8_scale_calculation_mkldnn_pass", }); if (predictor_.config_.ir_debug_) builder->TurnOnDebug(); auto passes = builder->AllPasses(); diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 67d1aaa4baf52396bd02f079c95cbd0cd9662b2f..fba17d303f282e155ffd123c6178e1b7e21bd72f 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -223,9 +223,17 @@ class ConvMKLDNNHandlerT float sum_scale = 1.0f; float activation_scale = 1.0f; std::vector output_shift_scale; - if (platform::is_int8()) - std::tie(sum_scale, output_shift_scale, activation_scale) = - get_int8_scales(ctx); + if (platform::is_int8()) { + if (ctx.HasAttr("Sum_scale")) { + sum_scale = ctx.Attr("Sum_scale"); + activation_scale = ctx.Attr("Activation_scale"); + output_shift_scale = + ctx.Attr>("Output_shift_scale"); + } else { + std::tie(sum_scale, output_shift_scale, activation_scale) = + get_int8_scales(ctx); + } + } const dnnl::primitive_attr conv_attr = CreatePostOps( fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn, @@ -872,8 +880,18 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { {DNNL_ARG_DST, *dst_memory_p}}; if (bias) { - auto p_scales_tuple = handler.get_int8_bias_scales(ctx); - + std::vector bias_scales; + auto p_scales_tuple = + std::make_shared>>( + std::make_tuple(static_cast(mask_reorder), bias_scales)); + if (ctx.HasAttr("Bias_scales")) { + bias_scales = ctx.Attr>("Bias_scales"); + p_scales_tuple = + std::make_shared>>( + std::make_tuple(static_cast(mask_reorder), bias_scales)); + } else { + p_scales_tuple = handler.get_int8_bias_scales(ctx); + } auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( bias, true, std::get<1>(*p_scales_tuple), std::get<0>(*p_scales_tuple)); diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index e8a9300635e2ca3fce8f240ad4e72b26b84313a0..e543bc1e17b2cce12bcbfbe78956732570de94b2 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -668,4 +668,5 @@ class Quant2Int8MkldnnPass(object): graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'], [self._var_quant_scales, self._get_data_layout(graph)]) graph = self._apply_pass(graph, 'cpu_quantize_squash_pass') + graph = self._apply_pass(graph, 'int8_scale_calculation_mkldnn_pass') return graph diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_int8_scale_calculation_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_int8_scale_calculation_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..31415f64725879dfe66d191660dabd08c1964873 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_int8_scale_calculation_pass.py @@ -0,0 +1,146 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import unittest + +import hypothesis.strategies as st + + +class TestInt8ScaleCalculationMkldnnPass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_gpu=False) + config.pass_builder().append_pass("int8_scale_calculation_mkldnn_pass") + yield config, ["conv2d"], (1e-4, 1e-5) + + def is_program_valid(self, prog_config): + paddings = prog_config.ops[0].attrs["paddings"] + strides = prog_config.ops[0].attrs["strides"] + groups = prog_config.ops[0].attrs["groups"] + padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"] + dilations = prog_config.ops[0].attrs["dilations"] + data_format = prog_config.ops[0].attrs["data_format"] + filter_shape = prog_config.weights["filter"].shape + input_shape = prog_config.inputs["input_x"].shape + if padding_algorithm == "VALID": + if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ + ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: + return False + if padding_algorithm == "EXPLICIT": + if ((input_shape[2] + paddings[0] + paddings[1] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ + ((input_shape[3] + paddings[2] + paddings[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: + return False + if data_format == "NCHW": + if input_shape[1] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + else: + if input_shape[3] != filter_shape[1] * groups: + return False + if filter_shape[0] % groups != 0: + return False + return True + + def sample_program_config(self, draw): + x_shape = draw( + st.lists( + st.integers( + min_value=5, max_value=100), min_size=4, max_size=4)) + x_shape[1] = draw(st.integers(min_value=5, max_value=10)) + + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + + f_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=4), min_size=4, max_size=4)) + if data_format == "NCHW": + f_shape[1] = x_shape[1] + else: + f_shape[1] = x_shape[3] + + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=4), min_size=2, max_size=2)) + + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + + padding = draw( + st.lists( + st.integers( + min_value=1, max_value=4), min_size=4, max_size=4)) + + groups = draw(st.integers(min_value=1, max_value=3)) + + dilations = draw( + st.lists( + st.integers( + min_value=1, max_value=4), min_size=2, max_size=2)) + + bias_shape = [f_shape[0]] + inputs = dict() + weights = dict() + use_mkldnn = True + + has_bias = draw(st.booleans()) + if has_bias: + inputs = { + "Input": ["input_x"], + "Filter": ["filter"], + } + weights = { + "filter": TensorConfig(shape=f_shape), + "bias": TensorConfig(shape=bias_shape), + } + else: + inputs = { + "Input": ["input_x"], + "Filter": ["filter"], + } + weights = {"filter": TensorConfig(shape=f_shape), } + + conv2d_op = OpConfig( + "conv2d", + inputs=inputs, + outputs={"Output": ["conv2d_out"]}, + strides=strides, + padding_algorithm=padding_algorithm, + paddings=padding, + groups=groups, + dilations=dilations, + data_format=data_format, + use_mkldnn=use_mkldnn, + mkldnn_data_type="int8") + + ops = [conv2d_op] + + program_config = ProgramConfig( + ops=ops, + weights=weights, + inputs={"input_x": TensorConfig(shape=x_shape)}, + outputs=["conv2d_out"]) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=100, + passes=["int8_scale_calculation_mkldnn_pass"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 5070ea2ef06a3e06efc2bf4f2a4d9e79aff6cfbf..6067b40f0a7c1a6fd4eacdfe9f942f5f08ba46b6 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -655,6 +655,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_transpose_mkldnn_op', 'test_mkldnn_conv_activation_fuse_pass', 'test_mkldnn_conv_concat_relu_mkldnn_fuse_pass', + 'test_mkldnn_int8_scale_calculation_pass', 'test_mkldnn_matmul_op_output_fuse_pass', 'test_mkldnn_matmul_transpose_reshape_fuse_pass', 'test_mkldnn_scale_matmul_fuse_pass',