未验证 提交 c0652972 编写于 作者: Z Zuza Gawrysiak 提交者: GitHub

Move weights and biases scale computing into pass (#42241)

* Add int8 scales gathering pass for convolution

* Fix typo

* Add unittest

* Add corrected unit test

* Change test name

* Remove enabling mkldnn in test

* Speed up test

* Change max examples

* Add functional test

* Change test name

* Add new test case

* Rename pass
上级 c16345cb
......@@ -123,6 +123,7 @@ if(WITH_MKLDNN)
pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(int8_scale_calculation_mkldnn_pass inference DIR mkldnn)
pass_library(fc_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(scale_matmul_fuse_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn)
......@@ -209,6 +210,7 @@ if (WITH_MKLDNN)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util)
cc_test(test_int8_scale_calculation_mkldnn_pass SRCS mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc DEPS int8_scale_calculation_mkldnn_pass pass_test_util)
cc_test(test_fc_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS fc_elementwise_add_mkldnn_fuse_pass pass_test_util)
cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util)
cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util)
......
......@@ -1057,7 +1057,7 @@ struct Pool : public PatternBase {
// Elementwise ops
// Forward pass for element-wise operators (add, mul)
// elementwise_mul_out is the result of the operator
// elementwise_out is the result of the operator
struct Elementwise : public PatternBase {
Elementwise(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise") {}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle {
namespace framework {
namespace ir {
Int8ScaleCalculationMkldnnPass::Int8ScaleCalculationMkldnnPass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
}
void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
FusePassBase::Init("int8_scale_calculation_mkldnn_pass", graph);
GraphPatternDetector gpd;
patterns::Conv conv_pattern(gpd.mutable_pattern(),
"int8_scale_calculation_mkldnn_pass");
conv_pattern();
int found_int8_scales_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
if (!platform::HasOpINT8DataType(conv_op->Op()) ||
conv_op->Op()->HasAttr("Sum_scale")) {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
auto input_names = conv_op->Op()->InputNames();
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
input_names.end();
std::vector<int64_t> weights_tz = conv_filter->Var()->GetShape();
const int groups =
std::max(conv_op->Op()->GetAttrIfExists<int>("groups"), 1);
const auto& scale_weights_data =
conv_op->Op()->GetAttrIfExists<std::vector<float>>("Scale_weights");
const auto& scale_in_data =
conv_op->Op()->GetAttrIfExists<float>("Scale_in");
bool is_multi_channel = scale_weights_data.size() > 1;
int count = 1;
if (is_multi_channel) {
count *= weights_tz[0];
if (groups > 1) {
count *= weights_tz[1];
}
}
if (has_bias && conv_op->Op()->Input("Bias").size() > 0) {
auto bias_scales = std::vector<float>(count);
for (int i = 0; i < count; i++) {
bias_scales[i] = scale_in_data * scale_weights_data[i];
}
conv_op->Op()->SetAttr("Bias_scales", bias_scales);
}
const bool& force_fp32_output =
conv_op->Op()->GetAttrIfExists<bool>("force_fp32_output");
const bool& fuse_residual_conn =
conv_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection");
const auto& scale_in_eltwise_data =
conv_op->Op()->GetAttrIfExists<float>("Scale_in_eltwise");
bool has_activation =
!conv_op->Op()->GetAttrIfExists<std::string>("fuse_activation").empty();
float activation_scale =
force_fp32_output
? 1.0f
: has_activation
? conv_op->Op()->GetAttrIfExists<float>("Scale_out")
: 1.0f;
auto scale_out_data =
force_fp32_output
? 1.0f
: has_activation
? 1.0f
: conv_op->Op()->GetAttrIfExists<float>("Scale_out");
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
std::vector<float> output_shift_scale(count);
#pragma omp parallel for if (count > 50)
for (int i = 0; i < count; i++) {
if (scale_weights_data[i] == 0.0)
// weights data will contain 0 in some models, then weights
// scale couldn't be calculated
output_shift_scale[i] = scale_out_data;
else
output_shift_scale[i] =
static_cast<float>(static_cast<double>(scale_out_data) /
(static_cast<double>(scale_in_data) *
static_cast<double>(scale_weights_data[i])));
}
conv_op->Op()->SetAttr("Sum_scale", sum_scale);
conv_op->Op()->SetAttr("Output_shift_scale", output_shift_scale);
conv_op->Op()->SetAttr("Activation_scale", activation_scale);
found_int8_scales_count++;
};
gpd(graph, handler);
AddStatis(found_int8_scales_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(int8_scale_calculation_mkldnn_pass,
paddle::framework::ir::Int8ScaleCalculationMkldnnPass);
REGISTER_PASS_CAPABILITY(int8_scale_calculation_mkldnn_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"conv2d", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// #include <memory>
// #include <string>
// #include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
/*
* compute quantization scales for biases and weights
*/
class Int8ScaleCalculationMkldnnPass : public FusePassBase {
public:
Int8ScaleCalculationMkldnnPass();
virtual ~Int8ScaleCalculationMkldnnPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
std::vector<float> scale_weights = {1.5f}) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", true);
op->SetAttr("name", name);
op->SetAttr("strides", std::vector<int>({1, 1}));
op->SetAttr("groups", 1);
op->SetAttr("paddings", std::vector<int>({0, 0}));
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("dilations", std::vector<int>({1, 1}));
op->SetAttr("data_format", std::string("NCHW"));
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2)
op->SetInput("Bias", {inputs[2]});
else
op->SetInput("Bias", {});
op->SetOutput("Output", outputs);
op->SetAttr("Scale_in", 1.0f);
op->SetAttr("Scale_out", 1.0f);
op->SetAttr("Scale_weights", scale_weights);
op->SetAttr("use_mkldnn", true);
op->SetAttr("mkldnn_data_type", std::string("int8"));
} else {
FAIL() << "Unexpected operator type.";
}
}
ProgramDesc BuildProgramDesc(bool convWithExistingBias,
std::vector<float> scale_weights = {1.5}) {
ProgramDesc prog;
std::vector<std::string> nodes{"c", "weights", "f"};
if (convWithExistingBias) nodes.push_back("conv_bias");
for (auto& v : nodes) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights") {
var->SetPersistable(true);
var->SetShape({1, static_cast<int>(scale_weights.size()), 1, 1});
}
}
if (convWithExistingBias) {
SetOp(&prog, "conv2d", "conv",
std::vector<std::string>({"c", "weights", "conv_bias"}),
std::vector<std::string>({"f"}), scale_weights);
} else if (scale_weights.size() > 1) {
SetOp(&prog, "conv2d", "conv",
std::vector<std::string>({"c", "weights", "conv_bias"}),
std::vector<std::string>({"f"}), scale_weights);
} else {
SetOp(&prog, "conv2d", "conv", std::vector<std::string>({"c", "weights"}),
std::vector<std::string>({"f"}));
}
return prog;
}
void MainTest(bool convWithExistingBias, int removed_nodes_count, float scale,
std::vector<float> scale_weights = {1.5f}) {
auto prog = BuildProgramDesc(convWithExistingBias, scale_weights);
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass =
PassRegistry::Instance().Get("int8_scale_calculation_mkldnn_pass");
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
EXPECT_EQ(original_nodes_num, current_nodes_num);
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_EQ(op->GetAttrIfExists<std::vector<float>>("Scale_weights"),
scale_weights);
EXPECT_EQ(op->GetAttrIfExists<float>("Scale_in"), scale);
EXPECT_EQ(op->GetAttrIfExists<float>("Scale_out"), scale);
EXPECT_EQ(op->GetAttrIfExists<float>("Sum_scale"), scale);
EXPECT_EQ(
op->GetAttrIfExists<std::vector<float>>("Output_shift_scale")[0],
scale / scale_weights[0]);
EXPECT_EQ(op->GetAttrIfExists<float>("Activation_scale"), scale);
if (convWithExistingBias) {
EXPECT_EQ(op->GetAttrIfExists<std::vector<float>>("Bias_scales")[0],
scale * scale_weights[0]);
}
}
}
EXPECT_EQ(original_nodes_num - removed_nodes_count, current_nodes_num);
}
TEST(Int8ScaleCalculationMkldnnPass, int8_scale_calculation_with_no_bias) {
auto scale = 1.0f;
int removed_nodes_count = 0;
auto scale_weights = {1.5f};
MainTest(false, removed_nodes_count, scale, scale_weights);
}
TEST(Int8ScaleCalculationMkldnnPass, int8_scale_calculation_with_bias) {
auto scale = 1.0f;
int removed_nodes_count = 0;
auto scale_weights = {1.5f};
MainTest(true, removed_nodes_count, scale, scale_weights);
}
TEST(Int8ScaleCalculationMkldnnPass,
int8_scale_calculation_with_bias_scale_weights) {
auto scale = 1.0f;
int removed_nodes_count = 0;
std::vector<float> scale_weights = {1.5f, 2.3f};
MainTest(true, removed_nodes_count, scale, scale_weights);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(int8_scale_calculation_mkldnn_pass);
......@@ -571,6 +571,7 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({
"cpu_quantize_pass", "cpu_quantize_squash_pass",
"int8_scale_calculation_mkldnn_pass",
});
if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses();
......
......@@ -223,9 +223,17 @@ class ConvMKLDNNHandlerT
float sum_scale = 1.0f;
float activation_scale = 1.0f;
std::vector<float> output_shift_scale;
if (platform::is_int8<T>())
std::tie(sum_scale, output_shift_scale, activation_scale) =
get_int8_scales(ctx);
if (platform::is_int8<T>()) {
if (ctx.HasAttr("Sum_scale")) {
sum_scale = ctx.Attr<float>("Sum_scale");
activation_scale = ctx.Attr<float>("Activation_scale");
output_shift_scale =
ctx.Attr<std::vector<float>>("Output_shift_scale");
} else {
std::tie(sum_scale, output_shift_scale, activation_scale) =
get_int8_scales(ctx);
}
}
const dnnl::primitive_attr conv_attr = CreatePostOps(
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
......@@ -872,8 +880,18 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
{DNNL_ARG_DST, *dst_memory_p}};
if (bias) {
auto p_scales_tuple = handler.get_int8_bias_scales(ctx);
std::vector<float> bias_scales;
auto p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder), bias_scales));
if (ctx.HasAttr("Bias_scales")) {
bias_scales = ctx.Attr<std::vector<float>>("Bias_scales");
p_scales_tuple =
std::make_shared<std::tuple<float, std::vector<float>>>(
std::make_tuple(static_cast<float>(mask_reorder), bias_scales));
} else {
p_scales_tuple = handler.get_int8_bias_scales(ctx);
}
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
bias, true, std::get<1>(*p_scales_tuple),
std::get<0>(*p_scales_tuple));
......
......@@ -668,4 +668,5 @@ class Quant2Int8MkldnnPass(object):
graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],
[self._var_quant_scales, self._get_data_layout(graph)])
graph = self._apply_pass(graph, 'cpu_quantize_squash_pass')
graph = self._apply_pass(graph, 'int8_scale_calculation_mkldnn_pass')
return graph
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import unittest
import hypothesis.strategies as st
class TestInt8ScaleCalculationMkldnnPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=False)
config.pass_builder().append_pass("int8_scale_calculation_mkldnn_pass")
yield config, ["conv2d"], (1e-4, 1e-5)
def is_program_valid(self, prog_config):
paddings = prog_config.ops[0].attrs["paddings"]
strides = prog_config.ops[0].attrs["strides"]
groups = prog_config.ops[0].attrs["groups"]
padding_algorithm = prog_config.ops[0].attrs["padding_algorithm"]
dilations = prog_config.ops[0].attrs["dilations"]
data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["filter"].shape
input_shape = prog_config.inputs["input_x"].shape
if padding_algorithm == "VALID":
if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
return False
if padding_algorithm == "EXPLICIT":
if ((input_shape[2] + paddings[0] + paddings[1] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] + paddings[2] + paddings[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
return False
if data_format == "NCHW":
if input_shape[1] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
else:
if input_shape[3] != filter_shape[1] * groups:
return False
if filter_shape[0] % groups != 0:
return False
return True
def sample_program_config(self, draw):
x_shape = draw(
st.lists(
st.integers(
min_value=5, max_value=100), min_size=4, max_size=4))
x_shape[1] = draw(st.integers(min_value=5, max_value=10))
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
f_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=4), min_size=4, max_size=4))
if data_format == "NCHW":
f_shape[1] = x_shape[1]
else:
f_shape[1] = x_shape[3]
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=4), min_size=2, max_size=2))
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=4), min_size=4, max_size=4))
groups = draw(st.integers(min_value=1, max_value=3))
dilations = draw(
st.lists(
st.integers(
min_value=1, max_value=4), min_size=2, max_size=2))
bias_shape = [f_shape[0]]
inputs = dict()
weights = dict()
use_mkldnn = True
has_bias = draw(st.booleans())
if has_bias:
inputs = {
"Input": ["input_x"],
"Filter": ["filter"],
}
weights = {
"filter": TensorConfig(shape=f_shape),
"bias": TensorConfig(shape=bias_shape),
}
else:
inputs = {
"Input": ["input_x"],
"Filter": ["filter"],
}
weights = {"filter": TensorConfig(shape=f_shape), }
conv2d_op = OpConfig(
"conv2d",
inputs=inputs,
outputs={"Output": ["conv2d_out"]},
strides=strides,
padding_algorithm=padding_algorithm,
paddings=padding,
groups=groups,
dilations=dilations,
data_format=data_format,
use_mkldnn=use_mkldnn,
mkldnn_data_type="int8")
ops = [conv2d_op]
program_config = ProgramConfig(
ops=ops,
weights=weights,
inputs={"input_x": TensorConfig(shape=x_shape)},
outputs=["conv2d_out"])
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=100,
passes=["int8_scale_calculation_mkldnn_pass"])
if __name__ == "__main__":
unittest.main()
......@@ -655,6 +655,7 @@ STATIC_MODE_TESTING_LIST = [
'test_transpose_mkldnn_op',
'test_mkldnn_conv_activation_fuse_pass',
'test_mkldnn_conv_concat_relu_mkldnn_fuse_pass',
'test_mkldnn_int8_scale_calculation_pass',
'test_mkldnn_matmul_op_output_fuse_pass',
'test_mkldnn_matmul_transpose_reshape_fuse_pass',
'test_mkldnn_scale_matmul_fuse_pass',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册