未验证 提交 8ad38701 编写于 作者: T Tomasz Socha 提交者: GitHub

Bfloat16 refactor (#42238)

* Refactor Quantization

* Refactor Dequantization

* Classy solution

* Style I

* Style II

* Style III

* Use VLOG(4) for debug info

* Style IV
上级 afa846d9
......@@ -2665,41 +2665,8 @@ PDNode *patterns::UnsupportedBfloat16::operator()() {
return op;
}
PDNode *patterns::LastBfloat16Ops::operator()() {
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();
op->LinksTo({op_out});
return op_out;
}
PDNode *patterns::FirstBfloat16Ops::operator()() {
auto *op_in = pattern->NewNode(op_in_repr())->AsInput();
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
op->LinksFrom({op_in});
return op;
}
PDNode *patterns::DuplicatedInputs::operator()() {
auto op = pattern->NewNode(op_repr())->assert_is_ops({"concat", "sum"});
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
return op;
}
PDNode *patterns::DuplicatedOutputs::operator()() {
auto op = pattern->NewNode(op_repr())->assert_is_ops({"split"});
PDNode *patterns::Bloat16Ops::operator()() {
auto op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
......
......@@ -1565,36 +1565,9 @@ struct UnsupportedBfloat16 : public PatternBase {
PATTERN_DECL_NODE(op);
};
struct LastBfloat16Ops : public PatternBase {
LastBfloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "last_bfloat16_ops") {}
PDNode* operator()();
PATTERN_DECL_NODE(op);
PATTERN_DECL_NODE(op_out);
};
struct FirstBfloat16Ops : public PatternBase {
FirstBfloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "first_bfloat16_ops") {}
PDNode* operator()();
PATTERN_DECL_NODE(op_in);
PATTERN_DECL_NODE(op);
};
struct DuplicatedInputs : public PatternBase {
DuplicatedInputs(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "many_inputs_op") {}
PDNode* operator()();
PATTERN_DECL_NODE(op);
};
struct DuplicatedOutputs : public PatternBase {
DuplicatedOutputs(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "many_outputs_op") {}
struct Bloat16Ops : public PatternBase {
Bloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "many_bfloat16_ops") {}
PDNode* operator()();
......
......@@ -22,290 +22,226 @@ namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
namespace {
class Quanter {
public:
void AddQuantOps() {
if (IsNotPermittedOpType()) return;
void UnlinkNodes(ir::Node* a, ir::Node* b) {
a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
a->outputs.end());
b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
b->inputs.end());
}
std::vector<std::string> linked_xputs;
// Checking whether a reorder from FP32 to BF16 should be added before the input
// to the operator
bool IsPermittedInputName(const std::string& input_name) {
// Only the inputs listed in \"permitted_names\" requires quanitization before
// the bfloat16 operator. Other inputs, such as Filter and Bias are reordered
// in the kernel.
const std::vector<std::string> permitted_names = {"X", "Y", "Input",
"ResidualData"};
return (std::find(permitted_names.begin(), permitted_names.end(),
input_name) != permitted_names.end());
}
for (const auto& logical_xput : op_xputs) {
std::vector<std::string> quant_xput_names;
quant_xput_names.reserve(xputs_map.size());
// Checking whether a reorder from BF16 to FP32 should be added after the output
// to the operator
bool IsPermittedOutputName(const std::string& output_name) {
// XShape is output in transpose2 and reshape2 operators used to store the
// shape and lod of X. So this output do not need dequantize before.
return (output_name != "XShape");
}
const auto& logical_xput_name = logical_xput.first;
if (IsNotPermittedName(logical_xput_name)) continue;
void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
int& quantize_counter) {
std::vector<std::string> input_names;
// Find the name of the input linking op to op_in
for (auto name : op->Op()->InputNames())
for (auto input_name : op->Op()->Input(name))
if (input_name == op_in->Name() && IsPermittedInputName(name))
input_names.push_back(name);
if (input_names.empty()) return;
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);
OpDesc q_desc;
q_desc.SetType("quantize");
q_desc.SetInput("Input", std::vector<std::string>({op_in->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node->Name()}));
q_desc.SetAttr("Scale", 1.f);
q_desc.SetAttr("Shift", 0.0f);
q_desc.SetAttr("bfloat16", true);
q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout")
: std::string("NCHW"));
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
const auto& physical_xputs_names = logical_xput.second;
for (const auto& physical_xput_name : physical_xputs_names) {
if (IsAlreadyLinked(linked_xputs, physical_xput_name)) continue;
for (auto name = input_names.begin(); name < input_names.end(); name++)
op->Op()->SetInput(*name,
std::vector<std::string>({quantize_out_node->Name()}));
VarDesc quant_x_desc(
patterns::PDNodeName(get_op_type(), get_op_edge()));
auto quant_x_node = graph.CreateVarNode(&quant_x_desc);
const auto xput_name = quant_x_node->Name();
quant_xput_names.emplace_back(xput_name);
UnlinkNodes(op_in, op);
IR_NODE_LINK_TO(op_in, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_node);
IR_NODE_LINK_TO(quantize_out_node, op);
quantize_counter++;
}
auto quant_op = create_quant_op(physical_xput_name, xput_name);
void AddQuantizes(Graph* g, ir::Node* op, int& quantize_counter) {
auto inputs = op->inputs;
PADDLE_ENFORCE_GE(inputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal or greater than 1.",
op->Name(), inputs.size()));
PADDLE_ENFORCE_EQ(op->outputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal to 1.", op->Name(),
op->outputs.size()));
OpDesc q_desc;
q_desc.SetType("quantize");
std::vector<Node*> quantize_out_nodes(inputs.size());
std::vector<std::string> quantize_out_node_names(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc);
quantize_out_node_names[i] = quantize_out_nodes[i]->Name();
q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node_names[i]}));
q_desc.SetAttr("Scale", 1.f);
q_desc.SetAttr("Shift", 0.0f);
q_desc.SetAttr("bfloat16", true);
q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
auto physical_xput_node = xputs_map[physical_xput_name];
link_nodes(physical_xput_node, quant_op, quant_x_node);
counter++;
linked_xputs.push_back(physical_xput_name);
}
set_edge(logical_xput_name, quant_xput_names);
}
}
int get_counter() const { return counter; }
virtual ~Quanter() = default;
protected:
Graph& graph;
ir::Node* const op;
std::map<std::string, ir::Node*> xputs_map;
const VariableNameMap& op_xputs;
int counter = 0;
Quanter(Graph& graph, ir::Node* const op, const VariableNameMap& op_xputs)
: graph(graph), op(op), op_xputs(op_xputs){};
virtual bool IsNotPermittedOpType() const = 0;
virtual bool IsNotPermittedName(const std::string& input_name) const = 0;
virtual std::string get_op_type() const = 0;
virtual std::string get_op_edge() const = 0;
virtual void link_nodes(ir::Node* const physical_xput_node,
ir::Node* const quant_op,
ir::Node* const quant_x_node) = 0;
virtual void set_edge(const std::string& logical_xput_name,
const std::vector<std::string>& quant_xput_names) = 0;
bool IsAlreadyLinked(const std::vector<std::string>& node_names,
const std::string& node_name) const {
return std::find(node_names.begin(), node_names.end(), node_name) !=
node_names.end();
}
virtual ir::Node* create_quant_op(const std::string& input_name,
const std::string& output_name) const {
OpDesc op_desc;
op_desc.SetType(get_op_type());
op_desc.SetInput("Input", std::vector<std::string>({input_name}));
op_desc.SetOutput("Output", std::vector<std::string>({output_name}));
op_desc.SetAttr("Scale", 1.f);
op_desc.SetAttr("Shift", 0.0f);
op_desc.SetAttr("bfloat16", true);
op_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout")
: std::string("NCHW"));
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
return graph.CreateOpNode(&op_desc); // OpDesc will be copied.
}
UnlinkNodes(inputs[i], op);
IR_NODE_LINK_TO(inputs[i], quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
IR_NODE_LINK_TO(quantize_out_nodes[i], op);
quantize_counter++;
void UnlinkNodes(ir::Node* a, ir::Node* b) const {
a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
a->outputs.end());
b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
b->inputs.end());
}
};
op->Op()->SetInput("X", quantize_out_node_names);
}
class Quantizer final : public Quanter {
public:
Quantizer(Graph* const graph, ir::Node* const op)
: Quanter(*graph, op, op->Op()->Inputs()) {
auto inputs = op->inputs;
PADDLE_ENFORCE_GE(
inputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal or greater than 1.", op->Name(),
inputs.size()));
// Operators like Concat and Sum have a single input name X, which actually
// consists of multiple inputs. Such operators require a different way to find
// pattern and add quantize ops.
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(),
"duplicated_inputs"};
duplicated_inputs();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, duplicated_inputs);
AddQuantizes(g, op, quantize_counter);
for (auto input : inputs) xputs_map[input->Name()] = input;
};
gpd(graph, handler);
}
// Adding quantize ops before all operators except Concat and Sum, which have
// already been handled in AddReoderBeforeDuplicatedInputs
void AddReoderBeforeSingleInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd;
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"first_bfloat16_ops"};
bfloat16_ops();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
if (op->Op()->Type() != "sum" && op->Op()->Type() != "concat") {
AddQuantize(g, op, op_in, quantize_counter);
protected:
bool IsNotPermittedOpType() const override { return false; }
// Checking whether a reorder from FP32 to BF16
// should be added before the input to the operator
bool IsNotPermittedName(const std::string& input_name) const override {
// Only the inputs listed in \"permitted_names\"
// requires quanitization before the bfloat16 operator.
// Other inputs, such as Filter and Bias are reordered in the kernel.
const std::vector<std::string> permitted_names = {"X", "Y", "Input",
"ResidualData"};
return std::none_of(
permitted_names.begin(), permitted_names.end(),
[&input_name](const std::string& name) { return name == input_name; });
}
};
gpd(graph, handler);
}
void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
int quantize_counter = 0;
AddReoderBeforeDuplicatedInputs(graph, quantize_counter);
AddReoderBeforeSingleInputs(graph, quantize_counter);
PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
quantize_counter);
}
std::string get_op_type() const override { return "quantize"; };
std::string get_op_edge() const override { return "out"; };
void AddDequantize(Graph* g, ir::Node* op, ir::Node* op_out,
int& dequantize_counter) {
if (op->Op()->Type() == "prior_box") return;
// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);
if (output_names.empty()) return;
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(*name,
std::vector<std::string>({dequantize_in_node->Name()}));
UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);
dequantize_counter++;
}
void link_nodes(ir::Node* const physical_xput_node, ir::Node* const quant_op,
ir::Node* const quant_x_node) override {
UnlinkNodes(physical_xput_node, op);
IR_NODE_LINK_TO(physical_xput_node, quant_op);
IR_NODE_LINK_TO(quant_op, quant_x_node);
IR_NODE_LINK_TO(quant_x_node, op);
}
void set_edge(const std::string& logical_xput_name,
const std::vector<std::string>& quant_xput_names) override {
op->Op()->SetInput(logical_xput_name, quant_xput_names);
}
};
void AddDequantizes(Graph* g, ir::Node* op, int& dequantize_counter) {
class DeQuantizer final : public Quanter {
public:
DeQuantizer(Graph* const graph, ir::Node* const op)
: Quanter(*graph, op, op->Op()->Outputs()) {
auto outputs = op->outputs;
PADDLE_ENFORCE_GE(outputs.size(), 1,
PADDLE_ENFORCE_GE(
outputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal or greater than 1.",
op->Name(), outputs.size()));
PADDLE_ENFORCE_EQ(op->inputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal to 1.", op->Name(),
op->inputs.size()));
"OP(%s)'s outputs(%d) must be equal or greater than 1.", op->Name(),
outputs.size()));
OpDesc deq_desc;
deq_desc.SetType("dequantize");
for (auto output : outputs) xputs_map[output->Name()] = output;
};
std::vector<Node*> dequantize_in_nodes(outputs.size());
std::vector<std::string> dequantize_in_node_names(outputs.size());
protected:
bool IsNotPermittedOpType() const override {
// Prior_box operator output is always FP32 so no dequantization is needed.
return op->Op()->Type() == "prior_box";
}
for (size_t i = 0; i < outputs.size(); i++) {
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
dequantize_in_nodes[i] = g->CreateVarNode(&dequantize_in_desc);
dequantize_in_node_names[i] = dequantize_in_nodes[i]->Name();
// Checking whether a reorder from BF16 to FP32
// should be added after the output to the operator
bool IsNotPermittedName(const std::string& output_name) const override {
// XShape is output in transpose2 and reshape2 operators used to store the
// shape and lod of X. So this output do not need dequantize before.
return (output_name == "XShape");
}
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node_names[i]}));
deq_desc.SetOutput("Output",
std::vector<std::string>({outputs[i]->Name()}));
std::string get_op_type() const override { return "dequantize"; };
std::string get_op_edge() const override { return "in"; };
deq_desc.SetAttr("Scale", 1.f);
deq_desc.SetAttr("Shift", 0.0f);
deq_desc.SetAttr("bfloat16", true);
deq_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout")
: std::string("NCHW"));
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
UnlinkNodes(op, outputs[i]);
IR_NODE_LINK_TO(op, dequantize_in_nodes[i]);
IR_NODE_LINK_TO(dequantize_in_nodes[i], dequantize_op);
IR_NODE_LINK_TO(dequantize_op, outputs[i]);
void link_nodes(ir::Node* const physical_xput_node, ir::Node* const quant_op,
ir::Node* const quant_x_node) override {
UnlinkNodes(op, physical_xput_node);
IR_NODE_LINK_TO(quant_op, physical_xput_node);
IR_NODE_LINK_TO(quant_x_node, quant_op);
IR_NODE_LINK_TO(op, quant_x_node);
}
dequantize_counter++;
void set_edge(const std::string& logical_xput_name,
const std::vector<std::string>& quant_xput_names) override {
op->Op()->SetOutput(logical_xput_name, quant_xput_names);
}
op->Op()->SetOutput("Out", dequantize_in_node_names);
ir::Node* create_quant_op(const std::string& input_name,
const std::string& output_name) const override {
return Quanter::create_quant_op(output_name, input_name);
}
};
}
using string::PrettyLogDetail;
// Operators like split have a single output name Out, which actually
// consists of multiple outputs. Such operators require a different way to find
// pattern and add dequantize ops.
void AddReoderAfterDuplicatedOutputs(ir::Graph* graph,
int& dequantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedOutputs duplicated_outputs{gpd.mutable_pattern(),
"duplicated_outputs"};
duplicated_outputs();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, duplicated_outputs);
AddDequantizes(g, op, dequantize_counter);
};
gpd(graph, handler);
}
void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
int quantize_counter = 0;
int dequantize_counter = 0;
// Adding dequantize ops after all operators except split, which has
// already been handled in AddReoderAfterDuplicatedOutputs
void AddReoderAfterSingleOutputs(ir::Graph* graph, int& dequantize_counter) {
GraphPatternDetector gpd;
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"last_bfloat16_ops"};
bfloat16_ops();
patterns::Bloat16Ops Bloat16Ops{gpd.mutable_pattern(), "Bloat16Ops"};
Bloat16Ops();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
if (op->Op()->Type() != "split") {
AddDequantize(g, op, op_out, dequantize_counter);
}
Graph* graph) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, Bloat16Ops);
Quantizer quantizer(graph, op);
quantizer.AddQuantOps();
quantize_counter += quantizer.get_counter();
DeQuantizer dequantizer(graph, op);
dequantizer.AddQuantOps();
dequantize_counter += dequantizer.get_counter();
};
gpd(graph, handler);
}
void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
int dequantize_counter = 0;
AddReoderAfterDuplicatedOutputs(graph, dequantize_counter);
AddReoderAfterSingleOutputs(graph, dequantize_counter);
PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
quantize_counter);
PrettyLogDetail("--- added %d dequantize ops after bfloat16 op",
dequantize_counter);
}
void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
SetInputDataType(graph);
SetOutputDataType(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -24,8 +24,6 @@ namespace ir {
class CPUBFloat16Pass : public Pass {
protected:
void SetInputDataType(ir::Graph* graph) const;
void SetOutputDataType(ir::Graph* graph) const;
void ApplyImpl(ir::Graph* graph) const override;
};
......
......@@ -27,8 +27,16 @@ namespace ir {
using string::PrettyLogDetail;
void CPUBfloat16PlacementPass::SetMkldnnDataType(
ir::Graph* graph, int* bfloat16_operators) const {
void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
int bfloat16_operators = 0;
bfloat16_operators += SetMkldnnDataType(graph);
bfloat16_operators -= RemoveOrphanedOperators(graph);
bfloat16_operators -= RemoveUnsupportedOperators(graph);
PrettyLogDetail("--- marked %d operators to bfloat16 ",
bfloat16_operators);
}
int CPUBfloat16PlacementPass::SetMkldnnDataType(ir::Graph* graph) const {
const auto& op_types_list =
Get<std::unordered_set<std::string>>("bfloat16_enabled_op_types");
// set mkldnn_data_type to bfloat16 to all operators that are in
......@@ -39,6 +47,7 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType(
"bfloat16_placement"};
bfloat16_placement_pattern(op_types_list);
int detected_operators = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_placement_pattern);
......@@ -50,58 +59,58 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType(
if ((op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) &&
!platform::HasOpINT8DataType(op->Op())) {
VLOG(4) << "--- marked " << op->Op()->Type()
<< " operator to bfloat16 ";
op->Op()->SetAttr("mkldnn_data_type", std::string("bfloat16"));
(*bfloat16_operators)++;
detected_operators++;
}
};
gpd(graph, handler);
return detected_operators;
}
void CPUBfloat16PlacementPass::RemoveOrphanedOperators(
ir::Graph* graph, int* bfloat16_operators) const {
int CPUBfloat16PlacementPass::RemoveOrphanedOperators(ir::Graph* graph) const {
// find orphaned bfloat16 operator that is between two float32 operators
// revert mkldnn_data_type attr to float32
GraphPatternDetector gpd;
patterns::OrphanedBfloat16 orphaned_bfloat16_pattern{gpd.mutable_pattern(),
"orphaned_bfloat16"};
orphaned_bfloat16_pattern();
int detected_operators = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, orphaned_bfloat16_pattern);
op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
bfloat16_operators--;
VLOG(4) << "--- demarked " << op->Op()->Type() << " operator to bfloat16 ";
detected_operators++;
};
gpd(graph, handler);
return detected_operators;
}
void CPUBfloat16PlacementPass::RemoveUnsupportedOperators(
ir::Graph* graph, int* bfloat16_operators) const {
int CPUBfloat16PlacementPass::RemoveUnsupportedOperators(
ir::Graph* graph) const {
// now quantize is supported FP32 only, so try to find
// bfloat16 operator that input type is not FP32
GraphPatternDetector gpd;
patterns::UnsupportedBfloat16 unsupported_bfloat16_pattern{
gpd.mutable_pattern(), "unsupported_bfloat16"};
unsupported_bfloat16_pattern();
int detected_operators = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(prev_out, prev_out, unsupported_bfloat16_pattern);
GET_IR_NODE_FROM_SUBGRAPH(op, op, unsupported_bfloat16_pattern);
if ((prev_out->Var()->GetDataType() != proto::VarType::FP32)) {
op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
bfloat16_operators--;
VLOG(4) << "--- demarked " << op->Op()->Type()
<< " operator to bfloat16 ";
detected_operators++;
}
};
gpd(graph, handler);
}
void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
int bfloat16_operators = 0;
SetMkldnnDataType(graph, &bfloat16_operators);
RemoveOrphanedOperators(graph, &bfloat16_operators);
RemoveUnsupportedOperators(graph, &bfloat16_operators);
PrettyLogDetail("--- marked %d operators to bfloat16 ",
bfloat16_operators);
return detected_operators;
}
} // namespace ir
......
......@@ -26,14 +26,11 @@ namespace ir {
*/
class CPUBfloat16PlacementPass : public Pass {
protected:
void SetMkldnnDataType(ir::Graph* graph, int* bfloat16_operators) const;
void RemoveOrphanedOperators(ir::Graph* graph, int* bfloat16_operators) const;
void RemoveUnsupportedOperators(ir::Graph* graph,
int* bfloat16_operators) const;
void ApplyImpl(ir::Graph* graph) const override;
int SetMkldnnDataType(ir::Graph* graph) const;
int RemoveOrphanedOperators(ir::Graph* graph) const;
int RemoveUnsupportedOperators(ir::Graph* graph) const;
};
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册