From b8b2d6a9da3fbde6d550c0499cb0807f89089225 Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Wed, 22 Jun 2022 19:21:04 +0200 Subject: [PATCH] [external reviewing] Params to int8 pass (#42625) * sylwek prototype params to int8 pass * trying to make warmup work * wip * wip * change test to cpp test * review fixes, refactoring * more refactoring * add erasevars * change test to fixture * rename pass and reorder erasevars and graphsaferemovenodes * fix * more refactoring and fixed bug * formatting * remove scale count * enfroce message too short * remove erasevars erasevars couldbe cauuse of memory issues some other fixes * add count of successfull fuses to name of new nodes * FindVar -> GetVar and use ConvResidual pattern * use tensor->clear() instead of new variable * Update paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc Co-authored-by: Tomasz Socha * Update paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc Co-authored-by: Tomasz Socha * Update paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc Co-authored-by: Tomasz Socha * add log (review fix)c * review fix (2 functions to one) * code review: Conv->QuantizeConv * revert * fix formatting * remove unused functions * add paddle enforce Co-authored-by: Tomasz Socha --- paddle/fluid/framework/ir/CMakeLists.txt | 5 + .../framework/ir/graph_pattern_detector.cc | 1 - .../framework/ir/graph_pattern_detector.h | 1 - .../mkldnn/params_quantization_mkldnn_pass.cc | 193 +++++++++++ .../mkldnn/params_quantization_mkldnn_pass.h | 43 +++ .../params_quantization_mkldnn_pass_tester.cc | 305 ++++++++++++++++++ .../fluid/inference/api/mkldnn_quantizer.cc | 8 +- .../fluid/inference/tests/api/tester_helper.h | 3 + paddle/fluid/operators/conv_op.cc | 5 +- paddle/fluid/operators/conv_op.h | 1 + .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 23 +- .../quantization/quant2_int8_mkldnn_pass.py | 1 + .../unittests/ir/inference/auto_scan_test.py | 10 +- 13 files changed, 581 insertions(+), 18 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.h create mode 100644 paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index abc14dbd21..e74154698b 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -202,6 +202,7 @@ if(WITH_MKLDNN) pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(int8_scale_calculation_mkldnn_pass inference DIR mkldnn) + pass_library(params_quantization_mkldnn_pass inference DIR mkldnn) pass_library(fc_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(scale_matmul_fuse_pass inference DIR mkldnn) pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn) @@ -417,6 +418,10 @@ if(WITH_MKLDNN) test_int8_scale_calculation_mkldnn_pass SRCS mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc DEPS int8_scale_calculation_mkldnn_pass pass_test_util) + cc_test( + test_params_quantization_mkldnn_pass + SRCS mkldnn/params_quantization_mkldnn_pass_tester.cc + DEPS params_quantization_mkldnn_pass) cc_test( test_fc_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 27444eca5d..e118ef1ee6 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2013,7 +2013,6 @@ PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { if (!with_residual_data) { conv_op->assert_more([&](Node *x) { - auto node_names = x->Op()->InputNames(); if (!HasInput(x, "ResidualData") || x->Op()->Input("ResidualData").size() == 0) return true; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 49d928c419..48041db40c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1023,7 +1023,6 @@ struct Conv : public PatternBase { PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_input); PATTERN_DECL_NODE(conv_filter); - PATTERN_DECL_NODE(conv_residual_data); PATTERN_DECL_NODE(conv_output); }; diff --git a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc new file mode 100644 index 0000000000..34c888b23d --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.h" + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +namespace { + +template +void QuantizeParams(LoDTensor* param_tensor, const std::vector& scales) { + std::vector tmp_data; + tmp_data.reserve(param_tensor->numel()); + + auto length = param_tensor->numel() / scales.size(); + + const float* param_data = param_tensor->data(); + for (int64_t i = 0; i < param_tensor->numel(); ++i) { + tmp_data[i] = + static_cast(std::round(param_data[i] * scales[i / length])); + } + + auto dims = param_tensor->dims(); + param_tensor->clear(); + param_tensor->Resize(dims); + + auto int_param_data = param_tensor->mutable_data(CPUPlace()); + std::copy_n(tmp_data.data(), param_tensor->numel(), int_param_data); +} + +bool HasBias(ir::Node* conv_op) { + auto input_names = conv_op->Op()->InputNames(); + return std::find(input_names.begin(), input_names.end(), "Bias") != + input_names.end() && + conv_op->Op()->Input("Bias").size() > 0; +} + +bool ShouldSkipConv(ir::Node* conv_op, Scope* scope, ir::Node* conv_filter) { + if (!platform::HasOpINT8DataType(conv_op->Op())) { + VLOG(4) << "Skipping non-int8 convolution (id: " << conv_op->id() << ")."; + return true; + } + + auto filter_var = scope->GetVar(conv_filter->Name()); + if (filter_var->Get().dtype() != phi::DataType::FLOAT32) { + VLOG(4) << "Skipping convolution (id: " << conv_op->id() + << ") because it's a bug that it is detected again."; + return true; + } + + VLOG(4) << "Not skipping convolution (id: " << conv_op->id() << ")"; + return false; +} + +template +void QuantizeConvInput(Scope* scope, ir::Graph* g, ir::Node* conv_op, + const std::string& input_name, + const std::string& scales_attr_name) { + const auto scales = + conv_op->Op()->GetAttrIfExists>(scales_attr_name); + + auto* tensor = scope->GetVar(input_name)->GetMutable(); + QuantizeParams(tensor, scales); + + conv_op->Op()->SetAttr(scales_attr_name, std::vector(1, 1)); +} + +} // namespace + +ParamsQuantizationMkldnnPass::ParamsQuantizationMkldnnPass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "AnyLayout"}) + .End(); +} + +void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph, + bool with_residual_data) const { + GraphPatternDetector gpd; + patterns::ConvResidual conv_pattern(gpd.mutable_pattern(), name_scope_); + conv_pattern(with_residual_data); + + int params_to_int8_conv_found = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + VLOG(4) << "handle convolution in params_quantization_mkldnn_pass"; + + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); + + // get scope to interact with tensors + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + if (ShouldSkipConv(conv_op, scope, conv_filter)) { + return; + } + + QuantizeConvInput(scope, g, conv_op, conv_filter->Name(), + "Scale_weights"); + + if (HasBias(conv_op)) { + QuantizeConvInput( + scope, g, conv_op, conv_op->Op()->Input("Bias")[0], "Bias_scales"); + } + params_to_int8_conv_found++; + }; + gpd(graph, handler); + AddStatis(params_to_int8_conv_found); + + std::stringstream msg_ss; + msg_ss << "Quantized params of " << params_to_int8_conv_found + << " conv2d ops"; + if (with_residual_data) msg_ss << " with residual connection"; + paddle::string::PrettyLogDetail(msg_ss.str().c_str()); +} + +void ParamsQuantizationMkldnnPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument should not be NULL.")); + FusePassBase::Init(name_scope_, graph); + QuantizeConv(graph, true /*with_residual_data*/); + QuantizeConv(graph, false /*with_residual_data*/); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(params_quantization_mkldnn_pass, + paddle::framework::ir::ParamsQuantizationMkldnnPass); +REGISTER_PASS_CAPABILITY(params_quantization_mkldnn_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().LE( + "conv2d", 1)); diff --git a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.h new file mode 100644 index 0000000000..1168d6f10d --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; +/* + * Quantize parameters of ops + */ +class ParamsQuantizationMkldnnPass : public FusePassBase { + public: + ParamsQuantizationMkldnnPass(); + virtual ~ParamsQuantizationMkldnnPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + void QuantizeConv(Graph* graph, bool with_residual_connection) const; + + private: + const std::string name_scope_ = "params_quantization_mkldnn_pass"; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc new file mode 100644 index 0000000000..1ad98cd68b --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc @@ -0,0 +1,305 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.h" // NOLINT +#include "paddle/fluid/imperative/type_defs.h" +#include "paddle/fluid/platform/place.h" + +using LoDTensor = phi::DenseTensor; + +namespace paddle { +namespace framework { +namespace ir { +namespace { +struct Data { + Data() = default; + + Data(std::vector&& data_shape, std::vector&& raw_data) + : shape(std::move(data_shape)), data(std::move(raw_data)) { + auto size_from_shape = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + PADDLE_ENFORCE_EQ(size_from_shape, data.size(), + platform::errors::InvalidArgument( + "Shape size doesn't match data size.")); + } + + const std::vector& getShape() const { return shape; } + const std::vector& getData() const { return data; } + + private: + const std::vector shape; + const std::vector data; +}; + +struct TestScope { + void CreateTensor(const std::string& var_name, const Data& data) { + auto variable = scope.Var(var_name); + auto tensor = variable->GetMutable(); + tensor->Resize(phi::make_ddim(data.getShape())); + auto dptr = tensor->mutable_data(place); + std::copy(data.getData().begin(), data.getData().end(), dptr); + } + + const LoDTensor& GetTensor(const std::string& input) const { + Variable* var = scope.FindVar(input); + return var->Get(); + } + + framework::Scope* Scope() { return &scope; } + + private: + framework::Scope scope; + CPUPlace place; +}; + +struct ProgramStrategy { + virtual ~ProgramStrategy() {} + + std::unique_ptr CreateGraph() { + CreateProgram(); + auto graph = std::make_unique(program); + graph->SetNotOwned(kParamScopeAttr, test_scope.Scope()); + return graph; + } + + void CheckGraph(const std::unique_ptr& graph) const { + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + CheckOp(*node->Op()); + } + } + } + + protected: + virtual void CreateProgram() = 0; + + virtual void CheckOp(const OpDesc& op) const = 0; + + VarDesc* AddInput(OpDesc* op, std::string input_name, const Data& data) { + const std::string var_name = input_name + "_var"; + op->SetInput(input_name, {var_name}); + auto var = program.MutableBlock(0)->Var(var_name); + var->SetShape(data.getShape()); + test_scope.CreateTensor(var_name, data); + return var; + } + + void AddOutput(OpDesc* op, std::string output_name, const Data& data) { + const std::string var_name = output_name + "_var"; + op->SetOutput(output_name, {var_name}); + program.MutableBlock(0)->Var(var_name); + test_scope.CreateTensor(var_name, data); + } + + protected: + TestScope test_scope; + ProgramDesc program; +}; + +struct ConvProgramStrategy : public ProgramStrategy { + ConvProgramStrategy(Data&& input, Data&& filter, Data&& output, + std::vector&& scale_weights, int groups = 1, + Data&& bias = Data(), + std::vector&& scale_bias = {}) + : input(std::move(input)), + filter(std::move(filter)), + output(std::move(output)), + scale_weights(std::move(scale_weights)), + groups(std::move(groups)), + bias(std::move(bias)), + scale_bias(std::move(scale_bias)) {} + + protected: + OpDesc* CreateBasicConvOp() { + auto op = program.MutableBlock(0)->AppendOp(); + op->SetType("conv2d"); + op->SetAttr("use_mkldnn", true); + op->SetAttr("name", std::string{"Conv1"}); + op->SetAttr("mkldnn_data_type", std::string{"int8"}); + op->SetAttr("data_format", std::string{"NCHW"}); + op->SetAttr("dilations", std::vector({1, 1})); + op->SetAttr("paddings", std::vector({1, 1})); + op->SetAttr("strides", std::vector({1, 1})); + return op; + } + + protected: + void CreateProgram() override { + OpDesc* op = CreateBasicConvOp(); + AddInput(op, "Input", input); + AddInput(op, "Filter", filter)->SetPersistable(true); + AddOutput(op, "Output", output); + + op->SetAttr("Scale_weights", scale_weights); + op->SetAttr("Scale_in", 1.0f); + op->SetAttr("groups", groups); + + if (HasBias()) { + AddInput(op, "Bias", bias); + op->SetAttr("Bias_scales", scale_bias); + } + } + + void CheckOp(const OpDesc& op) const override { + CheckFilter(op); + if (HasBias()) { + CheckBias(op); + } + } + + bool HasBias() const { return !bias.getData().empty(); } + + void CheckFilter(const OpDesc& op) const { + EXPECT_EQ(op.GetAttrIfExists>("Scale_weights"), + std::vector(1, 1)); + + auto filter_inputs = op.Input("Filter"); + ASSERT_EQ(filter_inputs.size(), 1ul); + + auto tensor = test_scope.GetTensor(filter_inputs[0]); + ASSERT_EQ(tensor.dtype(), phi::DataType::INT8); + + auto filter_ptr = tensor.data(); + ASSERT_NE(filter_ptr, nullptr); + auto length = tensor.numel() / scale_weights.size(); + for (int64_t i = 0; i < tensor.numel(); i++) { + EXPECT_EQ(filter_ptr[i], + static_cast(std::round(filter.getData()[i] * + scale_weights[i / length]))); + } + } + + void CheckBias(const OpDesc& op) const { + EXPECT_EQ(op.GetAttrIfExists>("Bias_scales"), + std::vector(1, 1)); + + auto bias_inputs = op.Input("Bias"); + ASSERT_EQ(bias_inputs.size(), 1ul); + + auto tensor = test_scope.GetTensor(bias_inputs[0]); + auto bias_ptr = tensor.data(); + ASSERT_NE(bias_ptr, nullptr); + auto length = tensor.numel() / scale_bias.size(); + for (int64_t i = 0; i < tensor.numel(); i++) { + EXPECT_EQ(bias_ptr[i], static_cast(std::round( + bias.getData()[i] * scale_bias[i / length]))); + } + } + + private: + const Data input; + const Data filter; + const Data output; + const std::vector scale_weights; + const int groups; + + const Data bias; + const std::vector scale_bias; +}; + +struct ParamsQuantizationMkldnnPassTestFixture : public ::testing::Test { + void RunPassTest(std::unique_ptr program) { + auto graph = program->CreateGraph(); + + auto pass = PassRegistry::Instance().Get("params_quantization_mkldnn_pass"); + graph.reset(pass->Apply(graph.release())); + + program->CheckGraph(graph); + } +}; + +Data GenericInput() { return Data({1, 4, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f}); } +Data GenericOutput() { return GenericInput(); } + +TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_without_bias_o1i1h1w1) { + auto program = std::make_unique( + GenericInput(), Data({1, 1, 1, 1}, {1.5f}), GenericOutput(), + std::vector{2.f}); + RunPassTest(std::move(program)); +} + +TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_without_bias_2o1i1h1w) { + auto program = std::make_unique( + GenericInput(), Data({2, 1, 1, 1}, {1.5f, 1.5f}), GenericOutput(), + std::vector{2.f, 4.f}); + RunPassTest(std::move(program)); +} + +TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_without_bias_2o2i2h2w) { + auto program = std::make_unique( + GenericInput(), + Data({2, 2, 2, 2}, {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, + 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f}), + GenericOutput(), std::vector{2.f, 4.f}); + RunPassTest(std::move(program)); +} + +TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_without_bias_2g2o2i1h1w) { + auto program = std::make_unique( + GenericInput(), + Data({2, 2, 2, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f}), + GenericOutput(), std::vector{2.f, 2.f, 2.f, 2.f}, 2); + RunPassTest(std::move(program)); +} + +TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_without_bias_2g2o1i1h1w) { + auto program = std::make_unique( + GenericInput(), Data({2, 2, 1, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f}), + GenericOutput(), std::vector{2.f, 2.f, 2.f, 2.f}, 2); + RunPassTest(std::move(program)); +} + +TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_1o1i1h1w) { + auto program = std::make_unique( + GenericInput(), Data({1, 1, 1, 1}, {1.5f}), GenericOutput(), + std::vector{2.f}, 1, Data({1, 1, 1, 1}, {1.5f}), + std::vector{2.f}); + RunPassTest(std::move(program)); +} + +TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2o1i1h1w) { + auto program = std::make_unique( + GenericInput(), Data({2, 1, 1, 1}, {1.5f, 1.5f}), GenericOutput(), + std::vector{2.f, 4.f}, 1, Data({2, 1, 1, 1}, {1.5f, 1.5f}), + std::vector{2.f, 4.f}); + RunPassTest(std::move(program)); +} + +TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o1i1h1w) { + auto program = std::make_unique( + GenericInput(), Data({4, 1, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f}), + GenericOutput(), std::vector{2.f, 2.f, 4.f, 4.f}, 2, + Data({4, 1, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f}), + std::vector{2.f, 2.f, 4.f, 4.f}); + RunPassTest(std::move(program)); +} + +TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o2i1h1w) { + auto program = std::make_unique( + GenericInput(), + Data({2, 2, 2, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f}), + GenericOutput(), std::vector{2.f, 2.f, 4.f, 4.f}, 2, + Data({2, 2, 1, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f}), + std::vector{2.f, 2.f, 4.f, 4.f}); + RunPassTest(std::move(program)); +} + +} // namespace +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(params_quantization_mkldnn_pass); diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 73096973c3..29f216e389 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -573,11 +573,9 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { arg.main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr); auto* builder = predictor_.config_.pass_builder(); - builder->SetPasses({ - "cpu_quantize_pass", - "cpu_quantize_squash_pass", - "int8_scale_calculation_mkldnn_pass", - }); + builder->SetPasses({"cpu_quantize_pass", "cpu_quantize_squash_pass", + "int8_scale_calculation_mkldnn_pass", + "params_quantization_mkldnn_pass"}); if (predictor_.config_.ir_debug_) builder->TurnOnDebug(); auto passes = builder->AllPasses(); predictor_.argument_.SetIrAnalysisPasses(passes); diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index d7784a909a..8c6ee90930 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -797,6 +797,9 @@ void CompareQuantizedAndAnalysis( const AnalysisConfig *config, const AnalysisConfig *qconfig, const std::vector> &inputs, const int compared_idx = 1) { + PADDLE_ENFORCE_GT( + inputs.size(), 0, + platform::errors::PreconditionNotMet("There is no input data provided.")); PADDLE_ENFORCE_EQ( inputs[0][0].shape[0], FLAGS_batch_size, platform::errors::InvalidArgument( diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 28ca2feeec..a175cce9c2 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -187,7 +187,10 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( customized_type_value = (input_data_type == framework::DataTypeTrait::DataType() || input_data_type == framework::DataTypeTrait::DataType()) - ? kConvMKLDNNINT8 + ? OperatorWithKernel::IndicateVarDataType(ctx, "Filter") == + framework::DataTypeTrait::DataType() + ? kConvMKLDNNINT8WS8 + : kConvMKLDNNINT8 : kConvMKLDNNFP32; } #endif diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 644a827b48..21223ed4e4 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -32,6 +32,7 @@ namespace operators { using Tensor = framework::Tensor; constexpr int kConvMKLDNNFP32 = 1; constexpr int kConvMKLDNNINT8 = 2; +constexpr int kConvMKLDNNINT8WS8 = 3; constexpr int MaxKeyLength = 256; // Base convolution operator definations for other conv diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index a2828b978e..92b799ca2d 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -639,14 +639,21 @@ class ConvMKLDNNHandlerT if (is_test && bias_mem_p) { return bias_mem_p; } else { - const K* bias_data = bias->data(); + // if K is int8 (weights are int8) then biases are int32 + using K_Bias = typename std::conditional::value, + int32_t, K>::type; + if (std::is_same::value && + bias->dtype() != phi::DataType::INT32) { + LOG(ERROR) << "Bias should be of type int32 but is " << bias->dtype(); + } + const K_Bias* bias_data = bias->data(); auto user_bias_md = platform::MKLDNNMemDesc( - phi::vectorize(bias->dims()), platform::MKLDNNGetDataType(), + phi::vectorize(bias->dims()), platform::MKLDNNGetDataType(), MKLDNNMemoryFormat::x); return this->AcquireMemoryWithReorder( user_bias_md, this->fwd_pd_->bias_desc(), - platform::to_void_cast(bias_data), "@bias_mem_p", is_test, {}, + platform::to_void_cast(bias_data), "@bias_mem_p", is_test, {}, scale_data, mask); } } @@ -1031,11 +1038,21 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ops::kConvMKLDNNINT8, ops::ConvMKLDNNOpKernel); +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, + ::paddle::platform::CPUPlace, U8WS8, + ops::kConvMKLDNNINT8WS8, + ops::ConvMKLDNNOpKernel); + REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ::paddle::platform::CPUPlace, S8, ops::kConvMKLDNNINT8, ops::ConvMKLDNNOpKernel); +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, + ::paddle::platform::CPUPlace, S8WS8, + ops::kConvMKLDNNINT8WS8, + ops::ConvMKLDNNOpKernel); + REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, FP32, ops::kConvMKLDNNFP32, diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 49dcda0cca..622d54343f 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -662,4 +662,5 @@ class Quant2Int8MkldnnPass(object): self._get_data_layout(graph)]) graph = self._apply_pass(graph, 'cpu_quantize_squash_pass') graph = self._apply_pass(graph, 'int8_scale_calculation_mkldnn_pass') + graph = self._apply_pass(graph, 'params_quantization_mkldnn_pass') return graph diff --git a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py index 818862e51d..50d32a9ed7 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py @@ -21,17 +21,13 @@ import time import logging import shutil import paddle -import paddle.fluid as fluid -from paddle.fluid.initializer import NumpyArrayInitializer from paddle.fluid.core import PassVersionChecker -import paddle.fluid.core as core -from paddle import compat as cpt import paddle.inference as paddle_infer -from typing import Optional, List, Callable, Dict, Any, Set -from program_config import TensorConfig, OpConfig, ProgramConfig, create_fake_model, create_quant_model +from typing import Optional, List, Callable, Dict, Any +from program_config import OpConfig, ProgramConfig, create_fake_model, create_quant_model import hypothesis -from hypothesis import given, settings, seed, reproduce_failure +from hypothesis import given, settings import hypothesis.strategies as st logging.basicConfig(level=logging.INFO, format="%(message)s") -- GitLab