diff --git a/lite/kernels/mlu/bridges/concat_op.cc b/lite/kernels/mlu/bridges/concat_op.cc index 1c3c0b1e35b26950ef07f7a4d63d84e0df06c4c5..76037c3358201fb99c2a1415c1f7bfd91ef06103 100644 --- a/lite/kernels/mlu/bridges/concat_op.cc +++ b/lite/kernels/mlu/bridges/concat_op.cc @@ -44,9 +44,14 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto dims = output_dims.size(); int axis = (param_axis < 0) ? (param_axis + dims) : param_axis; - CHECK_LE(axis, 4) << "Unsupport dims in mlu concat"; - int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2}; - int nhwc_axis = nchw_to_nhwc_axis_map[axis]; + CHECK_LT(axis, dims) << "Unsupport dims in mlu concat"; + std::vector nchw2nhwc_axis(dims); + nchw2nhwc_axis[0] = 0; + if (dims > 1) nchw2nhwc_axis[1] = dims - 1; + for (size_t i = 2; i < dims; ++i) { + nchw2nhwc_axis[i] = i - 1; + } + int nhwc_axis = nchw2nhwc_axis[axis]; auto output_tensor = graph->AddNode( out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); diff --git a/lite/kernels/mlu/bridges/transpose_op.cc b/lite/kernels/mlu/bridges/transpose_op.cc index 130e8417f1f9ad5ea4bc8c0b4ffaacf20124fae7..43f25550c26be484a452b0fc6e9a5e64226cf38c 100644 --- a/lite/kernels/mlu/bridges/transpose_op.cc +++ b/lite/kernels/mlu/bridges/transpose_op.cc @@ -22,12 +22,24 @@ namespace subgraph { namespace mlu { std::vector axis_to_nhwc(const std::vector& axis) { - CHECK_EQ(axis.size(), 4) << "Unsupport dim in mlu transpose"; - std::vector new_axis(4, 0); - const std::vector axis_map1 = {0, 2, 3, 1}; - const std::vector axis_map2 = {0, 3, 1, 2}; + std::vector new_axis(axis.size()); + + std::vector nhwc2nchw_axis(axis.size()); + nhwc2nchw_axis[0] = 0; + if (axis.size() > 1) nhwc2nchw_axis[1] = axis.size() - 1; + for (size_t i = 2; i < axis.size(); ++i) { + nhwc2nchw_axis[i] = i - 1; + } + + std::vector nchw2nhwc_axis(axis.size()); + nchw2nhwc_axis[0] = 0; + for (size_t i = 1; i < axis.size() - 1; ++i) { + nchw2nhwc_axis[i] = i + 1; + } + if (axis.size() > 1) nchw2nhwc_axis[axis.size() - 1] = 1; + for (size_t i = 0; i < new_axis.size(); ++i) { - new_axis[i] = axis_map2[axis[axis_map1[i]]]; + new_axis[i] = nhwc2nchw_axis[axis[nchw2nhwc_axis[i]]]; } return new_axis; } @@ -51,9 +63,6 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto output_dims = output->dims().Vectorize(); auto axis = op_info->GetAttr>("axis"); - while (axis.size() < 4) { - axis.push_back(axis.size()); - } std::vector axis_nhwc = axis_to_nhwc(axis); auto output_tensor = graph->AddNode( diff --git a/lite/kernels/mlu/io_copy_compute.cc b/lite/kernels/mlu/io_copy_compute.cc index 7178cdb109fab8cfeda797a7e6ef53af34ec834d..d11279f767d77b04a78d51c3b283ed638070b64d 100644 --- a/lite/kernels/mlu/io_copy_compute.cc +++ b/lite/kernels/mlu/io_copy_compute.cc @@ -173,14 +173,14 @@ REGISTER_LITE_KERNEL( kMLU, kInt8, kNHWC, - paddle::lite::kernels::mlu::IoCopyMluToHostCompute, - device_to_host_kInt8) + paddle::lite::kernels::mlu::IoCopyHostToMluCompute, + host_to_device_to_kInt8) .BindInput("Input", - {LiteType::GetTensorTy(TARGET(kMLU), + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt8), DATALAYOUT(kAny))}) .BindOutput("Out", - {LiteType::GetTensorTy(TARGET(kHost), + {LiteType::GetTensorTy(TARGET(kMLU), PRECISION(kInt8), DATALAYOUT(kAny))}) .Finalize();