提交 d3a96632 编写于 作者: L lidanqing 提交者: Tao Luo

Add fc-dequantize squash in cpu_quantize_squash_pass for ernie model (#21714)

* fc-dequantize squash
test=develop

* change according to reviews
test=develop

* change PADDLE_ENFORCE
test=develop

* add second test when fc-dequant do not fuse
test=develop

* change all related PADDLE_ENFORCE
test=develop
上级 1fd1f06f
......@@ -41,9 +41,10 @@ size_t PDPattern::id_ = 0UL;
PDNode *PDPattern::NewNode(const std::string &name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0UL,
"PDNode's name should be unique, get duplicate [%s]",
name);
PADDLE_ENFORCE_EQ(
node_map_.count(name), 0UL,
platform::errors::PreconditionNotMet(
"PDNode's name should be unique, get duplicate [%s]", name));
}
nodes_.emplace_back(new PDNode(this, name));
......@@ -54,9 +55,10 @@ PDNode *PDPattern::NewNode(const std::string &name) {
PDNode *PDPattern::NewNode(PDNode::teller_t &&teller, const std::string &name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0UL,
"PDNode's name should be unique, get duplicate [%s]",
name);
PADDLE_ENFORCE_EQ(
node_map_.count(name), 0UL,
platform::errors::PreconditionNotMet(
"PDNode's name should be unique, get duplicate [%s]", name));
}
nodes_.emplace_back(new PDNode(std::move(teller), this, name));
......@@ -75,8 +77,10 @@ PDNode *PDPattern::RetrieveNode(const std::string &id) const {
}
void PDPattern::AddEdge(PDNode *a, PDNode *b) {
PADDLE_ENFORCE(a);
PADDLE_ENFORCE(b);
PADDLE_ENFORCE_NOT_NULL(
a, platform::errors::NotFound("PDNode %s is not found.", a->name()));
PADDLE_ENFORCE_NOT_NULL(
b, platform::errors::NotFound("PDNode %s is not found.", b->name()));
PADDLE_ENFORCE_NE(a, b, platform::errors::PermissionDenied(
"Cannot connect the same node in the graph."));
edges_.emplace_back(a, b);
......@@ -610,15 +614,24 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
}
bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
PADDLE_ENFORCE_EQ(
var->IsVar(), true,
platform::errors::InvalidArgument(
"First parameter of function IsNthInput must be Node::Var"));
PADDLE_ENFORCE_EQ(
op->IsOp(), true,
platform::errors::InvalidArgument(
"Second parameter of function IsNthInput must be Node::Op"));
if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth)
return false;
return var->Name() == op->Op()->Input(argument)[nth];
}
bool HasInput(Node *op, const std::string &argument) {
PADDLE_ENFORCE(op->IsOp());
PADDLE_ENFORCE_EQ(
op->IsOp(), true,
platform::errors::InvalidArgument(
"First parameter of function HasInput must be Node::Op"));
auto const &names = op->Op()->InputNames();
if (std::find(names.begin(), names.end(), argument) == names.end())
return false;
......@@ -626,8 +639,14 @@ bool HasInput(Node *op, const std::string &argument) {
}
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
PADDLE_ENFORCE_EQ(
var->IsVar(), true,
platform::errors::InvalidArgument(
"First parameter of function IsNthOutput must be Node::Var"));
PADDLE_ENFORCE_EQ(
op->IsOp(), true,
platform::errors::InvalidArgument(
"Second parameter of function IsNthOutput must be Node::Op"));
if (op->Op()->Output(argument).size() <= nth) return false;
return var->Name() == op->Op()->Output(argument)[nth];
}
......@@ -1344,6 +1363,24 @@ PDNode *patterns::ConvDequant::operator()() {
return dequant_out;
}
PDNode *patterns::FcDequant::operator()() {
// Create Operators
auto fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto fc_out =
pattern->NewNode(fc_out_repr())->assert_is_op_output("fc", "Out");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
fc_op->LinksTo({fc_out});
dequant_op->LinksFrom({fc_out}).LinksTo({dequant_out});
return dequant_out;
}
PDNode *patterns::PriorBox::operator()() {
auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
......
......@@ -153,7 +153,9 @@ struct PDNode {
pattern_(pattern),
name_(name),
type_(type) {
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set.");
PADDLE_ENFORCE_NOT_NULL(
teller_,
platform::errors::NotFound("invalid teller is set, teller is null"));
}
PDNode(PDNode&& other) = default;
......@@ -371,10 +373,13 @@ static std::string UniqueKey(const std::string& repr) {
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// pat: the pattern object.
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
PADDLE_ENFORCE(subgraph.count(pat.arg##_n()), \
"Node not found for PDNode %s", pat.arg##_repr()); \
PADDLE_ENFORCE_NE(subgraph.count(pat.arg##_n()), 0UL, \
platform::errors::NotFound("Node not found for PDNode %s", \
pat.arg##_repr())); \
Node* var = subgraph.at(pat.arg##_n()); \
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph", #arg)
PADDLE_ENFORCE_NOT_NULL( \
var, platform::errors::NotFound("node %s not exists in the sub-graph", \
#arg));
// The base class of all the patterns.
struct PatternBase {
......@@ -844,6 +849,20 @@ struct ConvDequant : public PatternBase {
PATTERN_DECL_NODE(dequant_out);
};
// Fc + Dequant
struct FcDequant : public PatternBase {
FcDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_dequant") {}
PDNode* operator()();
PATTERN_DECL_NODE(fc_op);
PATTERN_DECL_NODE(fc_out);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};
// PriorBox operator
// operator: prior_box_op
// inputs: prior_box_input, prior_box_image
......
......@@ -71,8 +71,9 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
auto* next_op_desc = next_op->Op();
float dequant_scale = boost::get<float>(dequant_op->Op()->GetAttr("Scale"));
float quant_scale = boost::get<float>(quant_op->Op()->GetAttr("Scale"));
PADDLE_ENFORCE(nodes_keep_counter->find(dequant_out) !=
nodes_keep_counter->end());
PADDLE_ENFORCE_NE(
nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(),
platform::errors::NotFound("The dequant output node is not found"));
// check if dequantize op should be kept or removed, decrease the counter
bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1;
......@@ -195,14 +196,50 @@ void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
found_conv_dequant_squash_count);
}
// squash fc with dequant
void CPUQuantizeSquashPass::FcDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::FcDequant fc_dequant_pattern{gpd.mutable_pattern(), "fc_dequant"};
fc_dequant_pattern();
int found_fc_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash fc-dequant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, fc_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, fc_dequant_pattern);
// if fc has force_fp32_output attribute
if (fc_out->outputs.size() == 1) {
fc_op->Op()->SetAttr("force_fp32_output", true);
fc_op->Op()->SetOutput("Out",
std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(fc_op, dequant_out);
GraphSafeRemoveNodes(graph, {fc_out, dequant_op});
found_fc_dequant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_fc_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with fcs",
found_fc_dequant_squash_count);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::NotFound(
"The graph in function CPUQuantizeSquashPass::ApplyImpl is null"));
FusePassBase::Init("cpu_quantize_squash_pass", graph);
std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter);
ConvDequantSquash(graph);
FcDequantSquash(graph);
}
} // namespace ir
......
......@@ -60,6 +60,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
void ConvDequantSquash(Graph* graph) const;
/*
* Squash fc with dequant when dequant is the next op after fc
*/
void FcDequantSquash(Graph* graph) const;
const std::string name_scope_{"squash"};
};
......
......@@ -50,6 +50,15 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
} else if (type == "fc") {
op->SetInput("Input", {inputs[0]});
PADDLE_ENFORCE_EQ(inputs.size(), 2UL,
platform::errors::InvalidArgument(
"The fc inputs should contain input and weights, but "
"now the size of inputs is %d",
inputs.size()));
op->SetInput("W", {inputs[1]});
op->SetOutput("Out", outputs);
}
}
......@@ -176,6 +185,36 @@ ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
return prog;
}
// a->fc->b
// b->Dequant1->c
// c->Concat1->d
ProgramDesc BuildFcDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
return prog;
}
// a->fc->b
// b->Dequant1->c
// b->concat->d
ProgramDesc BuildFcDequantFcProgramDesc(bool use_mkldnn, float scale_out,
float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "concat", "Concat1", {"b"}, {"d"}, use_mkldnn);
return prog;
}
// a->Conv1->b
// b->Dequant1(Scale1)->c
// b->Conv2->d
......@@ -261,6 +300,23 @@ void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in,
}
}
// check requant_op scales
void IsForceFp32OutputTest(const ProgramDesc& prog, std::string op_type,
bool target_is_force_fp32_output) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
PrepareGraph(&graph, prog);
RegisterPass(&graph);
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == op_type) {
bool is_force_fp32_output =
node->Op()->GetAttrIfExists<bool>("force_fp32_output");
EXPECT_EQ(is_force_fp32_output, target_is_force_fp32_output);
}
}
}
// From Conv1->d->Dequant->e->Quant->f->Conv2
// To Conv1->d->Conv2
TEST(CpuQuantizeSquashPass, equal_scales) {
......@@ -362,8 +418,12 @@ TEST(CpuQuantizeSquashPass, conv_dequant_only_one_output) {
auto remove_nodes = 2;
CountNodeTest(BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
IsForceFp32OutputTest(
BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale), "conv2d",
true);
}
// If there are more than one op after conv->dequantize, do not fuse
TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
......@@ -372,6 +432,39 @@ TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) {
auto remove_nodes = 0;
CountNodeTest(BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
IsForceFp32OutputTest(
BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale), "conv2d",
false);
}
// from
// a->fc->b->Dequant1->c->Concat1->d
// to
// a->fc->c->Concat->d
TEST(CpuQuantizeSquashPass, fc_dequant_only_one_output) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// remove 2 nodes: b, Dequant1
auto remove_nodes = 2;
CountNodeTest(BuildFcDequantConcatProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
IsForceFp32OutputTest(
BuildFcDequantConcatProgramDesc(use_mkldnn, scale_out, scale), "fc",
true);
}
// If there are more than one op after fc->dequantize, do not fuse
TEST(CpuQuantizeSquashPass, fc_dequant_more_than_one_op_after_dequant) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// nothing change
auto remove_nodes = 0;
CountNodeTest(BuildFcDequantFcProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
IsForceFp32OutputTest(
BuildFcDequantFcProgramDesc(use_mkldnn, scale_out, scale), "fc", false);
}
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册