未验证 提交 55dd3f2c 编写于 作者: 周周周 提交者: GitHub

[Paddle-TRT] revert to paddle op when matrix_multiply cannot enter into paddle-trt (#54278)

* revert to paddle op when matrix_multiply cannot enter into paddle-trt 
上级 d6c8ca82
...@@ -57,60 +57,51 @@ void TrtMapOpsToMatrixMultiplyPass::ApplyImpl(ir::Graph* graph) const { ...@@ -57,60 +57,51 @@ void TrtMapOpsToMatrixMultiplyPass::ApplyImpl(ir::Graph* graph) const {
VLOG(4) << "trt map some ops to matrix_multiply"; VLOG(4) << "trt map some ops to matrix_multiply";
GET_IR_NODE_FROM_SUBGRAPH(ops, ops, mul_matmul_matmul_v2); GET_IR_NODE_FROM_SUBGRAPH(ops, ops, mul_matmul_matmul_v2);
GET_IR_NODE_FROM_SUBGRAPH(ops_out, ops_out, mul_matmul_matmul_v2); GET_IR_NODE_FROM_SUBGRAPH(ops_out, ops_out, mul_matmul_matmul_v2);
OpDesc desc(ops->Op()->Block()); auto op_desc = ops->Op();
desc.SetType("matrix_multiply"); op_desc->SetAttr("original_type", op_desc->Type());
desc.SetInput("X", {ops->Op()->Input("X").front()}); op_desc->SetType("matrix_multiply");
desc.SetInput("Y", {ops->Op()->Input("Y").front()}); ops->RenameOp("matrix_multiply");
desc.SetOutput("Out", {ops_out->Name()});
// OpDesc original_desc(*(ops->Op()));
if (ops->Op()->HasAttr("transpose_X") || ops->Op()->HasAttr("trans_x")) {
if (ops->Op()->HasAttr("transpose_X")) { if (op_desc->HasAttr("transpose_X") || op_desc->HasAttr("trans_x")) {
desc.SetAttr("transpose_x", ops->Op()->GetAttr("transpose_X")); if (op_desc->HasAttr("transpose_X")) {
op_desc->SetAttr("transpose_x", op_desc->GetAttr("transpose_X"));
} else { } else {
desc.SetAttr("transpose_x", ops->Op()->GetAttr("trans_x")); op_desc->SetAttr("transpose_x", op_desc->GetAttr("trans_x"));
} }
} else { } else {
desc.SetAttr("transpose_x", false); op_desc->SetAttr("transpose_x", false);
} }
if (ops->Op()->HasAttr("transpose_Y") || ops->Op()->HasAttr("trans_y")) { if (op_desc->HasAttr("transpose_Y") || op_desc->HasAttr("trans_y")) {
if (ops->Op()->HasAttr("transpose_Y")) { if (op_desc->HasAttr("transpose_Y")) {
desc.SetAttr("transpose_y", ops->Op()->GetAttr("transpose_Y")); op_desc->SetAttr("transpose_y", op_desc->GetAttr("transpose_Y"));
} else { } else {
desc.SetAttr("transpose_y", ops->Op()->GetAttr("trans_y")); op_desc->SetAttr("transpose_y", op_desc->GetAttr("trans_y"));
} }
} else { } else {
desc.SetAttr("transpose_y", false); op_desc->SetAttr("transpose_y", false);
}
if (ops->Op()->HasAttr("out_threshold")) {
desc.SetAttr("out_threshold", ops->Op()->GetAttr("out_threshold"));
} }
// Todo: remove attr(x_num_col_dims, y_num_col_dims, alpha) // Todo: remove attr(x_num_col_dims, y_num_col_dims, alpha)
if (ops->Op()->HasAttr("x_num_col_dims")) { if (op_desc->HasAttr("x_num_col_dims")) {
desc.SetAttr("x_num_col_dims", ops->Op()->GetAttr("x_num_col_dims")); op_desc->SetAttr("x_num_col_dims", op_desc->GetAttr("x_num_col_dims"));
} else { } else {
int32_t x_num_col_dims = -1; int32_t x_num_col_dims = -1;
desc.SetAttr("x_num_col_dims", x_num_col_dims); op_desc->SetAttr("x_num_col_dims", x_num_col_dims);
} }
// op_teller: Only support y_num_col_dims == y.rank - 1; // op_teller: Only support y_num_col_dims == y.rank - 1;
int32_t y_num_col_dims = -1; int32_t y_num_col_dims = -1;
desc.SetAttr("y_num_col_dims", y_num_col_dims); op_desc->SetAttr("y_num_col_dims", y_num_col_dims);
float alpha = 1; float alpha = 1;
if (ops->Op()->HasAttr("alpha")) { if (op_desc->HasAttr("alpha")) {
alpha = PADDLE_GET_CONST(float, ops->Op()->GetAttr("alpha")); alpha = PADDLE_GET_CONST(float, op_desc->GetAttr("alpha"));
} }
desc.SetAttr("alpha", alpha); op_desc->SetAttr("alpha", alpha);
auto matrix_multiply_node = g->CreateOpNode(&desc);
for (auto node : ops->inputs) {
IR_NODE_LINK_TO(node, matrix_multiply_node);
}
IR_NODE_LINK_TO(matrix_multiply_node, ops_out);
GraphSafeRemoveNodes(graph, {ops});
++found_count; ++found_count;
}; };
gpd(graph, handler); gpd(graph, handler);
......
...@@ -224,6 +224,20 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -224,6 +224,20 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
->SetAllNodesLowerToTrt(use_cuda_graph); ->SetAllNodesLowerToTrt(use_cuda_graph);
} }
} }
// some ops are only implemented in paddle-trt,
// but not in paddle ,we should revert it.
for (auto *op_node : framework::ir::TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(0))) {
if (op_node->Op()->Type() == "matrix_multiply") {
auto origin_type =
op_node->Op()->GetAttrIfExists<std::string>("original_type");
LOG(WARNING) << "matrix_multiply can't enter into paddle-trt,"
<< "we will revert to " << origin_type;
op_node->Op()->SetType(origin_type);
op_node->RenameOp(origin_type);
}
}
} }
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs, std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册