未验证 提交 9aa89b99 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Correct elementwise quantization (#43693)

上级 bd5e97d3
......@@ -2078,6 +2078,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
return out_var;
}
PDNode *patterns::ElementwiseOp::operator()(
const std::string elementwise_type) {
auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
auto out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_op_output(elementwise_type, "Out");
elementwise_op->LinksTo({out_var});
return out_var;
}
PDNode *patterns::ResidualElementwise::operator()(
PDNode *op_var, PDNode *residual_var, const std::string elementwise_type,
bool as_x) {
......
......@@ -1072,6 +1072,19 @@ struct Elementwise : public PatternBase {
PATTERN_DECL_NODE(elementwise_out);
};
// Elementwise ops
// Forward pass for element-wise operators
// elementwise_out is the result of the operator
struct ElementwiseOp : public PatternBase {
ElementwiseOp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise") {}
PDNode* operator()(const std::string elementwise_type);
PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_out);
};
// Residual Elementwise ops
// This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag
......
......@@ -858,12 +858,9 @@ void CPUQuantizePass::QuantizeElementwise(
Graph* graph, const std::string elementwise_type) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Elementwise elementwise_pattern{pattern, name_scope_};
patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};
elementwise_pattern(
pattern->NewNode(elementwise_pattern.elementwise_x_repr()),
pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
elementwise_type);
elementwise_pattern(elementwise_type);
int quantize_elementwise_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......@@ -878,10 +875,18 @@ void CPUQuantizePass::QuantizeElementwise(
return;
}
GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y,
elementwise_pattern);
auto x_name = elementwise_op->Op()->Input("X");
auto y_name = elementwise_op->Op()->Input("Y");
Node *elementwise_x, *elementwise_y;
for (auto& input : elementwise_op->inputs) {
if (input->Name() == x_name[0]) elementwise_x = input;
if (input->Name() == y_name[0]) elementwise_y = input;
}
if (!elementwise_x || !elementwise_y) {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册