未验证 提交 1483ea23 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add bfloat16 passes (#26999)

上级 6947a58a
...@@ -102,6 +102,8 @@ if(WITH_MKLDNN) ...@@ -102,6 +102,8 @@ if(WITH_MKLDNN)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(scale_matmul_fuse_pass inference DIR mkldnn) pass_library(scale_matmul_fuse_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_pass inference DIR mkldnn)
pass_library(fc_mkldnn_pass inference DIR mkldnn) pass_library(fc_mkldnn_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn)
...@@ -162,4 +164,6 @@ endif() ...@@ -162,4 +164,6 @@ endif()
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass) cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
endif () endif ()
...@@ -1892,6 +1892,82 @@ PDNode *patterns::QuantizePlacement::operator()( ...@@ -1892,6 +1892,82 @@ PDNode *patterns::QuantizePlacement::operator()(
return op; return op;
} }
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>();
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
return op;
}
PDNode *patterns::OrphanedBfloat16::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"float32";
});
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
next_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"float32";
});
prev_op->LinksTo({prev_out});
op->LinksFrom({prev_out}).LinksTo({op_out});
next_op->LinksFrom({op_out});
return next_op;
}
PDNode *patterns::LastBfloat16Ops::operator()() {
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
next_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") !=
"bfloat16";
});
op->LinksTo({op_out});
next_op->LinksFrom({op_out});
return next_op;
}
PDNode *patterns::FirstBfloat16Ops::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") !=
"bfloat16";
});
auto *op_in = pattern->NewNode(op_in_repr())->AsOutput();
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
prev_op->LinksTo({op_in});
op->LinksFrom({op_in});
return op;
}
PDNode *patterns::MKLDNNInPlace::operator()() { PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = { const std::unordered_set<std::string> &supported_op_types = {
"abs", "abs",
......
...@@ -1129,6 +1129,47 @@ struct QuantizePlacement : public PatternBase { ...@@ -1129,6 +1129,47 @@ struct QuantizePlacement : public PatternBase {
PATTERN_DECL_NODE(op); PATTERN_DECL_NODE(op);
}; };
struct Bfloat16Placement : public PatternBase {
Bfloat16Placement(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "bfloat16_placement") {}
PDNode* operator()(
const std::unordered_set<std::string>& bfloat16_enabled_op_types);
PATTERN_DECL_NODE(op);
};
struct OrphanedBfloat16 : public PatternBase {
OrphanedBfloat16(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "orphaned_bfloat16") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(prev_out);
PATTERN_DECL_NODE(op);
PATTERN_DECL_NODE(op_out);
PATTERN_DECL_NODE(next_op);
};
struct LastBfloat16Ops : public PatternBase {
LastBfloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "last_bfloat16_ops") {}
PDNode* operator()();
PATTERN_DECL_NODE(op);
PATTERN_DECL_NODE(op_out);
PATTERN_DECL_NODE(next_op);
};
struct FirstBfloat16Ops : public PatternBase {
FirstBfloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "first_bfloat16_ops") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(op_in);
PATTERN_DECL_NODE(op);
};
// Pattern used for enforcing inplace computation for in-place computation // Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm // supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase { struct MKLDNNInPlace : public PatternBase {
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void UnlinkNodes(ir::Node* a, ir::Node* b) {
a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
a->outputs.end());
b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
b->inputs.end());
}
void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"first_bfloat16_ops"};
bfloat16_ops();
int quantize_counter = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
if (op->Op()->Type() != "conv2d" && prev_op->Op()->Type() != "quantize") {
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);
// create a quantize op node
OpDesc q_desc;
q_desc.SetType("quantize");
q_desc.SetInput("Input", std::vector<std::string>({op_in->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node->Name()}));
q_desc.SetAttr("Scale", 1.f);
q_desc.SetAttr("bfloat16", true);
q_desc.SetAttr("output_format", Has("data_layout")
? Get<std::string>("data_layout")
: "NCHW");
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
std::string op_input_name;
for (auto name : op->Op()->InputNames()) {
for (auto input_name : op->Op()->Input(name)) {
if (input_name == op_in->Name()) op_input_name = name;
}
}
PADDLE_ENFORCE_NE(
op_input_name.empty(), true,
platform::errors::NotFound(
"Operator before operator should have input as op output"));
op->Op()->SetInput(op_input_name,
std::vector<std::string>({quantize_out_node->Name()}));
UnlinkNodes(op_in, op);
IR_NODE_LINK_TO(op_in, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_node);
IR_NODE_LINK_TO(quantize_out_node, op);
quantize_counter++;
}
};
gpd(graph, handler);
PrettyLogDetail("--- added %d quantize op before bfloat16 op",
quantize_counter);
}
void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"last_bfloat16_ops"};
bfloat16_ops();
int force_fp32_counter = 0, dequantize_counter = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, bfloat16_ops);
if ((op->Op()->HasAttr("force_fp32_output") ||
op->Op()->HasProtoAttr("force_fp32_output")) &&
!op->Op()->GetAttrIfExists<bool>("fuse_residual_connection")) {
op->Op()->SetAttr("force_fp32_output", true);
force_fp32_counter++;
} else if (op->Op()->Type() != "prior_box") {
// Create dequantize input variable
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
// create a dequantize op node for output.
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc);
std::string op_output_name;
for (auto name : op->Op()->OutputNames()) {
for (auto output_name : op->Op()->Output(name)) {
if (output_name == op_out->Name()) op_output_name = name;
}
}
PADDLE_ENFORCE_NE(
op_output_name.empty(), true,
platform::errors::NotFound(
"Operator after operator should have input as op output"));
op->Op()->SetOutput(op_output_name, std::vector<std::string>(
{dequantize_in_node->Name()}));
UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);
dequantize_counter++;
}
};
gpd(graph, handler);
PrettyLogDetail("--- added %d dequantize op and used %d force_fp32_output",
dequantize_counter, force_fp32_counter);
}
void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
SetInputDataType(graph);
SetOutputDataType(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cpu_bfloat16_pass, paddle::framework::ir::CPUBFloat16Pass);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class CPUBFloat16Pass : public Pass {
protected:
void SetInputDataType(ir::Graph* graph) const;
void SetOutputDataType(ir::Graph* graph) const;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn,
const std::string& mkldnn_data_type = "float32",
const bool force_fp32_output = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
if (type == "conv2d") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("force_fp32_output", force_fp32_output);
} else if (type == "pool2d" || type == "transpose2" || type == "reshape2" ||
type == "dropout") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "fc") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "matmul" || type == "elementwise_add") {
op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
}
}
void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
const std::initializer_list<std::string> variable_names,
int* original_nodes_num, int* current_nodes_num) {
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_pass");
graph->reset(pass->Apply(graph->release()));
*original_nodes_num = (*graph)->Nodes().size();
(*graph).reset(pass->Apply((*graph).release()));
*current_nodes_num = (*graph)->Nodes().size();
}
static const std::initializer_list<std::string> variable_names{
"z", "a", "b", "c", "d", "e", "f", "g", "h", "i"};
ProgramDesc BuildProgramDesc(bool use_mkldnn) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dropout", "Dropout1", {"z"}, {"a"}, use_mkldnn, "float32");
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, "bfloat16");
SetOp(&prog, "pool2d", "Pool1", {"b"}, {"c"}, use_mkldnn, "bfloat16");
SetOp(&prog, "conv2d", "Conv1", {"c"}, {"d"}, use_mkldnn, "bfloat16");
SetOp(&prog, "dropout", "Dropout2", {"d"}, {"e"}, use_mkldnn, "float32");
SetOp(&prog, "transpose2", "Transpose1", {"e"}, {"f"}, use_mkldnn,
"bfloat16");
SetOp(&prog, "reshape2", "Reshape1", {"f"}, {"g"}, use_mkldnn, "bfloat16");
SetOp(&prog, "concat", "Concat1", {"g"}, {"h"}, use_mkldnn, "bfloat16");
SetOp(&prog, "dropout", "Dropout3", {"h"}, {"i"}, use_mkldnn, "float32");
return prog;
}
void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
int transpose_count, int quant_count, int dequant_count,
int added_nodes_count) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names, &original_nodes_num,
&current_nodes_num);
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int conv2d_nodes_count = 0;
int pool2d_nodes_count = 0;
int transpose2_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "conv2d") {
conv2d_nodes_count++;
} else if (op->Type() == "pool2d") {
pool2d_nodes_count++;
} else if (op->Type() == "transpose2") {
transpose2_nodes_count++;
} else if (op->Type() == "quantize") {
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
}
}
}
EXPECT_EQ(conv2d_nodes_count, conv_count);
EXPECT_EQ(pool2d_nodes_count, pool_count);
EXPECT_EQ(transpose2_nodes_count, transpose_count);
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}
TEST(CpuQuantizePass, quantize) {
bool use_mkldnn = true;
// 1 quantize + 1 dequantize
int added_nodes = 2;
MainTest(BuildProgramDesc(use_mkldnn), 2, 1, 1, 1, 2, added_nodes);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(cpu_bfloat16_pass);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void CPUBfloat16PlacementPass::SetMkldnnDataType(
ir::Graph* graph, int* bfloat16_operators) const {
const auto& op_types_list =
Get<std::unordered_set<std::string>>("bfloat16_enabled_op_types");
// set mkldnn_data_type to bfloat16 to all operators that are in
// bfloat16_enabled_op_types vector or they are included to Bfloat16Placement
// pattern
GraphPatternDetector gpd;
patterns::Bfloat16Placement bfloat16_placement_pattern{gpd.mutable_pattern(),
"bfloat16_placement"};
bfloat16_placement_pattern(op_types_list);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_placement_pattern);
if ((op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) &&
!platform::HasOpINT8DataType(op->Op())) {
op->Op()->SetAttr("mkldnn_data_type", std::string("bfloat16"));
(*bfloat16_operators)++;
}
};
gpd(graph, handler);
}
void CPUBfloat16PlacementPass::RemoveOrhanedOperators(
ir::Graph* graph, int* bfloat16_operators) const {
// find orphaned bfloat16 operator that is between two float32 operators
// revert mkldnn_data_type attr to float32
GraphPatternDetector gpd;
patterns::OrphanedBfloat16 orphaned_bfloat16_pattern{gpd.mutable_pattern(),
"orphaned_bfloat16"};
orphaned_bfloat16_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, orphaned_bfloat16_pattern);
op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
bfloat16_operators--;
};
gpd(graph, handler);
}
void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
int bfloat16_operators = 0;
SetMkldnnDataType(graph, &bfloat16_operators);
RemoveOrhanedOperators(graph, &bfloat16_operators);
PrettyLogDetail("--- marked %d operators to bfloat16 ",
bfloat16_operators);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cpu_bfloat16_placement_pass,
paddle::framework::ir::CPUBfloat16PlacementPass)
// a vector of operator type names with bfloat16 support ("conv2d" etc.)
// the second param is the default value for this vector
.DefaultPassAttr("bfloat16_enabled_op_types",
new std::unordered_set<std::string>());
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Specifies which operators should be run on bfloat16.
*/
class CPUBfloat16PlacementPass : public Pass {
protected:
void SetMkldnnDataType(ir::Graph* graph, int* bfloat16_operators) const;
void RemoveOrhanedOperators(ir::Graph* graph, int* bfloat16_operators) const;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::string& mkldnn_data_type = "float32") {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
if (type == "conv2d") {
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
} else if (type == "relu") {
op->SetInput("X", inputs);
} else if (type == "concat") {
op->SetAttr("axis", 1);
op->SetInput("X", {inputs[0], inputs[1]});
} else if (type == "pool2d") {
op->SetInput("X", {inputs[0]});
} else {
FAIL() << "Unexpected operator type.";
}
op->SetOutput("Out", {outputs[0]});
}
// operator mkldnn_data_type
// ---------------------------------------
// (a,b)->concat->c float32
// c->conv->f float32
// f->relu->g float32
// g->pool->h float32
// h->conv->k float32
// k->pool->l float32
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "f", "g", "h", "k", "l"})) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"});
SetOp(&prog, "conv2d", "conv1", {"c"}, {"f"});
SetOp(&prog, "relu", "relu1", {"f"}, {"g"});
SetOp(&prog, "pool2d", "pool1", {"g"}, {"h"});
SetOp(&prog, "conv2d", "conv2", {"h"}, {"k"});
SetOp(&prog, "pool2d", "pool2", {"k"}, {"l"});
return prog;
}
void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types,
unsigned expected_bfloat16_data_type_count) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
pass->Set("bfloat16_enabled_op_types",
new std::unordered_set<std::string>(bfloat16_enabled_op_types));
graph.reset(pass->Apply(graph.release()));
unsigned bfloat16_data_type_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
if (platform::HasOpBFLOAT16DataType(node->Op())) {
++bfloat16_data_type_count;
}
}
}
EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count);
}
void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
graph.reset(pass->Apply(graph.release()));
unsigned bfloat16_data_type_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
if (platform::HasOpBFLOAT16DataType(node->Op())) {
++bfloat16_data_type_count;
}
}
}
EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count);
}
TEST(Bfloat16PlacementPass, enable_all) {
MainTest({"conv2d", "pool2d", "relu", "concat"}, 6);
}
TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
// 2 conv2d + 2 pool2 - 1 orphaned conv2d
MainTest({"conv2d", "pool2d"}, 3);
}
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(0); }
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(cpu_bfloat16_placement_pass);
...@@ -231,6 +231,10 @@ void CpuPassStrategy::EnableMkldnnQuantizer() { ...@@ -231,6 +231,10 @@ void CpuPassStrategy::EnableMkldnnQuantizer() {
void CpuPassStrategy::EnableMkldnnBfloat16() { void CpuPassStrategy::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (!use_mkldnn_bfloat16_) {
passes_.push_back("cpu_bfloat16_placement_pass");
passes_.push_back("cpu_bfloat16_pass");
}
use_mkldnn_bfloat16_ = true; use_mkldnn_bfloat16_ = true;
#else #else
use_mkldnn_bfloat16_ = false; use_mkldnn_bfloat16_ = false;
......
...@@ -48,6 +48,7 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -48,6 +48,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
bool is_negative = ctx.Attr<bool>("is_negative_input"); bool is_negative = ctx.Attr<bool>("is_negative_input");
bool bfloat16 = ctx.Attr<bool>("bfloat16");
std::string key = std::string key =
platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_data, platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_data,
is_negative, ctx.OutputName("Output")); is_negative, ctx.OutputName("Output"));
...@@ -74,7 +75,10 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -74,7 +75,10 @@ class QuantOpKernel : public framework::OpKernel<T> {
src_md, engine, to_void_cast<T>(input_data)); src_md, engine, to_void_cast<T>(input_data));
std::shared_ptr<mkldnn::memory::desc> dst_md; std::shared_ptr<mkldnn::memory::desc> dst_md;
if (is_negative) { if (bfloat16) {
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
} else if (is_negative) {
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine, platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format); dst_md, dst_memory, out_format);
} else { } else {
...@@ -96,7 +100,11 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -96,7 +100,11 @@ class QuantOpKernel : public framework::OpKernel<T> {
dst_memory = std::static_pointer_cast<mkldnn::memory>( dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_dst_mem)); dev_ctx.GetBlob(key_dst_mem));
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
if (is_negative) {
if (bfloat16) {
dst_memory->set_data_handle(
output->mutable_data<paddle::platform::bfloat16>(place));
} else if (is_negative) {
dst_memory->set_data_handle(output->mutable_data<int8_t>(place)); dst_memory->set_data_handle(output->mutable_data<int8_t>(place));
} else { } else {
dst_memory->set_data_handle(output->mutable_data<uint8_t>(place)); dst_memory->set_data_handle(output->mutable_data<uint8_t>(place));
......
...@@ -40,6 +40,8 @@ void QuantOpMaker::Make() { ...@@ -40,6 +40,8 @@ void QuantOpMaker::Make() {
AddAttr<std::string>("output_format", AddAttr<std::string>("output_format",
"Convert format to NHWC or NCHW during quantization.") "Convert format to NHWC or NCHW during quantization.")
.SetDefault("NHWC"); .SetDefault("NHWC");
AddAttr<bool>("bfloat16", "(bool, default false) Convert to bfloat16")
.SetDefault(false);
AddComment(R"DOC(This op will quantize data from FP32 to INT8)DOC"); AddComment(R"DOC(This op will quantize data from FP32 to INT8)DOC");
} }
......
...@@ -443,6 +443,13 @@ inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) { ...@@ -443,6 +443,13 @@ inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
op->GetAttrIfExists<bool>("use_quantizer")); op->GetAttrIfExists<bool>("use_quantizer"));
} }
inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
}
inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
}
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP }; enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };
} // namespace platform } // namespace platform
......
...@@ -184,6 +184,7 @@ void BindVarDsec(pybind11::module *m) { ...@@ -184,6 +184,7 @@ void BindVarDsec(pybind11::module *m) {
.value("FP16", pd::proto::VarType::FP16) .value("FP16", pd::proto::VarType::FP16)
.value("FP32", pd::proto::VarType::FP32) .value("FP32", pd::proto::VarType::FP32)
.value("FP64", pd::proto::VarType::FP64) .value("FP64", pd::proto::VarType::FP64)
.value("BF16", pd::proto::VarType::BF16)
.value("LOD_TENSOR", pd::proto::VarType::LOD_TENSOR) .value("LOD_TENSOR", pd::proto::VarType::LOD_TENSOR)
.value("SELECTED_ROWS", pd::proto::VarType::SELECTED_ROWS) .value("SELECTED_ROWS", pd::proto::VarType::SELECTED_ROWS)
.value("FEED_MINIBATCH", pd::proto::VarType::FEED_MINIBATCH) .value("FEED_MINIBATCH", pd::proto::VarType::FEED_MINIBATCH)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册