提交 ed48feaa 编写于 作者: D dingminghui 提交者: jackzhang235

fix(concat): fix concat mlu kernel wrong output shape

上级 a112fb87
......@@ -37,25 +37,44 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto input_num = x_var_name.size();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType());
int axis = (param_axis < 0) ? (param_axis + output_dims.size()) : param_axis;
std::vector<cnmlTensor_t> input_tensor;
std::vector<std::vector<int64_t>> input_dims;
for (auto x_name : x_var_name) {
CHECK(graph->HasNode(x_name));
input_tensor.push_back(graph->GetNode(x_name)->mlu_tensor());
auto x = scope->FindVar(x_name)->GetMutable<Tensor>();
input_dims.push_back(x->dims().Vectorize());
}
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
int axis = (param_axis < 0) ? (param_axis + output->dims().size()) : param_axis;
int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2};
int nhwc_axis = nchw_to_nhwc_axis_map[axis];
std::vector<int64_t> output_dims;
output_dims.assign(output->dims().size(), 0);
/* std::cout << string_format("concat axis: %d(NCHW), %d(NHWC)", axis, nhwc_axis) << std::endl; */
for (int i = 0; i < output_dims.size(); ++i) {
if (i == nhwc_axis) {
for (auto &dim : input_dims) output_dims[i] += dim[i];
} else {
output_dims[i] = input_dims[0][i];
}
}
int nchw_to_nhwc_aixs_map[4] = {0, 3, 1, 2};
int nhwc_axis = nchw_to_nhwc_aixs_map[axis];
/* std::cout << string_format("concat output dim: %ld, %ld, %ld, %ld") << std::endl; */
output->Resize(output_dims);
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType());
cnmlBaseOp_t concat_op;
auto output_t = output_tensor->mlu_tensor();
cnmlTensor_t outputs[1];
outputs[0] = output_tensor->mlu_tensor();
CNML_CALL(cnmlCreateNdConcatOp(
&concat_op, nhwc_axis, input_tensor.data(), input_num, &output_t, 1));
&concat_op, nhwc_axis, input_tensor.data(), input_num, outputs, 1));
graph->FuseOp(concat_op);
return SUCCESS;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册