未验证 提交 a7e0cdea 编写于 作者: L lidanqing 提交者: GitHub

[cherry-pick] release/2.3 elementwise_mul and matmul mkldnn fix (#43725)

* Correct elementwise quantization (#43693)

* [Bug fix] Do not quantize weights Y when matmul X and Y both other ops outputs (#43297)

* fix some matmul that X and Y both other ops outputs, do not dequantize the Y.

* fix CI format

* fix according to review
Co-authored-by: Njoanna.wozna.intel <joanna.wozna@intel.com>
上级 d0bbf46c
...@@ -2069,6 +2069,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var, ...@@ -2069,6 +2069,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
return out_var; return out_var;
} }
PDNode *patterns::ElementwiseOp::operator()(
const std::string elementwise_type) {
auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
auto out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_op_output(elementwise_type, "Out");
elementwise_op->LinksTo({out_var});
return out_var;
}
PDNode *patterns::ResidualElementwise::operator()( PDNode *patterns::ResidualElementwise::operator()(
PDNode *op_var, PDNode *residual_var, const std::string elementwise_type, PDNode *op_var, PDNode *residual_var, const std::string elementwise_type,
bool as_x) { bool as_x) {
......
...@@ -1071,6 +1071,19 @@ struct Elementwise : public PatternBase { ...@@ -1071,6 +1071,19 @@ struct Elementwise : public PatternBase {
PATTERN_DECL_NODE(elementwise_out); PATTERN_DECL_NODE(elementwise_out);
}; };
// Elementwise ops
// Forward pass for element-wise operators
// elementwise_out is the result of the operator
struct ElementwiseOp : public PatternBase {
ElementwiseOp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise") {}
PDNode* operator()(const std::string elementwise_type);
PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_out);
};
// Residual Elementwise ops // Residual Elementwise ops
// This pattern allows operator output to be X or Y // This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag // and residual data Y or X, based on as_x flag
......
...@@ -819,12 +819,9 @@ void CPUQuantizePass::QuantizeElementwise( ...@@ -819,12 +819,9 @@ void CPUQuantizePass::QuantizeElementwise(
Graph* graph, const std::string elementwise_type) const { Graph* graph, const std::string elementwise_type) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::Elementwise elementwise_pattern{pattern, name_scope_}; patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};
elementwise_pattern( elementwise_pattern(elementwise_type);
pattern->NewNode(elementwise_pattern.elementwise_x_repr()),
pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
elementwise_type);
int quantize_elementwise_count = 0; int quantize_elementwise_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
...@@ -839,10 +836,18 @@ void CPUQuantizePass::QuantizeElementwise( ...@@ -839,10 +836,18 @@ void CPUQuantizePass::QuantizeElementwise(
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x, auto x_name = elementwise_op->Op()->Input("X");
elementwise_pattern); auto y_name = elementwise_op->Op()->Input("Y");
GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y, Node *elementwise_x, *elementwise_y;
elementwise_pattern);
for (auto& input : elementwise_op->inputs) {
if (input->Name() == x_name[0]) elementwise_x = input;
if (input->Name() == y_name[0]) elementwise_y = input;
}
if (!elementwise_x || !elementwise_y) {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern); elementwise_pattern);
......
...@@ -354,10 +354,9 @@ bool QuantDequantMkldnnPass::IsInt8Weight( ...@@ -354,10 +354,9 @@ bool QuantDequantMkldnnPass::IsInt8Weight(
auto* op_desc = op_node->Op(); auto* op_desc = op_node->Op();
auto var_name = op_desc->Input(weight_name)[0]; auto var_name = op_desc->Input(weight_name)[0];
auto* var = scope->FindVar(var_name); auto* var = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL( if (var == nullptr) {
var, platform::errors::NotFound( return false;
"The input persistable [%s] var of [%s] op is not found.", }
var_name, op_desc->Type()));
auto* weight_tensor = var->GetMutable<LoDTensor>(); auto* weight_tensor = var->GetMutable<LoDTensor>();
auto* weight_data = weight_tensor->data<float>(); auto* weight_data = weight_tensor->data<float>();
bool is_int8 = true; bool is_int8 = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册