提交 6c405fca 编写于 作者: - --get 提交者: jackzhang235

(bugfix): fix Tensor.ToFile trans bug

上级 a74dfefb
......@@ -312,8 +312,28 @@ void MLUTensor::ToFile(std::string file_name) {
// trans to nchw
std::vector<float> cpu_data_trans(count);
if (data_order_ != CNML_NCHW) {
transpose(
cpu_data_fp32.data(), cpu_data_trans.data(), shape_, {0, 3, 1, 2});
switch (shape_.size()) {
case 4:
transpose(cpu_data_fp32.data(),
cpu_data_trans.data(),
shape_,
{0, 3, 1, 2});
break;
case 3:
transpose(
cpu_data_fp32.data(), cpu_data_trans.data(), shape_, {0, 2, 1});
break;
case 2:
transpose(
cpu_data_fp32.data(), cpu_data_trans.data(), shape_, {0, 1});
break;
case 1:
transpose(cpu_data_fp32.data(), cpu_data_trans.data(), shape_, {0});
break;
default:
CHECK(0) << "ToFile only support dim <=4";
break;
}
}
// to file
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册