diff --git a/lite/kernels/mlu/bridges/CMakeLists.txt b/lite/kernels/mlu/bridges/CMakeLists.txt index 00f544c90b4c2580aaa3c20493b4bed2dbeab0ba..93dcff45c440a1da28933db1409182373d90671a 100644 --- a/lite/kernels/mlu/bridges/CMakeLists.txt +++ b/lite/kernels/mlu/bridges/CMakeLists.txt @@ -17,6 +17,7 @@ lite_cc_library(subgraph_bridge_softmax_op_mlu SRCS softmax_op.cc DEPS ${subgrap lite_cc_library(subgraph_bridge_fc_op_mlu SRCS fc_op.cc DEPS ${subgraph_bridge_deps_mlu}) lite_cc_library(subgraph_bridge_scale_op_mlu SRCS scale_op.cc DEPS ${subgraph_bridge_deps_mlu}) lite_cc_library(subgraph_bridge_interp_op_mlu SRCS interpolate_op.cc DEPS ${subgraph_bridge_deps_mlu}) +lite_cc_library(subgraph_bridge_concat_op_mlu SRCS concat_op.cc DEPS ${subgraph_bridge_deps_mlu}) set(mlu_subgraph_bridges subgraph_bridge_registry subgraph_bridge_utility_mlu @@ -30,6 +31,7 @@ set(mlu_subgraph_bridges subgraph_bridge_batch_norm_op_mlu subgraph_bridge_scale_op_mlu subgraph_bridge_interp_op_mlu + subgraph_bridge_concat_op_mlu CACHE INTERNAL "mlu_subgraph_bridges") lite_cc_library(subgraph_test_helper_mlu SRCS test_helper.cc DEPS ${mlu_subgraph_bridges}) @@ -42,5 +44,6 @@ lite_cc_test(test_softmax_converter_mlu SRCS softmax_op_test.cc DEPS scope optim lite_cc_test(test_fc_converter_mlu SRCS fc_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_scale_converter_mlu SRCS scale_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_interp_converter_mlu SRCS interpolate_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) +lite_cc_test(test_concat_converter_mlu SRCS concat_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) message(STATUS "+++++ mlu_subgraph_bridges: ${mlu_subgraph_bridges}") diff --git a/lite/kernels/mlu/bridges/concat_op.cc b/lite/kernels/mlu/bridges/concat_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..075529c3e8c9897d2dfc6f86815af5aa2d2cb114 --- /dev/null +++ b/lite/kernels/mlu/bridges/concat_op.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/bridges/graph.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[MLU] Converting " + op_type + "..."; + + auto x_var_name = op_info->Input("X"); + auto out_var_name = op_info->Output("Out").front(); + auto param_axis = op_info->GetAttr("axis"); + // auto x = scope->FindVar(x_var_name[0])->GetMutable(); + + auto input_num = x_var_name.size(); + + auto output = scope->FindVar(out_var_name)->GetMutable(); + auto output_dims = output->dims().Vectorize(); + auto output_tensor = graph->AddNode( + out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); + + int axis = (param_axis < 0) ? (param_axis + output_dims.size()) : param_axis; + + std::vector input_tensor; + for (auto x_name : x_var_name) { + CHECK(graph->HasNode(x_name)); + input_tensor.push_back(graph->GetNode(x_name)->mlu_tensor()); + } + int nchw_to_nhwc_aixs_map[4] = {0, 3, 1, 2}; + int nhwc_axis = nchw_to_nhwc_aixs_map[axis]; + + cnmlBaseOp_t concat_op; + auto output_t = output_tensor->mlu_tensor(); + CNML_CALL(cnmlCreateNdConcatOp( + &concat_op, nhwc_axis, input_tensor.data(), input_num, &output_t, 1)); + graph->FuseOp(concat_op); + return SUCCESS; +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(concat, + kMLU, + paddle::lite::subgraph::mlu::ConcatConverter); diff --git a/lite/kernels/mlu/bridges/concat_op_test.cc b/lite/kernels/mlu/bridges/concat_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b75ebc0951af91799f40ef8782a5bca43d0a87e1 --- /dev/null +++ b/lite/kernels/mlu/bridges/concat_op_test.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/concat_op.h" +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/mlu/bridges/test_helper.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +void concat_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = op_info->Input("X"); + std::vector inputs; + for (auto var : x) { + inputs.push_back(scope->FindVar(var)->GetMutable()); + } + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + int axis = op_info->GetAttr("axis"); + std::vector inputs_concat(inputs.size()); + for (int j = 0; j < inputs.size(); ++j) { + inputs_concat[j] = inputs[j]; + } + size_t num = inputs.size(); + int rows = 1; + auto dim_0 = inputs[0]->dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + std::vector inputs_cols(inputs.size()); + for (int i = 0; i < num; ++i) { + int t_cols = inputs[i]->numel() / rows; + out_cols += t_cols; + inputs_cols[i] = t_cols; + } + for (int k = 0; k < out_rows; ++k) { + float* dst_ptr = out->mutable_data() + k * out_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = inputs_cols[j]; + const float* src_prt = inputs[j]->data() + k * col_len; + std::memcpy(dst_ptr + col_idx, src_prt, sizeof(float) * col_len); + col_idx += col_len; + } + } +} + +void test_concat(std::vector> input, int axis) { + std::string x_var_name = "x"; + std::string y_var_name = "y"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + + // prepare input&output variables + Scope scope; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* y = scope.Var(y_var_name)->GetMutable(); + x->Resize(DDim(input[0])); + y->Resize(DDim(input[1])); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + CHECK_EQ(out->dims(), out_ref->dims()); + + // initialize input&output data + FillTensor(x); + FillTensor(y); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("concat"); + opdesc.SetInput("X", {x_var_name, y_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("axis", axis); + + auto op = CreateOp(opdesc, &scope); + concat_ref(op); + out_ref->CopyDataFrom(*out); + + Tensor input_x, input_y; + input_x.Resize(DDim(input[0])); + input_y.Resize(DDim(input[1])); + transpose(x->mutable_data(), + input_x.mutable_data(), + {static_cast(input[0][0]), + static_cast(input[0][1]), + static_cast(input[0][2]), + static_cast(input[0][3])}, + {0, 2, 3, 1}); + transpose(y->mutable_data(), + input_y.mutable_data(), + {static_cast(input[1][0]), + static_cast(input[1][1]), + static_cast(input[1][2]), + static_cast(input[1][3])}, + {0, 2, 3, 1}); + auto os = out->dims(); + out->Resize({static_cast(os[0]), + static_cast(os[2]), + static_cast(os[3]), + static_cast(os[1])}); + x->CopyDataFrom(input_x); + y->CopyDataFrom(input_y); + x->Resize({static_cast(input[0][0]), + static_cast(input[0][2]), + static_cast(input[0][3]), + static_cast(input[0][1])}); + y->Resize({static_cast(input[1][0]), + static_cast(input[1][2]), + static_cast(input[1][3]), + static_cast(input[1][1])}); + + LaunchOp(op, {x_var_name, y_var_name}, {out_var_name}); + + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + + Tensor output_trans; + output_trans.Resize(out->dims()); + transpose(out_data, + output_trans.mutable_data(), + {static_cast(os[0]), + static_cast(os[2]), + static_cast(os[3]), + static_cast(os[1])}, + {0, 3, 1, 2}); + out_data = output_trans.mutable_data(); + + for (int i = 0; i < out->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data[i], out_ref_data[i], 5e-4); + } +} + +TEST(MLUBridges, concat) { + test_concat({{3, 3, 5, 2}, {2, 3, 5, 2}}, 0); + test_concat({{3, 5, 5, 2}, {3, 1, 5, 2}}, 1); + test_concat({{3, 3, 2, 2}, {3, 3, 4, 2}}, 2); + test_concat({{3, 3, 5, 2}, {3, 3, 5, 6}}, 3); +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +USE_SUBGRAPH_BRIDGE(concat, kMLU);