未验证 提交 f29da150 编写于 作者: S Sylwester Fraczek 提交者: GitHub

[bugfix] to concat input squash (#39593)

* fix and add more tests

* remove unwanted changes

* check only concat and elementwise

* move check to a function

* add todo comment

* Revert "fix ptq fc attr name fuse_activation->activation_type"

This reverts commit ffd023353a5e9b0fd15e81b9e9f9fe1794035017.
上级 2d2f11d1
......@@ -132,6 +132,27 @@ bool CPUQuantizeSquashPass::IsDequantizeInputUint8(
return false;
}
bool CPUQuantizeSquashPass::IsDequantizeQuantizeIncompatible(
Node* quant_op, Node* dequant_in, Node* next_op) const {
bool is_concat_signed =
quant_op->Op()->GetAttrIfExists<bool>("is_negative_input");
bool is_input_unsigned = IsDequantizeInputUint8(dequant_in);
/* TODO(sfraczek): remove elementwise from this condition when BinaryMKLDNN
kernel will support two different input data types */
bool is_next_op_concat_or_elementwise =
next_op->Op()->Type() == "concat" ||
next_op->Op()->Type().find("elementwise") == 0;
if (is_next_op_concat_or_elementwise && is_concat_signed &&
is_input_unsigned) {
VLOG(4) << "Do not squash dequant-quant, because "
<< "next_op is: " << next_op->Op()->Type()
<< ", is_concat_signed: " << is_concat_signed
<< ", is_input_unsigned: " << is_input_unsigned << ".";
return true;
}
return false;
}
void CPUQuantizeSquashPass::DequantQuantSquash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
......@@ -151,9 +172,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);
// Don't squash if e.g. just one concat input is unsigned
if (IsDequantizeInputUint8(dequant_in) &&
!quant_op->Op()->GetAttrIfExists<bool>("is_negative_input")) {
if (IsDequantizeQuantizeIncompatible(quant_op, dequant_in, next_op)) {
return;
}
......
......@@ -48,6 +48,15 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
bool IsDequantizeInputUint8(const Node* dequant_in) const;
/*
* Don't squash unsigned dequantize with signed quantize.
* This is important for concat and elementwise ops.
* When inputs have different sign, concat will assume signed type and
* elementwise assumes first input type.
*/
bool IsDequantizeQuantizeIncompatible(Node* quant_op, Node* dequant_in,
Node* next_op) const;
/*
* Squash dequantize-quantize ops pairs into requantize or nothing
*/
......
......@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"
......@@ -234,11 +235,70 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
return prog;
}
/* a->relu->b->Dequant->c(u8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->relu->j->Dequant->k(u8)->Quant->l-/
*/
ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu3", {"i"}, {"j"}, true, {scale, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
{scale, scale_out});
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out},
0.0f, "float32", false, 1, false); // is_negative_input = false
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out},
0.0f, "float32", false, 1, false); // is_negative_input = false
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out},
0.0f, "float32", false, 1, false); // is_negative_input = false
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
return prog;
}
/* a->relu->b->Dequant->c(u8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
*/
ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "relu", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu1", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
{scale, scale_out});
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
return prog;
}
/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
*/
ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
......@@ -255,8 +315,35 @@ ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) {
{scale, scale_out});
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out},
0.0, "float32", false, 1, false);
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
return prog;
}
/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
* e->pool2d->f->Dequant->g(s8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
*/
ProgramDesc BuildS8S8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "pool2d", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d2", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d3", {"i"}, {"j"}, true, {scale, scale_out});
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
{scale, scale_out});
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});
SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
......@@ -834,7 +921,7 @@ TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) {
remove_nodes);
}
TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat) {
TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat1) {
// removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 8;
std::unordered_map<std::string, int> expected_operators = {{"concat", 1},
......@@ -842,8 +929,38 @@ TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat) {
{"dequantize", 1},
{"relu", 1},
{"pool2d", 2}};
CheckNodesTest(BuildConvS8U8S8ConcatProgramDesc(1.2f, 1.2f),
expected_operators, remove_nodes);
CheckNodesTest(BuildS8U8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}
TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat2) {
// removed 1 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 4;
std::unordered_map<std::string, int> expected_operators = {{"concat", 1},
{"quantize", 2},
{"dequantize", 2},
{"relu", 2},
{"pool2d", 1}};
CheckNodesTest(BuildU8U8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}
TEST(CpuQuantizeSquashPass, squash_all_s8_input_to_concat1) {
// removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 12;
std::unordered_map<std::string, int> expected_operators = {
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"pool2d", 3}};
CheckNodesTest(BuildS8S8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}
TEST(CpuQuantizeSquashPass, squash_all_u8_input_to_concat2) {
// removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 12;
std::unordered_map<std::string, int> expected_operators = {
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 3}};
CheckNodesTest(BuildU8U8U8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}
} // namespace ir
......
......@@ -116,15 +116,11 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
// force unsigned type if already know it
bool is_unsigned = false;
bool compute_scale = true;
if (op->Type() == "conv2d") {
if (op->Type() == "conv2d" || op->Type() == "fc") {
// output of conv2d with relu must be unsigned
std::string fuse_activation =
op->GetAttrIfExists<std::string>("fuse_activation");
is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6");
} else if (op->Type() == "fc") {
std::string activation_type =
op->GetAttrIfExists<std::string>("activation_type");
is_unsigned = (activation_type == "relu" || activation_type == "relu6");
} else if (op->Type() == "relu") {
is_unsigned = true;
} else if (op->Type() == "transpose2" || op->Type() == "reshape2" ||
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册