未验证 提交 8ad38701 编写于 作者: T Tomasz Socha 提交者: GitHub

Bfloat16 refactor (#42238)

* Refactor Quantization

* Refactor Dequantization

* Classy solution

* Style I

* Style II

* Style III

* Use VLOG(4) for debug info

* Style IV
上级 afa846d9
...@@ -2665,41 +2665,8 @@ PDNode *patterns::UnsupportedBfloat16::operator()() { ...@@ -2665,41 +2665,8 @@ PDNode *patterns::UnsupportedBfloat16::operator()() {
return op; return op;
} }
PDNode *patterns::LastBfloat16Ops::operator()() { PDNode *patterns::Bloat16Ops::operator()() {
auto *op = pattern->NewNode(op_repr())->assert_is_op(); auto op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();
op->LinksTo({op_out});
return op_out;
}
PDNode *patterns::FirstBfloat16Ops::operator()() {
auto *op_in = pattern->NewNode(op_in_repr())->AsInput();
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
op->LinksFrom({op_in});
return op;
}
PDNode *patterns::DuplicatedInputs::operator()() {
auto op = pattern->NewNode(op_repr())->assert_is_ops({"concat", "sum"});
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
return op;
}
PDNode *patterns::DuplicatedOutputs::operator()() {
auto op = pattern->NewNode(op_repr())->assert_is_ops({"split"});
op->assert_more([&](Node *node) { op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") == return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16"; "bfloat16";
......
...@@ -1565,36 +1565,9 @@ struct UnsupportedBfloat16 : public PatternBase { ...@@ -1565,36 +1565,9 @@ struct UnsupportedBfloat16 : public PatternBase {
PATTERN_DECL_NODE(op); PATTERN_DECL_NODE(op);
}; };
struct LastBfloat16Ops : public PatternBase { struct Bloat16Ops : public PatternBase {
LastBfloat16Ops(PDPattern* pattern, const std::string& name_scope) Bloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "last_bfloat16_ops") {} : PatternBase(pattern, name_scope, "many_bfloat16_ops") {}
PDNode* operator()();
PATTERN_DECL_NODE(op);
PATTERN_DECL_NODE(op_out);
};
struct FirstBfloat16Ops : public PatternBase {
FirstBfloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "first_bfloat16_ops") {}
PDNode* operator()();
PATTERN_DECL_NODE(op_in);
PATTERN_DECL_NODE(op);
};
struct DuplicatedInputs : public PatternBase {
DuplicatedInputs(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "many_inputs_op") {}
PDNode* operator()();
PATTERN_DECL_NODE(op);
};
struct DuplicatedOutputs : public PatternBase {
DuplicatedOutputs(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "many_outputs_op") {}
PDNode* operator()(); PDNode* operator()();
......
...@@ -22,290 +22,226 @@ namespace paddle { ...@@ -22,290 +22,226 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
using string::PrettyLogDetail; namespace {
class Quanter {
public:
void AddQuantOps() {
if (IsNotPermittedOpType()) return;
void UnlinkNodes(ir::Node* a, ir::Node* b) { std::vector<std::string> linked_xputs;
a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
a->outputs.end());
b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
b->inputs.end());
}
// Checking whether a reorder from FP32 to BF16 should be added before the input for (const auto& logical_xput : op_xputs) {
// to the operator std::vector<std::string> quant_xput_names;
bool IsPermittedInputName(const std::string& input_name) { quant_xput_names.reserve(xputs_map.size());
// Only the inputs listed in \"permitted_names\" requires quanitization before
// the bfloat16 operator. Other inputs, such as Filter and Bias are reordered
// in the kernel.
const std::vector<std::string> permitted_names = {"X", "Y", "Input",
"ResidualData"};
return (std::find(permitted_names.begin(), permitted_names.end(),
input_name) != permitted_names.end());
}
// Checking whether a reorder from BF16 to FP32 should be added after the output const auto& logical_xput_name = logical_xput.first;
// to the operator if (IsNotPermittedName(logical_xput_name)) continue;
bool IsPermittedOutputName(const std::string& output_name) {
// XShape is output in transpose2 and reshape2 operators used to store the
// shape and lod of X. So this output do not need dequantize before.
return (output_name != "XShape");
}
void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in, const auto& physical_xputs_names = logical_xput.second;
int& quantize_counter) { for (const auto& physical_xput_name : physical_xputs_names) {
std::vector<std::string> input_names; if (IsAlreadyLinked(linked_xputs, physical_xput_name)) continue;
// Find the name of the input linking op to op_in
for (auto name : op->Op()->InputNames())
for (auto input_name : op->Op()->Input(name))
if (input_name == op_in->Name() && IsPermittedInputName(name))
input_names.push_back(name);
if (input_names.empty()) return;
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);
OpDesc q_desc;
q_desc.SetType("quantize");
q_desc.SetInput("Input", std::vector<std::string>({op_in->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node->Name()}));
q_desc.SetAttr("Scale", 1.f);
q_desc.SetAttr("Shift", 0.0f);
q_desc.SetAttr("bfloat16", true);
q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout")
: std::string("NCHW"));
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
for (auto name = input_names.begin(); name < input_names.end(); name++) VarDesc quant_x_desc(
op->Op()->SetInput(*name, patterns::PDNodeName(get_op_type(), get_op_edge()));
std::vector<std::string>({quantize_out_node->Name()})); auto quant_x_node = graph.CreateVarNode(&quant_x_desc);
const auto xput_name = quant_x_node->Name();
quant_xput_names.emplace_back(xput_name);
UnlinkNodes(op_in, op); auto quant_op = create_quant_op(physical_xput_name, xput_name);
IR_NODE_LINK_TO(op_in, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_node);
IR_NODE_LINK_TO(quantize_out_node, op);
quantize_counter++;
}
void AddQuantizes(Graph* g, ir::Node* op, int& quantize_counter) { auto physical_xput_node = xputs_map[physical_xput_name];
auto inputs = op->inputs; link_nodes(physical_xput_node, quant_op, quant_x_node);
PADDLE_ENFORCE_GE(inputs.size(), 1, counter++;
platform::errors::InvalidArgument( linked_xputs.push_back(physical_xput_name);
"OP(%s)'s inputs(%d) must be equal or greater than 1.", }
op->Name(), inputs.size()));
PADDLE_ENFORCE_EQ(op->outputs.size(), 1, set_edge(logical_xput_name, quant_xput_names);
platform::errors::InvalidArgument( }
"OP(%s)'s outputs(%d) must be equal to 1.", op->Name(), }
op->outputs.size()));
int get_counter() const { return counter; }
OpDesc q_desc;
q_desc.SetType("quantize"); virtual ~Quanter() = default;
std::vector<Node*> quantize_out_nodes(inputs.size()); protected:
std::vector<std::string> quantize_out_node_names(inputs.size()); Graph& graph;
ir::Node* const op;
for (size_t i = 0; i < inputs.size(); i++) {
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); std::map<std::string, ir::Node*> xputs_map;
quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc); const VariableNameMap& op_xputs;
quantize_out_node_names[i] = quantize_out_nodes[i]->Name();
int counter = 0;
q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
q_desc.SetOutput("Output", Quanter(Graph& graph, ir::Node* const op, const VariableNameMap& op_xputs)
std::vector<std::string>({quantize_out_node_names[i]})); : graph(graph), op(op), op_xputs(op_xputs){};
q_desc.SetAttr("Scale", 1.f);
q_desc.SetAttr("Shift", 0.0f); virtual bool IsNotPermittedOpType() const = 0;
q_desc.SetAttr("bfloat16", true); virtual bool IsNotPermittedName(const std::string& input_name) const = 0;
q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout") virtual std::string get_op_type() const = 0;
virtual std::string get_op_edge() const = 0;
virtual void link_nodes(ir::Node* const physical_xput_node,
ir::Node* const quant_op,
ir::Node* const quant_x_node) = 0;
virtual void set_edge(const std::string& logical_xput_name,
const std::vector<std::string>& quant_xput_names) = 0;
bool IsAlreadyLinked(const std::vector<std::string>& node_names,
const std::string& node_name) const {
return std::find(node_names.begin(), node_names.end(), node_name) !=
node_names.end();
}
virtual ir::Node* create_quant_op(const std::string& input_name,
const std::string& output_name) const {
OpDesc op_desc;
op_desc.SetType(get_op_type());
op_desc.SetInput("Input", std::vector<std::string>({input_name}));
op_desc.SetOutput("Output", std::vector<std::string>({output_name}));
op_desc.SetAttr("Scale", 1.f);
op_desc.SetAttr("Shift", 0.0f);
op_desc.SetAttr("bfloat16", true);
op_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout") ? op->Op()->GetAttr("data_layout")
: std::string("NCHW")); : std::string("NCHW"));
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. return graph.CreateOpNode(&op_desc); // OpDesc will be copied.
}
UnlinkNodes(inputs[i], op); void UnlinkNodes(ir::Node* a, ir::Node* b) const {
IR_NODE_LINK_TO(inputs[i], quantize_op); a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]); a->outputs.end());
IR_NODE_LINK_TO(quantize_out_nodes[i], op); b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
quantize_counter++; b->inputs.end());
} }
};
op->Op()->SetInput("X", quantize_out_node_names); class Quantizer final : public Quanter {
} public:
Quantizer(Graph* const graph, ir::Node* const op)
: Quanter(*graph, op, op->Op()->Inputs()) {
auto inputs = op->inputs;
PADDLE_ENFORCE_GE(
inputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal or greater than 1.", op->Name(),
inputs.size()));
// Operators like Concat and Sum have a single input name X, which actually for (auto input : inputs) xputs_map[input->Name()] = input;
// consists of multiple inputs. Such operators require a different way to find
// pattern and add quantize ops.
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(),
"duplicated_inputs"};
duplicated_inputs();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, duplicated_inputs);
AddQuantizes(g, op, quantize_counter);
}; };
gpd(graph, handler);
}
// Adding quantize ops before all operators except Concat and Sum, which have protected:
// already been handled in AddReoderBeforeDuplicatedInputs bool IsNotPermittedOpType() const override { return false; }
void AddReoderBeforeSingleInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd; // Checking whether a reorder from FP32 to BF16
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), // should be added before the input to the operator
"first_bfloat16_ops"}; bool IsNotPermittedName(const std::string& input_name) const override {
bfloat16_ops(); // Only the inputs listed in \"permitted_names\"
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, // requires quanitization before the bfloat16 operator.
Graph* g) { // Other inputs, such as Filter and Bias are reordered in the kernel.
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops); const std::vector<std::string> permitted_names = {"X", "Y", "Input",
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); "ResidualData"};
if (op->Op()->Type() != "sum" && op->Op()->Type() != "concat") {
AddQuantize(g, op, op_in, quantize_counter); return std::none_of(
permitted_names.begin(), permitted_names.end(),
[&input_name](const std::string& name) { return name == input_name; });
} }
};
gpd(graph, handler);
}
void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const { std::string get_op_type() const override { return "quantize"; };
int quantize_counter = 0; std::string get_op_edge() const override { return "out"; };
AddReoderBeforeDuplicatedInputs(graph, quantize_counter);
AddReoderBeforeSingleInputs(graph, quantize_counter);
PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
quantize_counter);
}
void AddDequantize(Graph* g, ir::Node* op, ir::Node* op_out, void link_nodes(ir::Node* const physical_xput_node, ir::Node* const quant_op,
int& dequantize_counter) { ir::Node* const quant_x_node) override {
if (op->Op()->Type() == "prior_box") return; UnlinkNodes(physical_xput_node, op);
IR_NODE_LINK_TO(physical_xput_node, quant_op);
// Find the name of the output linking op to op_out IR_NODE_LINK_TO(quant_op, quant_x_node);
std::vector<std::string> output_names; IR_NODE_LINK_TO(quant_x_node, op);
for (auto name : op->Op()->OutputNames()) }
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name)) void set_edge(const std::string& logical_xput_name,
output_names.push_back(name); const std::vector<std::string>& quant_xput_names) override {
op->Op()->SetInput(logical_xput_name, quant_xput_names);
if (output_names.empty()) return; }
};
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(*name,
std::vector<std::string>({dequantize_in_node->Name()}));
UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);
dequantize_counter++;
}
void AddDequantizes(Graph* g, ir::Node* op, int& dequantize_counter) { class DeQuantizer final : public Quanter {
public:
DeQuantizer(Graph* const graph, ir::Node* const op)
: Quanter(*graph, op, op->Op()->Outputs()) {
auto outputs = op->outputs; auto outputs = op->outputs;
PADDLE_ENFORCE_GE(outputs.size(), 1, PADDLE_ENFORCE_GE(
outputs.size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal or greater than 1.", "OP(%s)'s outputs(%d) must be equal or greater than 1.", op->Name(),
op->Name(), outputs.size())); outputs.size()));
PADDLE_ENFORCE_EQ(op->inputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal to 1.", op->Name(),
op->inputs.size()));
OpDesc deq_desc; for (auto output : outputs) xputs_map[output->Name()] = output;
deq_desc.SetType("dequantize"); };
std::vector<Node*> dequantize_in_nodes(outputs.size()); protected:
std::vector<std::string> dequantize_in_node_names(outputs.size()); bool IsNotPermittedOpType() const override {
// Prior_box operator output is always FP32 so no dequantization is needed.
return op->Op()->Type() == "prior_box";
}
for (size_t i = 0; i < outputs.size(); i++) { // Checking whether a reorder from BF16 to FP32
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); // should be added after the output to the operator
dequantize_in_nodes[i] = g->CreateVarNode(&dequantize_in_desc); bool IsNotPermittedName(const std::string& output_name) const override {
dequantize_in_node_names[i] = dequantize_in_nodes[i]->Name(); // XShape is output in transpose2 and reshape2 operators used to store the
// shape and lod of X. So this output do not need dequantize before.
return (output_name == "XShape");
}
deq_desc.SetInput("Input", std::string get_op_type() const override { return "dequantize"; };
std::vector<std::string>({dequantize_in_node_names[i]})); std::string get_op_edge() const override { return "in"; };
deq_desc.SetOutput("Output",
std::vector<std::string>({outputs[i]->Name()}));
deq_desc.SetAttr("Scale", 1.f); void link_nodes(ir::Node* const physical_xput_node, ir::Node* const quant_op,
deq_desc.SetAttr("Shift", 0.0f); ir::Node* const quant_x_node) override {
deq_desc.SetAttr("bfloat16", true); UnlinkNodes(op, physical_xput_node);
deq_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout") IR_NODE_LINK_TO(quant_op, physical_xput_node);
? op->Op()->GetAttr("data_layout") IR_NODE_LINK_TO(quant_x_node, quant_op);
: std::string("NCHW")); IR_NODE_LINK_TO(op, quant_x_node);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied. }
UnlinkNodes(op, outputs[i]);
IR_NODE_LINK_TO(op, dequantize_in_nodes[i]);
IR_NODE_LINK_TO(dequantize_in_nodes[i], dequantize_op);
IR_NODE_LINK_TO(dequantize_op, outputs[i]);
dequantize_counter++; void set_edge(const std::string& logical_xput_name,
const std::vector<std::string>& quant_xput_names) override {
op->Op()->SetOutput(logical_xput_name, quant_xput_names);
} }
op->Op()->SetOutput("Out", dequantize_in_node_names); ir::Node* create_quant_op(const std::string& input_name,
const std::string& output_name) const override {
return Quanter::create_quant_op(output_name, input_name);
}
};
} }
using string::PrettyLogDetail;
// Operators like split have a single output name Out, which actually void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
// consists of multiple outputs. Such operators require a different way to find int quantize_counter = 0;
// pattern and add dequantize ops. int dequantize_counter = 0;
void AddReoderAfterDuplicatedOutputs(ir::Graph* graph,
int& dequantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedOutputs duplicated_outputs{gpd.mutable_pattern(),
"duplicated_outputs"};
duplicated_outputs();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, duplicated_outputs);
AddDequantizes(g, op, dequantize_counter);
};
gpd(graph, handler);
}
// Adding dequantize ops after all operators except split, which has
// already been handled in AddReoderAfterDuplicatedOutputs
void AddReoderAfterSingleOutputs(ir::Graph* graph, int& dequantize_counter) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), patterns::Bloat16Ops Bloat16Ops{gpd.mutable_pattern(), "Bloat16Ops"};
"last_bfloat16_ops"}; Bloat16Ops();
bfloat16_ops();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* graph) {
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op, op, Bloat16Ops);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
if (op->Op()->Type() != "split") { Quantizer quantizer(graph, op);
AddDequantize(g, op, op_out, dequantize_counter); quantizer.AddQuantOps();
} quantize_counter += quantizer.get_counter();
DeQuantizer dequantizer(graph, op);
dequantizer.AddQuantOps();
dequantize_counter += dequantizer.get_counter();
}; };
gpd(graph, handler); gpd(graph, handler);
}
void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const { PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
int dequantize_counter = 0; quantize_counter);
AddReoderAfterDuplicatedOutputs(graph, dequantize_counter);
AddReoderAfterSingleOutputs(graph, dequantize_counter);
PrettyLogDetail("--- added %d dequantize ops after bfloat16 op", PrettyLogDetail("--- added %d dequantize ops after bfloat16 op",
dequantize_counter); dequantize_counter);
} }
void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
SetInputDataType(graph);
SetOutputDataType(graph);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -24,8 +24,6 @@ namespace ir { ...@@ -24,8 +24,6 @@ namespace ir {
class CPUBFloat16Pass : public Pass { class CPUBFloat16Pass : public Pass {
protected: protected:
void SetInputDataType(ir::Graph* graph) const;
void SetOutputDataType(ir::Graph* graph) const;
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
...@@ -27,8 +27,16 @@ namespace ir { ...@@ -27,8 +27,16 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void CPUBfloat16PlacementPass::SetMkldnnDataType( void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
ir::Graph* graph, int* bfloat16_operators) const { int bfloat16_operators = 0;
bfloat16_operators += SetMkldnnDataType(graph);
bfloat16_operators -= RemoveOrphanedOperators(graph);
bfloat16_operators -= RemoveUnsupportedOperators(graph);
PrettyLogDetail("--- marked %d operators to bfloat16 ",
bfloat16_operators);
}
int CPUBfloat16PlacementPass::SetMkldnnDataType(ir::Graph* graph) const {
const auto& op_types_list = const auto& op_types_list =
Get<std::unordered_set<std::string>>("bfloat16_enabled_op_types"); Get<std::unordered_set<std::string>>("bfloat16_enabled_op_types");
// set mkldnn_data_type to bfloat16 to all operators that are in // set mkldnn_data_type to bfloat16 to all operators that are in
...@@ -39,6 +47,7 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType( ...@@ -39,6 +47,7 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType(
"bfloat16_placement"}; "bfloat16_placement"};
bfloat16_placement_pattern(op_types_list); bfloat16_placement_pattern(op_types_list);
int detected_operators = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_placement_pattern); GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_placement_pattern);
...@@ -50,58 +59,58 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType( ...@@ -50,58 +59,58 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType(
if ((op->Op()->HasAttr("mkldnn_data_type") || if ((op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) && op->Op()->HasProtoAttr("mkldnn_data_type")) &&
!platform::HasOpINT8DataType(op->Op())) { !platform::HasOpINT8DataType(op->Op())) {
VLOG(4) << "--- marked " << op->Op()->Type()
<< " operator to bfloat16 ";
op->Op()->SetAttr("mkldnn_data_type", std::string("bfloat16")); op->Op()->SetAttr("mkldnn_data_type", std::string("bfloat16"));
(*bfloat16_operators)++; detected_operators++;
} }
}; };
gpd(graph, handler); gpd(graph, handler);
return detected_operators;
} }
void CPUBfloat16PlacementPass::RemoveOrphanedOperators( int CPUBfloat16PlacementPass::RemoveOrphanedOperators(ir::Graph* graph) const {
ir::Graph* graph, int* bfloat16_operators) const {
// find orphaned bfloat16 operator that is between two float32 operators // find orphaned bfloat16 operator that is between two float32 operators
// revert mkldnn_data_type attr to float32 // revert mkldnn_data_type attr to float32
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::OrphanedBfloat16 orphaned_bfloat16_pattern{gpd.mutable_pattern(), patterns::OrphanedBfloat16 orphaned_bfloat16_pattern{gpd.mutable_pattern(),
"orphaned_bfloat16"}; "orphaned_bfloat16"};
orphaned_bfloat16_pattern(); orphaned_bfloat16_pattern();
int detected_operators = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, orphaned_bfloat16_pattern); GET_IR_NODE_FROM_SUBGRAPH(op, op, orphaned_bfloat16_pattern);
op->Op()->SetAttr("mkldnn_data_type", std::string("float32")); op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
bfloat16_operators--; VLOG(4) << "--- demarked " << op->Op()->Type() << " operator to bfloat16 ";
detected_operators++;
}; };
gpd(graph, handler); gpd(graph, handler);
return detected_operators;
} }
void CPUBfloat16PlacementPass::RemoveUnsupportedOperators( int CPUBfloat16PlacementPass::RemoveUnsupportedOperators(
ir::Graph* graph, int* bfloat16_operators) const { ir::Graph* graph) const {
// now quantize is supported FP32 only, so try to find // now quantize is supported FP32 only, so try to find
// bfloat16 operator that input type is not FP32 // bfloat16 operator that input type is not FP32
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::UnsupportedBfloat16 unsupported_bfloat16_pattern{ patterns::UnsupportedBfloat16 unsupported_bfloat16_pattern{
gpd.mutable_pattern(), "unsupported_bfloat16"}; gpd.mutable_pattern(), "unsupported_bfloat16"};
unsupported_bfloat16_pattern(); unsupported_bfloat16_pattern();
int detected_operators = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(prev_out, prev_out, unsupported_bfloat16_pattern); GET_IR_NODE_FROM_SUBGRAPH(prev_out, prev_out, unsupported_bfloat16_pattern);
GET_IR_NODE_FROM_SUBGRAPH(op, op, unsupported_bfloat16_pattern); GET_IR_NODE_FROM_SUBGRAPH(op, op, unsupported_bfloat16_pattern);
if ((prev_out->Var()->GetDataType() != proto::VarType::FP32)) { if ((prev_out->Var()->GetDataType() != proto::VarType::FP32)) {
op->Op()->SetAttr("mkldnn_data_type", std::string("float32")); op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
bfloat16_operators--; VLOG(4) << "--- demarked " << op->Op()->Type()
<< " operator to bfloat16 ";
detected_operators++;
} }
}; };
gpd(graph, handler); gpd(graph, handler);
} return detected_operators;
void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
int bfloat16_operators = 0;
SetMkldnnDataType(graph, &bfloat16_operators);
RemoveOrphanedOperators(graph, &bfloat16_operators);
RemoveUnsupportedOperators(graph, &bfloat16_operators);
PrettyLogDetail("--- marked %d operators to bfloat16 ",
bfloat16_operators);
} }
} // namespace ir } // namespace ir
......
...@@ -26,14 +26,11 @@ namespace ir { ...@@ -26,14 +26,11 @@ namespace ir {
*/ */
class CPUBfloat16PlacementPass : public Pass { class CPUBfloat16PlacementPass : public Pass {
protected: protected:
void SetMkldnnDataType(ir::Graph* graph, int* bfloat16_operators) const;
void RemoveOrphanedOperators(ir::Graph* graph, int* bfloat16_operators) const;
void RemoveUnsupportedOperators(ir::Graph* graph,
int* bfloat16_operators) const;
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
int SetMkldnnDataType(ir::Graph* graph) const;
int RemoveOrphanedOperators(ir::Graph* graph) const;
int RemoveUnsupportedOperators(ir::Graph* graph) const;
}; };
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册