提交 96845d21 编写于 作者: S Sylwester Fraczek 提交者: Tao Luo

add Concat quantization (#17448)

* add Concat quantization
add unit test for quantizing concat
fix for wrong value when the input is not in map of calculated scales
add use_quantizer to concat_op.cc
add scale_algo rules for concat

test=develop

* missing fix for multiple inputs quantize-squash

* wojtuss review fix: adding comment

test=develop
上级 432ac701
...@@ -1214,6 +1214,17 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { ...@@ -1214,6 +1214,17 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var; return out_var;
} }
PDNode *patterns::Concat::operator()() {
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto output_var = pattern->NewNode(concat_out_repr())
->AsOutput()
->assert_is_op_output("concat", "Out");
concat_op->LinksTo({output_var});
return output_var;
}
PDNode *patterns::ConcatReLU::operator()() { PDNode *patterns::ConcatReLU::operator()() {
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat"); auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu"); auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
......
...@@ -747,6 +747,19 @@ struct ElementwiseAdd : public PatternBase { ...@@ -747,6 +747,19 @@ struct ElementwiseAdd : public PatternBase {
PATTERN_DECL_NODE(elementwise_add_out); PATTERN_DECL_NODE(elementwise_add_out);
}; };
// Concat op
// Forward pass for concat.
// concat_out is a result of the operator.
struct Concat : public PatternBase {
Concat(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "concat") {}
PDNode* operator()();
PATTERN_DECL_NODE(concat_op);
PATTERN_DECL_NODE(concat_out);
};
// Concat + ReLU // Concat + ReLU
// named nodes: // named nodes:
// concat_op, concat_out, relu_op, relu_out // concat_op, concat_out, relu_op, relu_out
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h" #include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"
#include <limits>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
...@@ -72,6 +73,53 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, ...@@ -72,6 +73,53 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
} }
void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name,
VarQuantScale* scales, bool are_unsigned,
std::string scale_attr_name) const {
auto inputs = op->inputs;
PADDLE_ENFORCE_GE(inputs.size(), 1);
// create a quantize op desc prototype
OpDesc q_desc;
q_desc.SetType("quantize");
std::vector<Node*> quantize_out_nodes(inputs.size());
std::vector<std::string> quantize_out_node_names(inputs.size());
double scale_min = std::numeric_limits<double>::max();
for (const auto& input : inputs) {
double scale = (*scales)[input->Name()].second.data<double>()[0];
if (scale < scale_min) scale_min = scale;
}
unsigned max = are_unsigned ? U8_MAX : S8_MAX;
float scale = scale_min * max;
for (size_t i = 0; i < inputs.size(); i++) {
// Create quantize output variable
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc);
quantize_out_node_names[i] = quantize_out_nodes[i]->Name();
q_desc.SetAttr("Scale", scale);
q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node_names[i]}));
q_desc.SetAttr("is_negative_input", !are_unsigned);
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
// link quantize op
UnlinkNodes(inputs[i], op);
IR_NODE_LINK_TO(inputs[i], quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
IR_NODE_LINK_TO(quantize_out_nodes[i], op);
}
// update op's input
op->Op()->SetInput(input_name, quantize_out_node_names);
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}
void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output, void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output,
std::string output_name, std::string output_name,
double scale_to_one, bool is_unsigned, double scale_to_one, bool is_unsigned,
...@@ -216,6 +264,48 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const { ...@@ -216,6 +264,48 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count); PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count);
} }
void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Concat concat_pattern{pattern, name_scope_};
concat_pattern();
int quantize_concat_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize concat op";
GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_pattern);
auto* concat_op_desc = concat_op->Op();
// skip if should not be quantized
if (!concat_op_desc->HasAttr("use_quantizer") ||
!boost::get<bool>(concat_op_desc->GetAttr("use_quantizer")))
return;
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);
// get scales calculated after warmup, they scale variables to MAX=1.0
auto scales = Get<VarQuantScale>("quant_var_scales");
// if all inputs were unsigned, then the output was set to unsigned
// during the scale calculation step
bool are_all_inputs_unsigned = scales[concat_out->Name()].first;
QuantizeInputs(g, concat_op, "X", &scales, are_all_inputs_unsigned);
auto output_scale = scales[concat_out->Name()].second.data<double>()[0];
DequantizeOutput(g, concat_op, concat_out, "Out", output_scale,
are_all_inputs_unsigned);
++quantize_concat_count;
};
gpd(graph, handler);
AddStatis(quantize_concat_count);
PrettyLogDetail("--- quantized %d concat ops", quantize_concat_count);
}
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph."; VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph); PADDLE_ENFORCE(graph);
...@@ -226,6 +316,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -226,6 +316,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeConv(graph, false /* with_residual_data */); QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph, true /* with_residual_data */); QuantizeConv(graph, true /* with_residual_data */);
QuantizePool(graph); QuantizePool(graph);
QuantizeConcat(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -48,10 +48,17 @@ class CPUQuantizePass : public FusePassBase { ...@@ -48,10 +48,17 @@ class CPUQuantizePass : public FusePassBase {
void QuantizePool(Graph* graph) const; void QuantizePool(Graph* graph) const;
void QuantizeConcat(Graph* graph) const;
void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name,
double scale_to_one, bool is_unsigned, double scale_to_one, bool is_unsigned,
std::string scale_attr_name = "") const; std::string scale_attr_name = "") const;
// quantize all inputs of given name with the same (minimum) scale
void QuantizeInputs(Graph* g, Node* op, std::string input_name,
VarQuantScale* scales, bool are_unsigned,
std::string scale_attr_name = "") const;
void DequantizeOutput(Graph* g, Node* op, Node* output, void DequantizeOutput(Graph* g, Node* op, Node* output,
std::string output_name, double scale_to_one, std::string output_name, double scale_to_one,
bool is_unsigned, bool is_unsigned,
......
...@@ -60,9 +60,14 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -60,9 +60,14 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
if (inputs.size() > 1) op->SetInput("W", {inputs[1]}); if (inputs.size() > 1) op->SetInput("W", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]}); if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
op->SetAttr("use_quantizer", use_quantizer);
} }
} }
namespace {
static const std::initializer_list<std::string> variable_names{ static const std::initializer_list<std::string> variable_names{
"a", "w1", "c", "d", "w2", "e", "f", "g", "a", "w1", "c", "d", "w2", "e", "f", "g",
"h", "w3", "b1", "i", "j", "w4", "b2"}; "h", "w3", "b1", "i", "j", "w4", "b2"};
...@@ -204,6 +209,101 @@ TEST(CpuQuantizePass, do_not_quantize) { ...@@ -204,6 +209,101 @@ TEST(CpuQuantizePass, do_not_quantize) {
1.0f); 1.0f);
} }
} // namespace
namespace {
static const std::initializer_list<std::string> variable_names_concat = {
"a1", "b1", "a2", "b2", "c", "d"};
// a1->Pool1->b1
// a2->Pool2->b2
// (b1,b2)->Concat->c
// c->Pool3->d
ProgramDesc BuildProgramDescConcat() {
ProgramDesc prog;
SetOp(&prog, "pool2d", "Pool1", {"a1"}, {"b1"}, true, false);
SetOp(&prog, "pool2d", "Pool2", {"a2"}, {"b2"}, true, false);
SetOp(&prog, "concat", "Concat", {"b1", "b2"}, {"c"}, true, true);
SetOp(&prog, "pool2d", "Pool3", {"c"}, {"d"}, true, false);
return prog;
}
void MainTestConcat(const ProgramDesc& prog, int pool_count, int concat_count,
int quant_count, int dequant_count, int added_nodes_count) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
// Init scope, as it is used in pass
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
exe.CreateVariables(prog, 0, true, &scope);
auto* scales = new VarQuantScale();
for (auto& v : variable_names_concat) {
InitTensorHolder(&scope, place, v.c_str());
LoDTensor tensor;
tensor.Resize({1});
auto* ptr = tensor.mutable_data<double>(place);
ptr[0] = 2.0;
(*scales)[v] = std::make_pair(false, std::move(tensor));
}
graph->SetNotOwned(kParamScopeAttr, &scope);
auto pass = PassRegistry::Instance().Get("cpu_quantize_pass");
pass->Set("quant_var_scales", scales);
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int concat_nodes_count = 0;
int pool2d_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "concat") {
concat_nodes_count++;
} else if (op->Type() == "pool2d") {
pool2d_nodes_count++;
} else if (op->Type() == "quantize") {
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
}
}
}
EXPECT_EQ(concat_nodes_count, concat_count);
EXPECT_EQ(pool2d_nodes_count, pool_count);
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}
TEST(CpuQuantizePass, concat) {
// a1->Pool1->b1
// a2->Pool2->b2
// (b1->QUANT1->IN1, b2->QUANT2->IN2)->Concat->c
// c->OUT1->DEQUANT1->Pool3->d
int pool_count = 3;
int concat_count = 1;
int quant_count = 2;
int dequant_count = 1;
int added_nodes_count = 6;
MainTestConcat(BuildProgramDescConcat(), pool_count, concat_count,
quant_count, dequant_count, added_nodes_count);
}
} // namespace
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h" #include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
#include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -81,15 +82,10 @@ void CPUQuantizeSquashPass::Squash( ...@@ -81,15 +82,10 @@ void CPUQuantizeSquashPass::Squash(
auto quant_out_var_name = quant_out->Name(); auto quant_out_var_name = quant_out->Name();
auto next_op_inputs = next_op_desc->InputNames(); auto next_op_inputs = next_op_desc->InputNames();
for (const auto& name : next_op_inputs) { for (const auto& name : next_op_inputs) {
if (next_op_desc->Inputs().count(name) == 0 || auto input_names = next_op_desc->Input(name);
next_op_desc->Input(name).size() == 0) std::replace(input_names.begin(), input_names.end(), quant_out_var_name,
continue; dequant_in->Name());
auto var_name = next_op_desc->Input(name)[0]; next_op_desc->SetInput(name, input_names);
if (var_name.compare(quant_out_var_name) == 0) {
next_op_desc->SetInput(
name, std::vector<std::string>({dequant_in->Name()}));
break;
}
} }
if (keep_dequant) if (keep_dequant)
......
...@@ -50,40 +50,46 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() { ...@@ -50,40 +50,46 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() {
auto glambda = [&](const VariableNameMap& connections, bool is_output) { auto glambda = [&](const VariableNameMap& connections, bool is_output) {
for (auto const& conn : connections) { for (auto const& conn : connections) {
if (conn.second.size() == 0) continue; for (const auto& var_name : conn.second) {
auto& var_name = conn.second[0]; // skip if scale already computed
if (scales_.find(var_name) != scales_.end()) return;
// skip if scale already computed
if (scales_.find(var_name) != scales_.end()) return; auto* var = predictor_.sub_scope_->FindVar(var_name);
PADDLE_ENFORCE(var, "%s is not in the scope", var_name);
auto* var = predictor_.sub_scope_->FindVar(var_name); PADDLE_ENFORCE(var->IsType<LoDTensor>(),
PADDLE_ENFORCE(var, "%s is not in the scope", var_name); "Only support lod tensor now.");
PADDLE_ENFORCE(var->IsType<LoDTensor>(), LoDTensor* var_tensor = var->GetMutable<LoDTensor>();
"Only support lod tensor now.");
LoDTensor* var_tensor = var->GetMutable<LoDTensor>(); // force unsigned type if already know it
bool is_unsigned = false;
// force unsigned type if already know it if (is_output && op->Type() == "conv2d") {
bool is_unsigned = false; // output of conv2d with relu must be unsigned
if (is_output && op->Type() == "conv2d") { is_unsigned = op->HasAttr("fuse_relu") &&
// output of conv2d with relu must be unsigned boost::get<bool>(op->GetAttr("fuse_relu"));
is_unsigned = op->HasAttr("fuse_relu") && } else if (is_output && op->Type() == "relu") {
boost::get<bool>(op->GetAttr("fuse_relu")); is_unsigned = true;
} else if (is_output && op->Type() == "pool2d") { } else if (is_output &&
// output of pool2d with unsigned input must be unsigned (op->Type() == "pool2d" || op->Type() == "transpose2" ||
auto input_var_name = op->Input("X")[0]; op->Type() == "reshape2" || op->Type() == "concat")) {
if (scales_.find(input_var_name) != scales_.end()) { // output of ops with unsigned input must be unsigned
is_unsigned = scales_[input_var_name].first; is_unsigned = true;
for (auto input_var_name : op->Input("X")) {
PADDLE_ENFORCE(scales_.find(input_var_name) != scales_.end(),
"Input scales must be calculated before the "
"output scales to infer if output is unsigned.");
is_unsigned = is_unsigned && scales_[input_var_name].first;
}
} }
}
CalculateSingleScale(op->Type(), conn.first, var_name, *var_tensor, CalculateSingleScale(op->Type(), conn.first, var_name, *var_tensor,
is_unsigned); is_unsigned);
}
} }
}; };
// handle outputs first so unsigned outputs could be inferred // handle inputs first to let is_unsigned be inferred for the outputs
glambda(connections_out, true /* is_output */);
glambda(connections_in, false /* is_output */); glambda(connections_in, false /* is_output */);
glambda(connections_out, true /* is_output */);
} }
} }
......
...@@ -22,10 +22,13 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { ...@@ -22,10 +22,13 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["conv2d"]["Filter"] = ScaleAlgo::MAX_CH; rules_["conv2d"]["Filter"] = ScaleAlgo::MAX_CH;
rules_["conv2d"]["Bias"] = ScaleAlgo::NONE; // do not compute scale rules_["conv2d"]["Bias"] = ScaleAlgo::NONE; // do not compute scale
rules_["conv2d"]["ResidualData"] = ScaleAlgo::KL; rules_["conv2d"]["ResidualData"] = ScaleAlgo::KL;
rules_["conv2d"]["Output"] = ScaleAlgo::KL; // do not compute scale rules_["conv2d"]["Output"] = ScaleAlgo::KL;
rules_["pool2d"]["X"] = ScaleAlgo::KL; rules_["pool2d"]["X"] = ScaleAlgo::KL;
rules_["pool2d"]["Out"] = ScaleAlgo::KL; // do not compute scale rules_["pool2d"]["Out"] = ScaleAlgo::KL;
rules_["concat"]["X"] = ScaleAlgo::KL;
rules_["concat"]["Out"] = ScaleAlgo::KL;
} }
ScaleAlgo MkldnnQuantizerConfig::scale_algo( ScaleAlgo MkldnnQuantizerConfig::scale_algo(
......
...@@ -117,6 +117,12 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -117,6 +117,12 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("axis", AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated.") "The axis along which the input tensors will be concatenated.")
.SetDefault(0); .SetDefault(0);
AddAttr<bool>("use_quantizer",
"(bool, default false) "
"Set to true for operators that should be quantized and use "
"int8 kernel. "
"Only used on CPU.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Concat Operator. Concat Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册