未验证 提交 a7e0cdea 编写于 作者: L lidanqing 提交者: GitHub

[cherry-pick] release/2.3 elementwise_mul and matmul mkldnn fix (#43725)

* Correct elementwise quantization (#43693)

* [Bug fix] Do not quantize weights Y when matmul X and Y both other ops outputs (#43297)

* fix some matmul that X and Y both other ops outputs, do not dequantize the Y.

* fix CI format

* fix according to review
Co-authored-by: Njoanna.wozna.intel <joanna.wozna@intel.com>
上级 d0bbf46c
......@@ -2069,6 +2069,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
return out_var;
}
PDNode *patterns::ElementwiseOp::operator()(
const std::string elementwise_type) {
auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
auto out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_op_output(elementwise_type, "Out");
elementwise_op->LinksTo({out_var});
return out_var;
}
PDNode *patterns::ResidualElementwise::operator()(
PDNode *op_var, PDNode *residual_var, const std::string elementwise_type,
bool as_x) {
......
......@@ -1071,6 +1071,19 @@ struct Elementwise : public PatternBase {
PATTERN_DECL_NODE(elementwise_out);
};
// Elementwise ops
// Forward pass for element-wise operators
// elementwise_out is the result of the operator
struct ElementwiseOp : public PatternBase {
ElementwiseOp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise") {}
PDNode* operator()(const std::string elementwise_type);
PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_out);
};
// Residual Elementwise ops
// This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag
......
......@@ -819,12 +819,9 @@ void CPUQuantizePass::QuantizeElementwise(
Graph* graph, const std::string elementwise_type) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Elementwise elementwise_pattern{pattern, name_scope_};
patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};
elementwise_pattern(
pattern->NewNode(elementwise_pattern.elementwise_x_repr()),
pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
elementwise_type);
elementwise_pattern(elementwise_type);
int quantize_elementwise_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......@@ -839,10 +836,18 @@ void CPUQuantizePass::QuantizeElementwise(
return;
}
GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y,
elementwise_pattern);
auto x_name = elementwise_op->Op()->Input("X");
auto y_name = elementwise_op->Op()->Input("Y");
Node *elementwise_x, *elementwise_y;
for (auto& input : elementwise_op->inputs) {
if (input->Name() == x_name[0]) elementwise_x = input;
if (input->Name() == y_name[0]) elementwise_y = input;
}
if (!elementwise_x || !elementwise_y) {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern);
......
......@@ -354,10 +354,9 @@ bool QuantDequantMkldnnPass::IsInt8Weight(
auto* op_desc = op_node->Op();
auto var_name = op_desc->Input(weight_name)[0];
auto* var = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound(
"The input persistable [%s] var of [%s] op is not found.",
var_name, op_desc->Type()));
if (var == nullptr) {
return false;
}
auto* weight_tensor = var->GetMutable<LoDTensor>();
auto* weight_data = weight_tensor->data<float>();
bool is_int8 = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册