未验证 提交 66dccd7d 编写于 作者: Y yeliang2258 提交者: GitHub

Add unsigned int8 scale propagation (#46378) (#47156)

* Add unsigned int8 propagation

* Add or modify unit tests

* Correct concat scale checking

* Apply review suggestions

* Corrections
Co-authored-by: Njoanna.wozna.intel <joanna.wozna@intel.com>
上级 5a9befea
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
...@@ -394,8 +393,13 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales( ...@@ -394,8 +393,13 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
auto out_iter = var_quant_scales->find(op_node->Op()->Output("Out")[0]); auto out_iter = var_quant_scales->find(op_node->Op()->Output("Out")[0]);
if (out_iter != var_quant_scales->end()) { if (out_iter != var_quant_scales->end()) {
std::vector<std::string> input_names = op_node->Op()->Input("X"); std::vector<std::string> input_names = op_node->Op()->Input("X");
for (auto input_name : input_names) for (auto input_name : input_names) {
auto concat_in_iter = var_quant_scales->find(input_name);
if (concat_in_iter == var_quant_scales->end())
(*var_quant_scales)[input_name] = out_iter->second; (*var_quant_scales)[input_name] = out_iter->second;
else
(*var_quant_scales)[input_name].second = out_iter->second.second;
}
} }
} else if (op_name == "scale") { } else if (op_name == "scale") {
const std::string output_name = op_node->Op()->Output("Out")[0]; const std::string output_name = op_node->Op()->Output("Out")[0];
...@@ -409,6 +413,40 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales( ...@@ -409,6 +413,40 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
} }
return waiting_for_scale; return waiting_for_scale;
} }
void ComputePropagateScalesMkldnnPass::UpdateReluOutputScales(
ir::Graph* graph, StringPairMap* var_quant_scales) const {
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue;
auto op = op_node->Op();
bool is_unsigned = false;
std::string output_name = "Out";
std::string act_name;
if (op->Type() == "relu") {
is_unsigned = true;
} else {
if (op->Type() == "conv2d") {
act_name = "fuse_activation";
output_name = "Output";
} else if (op->Type() == "fc") {
act_name = "activation_type";
}
if (!act_name.empty()) {
auto act = op->GetAttrIfExists<std::string>(act_name);
if (act == "relu" || act == "relu6") {
is_unsigned = true;
}
}
}
if (is_unsigned) {
std::string output_var_name = op->Output(output_name)[0];
auto out_iter = var_quant_scales->find(output_var_name);
if (out_iter != var_quant_scales->end()) {
(*var_quant_scales)[output_var_name].first = true;
}
}
}
}
void ComputePropagateScalesMkldnnPass::PropagateScales( void ComputePropagateScalesMkldnnPass::PropagateScales(
ir::Graph* graph, ir::Graph* graph,
...@@ -427,21 +465,6 @@ void ComputePropagateScalesMkldnnPass::PropagateScales( ...@@ -427,21 +465,6 @@ void ComputePropagateScalesMkldnnPass::PropagateScales(
} }
} }
void ComputePropagateScalesMkldnnPass::ConvertStringPairMap(
const StringPairMap& var_quant_scales,
std::unordered_map<std::string, std::vector<float>>* info_map) const {
for (auto iter = var_quant_scales.begin(); iter != var_quant_scales.end();
iter++) {
auto* data = iter->second.second.data<float>();
std::vector<float> data_v;
for (int i = 0; i < iter->second.second.numel(); i++) {
data_v.push_back(data[i]);
}
info_map->insert(std::make_pair(iter->first, data_v));
}
}
void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const { void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Convert paddle model to mkldnn quantized model."; VLOG(3) << "Convert paddle model to mkldnn quantized model.";
const std::string pattern_name = "compute_propagate_scales_mkldnn_pass"; const std::string pattern_name = "compute_propagate_scales_mkldnn_pass";
...@@ -461,13 +484,13 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const { ...@@ -461,13 +484,13 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const {
auto* scope = param_scope(); auto* scope = param_scope();
GetQuantInfo(graph, &var_quant_scales); GetQuantInfo(graph, &var_quant_scales);
ComputeWeightScales(graph, scope, &var_quant_scales); ComputeWeightScales(graph, scope, &var_quant_scales);
UpdateReluOutputScales(graph, &var_quant_scales);
PropagateScales(graph, &var_quant_scales, scale_immutable_ops); PropagateScales(graph, &var_quant_scales, scale_immutable_ops);
// save var_quant_scales in the first op's attr // save var_quant_scales in the first op's attr
// for cpu_quantize_pass // for cpu_quantize_pass
std::unordered_map<std::string, std::vector<float>> info_map; SaveInfoInTheFirstOp(
ConvertStringPairMap(var_quant_scales, &info_map); graph, "has_quant_info", "var_quant_scales", var_quant_scales);
SaveInfoInTheFirstOp(graph, "has_quant_info", "var_quant_scales", info_map);
} }
} // namespace ir } // namespace ir
......
...@@ -17,13 +17,12 @@ ...@@ -17,13 +17,12 @@
#include <string> #include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
using StringPairMap = std::unordered_map<std::string, std::pair<bool, Tensor>>;
class ComputePropagateScalesMkldnnPass : public FusePassBase { class ComputePropagateScalesMkldnnPass : public FusePassBase {
public: public:
ComputePropagateScalesMkldnnPass() = default; ComputePropagateScalesMkldnnPass() = default;
...@@ -77,6 +76,9 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase { ...@@ -77,6 +76,9 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase {
Scope* scope, Scope* scope,
StringPairMap* var_quant_scales) const; StringPairMap* var_quant_scales) const;
void UpdateReluOutputScales(ir::Graph* graph,
StringPairMap* var_quant_scales) const;
void UpdateScaleOpInScale(Node* op_node, void UpdateScaleOpInScale(Node* op_node,
const std::string& input_name, const std::string& input_name,
const std::string& output_name, const std::string& output_name,
...@@ -91,10 +93,6 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase { ...@@ -91,10 +93,6 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase {
ir::Graph* graph, ir::Graph* graph,
StringPairMap* var_quant_scales, StringPairMap* var_quant_scales,
const std::unordered_set<std::string>& scale_immutable_ops) const; const std::unordered_set<std::string>& scale_immutable_ops) const;
void ConvertStringPairMap(
const StringPairMap& var_quant_scales,
std::unordered_map<std::string, std::vector<float>>* info_map) const;
}; };
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <unordered_map>
#include "paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h" #include "paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h"
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
...@@ -91,11 +92,16 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { ...@@ -91,11 +92,16 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test {
graph, scope, wx_name, wh_name, var_quant_scales); graph, scope, wx_name, wh_name, var_quant_scales);
} }
void UpdateReluOutputScales(ir::Graph* graph,
StringPairMap* var_quant_scales) const {
pass->UpdateReluOutputScales(graph, var_quant_scales);
}
void InitTensorHolder(Scope* scope, void InitTensorHolder(Scope* scope,
const paddle::platform::Place& place, const paddle::platform::Place& place,
const std::string& var_name) { const std::string& var_name) {
auto x = scope->Var(var_name); auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>(); auto tensor = x->GetMutable<phi::DenseTensor>();
auto tensor_size = 1; auto tensor_size = 1;
if (var_name == "filter") { if (var_name == "filter") {
tensor_size = positive_and_negative_values.size(); tensor_size = positive_and_negative_values.size();
...@@ -124,7 +130,6 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { ...@@ -124,7 +130,6 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test {
} }
void ComputeRnnWeightScalesTest(const std::string& type, void ComputeRnnWeightScalesTest(const std::string& type,
const std::initializer_list<std::string>& ops,
const framework::ProgramDesc& prog, const framework::ProgramDesc& prog,
std::vector<double> scales) { std::vector<double> scales) {
ir::Graph* graph(new ir::Graph(prog)); ir::Graph* graph(new ir::Graph(prog));
...@@ -140,7 +145,7 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { ...@@ -140,7 +145,7 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test {
StringPairMap var_quant_scales; StringPairMap var_quant_scales;
auto* wx_var = scope.FindVar(wx_var_names); auto* wx_var = scope.FindVar(wx_var_names);
auto* wx_tensor = wx_var->GetMutable<LoDTensor>(); auto* wx_tensor = wx_var->GetMutable<phi::DenseTensor>();
wx_tensor->Resize(phi::make_dim(wx.size(), wx[0].size())); wx_tensor->Resize(phi::make_dim(wx.size(), wx[0].size()));
for (size_t i = 0; i < wx.size(); i++) for (size_t i = 0; i < wx.size(); i++)
std::copy(begin(wx[i]), std::copy(begin(wx[i]),
...@@ -149,7 +154,7 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { ...@@ -149,7 +154,7 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test {
i * wx[0].size()); i * wx[0].size());
auto* wh_var = scope.FindVar(wh_var_names); auto* wh_var = scope.FindVar(wh_var_names);
auto* wh_tensor = wh_var->GetMutable<LoDTensor>(); auto* wh_tensor = wh_var->GetMutable<phi::DenseTensor>();
wh_tensor->Resize(phi::make_dim(wh.size(), wh[0].size())); wh_tensor->Resize(phi::make_dim(wh.size(), wh[0].size()));
for (size_t i = 0; i < wh.size(); i++) for (size_t i = 0; i < wh.size(); i++)
std::copy(begin(wh[i]), std::copy(begin(wh[i]),
...@@ -174,6 +179,24 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { ...@@ -174,6 +179,24 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test {
} }
} }
void UpdateReluOutputScaleTest(
const framework::ProgramDesc& prog,
StringPairMap* var_quant_scales,
const std::initializer_list<std::string>& variable_names) {
ir::Graph* graph(new ir::Graph(prog));
Scope scope;
PrepareGraph(graph, prog, &scope, conv_variable_names);
UpdateReluOutputScales(graph, var_quant_scales);
for (auto& var_name : variable_names) {
auto iter = var_quant_scales->find(var_name);
ASSERT_NE(iter, var_quant_scales->end());
ASSERT_EQ((*var_quant_scales)[var_name].first, true);
}
}
private: private:
std::unique_ptr<ComputePropagateScalesMkldnnPass> pass; std::unique_ptr<ComputePropagateScalesMkldnnPass> pass;
}; };
...@@ -182,11 +205,15 @@ void SetOp(ProgramDesc* prog, ...@@ -182,11 +205,15 @@ void SetOp(ProgramDesc* prog,
const std::string& type, const std::string& type,
const std::string& name, const std::string& name,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) { const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& attrs = {}) {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
op->SetAttr("use_mkldnn", true); op->SetAttr("use_mkldnn", true);
op->SetAttr("name", name); op->SetAttr("name", name);
if (!attrs.empty())
for (auto& attr : attrs) op->SetAttr(attr.first, attr.second);
if (type == "conv2d") { if (type == "conv2d") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
...@@ -211,6 +238,23 @@ ProgramDesc BuildConv2dProgramDesc() { ...@@ -211,6 +238,23 @@ ProgramDesc BuildConv2dProgramDesc() {
return prog; return prog;
} }
ProgramDesc BuildConv2dReluProgramDesc() {
ProgramDesc prog;
for (auto& v : conv_variable_names) {
prog.MutableBlock(0)->Var(v);
}
std::unordered_map<std::string, std::string> attrs = {
{"fuse_activation", "relu"}};
SetOp(&prog,
"conv2d",
"Conv2d",
{"conv_in", "filter", "bias"},
{"conv_out"},
attrs);
return prog;
}
ProgramDesc BuildFusionGruProgramDesc() { ProgramDesc BuildFusionGruProgramDesc() {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : rnn_variable_names) { for (auto& v : rnn_variable_names) {
...@@ -262,7 +306,7 @@ TEST_F(ComputePropagateScalesMkldnnPassTest, compute_var_scales) { ...@@ -262,7 +306,7 @@ TEST_F(ComputePropagateScalesMkldnnPassTest, compute_var_scales) {
StringPairMap var_quant_scales; StringPairMap var_quant_scales;
auto* var = scope.FindVar(weight_var_name); auto* var = scope.FindVar(weight_var_name);
auto* weight_tensor = var->GetMutable<LoDTensor>(); auto* weight_tensor = var->GetMutable<phi::DenseTensor>();
weight_tensor->Resize(phi::make_dim(1, values.size())); weight_tensor->Resize(phi::make_dim(1, values.size()));
std::copy(begin(values), std::copy(begin(values),
end(values), end(values),
...@@ -283,15 +327,24 @@ TEST_F(ComputePropagateScalesMkldnnPassTest, compute_var_scales) { ...@@ -283,15 +327,24 @@ TEST_F(ComputePropagateScalesMkldnnPassTest, compute_var_scales) {
} }
TEST_F(ComputePropagateScalesMkldnnPassTest, compute_gru_weight_scales) { TEST_F(ComputePropagateScalesMkldnnPassTest, compute_gru_weight_scales) {
ComputeRnnWeightScalesTest("gru", ComputeRnnWeightScalesTest("gru", BuildFusionGruProgramDesc(), gru_scales);
{"fusion_gru", "multi_gru"},
BuildFusionGruProgramDesc(),
gru_scales);
} }
TEST_F(ComputePropagateScalesMkldnnPassTest, compute_lstm_weight_scales) { TEST_F(ComputePropagateScalesMkldnnPassTest, compute_lstm_weight_scales) {
ComputeRnnWeightScalesTest( ComputeRnnWeightScalesTest("lstm", BuildFusionLstmProgramDesc(), lstm_scales);
"lstm", {"fusion_lstm"}, BuildFusionLstmProgramDesc(), lstm_scales); }
TEST_F(ComputePropagateScalesMkldnnPassTest, update_relu_output_scales) {
StringPairMap var_quant_scales;
for (auto& var_name : conv_variable_names) {
phi::DenseTensor tensor;
auto* data = tensor.mutable_data<float>({1}, platform::CPUPlace());
data[0] = 10;
auto pair = std::make_pair(false, tensor);
var_quant_scales.insert(std::make_pair(var_name, pair));
}
UpdateReluOutputScaleTest(
BuildConv2dReluProgramDesc(), &var_quant_scales, {"conv_out"});
} }
} // namespace ir } // namespace ir
......
...@@ -229,6 +229,7 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, ...@@ -229,6 +229,7 @@ void CPUQuantizePass::DequantizeOutput(Graph* g,
std::vector<std::string>({dequantize_in_node->Name()})); std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()})); deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
deq_desc.SetAttr("Scale", scale); deq_desc.SetAttr("Scale", scale);
deq_desc.SetAttr("is_negative_input", !is_unsigned);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied. auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
// update op's output // update op's output
...@@ -332,20 +333,8 @@ bool CPUQuantizePass::IsOpQuantized(const Node* node) const { ...@@ -332,20 +333,8 @@ bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
} }
void CPUQuantizePass::GetQuantInfo(Graph* graph) const { void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
std::unordered_map<std::string, std::vector<float>> info_map{}; GetInfoFromTheFirstOp(
GetInfoFromTheFirstOp(graph, "has_quant_info", "var_quant_scales", &info_map); graph, "has_quant_info", "var_quant_scales", var_quant_scales_);
for (auto iter = info_map.begin(); iter != info_map.end(); iter++) {
LoDTensor tensor;
const int size = static_cast<int>(iter->second.size());
auto* data = tensor.mutable_data<double>({size}, platform::CPUPlace());
for (int i = 0; i < size; i++) {
data[i] = static_cast<double>(iter->second[i]);
}
auto pair = std::make_pair(false, tensor);
var_quant_scales_->insert(std::make_pair(iter->first, pair));
}
} }
void CPUQuantizePass::QuantizeConv(Graph* graph, void CPUQuantizePass::QuantizeConv(Graph* graph,
...@@ -597,6 +586,20 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { ...@@ -597,6 +586,20 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
return; return;
} }
bool are_all_inputs_unsigned{true};
// if all inputs were unsigned, then the output was set to unsigned
// during the scale calculation step
auto inputs = concat_op->inputs;
for (size_t i = 0; i < inputs.size(); i++) {
if (AreScalesPresentForVarNames({inputs[i]->Name()})) {
auto scale_data = GetScaleDataByName(inputs[i]->Name());
if (scale_data.first == false) {
are_all_inputs_unsigned = false;
break;
}
}
}
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern); GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);
if (!AreScalesPresentForNodes({concat_out})) { if (!AreScalesPresentForNodes({concat_out})) {
...@@ -605,17 +608,12 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { ...@@ -605,17 +608,12 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
return; return;
} }
// if all inputs were unsigned, then the output was set to unsigned auto output_scale = GetScaleValueForNode(concat_out);
// during the scale calculation step
bool are_all_inputs_unsigned{false};
auto output_scale =
GetScaleValueForNode(concat_out, &are_all_inputs_unsigned);
QuantizeInputs(g, concat_op, "X", are_all_inputs_unsigned); QuantizeInputs(g, concat_op, "X", are_all_inputs_unsigned);
DequantizeOutput( DequantizeOutput(
g, concat_op, concat_out, "Out", output_scale, are_all_inputs_unsigned); g, concat_op, concat_out, "Out", output_scale, are_all_inputs_unsigned);
++quantize_concat_count; ++quantize_concat_count;
}; };
......
...@@ -104,51 +104,24 @@ void CPUQuantizeSquashPass::FindNodesToKeep( ...@@ -104,51 +104,24 @@ void CPUQuantizeSquashPass::FindNodesToKeep(
AddStatis(found_count); AddStatis(found_count);
} }
bool CPUQuantizeSquashPass::IsDequantizeInputUint8(
const Node* dequant_in) const {
PADDLE_ENFORCE_EQ(
dequant_in->inputs.size(),
1,
platform::errors::InvalidArgument(
"Dequantize (id: %f) should have only one input.", dequant_in->id()));
if (dequant_in->inputs[0]->IsOp()) {
auto prev_op = dequant_in->inputs[0]->Op();
std::string act_name;
if (prev_op->Type() == "relu") {
return true;
} else {
if (prev_op->Type() == "conv2d") {
act_name = "fuse_activation";
} else if (prev_op->Type() == "fc") {
act_name = "activation_type";
}
if (!act_name.empty()) {
auto act = prev_op->GetAttrIfExists<std::string>(act_name);
if (act == "relu" || act == "relu6") {
return true;
}
}
}
}
return false;
}
bool CPUQuantizeSquashPass::IsDequantizeQuantizeIncompatible( bool CPUQuantizeSquashPass::IsDequantizeQuantizeIncompatible(
Node* quant_op, Node* dequant_in, Node* next_op) const { Node* quant_op, Node* dequant_op, Node* next_op) const {
bool is_concat_signed = bool is_next_op_signed =
quant_op->Op()->GetAttrIfExists<bool>("is_negative_input"); quant_op->Op()->GetAttrIfExists<bool>("is_negative_input");
bool is_input_unsigned = IsDequantizeInputUint8(dequant_in); bool is_input_signed =
dequant_op->Op()->GetAttrIfExists<bool>("is_negative_input");
/* TODO(sfraczek): remove elementwise from this condition when BinaryMKLDNN /* TODO(sfraczek): remove elementwise from this condition when BinaryMKLDNN
kernel will support two different input data types */ kernel will support two different input data types */
bool is_next_op_concat_or_elementwise = bool is_next_op_concat_or_elementwise =
next_op->Op()->Type() == "concat" || next_op->Op()->Type() == "concat" ||
next_op->Op()->Type().find("elementwise") == 0; next_op->Op()->Type().find("elementwise") == 0;
if (is_next_op_concat_or_elementwise && is_concat_signed && if (is_next_op_concat_or_elementwise &&
is_input_unsigned) { (is_next_op_signed ^ is_input_signed)) {
VLOG(4) << "Do not squash dequant-quant, because " VLOG(4) << "Do not squash dequant-quant, because "
<< "next_op is: " << next_op->Op()->Type() << "next_op is: " << next_op->Op()->Type()
<< ", is_concat_signed: " << is_concat_signed << ", is_next_op_signed: " << is_next_op_signed
<< ", is_input_unsigned: " << is_input_unsigned << "."; << ", is_input_signed: " << is_input_signed << ".";
return true; return true;
} }
return false; return false;
...@@ -173,7 +146,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash( ...@@ -173,7 +146,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern); GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern); GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);
if (IsDequantizeQuantizeIncompatible(quant_op, dequant_in, next_op)) { if (IsDequantizeQuantizeIncompatible(quant_op, dequant_op, next_op)) {
return; return;
} }
......
...@@ -43,11 +43,6 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -43,11 +43,6 @@ class CPUQuantizeSquashPass : public FusePassBase {
Graph* graph, Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const; std::unordered_map<const Node*, int>* nodes_keep_counter) const;
/*
* Check if input to dequantize is uint8
*/
bool IsDequantizeInputUint8(const Node* dequant_in) const;
/* /*
* Don't squash unsigned dequantize with signed quantize. * Don't squash unsigned dequantize with signed quantize.
* This is important for concat and elementwise ops. * This is important for concat and elementwise ops.
...@@ -55,7 +50,7 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -55,7 +50,7 @@ class CPUQuantizeSquashPass : public FusePassBase {
* elementwise assumes first input type. * elementwise assumes first input type.
*/ */
bool IsDequantizeQuantizeIncompatible(Node* quant_op, bool IsDequantizeQuantizeIncompatible(Node* quant_op,
Node* dequant_in, Node* dequant_op,
Node* next_op) const; Node* next_op) const;
/* /*
......
...@@ -68,15 +68,11 @@ void SetOp(ProgramDesc* prog, ...@@ -68,15 +68,11 @@ void SetOp(ProgramDesc* prog,
op->SetAttr("padding_algorithm", std::string("EXPLICIT")); op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("data_format", std::string("NCHW")); op->SetAttr("data_format", std::string("NCHW"));
op->SetAttr("force_fp32_output", false); op->SetAttr("force_fp32_output", false);
} else if (type == "quantize") { } else if (type == "quantize" || type == "dequantize") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale[0]); op->SetAttr("Scale", scale[0]);
op->SetAttr("is_negative_input", is_negative_input); op->SetAttr("is_negative_input", is_negative_input);
} else if (type == "dequantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale[0]);
} else if (type == "requantize") { } else if (type == "requantize") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
...@@ -303,31 +299,22 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, ...@@ -303,31 +299,22 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn,
return prog; return prog;
} }
/* a->relu->b->Dequant->c(u8)->Quant->d-\ /* a->relu->b->Dequant(u8)->c->Quant(u8)->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x * e->relu->f->Dequant(u8)->g->Quant(u8)->h--Concat1->i
* i->relu->j->Dequant->k(u8)->Quant->l-/
*/ */
ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) { ProgramDesc BuildU8U8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names) { for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out}); SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out}); SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu3", {"i"}, {"j"}, true, {scale, scale_out});
SetOp(
&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, {scale, scale_out});
SetOp(
&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, {scale, scale_out});
SetOp(
&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out});
SetOp(&prog, SetOp(&prog,
"quantize", "dequantize",
"Quant1", "Dequant1",
{"b"},
{"c"}, {"c"},
{"d"},
true, true,
{scale, scale_out}, {scale, scale_out},
0.0f, 0.0f,
...@@ -336,10 +323,23 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) { ...@@ -336,10 +323,23 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
1, 1,
false); // is_negative_input = false false); // is_negative_input = false
SetOp(&prog, SetOp(&prog,
"quantize", "dequantize",
"Quant2", "Dequant2",
{"f"},
{"g"}, {"g"},
{"h"}, true,
{scale, scale_out},
0.0f,
"float32",
false,
1,
false); // is_negative_input = false
SetOp(&prog,
"quantize",
"Quant1",
{"c"},
{"d"},
true, true,
{scale, scale_out}, {scale, scale_out},
0.0f, 0.0f,
...@@ -349,9 +349,9 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) { ...@@ -349,9 +349,9 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
false); // is_negative_input = false false); // is_negative_input = false
SetOp(&prog, SetOp(&prog,
"quantize", "quantize",
"Quant3", "Quant2",
{"k"}, {"g"},
{"l"}, {"h"},
true, true,
{scale, scale_out}, {scale, scale_out},
0.0f, 0.0f,
...@@ -360,27 +360,47 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) { ...@@ -360,27 +360,47 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
1, 1,
false); // is_negative_input = false false); // is_negative_input = false
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true); SetOp(&prog, "concat", "Concat1", {"d", "h"}, {"i"}, true);
return prog; return prog;
} }
/* a->relu->b->Dequant->c(u8)->Quant->d-\ /* a->relu->b->Dequant(u8)->c->Quant(s8)->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x * e->relu->f->Dequant(u8)->g->Quant(s8)->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/ * i->pool2d->j->Dequant(s8)->k->Quant(s8)->l-/
*/ */
ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) { ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names) { for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "relu", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out}); SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu1", {"e"}, {"f"}, true, {scale, scale_out}); SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out}); SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out});
SetOp( SetOp(&prog,
&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, {scale, scale_out}); "dequantize",
SetOp( "Dequant1",
&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, {scale, scale_out}); {"b"},
{"c"},
true,
{scale, scale_out},
0.0f,
"float32",
false,
1,
false); // is_negative_input = false
SetOp(&prog,
"dequantize",
"Dequant2",
{"f"},
{"g"},
true,
{scale, scale_out},
0.0f,
"float32",
false,
1,
false); // is_negative_input = false
SetOp( SetOp(
&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out}); &prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out});
...@@ -392,9 +412,9 @@ ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) { ...@@ -392,9 +412,9 @@ ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) {
return prog; return prog;
} }
/* a->pool2d->b->Dequant->c(s8)->Quant->d-\ /* a->pool2d->b->Dequant(s8)->c->Quant(s8)->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x * e->relu->f->Dequant(u8)->g->Quant(s8)->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/ * i->pool2d->j->Dequant(s8)->k->Quant(s8)->l-/
*/ */
ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) { ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog; ProgramDesc prog;
...@@ -407,8 +427,18 @@ ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) { ...@@ -407,8 +427,18 @@ ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) {
SetOp( SetOp(
&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, {scale, scale_out}); &prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, {scale, scale_out});
SetOp( SetOp(&prog,
&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, {scale, scale_out}); "dequantize",
"Dequant2",
{"f"},
{"g"},
true,
{scale, scale_out},
0.0f,
"float32",
false,
1,
false); // is_negative_input = false
SetOp( SetOp(
&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out}); &prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out});
...@@ -1141,13 +1171,12 @@ TEST(CpuQuantizeSquashPass, squash_all_s8_input_to_concat1) { ...@@ -1141,13 +1171,12 @@ TEST(CpuQuantizeSquashPass, squash_all_s8_input_to_concat1) {
} }
TEST(CpuQuantizeSquashPass, squash_all_u8_input_to_concat2) { TEST(CpuQuantizeSquashPass, squash_all_u8_input_to_concat2) {
// removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out) // removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 12; auto remove_nodes = 8;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 3}}; {"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 2}};
CheckNodesTest(BuildU8U8U8ConcatProgramDesc(1.2f, 1.2f), CheckNodesTest(
expected_operators, BuildU8U8ConcatProgramDesc(1.2f, 1.2f), expected_operators, remove_nodes);
remove_nodes);
} }
} // namespace ir } // namespace ir
......
...@@ -22,6 +22,9 @@ namespace paddle { ...@@ -22,6 +22,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
using StringPairMap =
std::unordered_map<std::string, std::pair<bool, phi::DenseTensor>>;
static void SaveInfoInTheFirstOp( static void SaveInfoInTheFirstOp(
ir::Graph* graph, ir::Graph* graph,
const std::string& flag, const std::string& flag,
...@@ -44,6 +47,31 @@ static void SaveInfoInTheFirstOp( ...@@ -44,6 +47,31 @@ static void SaveInfoInTheFirstOp(
} }
} }
static void SaveInfoInTheFirstOp(ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
const StringPairMap& info_map) {
VLOG(3) << "save variables in the first op's attr";
const std::string suffix = "_" + key_suffix + "_" + flag;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;
op_node->Op()->SetAttr(flag, true);
for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
auto* data = iter->second.second.data<float>();
std::vector<float> data_v(data, data + iter->second.second.numel());
op_node->Op()->SetAttr(iter->first + suffix + "_unsigned",
iter->second.first);
op_node->Op()->SetAttr(iter->first + suffix, data_v);
}
break;
}
}
static void GetInfoFromTheFirstOp( static void GetInfoFromTheFirstOp(
ir::Graph* graph, ir::Graph* graph,
const std::string& flag, const std::string& flag,
...@@ -77,6 +105,54 @@ static void GetInfoFromTheFirstOp( ...@@ -77,6 +105,54 @@ static void GetInfoFromTheFirstOp(
} }
} }
static void GetInfoFromTheFirstOp(ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
StringPairMap* info_map) {
VLOG(3) << "get variables from the first op's attr";
const std::string unsigned_flag = "_unsigned";
const std::string suffix = "_" + key_suffix + "_" + flag;
const std::string suffix_is_unsigned = suffix + unsigned_flag;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;
auto* op_desc = op_node->Op();
if (op_desc->GetAttrIfExists<bool>(flag)) {
op_desc->RemoveAttr(flag);
std::vector<std::string> attr_names = op_desc->AttrNames();
for (auto fake_name : attr_names) {
auto is_unsigned = false;
size_t pos = fake_name.find(suffix_is_unsigned);
if (pos != std::string::npos) {
std::string unsigned_var_name = fake_name;
is_unsigned =
PADDLE_GET_CONST(bool, op_desc->GetAttr(unsigned_var_name));
std::string var_name = fake_name.substr(0, pos);
size_t unsigned_pos = fake_name.find(unsigned_flag);
std::string vector_name =
fake_name.erase(unsigned_pos, unsigned_flag.length());
auto scales_vector = PADDLE_GET_CONST(std::vector<float>,
op_desc->GetAttr(vector_name));
phi::DenseTensor tensor;
const int size = static_cast<int>(scales_vector.size());
auto data = tensor.mutable_data<double>({size}, platform::CPUPlace());
std::copy(scales_vector.begin(), scales_vector.end(), data);
auto pair = std::make_pair(is_unsigned, tensor);
info_map->insert(std::make_pair(var_name, pair));
op_desc->RemoveAttr(unsigned_var_name);
op_desc->RemoveAttr(vector_name);
}
}
break;
}
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册