提交 f8681ffc 编写于 作者: J joanna.wozna.intel 提交者: lidanqing-intel

Correct elementwise quantization (#43693)

上级 ae7192a8
...@@ -2069,6 +2069,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var, ...@@ -2069,6 +2069,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
return out_var; return out_var;
} }
PDNode *patterns::ElementwiseOp::operator()(
const std::string elementwise_type) {
auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
auto out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_op_output(elementwise_type, "Out");
elementwise_op->LinksTo({out_var});
return out_var;
}
PDNode *patterns::ResidualElementwise::operator()( PDNode *patterns::ResidualElementwise::operator()(
PDNode *op_var, PDNode *residual_var, const std::string elementwise_type, PDNode *op_var, PDNode *residual_var, const std::string elementwise_type,
bool as_x) { bool as_x) {
......
...@@ -1071,6 +1071,19 @@ struct Elementwise : public PatternBase { ...@@ -1071,6 +1071,19 @@ struct Elementwise : public PatternBase {
PATTERN_DECL_NODE(elementwise_out); PATTERN_DECL_NODE(elementwise_out);
}; };
// Elementwise ops
// Forward pass for element-wise operators
// elementwise_out is the result of the operator
struct ElementwiseOp : public PatternBase {
ElementwiseOp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise") {}
PDNode* operator()(const std::string elementwise_type);
PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_out);
};
// Residual Elementwise ops // Residual Elementwise ops
// This pattern allows operator output to be X or Y // This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag // and residual data Y or X, based on as_x flag
......
...@@ -819,12 +819,9 @@ void CPUQuantizePass::QuantizeElementwise( ...@@ -819,12 +819,9 @@ void CPUQuantizePass::QuantizeElementwise(
Graph* graph, const std::string elementwise_type) const { Graph* graph, const std::string elementwise_type) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::Elementwise elementwise_pattern{pattern, name_scope_}; patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};
elementwise_pattern( elementwise_pattern(elementwise_type);
pattern->NewNode(elementwise_pattern.elementwise_x_repr()),
pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
elementwise_type);
int quantize_elementwise_count = 0; int quantize_elementwise_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
...@@ -839,10 +836,18 @@ void CPUQuantizePass::QuantizeElementwise( ...@@ -839,10 +836,18 @@ void CPUQuantizePass::QuantizeElementwise(
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x, auto x_name = elementwise_op->Op()->Input("X");
elementwise_pattern); auto y_name = elementwise_op->Op()->Input("Y");
GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y, Node *elementwise_x, *elementwise_y;
elementwise_pattern);
for (auto& input : elementwise_op->inputs) {
if (input->Name() == x_name[0]) elementwise_x = input;
if (input->Name() == y_name[0]) elementwise_y = input;
}
if (!elementwise_x || !elementwise_y) {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern); elementwise_pattern);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册