未验证 提交 66dccd7d 编写于 作者: Y yeliang2258 提交者: GitHub

Add unsigned int8 scale propagation (#46378) (#47156)

* Add unsigned int8 propagation

* Add or modify unit tests

* Correct concat scale checking

* Apply review suggestions

* Corrections
Co-authored-by: Njoanna.wozna.intel <joanna.wozna@intel.com>
上级 5a9befea
......@@ -19,7 +19,6 @@
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -394,8 +393,13 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
auto out_iter = var_quant_scales->find(op_node->Op()->Output("Out")[0]);
if (out_iter != var_quant_scales->end()) {
std::vector<std::string> input_names = op_node->Op()->Input("X");
for (auto input_name : input_names)
(*var_quant_scales)[input_name] = out_iter->second;
for (auto input_name : input_names) {
auto concat_in_iter = var_quant_scales->find(input_name);
if (concat_in_iter == var_quant_scales->end())
(*var_quant_scales)[input_name] = out_iter->second;
else
(*var_quant_scales)[input_name].second = out_iter->second.second;
}
}
} else if (op_name == "scale") {
const std::string output_name = op_node->Op()->Output("Out")[0];
......@@ -409,6 +413,40 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
}
return waiting_for_scale;
}
void ComputePropagateScalesMkldnnPass::UpdateReluOutputScales(
ir::Graph* graph, StringPairMap* var_quant_scales) const {
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue;
auto op = op_node->Op();
bool is_unsigned = false;
std::string output_name = "Out";
std::string act_name;
if (op->Type() == "relu") {
is_unsigned = true;
} else {
if (op->Type() == "conv2d") {
act_name = "fuse_activation";
output_name = "Output";
} else if (op->Type() == "fc") {
act_name = "activation_type";
}
if (!act_name.empty()) {
auto act = op->GetAttrIfExists<std::string>(act_name);
if (act == "relu" || act == "relu6") {
is_unsigned = true;
}
}
}
if (is_unsigned) {
std::string output_var_name = op->Output(output_name)[0];
auto out_iter = var_quant_scales->find(output_var_name);
if (out_iter != var_quant_scales->end()) {
(*var_quant_scales)[output_var_name].first = true;
}
}
}
}
void ComputePropagateScalesMkldnnPass::PropagateScales(
ir::Graph* graph,
......@@ -427,21 +465,6 @@ void ComputePropagateScalesMkldnnPass::PropagateScales(
}
}
void ComputePropagateScalesMkldnnPass::ConvertStringPairMap(
const StringPairMap& var_quant_scales,
std::unordered_map<std::string, std::vector<float>>* info_map) const {
for (auto iter = var_quant_scales.begin(); iter != var_quant_scales.end();
iter++) {
auto* data = iter->second.second.data<float>();
std::vector<float> data_v;
for (int i = 0; i < iter->second.second.numel(); i++) {
data_v.push_back(data[i]);
}
info_map->insert(std::make_pair(iter->first, data_v));
}
}
void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Convert paddle model to mkldnn quantized model.";
const std::string pattern_name = "compute_propagate_scales_mkldnn_pass";
......@@ -461,13 +484,13 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const {
auto* scope = param_scope();
GetQuantInfo(graph, &var_quant_scales);
ComputeWeightScales(graph, scope, &var_quant_scales);
UpdateReluOutputScales(graph, &var_quant_scales);
PropagateScales(graph, &var_quant_scales, scale_immutable_ops);
// save var_quant_scales in the first op's attr
// for cpu_quantize_pass
std::unordered_map<std::string, std::vector<float>> info_map;
ConvertStringPairMap(var_quant_scales, &info_map);
SaveInfoInTheFirstOp(graph, "has_quant_info", "var_quant_scales", info_map);
SaveInfoInTheFirstOp(
graph, "has_quant_info", "var_quant_scales", var_quant_scales);
}
} // namespace ir
......
......@@ -17,13 +17,12 @@
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
namespace paddle {
namespace framework {
namespace ir {
using StringPairMap = std::unordered_map<std::string, std::pair<bool, Tensor>>;
class ComputePropagateScalesMkldnnPass : public FusePassBase {
public:
ComputePropagateScalesMkldnnPass() = default;
......@@ -77,6 +76,9 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase {
Scope* scope,
StringPairMap* var_quant_scales) const;
void UpdateReluOutputScales(ir::Graph* graph,
StringPairMap* var_quant_scales) const;
void UpdateScaleOpInScale(Node* op_node,
const std::string& input_name,
const std::string& output_name,
......@@ -91,10 +93,6 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase {
ir::Graph* graph,
StringPairMap* var_quant_scales,
const std::unordered_set<std::string>& scale_immutable_ops) const;
void ConvertStringPairMap(
const StringPairMap& var_quant_scales,
std::unordered_map<std::string, std::vector<float>>* info_map) const;
};
} // namespace ir
} // namespace framework
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <gtest/gtest.h>
#include <unordered_map>
#include "paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h"
#include "paddle/fluid/framework/naive_executor.h"
......@@ -91,11 +92,16 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test {
graph, scope, wx_name, wh_name, var_quant_scales);
}
void UpdateReluOutputScales(ir::Graph* graph,
StringPairMap* var_quant_scales) const {
pass->UpdateReluOutputScales(graph, var_quant_scales);
}
void InitTensorHolder(Scope* scope,
const paddle::platform::Place& place,
const std::string& var_name) {
auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>();
auto tensor = x->GetMutable<phi::DenseTensor>();
auto tensor_size = 1;
if (var_name == "filter") {
tensor_size = positive_and_negative_values.size();
......@@ -124,7 +130,6 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test {
}
void ComputeRnnWeightScalesTest(const std::string& type,
const std::initializer_list<std::string>& ops,
const framework::ProgramDesc& prog,
std::vector<double> scales) {
ir::Graph* graph(new ir::Graph(prog));
......@@ -140,7 +145,7 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test {
StringPairMap var_quant_scales;
auto* wx_var = scope.FindVar(wx_var_names);
auto* wx_tensor = wx_var->GetMutable<LoDTensor>();
auto* wx_tensor = wx_var->GetMutable<phi::DenseTensor>();
wx_tensor->Resize(phi::make_dim(wx.size(), wx[0].size()));
for (size_t i = 0; i < wx.size(); i++)
std::copy(begin(wx[i]),
......@@ -149,7 +154,7 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test {
i * wx[0].size());
auto* wh_var = scope.FindVar(wh_var_names);
auto* wh_tensor = wh_var->GetMutable<LoDTensor>();
auto* wh_tensor = wh_var->GetMutable<phi::DenseTensor>();
wh_tensor->Resize(phi::make_dim(wh.size(), wh[0].size()));
for (size_t i = 0; i < wh.size(); i++)
std::copy(begin(wh[i]),
......@@ -174,6 +179,24 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test {
}
}
void UpdateReluOutputScaleTest(
const framework::ProgramDesc& prog,
StringPairMap* var_quant_scales,
const std::initializer_list<std::string>& variable_names) {
ir::Graph* graph(new ir::Graph(prog));
Scope scope;
PrepareGraph(graph, prog, &scope, conv_variable_names);
UpdateReluOutputScales(graph, var_quant_scales);
for (auto& var_name : variable_names) {
auto iter = var_quant_scales->find(var_name);
ASSERT_NE(iter, var_quant_scales->end());
ASSERT_EQ((*var_quant_scales)[var_name].first, true);
}
}
private:
std::unique_ptr<ComputePropagateScalesMkldnnPass> pass;
};
......@@ -182,11 +205,15 @@ void SetOp(ProgramDesc* prog,
const std::string& type,
const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& attrs = {}) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", true);
op->SetAttr("name", name);
if (!attrs.empty())
for (auto& attr : attrs) op->SetAttr(attr.first, attr.second);
if (type == "conv2d") {
op->SetInput("Input", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
......@@ -211,6 +238,23 @@ ProgramDesc BuildConv2dProgramDesc() {
return prog;
}
ProgramDesc BuildConv2dReluProgramDesc() {
ProgramDesc prog;
for (auto& v : conv_variable_names) {
prog.MutableBlock(0)->Var(v);
}
std::unordered_map<std::string, std::string> attrs = {
{"fuse_activation", "relu"}};
SetOp(&prog,
"conv2d",
"Conv2d",
{"conv_in", "filter", "bias"},
{"conv_out"},
attrs);
return prog;
}
ProgramDesc BuildFusionGruProgramDesc() {
ProgramDesc prog;
for (auto& v : rnn_variable_names) {
......@@ -262,7 +306,7 @@ TEST_F(ComputePropagateScalesMkldnnPassTest, compute_var_scales) {
StringPairMap var_quant_scales;
auto* var = scope.FindVar(weight_var_name);
auto* weight_tensor = var->GetMutable<LoDTensor>();
auto* weight_tensor = var->GetMutable<phi::DenseTensor>();
weight_tensor->Resize(phi::make_dim(1, values.size()));
std::copy(begin(values),
end(values),
......@@ -283,15 +327,24 @@ TEST_F(ComputePropagateScalesMkldnnPassTest, compute_var_scales) {
}
TEST_F(ComputePropagateScalesMkldnnPassTest, compute_gru_weight_scales) {
ComputeRnnWeightScalesTest("gru",
{"fusion_gru", "multi_gru"},
BuildFusionGruProgramDesc(),
gru_scales);
ComputeRnnWeightScalesTest("gru", BuildFusionGruProgramDesc(), gru_scales);
}
TEST_F(ComputePropagateScalesMkldnnPassTest, compute_lstm_weight_scales) {
ComputeRnnWeightScalesTest(
"lstm", {"fusion_lstm"}, BuildFusionLstmProgramDesc(), lstm_scales);
ComputeRnnWeightScalesTest("lstm", BuildFusionLstmProgramDesc(), lstm_scales);
}
TEST_F(ComputePropagateScalesMkldnnPassTest, update_relu_output_scales) {
StringPairMap var_quant_scales;
for (auto& var_name : conv_variable_names) {
phi::DenseTensor tensor;
auto* data = tensor.mutable_data<float>({1}, platform::CPUPlace());
data[0] = 10;
auto pair = std::make_pair(false, tensor);
var_quant_scales.insert(std::make_pair(var_name, pair));
}
UpdateReluOutputScaleTest(
BuildConv2dReluProgramDesc(), &var_quant_scales, {"conv_out"});
}
} // namespace ir
......
......@@ -229,6 +229,7 @@ void CPUQuantizePass::DequantizeOutput(Graph* g,
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
deq_desc.SetAttr("Scale", scale);
deq_desc.SetAttr("is_negative_input", !is_unsigned);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
// update op's output
......@@ -332,20 +333,8 @@ bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
}
void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
std::unordered_map<std::string, std::vector<float>> info_map{};
GetInfoFromTheFirstOp(graph, "has_quant_info", "var_quant_scales", &info_map);
for (auto iter = info_map.begin(); iter != info_map.end(); iter++) {
LoDTensor tensor;
const int size = static_cast<int>(iter->second.size());
auto* data = tensor.mutable_data<double>({size}, platform::CPUPlace());
for (int i = 0; i < size; i++) {
data[i] = static_cast<double>(iter->second[i]);
}
auto pair = std::make_pair(false, tensor);
var_quant_scales_->insert(std::make_pair(iter->first, pair));
}
GetInfoFromTheFirstOp(
graph, "has_quant_info", "var_quant_scales", var_quant_scales_);
}
void CPUQuantizePass::QuantizeConv(Graph* graph,
......@@ -597,6 +586,20 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
return;
}
bool are_all_inputs_unsigned{true};
// if all inputs were unsigned, then the output was set to unsigned
// during the scale calculation step
auto inputs = concat_op->inputs;
for (size_t i = 0; i < inputs.size(); i++) {
if (AreScalesPresentForVarNames({inputs[i]->Name()})) {
auto scale_data = GetScaleDataByName(inputs[i]->Name());
if (scale_data.first == false) {
are_all_inputs_unsigned = false;
break;
}
}
}
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);
if (!AreScalesPresentForNodes({concat_out})) {
......@@ -605,17 +608,12 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
return;
}
// if all inputs were unsigned, then the output was set to unsigned
// during the scale calculation step
bool are_all_inputs_unsigned{false};
auto output_scale =
GetScaleValueForNode(concat_out, &are_all_inputs_unsigned);
auto output_scale = GetScaleValueForNode(concat_out);
QuantizeInputs(g, concat_op, "X", are_all_inputs_unsigned);
DequantizeOutput(
g, concat_op, concat_out, "Out", output_scale, are_all_inputs_unsigned);
++quantize_concat_count;
};
......
......@@ -104,51 +104,24 @@ void CPUQuantizeSquashPass::FindNodesToKeep(
AddStatis(found_count);
}
bool CPUQuantizeSquashPass::IsDequantizeInputUint8(
const Node* dequant_in) const {
PADDLE_ENFORCE_EQ(
dequant_in->inputs.size(),
1,
platform::errors::InvalidArgument(
"Dequantize (id: %f) should have only one input.", dequant_in->id()));
if (dequant_in->inputs[0]->IsOp()) {
auto prev_op = dequant_in->inputs[0]->Op();
std::string act_name;
if (prev_op->Type() == "relu") {
return true;
} else {
if (prev_op->Type() == "conv2d") {
act_name = "fuse_activation";
} else if (prev_op->Type() == "fc") {
act_name = "activation_type";
}
if (!act_name.empty()) {
auto act = prev_op->GetAttrIfExists<std::string>(act_name);
if (act == "relu" || act == "relu6") {
return true;
}
}
}
}
return false;
}
bool CPUQuantizeSquashPass::IsDequantizeQuantizeIncompatible(
Node* quant_op, Node* dequant_in, Node* next_op) const {
bool is_concat_signed =
Node* quant_op, Node* dequant_op, Node* next_op) const {
bool is_next_op_signed =
quant_op->Op()->GetAttrIfExists<bool>("is_negative_input");
bool is_input_unsigned = IsDequantizeInputUint8(dequant_in);
bool is_input_signed =
dequant_op->Op()->GetAttrIfExists<bool>("is_negative_input");
/* TODO(sfraczek): remove elementwise from this condition when BinaryMKLDNN
kernel will support two different input data types */
bool is_next_op_concat_or_elementwise =
next_op->Op()->Type() == "concat" ||
next_op->Op()->Type().find("elementwise") == 0;
if (is_next_op_concat_or_elementwise && is_concat_signed &&
is_input_unsigned) {
if (is_next_op_concat_or_elementwise &&
(is_next_op_signed ^ is_input_signed)) {
VLOG(4) << "Do not squash dequant-quant, because "
<< "next_op is: " << next_op->Op()->Type()
<< ", is_concat_signed: " << is_concat_signed
<< ", is_input_unsigned: " << is_input_unsigned << ".";
<< ", is_next_op_signed: " << is_next_op_signed
<< ", is_input_signed: " << is_input_signed << ".";
return true;
}
return false;
......@@ -173,7 +146,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);
if (IsDequantizeQuantizeIncompatible(quant_op, dequant_in, next_op)) {
if (IsDequantizeQuantizeIncompatible(quant_op, dequant_op, next_op)) {
return;
}
......
......@@ -43,11 +43,6 @@ class CPUQuantizeSquashPass : public FusePassBase {
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
/*
* Check if input to dequantize is uint8
*/
bool IsDequantizeInputUint8(const Node* dequant_in) const;
/*
* Don't squash unsigned dequantize with signed quantize.
* This is important for concat and elementwise ops.
......@@ -55,7 +50,7 @@ class CPUQuantizeSquashPass : public FusePassBase {
* elementwise assumes first input type.
*/
bool IsDequantizeQuantizeIncompatible(Node* quant_op,
Node* dequant_in,
Node* dequant_op,
Node* next_op) const;
/*
......
......@@ -68,15 +68,11 @@ void SetOp(ProgramDesc* prog,
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("data_format", std::string("NCHW"));
op->SetAttr("force_fp32_output", false);
} else if (type == "quantize") {
} else if (type == "quantize" || type == "dequantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale[0]);
op->SetAttr("is_negative_input", is_negative_input);
} else if (type == "dequantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale[0]);
} else if (type == "requantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
......@@ -303,31 +299,22 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn,
return prog;
}
/* a->relu->b->Dequant->c(u8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->relu->j->Dequant->k(u8)->Quant->l-/
/* a->relu->b->Dequant(u8)->c->Quant(u8)->d-\
* e->relu->f->Dequant(u8)->g->Quant(u8)->h--Concat1->i
*/
ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc BuildU8U8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu3", {"i"}, {"j"}, true, {scale, scale_out});
SetOp(
&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, {scale, scale_out});
SetOp(
&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, {scale, scale_out});
SetOp(
&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out});
SetOp(&prog,
"quantize",
"Quant1",
"dequantize",
"Dequant1",
{"b"},
{"c"},
{"d"},
true,
{scale, scale_out},
0.0f,
......@@ -336,10 +323,23 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
1,
false); // is_negative_input = false
SetOp(&prog,
"quantize",
"Quant2",
"dequantize",
"Dequant2",
{"f"},
{"g"},
{"h"},
true,
{scale, scale_out},
0.0f,
"float32",
false,
1,
false); // is_negative_input = false
SetOp(&prog,
"quantize",
"Quant1",
{"c"},
{"d"},
true,
{scale, scale_out},
0.0f,
......@@ -349,9 +349,9 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
false); // is_negative_input = false
SetOp(&prog,
"quantize",
"Quant3",
{"k"},
{"l"},
"Quant2",
{"g"},
{"h"},
true,
{scale, scale_out},
0.0f,
......@@ -360,27 +360,47 @@ ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
1,
false); // is_negative_input = false
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
SetOp(&prog, "concat", "Concat1", {"d", "h"}, {"i"}, true);
return prog;
}
/* a->relu->b->Dequant->c(u8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
/* a->relu->b->Dequant(u8)->c->Quant(s8)->d-\
* e->relu->f->Dequant(u8)->g->Quant(s8)->h--Concat1->x
* i->pool2d->j->Dequant(s8)->k->Quant(s8)->l-/
*/
ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "relu", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu1", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out});
SetOp(
&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, {scale, scale_out});
SetOp(
&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, {scale, scale_out});
SetOp(&prog,
"dequantize",
"Dequant1",
{"b"},
{"c"},
true,
{scale, scale_out},
0.0f,
"float32",
false,
1,
false); // is_negative_input = false
SetOp(&prog,
"dequantize",
"Dequant2",
{"f"},
{"g"},
true,
{scale, scale_out},
0.0f,
"float32",
false,
1,
false); // is_negative_input = false
SetOp(
&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out});
......@@ -392,9 +412,9 @@ ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) {
return prog;
}
/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
/* a->pool2d->b->Dequant(s8)->c->Quant(s8)->d-\
* e->relu->f->Dequant(u8)->g->Quant(s8)->h--Concat1->x
* i->pool2d->j->Dequant(s8)->k->Quant(s8)->l-/
*/
ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
......@@ -407,8 +427,18 @@ ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) {
SetOp(
&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, {scale, scale_out});
SetOp(
&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, {scale, scale_out});
SetOp(&prog,
"dequantize",
"Dequant2",
{"f"},
{"g"},
true,
{scale, scale_out},
0.0f,
"float32",
false,
1,
false); // is_negative_input = false
SetOp(
&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, {scale, scale_out});
......@@ -1141,13 +1171,12 @@ TEST(CpuQuantizeSquashPass, squash_all_s8_input_to_concat1) {
}
TEST(CpuQuantizeSquashPass, squash_all_u8_input_to_concat2) {
// removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 12;
// removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 8;
std::unordered_map<std::string, int> expected_operators = {
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 3}};
CheckNodesTest(BuildU8U8U8ConcatProgramDesc(1.2f, 1.2f),
expected_operators,
remove_nodes);
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 2}};
CheckNodesTest(
BuildU8U8ConcatProgramDesc(1.2f, 1.2f), expected_operators, remove_nodes);
}
} // namespace ir
......
......@@ -22,6 +22,9 @@ namespace paddle {
namespace framework {
namespace ir {
using StringPairMap =
std::unordered_map<std::string, std::pair<bool, phi::DenseTensor>>;
static void SaveInfoInTheFirstOp(
ir::Graph* graph,
const std::string& flag,
......@@ -44,6 +47,31 @@ static void SaveInfoInTheFirstOp(
}
}
static void SaveInfoInTheFirstOp(ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
const StringPairMap& info_map) {
VLOG(3) << "save variables in the first op's attr";
const std::string suffix = "_" + key_suffix + "_" + flag;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;
op_node->Op()->SetAttr(flag, true);
for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
auto* data = iter->second.second.data<float>();
std::vector<float> data_v(data, data + iter->second.second.numel());
op_node->Op()->SetAttr(iter->first + suffix + "_unsigned",
iter->second.first);
op_node->Op()->SetAttr(iter->first + suffix, data_v);
}
break;
}
}
static void GetInfoFromTheFirstOp(
ir::Graph* graph,
const std::string& flag,
......@@ -77,6 +105,54 @@ static void GetInfoFromTheFirstOp(
}
}
static void GetInfoFromTheFirstOp(ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
StringPairMap* info_map) {
VLOG(3) << "get variables from the first op's attr";
const std::string unsigned_flag = "_unsigned";
const std::string suffix = "_" + key_suffix + "_" + flag;
const std::string suffix_is_unsigned = suffix + unsigned_flag;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;
auto* op_desc = op_node->Op();
if (op_desc->GetAttrIfExists<bool>(flag)) {
op_desc->RemoveAttr(flag);
std::vector<std::string> attr_names = op_desc->AttrNames();
for (auto fake_name : attr_names) {
auto is_unsigned = false;
size_t pos = fake_name.find(suffix_is_unsigned);
if (pos != std::string::npos) {
std::string unsigned_var_name = fake_name;
is_unsigned =
PADDLE_GET_CONST(bool, op_desc->GetAttr(unsigned_var_name));
std::string var_name = fake_name.substr(0, pos);
size_t unsigned_pos = fake_name.find(unsigned_flag);
std::string vector_name =
fake_name.erase(unsigned_pos, unsigned_flag.length());
auto scales_vector = PADDLE_GET_CONST(std::vector<float>,
op_desc->GetAttr(vector_name));
phi::DenseTensor tensor;
const int size = static_cast<int>(scales_vector.size());
auto data = tensor.mutable_data<double>({size}, platform::CPUPlace());
std::copy(scales_vector.begin(), scales_vector.end(), data);
auto pair = std::make_pair(is_unsigned, tensor);
info_map->insert(std::make_pair(var_name, pair));
op_desc->RemoveAttr(unsigned_var_name);
op_desc->RemoveAttr(vector_name);
}
}
break;
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册