未验证 提交 c16421c2 编写于 作者: Q Qi Li 提交者: GitHub

fix npu compile error, test=develop (#34656)

上级 ce733495
......@@ -59,7 +59,7 @@ cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS grap
pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper)
pass_library(fc_fuse_pass inference)
pass_library(map_matmul_to_mul_pass inference)
pass_library(attention_lstm_fuse_pass inference)
......
......@@ -39,7 +39,26 @@ class ExpandNPUKernel : public framework::OpKernel<T> {
"The number of dimensions of the input 'x' for Op(expand) "
"must be less than or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, rank));
switch (rank) { REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) }
switch (rank) {
case 1:
Expand<1>(context);
break;
case 2:
Expand<2>(context);
break;
case 3:
Expand<3>(context);
break;
case 4:
Expand<4>(context);
break;
case 5:
Expand<5>(context);
break;
case 6:
Expand<6>(context);
break;
}
}
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册