未验证 提交 781df300 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Unification of BF16 enablement process (#31034)

* Unification of bfloat16 enablement process and refactor

* Remove unnecessary function

* Standardize the output name search
上级 16fe11d7
......@@ -1829,9 +1829,8 @@ PDNode *patterns::OpDequant::operator()() {
auto any_op = pattern->NewNode(any_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return (node->Op()->Type() == "matmul" ||
node->Op()->Type() == "conv2d" ||
node->Op()->Type() == "fc");
return (node->Op()->HasAttr("force_fp32_output") ||
node->Op()->HasProtoAttr("force_fp32_output"));
});
auto dequant_in = pattern->NewNode(dequant_in_repr())
->assert_is_op_input("dequantize", "Input");
......@@ -1865,6 +1864,44 @@ PDNode *patterns::DequantScale::operator()() {
return scale_out;
}
PDNode *patterns::ScaleQuant::operator()() {
auto scale_in = pattern->NewNode(scale_in_repr())
->AsInput()
->assert_is_op_input("scale", "X");
auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale");
auto quant_in = pattern->NewNode(quant_in_repr())
->AsInput()
->assert_is_op_input("quantize", "Input");
auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
scale_op->LinksFrom({scale_in}).LinksTo({quant_in});
quant_op->LinksFrom({quant_in});
return quant_op;
}
PDNode *patterns::QuantConv::operator()() {
auto quant_in = pattern->NewNode(quant_in_repr())
->AsInput()
->assert_is_op_input("quantize", "Input");
auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
auto conv_in = pattern->NewNode(conv_in_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input");
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
conv_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
quant_op->LinksFrom({quant_in}).LinksTo({conv_in});
conv_op->LinksFrom({conv_in});
return quant_op;
}
PDNode *patterns::ScaleMatmul::operator()() {
auto scale_in = pattern->NewNode(scale_in_repr())
->AsInput()
......@@ -2191,10 +2228,11 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>(
{"concat", "conv2d", "conv2d_transpose", "elementwise_add",
"elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm",
"matmul", "pool2d", "reshape2", "softmax", "sum", "transpose2"});
std::unordered_set<std::string>({"concat", "conv2d", "conv2d_transpose",
"elementwise_add", "elementwise_mul",
"fc", "fusion_gru", "gelu", "layer_norm",
"matmul", "pool2d", "relu", "reshape2",
"softmax", "sum", "transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
......@@ -2240,25 +2278,12 @@ PDNode *patterns::LastBfloat16Ops::operator()() {
"bfloat16";
});
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
next_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") !=
"bfloat16";
});
op->LinksTo({op_out});
next_op->LinksFrom({op_out});
return next_op;
return op_out;
}
PDNode *patterns::FirstBfloat16Ops::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") !=
"bfloat16";
});
auto *op_in = pattern->NewNode(op_in_repr())->AsOutput();
auto *op_in = pattern->NewNode(op_in_repr())->AsInput();
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
......@@ -2266,7 +2291,6 @@ PDNode *patterns::FirstBfloat16Ops::operator()() {
"bfloat16";
});
prev_op->LinksTo({op_in});
op->LinksFrom({op_in});
return op;
}
......@@ -2280,27 +2304,6 @@ PDNode *patterns::DuplicatedInputs::operator()() {
return op;
}
PDNode *patterns::UnnecessaryReorders::operator()() {
auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
auto *quant_in = pattern->NewNode(quant_in_repr())
->assert_is_op_input("quantize", "Input");
auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
auto *quant_out = pattern->NewNode(quant_out_repr())
->assert_is_op_output("quantize", "Output");
prev_op->LinksTo({quant_in});
quant_op->LinksFrom({quant_in}).LinksTo({quant_out});
return quant_out;
}
PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = {
"abs",
......
......@@ -1135,11 +1135,36 @@ struct DequantScale : public PatternBase {
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(scale_out);
};
// Scale + Quantize
struct ScaleQuant : public PatternBase {
ScaleQuant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "scale_quant") {}
PDNode* operator()();
PATTERN_DECL_NODE(scale_in);
PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
};
// Quantize + Conv2d
struct QuantConv : public PatternBase {
QuantConv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_conv") {}
PDNode* operator()();
PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(conv_in);
PATTERN_DECL_NODE(conv_op);
};
// Scale + Matmul
struct ScaleMatmul : public PatternBase {
ScaleMatmul(PDPattern* pattern, const std::string& name_scope)
......@@ -1338,7 +1363,6 @@ struct LastBfloat16Ops : public PatternBase {
PATTERN_DECL_NODE(op);
PATTERN_DECL_NODE(op_out);
PATTERN_DECL_NODE(next_op);
};
struct FirstBfloat16Ops : public PatternBase {
......@@ -1346,7 +1370,6 @@ struct FirstBfloat16Ops : public PatternBase {
: PatternBase(pattern, name_scope, "first_bfloat16_ops") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(op_in);
PATTERN_DECL_NODE(op);
};
......@@ -1360,17 +1383,6 @@ struct DuplicatedInputs : public PatternBase {
PATTERN_DECL_NODE(op);
};
struct UnnecessaryReorders : public PatternBase {
UnnecessaryReorders(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "unnecessary_reorders") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(quant_out);
};
// Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase {
......
......@@ -12,12 +12,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......@@ -33,8 +31,38 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) {
b->inputs.end());
}
// Checking whether a reorder from FP32 to BF16 should be added before the input
// to the operator
bool IsPermittedInputName(const std::string& input_name) {
// Only the inputs listed in \"permitted_names\" requires quanitization before
// the bfloat16 operator. Other inputs, such as Filter and Bias are reordered
// in the kernel.
const std::vector<std::string> permitted_names = {"X", "Y", "Input",
"ResidualData"};
return (std::find(permitted_names.begin(), permitted_names.end(),
input_name) != permitted_names.end());
}
// Checking whether a reorder from BF16 to FP32 should be added after the output
// to the operator
bool IsPermittedOutputName(const std::string& output_name) {
// XShape is output in transpose2 and reshape2 operators used to store the
// shape and lod of X. So this output do not need dequantize before.
return (output_name != "XShape");
}
void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
int* quantize_counter) {
std::vector<std::string> input_names;
// Find the name of the input linking op to op_in
for (auto name : op->Op()->InputNames())
for (auto input_name : op->Op()->Input(name))
if (input_name == op_in->Name() && IsPermittedInputName(name))
input_names.push_back(name);
if (input_names.empty()) return;
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);
......@@ -44,23 +72,12 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node->Name()}));
q_desc.SetAttr("Scale", 1.f);
q_desc.SetAttr("Shift", 0.0f);
q_desc.SetAttr("bfloat16", true);
q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout")
: std::string("NCHW"));
auto quantize_op = g->CreateOpNode(&q_desc);
std::vector<std::string> input_names;
for (auto name : op->Op()->InputNames()) {
for (auto input_name : op->Op()->Input(name)) {
if (input_name == op_in->Name()) input_names.push_back(name);
}
}
PADDLE_ENFORCE_NE(
input_names.empty(), true,
platform::errors::NotFound(
"Operator before operator should have input as op output"));
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
for (auto name = input_names.begin(); name < input_names.end(); name++)
op->Op()->SetInput(*name,
......@@ -99,11 +116,12 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node_names[i]}));
q_desc.SetAttr("Scale", 1.f);
q_desc.SetAttr("Shift", 0.0f);
q_desc.SetAttr("bfloat16", true);
q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout")
: std::string("NCHW"));
auto quantize_op = g->CreateOpNode(&q_desc);
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
UnlinkNodes(inputs[i], op);
IR_NODE_LINK_TO(inputs[i], quantize_op);
......@@ -115,6 +133,9 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
op->Op()->SetInput("X", quantize_out_node_names);
}
// Operators like Concat and Sum have a single input name X, which actually
// consists of multiple inputs. Such operators require a different way to find
// pattern and add quantize ops.
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(),
......@@ -128,38 +149,8 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {
gpd(graph, handler);
}
void RemoveUnnecessaryReorders(ir::Graph* graph, int* quantize_counter) {
GraphPatternDetector gpd;
patterns::UnnecessaryReorders unnecessary_reorders{gpd.mutable_pattern(),
"unnecessary_reorders"};
unnecessary_reorders();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, unnecessary_reorders);
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, unnecessary_reorders);
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, unnecessary_reorders);
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, unnecessary_reorders);
std::string op_output_name;
for (auto name : prev_op->Op()->OutputNames())
for (auto output_name : prev_op->Op()->Output(name))
if (output_name == quant_in->Name()) op_output_name = name;
PADDLE_ENFORCE_NE(
op_output_name.empty(), true,
platform::errors::NotFound(
"Operator before operator should have input as op output"));
prev_op->Op()->SetOutput(op_output_name,
std::vector<std::string>({quant_out->Name()}));
IR_NODE_LINK_TO(prev_op, quant_out);
GraphSafeRemoveNodes(graph, {quant_in, quant_op});
(*quantize_counter)--;
};
gpd(graph, handler);
}
// Adding quantize ops before all operators except Concat and Sum, which have
// already been handled in AddReoderBeforeDuplicatedInputs
void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
GraphPatternDetector gpd;
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
......@@ -167,12 +158,9 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
bfloat16_ops();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
auto prev_op_type = prev_op->Op()->Type();
if (op->Op()->Type() != "conv2d" && prev_op_type != "quantize" &&
prev_op_type != "sum" && prev_op_type != "concat") {
if (op->Op()->Type() != "sum" && op->Op()->Type() != "concat") {
AddQuantize(g, op, op_in, quantize_counter);
}
};
......@@ -182,9 +170,8 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
int quantize_counter = 0;
AddReoderBeforeDuplicatedInputs(graph, &quantize_counter);
RemoveUnnecessaryReorders(graph, &quantize_counter);
AddReoderBeforeSingleInputs(graph, &quantize_counter);
PrettyLogDetail("--- added %d quantize op before bfloat16 op",
PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
quantize_counter);
}
......@@ -193,55 +180,51 @@ void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"last_bfloat16_ops"};
bfloat16_ops();
int force_fp32_counter = 0, dequantize_counter = 0;
int dequantize_counter = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, bfloat16_ops);
if ((op->Op()->HasAttr("force_fp32_output") ||
op->Op()->HasProtoAttr("force_fp32_output")) &&
!op->Op()->GetAttrIfExists<bool>("fuse_residual_connection")) {
op->Op()->SetAttr("force_fp32_output", true);
force_fp32_counter++;
} else if (op->Op()->Type() != "prior_box") {
VarDesc dequantize_out_desc(patterns::PDNodeName("dequantize", "out"));
auto* dequantize_out_node = g->CreateVarNode(&dequantize_out_desc);
if (op->Op()->Type() != "prior_box") {
// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);
if (output_names.empty()) return;
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input", std::vector<std::string>({op_out->Name()}));
deq_desc.SetOutput(
"Output", std::vector<std::string>({dequantize_out_node->Name()}));
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc);
std::string next_op_input_name;
for (auto name : next_op->Op()->InputNames()) {
for (auto input_name : next_op->Op()->Input(name)) {
if (input_name == op_out->Name()) next_op_input_name = name;
}
}
PADDLE_ENFORCE_NE(
next_op_input_name.empty(), true,
platform::errors::NotFound(
"Operator before operator should have input as op output"));
next_op->Op()->SetInput(
next_op_input_name,
std::vector<std::string>({dequantize_out_node->Name()}));
UnlinkNodes(op_out, next_op);
IR_NODE_LINK_TO(op_out, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, dequantize_out_node);
IR_NODE_LINK_TO(dequantize_out_node, next_op);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op =
g->CreateOpNode(&deq_desc); // OpDesc will be copied.
for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(
*name, std::vector<std::string>({dequantize_in_node->Name()}));
UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);
dequantize_counter++;
}
};
gpd(graph, handler);
PrettyLogDetail("--- added %d dequantize op and used %d force_fp32_output",
dequantize_counter, force_fp32_counter);
PrettyLogDetail("--- added %d dequantize ops after bfloat16 op",
dequantize_counter);
}
void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
......
......@@ -26,8 +26,7 @@ namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn,
const std::string& mkldnn_data_type = "float32",
const bool force_fp32_output = false) {
const std::string& mkldnn_data_type = "float32") {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
......@@ -37,7 +36,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("force_fp32_output", force_fp32_output);
} else if (type == "pool2d" || type == "transpose2" || type == "reshape2" ||
type == "dropout") {
op->SetInput("X", {inputs[0]});
......@@ -47,7 +45,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("force_fp32_output", force_fp32_output);
} else if (type == "concat" || type == "sum") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
......@@ -58,7 +55,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
if (type == "matmul") op->SetAttr("force_fp32_output", force_fp32_output);
} else if (type == "layer_norm") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Y", {outputs[0]});
......@@ -79,8 +75,8 @@ void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
*current_nodes_num = (*graph)->Nodes().size();
}
void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count,
int force_fp32_count, int added_nodes_count) {
void MainTest(const ProgramDesc& prog, const int& quant_count,
const int& dequant_count, const int& added_nodes_count) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names, &original_nodes_num,
......@@ -88,7 +84,6 @@ void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count,
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int force_fp32_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
......@@ -96,16 +91,11 @@ void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count,
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
} else if (op->Type() == "conv2d" || op->Type() == "matmul" ||
op->Type() == "fc") {
if (op->GetAttrIfExists<bool>("force_fp32_output"))
force_fp32_nodes_count++;
}
}
}
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(force_fp32_nodes_count, force_fp32_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}
......@@ -125,9 +115,10 @@ ProgramDesc BuildProgramDescConv(bool use_mkldnn) {
TEST(CpuBfloat16Pass, convolution) {
bool use_mkldnn = true;
// 0 added + 1 force_fp32_output
int added_nodes = 0;
MainTest(BuildProgramDescConv(use_mkldnn), 0, 0, 1, added_nodes);
int quant_op = 3;
int dequant_op = 3;
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescConv(use_mkldnn), quant_op, dequant_op, added_nodes);
}
ProgramDesc BuildProgramDescDoubleInput(bool use_mkldnn) {
......@@ -147,9 +138,11 @@ ProgramDesc BuildProgramDescDoubleInput(bool use_mkldnn) {
TEST(CpuBfloat16Pass, double_input_ops) {
bool use_mkldnn = true;
// 2 quant + 2 quant out
int added_nodes = 4;
MainTest(BuildProgramDescDoubleInput(use_mkldnn), 2, 0, 0, added_nodes);
int quant_op = 4;
int dequant_op = 3;
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleInput(use_mkldnn), quant_op, dequant_op,
added_nodes);
}
ProgramDesc BuildProgramDescDuplicatedInput(bool use_mkldnn) {
......@@ -169,9 +162,11 @@ ProgramDesc BuildProgramDescDuplicatedInput(bool use_mkldnn) {
TEST(CpuBfloat16Pass, duplicated_input_ops) {
bool use_mkldnn = true;
// 3 quant + 3 quant out
int added_nodes = 6;
MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), 3, 0, 0, added_nodes);
int quant_op = 5;
int dequant_op = 3;
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), quant_op, dequant_op,
added_nodes);
}
ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) {
......@@ -193,9 +188,11 @@ ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) {
TEST(CpuBfloat16Pass, double_outputs_ops) {
bool use_mkldnn = true;
// 3 dequant + 3 dequant out
int added_nodes = 6;
MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), 0, 3, 0, added_nodes);
int quant_op = 3;
int dequant_op = 3;
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), quant_op, dequant_op,
added_nodes);
}
} // namespace ir
......
......@@ -255,14 +255,21 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, op_dequant_pattern);
if (dequant_in->outputs.size() == 1) {
auto output_name = "Out";
if (any_op->Op()->Type() == "conv2d") {
if (any_op->Op()->Type() == "conv2d" ||
any_op->Op()->Type() == "conv2d_transpose") {
// do not squash if fuse residual connection is true
// because residual fusion does not support force output with fp32
if (any_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection"))
return;
output_name = "Output";
}
// Find the name of the output linking any_op to dequant_in
std::string output_name;
for (auto name : any_op->Op()->OutputNames())
for (auto out_name : any_op->Op()->Output(name))
if (out_name == dequant_in->Name()) output_name = name;
if (output_name.empty()) return;
any_op->Op()->SetAttr("force_fp32_output", true);
any_op->Op()->SetOutput(output_name,
std::vector<std::string>({dequant_out->Name()}));
......@@ -363,10 +370,10 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
platform::errors::InvalidArgument(
"Dequantize scale(%f) should have positive value.",
dequant_scale));
PADDLE_ENFORCE_GT(scale_scale, 0.0f,
platform::errors::InvalidArgument(
"Scale(%f) of scale op should have positive value.",
scale_scale));
PADDLE_ENFORCE_NE(
scale_scale, 0.0f,
platform::errors::InvalidArgument(
"Scale(%f) should have a non-zero value", scale_scale));
dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale);
dequant_op->Op()->SetOutput(
......@@ -378,10 +385,86 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
};
gpd(graph, handler);
AddStatis(found_dequant_scale_squash_count);
PrettyLogDetail("--- squashed %d scale with dequant",
PrettyLogDetail("--- squashed %d scale with dequantize op",
found_dequant_scale_squash_count);
}
// squash scale with quantize
void CPUQuantizeSquashPass::ScaleQuantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ScaleQuant scale_quant_pattern{gpd.mutable_pattern(),
"scale_quant"};
scale_quant_pattern();
int found_scale_quant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash scale-quant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(scale_in, scale_in, scale_quant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, scale_quant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, scale_quant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, scale_quant_pattern);
if (quant_in->outputs.size() == 1 &&
scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) {
auto quant_scale = quant_op->Op()->GetAttrIfExists<float>("Scale");
auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale");
PADDLE_ENFORCE_GT(
quant_scale, 0.0f,
platform::errors::InvalidArgument(
"Quantize scale(%f) should have positive value.", quant_scale));
PADDLE_ENFORCE_NE(
scale_scale, 0.0f,
platform::errors::InvalidArgument(
"Scale(%f) should have a non-zero value", scale_scale));
quant_op->Op()->SetAttr("Scale", quant_scale * scale_scale);
quant_op->Op()->SetInput("Input",
std::vector<std::string>({scale_in->Name()}));
IR_NODE_LINK_TO(scale_in, quant_op);
GraphSafeRemoveNodes(graph, {scale_op, quant_in});
found_scale_quant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_scale_quant_squash_count);
PrettyLogDetail("--- squashed %d scale with quantize op",
found_scale_quant_squash_count);
}
// squash quantize if is before bfloat16 conv2d
void CPUQuantizeSquashPass::QuantizeBf16Conv(Graph* graph) const {
GraphPatternDetector gpd;
patterns::QuantConv pattern{gpd.mutable_pattern(), "quant_conv"};
pattern();
int found_quant_conv_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash quant-conv2d ops pair";
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_in, conv_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, pattern);
if (conv_in->outputs.size() == 1 &&
quant_op->Op()->GetAttrIfExists<float>("Scale") == 1.0) {
conv_op->Op()->SetInput("Input",
std::vector<std::string>({quant_in->Name()}));
IR_NODE_LINK_TO(quant_in, conv_op);
GraphSafeRemoveNodes(graph, {quant_op, conv_in});
found_quant_conv_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_quant_conv_squash_count);
PrettyLogDetail("--- squashed %d quantize with bfloat16 conv2d op",
found_quant_conv_squash_count);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
......@@ -389,6 +472,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
"The graph in function CPUQuantizeSquashPass::ApplyImpl is null."));
FusePassBase::Init("cpu_quantize_squash_pass", graph);
DequantScaleSquash(graph);
ScaleQuantSquash(graph);
std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter);
......@@ -396,7 +481,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
RequantOpSquash(graph);
OpDequantSquash(graph);
MultipleQuantizeSquash(graph);
DequantScaleSquash(graph);
QuantizeBf16Conv(graph);
}
} // namespace ir
......
......@@ -78,6 +78,16 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
void DequantScaleSquash(Graph* graph) const;
/*
* Squash scale if scale is before quantize
*/
void ScaleQuantSquash(Graph* graph) const;
/*
* Squash quantize if is before bfloat16 conv2d
*/
void QuantizeBf16Conv(Graph* graph) const;
const std::string name_scope_{"squash"};
};
......
......@@ -24,7 +24,8 @@ namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn,
const std::vector<float> scale = {}, float bias = 0.0) {
const std::vector<float> scale = {}, float bias = 0.0,
const std::string& mkldnn_data_type = "float32") {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
......@@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("force_fp32_output", false);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "quantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
......@@ -52,6 +55,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "fc") {
op->SetInput("Input", {inputs[0]});
PADDLE_ENFORCE_EQ(inputs.size(), 2UL,
......@@ -63,6 +67,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetOutput("Out", outputs);
if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
op->SetAttr("force_fp32_output", false);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "scale") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
......@@ -74,6 +80,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetOutput("Out", {outputs[0]});
if (scale.size() > 0) op->SetAttr("Scale_x", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
op->SetAttr("force_fp32_output", false);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
}
}
......@@ -299,6 +307,20 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale,
return prog;
}
// a->Scale->b
// b->Quant->c
ProgramDesc BuildScaleQuantProgramDesc(bool use_mkldnn, float scale_scale,
float quant_scale, float bias) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "scale", "Scale", {"a"}, {"b"}, use_mkldnn, {scale_scale}, bias);
SetOp(&prog, "quantize", "Quant", {"b"}, {"c"}, use_mkldnn, {quant_scale});
return prog;
}
// {x,y}->Matmul->b
// b->Dequant->c
ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn,
......@@ -341,6 +363,22 @@ ProgramDesc BuildRequantOpProgramDesc(bool use_mkldnn, float requant_scale_in,
return prog;
}
// a->Quant->b
// b->Conv2d->c
ProgramDesc BuildQuantConv2dProgramDesc(const bool& use_mkldnn,
const float& quant_scale,
const std::string& mkldnn_data_type) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "quantize", "Quant", {"a"}, {"b"}, use_mkldnn, {quant_scale});
SetOp(&prog, "conv2d", "Conv2d", {"b"}, {"c"}, use_mkldnn, {}, 0.0f,
mkldnn_data_type);
return prog;
}
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const char* var_name) {
auto x = scope->Var(var_name);
......@@ -664,6 +702,22 @@ TEST(CpuQuantizeSquashPass, dequantize_scale_with_bias) {
"Dequant", "Scale", dequant_scale);
}
// if scale has no bias
TEST(CpuQuantizeSquashPass, scale_with_no_bias_quantize) {
constexpr auto scale_scale = 1.5432f;
constexpr auto quant_scale = 1.2345f;
constexpr auto bias = 0.0f;
auto use_mkldnn = true;
// remove: dequant out, scale op
auto remove_nodes = 2;
CountNodeTest(
BuildScaleQuantProgramDesc(use_mkldnn, scale_scale, quant_scale, bias),
remove_nodes);
EqualScaleTest(
BuildScaleQuantProgramDesc(use_mkldnn, scale_scale, quant_scale, bias),
"Scale", "Quant", quant_scale * scale_scale);
}
TEST(CpuQuantizeSquashPass, matmul_with_dequant) {
auto dequant_scale = 1.2345f;
auto use_mkldnn = true;
......@@ -688,6 +742,17 @@ TEST(CpuQuantizeSquashPass, requantize_with_matmul_fc_conv) {
EqualScaleTest(program_desc, "Conv", "Scale_in", requant_scale_in);
}
TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) {
auto quant_scale = 1.0f;
auto use_mkldnn = true;
auto mkldnn_data_type = "bfloat16";
// remove: quant_op, conv_in
auto remove_nodes = 2;
CountNodeTest(
BuildQuantConv2dProgramDesc(use_mkldnn, quant_scale, mkldnn_data_type),
remove_nodes);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -268,6 +268,7 @@ void CpuPassStrategy::EnableMkldnnBfloat16() {
if (!use_mkldnn_bfloat16_) {
passes_.push_back("cpu_bfloat16_placement_pass");
passes_.push_back("cpu_bfloat16_pass");
passes_.push_back("cpu_quantize_squash_pass");
}
use_mkldnn_bfloat16_ = true;
#else
......
......@@ -156,4 +156,5 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace,
ops::ReQuantOpKernel<int8_t>, ops::ReQuantOpKernel<uint8_t>);
ops::ReQuantOpKernel<int8_t>, ops::ReQuantOpKernel<uint8_t>,
ops::ReQuantOpKernel<paddle::platform::bfloat16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册