diff --git a/.github/workflows/github-CI.yml b/.github/workflows/github-CI.yml index b277d007bd039ecd4509c6c28ce851986d591157..a78a79ef93da8f66bed53d5e63f341a5721a6763 100644 --- a/.github/workflows/github-CI.yml +++ b/.github/workflows/github-CI.yml @@ -41,6 +41,12 @@ jobs: run: ./build.lite.mlu/lite/kernels/mlu/bridges/test_softmax_converter_mlu - name: test_transpose_converter_mlu run: ./build.lite.mlu/lite/kernels/mlu/bridges/test_transpose_converter_mlu + - name: test_slice_converter_mlu + run: ./build.lite.mlu/lite/kernels/mlu/bridges/test_slice_converter_mlu + - name: test_argmax_converter_mlu + run: ./build.lite.mlu/lite/kernels/mlu/bridges/test_argmax_converter_mlu + - name: test_split_converter_mlu + run: ./build.lite.mlu/lite/kernels/mlu/bridges/test_split_converter_mlu - name: test_classification run: | cd .. diff --git a/lite/kernels/mlu/bridges/CMakeLists.txt b/lite/kernels/mlu/bridges/CMakeLists.txt index 611bb3bb9a67b1afe10de224f4865f210d939ded..e1fc2f91e7d448015c0cf3ff3d033de0a2b2f473 100644 --- a/lite/kernels/mlu/bridges/CMakeLists.txt +++ b/lite/kernels/mlu/bridges/CMakeLists.txt @@ -21,6 +21,7 @@ lite_cc_library(subgraph_bridge_concat_op_mlu SRCS concat_op.cc DEPS ${subgraph_ lite_cc_library(subgraph_bridge_transpose_op_mlu SRCS transpose_op.cc DEPS ${subgraph_bridge_deps_mlu}) lite_cc_library(subgraph_bridge_dropout_op_mlu SRCS dropout_op.cc DEPS ${subgraph_bridge_deps_mlu}) lite_cc_library(subgraph_bridge_slice_op_mlu SRCS slice_op.cc DEPS ${subgraph_bridge_deps_mlu}) +lite_cc_library(subgraph_bridge_split_op_mlu SRCS split_op.cc DEPS ${subgraph_bridge_deps_mlu}) lite_cc_library(subgraph_bridge_argmax_op_mlu SRCS argmax_op.cc DEPS ${subgraph_bridge_deps_mlu}) set(mlu_subgraph_bridges subgraph_bridge_registry @@ -39,6 +40,7 @@ set(mlu_subgraph_bridges subgraph_bridge_concat_op_mlu subgraph_bridge_dropout_op_mlu subgraph_bridge_slice_op_mlu + subgraph_bridge_split_op_mlu subgraph_bridge_argmax_op_mlu CACHE INTERNAL "mlu_subgraph_bridges") @@ -62,6 +64,7 @@ lite_cc_test(test_concat_converter_mlu SRCS concat_op_test.cc DEPS scope optimiz lite_cc_test(test_transpose_converter_mlu SRCS transpose_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_dropout_converter_mlu SRCS dropout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_slice_converter_mlu SRCS slice_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) +lite_cc_test(test_split_converter_mlu SRCS split_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_argmax_converter_mlu SRCS argmax_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) if (LITE_BUILD_EXTRA) lite_cc_test(test_lrn_converter_mlu SRCS lrn_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) diff --git a/lite/kernels/mlu/bridges/paddle_use_bridges.h b/lite/kernels/mlu/bridges/paddle_use_bridges.h index 8a6ecf64e574c639c4353bd85c5ebc3b63f677ae..8c296b04ce253a594defa880df9b833b35514313 100644 --- a/lite/kernels/mlu/bridges/paddle_use_bridges.h +++ b/lite/kernels/mlu/bridges/paddle_use_bridges.h @@ -32,6 +32,7 @@ USE_SUBGRAPH_BRIDGE(sigmoid, kMLU); USE_SUBGRAPH_BRIDGE(elementwise_mul, kMLU); USE_SUBGRAPH_BRIDGE(dropout, kMLU); USE_SUBGRAPH_BRIDGE(argmax, kMLU); +USE_SUBGRAPH_BRIDGE(split, kMLU); #ifdef LITE_BUILD_EXTRA USE_SUBGRAPH_BRIDGE(lrn, kMLU) #endif diff --git a/lite/kernels/mlu/bridges/split_op.cc b/lite/kernels/mlu/bridges/split_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..aeee1a3669245beb188bacead03e1be73c5e8c68 --- /dev/null +++ b/lite/kernels/mlu/bridges/split_op.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/bridges/graph.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +int SplitConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[MLU] Converting " + op_type + "..."; + + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims().Vectorize(); + + auto out_var_name = op_info->Output("Out"); + + auto param_axis = op_info->GetAttr("axis"); + + auto num = op_info->GetAttr("num"); + auto sections = op_info->GetAttr>("sections"); + int64_t sections_num = static_cast(sections.size()); + auto output_num = num > 0 ? num : sections_num; + + std::vector output_tensor; + for (auto out_name : out_var_name) { + auto out = scope->FindVar(out_name)->GetMutable(); + auto out_dims = out->dims().Vectorize(); + auto out_tensor = graph->AddNode( + out_name, out_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); + output_tensor.push_back(out_tensor->mlu_tensor()); + } + + auto dims = x_dims.size(); + int axis = (param_axis < 0) ? (param_axis + dims) : param_axis; + CHECK_LE(axis, 4) << "Unsupport dims in mlu concat"; + int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2}; + int nhwc_axis = nchw_to_nhwc_axis_map[axis]; + + CHECK(graph->HasNode(x_var_name)); + auto input_tensor = graph->GetNode(x_var_name); + + cnmlBaseOp_t split_op; + cnmlTensor_t inputs = input_tensor->mlu_tensor(); + CNML_CALL(cnmlCreateNdSplitOp( + &split_op, nhwc_axis, &inputs, 1, output_tensor.data(), output_num)); + graph->FuseOp(split_op); + CNML_CALL(cnmlDestroyBaseOp(&split_op)); + return SUCCESS; +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(split, + kMLU, + paddle::lite::subgraph::mlu::SplitConverter); diff --git a/lite/kernels/mlu/bridges/split_op_test.cc b/lite/kernels/mlu/bridges/split_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..18bd74ea94ec46a7c01d39cab883ec325cfc0fd0 --- /dev/null +++ b/lite/kernels/mlu/bridges/split_op_test.cc @@ -0,0 +1,199 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/split_op.h" +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/mlu/bridges/test_helper.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +template +void split_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + int num = op_info->GetAttr("num"); + int axis = op_info->GetAttr("axis"); + std::vector sections = op_info->GetAttr>("sections"); + std::vector output_vec; + auto output = op_info->Output("Out"); + for (auto out_var : output) { + output_vec.push_back(scope->Var(out_var)->GetMutable()); + } + auto in_dims = x->dims(); + auto rank = in_dims.size(); + int outs_number = output_vec.size(); + std::vector outs_dims; + outs_dims.reserve(outs_number); + if (axis < 0) { + axis += rank; + } + if (num > 0) { + int out_axis_dim = in_dims[axis] / num; + for (int i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = out_axis_dim; + outs_dims.push_back(dim); + } + } else if (sections.size() > 0) { + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = sections[i]; + outs_dims.push_back(dim); + } + } + for (int j = 0; j < outs_dims.size(); ++j) { + output_vec[j]->Resize(outs_dims[j]); + } + + const dtype* din = x->mutable_data(); + std::vector in_strides(in_dims.size()); + in_strides[in_dims.size() - 1] = in_dims[in_dims.size() - 1]; + for (int i = in_dims.size() - 2; i >= 0; --i) { + in_strides[i] = in_strides[i + 1] * in_dims[i]; + } + + int input_offset = 0; + for (auto out : output_vec) { + auto out_dim = out->dims(); + std::vector out_strides(out_dim.size()); + out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; + for (int i = out_dim.size() - 2; i >= 0; --i) { + out_strides[i] = out_strides[i + 1] * out_dim[i]; + } + + dtype* out_data = out->mutable_data(); + int before = out_strides[0] / out_strides[axis]; + int in_after = in_strides[axis]; + int out_after = out_strides[axis]; + + for (int i = 0; i < before; ++i) { + std::memcpy(out_data + i * out_after, + din + input_offset + i * in_after, + sizeof(dtype) * out_after); + } + input_offset += out_strides[axis]; + } +} + +void test_split(int bs, + int ic, + int ih, + int iw, + int axis, + int num, + std::vector sections) { + // prepare input&output variables + std::string x_var_name = "x"; + std::string out_var_name_1 = "out_1"; + std::string out_var_name_2 = "out_2"; + std::string out_ref_var_name_1 = "out_ref_1"; + std::string out_ref_var_name_2 = "out_ref_2"; + + Scope scope; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out_1 = scope.Var(out_var_name_1)->GetMutable(); + auto* out_2 = scope.Var(out_var_name_2)->GetMutable(); + auto* out_ref_1 = scope.Var(out_ref_var_name_1)->GetMutable(); + auto* out_ref_2 = scope.Var(out_ref_var_name_2)->GetMutable(); + x->Resize({bs, ic, ih, iw}); + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("split"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name_1, out_var_name_2}); + opdesc.SetAttr("axis", axis); + opdesc.SetAttr("sections", sections); + opdesc.SetAttr("num", num); + + auto op = CreateOp(opdesc, &scope); + split_ref(op); + out_ref_1->CopyDataFrom(*out_1); + out_ref_2->CopyDataFrom(*out_2); + // execute reference implementation and save to output tensor + + Tensor input; + input.Resize({bs, ic, ih, iw}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(bs), + static_cast(ic), + static_cast(ih), + static_cast(iw)}, + {0, 2, 3, 1}); + x->CopyDataFrom(input); + LaunchOp(op, {x_var_name}, {out_var_name_1, out_var_name_2}); + + // compare results + auto* out_data_1 = out_1->mutable_data(); + auto* out_data_2 = out_2->mutable_data(); + auto* out_ref_data_1 = out_ref_1->mutable_data(); + auto* out_ref_data_2 = out_ref_2->mutable_data(); + + Tensor output1, output2; + output1.Resize(out_1->dims()); + output2.Resize(out_2->dims()); + transpose(out_data_1, + output1.mutable_data(), + {static_cast(out_1->dims()[0]), + static_cast(out_1->dims()[2]), + static_cast(out_1->dims()[3]), + static_cast(out_1->dims()[1])}, + {0, 3, 1, 2}); + transpose(out_data_2, + output2.mutable_data(), + {static_cast(out_2->dims()[0]), + static_cast(out_2->dims()[2]), + static_cast(out_2->dims()[3]), + static_cast(out_2->dims()[1])}, + {0, 3, 1, 2}); + out_data_1 = output1.mutable_data(); + out_data_2 = output2.mutable_data(); + for (int i = 0; i < out_1->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data_1[i], out_ref_data_1[i], 5e-4); + } + for (int i = 0; i < out_2->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data_2[i], out_ref_data_2[i], 5e-4); + } +} + +TEST(MLUBridges, split) { + test_split(4, 2, 3, 1, 0, 2, {}); + test_split(4, 2, 3, 1, 0, 0, {3, 1}); + test_split(4, 6, 3, 1, 1, 2, {}); + test_split(4, 6, 3, 1, 1, 0, {2, 4}); + test_split(4, 2, 2, 1, 2, 2, {}); + test_split(4, 2, 6, 1, 2, 0, {3, 3}); + test_split(4, 2, 3, 4, 3, 2, {}); + test_split(4, 2, 3, 6, 3, 0, {5, 1}); +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +USE_SUBGRAPH_BRIDGE(split, kMLU);