未验证 提交 781df300 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Unification of BF16 enablement process (#31034)

* Unification of bfloat16 enablement process and refactor

* Remove unnecessary function

* Standardize the output name search
上级 16fe11d7
...@@ -1829,9 +1829,8 @@ PDNode *patterns::OpDequant::operator()() { ...@@ -1829,9 +1829,8 @@ PDNode *patterns::OpDequant::operator()() {
auto any_op = pattern->NewNode(any_op_repr()) auto any_op = pattern->NewNode(any_op_repr())
->assert_is_op() ->assert_is_op()
->assert_more([&](Node *node) { ->assert_more([&](Node *node) {
return (node->Op()->Type() == "matmul" || return (node->Op()->HasAttr("force_fp32_output") ||
node->Op()->Type() == "conv2d" || node->Op()->HasProtoAttr("force_fp32_output"));
node->Op()->Type() == "fc");
}); });
auto dequant_in = pattern->NewNode(dequant_in_repr()) auto dequant_in = pattern->NewNode(dequant_in_repr())
->assert_is_op_input("dequantize", "Input"); ->assert_is_op_input("dequantize", "Input");
...@@ -1865,6 +1864,44 @@ PDNode *patterns::DequantScale::operator()() { ...@@ -1865,6 +1864,44 @@ PDNode *patterns::DequantScale::operator()() {
return scale_out; return scale_out;
} }
PDNode *patterns::ScaleQuant::operator()() {
auto scale_in = pattern->NewNode(scale_in_repr())
->AsInput()
->assert_is_op_input("scale", "X");
auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale");
auto quant_in = pattern->NewNode(quant_in_repr())
->AsInput()
->assert_is_op_input("quantize", "Input");
auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
scale_op->LinksFrom({scale_in}).LinksTo({quant_in});
quant_op->LinksFrom({quant_in});
return quant_op;
}
PDNode *patterns::QuantConv::operator()() {
auto quant_in = pattern->NewNode(quant_in_repr())
->AsInput()
->assert_is_op_input("quantize", "Input");
auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
auto conv_in = pattern->NewNode(conv_in_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input");
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
conv_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
quant_op->LinksFrom({quant_in}).LinksTo({conv_in});
conv_op->LinksFrom({conv_in});
return quant_op;
}
PDNode *patterns::ScaleMatmul::operator()() { PDNode *patterns::ScaleMatmul::operator()() {
auto scale_in = pattern->NewNode(scale_in_repr()) auto scale_in = pattern->NewNode(scale_in_repr())
->AsInput() ->AsInput()
...@@ -2191,10 +2228,11 @@ PDNode *patterns::QuantizePlacement::operator()( ...@@ -2191,10 +2228,11 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()( PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) { const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>( std::unordered_set<std::string>({"concat", "conv2d", "conv2d_transpose",
{"concat", "conv2d", "conv2d_transpose", "elementwise_add", "elementwise_add", "elementwise_mul",
"elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm", "fc", "fusion_gru", "gelu", "layer_norm",
"matmul", "pool2d", "reshape2", "softmax", "sum", "transpose2"}); "matmul", "pool2d", "relu", "reshape2",
"softmax", "sum", "transpose2"});
if (!bfloat16_enabled_op_types.empty()) { if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types; supported_op_types = bfloat16_enabled_op_types;
} }
...@@ -2240,25 +2278,12 @@ PDNode *patterns::LastBfloat16Ops::operator()() { ...@@ -2240,25 +2278,12 @@ PDNode *patterns::LastBfloat16Ops::operator()() {
"bfloat16"; "bfloat16";
}); });
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput(); auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
next_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") !=
"bfloat16";
});
op->LinksTo({op_out}); op->LinksTo({op_out});
next_op->LinksFrom({op_out}); return op_out;
return next_op;
} }
PDNode *patterns::FirstBfloat16Ops::operator()() { PDNode *patterns::FirstBfloat16Ops::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); auto *op_in = pattern->NewNode(op_in_repr())->AsInput();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") !=
"bfloat16";
});
auto *op_in = pattern->NewNode(op_in_repr())->AsOutput();
auto *op = pattern->NewNode(op_repr())->assert_is_op(); auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) { op->assert_more([&](Node *node) {
...@@ -2266,7 +2291,6 @@ PDNode *patterns::FirstBfloat16Ops::operator()() { ...@@ -2266,7 +2291,6 @@ PDNode *patterns::FirstBfloat16Ops::operator()() {
"bfloat16"; "bfloat16";
}); });
prev_op->LinksTo({op_in});
op->LinksFrom({op_in}); op->LinksFrom({op_in});
return op; return op;
} }
...@@ -2280,27 +2304,6 @@ PDNode *patterns::DuplicatedInputs::operator()() { ...@@ -2280,27 +2304,6 @@ PDNode *patterns::DuplicatedInputs::operator()() {
return op; return op;
} }
PDNode *patterns::UnnecessaryReorders::operator()() {
auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
auto *quant_in = pattern->NewNode(quant_in_repr())
->assert_is_op_input("quantize", "Input");
auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");
auto *quant_out = pattern->NewNode(quant_out_repr())
->assert_is_op_output("quantize", "Output");
prev_op->LinksTo({quant_in});
quant_op->LinksFrom({quant_in}).LinksTo({quant_out});
return quant_out;
}
PDNode *patterns::MKLDNNInPlace::operator()() { PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = { const std::unordered_set<std::string> &supported_op_types = {
"abs", "abs",
......
...@@ -1135,11 +1135,36 @@ struct DequantScale : public PatternBase { ...@@ -1135,11 +1135,36 @@ struct DequantScale : public PatternBase {
PATTERN_DECL_NODE(dequant_op); PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out); PATTERN_DECL_NODE(dequant_out);
PATTERN_DECL_NODE(scale_op); PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(scale_out); PATTERN_DECL_NODE(scale_out);
}; };
// Scale + Quantize
struct ScaleQuant : public PatternBase {
ScaleQuant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "scale_quant") {}
PDNode* operator()();
PATTERN_DECL_NODE(scale_in);
PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
};
// Quantize + Conv2d
struct QuantConv : public PatternBase {
QuantConv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_conv") {}
PDNode* operator()();
PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(conv_in);
PATTERN_DECL_NODE(conv_op);
};
// Scale + Matmul // Scale + Matmul
struct ScaleMatmul : public PatternBase { struct ScaleMatmul : public PatternBase {
ScaleMatmul(PDPattern* pattern, const std::string& name_scope) ScaleMatmul(PDPattern* pattern, const std::string& name_scope)
...@@ -1338,7 +1363,6 @@ struct LastBfloat16Ops : public PatternBase { ...@@ -1338,7 +1363,6 @@ struct LastBfloat16Ops : public PatternBase {
PATTERN_DECL_NODE(op); PATTERN_DECL_NODE(op);
PATTERN_DECL_NODE(op_out); PATTERN_DECL_NODE(op_out);
PATTERN_DECL_NODE(next_op);
}; };
struct FirstBfloat16Ops : public PatternBase { struct FirstBfloat16Ops : public PatternBase {
...@@ -1346,7 +1370,6 @@ struct FirstBfloat16Ops : public PatternBase { ...@@ -1346,7 +1370,6 @@ struct FirstBfloat16Ops : public PatternBase {
: PatternBase(pattern, name_scope, "first_bfloat16_ops") {} : PatternBase(pattern, name_scope, "first_bfloat16_ops") {}
PDNode* operator()(); PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(op_in); PATTERN_DECL_NODE(op_in);
PATTERN_DECL_NODE(op); PATTERN_DECL_NODE(op);
}; };
...@@ -1360,17 +1383,6 @@ struct DuplicatedInputs : public PatternBase { ...@@ -1360,17 +1383,6 @@ struct DuplicatedInputs : public PatternBase {
PATTERN_DECL_NODE(op); PATTERN_DECL_NODE(op);
}; };
struct UnnecessaryReorders : public PatternBase {
UnnecessaryReorders(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "unnecessary_reorders") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(quant_out);
};
// Pattern used for enforcing inplace computation for in-place computation // Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm // supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase { struct MKLDNNInPlace : public PatternBase {
......
...@@ -12,12 +12,10 @@ limitations under the License. */ ...@@ -12,12 +12,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h" #include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -33,8 +31,38 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) { ...@@ -33,8 +31,38 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) {
b->inputs.end()); b->inputs.end());
} }
// Checking whether a reorder from FP32 to BF16 should be added before the input
// to the operator
bool IsPermittedInputName(const std::string& input_name) {
// Only the inputs listed in \"permitted_names\" requires quanitization before
// the bfloat16 operator. Other inputs, such as Filter and Bias are reordered
// in the kernel.
const std::vector<std::string> permitted_names = {"X", "Y", "Input",
"ResidualData"};
return (std::find(permitted_names.begin(), permitted_names.end(),
input_name) != permitted_names.end());
}
// Checking whether a reorder from BF16 to FP32 should be added after the output
// to the operator
bool IsPermittedOutputName(const std::string& output_name) {
// XShape is output in transpose2 and reshape2 operators used to store the
// shape and lod of X. So this output do not need dequantize before.
return (output_name != "XShape");
}
void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in, void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
int* quantize_counter) { int* quantize_counter) {
std::vector<std::string> input_names;
// Find the name of the input linking op to op_in
for (auto name : op->Op()->InputNames())
for (auto input_name : op->Op()->Input(name))
if (input_name == op_in->Name() && IsPermittedInputName(name))
input_names.push_back(name);
if (input_names.empty()) return;
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc); auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);
...@@ -44,23 +72,12 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in, ...@@ -44,23 +72,12 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
q_desc.SetOutput("Output", q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node->Name()})); std::vector<std::string>({quantize_out_node->Name()}));
q_desc.SetAttr("Scale", 1.f); q_desc.SetAttr("Scale", 1.f);
q_desc.SetAttr("Shift", 0.0f);
q_desc.SetAttr("bfloat16", true); q_desc.SetAttr("bfloat16", true);
q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout") q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout") ? op->Op()->GetAttr("data_layout")
: std::string("NCHW")); : std::string("NCHW"));
auto quantize_op = g->CreateOpNode(&q_desc); auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
std::vector<std::string> input_names;
for (auto name : op->Op()->InputNames()) {
for (auto input_name : op->Op()->Input(name)) {
if (input_name == op_in->Name()) input_names.push_back(name);
}
}
PADDLE_ENFORCE_NE(
input_names.empty(), true,
platform::errors::NotFound(
"Operator before operator should have input as op output"));
for (auto name = input_names.begin(); name < input_names.end(); name++) for (auto name = input_names.begin(); name < input_names.end(); name++)
op->Op()->SetInput(*name, op->Op()->SetInput(*name,
...@@ -99,11 +116,12 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { ...@@ -99,11 +116,12 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
q_desc.SetOutput("Output", q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node_names[i]})); std::vector<std::string>({quantize_out_node_names[i]}));
q_desc.SetAttr("Scale", 1.f); q_desc.SetAttr("Scale", 1.f);
q_desc.SetAttr("Shift", 0.0f);
q_desc.SetAttr("bfloat16", true); q_desc.SetAttr("bfloat16", true);
q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout") q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout") ? op->Op()->GetAttr("data_layout")
: std::string("NCHW")); : std::string("NCHW"));
auto quantize_op = g->CreateOpNode(&q_desc); auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
UnlinkNodes(inputs[i], op); UnlinkNodes(inputs[i], op);
IR_NODE_LINK_TO(inputs[i], quantize_op); IR_NODE_LINK_TO(inputs[i], quantize_op);
...@@ -115,6 +133,9 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { ...@@ -115,6 +133,9 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
op->Op()->SetInput("X", quantize_out_node_names); op->Op()->SetInput("X", quantize_out_node_names);
} }
// Operators like Concat and Sum have a single input name X, which actually
// consists of multiple inputs. Such operators require a different way to find
// pattern and add quantize ops.
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) { void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(), patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(),
...@@ -128,38 +149,8 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) { ...@@ -128,38 +149,8 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {
gpd(graph, handler); gpd(graph, handler);
} }
void RemoveUnnecessaryReorders(ir::Graph* graph, int* quantize_counter) { // Adding quantize ops before all operators except Concat and Sum, which have
GraphPatternDetector gpd; // already been handled in AddReoderBeforeDuplicatedInputs
patterns::UnnecessaryReorders unnecessary_reorders{gpd.mutable_pattern(),
"unnecessary_reorders"};
unnecessary_reorders();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, unnecessary_reorders);
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, unnecessary_reorders);
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, unnecessary_reorders);
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, unnecessary_reorders);
std::string op_output_name;
for (auto name : prev_op->Op()->OutputNames())
for (auto output_name : prev_op->Op()->Output(name))
if (output_name == quant_in->Name()) op_output_name = name;
PADDLE_ENFORCE_NE(
op_output_name.empty(), true,
platform::errors::NotFound(
"Operator before operator should have input as op output"));
prev_op->Op()->SetOutput(op_output_name,
std::vector<std::string>({quant_out->Name()}));
IR_NODE_LINK_TO(prev_op, quant_out);
GraphSafeRemoveNodes(graph, {quant_in, quant_op});
(*quantize_counter)--;
};
gpd(graph, handler);
}
void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
...@@ -167,12 +158,9 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { ...@@ -167,12 +158,9 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
bfloat16_ops(); bfloat16_ops();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
auto prev_op_type = prev_op->Op()->Type(); if (op->Op()->Type() != "sum" && op->Op()->Type() != "concat") {
if (op->Op()->Type() != "conv2d" && prev_op_type != "quantize" &&
prev_op_type != "sum" && prev_op_type != "concat") {
AddQuantize(g, op, op_in, quantize_counter); AddQuantize(g, op, op_in, quantize_counter);
} }
}; };
...@@ -182,9 +170,8 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { ...@@ -182,9 +170,8 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const { void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
int quantize_counter = 0; int quantize_counter = 0;
AddReoderBeforeDuplicatedInputs(graph, &quantize_counter); AddReoderBeforeDuplicatedInputs(graph, &quantize_counter);
RemoveUnnecessaryReorders(graph, &quantize_counter);
AddReoderBeforeSingleInputs(graph, &quantize_counter); AddReoderBeforeSingleInputs(graph, &quantize_counter);
PrettyLogDetail("--- added %d quantize op before bfloat16 op", PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
quantize_counter); quantize_counter);
} }
...@@ -193,55 +180,51 @@ void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const { ...@@ -193,55 +180,51 @@ void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"last_bfloat16_ops"}; "last_bfloat16_ops"};
bfloat16_ops(); bfloat16_ops();
int force_fp32_counter = 0, dequantize_counter = 0; int dequantize_counter = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, bfloat16_ops);
if ((op->Op()->HasAttr("force_fp32_output") || if (op->Op()->Type() != "prior_box") {
op->Op()->HasProtoAttr("force_fp32_output")) && // Find the name of the output linking op to op_out
!op->Op()->GetAttrIfExists<bool>("fuse_residual_connection")) { std::vector<std::string> output_names;
op->Op()->SetAttr("force_fp32_output", true); for (auto name : op->Op()->OutputNames())
force_fp32_counter++; for (auto output_name : op->Op()->Output(name))
} else if (op->Op()->Type() != "prior_box") { if (output_name == op_out->Name() && IsPermittedOutputName(name))
VarDesc dequantize_out_desc(patterns::PDNodeName("dequantize", "out")); output_names.push_back(name);
auto* dequantize_out_node = g->CreateVarNode(&dequantize_out_desc);
if (output_names.empty()) return;
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
OpDesc deq_desc; OpDesc deq_desc;
deq_desc.SetType("dequantize"); deq_desc.SetType("dequantize");
deq_desc.SetInput("Input", std::vector<std::string>({op_out->Name()})); deq_desc.SetInput("Input",
deq_desc.SetOutput( std::vector<std::string>({dequantize_in_node->Name()}));
"Output", std::vector<std::string>({dequantize_out_node->Name()})); deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f); deq_desc.SetAttr("Scale", 1.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc); deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op =
std::string next_op_input_name; g->CreateOpNode(&deq_desc); // OpDesc will be copied.
for (auto name : next_op->Op()->InputNames()) {
for (auto input_name : next_op->Op()->Input(name)) { for (auto name = output_names.begin(); name < output_names.end(); name++)
if (input_name == op_out->Name()) next_op_input_name = name; op->Op()->SetOutput(
} *name, std::vector<std::string>({dequantize_in_node->Name()}));
}
UnlinkNodes(op, op_out);
PADDLE_ENFORCE_NE( IR_NODE_LINK_TO(op, dequantize_in_node);
next_op_input_name.empty(), true, IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
platform::errors::NotFound( IR_NODE_LINK_TO(dequantize_op, op_out);
"Operator before operator should have input as op output"));
next_op->Op()->SetInput(
next_op_input_name,
std::vector<std::string>({dequantize_out_node->Name()}));
UnlinkNodes(op_out, next_op);
IR_NODE_LINK_TO(op_out, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, dequantize_out_node);
IR_NODE_LINK_TO(dequantize_out_node, next_op);
dequantize_counter++; dequantize_counter++;
} }
}; };
gpd(graph, handler); gpd(graph, handler);
PrettyLogDetail("--- added %d dequantize op and used %d force_fp32_output", PrettyLogDetail("--- added %d dequantize ops after bfloat16 op",
dequantize_counter, force_fp32_counter); dequantize_counter);
} }
void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const { void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
......
...@@ -26,8 +26,7 @@ namespace ir { ...@@ -26,8 +26,7 @@ namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn, const std::vector<std::string>& outputs, bool use_mkldnn,
const std::string& mkldnn_data_type = "float32", const std::string& mkldnn_data_type = "float32") {
const bool force_fp32_output = false) {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
...@@ -37,7 +36,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -37,7 +36,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type); op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("force_fp32_output", force_fp32_output);
} else if (type == "pool2d" || type == "transpose2" || type == "reshape2" || } else if (type == "pool2d" || type == "transpose2" || type == "reshape2" ||
type == "dropout") { type == "dropout") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
...@@ -47,7 +45,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -47,7 +45,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type); op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("force_fp32_output", force_fp32_output);
} else if (type == "concat" || type == "sum") { } else if (type == "concat" || type == "sum") {
op->SetInput("X", inputs); op->SetInput("X", inputs);
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
...@@ -58,7 +55,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -58,7 +55,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type); op->SetAttr("mkldnn_data_type", mkldnn_data_type);
if (type == "matmul") op->SetAttr("force_fp32_output", force_fp32_output);
} else if (type == "layer_norm") { } else if (type == "layer_norm") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
op->SetOutput("Y", {outputs[0]}); op->SetOutput("Y", {outputs[0]});
...@@ -79,8 +75,8 @@ void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog, ...@@ -79,8 +75,8 @@ void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
*current_nodes_num = (*graph)->Nodes().size(); *current_nodes_num = (*graph)->Nodes().size();
} }
void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count, void MainTest(const ProgramDesc& prog, const int& quant_count,
int force_fp32_count, int added_nodes_count) { const int& dequant_count, const int& added_nodes_count) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num; int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names, &original_nodes_num, PreparePass(&graph, prog, variable_names, &original_nodes_num,
...@@ -88,7 +84,6 @@ void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count, ...@@ -88,7 +84,6 @@ void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count,
int quantize_nodes_count = 0; int quantize_nodes_count = 0;
int dequantize_nodes_count = 0; int dequantize_nodes_count = 0;
int force_fp32_nodes_count = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp()) { if (node->IsOp()) {
auto* op = node->Op(); auto* op = node->Op();
...@@ -96,16 +91,11 @@ void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count, ...@@ -96,16 +91,11 @@ void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count,
quantize_nodes_count++; quantize_nodes_count++;
} else if (op->Type() == "dequantize") { } else if (op->Type() == "dequantize") {
dequantize_nodes_count++; dequantize_nodes_count++;
} else if (op->Type() == "conv2d" || op->Type() == "matmul" ||
op->Type() == "fc") {
if (op->GetAttrIfExists<bool>("force_fp32_output"))
force_fp32_nodes_count++;
} }
} }
} }
EXPECT_EQ(quantize_nodes_count, quant_count); EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count); EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(force_fp32_nodes_count, force_fp32_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num); EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
} }
...@@ -125,9 +115,10 @@ ProgramDesc BuildProgramDescConv(bool use_mkldnn) { ...@@ -125,9 +115,10 @@ ProgramDesc BuildProgramDescConv(bool use_mkldnn) {
TEST(CpuBfloat16Pass, convolution) { TEST(CpuBfloat16Pass, convolution) {
bool use_mkldnn = true; bool use_mkldnn = true;
// 0 added + 1 force_fp32_output int quant_op = 3;
int added_nodes = 0; int dequant_op = 3;
MainTest(BuildProgramDescConv(use_mkldnn), 0, 0, 1, added_nodes); int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescConv(use_mkldnn), quant_op, dequant_op, added_nodes);
} }
ProgramDesc BuildProgramDescDoubleInput(bool use_mkldnn) { ProgramDesc BuildProgramDescDoubleInput(bool use_mkldnn) {
...@@ -147,9 +138,11 @@ ProgramDesc BuildProgramDescDoubleInput(bool use_mkldnn) { ...@@ -147,9 +138,11 @@ ProgramDesc BuildProgramDescDoubleInput(bool use_mkldnn) {
TEST(CpuBfloat16Pass, double_input_ops) { TEST(CpuBfloat16Pass, double_input_ops) {
bool use_mkldnn = true; bool use_mkldnn = true;
// 2 quant + 2 quant out int quant_op = 4;
int added_nodes = 4; int dequant_op = 3;
MainTest(BuildProgramDescDoubleInput(use_mkldnn), 2, 0, 0, added_nodes); int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleInput(use_mkldnn), quant_op, dequant_op,
added_nodes);
} }
ProgramDesc BuildProgramDescDuplicatedInput(bool use_mkldnn) { ProgramDesc BuildProgramDescDuplicatedInput(bool use_mkldnn) {
...@@ -169,9 +162,11 @@ ProgramDesc BuildProgramDescDuplicatedInput(bool use_mkldnn) { ...@@ -169,9 +162,11 @@ ProgramDesc BuildProgramDescDuplicatedInput(bool use_mkldnn) {
TEST(CpuBfloat16Pass, duplicated_input_ops) { TEST(CpuBfloat16Pass, duplicated_input_ops) {
bool use_mkldnn = true; bool use_mkldnn = true;
// 3 quant + 3 quant out int quant_op = 5;
int added_nodes = 6; int dequant_op = 3;
MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), 3, 0, 0, added_nodes); int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), quant_op, dequant_op,
added_nodes);
} }
ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) { ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) {
...@@ -193,9 +188,11 @@ ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) { ...@@ -193,9 +188,11 @@ ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) {
TEST(CpuBfloat16Pass, double_outputs_ops) { TEST(CpuBfloat16Pass, double_outputs_ops) {
bool use_mkldnn = true; bool use_mkldnn = true;
// 3 dequant + 3 dequant out int quant_op = 3;
int added_nodes = 6; int dequant_op = 3;
MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), 0, 3, 0, added_nodes); int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), quant_op, dequant_op,
added_nodes);
} }
} // namespace ir } // namespace ir
......
...@@ -255,14 +255,21 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { ...@@ -255,14 +255,21 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, op_dequant_pattern); GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, op_dequant_pattern);
if (dequant_in->outputs.size() == 1) { if (dequant_in->outputs.size() == 1) {
auto output_name = "Out"; if (any_op->Op()->Type() == "conv2d" ||
if (any_op->Op()->Type() == "conv2d") { any_op->Op()->Type() == "conv2d_transpose") {
// do not squash if fuse residual connection is true // do not squash if fuse residual connection is true
// because residual fusion does not support force output with fp32 // because residual fusion does not support force output with fp32
if (any_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection")) if (any_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection"))
return; return;
output_name = "Output";
} }
// Find the name of the output linking any_op to dequant_in
std::string output_name;
for (auto name : any_op->Op()->OutputNames())
for (auto out_name : any_op->Op()->Output(name))
if (out_name == dequant_in->Name()) output_name = name;
if (output_name.empty()) return;
any_op->Op()->SetAttr("force_fp32_output", true); any_op->Op()->SetAttr("force_fp32_output", true);
any_op->Op()->SetOutput(output_name, any_op->Op()->SetOutput(output_name,
std::vector<std::string>({dequant_out->Name()})); std::vector<std::string>({dequant_out->Name()}));
...@@ -363,10 +370,10 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { ...@@ -363,10 +370,10 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Dequantize scale(%f) should have positive value.", "Dequantize scale(%f) should have positive value.",
dequant_scale)); dequant_scale));
PADDLE_ENFORCE_GT(scale_scale, 0.0f, PADDLE_ENFORCE_NE(
platform::errors::InvalidArgument( scale_scale, 0.0f,
"Scale(%f) of scale op should have positive value.", platform::errors::InvalidArgument(
scale_scale)); "Scale(%f) should have a non-zero value", scale_scale));
dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale); dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale);
dequant_op->Op()->SetOutput( dequant_op->Op()->SetOutput(
...@@ -378,10 +385,86 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { ...@@ -378,10 +385,86 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_dequant_scale_squash_count); AddStatis(found_dequant_scale_squash_count);
PrettyLogDetail("--- squashed %d scale with dequant", PrettyLogDetail("--- squashed %d scale with dequantize op",
found_dequant_scale_squash_count); found_dequant_scale_squash_count);
} }
// squash scale with quantize
void CPUQuantizeSquashPass::ScaleQuantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ScaleQuant scale_quant_pattern{gpd.mutable_pattern(),
"scale_quant"};
scale_quant_pattern();
int found_scale_quant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash scale-quant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(scale_in, scale_in, scale_quant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, scale_quant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, scale_quant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, scale_quant_pattern);
if (quant_in->outputs.size() == 1 &&
scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) {
auto quant_scale = quant_op->Op()->GetAttrIfExists<float>("Scale");
auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale");
PADDLE_ENFORCE_GT(
quant_scale, 0.0f,
platform::errors::InvalidArgument(
"Quantize scale(%f) should have positive value.", quant_scale));
PADDLE_ENFORCE_NE(
scale_scale, 0.0f,
platform::errors::InvalidArgument(
"Scale(%f) should have a non-zero value", scale_scale));
quant_op->Op()->SetAttr("Scale", quant_scale * scale_scale);
quant_op->Op()->SetInput("Input",
std::vector<std::string>({scale_in->Name()}));
IR_NODE_LINK_TO(scale_in, quant_op);
GraphSafeRemoveNodes(graph, {scale_op, quant_in});
found_scale_quant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_scale_quant_squash_count);
PrettyLogDetail("--- squashed %d scale with quantize op",
found_scale_quant_squash_count);
}
// squash quantize if is before bfloat16 conv2d
void CPUQuantizeSquashPass::QuantizeBf16Conv(Graph* graph) const {
GraphPatternDetector gpd;
patterns::QuantConv pattern{gpd.mutable_pattern(), "quant_conv"};
pattern();
int found_quant_conv_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash quant-conv2d ops pair";
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_in, conv_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, pattern);
if (conv_in->outputs.size() == 1 &&
quant_op->Op()->GetAttrIfExists<float>("Scale") == 1.0) {
conv_op->Op()->SetInput("Input",
std::vector<std::string>({quant_in->Name()}));
IR_NODE_LINK_TO(quant_in, conv_op);
GraphSafeRemoveNodes(graph, {quant_op, conv_in});
found_quant_conv_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_quant_conv_squash_count);
PrettyLogDetail("--- squashed %d quantize with bfloat16 conv2d op",
found_quant_conv_squash_count);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, graph,
...@@ -389,6 +472,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { ...@@ -389,6 +472,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
"The graph in function CPUQuantizeSquashPass::ApplyImpl is null.")); "The graph in function CPUQuantizeSquashPass::ApplyImpl is null."));
FusePassBase::Init("cpu_quantize_squash_pass", graph); FusePassBase::Init("cpu_quantize_squash_pass", graph);
DequantScaleSquash(graph);
ScaleQuantSquash(graph);
std::unordered_map<const Node*, int> nodes_keep_counter; std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph, &nodes_keep_counter); FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter);
...@@ -396,7 +481,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { ...@@ -396,7 +481,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
RequantOpSquash(graph); RequantOpSquash(graph);
OpDequantSquash(graph); OpDequantSquash(graph);
MultipleQuantizeSquash(graph); MultipleQuantizeSquash(graph);
DequantScaleSquash(graph); QuantizeBf16Conv(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -78,6 +78,16 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -78,6 +78,16 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/ */
void DequantScaleSquash(Graph* graph) const; void DequantScaleSquash(Graph* graph) const;
/*
* Squash scale if scale is before quantize
*/
void ScaleQuantSquash(Graph* graph) const;
/*
* Squash quantize if is before bfloat16 conv2d
*/
void QuantizeBf16Conv(Graph* graph) const;
const std::string name_scope_{"squash"}; const std::string name_scope_{"squash"};
}; };
......
...@@ -24,7 +24,8 @@ namespace ir { ...@@ -24,7 +24,8 @@ namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn, const std::vector<std::string>& outputs, bool use_mkldnn,
const std::vector<float> scale = {}, float bias = 0.0) { const std::vector<float> scale = {}, float bias = 0.0,
const std::string& mkldnn_data_type = "float32") {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
...@@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]}); if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
op->SetAttr("force_fp32_output", false);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "quantize") { } else if (type == "quantize") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
...@@ -52,6 +55,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -52,6 +55,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
} else if (type == "concat") { } else if (type == "concat") {
op->SetInput("X", inputs); op->SetInput("X", inputs);
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "fc") { } else if (type == "fc") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
PADDLE_ENFORCE_EQ(inputs.size(), 2UL, PADDLE_ENFORCE_EQ(inputs.size(), 2UL,
...@@ -63,6 +67,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -63,6 +67,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]); if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]); if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
op->SetAttr("force_fp32_output", false);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "scale") { } else if (type == "scale") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
...@@ -74,6 +80,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -74,6 +80,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
if (scale.size() > 0) op->SetAttr("Scale_x", scale[0]); if (scale.size() > 0) op->SetAttr("Scale_x", scale[0]);
if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]); if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]);
op->SetAttr("force_fp32_output", false);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} }
} }
...@@ -299,6 +307,20 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale, ...@@ -299,6 +307,20 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale,
return prog; return prog;
} }
// a->Scale->b
// b->Quant->c
ProgramDesc BuildScaleQuantProgramDesc(bool use_mkldnn, float scale_scale,
float quant_scale, float bias) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "scale", "Scale", {"a"}, {"b"}, use_mkldnn, {scale_scale}, bias);
SetOp(&prog, "quantize", "Quant", {"b"}, {"c"}, use_mkldnn, {quant_scale});
return prog;
}
// {x,y}->Matmul->b // {x,y}->Matmul->b
// b->Dequant->c // b->Dequant->c
ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn, ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn,
...@@ -341,6 +363,22 @@ ProgramDesc BuildRequantOpProgramDesc(bool use_mkldnn, float requant_scale_in, ...@@ -341,6 +363,22 @@ ProgramDesc BuildRequantOpProgramDesc(bool use_mkldnn, float requant_scale_in,
return prog; return prog;
} }
// a->Quant->b
// b->Conv2d->c
ProgramDesc BuildQuantConv2dProgramDesc(const bool& use_mkldnn,
const float& quant_scale,
const std::string& mkldnn_data_type) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "quantize", "Quant", {"a"}, {"b"}, use_mkldnn, {quant_scale});
SetOp(&prog, "conv2d", "Conv2d", {"b"}, {"c"}, use_mkldnn, {}, 0.0f,
mkldnn_data_type);
return prog;
}
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const char* var_name) { const char* var_name) {
auto x = scope->Var(var_name); auto x = scope->Var(var_name);
...@@ -664,6 +702,22 @@ TEST(CpuQuantizeSquashPass, dequantize_scale_with_bias) { ...@@ -664,6 +702,22 @@ TEST(CpuQuantizeSquashPass, dequantize_scale_with_bias) {
"Dequant", "Scale", dequant_scale); "Dequant", "Scale", dequant_scale);
} }
// if scale has no bias
TEST(CpuQuantizeSquashPass, scale_with_no_bias_quantize) {
constexpr auto scale_scale = 1.5432f;
constexpr auto quant_scale = 1.2345f;
constexpr auto bias = 0.0f;
auto use_mkldnn = true;
// remove: dequant out, scale op
auto remove_nodes = 2;
CountNodeTest(
BuildScaleQuantProgramDesc(use_mkldnn, scale_scale, quant_scale, bias),
remove_nodes);
EqualScaleTest(
BuildScaleQuantProgramDesc(use_mkldnn, scale_scale, quant_scale, bias),
"Scale", "Quant", quant_scale * scale_scale);
}
TEST(CpuQuantizeSquashPass, matmul_with_dequant) { TEST(CpuQuantizeSquashPass, matmul_with_dequant) {
auto dequant_scale = 1.2345f; auto dequant_scale = 1.2345f;
auto use_mkldnn = true; auto use_mkldnn = true;
...@@ -688,6 +742,17 @@ TEST(CpuQuantizeSquashPass, requantize_with_matmul_fc_conv) { ...@@ -688,6 +742,17 @@ TEST(CpuQuantizeSquashPass, requantize_with_matmul_fc_conv) {
EqualScaleTest(program_desc, "Conv", "Scale_in", requant_scale_in); EqualScaleTest(program_desc, "Conv", "Scale_in", requant_scale_in);
} }
TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) {
auto quant_scale = 1.0f;
auto use_mkldnn = true;
auto mkldnn_data_type = "bfloat16";
// remove: quant_op, conv_in
auto remove_nodes = 2;
CountNodeTest(
BuildQuantConv2dProgramDesc(use_mkldnn, quant_scale, mkldnn_data_type),
remove_nodes);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -268,6 +268,7 @@ void CpuPassStrategy::EnableMkldnnBfloat16() { ...@@ -268,6 +268,7 @@ void CpuPassStrategy::EnableMkldnnBfloat16() {
if (!use_mkldnn_bfloat16_) { if (!use_mkldnn_bfloat16_) {
passes_.push_back("cpu_bfloat16_placement_pass"); passes_.push_back("cpu_bfloat16_placement_pass");
passes_.push_back("cpu_bfloat16_pass"); passes_.push_back("cpu_bfloat16_pass");
passes_.push_back("cpu_quantize_squash_pass");
} }
use_mkldnn_bfloat16_ = true; use_mkldnn_bfloat16_ = true;
#else #else
......
...@@ -156,4 +156,5 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -156,4 +156,5 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace,
ops::ReQuantOpKernel<int8_t>, ops::ReQuantOpKernel<uint8_t>); ops::ReQuantOpKernel<int8_t>, ops::ReQuantOpKernel<uint8_t>,
ops::ReQuantOpKernel<paddle::platform::bfloat16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册