diff --git a/lite/kernels/mlu/bridges/act_op.cc b/lite/kernels/mlu/bridges/act_op.cc index 5d6e2f9ca0932c03ade897b416420e41019d775a..286195d9d5f961288dd0156db31ff8aacae58227 100644 --- a/lite/kernels/mlu/bridges/act_op.cc +++ b/lite/kernels/mlu/bridges/act_op.cc @@ -37,7 +37,7 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto output = scope->FindVar(out_var_name)->GetMutable(); auto output_dims = output->dims().Vectorize(); auto output_tensor = graph->AddNode( - out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, fp_type); + out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, fp_type); CHECK(graph->HasNode(x_var_name)); auto input_tensor = graph->GetNode(x_var_name); cnmlBaseOp_t activation_op; diff --git a/lite/kernels/mlu/bridges/batch_norm_op.cc b/lite/kernels/mlu/bridges/batch_norm_op.cc index d95a5115c96c10a8881f50c44fee9881c6a9e218..7353a685dd5fd3a5bcc8c88def8ffb8b96fdde55 100644 --- a/lite/kernels/mlu/bridges/batch_norm_op.cc +++ b/lite/kernels/mlu/bridges/batch_norm_op.cc @@ -42,7 +42,7 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto output = scope->FindVar(y_var_name)->GetMutable(); auto output_dims = output->dims().Vectorize(); auto output_tensor = graph->AddNode( - y_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); + y_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); CHECK(graph->HasNode(x_var_name)); diff --git a/lite/kernels/mlu/bridges/concat_op.cc b/lite/kernels/mlu/bridges/concat_op.cc index e2986b964853ab90e5d7317638ee8c2c2969a4d0..14f0da746a00c1ea10ffae824217dbb2df84df55 100644 --- a/lite/kernels/mlu/bridges/concat_op.cc +++ b/lite/kernels/mlu/bridges/concat_op.cc @@ -32,60 +32,33 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto x_var_name = op_info->Input("X"); auto out_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(out_var_name)->GetMutable(); + auto output_dims = output->dims().Vectorize(); auto param_axis = op_info->GetAttr("axis"); - // auto x = scope->FindVar(x_var_name[0])->GetMutable(); - - auto input_num = x_var_name.size(); std::vector input_tensor; - std::vector> input_dims; for (auto x_name : x_var_name) { CHECK(graph->HasNode(x_name)); input_tensor.push_back(graph->GetNode(x_name)->mlu_tensor()); - auto x = scope->FindVar(x_name)->GetMutable(); - input_dims.push_back(x->dims().Vectorize()); } - auto dims = input_dims[0].size(); + auto dims = output_dims.size(); int axis = (param_axis < 0) ? (param_axis + dims) : param_axis; - int nhwc_axis = -1; - if (dims == 4) { - int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2}; - nhwc_axis = nchw_to_nhwc_axis_map[axis]; - } else if (dims == 3) { - int nchw_to_nhwc_axis_map[3] = {0, 2, 1}; - nhwc_axis = nchw_to_nhwc_axis_map[axis]; - } else { - CHECK(0) << "Unsupport dims in mlu concat"; - } - - std::vector output_dims; - output_dims.assign(dims, 0); - - /* std::cout << string_format("concat axis: %d(NCHW), %d(NHWC)", axis, - * nhwc_axis) << std::endl; */ - - for (int i = 0; i < output_dims.size(); ++i) { - if (i == nhwc_axis) { - for (auto& dim : input_dims) output_dims[i] += dim[i]; - } else { - output_dims[i] = input_dims[0][i]; - } - } - - /* std::cout << string_format("concat output dim: %ld, %ld, %ld, %ld") << - * std::endl; */ + CHECK_LE(axis, 4) << "Unsupport dims in mlu concat"; + int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2}; + int nhwc_axis = nchw_to_nhwc_axis_map[axis]; - auto* output = scope->FindVar(out_var_name)->GetMutable(); - output->Resize(output_dims); auto output_tensor = graph->AddNode( - out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); + out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); cnmlBaseOp_t concat_op; - cnmlTensor_t outputs[1]; - outputs[0] = output_tensor->mlu_tensor(); - CNML_CALL(cnmlCreateNdConcatOp( - &concat_op, nhwc_axis, input_tensor.data(), input_num, outputs, 1)); + cnmlTensor_t outputs = output_tensor->mlu_tensor(); + CNML_CALL(cnmlCreateNdConcatOp(&concat_op, + nhwc_axis, + input_tensor.data(), + x_var_name.size(), + &outputs, + 1)); graph->FuseOp(concat_op); return SUCCESS; } diff --git a/lite/kernels/mlu/bridges/conv_op.cc b/lite/kernels/mlu/bridges/conv_op.cc index e4f672e06e38c0212d1887de5cebed6a35bd0e0d..6a7ef408eb7432950d5a0985dd6e174236e937e0 100644 --- a/lite/kernels/mlu/bridges/conv_op.cc +++ b/lite/kernels/mlu/bridges/conv_op.cc @@ -33,13 +33,14 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { // get input, filter and op attributes const auto input_var_name = op_info->Input("Input").front(); - const auto& input_dims_nhwc = + const auto& input_dims = scope->FindVar(input_var_name)->GetMutable()->dims(); - const auto input_dims = DimNHWC2NCHW(input_dims_nhwc); const auto filter_var_name = op_info->Input("Filter").front(); auto* filter = scope->FindVar(filter_var_name)->GetMutable(); const auto& filter_dims = filter->dims(); const auto output_var_name = op_info->Output("Output").front(); + auto* output = scope->FindVar(output_var_name)->GetMutable(); + const auto output_shape = output->dims().Vectorize(); const auto bs = input_dims[0]; const auto oc = filter_dims[0]; CHECK_EQ(input_dims.size(), 4); @@ -70,24 +71,8 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { input_dims, filter_dims); - std::vector output_shape({bs, oc}); - for (size_t i = 0; i < 2; i++) { - const int dkernel = dilations[i] * (filter_dims[2 + i] - 1) + 1; - output_shape.push_back( - (input_dims[i + 2] + paddings[2 * i] + paddings[2 * i + 1] - dkernel) / - strides[i] + - 1); - } - - const auto output_shape_nhwc = DimNCHW2NHWC(output_shape); - const auto output_tensor = graph->AddNode(output_var_name, - output_shape_nhwc, - CNML_TENSOR, - CNML_NHWC, - graph->FPType()); - scope->FindVar(output_var_name) - ->GetMutable<::paddle::lite::Tensor>() - ->Resize(output_shape_nhwc); + const auto output_tensor = graph->AddNode( + output_var_name, output_shape, CNML_TENSOR, CNML_NCHW, graph->FPType()); // Create filter node const auto filter_tensor = graph->AddNode(filter_var_name, @@ -156,7 +141,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { const auto input_scale = op_info->GetAttr("input_scale"); bool use_first_conv = false; - if (lite::DeviceInfo::Global().UseFirstConv() && input_dims_nhwc[3] == 3) { + if (lite::DeviceInfo::Global().UseFirstConv() && input_dims[1] == 3) { use_first_conv = true; } diff --git a/lite/kernels/mlu/bridges/elementwise_ops.cc b/lite/kernels/mlu/bridges/elementwise_ops.cc index 4ef949925d20e0a2cb1c7f25d840e2041d79dd7a..41526a0100ba71be9eda25983cb96aa888d6cf4d 100644 --- a/lite/kernels/mlu/bridges/elementwise_ops.cc +++ b/lite/kernels/mlu/bridges/elementwise_ops.cc @@ -77,7 +77,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto output_tensor = graph->AddNode(out_var_name, x->dims().Vectorize(), CNML_TENSOR, - CNML_NHWC, + CNML_NCHW, graph->FPType()); cnmlBaseOp_t elementwise_op; @@ -90,7 +90,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto mid_tensor = graph->AddNode(out_var_name + "_mid", x->dims().Vectorize(), CNML_TENSOR, - CNML_NHWC, + CNML_NCHW, graph->FPType()); CNML_CALL(cnmlCreateBroadcastAddOp(&elementwise_op, x_tensor->mlu_tensor(), diff --git a/lite/kernels/mlu/bridges/fc_op.cc b/lite/kernels/mlu/bridges/fc_op.cc index f480a9110790406ddb2aa7464221c7062b26268e..286feec8d4d44eaa025f333d559c32ca72f042ff 100644 --- a/lite/kernels/mlu/bridges/fc_op.cc +++ b/lite/kernels/mlu/bridges/fc_op.cc @@ -37,6 +37,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { // int in_num_col_dims = op_info->GetAttr("in_num_col_dims"); auto x = scope->FindVar(x_var_name)->GetMutable(); auto w = scope->FindVar(w_var_name)->GetMutable(); + auto output = scope->FindVar(output_var_name)->GetMutable(); auto x_dims = x->dims(); auto w_dims = w->dims(); @@ -50,15 +51,11 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto input_scale = op_info->GetAttr("input_scale"); - std::vector output_shape_nhwc({x_dims[0], 1, 1, w_dims[1]}); auto output_tensor = graph->AddNode(output_var_name, - output_shape_nhwc, + output->dims().Vectorize(), CNML_TENSOR, - CNML_NHWC, + CNML_NCHW, graph->FPType()); - scope->FindVar(output_var_name) - ->GetMutable<::paddle::lite::Tensor>() - ->Resize(output_shape_nhwc); std::string bias_var_name; std::shared_ptr bias_tensor; diff --git a/lite/kernels/mlu/bridges/graph.cc b/lite/kernels/mlu/bridges/graph.cc index 27c6ab2597fa6930b14c4c4e34750030608167b6..65c2f8214c13ee8d004dbe4b2e706523d007469c 100644 --- a/lite/kernels/mlu/bridges/graph.cc +++ b/lite/kernels/mlu/bridges/graph.cc @@ -25,12 +25,12 @@ namespace mlu { std::shared_ptr Graph::AddNode(const std::string& name, std::vector shape, cnmlTensorType_t tensor_type, - cnmlDataOrder_t data_order, + cnmlDataOrder_t shape_order, cnmlDataType_t mlu_dtype, void* raw_ptr) { CHECK(!HasNode(name)); auto node = std::shared_ptr( - new MLUTensor(shape, tensor_type, data_order, mlu_dtype)); + new MLUTensor(shape, tensor_type, shape_order, mlu_dtype)); node->set_mlu_ptr(raw_ptr); nodes_.insert(std::make_pair(name, node)); return node; diff --git a/lite/kernels/mlu/bridges/interpolate_op.cc b/lite/kernels/mlu/bridges/interpolate_op.cc index 77a6722e03a63c4ee06a0159916509aa0ca36139..e201199824d8042abd6002ccbe5bb659a9ca2898 100644 --- a/lite/kernels/mlu/bridges/interpolate_op.cc +++ b/lite/kernels/mlu/bridges/interpolate_op.cc @@ -45,8 +45,8 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { CHECK(graph->HasNode(x_var_name)); auto input_tensor = graph->GetNode(x_var_name); - auto in_h = x_dims[1]; - auto in_w = x_dims[2]; + auto in_h = x_dims[2]; + auto in_w = x_dims[3]; // Priority: SizeTensor > OutSize > Scale > scale > out_h/out_w if (HasInputArg(op_info, scope, "SizeTensor")) { @@ -69,25 +69,13 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { } } - out->Resize({x_dims[0], out_h, out_w, x_dims[3]}); - auto output_tensor = graph->AddNode(out_var_name, out->dims().Vectorize(), CNML_TENSOR, - CNML_NHWC, + CNML_NCHW, graph->FPType()); cnmlBaseOp_t interp_op; - /* if (interp_method == "bilinear") { */ - /* cnmlInterpOpParam_t interp_param; */ - /* CNML_CALL(cnmlCreateInterpOpParam(&interp_param, out_w, out_h, - * align_corners)); */ - /* CNML_CALL(cnmlCreateInterpOp(&interp_op, */ - /* input_tensor->mlu_tensor(), */ - /* output_tensor->mlu_tensor(), */ - /* interp_param)); */ - /* CNML_CALL(cnmlDestroyInterpOpParam(&interp_param)); */ - /* } else if (interp_method == "nearest") { */ cnmlNearestNeighborOpParam_t nn_param; CNML_CALL(cnmlCreateNearestNeighborOpParam(&nn_param, out_w, out_h)); CNML_CALL(cnmlSetNearestNeighborAlignCorner(&nn_param, align_corners)); @@ -96,11 +84,6 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { output_tensor->mlu_tensor(), nn_param)); CNML_CALL(cnmlDestroyNearestNeighborOpParam(&nn_param)); - /* } else { */ - /* LOG(WARNING) << "[MLU] Unsupported interpolate method: " << - * interp_method; */ - /* return FAILED; */ - /* } */ graph->FuseOp(interp_op); return SUCCESS; diff --git a/lite/kernels/mlu/bridges/pool_op.cc b/lite/kernels/mlu/bridges/pool_op.cc index 3119b6c77dca10641c7c7c32072969fedb1ecef6..f77c8084c76fc52c39938e723f02bde9b3cac41b 100644 --- a/lite/kernels/mlu/bridges/pool_op.cc +++ b/lite/kernels/mlu/bridges/pool_op.cc @@ -47,9 +47,8 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { // Get input, and attributes auto x_var_name = op_info->Input("X").front(); auto x = scope->FindTensor(x_var_name); - auto input_dims_nhwc = x->dims(); - const auto input_dims = DimNHWC2NCHW(input_dims_nhwc); auto output_var_name = op_info->Output("Out").front(); + auto output_shape = scope->FindTensor(output_var_name)->dims().Vectorize(); auto pooling_type = op_info->GetAttr("pooling_type"); auto ceil_mode = op_info->GetAttr("ceil_mode"); auto paddings = op_info->GetAttr>("paddings"); @@ -81,23 +80,17 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { strides, ksize); - std::vector output_shape({input_dims[0], input_dims[1]}); - for (size_t i = 0; i < 2; i++) { - output_shape.push_back( - (input_dims[i + 2] + paddings[2 * i] + paddings[2 * i + 1] - ksize[0]) / - strides[i] + - 1); - } + // std::vector output_shape({input_dims[0], input_dims[1]}); + // for (size_t i = 0; i < 2; i++) { + // output_shape.push_back( + // (input_dims[i + 2] + paddings[2 * i] + paddings[2 * i + 1] - + // ksize[0]) / + // strides[i] + + // 1); + // } - auto output_shape_nhwc = DimNCHW2NHWC(output_shape); - auto output_tensor = graph->AddNode(output_var_name, - output_shape_nhwc, - CNML_TENSOR, - CNML_NHWC, - graph->FPType()); - scope->FindVar(output_var_name) - ->GetMutable<::paddle::lite::Tensor>() - ->Resize(output_shape_nhwc); + auto output_tensor = graph->AddNode( + output_var_name, output_shape, CNML_TENSOR, CNML_NCHW, graph->FPType()); cnmlPoolOpParam_t pool_param; CNML_CALL( diff --git a/lite/kernels/mlu/bridges/scale_op.cc b/lite/kernels/mlu/bridges/scale_op.cc index d500786006286884af0843967410fbc907923e56..5557602bd7576ccd71c51f52a538a45fe27f7ada 100644 --- a/lite/kernels/mlu/bridges/scale_op.cc +++ b/lite/kernels/mlu/bridges/scale_op.cc @@ -36,7 +36,7 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto output = scope->FindVar(out_var_name)->GetMutable(); auto output_dims = output->dims().Vectorize(); auto output_tensor = graph->AddNode( - out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); + out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); auto bias_after_scale = op_info->GetAttr("bias_after_scale"); auto scale = op_info->GetAttr("scale"); auto bias = op_info->GetAttr("bias"); diff --git a/lite/kernels/mlu/bridges/softmax_op.cc b/lite/kernels/mlu/bridges/softmax_op.cc index b9e2b1116dc95ec276f8d85a5669cec45d98ea39..17c911675718a15c7ede4888b268ffcd62b4d8ed 100644 --- a/lite/kernels/mlu/bridges/softmax_op.cc +++ b/lite/kernels/mlu/bridges/softmax_op.cc @@ -45,11 +45,10 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { axis = output_dims.size() + axis; } } - int nhwc_axis = nchw_to_nhwc_aixs_map[axis]; auto output_tensor = graph->AddNode( - out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); + out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); cnmlBaseOp_t softmax_op; CNML_CALL(cnmlCreateNdSoftmaxOp(&softmax_op, nhwc_axis, diff --git a/lite/kernels/mlu/bridges/transpose_op.cc b/lite/kernels/mlu/bridges/transpose_op.cc index 74af692b6f5b8834e29b8e008a4e48801a1e8820..f9b0caeb3e0d0977ebfbb32566b8af5936f7724e 100644 --- a/lite/kernels/mlu/bridges/transpose_op.cc +++ b/lite/kernels/mlu/bridges/transpose_op.cc @@ -21,8 +21,30 @@ namespace lite { namespace subgraph { namespace mlu { -std::vector axis_to_nhwc4d(const std::vector& axis) { - CHECK_EQ(axis.size(), 4); +// std::vector axis_to_nhwc4d(const std::vector& axis) { +// CHECK_EQ(axis.size(), 4); +// std::vector new_axis(4, 0); +// const std::vector axis_map1 = {0, 2, 3, 1}; +// const std::vector axis_map2 = {0, 3, 1, 2}; +// for (size_t i = 0; i < new_axis.size(); ++i) { +// new_axis[i] = axis_map2[axis[axis_map1[i]]]; +// } +// return new_axis; +//} +// +// std::vector axis_to_nhw3d(const std::vector& axis) { +// CHECK_EQ(axis.size(), 3); +// std::vector new_axis(3, 0); +// const std::vector axis_map = {0, 2, 1}; +// for (size_t i = 0; i < new_axis.size(); ++i) { +// new_axis[i] = axis_map[axis[axis_map[i]]]; +// } +// new_axis.push_back(3); +// return new_axis; +//} + +std::vector axis_to_nhwc(const std::vector& axis) { + CHECK_EQ(axis.size(), 4) << "Unsupport dim in mlu transpose"; std::vector new_axis(4, 0); const std::vector axis_map1 = {0, 2, 3, 1}; const std::vector axis_map2 = {0, 3, 1, 2}; @@ -32,26 +54,6 @@ std::vector axis_to_nhwc4d(const std::vector& axis) { return new_axis; } -std::vector axis_to_nhw3d(const std::vector& axis) { - CHECK_EQ(axis.size(), 3); - std::vector new_axis(3, 0); - const std::vector axis_map = {0, 2, 1}; - for (size_t i = 0; i < new_axis.size(); ++i) { - new_axis[i] = axis_map[axis[axis_map[i]]]; - } - new_axis.push_back(3); - return new_axis; -} - -std::vector infer_shape(const std::vector& x_dims, - const std::vector& axis_nhwc) { - std::vector out_dims(x_dims); - for (size_t i = 0; i < out_dims.size(); ++i) { - out_dims[i] = x_dims[axis_nhwc[i]]; - } - return out_dims; -} - int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { CHECK(ctx != nullptr); CHECK(op != nullptr); @@ -71,21 +73,13 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto output_dims = output->dims().Vectorize(); auto axis = op_info->GetAttr>("axis"); - - std::vector axis_nhwc; - if (axis.size() == 4) { - axis_nhwc = axis_to_nhwc4d(axis); - } else if (axis.size() == 3) { - axis_nhwc = axis_to_nhw3d(axis); - } else { - CHECK(0) << "Unsupport dim in mlu transpose"; + while (axis.size() < 4) { + axis.push_back(axis.size()); } - - auto output_dims_nhwc = infer_shape(x_dims, axis_nhwc); - output->Resize(output_dims_nhwc); + std::vector axis_nhwc = axis_to_nhwc(axis); auto output_tensor = graph->AddNode( - out_var_name, output_dims_nhwc, CNML_TENSOR, CNML_NHWC, graph->FPType()); + out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); CHECK(graph->HasNode(x_var_name)); auto input_tensor = graph->GetNode(x_var_name); @@ -113,7 +107,6 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { REGISTER_SUBGRAPH_BRIDGE(transpose, kMLU, paddle::lite::subgraph::mlu::TransposeConverter); - REGISTER_SUBGRAPH_BRIDGE(transpose2, kMLU, paddle::lite::subgraph::mlu::TransposeConverter); diff --git a/lite/kernels/mlu/layout_compute.h b/lite/kernels/mlu/layout_compute.h index 2d355e5ddf590b55c22a103d3d2a24ad4357da4c..5e87e3526417573f2e0f01280b1d86ccb5691093 100644 --- a/lite/kernels/mlu/layout_compute.h +++ b/lite/kernels/mlu/layout_compute.h @@ -67,6 +67,8 @@ class LayoutNchwToNhwcCompute auto x_dims = param.x->dims().size(); auto& context = this->ctx_->template As(); + const auto origin_dims = out->dims().Vectorize(); + std::vector axis; switch (x_dims) { case 2: @@ -88,6 +90,10 @@ class LayoutNchwToNhwcCompute LayoutTransCompute( x_dims, context, *x, out, axis); + + if (x_dims > 2) { + out->Resize(origin_dims); + } } std::string doc() const override { @@ -109,20 +115,22 @@ class LayoutNhwcToNchwCompute auto x_dims = param.x->dims().size(); auto& context = this->ctx_->template As(); + const auto origin_dims = out->dims().Vectorize(); + std::vector axis; switch (x_dims) { case 2: axis = {0, 1}; break; case 3: - axis = {0, 2, 1}; out->Resize(std::vector{ out->dims()[0], out->dims()[2], out->dims()[1]}); + axis = {0, 2, 1}; break; case 4: - axis = {0, 3, 1, 2}; out->Resize(std::vector{ out->dims()[0], out->dims()[3], out->dims()[1], out->dims()[2]}); + axis = {0, 3, 1, 2}; break; default: CHECK(0) << "Unsupport dim in mlu layout nhwc to nchw"; @@ -130,6 +138,10 @@ class LayoutNhwcToNchwCompute LayoutTransCompute( x_dims, context, *x, out, axis); + + if (x_dims > 2) { + out->Resize(origin_dims); + } } std::string doc() const override { diff --git a/lite/kernels/mlu/subgraph_compute.h b/lite/kernels/mlu/subgraph_compute.h index 0e79e54eb2888fa9c2d6867d16de81c2f334af29..51a9c0ffe05232bd807017e79c490d947e26c0f7 100644 --- a/lite/kernels/mlu/subgraph_compute.h +++ b/lite/kernels/mlu/subgraph_compute.h @@ -83,7 +83,7 @@ class SubgraphEngine : public subgraph::Engine { graph_.AddNode(input_name, input_tensor->dims().Vectorize(), CNML_TENSOR, - CNML_NHWC, + CNML_NCHW, graph_.FPType(), const_cast(input_tensor->raw_data())); CHECK(input_node); @@ -99,9 +99,7 @@ class SubgraphEngine : public subgraph::Engine { CHECK(op); std::string op_type = op->op_info()->Type(); op->CheckShape(); - if (op_type != "concat") { - op->InferShape(); - } + op->InferShape(); if (!bridges.Exists(op_type, TARGET(kMLU))) { LOG(INFO) << "MLU bridges doesn't support op_type: " << op_type; return subgraph::FAILED;