From 30c273de98a95508a2b8ab6c34b92973a5368da8 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Thu, 22 Aug 2019 11:17:42 +0800 Subject: [PATCH] port lite code (#1819) --- CMakeLists.txt | 1 - lite/api/CMakeLists.txt | 7 + lite/api/benchmark.cc | 188 ++++++++ lite/api/paddle_use_kernels.h | 12 +- lite/api/paddle_use_ops.h | 6 + lite/arm/math/activation.cc | 34 ++ lite/arm/math/activation.h | 4 + lite/core/mir/graph_visualize_pass.cc | 29 +- lite/core/mir/subgraph/CMakeLists.txt | 2 + .../mir/subgraph/generate_npu_program_pass.cc | 288 +++++------ .../mir/subgraph/generate_npu_program_pass.h | 37 +- .../mir/subgraph/subgraph_program_pass.cc | 73 ++- .../core/mir/subgraph/subgraph_program_pass.h | 10 +- .../subgraph/subgraph_program_pass_test.cc | 83 +++- lite/kernels/arm/CMakeLists.txt | 6 + lite/kernels/arm/activation_compute.cc | 17 +- lite/kernels/arm/activation_compute.h | 10 + lite/kernels/arm/conv_transpose_compute.cc | 8 +- lite/kernels/arm/expand_compute.cc | 72 +++ lite/kernels/arm/expand_compute.h | 34 ++ lite/kernels/arm/matmul_compute.cc | 316 ++++++++++++ lite/kernels/arm/matmul_compute.h | 42 ++ lite/kernels/arm/multiclass_nms_compute.cc | 17 +- .../arm/multiclass_nms_compute_test.cc | 2 + lite/kernels/arm/squeeze_compute.cc | 70 +++ lite/kernels/arm/squeeze_compute.h | 42 ++ lite/kernels/npu/graph_compute.cc | 8 + lite/model_parser/model_parser.cc | 173 +++++-- lite/model_parser/model_parser.h | 14 +- lite/model_parser/naive_buffer/CMakeLists.txt | 7 +- .../naive_buffer/combined_params_desc.cc | 15 + .../naive_buffer/combined_params_desc.h | 63 +++ .../naive_buffer/naive_buffer_wrapper_test.cc | 89 +++- lite/model_parser/naive_buffer/param_desc.cc | 10 + lite/model_parser/naive_buffer/param_desc.h | 4 + .../naive_buffer/proto/framework.nb.h | 3 + lite/npu/bridge/CMakeLists.txt | 18 + lite/npu/bridge/act_op.cc | 5 +- lite/npu/bridge/bilinear_interp_op.cc | 121 +++++ lite/npu/bridge/bilinear_interp_op_test.cc | 314 ++++++++++++ lite/npu/bridge/concat_op.cc | 74 +++ lite/npu/bridge/concat_op_test.cc | 128 +++++ lite/npu/bridge/conv_op.cc | 52 +- lite/npu/bridge/conv_transpose_op.cc | 146 ++++++ lite/npu/bridge/conv_transpose_op_test.cc | 369 ++++++++++++++ lite/npu/bridge/paddle_use_npu_bridges.h | 5 + lite/npu/bridge/reshape_op.cc | 121 +++++ lite/npu/bridge/reshape_op_test.cc | 202 ++++++++ lite/npu/bridge/scale_op.cc | 9 +- lite/npu/bridge/shuffle_channel_op.cc | 58 +++ lite/npu/bridge/shuffle_channel_op_test.cc | 115 +++++ lite/npu/bridge/split_op.cc | 86 ++++ lite/npu/bridge/split_op_test.cc | 170 +++++++ lite/npu/bridge/transpose_op_test.cc | 19 +- lite/npu/bridge/utils.h | 32 +- lite/operators/CMakeLists.txt | 6 + lite/operators/activation_ops.cc | 1 + lite/operators/conv_transpose_op.cc | 14 +- lite/operators/expand_op.cc | 57 +++ lite/operators/expand_op.h | 44 ++ lite/operators/interpolate_op.cc | 9 +- lite/operators/matmul_op.cc | 138 ++++++ lite/operators/matmul_op.h | 50 ++ lite/operators/op_params.h | 26 + lite/operators/squeeze_op.cc | 133 ++++++ lite/operators/squeeze_op.h | 61 +++ lite/tests/kernels/CMakeLists.txt | 3 + lite/tests/kernels/activation_compute_test.cc | 78 ++- .../kernels/conv2d_transpose_compute_test.cc | 10 +- lite/tests/kernels/expand_compute_test.cc | 135 ++++++ lite/tests/kernels/matmul_compute_test.cc | 452 ++++++++++++++++++ lite/tests/kernels/shape_compute_test.cc | 3 +- lite/tests/kernels/squeeze_compute_test.cc | 253 ++++++++++ lite/tools/benchmark.sh | 36 ++ lite/tools/ci_build.sh | 4 + lite/tools/search_support_ops.py | 66 +++ lite/utils/io.h | 5 +- 77 files changed, 5079 insertions(+), 345 deletions(-) create mode 100644 lite/api/benchmark.cc create mode 100644 lite/kernels/arm/expand_compute.cc create mode 100644 lite/kernels/arm/expand_compute.h create mode 100644 lite/kernels/arm/matmul_compute.cc create mode 100644 lite/kernels/arm/matmul_compute.h create mode 100644 lite/kernels/arm/squeeze_compute.cc create mode 100644 lite/kernels/arm/squeeze_compute.h create mode 100644 lite/model_parser/naive_buffer/combined_params_desc.cc create mode 100644 lite/model_parser/naive_buffer/combined_params_desc.h create mode 100644 lite/npu/bridge/bilinear_interp_op.cc create mode 100644 lite/npu/bridge/bilinear_interp_op_test.cc create mode 100644 lite/npu/bridge/concat_op.cc create mode 100644 lite/npu/bridge/concat_op_test.cc create mode 100644 lite/npu/bridge/conv_transpose_op.cc create mode 100644 lite/npu/bridge/conv_transpose_op_test.cc create mode 100644 lite/npu/bridge/reshape_op.cc create mode 100644 lite/npu/bridge/reshape_op_test.cc create mode 100644 lite/npu/bridge/shuffle_channel_op.cc create mode 100644 lite/npu/bridge/shuffle_channel_op_test.cc create mode 100644 lite/npu/bridge/split_op.cc create mode 100644 lite/npu/bridge/split_op_test.cc create mode 100644 lite/operators/expand_op.cc create mode 100644 lite/operators/expand_op.h create mode 100644 lite/operators/matmul_op.cc create mode 100644 lite/operators/matmul_op.h create mode 100644 lite/operators/squeeze_op.cc create mode 100644 lite/operators/squeeze_op.h create mode 100644 lite/tests/kernels/expand_compute_test.cc create mode 100644 lite/tests/kernels/matmul_compute_test.cc create mode 100644 lite/tests/kernels/squeeze_compute_test.cc create mode 100644 lite/tools/benchmark.sh create mode 100644 lite/tools/search_support_ops.py diff --git a/CMakeLists.txt b/CMakeLists.txt index a829ec6bf6..bdc212835a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,7 +173,6 @@ include(ccache) # set ccache for compilation include(util) # set unittest and link libs include(version) # set PADDLE_VERSION - set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 7dc74e1570..5212d7a4ca 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -211,6 +211,13 @@ if(NOT IOS) CL_DEPS ${opencl_kernels} FPGA_DEPS ${fpga_kernels} X86_DEPS ${x86_kernels}) + lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags + ${ops} + ARM_DEPS ${arm_kernels} + NPU_DEPS ${npu_kernels} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels} + X86_DEPS ${x86_kernels}) endif() #lite_cc_binary(cxx_api_bin SRCS cxx_api_bin.cc diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc new file mode 100644 index 0000000000..42f89e7e66 --- /dev/null +++ b/lite/api/benchmark.cc @@ -0,0 +1,188 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/cpu_info.h" +#include "lite/utils/cp_logging.h" +#include "lite/utils/string.h" + +DEFINE_string(input_shape, + "1,3,224,224", + "input shapes, separated by colon and comma"); +DEFINE_string(result_filename, "", "save test result"); + +namespace paddle { +namespace lite_api { + +void OutputOptModel(const std::string& load_model_dir, + const std::string& save_optimized_model_dir, + const std::vector>& input_shapes) { + lite_api::CxxConfig config; + config.set_model_dir(load_model_dir); + config.set_preferred_place(Place{TARGET(kX86), PRECISION(kFloat)}); + config.set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + auto predictor = lite_api::CreatePaddlePredictor(config); + + // delete old optimized model + int ret = system( + paddle::lite::string_format("rm -rf %s", save_optimized_model_dir.c_str()) + .c_str()); + if (ret == 0) { + LOG(INFO) << "delete old optimized model " << save_optimized_model_dir; + } + predictor->SaveOptimizedModel(save_optimized_model_dir, + LiteModelType::kNaiveBuffer); + LOG(INFO) << "Load model from " << load_model_dir; + LOG(INFO) << "Save optimized model to " << save_optimized_model_dir; +} + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +void Run(const std::vector>& input_shapes, + const std::string& model_dir, + const int repeat, + const int thread_num, + const int warmup_times, + const std::string model_name) { +#ifdef LITE_WITH_ARM + lite::DeviceInfo::Init(); + if (thread_num == 1) { + lite::DeviceInfo::Global().SetRunMode(lite::LITE_POWER_HIGH, thread_num); + LOG(INFO) << "LITE_POWER_HIGH"; + } else { + lite::DeviceInfo::Global().SetRunMode(lite::LITE_POWER_NO_BIND, thread_num); + LOG(INFO) << "LITE_POWER_NO_BIND"; + } +#endif + lite_api::MobileConfig config; + config.set_model_dir(model_dir); + + auto predictor = lite_api::CreatePaddlePredictor(config); + + for (int j = 0; j < input_shapes.size(); ++j) { + auto input_tensor = predictor->GetInput(j); + input_tensor->Resize(input_shapes[j]); + auto input_data = input_tensor->mutable_data(); + int input_num = 1; + for (int i = 0; i < input_shapes[j].size(); ++i) { + input_num *= input_shapes[j][i]; + } + for (int i = 0; i < input_num; ++i) { + input_data[i] = 1.f; + } + } + + for (int i = 0; i < warmup_times; ++i) { + predictor->Run(); + } + + auto start = lite::GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + predictor->Run(); + } + auto end = lite::GetCurrentUS(); + + std::FILE* pf = std::fopen(FLAGS_result_filename.c_str(), "a"); + if (nullptr == pf) { + LOG(INFO) << "create result file error"; + exit(0); + } + fprintf(pf, + "-- %-18s avg = %5.4f ms\n", + model_name.c_str(), + (end - start) / repeat / 1000.0); + std::fclose(pf); +} +#endif + +} // namespace lite_api +} // namespace paddle + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (FLAGS_model_dir == "" || FLAGS_result_filename == "") { + LOG(INFO) << "usage: " + << "--model_dir /path/to/your/model --result_filename " + "/path/to/resultfile"; + exit(0); + } + + std::size_t found = FLAGS_model_dir.find_last_of("/"); + std::string model_name = FLAGS_model_dir.substr(found + 1); + std::string save_optimized_model_dir = FLAGS_model_dir + "opt2"; + + auto split_string = + [](const std::string& str_in) -> std::vector { + std::vector str_out; + std::string tmp_str = str_in; + while (!tmp_str.empty()) { + size_t next_offset = tmp_str.find(":"); + str_out.push_back(tmp_str.substr(0, next_offset)); + if (next_offset == std::string::npos) { + break; + } else { + tmp_str = tmp_str.substr(next_offset + 1); + } + } + return str_out; + }; + + auto get_shape = [](const std::string& str_shape) -> std::vector { + std::vector shape; + std::string tmp_str = str_shape; + while (!tmp_str.empty()) { + int dim = atoi(tmp_str.data()); + shape.push_back(dim); + size_t next_offset = tmp_str.find(","); + if (next_offset == std::string::npos) { + break; + } else { + tmp_str = tmp_str.substr(next_offset + 1); + } + } + return shape; + }; + + std::vector str_input_shapes = split_string(FLAGS_input_shape); + std::vector> input_shapes; + for (int i = 0; i < str_input_shapes.size(); ++i) { + input_shapes.push_back(get_shape(str_input_shapes[i])); + } + + // Output optimized model + paddle::lite_api::OutputOptModel( + FLAGS_model_dir, save_optimized_model_dir, input_shapes); + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + // Run inference using optimized model + paddle::lite_api::Run(input_shapes, + save_optimized_model_dir, + FLAGS_repeats, + FLAGS_threads, + FLAGS_warmup, + model_name); +#endif + return 0; +} diff --git a/lite/api/paddle_use_kernels.h b/lite/api/paddle_use_kernels.h index d18a86a8a7..f2fe0ce34f 100644 --- a/lite/api/paddle_use_kernels.h +++ b/lite/api/paddle_use_kernels.h @@ -21,16 +21,19 @@ #ifndef LITE_WITH_FPGA USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); -USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); -USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); #else USE_LITE_KERNEL(feed, kFPGA, kFP16, kNHWC, def); USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def); #endif +// host kernels +USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); +USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); + #ifdef LITE_WITH_ARM USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(matmul, kARM, kFloat, kNCHW, def); // for x2paddle USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(lrn, kARM, kFloat, kNCHW, def); @@ -49,6 +52,7 @@ USE_LITE_KERNEL(dropout, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(concat, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(relu6, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(transpose, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(transpose2, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); @@ -64,6 +68,7 @@ USE_LITE_KERNEL(sigmoid, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(tanh, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(swish, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(log, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(exp, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(conv2d_transpose, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(pad2d, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(prior_box, kARM, kFloat, kNCHW, def); @@ -91,6 +96,9 @@ USE_LITE_KERNEL(shape, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(fill_constant, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(cast, kARM, kFloat, kNCHW, def) USE_LITE_KERNEL(slice, kARM, kFloat, kNCHW, def) +USE_LITE_KERNEL(squeeze, kARM, kFloat, kNCHW, def) // for x2paddle +USE_LITE_KERNEL(squeeze2, kARM, kFloat, kNCHW, def) // for x2paddle +USE_LITE_KERNEL(expand, kARM, kFloat, kNCHW, def) // for x2paddle USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8); USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32); diff --git a/lite/api/paddle_use_ops.h b/lite/api/paddle_use_ops.h index d11afb358b..cf01c74adb 100644 --- a/lite/api/paddle_use_ops.h +++ b/lite/api/paddle_use_ops.h @@ -19,8 +19,10 @@ #include "paddle_lite_factory_helper.h" // NOLINT USE_LITE_OP(mul); +USE_LITE_OP(matmul); // for x2paddle USE_LITE_OP(fc); USE_LITE_OP(relu); +USE_LITE_OP(relu6); USE_LITE_OP(scale); USE_LITE_OP(feed); USE_LITE_OP(lrn); @@ -56,6 +58,7 @@ USE_LITE_OP(sigmoid) USE_LITE_OP(tanh) USE_LITE_OP(swish) USE_LITE_OP(log) +USE_LITE_OP(exp) USE_LITE_OP(conv2d_transpose) USE_LITE_OP(negative) USE_LITE_OP(pad2d) @@ -104,3 +107,6 @@ USE_LITE_OP(is_empty) USE_LITE_OP(shape) USE_LITE_OP(slice) USE_LITE_OP(cast) +USE_LITE_OP(squeeze) // for x2paddle +USE_LITE_OP(squeeze2) // for x2paddle +USE_LITE_OP(expand) // for x2paddle diff --git a/lite/arm/math/activation.cc b/lite/arm/math/activation.cc index b5df8e793c..009c778aee 100644 --- a/lite/arm/math/activation.cc +++ b/lite/arm/math/activation.cc @@ -632,6 +632,40 @@ void act_log(const float* din, float* dout, int size, int threads) { } } +template <> +void act_exp(const float* din, float* dout, int size, int threads) { + int nums_per_thread = size / threads; + int remain = size - threads * nums_per_thread; + int neon_loop_cnt_dim4 = nums_per_thread >> 2; + int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2); + + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < threads; ++i) { + float32x4_t exp_vec = vdupq_n_f32(0.0f); + const float* ptr_in_thread = din + i * nums_per_thread; + float* ptr_out_thread = dout + i * nums_per_thread; + for (int k = 0; k < neon_loop_cnt_dim4; ++k) { + exp_vec = exp_ps(vld1q_f32(ptr_in_thread)); + vst1q_f32(ptr_out_thread, exp_vec); + ptr_out_thread += 4; + ptr_in_thread += 4; + } + for (int j = 0; j < neon_loop_remain_dim4; ++j) { + ptr_out_thread[0] = expf(ptr_in_thread[0]); + ptr_in_thread++; + ptr_out_thread++; + } + } + float* ptr_out = dout + threads * nums_per_thread; + const float* ptr_in = din + threads * nums_per_thread; + for (int j = 0; j < remain; ++j) { + ptr_out[0] = expf(ptr_in[0]); + ptr_in++; + ptr_out++; + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/arm/math/activation.h b/lite/arm/math/activation.h index c22c963cc1..3a1530b8db 100644 --- a/lite/arm/math/activation.h +++ b/lite/arm/math/activation.h @@ -51,6 +51,10 @@ void act_swish(const T* din, T* dout, int size, float coef, int threads); template void act_log(const T* din, T* dout, int size, int threads); + +template +void act_exp(const T* din, T* dout, int size, int threads); + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/core/mir/graph_visualize_pass.cc b/lite/core/mir/graph_visualize_pass.cc index 1aa9ea77d0..74245ad11d 100644 --- a/lite/core/mir/graph_visualize_pass.cc +++ b/lite/core/mir/graph_visualize_pass.cc @@ -13,9 +13,11 @@ // limitations under the License. #include "lite/core/mir/graph_visualize_pass.h" +#include #include #include #include +#include #include "lite/core/mir/pass_registry.h" #include "lite/utils/string.h" @@ -34,7 +36,15 @@ std::string Visualize(mir::SSAGraph* graph) { int id = 0; std::set exists_args; - + std::map graph_col; // Different colors of subgraphs + graph_col.insert({{1, "red"}, + {2, "green"}, + {3, "cyan"}, + {4, "bisque3"}, + {5, "coral"}, + {6, "darkseagreen1"}, + {7, "goldenrod1"}, + {8, "darkorchid"}}); for (auto& node : graph->mutable_nodes()) { std::string key; if (node.IsArg()) { @@ -44,7 +54,22 @@ std::string Visualize(mir::SSAGraph* graph) { } if (node.IsStmt()) { - dot.AddNode(key, {Dot::Attr("shape", "box")}); + auto& stmt = node.AsStmt(); + auto sub_id = stmt.subgraph_id(); + auto it = graph_col.find(sub_id); + if (sub_id > 0 && it != graph_col.end()) { + dot.AddNode(key, + {Dot::Attr("shape", "box"), + Dot::Attr("style", "filled"), + Dot::Attr("color", "black"), + Dot::Attr("fillcolor", it->second)}); + } else { + dot.AddNode(key, + {Dot::Attr("shape", "box"), + Dot::Attr("style", "filled"), + Dot::Attr("color", "black"), + Dot::Attr("fillcolor", "yellow")}); + } for (auto& x : node.inlinks) { auto name = x->AsArg().name; if (!exists_args.count(name)) { diff --git a/lite/core/mir/subgraph/CMakeLists.txt b/lite/core/mir/subgraph/CMakeLists.txt index 4b4eb562c7..fa2af906b2 100644 --- a/lite/core/mir/subgraph/CMakeLists.txt +++ b/lite/core/mir/subgraph/CMakeLists.txt @@ -7,6 +7,7 @@ lite_cc_test(test_subgraph_pass SRCS subgraph_program_pass_test.cc ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL) if (WITH_TESTING) add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v1_tar_gz) + add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz) set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") set_target_properties(test_subgraph_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") endif() @@ -23,6 +24,7 @@ if(LITE_WITH_NPU) --optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL) if (WITH_TESTING) add_dependencies(test_npu_pass extern_lite_download_mobilenet_v1_tar_gz) + add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz) set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") set_target_properties(test_npu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") endif() diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.cc b/lite/core/mir/subgraph/generate_npu_program_pass.cc index 798faf4544..d4370837c0 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.cc +++ b/lite/core/mir/subgraph/generate_npu_program_pass.cc @@ -14,6 +14,7 @@ #include "lite/core/mir/subgraph/generate_npu_program_pass.h" #include +#include #include #include #include @@ -36,182 +37,143 @@ namespace lite { namespace mir { namespace subgraph { -// call convert function from start node -// return if convert success and the nodes to remove -// return the output npu op -lite::npu::bridge::node_map_type GenerateNPUProgramPass::CvtOpNodes( - const lite::npu::bridge::cvt_map_type& cvtfunc_map, - const Node* op_node, - const lite::npu::bridge::node_map_type& inputs_map, - int sub_id, - std::unordered_set* nodes2rm, - key2nodes_t* matched) { - lite::npu::bridge::node_map_type failed; - if (!op_node->IsStmt()) { - LOG(INFO) << "stop return failed"; - return failed; - } - auto* stmt = op_node->stmt(); - auto op_type = stmt->op_type(); - LOG(INFO) << "cvt op type: " << op_type; - - if (stmt->subgraph_id() != sub_id) { - LOG(INFO) << "return as subgraph_id(" << stmt->subgraph_id() - << ") != sub_id(" << sub_id << ")"; - return failed; - } else { - CHECK(cvtfunc_map.count(op_type)) << "Should be supported " << op_type - << ", with subgraph_id: " << sub_id; - } - - auto outputs_map = cvtfunc_map.at(op_type)(stmt->op(), inputs_map); - if (outputs_map.empty()) { - return outputs_map; - } - - nodes2rm->insert(op_node); - for (auto& var_node : op_node->outlinks) { - for (auto& next_op_node : var_node->outlinks) { - LOG(INFO) << "next op type: " << next_op_node->AsStmt().op_type(); - if (next_op_node->AsStmt().subgraph_id() != sub_id) { - // this is the end condition - // TODO(TJ): when enable more inputs and outputs this is bugy - LOG(INFO) << "--- should return once ---"; - // TODO(TJ): matched output could be vector - matched->insert(std::make_pair("Output", var_node)); - return outputs_map; - } else { - // LOG(INFO) << "argnames: "; - // for (auto sss : next_op_node->AsStmt().op_info()->input_argnames()) { - // LOG(INFO) << sss; - // } - // LOG(INFO) << "input argnames: "; - // for (auto sss : next_op_node->AsStmt().op_info()->input_names()) { - // LOG(INFO) << sss; - // } - for (auto& i_node : next_op_node->inlinks) { - CHECK(i_node->IsArg()); - auto& arg = i_node->AsArg(); - LOG(INFO) << arg.name; - if (outputs_map.count(arg.name)) continue; - if (!arg.is_weight) { - LOG(INFO) << "Data arg name:" << arg.name; - outputs_map.insert(std::make_pair( - arg.name, - lite::npu::bridge::CvtNode( - i_node, next_op_node->AsStmt().op()->scope()))); - } - } - nodes2rm->insert(var_node); - return CvtOpNodes( - cvtfunc_map, next_op_node, outputs_map, sub_id, nodes2rm, matched); - } +void GenerateNPUProgramPass::NPUSortHelper( + Node* node, + const std::unordered_set& nodes_all, + std::unordered_set* visited_nodes, + std::vector* ret) { + for (auto& var_node : node->inlinks) { + if (var_node->inlinks.empty()) continue; + auto* op_node = var_node->inlinks.front(); + if (nodes_all.count(op_node) && !visited_nodes->count(op_node)) { + NPUSortHelper(op_node, nodes_all, visited_nodes, ret); } } + ret->push_back(node); + visited_nodes->insert(node); } -void GenerateNPUProgramPass::ConvertSubgraph( - const std::unique_ptr& graph, int sub_num) { +void GenerateNPUProgramPass::CvtOpNodes( + const std::vector& nodes2cvt, + std::vector* in_vars_name, + std::vector* out_vars_name, + lite::npu::bridge::node_map_type* cvted_vars, + std::unordered_set* nodes2rm) { const auto& bridges = lite::npu::bridge::Factory::Instance(); const auto& cvtfunc_map = bridges.AllFunctions(); - std::unordered_set nodes2rm_all; - - auto items = graph->StmtTopologicalOrder(); - for (int id = 1; id <= sub_num; ++id) { - LOG(INFO) << "Converting subgraph_id:" << id; - for (auto& op_node : items) { - std::unordered_set nodes2rm; - if (!op_node->IsStmt()) continue; - auto& stmt = op_node->AsStmt(); - if (stmt.subgraph_id() != id) continue; - CHECK(bridges.HasType(stmt.op_type())); - key2nodes_t matched; - matched["target_op"] = op_node; - auto& op = stmt.op(); - auto* scope = op->scope(); - // prepare inputs data. - std::string data_name = "data_subgraph_" + std::to_string(id); - lite::npu::bridge::node_map_type npu_inputs_map; - int name_id = 0; - LOG(INFO) << "op_type: " << stmt.op_type(); - std::vector actual_input_argnames; - for (auto& arg_node : op_node->inlinks) { - CHECK(arg_node->IsArg()); - const auto& arg = arg_node->AsArg(); - if (!arg_node->AsArg().is_weight) { - LOG(INFO) << "Input arg name: " << arg.name; - npu_inputs_map.insert(std::make_pair( - arg.name, lite::npu::bridge::CvtNode(arg_node, scope))); - // TODO(TJ): Here matched inputs should also be input vector - matched["Input"] = arg_node; - name_id++; - } + for (auto& node : nodes2cvt) { + lite::npu::bridge::node_map_type node_inputs; + auto& stmt = node->AsStmt(); + for (auto& var_node : node->inlinks) { + auto& arg = var_node->AsArg(); + auto var_name = arg.name; + if (!cvted_vars->count(var_name)) { + if (arg.is_weight) continue; + cvted_vars->insert(std::make_pair( + var_name, + lite::npu::bridge::CvtNode(var_node, stmt.op()->scope()))); + in_vars_name->push_back(var_name); } - CHECK_EQ(name_id, 1) << "mobilenetv1 only have one input data!"; - auto npu_outputs_map = CvtOpNodes( - cvtfunc_map, op_node, npu_inputs_map, id, &nodes2rm, &matched); - if (!npu_outputs_map.empty()) { - LOG(INFO) << "[NPU] subgraph " << id << ": output not empty "; - std::vector inputs; - std::vector outputs; - for (auto& i : npu_inputs_map) { - LOG(INFO) << "input data argname:" << i.first - << ", ptr: " << i.second; - inputs.emplace_back(*(i.second)); - } - for (auto& i : npu_outputs_map) { - LOG(INFO) << "output data argname:" << i.first - << ", ptr: " << i.second; - outputs.emplace_back(*(i.second)); - } - - std::string model_name("hiai_npu_client_" + std::to_string(id) + ".om"); - if (!npu::BuildNPUClient(inputs, outputs, model_name)) { - // build failed, so this subgraph is abandoned - nodes2rm.clear(); - LOG(WARNING) << "Build NPU failed subgraph " << id; + node_inputs.insert(*cvted_vars->find(var_name)); + } + auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs); + cvted_vars->insert(node_outputs.begin(), node_outputs.end()); + nodes2rm->insert(node); + for (auto& var_node : node->outlinks) { + for (auto& next_op_node : var_node->outlinks) { + if (std::find(nodes2cvt.begin(), nodes2cvt.end(), next_op_node) == + nodes2cvt.end()) { + out_vars_name->push_back(var_node->AsArg().name); break; } - LOG(INFO) << "[NPU] Build NPU Client success subgraph " << id; + } + } + } +} + +void GenerateNPUProgramPass::GenNPUGraphOpNode( + const std::unique_ptr& graph, + int sub_id, + const std::unordered_set& nodes_all) { + std::unordered_set visited_nodes; + std::vector ret; + for (auto& node : nodes_all) { + if (!node->IsStmt()) continue; + if (visited_nodes.count(node)) continue; + NPUSortHelper(node, nodes_all, &visited_nodes, &ret); + } - // Then InsertNewNode(graph, matched); make one function - cpp::OpDesc op_desc; - op_desc.SetType("graph_op"); - // change to vectors - op_desc.SetInput("Inputs", {matched.at("Input")->arg()->name}); - op_desc.SetOutput("Outputs", {matched.at("Output")->arg()->name}); - op_desc.SetAttr("model_name", model_name); - auto graph_op = LiteOpRegistry::Global().Create("graph_op"); - auto target_op = matched.at("target_op")->stmt()->op(); - auto* scope = target_op->scope(); - CHECK(scope); - CHECK(graph_op); - graph_op->Attach(op_desc, scope); + std::vector in_vars_name; + std::vector out_vars_name; + lite::npu::bridge::node_map_type cvted_vars; + std::unordered_set nodes2rm; + CvtOpNodes(ret, &in_vars_name, &out_vars_name, &cvted_vars, &nodes2rm); + // insert new graph op node + std::vector inputs; + std::vector outputs; + for (auto i : in_vars_name) { + inputs.push_back(*cvted_vars.at(i)); + } + for (auto i : out_vars_name) { + outputs.push_back(*cvted_vars.at(i)); + } + std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om"); + if (!npu::BuildNPUClient(inputs, outputs, model_name)) { + LOG(FATAL) << "Build NPU failed subgraph " << sub_id; + } + LOG(INFO) << "[NPU] Build NPU Client success subgraph " << sub_id; + + cpp::OpDesc op_desc; + op_desc.SetType("graph_op"); + op_desc.SetInput("Inputs", in_vars_name); + op_desc.SetOutput("Outputs", out_vars_name); + op_desc.SetAttr("model_name", model_name); + auto graph_op = LiteOpRegistry::Global().Create("graph_op"); + // TODO(zpy): support multi inputs op + auto start_op = ret.front()->AsStmt().op(); + auto* scope = start_op->scope(); + graph_op->Attach(op_desc, scope); + + auto valid_places = start_op->valid_places(); + auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places); + + for (auto& var_node : ret.front()->inlinks) { + auto& arg = var_node->AsArg(); + if (arg.is_weight) continue; + IR_NODE_LINK_TO(var_node, new_op_node); + } + for (auto& var_node : ret.back()->outlinks) { + auto& arg = var_node->AsArg(); + if (arg.is_weight) continue; + IR_NODE_LINK_TO(var_node, new_op_node); + } - auto valid_places = - target_op->valid_places(); // TODO(TJ): add npu place? - auto* new_op_node = - graph->GraphCreateInstructNode(graph_op, valid_places); + // assign context + auto& inst = new_op_node->AsStmt(); + inst.picked_kernel().SetContext( + ContextScheduler::Global().NewContext(inst.picked_kernel().target())); - IR_NODE_LINK_TO(matched.at("Input"), new_op_node); - IR_NODE_LINK_TO(new_op_node, matched.at("Output")); + GraphSafeRemoveNodes(graph.get(), nodes2rm); +} - // assign context - auto& inst = new_op_node->AsStmt(); - inst.picked_kernel().SetContext(ContextScheduler::Global().NewContext( - inst.picked_kernel().target())); +void GenerateNPUProgramPass::ConvertSubgraph( + const std::unique_ptr& graph, int sub_num) { + std::unordered_map> nodes_all; + for (auto& item : graph->StmtTopologicalOrder()) { + if (!item->IsStmt()) continue; + auto& stmt = item->AsStmt(); + int sub_id = stmt.subgraph_id(); + if (sub_id < 1) continue; + if (nodes_all.count(sub_id) == 0) { + nodes_all[sub_id] = std::unordered_set(); + } + nodes_all.at(sub_id).insert(item); + } - if (!nodes2rm.empty()) { - nodes2rm_all.insert(nodes2rm.begin(), nodes2rm.end()); - } - break; - } // if npu output success - } // for op_nodes - } // for subgraph id - // remove all unused node once - GraphSafeRemoveNodes(graph.get(), nodes2rm_all); - // clear all npu ops - npu::OpList::Global().clear(); + for (int id = 1; id <= sub_num; ++id) { + LOG(INFO) << "Converting subgraph_id:" << id; + GenNPUGraphOpNode(graph, id, nodes_all.at(id)); + } } void GenerateNPUProgramPass::Apply(const std::unique_ptr& graph) { @@ -228,8 +190,6 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr& graph) { InferOnce(graph); ConvertSubgraph(graph, num_subgraph); - // auto graph1 = GenerateFusedGraph(std::move(graph)); - // GraphSafeRemoveNodes(graph, nodes2rm); LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get()); for (auto& item : graph->StmtTopologicalOrder()) { diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.h b/lite/core/mir/subgraph/generate_npu_program_pass.h index 908190e4e9..0ce60fb22b 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.h +++ b/lite/core/mir/subgraph/generate_npu_program_pass.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include "lite/core/mir/pass.h" @@ -37,23 +38,27 @@ class GenerateNPUProgramPass : public SubgraphProgramPass { std::unique_ptr GenProgram(); protected: - // TODO(TJ): maybe change a name - // convert all fused subgraphs to npu clients - // 1. if some subgraph failed, then skip. - // 2. add new graph nodes, kernels and context - // 3. remove unused nodes - void ConvertSubgraph(const std::unique_ptr& graph, int sub_num); + void NPUSortHelper(Node* node, + const std::unordered_set& nodes_all, + std::unordered_set* visited_nodes, + std::vector* ret); + + // nodes2cvt: op nodes to convert + // in_vars_name: graph op's inputs var name + // out_vars_name: graph op's outputs var name + // vcted_vars: + // nodes2rm: op nodes and var nodes that need to be removed + void CvtOpNodes(const std::vector& nodes2cvt, + std::vector* in_vars_name, + std::vector* out_vars_name, + lite::npu::bridge::node_map_type* cvted_vars, + std::unordered_set* nodes2rm); - // call convert function from start node - // return if convert success and the nodes to remove - // return the output(arg.name, npu op) - lite::npu::bridge::node_map_type CvtOpNodes( - const lite::npu::bridge::cvt_map_type& cvtfunc_map, - const Node* op_node, - const lite::npu::bridge::node_map_type& inputs_map, - int sub_id, - std::unordered_set* nodes2rm, - key2nodes_t* matched); + void GenNPUGraphOpNode(const std::unique_ptr& graph, + int sub_id, + const std::unordered_set& nodes_all); + + void ConvertSubgraph(const std::unique_ptr& graph, int sub_num); private: std::vector insts_; diff --git a/lite/core/mir/subgraph/subgraph_program_pass.cc b/lite/core/mir/subgraph/subgraph_program_pass.cc index 5816eefe18..91edadf895 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass.cc +++ b/lite/core/mir/subgraph/subgraph_program_pass.cc @@ -85,21 +85,31 @@ void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node, for (auto& i : node->outlinks) { if (!i->IsStmt()) return; auto& stmt = i->AsStmt(); - if (stmt.subgraph_id() != from_id) { + if (stmt.subgraph_id() < from_id) { all_out_op_supported = false; } } if (!all_out_op_supported) { return; } + nodes2rm_[to_id].insert(node); for (auto& i : node->outlinks) { CHECK(i->IsStmt()); auto& stmt = i->AsStmt(); - CHECK_EQ(stmt.subgraph_id(), from_id); - stmt.SetSubgraphID(to_id); - nodes2rm_[to_id].insert(i); - for (auto& o : i->outlinks) { - ChangeAllOutConnectedID(o, to_id, from_id); + if (stmt.subgraph_id() == from_id) { + stmt.SetSubgraphID(to_id); + nodes2rm_[to_id].insert(i); + for (auto& o : i->outlinks) { + for (auto& j : o->outlinks) { + if (j->IsStmt()) { + auto& Nstmt = j->AsStmt(); + if (Nstmt.subgraph_id() < from_id) { + o_nodes_[to_id].insert(o); + } + } + } + ChangeAllOutConnectedID(o, to_id, from_id); + } } } } @@ -109,12 +119,62 @@ int SubgraphProgramPass::FuseSubgraphID( const std::unique_ptr& graph) { int sub_id = 1; // id start from 1 not 0 for (auto& item : graph->StmtTopologicalOrder()) { + bool inputvar = 0; if (!item->IsStmt()) continue; auto& stmt = item->AsStmt(); + if (stmt.subgraph_id() == -1) { + for (auto& i : item->outlinks) { + for (auto& j : i->outlinks) { + if (j->IsStmt()) { + auto& jstmt = j->AsStmt(); + // LOG(INFO) << "initial: "<outlinks) i_nodes_[sub_id].insert(i); + } + } if (stmt.subgraph_id() != 0) continue; ChangeAllOutConnectedID(item, sub_id); sub_id++; } + for (auto& i : nodes2rm_) { + for (auto& item : i.second) { + if (item->IsStmt()) { + auto& stmt = item->AsStmt(); + LOG(INFO) << "nodes2rm_:" << stmt.op_type(); + } else if (item->IsArg()) { + auto& arg = item->AsArg(); + LOG(INFO) << "nodes2rm_:" << arg.name; + } + } + } + for (auto& i : i_nodes_) { + for (auto& item : i.second) { + if (item->IsStmt()) { + auto& stmt = item->AsStmt(); + LOG(INFO) << "i_nodes_: " << i.first << " " << stmt.op_type(); + } else if (item->IsArg()) { + auto& arg = item->AsArg(); + LOG(INFO) << "i_nodes_: " << i.first << " " << arg.name; + } + } + } + for (auto& i : o_nodes_) { + for (auto& item : i.second) { + if (item->IsStmt()) { + auto& stmt = item->AsStmt(); + LOG(INFO) << "o_nodes_:" << i.first << " " << stmt.op_type(); + } else if (item->IsArg()) { + auto& arg = item->AsArg(); + LOG(INFO) << "o_nodes_: " << i.first << " " << arg.name; + } + } + } return sub_id - 1; } @@ -129,7 +189,6 @@ int SubgraphProgramPass::FuseSubgraph( LOG(INFO) << "detected " << num_subgraph << " subgraph"; return num_subgraph; } - } // namespace subgraph } // namespace mir } // namespace lite diff --git a/lite/core/mir/subgraph/subgraph_program_pass.h b/lite/core/mir/subgraph/subgraph_program_pass.h index 4348c3439f..e80f87333e 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass.h +++ b/lite/core/mir/subgraph/subgraph_program_pass.h @@ -57,11 +57,15 @@ class SubgraphProgramPass : public ProgramPass { private: // {1: {nodes2rm_in_subgraph1, ...}, // 2: {nodes2rm_in_subgraph2, ...}} - std::unordered_map> nodes2rm_; + // delete nodes + std::unordered_map> nodes2rm_; + // std::unordered_map> nodes2rm_; // inputs nodes - std::unordered_map> i_nodes_; + std::unordered_map> i_nodes_; + // std::unordered_map> i_nodes_; // outputs nodes - std::unordered_map> o_nodes_; + std::unordered_map> o_nodes_; + // std::unordered_map> o_nodes_; }; } // namespace subgraph diff --git a/lite/core/mir/subgraph/subgraph_program_pass_test.cc b/lite/core/mir/subgraph/subgraph_program_pass_test.cc index f9b8dd38b1..2acf0f13aa 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass_test.cc +++ b/lite/core/mir/subgraph/subgraph_program_pass_test.cc @@ -29,7 +29,7 @@ DEFINE_string(model_dir, "", "model_dir"); namespace paddle { namespace lite { -TEST(SubgraphTest, mobilenetv1) { +TEST(SubgraphTest, mobilenetv2) { cpp::ProgramDesc program_desc; auto scope = std::make_shared(); LoadModelPb(FLAGS_model_dir, scope.get(), &program_desc); @@ -46,7 +46,8 @@ TEST(SubgraphTest, mobilenetv1) { auto graph = std::unique_ptr(new mir::SSAGraph()); graph->Build(program, valid_places); - std::vector supported_op_types{"conv2d", + std::vector supported_op_types{"concat", + "conv2d", "depthwise_conv2d", "batch_norm", "scale", @@ -54,9 +55,13 @@ TEST(SubgraphTest, mobilenetv1) { "mul", "elementwise_add", "softmax", - "relu"}; + "split", + "relu", + "reshape2", + "transpose2"}; auto* pass = new mir::subgraph::SubgraphProgramPass; ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1); + LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get()); } // return output_var_names @@ -99,6 +104,77 @@ std::vector AddFCDesc( return out_var_names; } +std::vector AddElementwiseAddDesc( + cpp::BlockDesc* block_desc, + const std::shared_ptr& scope, + const std::vector& input_X_names, + const std::vector& input_Y_names) { + // CHECK_EQ(input_var_names.size(), 2); + static int id = 0; + std::string prefix = "elementwise_add_" + std::to_string(id); + auto* op_desc = block_desc->AddOp(); + auto* out = block_desc->AddVar(); + + out->SetName(prefix + "_Out"); + std::vector out_var_names{prefix + "_Out"}; + + scope->Var(prefix + "_Out")->GetMutable(); + + op_desc->SetType("elementwise_add"); + op_desc->SetInput("X", input_X_names); + op_desc->SetInput("Y", input_Y_names); + op_desc->SetOutput("Out", out_var_names); + op_desc->SetAttr("axis", -1); + id++; + return out_var_names; +} + +std::vector AddFeedDesc( + cpp::BlockDesc* block_desc, + const std::shared_ptr& scope, + const std::vector& input_X_names) { + // CHECK_EQ(input_var_names.size(), 1); + static int id = 0; + std::string prefix = "feed_" + std::to_string(id); + auto* op_desc = block_desc->AddOp(); + auto* out = block_desc->AddVar(); + + out->SetName(prefix + "_Out"); + std::vector out_var_names{prefix + "_Out"}; + + scope->Var(prefix + "_Out")->GetMutable(); + + op_desc->SetType("feed"); + op_desc->SetInput("X", input_X_names); + op_desc->SetOutput("Out", out_var_names); + op_desc->SetAttr("col", 1); + id++; + return out_var_names; +} + +std::vector AddFetchDesc( + cpp::BlockDesc* block_desc, + const std::shared_ptr& scope, + const std::vector& input_X_names) { + // CHECK_EQ(input_var_names.size(), 1); + static int id = 0; + std::string prefix = "fetch_" + std::to_string(id); + auto* op_desc = block_desc->AddOp(); + auto* out = block_desc->AddVar(); + + out->SetName(prefix + "_Out"); + std::vector out_var_names{prefix + "_Out"}; + + scope->Var(prefix + "_Out")->GetMutable(); + + op_desc->SetType("fetch"); + op_desc->SetInput("X", input_X_names); + op_desc->SetOutput("Out", out_var_names); + op_desc->SetAttr("col", 1); + id++; + return out_var_names; +} + std::unique_ptr BuildSimpleNet( cpp::ProgramDesc* program_desc, const std::shared_ptr& scope, @@ -134,6 +210,7 @@ TEST(SubGraphTest, SimpleNet) { const int num_nodes = graph->nodes().size(); ASSERT_EQ(graph->nodes().size(), 9); + // LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get()); } } // namespace lite diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index a59737809a..99098102dc 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -7,6 +7,7 @@ message(STATUS "compile with lite ARM kernels") lite_cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_library(activation_compute_arm SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(matmul_compute_arm SRCS matmul_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -59,6 +60,8 @@ lite_cc_library(is_empty_compute_arm SRCS is_empty_compute.cc DEPS ${lite_kernel lite_cc_library(shape_compute_arm SRCS shape_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_library(slice_compute_arm SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_library(cast_compute_arm SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(squeeze_compute_arm SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(expand_compute_arm SRCS expand_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) @@ -84,6 +87,7 @@ set(arm_kernels fc_compute_arm activation_compute_arm mul_compute_arm + matmul_compute_arm scale_compute_arm softmax_compute_arm conv_compute_arm @@ -136,6 +140,8 @@ set(arm_kernels shape_compute_arm slice_compute_arm cast_compute_arm + squeeze_compute_arm + expand_compute_arm ) set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") diff --git a/lite/kernels/arm/activation_compute.cc b/lite/kernels/arm/activation_compute.cc index 6a56633965..9e7379bcae 100644 --- a/lite/kernels/arm/activation_compute.cc +++ b/lite/kernels/arm/activation_compute.cc @@ -127,6 +127,16 @@ void LogCompute::Run() { x_data, output_data, x_dims.production(), ctx.threads()); } +void ExpCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_exp( + x_data, output_data, x_dims.production(), ctx.threads()); +} + } // namespace arm } // namespace kernels } // namespace lite @@ -185,7 +195,7 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); REGISTER_LITE_KERNEL( - relu6, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ReluCompute, def) + relu6, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::Relu6Compute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); @@ -194,3 +204,8 @@ REGISTER_LITE_KERNEL( .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +REGISTER_LITE_KERNEL( + exp, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ExpCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/activation_compute.h b/lite/kernels/arm/activation_compute.h index 9360528b81..eac169d97f 100644 --- a/lite/kernels/arm/activation_compute.h +++ b/lite/kernels/arm/activation_compute.h @@ -102,6 +102,16 @@ class LogCompute : public KernelLite { virtual ~LogCompute() = default; }; + +class ExpCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~ExpCompute() = default; +}; + } // namespace arm } // namespace kernels } // namespace lite diff --git a/lite/kernels/arm/conv_transpose_compute.cc b/lite/kernels/arm/conv_transpose_compute.cc index b68b346270..2e2d3906bb 100644 --- a/lite/kernels/arm/conv_transpose_compute.cc +++ b/lite/kernels/arm/conv_transpose_compute.cc @@ -157,8 +157,8 @@ REGISTER_LITE_KERNEL(conv2d_transpose, kNCHW, paddle::lite::kernels::arm::Conv2DTransposeCompute, def) - .BindInput("x", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("bias", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("filter", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("output", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/expand_compute.cc b/lite/kernels/arm/expand_compute.cc new file mode 100644 index 0000000000..73bcae909e --- /dev/null +++ b/lite/kernels/arm/expand_compute.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/expand_compute.h" +#include +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ExpandCompute::Run() { + auto& param = Param(); + const auto* x = param.X; + auto* out = param.Out; + std::vector expand_times = param.expand_times; + + const float* src = x->data(); + float* dst = out->mutable_data(); + + int dims = expand_times.size(); + DDim in_shape = x->dims(); + + int inner_num = 1; + int i = dims - 1; + int outer_num = in_shape.count(0, i); + inner_num *= in_shape[i]; + for (int j = 0; j < outer_num; ++j) { + for (int k = 0; k < expand_times[i]; ++k) { + memcpy(dst + (j * expand_times[i] + k) * inner_num, + src + j * inner_num, + sizeof(float) * inner_num); + } + } + inner_num *= expand_times[i]; + for (int i = dims - 2; i >= 0; --i) { + int outer_num = in_shape.count(0, i); + inner_num *= in_shape[i]; + for (int j = outer_num - 1; j >= 0; --j) { + for (int k = expand_times[i] - 1; k >= 0; --k) { + memcpy(dst + (j * expand_times[i] + k) * inner_num, + dst + j * inner_num, + sizeof(float) * inner_num); + } + } + inner_num *= expand_times[i]; + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + expand, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ExpandCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/expand_compute.h b/lite/kernels/arm/expand_compute.h new file mode 100644 index 0000000000..d872c2a60b --- /dev/null +++ b/lite/kernels/arm/expand_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ExpandCompute : public KernelLite { + public: + void Run() override; + + virtual ~ExpandCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/matmul_compute.cc b/lite/kernels/arm/matmul_compute.cc new file mode 100644 index 0000000000..ba34228b48 --- /dev/null +++ b/lite/kernels/arm/matmul_compute.cc @@ -0,0 +1,316 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/matmul_compute.h" +#include +#include "lite/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void MatMulCompute::PrepareForRun() { + auto& ctx = this->ctx_->template As(); +} + +void MatMulCompute::Run() { + auto& param = Param(); + + const auto* x_data = param.X->data(); + const auto* y_data = param.Y->data(); + auto* o_data = param.Out->mutable_data(); + + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + bool x_transpose = param.transpose_X; + bool y_transpose = param.transpose_Y; + float alpha = param.alpha; + auto& ctx = this->ctx_->template As(); + + if (x_dims.size() > 2 && y_dims.size() >= 2) { + // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] + // x: [B, M, K], y: [K, N], out: [B, M, N] + if (x_transpose || y_transpose) { + LOG(FATAL) << "not supported transpose for x or y."; + } + CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + + if (y_dims.size() > 2) { + m_ = x_dims[x_dims.size() - 2]; + k_ = y_dims[y_dims.size() - 2]; + n_ = y_dims[y_dims.size() - 1]; + int hblock = lite::arm::math::get_hblock(ctx.arch()); + int m_round = 0; + m_round = hblock * ((m_ + hblock - 1) / hblock); + ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); + int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1]; + int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1]; + int out_inner = x_dims[x_dims.size() - 2] * y_dims[y_dims.size() - 1]; + if (n_ == 1) { + for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { + lite::arm::math::sgemv(x_data + i * x_inner, + y_data + i * y_inner, + o_data + i * out_inner, + false, + m_, + k_, + false, + nullptr, + false); + } + if (fabsf(alpha - 1.f) > 1e-8f) { + for (size_t i = 0; i < param.Out->dims().production(); ++i) { + o_data[i] *= alpha; + } + } + } else { + for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { + float* packed_x = static_cast(ctx.workspace_data()) + + ctx.llc_size() / sizeof(float); + lite::arm::math::prepackA(packed_x, + x_data + i * x_inner, + alpha, + k_, + 0, + m_, + 0, + k_, + false, + &ctx); + int ldb = n_; + if (y_transpose) { + ldb = k_; + } + lite::arm::math::sgemm_prepack(y_transpose, + m_, + n_, + k_, + packed_x, + y_data + i * y_inner, + ldb, + 0.f, + o_data + i * out_inner, + n_, + nullptr, + false, + false, + &ctx); + } + } + } else { + m_ = x_dims[x_dims.size() - 2]; + k_ = y_dims[0]; + n_ = y_dims[1]; + int hblock = lite::arm::math::get_hblock(ctx.arch()); + int m_round = 0; + m_round = hblock * ((m_ + hblock - 1) / hblock); + ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); + int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1]; + int out_inner = x_dims[x_dims.size() - 2] * y_dims[1]; + if (n_ == 1) { + for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { + lite::arm::math::sgemv(x_data + i * x_inner, + y_data, + o_data + i * out_inner, + false, + m_, + k_, + false, + nullptr, + false); + } + if (fabsf(param.alpha - 1.f) > 1e-8f) { + for (size_t i = 0; i < param.Out->dims().production(); ++i) { + o_data[i] *= param.alpha; + } + } + } else { + for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { + float* packed_x = static_cast(ctx.workspace_data()) + + ctx.llc_size() / sizeof(float); + lite::arm::math::prepackA(packed_x, + x_data + i * x_inner, + alpha, + k_, + 0, + m_, + 0, + k_, + false, + &ctx); + int ldb = n_; + if (y_transpose) { + ldb = k_; + } + lite::arm::math::sgemm_prepack(y_transpose, + m_, + n_, + k_, + packed_x, + y_data, + ldb, + 0.f, + o_data + i * out_inner, + n_, + nullptr, + false, + false, + &ctx); + } + } + } + } else if (x_dims.size() == 2 && y_dims.size() == 2) { + // x: [M, K], y: [K, N], out: [M, N] + if (!x_transpose && !y_transpose) { + CHECK_EQ(x_dims[1], y_dims[0]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << "), x_transpose is " << x_transpose << ", y_transpose is " + << y_transpose; + } else if (!x_transpose && y_transpose) { + CHECK_EQ(x_dims[1], y_dims[1]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << "), x_transpose is " << x_transpose << ", y_transpose is " + << y_transpose; + } else if (x_transpose && !y_transpose) { + CHECK_EQ(x_dims[0], y_dims[0]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << "), x_transpose is " << x_transpose << ", y_transpose is " + << y_transpose; + } else { + CHECK_EQ(x_dims[0], y_dims[1]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << "), x_transpose is " << x_transpose << ", y_transpose is " + << y_transpose; + } + // not supported transpose + if (x_transpose || y_transpose) { + LOG(FATAL) << "not supported transpose for x and y."; + } + m_ = x_dims[0]; + k_ = x_dims[1]; + n_ = y_dims[1]; + int hblock = lite::arm::math::get_hblock(ctx.arch()); + int m_round = 0; + m_round = hblock * ((m_ + hblock - 1) / hblock); + ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); + + if (n_ == 1) { + lite::arm::math::sgemv( + x_data, y_data, o_data, x_transpose, m_, k_, false, nullptr, false); + if (fabsf(param.alpha - 1.f) > 1e-8f) { + for (size_t i = 0; i < param.Out->dims().production(); ++i) { + o_data[i] *= param.alpha; + } + } + } else { + float* packed_x = static_cast(ctx.workspace_data()) + + ctx.llc_size() / sizeof(float); + lite::arm::math::prepackA( + packed_x, x_data, alpha, k_, 0, m_, 0, k_, x_transpose, &ctx); + int ldb = n_; + if (y_transpose) { + ldb = k_; + } + lite::arm::math::sgemm_prepack(y_transpose, + m_, + n_, + k_, + packed_x, + y_data, + ldb, + 0.f, + o_data, + n_, + nullptr, + false, + false, + &ctx); + } + } else if (x_dims.size() > 2 && y_dims.size() == 1) { + // x: [B, M, K], y: [K], out: [B, M] + CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 1); ++i) { + o_data[i] = 0; + for (size_t j = 0; j < y_dims[0]; ++j) { + o_data[i] += x_data[i * y_dims[0] + j] * y_data[j] * alpha; + } + } + } else if (x_dims.size() == 1 && y_dims.size() == 1) { + // x: [K], y: [K], out: [1] + if (x_dims[0] == y_dims[0] && x_transpose == false && + y_transpose == false) { + o_data[0] = 0.; + for (size_t i = 0; i < x_dims[0]; ++i) { + o_data[0] += x_data[i] * y_data[i] * alpha; + } + } + // x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N] + if (x_transpose == true && y_transpose == true) { + m_ = x_dims[0]; + k_ = 1; + n_ = y_dims[0]; + if (n_ == 1) { + lite::arm::math::sgemv( + x_data, y_data, o_data, false, m_, k_, false, nullptr, false); + if (fabsf(alpha - 1.f) > 1e-8f) { + for (size_t i = 0; i < param.Out->dims().production(); ++i) { + o_data[i] *= alpha; + } + } + } else { + float* packed_x = static_cast(ctx.workspace_data()) + + ctx.llc_size() / sizeof(float); + lite::arm::math::prepackA( + packed_x, x_data, alpha, k_, 0, m_, 0, k_, false, &ctx); + int ldb = n_; + lite::arm::math::sgemm_prepack(false, + m_, + n_, + k_, + packed_x, + y_data, + ldb, + 0.f, + o_data, + n_, + nullptr, + false, + false, + &ctx); + } + } + } else { + LOG(FATAL) << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + matmul, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::MatMulCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/matmul_compute.h b/lite/kernels/arm/matmul_compute.h new file mode 100644 index 0000000000..7050a05fcf --- /dev/null +++ b/lite/kernels/arm/matmul_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class MatMulCompute : public KernelLite { + public: + using param_t = operators::MatMulParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~MatMulCompute() = default; + + private: + int m_, n_, k_; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/multiclass_nms_compute.cc b/lite/kernels/arm/multiclass_nms_compute.cc index 41673ac22c..c36a81d152 100644 --- a/lite/kernels/arm/multiclass_nms_compute.cc +++ b/lite/kernels/arm/multiclass_nms_compute.cc @@ -78,12 +78,19 @@ void MulticlassNmsCompute::Run() { } } lod_info.push_back(num); - (*lod).push_back(lod_info); - param.out->Resize({static_cast(result_corrected.size() / 6), 6}); - float* out = param.out->mutable_data(); - std::memcpy( - out, result_corrected.data(), sizeof(float) * result_corrected.size()); + + if (result_corrected.empty()) { + (*lod).clear(); + (*lod).push_back(std::vector({0, 1})); + param.out->Resize({static_cast(1)}); + param.out->mutable_data()[0] = -1.; + } else { + param.out->Resize({static_cast(result_corrected.size() / 6), 6}); + float* out = param.out->mutable_data(); + std::memcpy( + out, result_corrected.data(), sizeof(float) * result_corrected.size()); + } } } // namespace arm diff --git a/lite/kernels/arm/multiclass_nms_compute_test.cc b/lite/kernels/arm/multiclass_nms_compute_test.cc index 8d7b1a1850..b0352f77c5 100644 --- a/lite/kernels/arm/multiclass_nms_compute_test.cc +++ b/lite/kernels/arm/multiclass_nms_compute_test.cc @@ -235,6 +235,8 @@ void multiclass_nms_compute_ref(const operators::MulticlassNmsParam& param, if (num_kept == 0) { (*result).clear(); + (*result).resize(1); + (*result)[0] = -1; return; } else { (*result).resize(num_kept * 6); diff --git a/lite/kernels/arm/squeeze_compute.cc b/lite/kernels/arm/squeeze_compute.cc new file mode 100644 index 0000000000..0f79d5c385 --- /dev/null +++ b/lite/kernels/arm/squeeze_compute.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/squeeze_compute.h" +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +void SqueezeCompute::Run() { + auto& param = Param(); + auto x = param.X; + auto output = param.Out; + auto x_dims = x->dims(); + auto* x_data = x->data(); + auto* out_data = output->mutable_data(); + memcpy(out_data, x_data, x_dims.production() * sizeof(float)); +} + +void Squeeze2Compute::Run() { + auto& param = Param(); + auto x = param.X; + auto output = param.Out; + auto xshape = param.XShape; + auto x_dims = x->dims(); + auto* x_data = x->data(); + auto* out_data = output->mutable_data(); + auto* xshape_data = xshape->mutable_data(); + memcpy(out_data, x_data, x_dims.production() * sizeof(float)); + memcpy(xshape_data, x_data, x_dims.production() * sizeof(float)); +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(squeeze, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::host::SqueezeCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +REGISTER_LITE_KERNEL(squeeze2, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::host::Squeeze2Compute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/squeeze_compute.h b/lite/kernels/arm/squeeze_compute.h new file mode 100644 index 0000000000..c9e4c2a17c --- /dev/null +++ b/lite/kernels/arm/squeeze_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class SqueezeCompute : public KernelLite { + public: + void Run() override; + + virtual ~SqueezeCompute() = default; +}; + +class Squeeze2Compute : public KernelLite { + public: + void Run() override; + + virtual ~Squeeze2Compute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/npu/graph_compute.cc b/lite/kernels/npu/graph_compute.cc index 9e3c028697..9f0f557f5c 100644 --- a/lite/kernels/npu/graph_compute.cc +++ b/lite/kernels/npu/graph_compute.cc @@ -40,6 +40,10 @@ void GraphCompute::PrepareForRun() { npu_otensors_.resize(npu_odims_.size()); for (size_t i = 0; i < npu_idims_.size(); ++i) { + VLOG(3) << "npu_idims[" << i << "]: " << npu_idims_[i].GetNumber() << "," + << npu_idims_[i].GetChannel() << "," << npu_idims_[i].GetHeight() + << "," << npu_idims_[i].GetWidth(); + VLOG(3) << "lite_idims[" << i << "]: " << param.inputs[i]->dims(); CHECK_EQ(param.inputs[i]->dims().production(), npu_idims_[i].GetNumber() * npu_idims_[i].GetChannel() * npu_idims_[i].GetHeight() * npu_idims_[i].GetWidth()); @@ -48,6 +52,10 @@ void GraphCompute::PrepareForRun() { } for (size_t i = 0; i < npu_odims_.size(); ++i) { + VLOG(3) << "npu_odims[" << i << "]: " << npu_odims_[i].GetNumber() << "," + << npu_odims_[i].GetChannel() << "," << npu_odims_[i].GetHeight() + << "," << npu_odims_[i].GetWidth(); + VLOG(3) << "lite_odims[" << i << "]: " << param.outputs[i]->dims(); auto out_size = npu_odims_[i].GetNumber() * npu_odims_[i].GetChannel() * npu_odims_[i].GetHeight() * npu_odims_[i].GetWidth(); if (param.outputs[i]->dims().production() != out_size) { diff --git a/lite/model_parser/model_parser.cc b/lite/model_parser/model_parser.cc index e843aafdb6..eff6e5ffa6 100644 --- a/lite/model_parser/model_parser.cc +++ b/lite/model_parser/model_parser.cc @@ -16,10 +16,12 @@ #include #include #include +#include #include "lite/core/scope.h" #include "lite/core/tensor.h" #include "lite/core/variable.h" #include "lite/model_parser/desc_apis.h" +#include "lite/model_parser/naive_buffer/combined_params_desc.h" #include "lite/model_parser/naive_buffer/param_desc.h" #include "lite/model_parser/naive_buffer/program_desc.h" #include "lite/model_parser/naive_buffer/var_desc.h" @@ -316,18 +318,19 @@ void SerializeTensor(std::ostream &os, } /// For navie buffer -void SaveParamNaive(const std::string &path, - const lite::Scope &scope, - const std::string &var_name) { +void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc, + const lite::Scope &scope, + const std::string &var_name) { + CHECK(param_desc); + auto &desc = *param_desc; + // the 1st field, uint32_t version constexpr uint32_t version = 0; auto *var = scope.FindVar(var_name); const auto &tensor = var->Get(); - naive_buffer::BinaryTable table; - naive_buffer::proto::ParamDesc pt_desc(&table); - naive_buffer::ParamDesc desc(&pt_desc); + desc.SetName(var_name); desc.SetModelVersion(version); desc.SetTensorVersion(version); @@ -355,18 +358,50 @@ void SaveParamNaive(const std::string &path, { desc.SetData(tensor.data(), tensor.data_size()); } +} + +void SaveParamNaive(const std::string &path, + const lite::Scope &scope, + const std::string &var_name) { + naive_buffer::BinaryTable table; + naive_buffer::proto::ParamDesc pt_desc(&table); + naive_buffer::ParamDesc desc(&pt_desc); + + SetParamInfoNaive(&desc, scope, var_name); // Save param pt_desc.Save(); table.SaveToFile(path); } +void SaveCombinedParamsNaive(const std::string &path, + const lite::Scope &exec_scope, + const cpp::ProgramDesc &cpp_prog) { + naive_buffer::BinaryTable table; + naive_buffer::proto::CombinedParamsDesc pt_desc(&table); + naive_buffer::CombinedParamsDesc desc(&pt_desc); + + auto prog = cpp_prog; + auto &main_block_desc = *prog.GetBlock(0); + for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) { + auto &var = *main_block_desc.GetVar(i); + if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable()) + continue; + naive_buffer::ParamDesc param_desc(desc.AddParam()); + SetParamInfoNaive(¶m_desc, exec_scope, var.Name()); + } + + pt_desc.Save(); + table.SaveToFile(path); +} + void SaveModelNaive(const std::string &model_dir, const Scope &exec_scope, - const cpp::ProgramDesc &cpp_prog) { + const cpp::ProgramDesc &cpp_prog, + bool combined) { MkDirRecur(model_dir); // Save program - const std::string prog_path = model_dir + "/__model__"; + const std::string prog_path = model_dir + "/__model__.nb"; naive_buffer::BinaryTable table; naive_buffer::proto::ProgramDesc nb_proto_prog(&table); naive_buffer::ProgramDesc nb_prog(&nb_proto_prog); @@ -376,14 +411,19 @@ void SaveModelNaive(const std::string &model_dir, // Save Params // NOTE: Only main block be used now. - auto prog = cpp_prog; - auto &main_block_desc = *prog.GetBlock(0); - for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) { - auto &var = *main_block_desc.GetVar(i); - if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable()) - continue; - const std::string path = model_dir + "/" + var.Name(); - SaveParamNaive(path, exec_scope, var.Name()); + if (combined) { + const std::string combined_params_path = model_dir + "/param.nb"; + SaveCombinedParamsNaive(combined_params_path, exec_scope, cpp_prog); + } else { + auto prog = cpp_prog; + auto &main_block_desc = *prog.GetBlock(0); + for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) { + auto &var = *main_block_desc.GetVar(i); + if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable()) + continue; + const std::string path = model_dir + "/" + var.Name() + ".nb"; + SaveParamNaive(path, exec_scope, var.Name()); + } } VLOG(4) << "Save naive buffer model in '" << model_dir << "'' successfully"; } @@ -398,18 +438,15 @@ void SetTensorDataNaive(T *out, size_t size, const std::vector &src) { } } -void LoadParamNaive(const std::string &path, - lite::Scope *scope, - const std::string &name) { +void GetParamInfoNaive(const naive_buffer::ParamDesc &desc, + lite::Scope *scope, + const std::string &name) { CHECK(scope); - auto *tensor = scope->Var(name)->GetMutable(); + CHECK_EQ(desc.Name(), name) + << "Var name not equal: ParamDesc.name=" << desc.Name() + << "vs filename=" << name; - // Load param - naive_buffer::BinaryTable table; - table.LoadFromFile(path); - naive_buffer::proto::ParamDesc pt_desc(&table); - pt_desc.Load(); - naive_buffer::ParamDesc desc(&pt_desc); + auto *tensor = scope->Var(name)->GetMutable(); VLOG(3) << "model version " << desc.ModelVersion(); CHECK_EQ(desc.TensorVersion(), 0U) << "Only version 0 is supported"; @@ -442,15 +479,56 @@ void LoadParamNaive(const std::string &path, } } +void LoadParamNaive(const std::string &path, + lite::Scope *scope, + const std::string &name) { + // Load param + naive_buffer::BinaryTable table; + table.LoadFromFile(path); + naive_buffer::proto::ParamDesc pt_desc(&table); + pt_desc.Load(); + naive_buffer::ParamDesc desc(&pt_desc); + GetParamInfoNaive(desc, scope, name); +} + +void LoadCombinedParamsNaive(const std::string &path, + lite::Scope *scope, + const cpp::ProgramDesc &cpp_prog) { + naive_buffer::BinaryTable table; + table.LoadFromFile(path); + naive_buffer::proto::CombinedParamsDesc pt_desc(&table); + pt_desc.Load(); + naive_buffer::CombinedParamsDesc desc(&pt_desc); + + std::set param_names; + for (size_t i = 0; i < desc.ParamsSize(); ++i) { + naive_buffer::ParamDesc param_desc(desc.GetParam(i)); + GetParamInfoNaive(param_desc, scope, param_desc.Name()); + param_names.insert(param_desc.Name()); + } + + // Check all params loaded + auto prog = cpp_prog; + auto &main_block_desc = *prog.GetBlock(0); + for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) { + auto &var = *main_block_desc.GetVar(i); + if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable()) + continue; + CHECK(param_names.count(var.Name())) << "Persistable var[" << var.Name() + << "] not found"; + } +} + void LoadModelNaive(const std::string &model_dir, Scope *scope, - cpp::ProgramDesc *cpp_prog) { + cpp::ProgramDesc *cpp_prog, + bool combined) { CHECK(cpp_prog); CHECK(scope); cpp_prog->ClearBlocks(); // Load model - const std::string prog_path = model_dir + "/__model__"; + const std::string prog_path = model_dir + "/__model__.nb"; naive_buffer::BinaryTable table; table.LoadFromFile(prog_path); naive_buffer::proto::ProgramDesc nb_proto_prog(&table); @@ -462,26 +540,33 @@ void LoadModelNaive(const std::string &model_dir, // Load Params // NOTE: Only main block be used now. - auto &prog = *cpp_prog; - auto &main_block_desc = *prog.GetBlock(0); - for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) { - auto &var = *main_block_desc.GetVar(i); - if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable()) - continue; - - std::string file_path = model_dir + "/" + var.Name(); - VLOG(4) << "reading weight " << var.Name(); - - switch (var.GetType()) { - case VarDescAPI::Type::LOD_TENSOR: - LoadParamNaive(file_path, scope, var.Name()); - break; - default: - CHECK(false) << "unknown weight type"; + if (combined) { + const std::string combined_params_path = model_dir + "/param.nb"; + LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog); + } else { + auto &prog = *cpp_prog; + auto &main_block_desc = *prog.GetBlock(0); + for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) { + auto &var = *main_block_desc.GetVar(i); + if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable()) + continue; + + std::string file_path = model_dir + "/" + var.Name() + ".nb"; + VLOG(4) << "reading weight " << var.Name(); + + switch (var.GetType()) { + case VarDescAPI::Type::LOD_TENSOR: + LoadParamNaive(file_path, scope, var.Name()); + break; + default: + CHECK(false) << "unknown weight type"; + } } } #ifdef LITE_WITH_NPU + auto &prog = *cpp_prog; + auto &main_block_desc = *prog.GetBlock(0); for (size_t i = 0; i < main_block_desc.OpsSize(); ++i) { auto &op = *main_block_desc.GetOp(i); if (op.Type() != "graph_op") { diff --git a/lite/model_parser/model_parser.h b/lite/model_parser/model_parser.h index c78ab9bd04..199c36ac21 100644 --- a/lite/model_parser/model_parser.h +++ b/lite/model_parser/model_parser.h @@ -66,18 +66,28 @@ void SaveParamNaive(const std::string& path, const lite::Scope& exec_scope, const std::string& var_name); +void SaveCombinedParamsNaive(const std::string& path, + const lite::Scope& exec_scope, + const cpp::ProgramDesc& cpp_prog); + void SaveModelNaive(const std::string& model_dir, const Scope& exec_scope, - const cpp::ProgramDesc& cpp_prog); + const cpp::ProgramDesc& cpp_prog, + bool combined = true); #endif void LoadParamNaive(const std::string& path, lite::Scope* scope, const std::string& name); +void LoadCombinedParamsNaive(const std::string& path, + lite::Scope* scope, + const cpp::ProgramDesc& cpp_prog); + void LoadModelNaive(const std::string& model_dir, lite::Scope* scope, - cpp::ProgramDesc* prog); + cpp::ProgramDesc* prog, + bool combined = true); } // namespace lite } // namespace paddle diff --git a/lite/model_parser/naive_buffer/CMakeLists.txt b/lite/model_parser/naive_buffer/CMakeLists.txt index 92a0227d62..f85482e5d6 100644 --- a/lite/model_parser/naive_buffer/CMakeLists.txt +++ b/lite/model_parser/naive_buffer/CMakeLists.txt @@ -5,14 +5,15 @@ add_subdirectory(proto) lite_cc_library(nb_op_desc SRCS op_desc.cc DEPS framework_nb) lite_cc_library(nb_var_desc SRCS var_desc.cc DEPS framework_nb) lite_cc_library(nb_param_desc SRCS param_desc.cc DEPS framework_nb) +lite_cc_library(nb_combined_params_desc SRCS combined_params_desc.cc DEPS nb_param_desc framework_nb) lite_cc_library(nb_block_desc SRCS block_desc.cc DEPS framework_nb) lite_cc_library(nb_program_desc SRCS program_desc.cc DEPS framework_nb) set(naive_wrapper - nb_op_desc nb_var_desc nb_param_desc + nb_op_desc nb_var_desc nb_param_desc nb_combined_params_desc nb_block_desc nb_program_desc PARENT_SCOPE) lite_cc_test(test_naive_buffer SRCS naive_buffer_test.cc DEPS naive_buffer) lite_cc_test(test_naive_buffer_wrapper SRCS naive_buffer_wrapper_test.cc - DEPS nb_op_desc nb_var_desc nb_param_desc nb_block_desc - nb_program_desc) + DEPS nb_op_desc nb_var_desc nb_param_desc nb_combined_params_desc + nb_block_desc nb_program_desc) diff --git a/lite/model_parser/naive_buffer/combined_params_desc.cc b/lite/model_parser/naive_buffer/combined_params_desc.cc new file mode 100644 index 0000000000..72a556b852 --- /dev/null +++ b/lite/model_parser/naive_buffer/combined_params_desc.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/model_parser/naive_buffer/combined_params_desc.h" diff --git a/lite/model_parser/naive_buffer/combined_params_desc.h b/lite/model_parser/naive_buffer/combined_params_desc.h new file mode 100644 index 0000000000..a5462ef5ee --- /dev/null +++ b/lite/model_parser/naive_buffer/combined_params_desc.h @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/model_parser/desc_apis.h" +#include "lite/model_parser/naive_buffer/param_desc.h" +#include "lite/model_parser/naive_buffer/proto/framework.nb.h" + +namespace paddle { +namespace lite { +namespace naive_buffer { + +class CombinedParamsDesc { + public: + CombinedParamsDesc() = delete; + + explicit CombinedParamsDesc(proto::CombinedParamsDesc *desc) : desc_(desc) { + CHECK(desc_); + } + + void CopyFrom(CombinedParamsDesc &combined_params_desc) { // NOLINT + CHECK(combined_params_desc.Proto()) + << "Source proto::CombinedParamsDesc pointer can't be null"; + desc_ = combined_params_desc.Proto(); + } + + proto::CombinedParamsDesc *Proto() { return desc_; } + + const proto::CombinedParamsDesc &ReadonlyProto() const { return *desc_; } + + size_t ParamsSize() const { return desc_->size(); } + + void ClearParams() { desc_->Clear(); } + + proto::ParamDesc *GetParam(int32_t idx) { + CHECK_LT(idx, ParamsSize()) << "idx >= params.size()"; + return desc_->GetMutable(idx); + } + + proto::ParamDesc *AddParam() { return desc_->New(); } + + private: + proto::CombinedParamsDesc *desc_; +}; + +} // namespace naive_buffer +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/naive_buffer/naive_buffer_wrapper_test.cc b/lite/model_parser/naive_buffer/naive_buffer_wrapper_test.cc index e1a9810c88..45224de122 100644 --- a/lite/model_parser/naive_buffer/naive_buffer_wrapper_test.cc +++ b/lite/model_parser/naive_buffer/naive_buffer_wrapper_test.cc @@ -14,6 +14,7 @@ #include #include "lite/model_parser/naive_buffer/block_desc.h" +#include "lite/model_parser/naive_buffer/combined_params_desc.h" #include "lite/model_parser/naive_buffer/op_desc.h" #include "lite/model_parser/naive_buffer/param_desc.h" #include "lite/model_parser/naive_buffer/program_desc.h" @@ -97,6 +98,7 @@ TEST(NaiveBufferWrapper, ParamDesc) { ParamDesc nb_desc0(&pt_desc0); // Set ParamDesc + nb_desc0.SetName("fc_w.0"); nb_desc0.SetModelVersion(0); nb_desc0.SetTensorVersion(1); std::vector> lod({{1, 2, 3}, {4, 5}}); @@ -122,6 +124,7 @@ TEST(NaiveBufferWrapper, ParamDesc) { pt_desc1.Load(); ParamDesc nb_desc1(&pt_desc1); + ASSERT_EQ(nb_desc1.Name(), "fc_w.0"); ASSERT_EQ(nb_desc1.ModelVersion(), 0); ASSERT_EQ(nb_desc1.TensorVersion(), 1); ASSERT_EQ(nb_desc1.LoDLevel(), 2); @@ -134,6 +137,84 @@ TEST(NaiveBufferWrapper, ParamDesc) { } } +TEST(NaiveBufferWrapper, CombinedParamsDesc) { + BinaryTable table0; + proto::CombinedParamsDesc pt_desc0(&table0); + CombinedParamsDesc nb_desc0(&pt_desc0); + + // Set ParamDesc + ParamDesc param_desc0_0(nb_desc0.AddParam()); + param_desc0_0.SetName("fc_w.0"); + param_desc0_0.SetModelVersion(0); + param_desc0_0.SetTensorVersion(1); + std::vector> param_desc0_0_lod({{1, 2, 3}, {4, 5}}); + param_desc0_0.SetLoDLevel(2); + param_desc0_0.SetLoD(param_desc0_0_lod); + std::vector param_desc0_0_dim({1, 2, 5}); + param_desc0_0.SetDim(param_desc0_0_dim); + param_desc0_0.SetDataType(VarDescAPI::VarDataType::FP32); + std::vector param_desc0_0_data; + for (int i = 0; i < 10; ++i) { + param_desc0_0_data.push_back(i / 10.0); + } + param_desc0_0.SetData(param_desc0_0_data); + + ParamDesc param_desc0_1(nb_desc0.AddParam()); + param_desc0_1.SetName("fc_b.0"); + param_desc0_1.SetModelVersion(0); + param_desc0_1.SetTensorVersion(1); + std::vector> param_desc0_1_lod({{1}, {2, 3}, {4, 5}}); + param_desc0_1.SetLoDLevel(3); + param_desc0_1.SetLoD(param_desc0_1_lod); + std::vector param_desc0_1_dim({1, 2, 2, 5}); + param_desc0_1.SetDim(param_desc0_1_dim); + param_desc0_1.SetDataType(VarDescAPI::VarDataType::FP32); + std::vector param_desc0_1_data; + for (int i = 0; i < 20; ++i) { + param_desc0_1_data.push_back((i - 10) / 10.0); + } + param_desc0_1.SetData(param_desc0_1_data); + + // Save model + pt_desc0.Save(); + table0.SaveToFile("4.bf"); + + // Load model + BinaryTable table1; + table1.LoadFromFile("4.bf"); + proto::CombinedParamsDesc pt_desc1(&table1); + pt_desc1.Load(); + CombinedParamsDesc nb_desc1(&pt_desc1); + + ASSERT_EQ(nb_desc1.ParamsSize(), 2); + + ParamDesc param_desc1_0(nb_desc1.GetParam(0)); + ASSERT_EQ(param_desc1_0.Name(), "fc_w.0"); + ASSERT_EQ(param_desc1_0.ModelVersion(), 0); + ASSERT_EQ(param_desc1_0.TensorVersion(), 1); + ASSERT_EQ(param_desc1_0.LoDLevel(), 2); + ASSERT_EQ(param_desc1_0.LoD(), param_desc0_0_lod); + ASSERT_EQ(param_desc1_0.Dim(), param_desc0_0_dim); + auto param_desc1_0_data = param_desc1_0.Data(); + ASSERT_EQ(param_desc1_0_data.size(), param_desc0_0_data.size()); + for (size_t i = 0; i < param_desc1_0_data.size(); ++i) { + EXPECT_NEAR(param_desc1_0_data[i], param_desc0_0_data[i], 1e-6); + } + + ParamDesc param_desc1_1(nb_desc1.GetParam(1)); + ASSERT_EQ(param_desc1_1.Name(), "fc_b.0"); + ASSERT_EQ(param_desc1_1.ModelVersion(), 0); + ASSERT_EQ(param_desc1_1.TensorVersion(), 1); + ASSERT_EQ(param_desc1_1.LoDLevel(), 3); + ASSERT_EQ(param_desc1_1.LoD(), param_desc0_1_lod); + ASSERT_EQ(param_desc1_1.Dim(), param_desc0_1_dim); + auto param_desc1_1_data = param_desc1_1.Data(); + ASSERT_EQ(param_desc1_1_data.size(), param_desc0_1_data.size()); + for (size_t i = 0; i < param_desc1_1_data.size(); ++i) { + EXPECT_NEAR(param_desc1_1_data[i], param_desc0_1_data[i], 1e-6); + } +} + TEST(NaiveBufferWrapper, BlockDesc) { BinaryTable table0; proto::BlockDesc pt_desc0(&table0); @@ -161,11 +242,11 @@ TEST(NaiveBufferWrapper, BlockDesc) { // Save model pt_desc0.Save(); - table0.SaveToFile("4.bf"); + table0.SaveToFile("5.bf"); // Load model BinaryTable table1; - table1.LoadFromFile("4.bf"); + table1.LoadFromFile("5.bf"); proto::BlockDesc pt_desc1(&table1); pt_desc1.Load(); BlockDesc nb_desc1(&pt_desc1); @@ -217,11 +298,11 @@ TEST(NaiveBufferWrapper, ProgramDesc) { // Save model pt_desc0.Save(); - table0.SaveToFile("5.bf"); + table0.SaveToFile("6.bf"); // Load model BinaryTable table1; - table1.LoadFromFile("5.bf"); + table1.LoadFromFile("6.bf"); proto::ProgramDesc pt_desc1(&table1); pt_desc1.Load(); ProgramDesc nb_desc1(&pt_desc1); diff --git a/lite/model_parser/naive_buffer/param_desc.cc b/lite/model_parser/naive_buffer/param_desc.cc index 8b7e99a782..b8f2654277 100644 --- a/lite/model_parser/naive_buffer/param_desc.cc +++ b/lite/model_parser/naive_buffer/param_desc.cc @@ -21,6 +21,16 @@ namespace paddle { namespace lite { namespace naive_buffer { +std::string ParamDesc::Name() const { + return desc_->GetField("name").data(); +} + +void ParamDesc::SetName(const std::string& name) { + auto* build = desc_->GetMutableField("name"); + CHECK(build); + build->set(name); +} + uint32_t ParamDesc::ModelVersion() const { return Version("model_version"); } void ParamDesc::SetModelVersion(uint32_t version) { diff --git a/lite/model_parser/naive_buffer/param_desc.h b/lite/model_parser/naive_buffer/param_desc.h index 631ffb40be..0a20b15331 100644 --- a/lite/model_parser/naive_buffer/param_desc.h +++ b/lite/model_parser/naive_buffer/param_desc.h @@ -40,6 +40,10 @@ class ParamDesc { const proto::ParamDesc &ReadonlyProto() const { return *desc_; } + std::string Name() const; + + void SetName(const std::string &name); + uint32_t ModelVersion() const; void SetModelVersion(uint32_t version); diff --git a/lite/model_parser/naive_buffer/proto/framework.nb.h b/lite/model_parser/naive_buffer/proto/framework.nb.h index 4faa6da699..f495a12b46 100644 --- a/lite/model_parser/naive_buffer/proto/framework.nb.h +++ b/lite/model_parser/naive_buffer/proto/framework.nb.h @@ -185,6 +185,7 @@ class ParamDesc : public StructBuilder { public: using lod_type = ListBuilder>; explicit ParamDesc(BinaryTable* table) : StructBuilder(table) { + NewStr("name"); NewUInt32("model_version"); NewUInt64("lod_level"); New("lod"); @@ -194,6 +195,8 @@ class ParamDesc : public StructBuilder { } }; +using CombinedParamsDesc = ListBuilder; + } // namespace proto } // namespace naive_buffer } // namespace lite diff --git a/lite/npu/bridge/CMakeLists.txt b/lite/npu/bridge/CMakeLists.txt index 3a2270ac99..583ea0e2ee 100644 --- a/lite/npu/bridge/CMakeLists.txt +++ b/lite/npu/bridge/CMakeLists.txt @@ -13,7 +13,13 @@ lite_cc_library(npu_bridge_softmax_op SRCS softmax_op.cc DEPS ${npu_bridge_deps} lite_cc_library(npu_bridge_pool_op SRCS pool_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_batch_norm_op SRCS batch_norm_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_elementwise_op SRCS elementwise_ops.cc DEPS ${npu_bridge_deps}) +lite_cc_library(npu_bridge_reshape_op SRCS reshape_op.cc DEPS ${npu_bridge_deps}) +lite_cc_library(npu_bridge_conv_transpose_op SRCS conv_transpose_op.cc DEPS ${npu_bridge_deps}) +lite_cc_library(npu_bridge_bilinear_interp_op SRCS bilinear_interp_op.cc DEPS ${npu_bridge_deps}) lite_cc_library(npu_bridge_transpose_op SRCS transpose_op.cc DEPS ${npu_bridge_deps}) +lite_cc_library(npu_bridge_split_op SRCS split_op.cc DEPS ${npu_bridge_deps}) +lite_cc_library(npu_bridge_concat_op SRCS concat_op.cc DEPS ${npu_bridge_deps}) +lite_cc_library(npu_bridge_shuffle_channel_op SRCS shuffle_channel_op.cc DEPS ${npu_bridge_deps}) set(npu_bridges npu_bridge_registry @@ -27,7 +33,13 @@ set(npu_bridges npu_bridge_pool_op npu_bridge_batch_norm_op npu_bridge_elementwise_op + npu_bridge_reshape_op + npu_bridge_conv_transpose_op + npu_bridge_bilinear_interp_op npu_bridge_transpose_op + npu_bridge_split_op + npu_bridge_concat_op + npu_bridge_shuffle_channel_op CACHE INTERNAL "npu_bridges") lite_cc_library(npu_test_helper SRCS test_helper.cc DEPS npu_helper ${npu_ddk_libs} ${npu_bridges} ${npu_kernels} ${ops}) @@ -41,6 +53,12 @@ lite_cc_test(test_npu_bridge_softmax_op SRCS softmax_op_test.cc DEPS npu_test_he lite_cc_test(test_npu_bridge_pool_op SRCS pool_op_test.cc DEPS npu_test_helper) lite_cc_test(test_npu_bridge_batch_norm_op SRCS batch_norm_op_test.cc DEPS npu_test_helper) lite_cc_test(test_npu_bridge_elementwise_op SRCS elementwise_ops_test.cc DEPS npu_test_helper) +lite_cc_test(test_npu_bridge_reshape_op SRCS reshape_op_test.cc DEPS npu_test_helper) +lite_cc_test(test_npu_bridge_conv_transpose_op SRCS conv_transpose_op_test.cc DEPS npu_test_helper) +lite_cc_test(test_npu_bridge_bilinear_interp_op SRCS bilinear_interp_op_test.cc DEPS npu_test_helper) lite_cc_test(test_npu_bridge_transpose_op SRCS transpose_op_test.cc DEPS npu_test_helper) +lite_cc_test(test_npu_bridge_split_op SRCS split_op_test.cc DEPS npu_test_helper) +lite_cc_test(test_npu_bridge_concat_op SRCS concat_op_test.cc DEPS npu_test_helper) +lite_cc_test(test_npu_bridge_shuffle_channel_op SRCS shuffle_channel_op_test.cc DEPS npu_test_helper) message(STATUS "+++++ npu_bridges: ${npu_bridges}") diff --git a/lite/npu/bridge/act_op.cc b/lite/npu/bridge/act_op.cc index f2f398c4f1..34299844f2 100644 --- a/lite/npu/bridge/act_op.cc +++ b/lite/npu/bridge/act_op.cc @@ -29,14 +29,15 @@ namespace bridge { node_map_type ActConverter(const std::shared_ptr act_op, const node_map_type& inputs_map) { - VLOG(3) << "invoking ActConverter..."; auto scope = act_op->scope(); auto op_info = act_op->op_info(); auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " + op_type + "..."; // create act node and set input node from inputs_map auto x_var_name = op_info->Input("X").front(); - auto act_node = std::make_shared(UniqueName(op_type)); + auto act_node = std::make_shared(unique_op_type); CHECK(inputs_map.count(x_var_name)); act_node->set_input_x(*inputs_map.at(x_var_name)); OpList::Global().add(inputs_map.at(x_var_name)); diff --git a/lite/npu/bridge/bilinear_interp_op.cc b/lite/npu/bridge/bilinear_interp_op.cc new file mode 100644 index 0000000000..c7f3289af3 --- /dev/null +++ b/lite/npu/bridge/bilinear_interp_op.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ai_ddk_lib/include/graph/buffer.h" +#include "ai_ddk_lib/include/graph/graph.h" +#include "ai_ddk_lib/include/graph/model.h" +#include "ai_ddk_lib/include/graph/op/all_ops.h" +#include "ai_ddk_lib/include/graph/operator.h" +#include "ai_ddk_lib/include/graph/operator_reg.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/utils.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +node_map_type BilinearInterpConverter( + const std::shared_ptr interp_op, + const node_map_type& inputs_map) { + auto scope = interp_op->scope(); + auto op_info = interp_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " + op_type + "..."; + + // get input, output and attributes from lite op + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims(); + auto x_h = x_dims[2]; + auto x_w = x_dims[3]; + CHECK_EQ(x_dims.size(), 4); + auto scale = op_info->GetAttr("scale"); + auto out_w = op_info->GetAttr("out_w"); + auto out_h = op_info->GetAttr("out_h"); + auto align_corners = op_info->GetAttr("align_corners"); + auto interp_method = op_info->GetAttr("interp_method"); + int align_mode = op_info->GetAttr("align_mode"); + CHECK(!(align_mode == 0 && !align_corners)) + << "align_mode = 0 && align_corners = false isn't supported in NPU DDK"; + + // priority: OutSize > scale > out_h/out_w + if (scale > 0) { + out_h = static_cast(x_h * scale); + out_w = static_cast(x_w * scale); + out_h = out_h > 0 ? out_h : -1; + out_w = out_w > 0 ? out_w : -1; + } + + // create interp node and set input node from inputs_map + auto interp_node = std::make_shared(unique_op_type); + CHECK(inputs_map.count(x_var_name)); + interp_node->set_input_x(*inputs_map.at(x_var_name)); + OpList::Global().add(inputs_map.at(x_var_name)); + OpList::Global().add(interp_node); + + // update out_h and out_w if has OutSize + bool is_dyn_out_size = false; + if (HasInputArg(op_info, scope, "OutSize")) { + auto out_size_var_name = op_info->Input("OutSize").front(); + if (!inputs_map.count(out_size_var_name)) { + auto out_size = + scope->FindVar(out_size_var_name)->GetMutable(); + auto out_size_dims = out_size->dims(); + CHECK_EQ(out_size_dims.size(), 1); + CHECK_EQ(out_size_dims.production(), 2); + auto out_size_data = out_size->mutable_data(); + // update out_h and out_w if has OutSize + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } else { + interp_node->set_input_w(*inputs_map.at(out_size_var_name)); + OpList::Global().add(inputs_map.at(out_size_var_name)); + is_dyn_out_size = true; // using dynamic output size + } + } + if (!is_dyn_out_size) { + CHECK_GT(out_h, 0); + CHECK_GT(out_w, 0); + const float largest_multiple = 7.0f; + float multiple = static_cast(x_h * x_w) / (out_h * out_w); + CHECK_LT(multiple, largest_multiple) + << "multiple=(ih*iw)/(oh*ow)=" << multiple + << " is too large, should not exceed " << largest_multiple + << " in NPU DDK"; + auto w_const_node = std::make_shared(unique_op_type + "/w"); + w_const_node->set_attr_value( + CreateTensorAndFillData(std::vector({out_h, out_w}))); + interp_node->set_input_w(*w_const_node); + OpList::Global().add(w_const_node); + } + + // set attributes + interp_node->set_attr_output_dim_mode( + 2); // 0: zoom_factor, 1: shrink_factor, 2: height/width + interp_node->set_attr_align_corners(align_corners); + + node_map_type outputs_map; + outputs_map[op_info->Output("Out").front()] = interp_node; + return outputs_map; +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +REGISTER_NPU_BRIDGE(bilinear_interp, + paddle::lite::npu::bridge::BilinearInterpConverter); diff --git a/lite/npu/bridge/bilinear_interp_op_test.cc b/lite/npu/bridge/bilinear_interp_op_test.cc new file mode 100644 index 0000000000..402b939096 --- /dev/null +++ b/lite/npu/bridge/bilinear_interp_op_test.cc @@ -0,0 +1,314 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/core/op_registry.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/test_helper.h" +#include "lite/operators/interpolate_op.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +template +void bilinear_interp_ref(const std::shared_ptr op) { + auto scope = op->scope(); + auto op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + auto x_dims = x->dims(); + int batch_size = x_dims[0]; + int channel_size = x_dims[1]; + auto x_h = x_dims[2]; + auto x_w = x_dims[3]; + CHECK_EQ(x_dims.size(), 4); + auto scale = op_info->GetAttr("scale"); + auto out_w = op_info->GetAttr("out_w"); + auto out_h = op_info->GetAttr("out_h"); + auto align_corners = op_info->GetAttr("align_corners"); + int align_mode = op_info->GetAttr("align_mode"); + auto interp_method = op_info->GetAttr("interp_method"); + + // calc real out_h and out_w + if (scale > 0) { + out_h = static_cast(x_h * scale); + out_w = static_cast(x_w * scale); + } + if (op_info->HasInput("OutSize")) { + auto out_size_var_names = op_info->Input("OutSize"); + if (out_size_var_names.size() > 0) { + auto out_size_var_name = out_size_var_names.front(); + auto out_size = + scope->FindVar(out_size_var_name)->GetMutable(); + auto out_size_dims = out_size->dims(); + CHECK_EQ(out_size_dims.size(), 1); + CHECK_EQ(out_size_dims.production(), 2); + auto out_size_data = out_size->mutable_data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + } + CHECK_GT(out_h, 0); + CHECK_GT(out_w, 0); + out->Resize({batch_size, channel_size, out_h, out_w}); + + // copy from x if no change + if (x_h == out_h && x_w == out_w) { + out->CopyDataFrom(*x); + return; + } + + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(x_h - 1) / (out_h - 1) + : static_cast(x_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(x_w - 1) / (out_w - 1) + : static_cast(x_w) / out_w; + } + + // naive bilinear interpolation + auto x_data = x->mutable_data(); + auto out_data = out->mutable_data(); + bool align_flag = (align_mode == 0 && !align_corners); + + std::vector vy_n, vy_s; + std::vector vd_n, vd_s; + vy_n.reserve(out_h); + vy_s.reserve(out_h); + vd_n.reserve(out_h); + vd_s.reserve(out_h); + for (int k = 0; k < out_h; k++) { + int yn = align_flag ? static_cast(ratio_h * (k + 0.5) - 0.5) + : static_cast(ratio_h * k); + yn = (yn > 0) ? yn : 0; + int ys = (yn + 1) < (x_h - 1) ? (yn + 1) : (x_h - 1); + float idx_src_y = ratio_h * (k + 0.5) - 0.5; + idx_src_y = (idx_src_y > 0) ? idx_src_y : 0; + float dn = align_flag ? idx_src_y - yn : ratio_h * k - yn; + float ds = 1.f - dn; + { + vy_n[k] = yn; + vy_s[k] = ys; + vd_n[k] = dn; + vd_s[k] = ds; + } + } + + std::vector vx_w, vx_e; + std::vector vd_w, vd_e; + vx_w.reserve(out_w); + vx_e.reserve(out_w); + vd_w.reserve(out_w); + vd_e.reserve(out_w); + for (int l = 0; l < out_w; l++) { + int xw = (align_mode == 0 && !align_corners) + ? static_cast(ratio_w * (l + 0.5) - 0.5) + : static_cast(ratio_w * l); + xw = (xw > 0) ? xw : 0; + int xe = (xw + 1) < (x_w - 1) ? (xw + 1) : (x_w - 1); + float idx_src_x = ratio_w * (l + 0.5) - 0.5; + idx_src_x = (idx_src_x > 0) ? idx_src_x : 0; + float dw = align_flag ? idx_src_x - xw : ratio_w * l - xw; + float de = 1.f - dw; + { + vx_w[l] = xw; + vx_e[l] = xe; + vd_w[l] = dw; + vd_e[l] = de; + } + } + + std::vector x_strides(x_dims.size(), 1); + for (int idx = x_strides.size() - 2; idx >= 0; idx--) { + x_strides[idx] = x_strides[idx + 1] * x_dims[idx + 1]; + } + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < channel_size; j++) { + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + DType x0 = x_data[i * x_strides[0] + j * x_strides[1] + + vy_n[k] * x_strides[2] + vx_w[l] * x_strides[3]]; + DType x1 = x_data[i * x_strides[0] + j * x_strides[1] + + vy_s[k] * x_strides[2] + vx_w[l] * x_strides[3]]; + DType x2 = x_data[i * x_strides[0] + j * x_strides[1] + + vy_n[k] * x_strides[2] + vx_e[l] * x_strides[3]]; + DType x3 = x_data[i * x_strides[0] + j * x_strides[1] + + vy_s[k] * x_strides[2] + vx_e[l] * x_strides[3]]; + *out_data = x0 * vd_s[k] * vd_e[l] + x1 * vd_n[k] * vd_e[l] + + x2 * vd_s[k] * vd_w[l] + x3 * vd_n[k] * vd_w[l]; + out_data++; + } + } + } + } +} + +void test_bilinear_interp(int bs, + int ic, + int ih, + int iw, + int oh, + int ow, + float scale, + int out_size_h, + int out_size_w, + bool align_corners, + int align_mode) { + // prepare input&output variables + Scope scope; + std::string x_var_name("x"); + std::string out_size_var_name("out_size"); + std::string out_var_name("out"); + std::string out_ref_var_name("out_ref"); + auto x = scope.Var(x_var_name)->GetMutable(); + auto out_size = scope.Var(out_size_var_name)->GetMutable(); + auto out = scope.Var(out_var_name)->GetMutable(); + auto out_ref = scope.Var(out_ref_var_name)->GetMutable(); + x->Resize({bs, ic, ih, iw}); + out_size->Resize({2}); + + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("bilinear_interp"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("out_h", oh); + opdesc.SetAttr("out_w", ow); + opdesc.SetAttr("scale", scale); + opdesc.SetAttr("align_corners", static_cast(align_corners)); + opdesc.SetAttr("align_mode", static_cast(align_mode)); + opdesc.SetAttr("interp_method", std::string("bilinear")); + if (out_size_h > 0 && out_size_w > 0) { + auto out_size_dims = out_size->dims(); + CHECK_EQ(out_size_dims.size(), 1); + CHECK_EQ(out_size_dims.production(), 2); + auto out_size_data = out_size->mutable_data(); + out_size_data[0] = out_size_h; + out_size_data[1] = out_size_w; + opdesc.SetInput("OutSize", {out_size_var_name}); + } + + // create op and execute reference implementation + auto op = CreateOp(opdesc, &scope); + bilinear_interp_ref(op); + out_ref->CopyDataFrom(*out); + + // convert op to NPU model, then run it on NPU + LauchOp(op, {x_var_name}, {out_var_name}); + + // compare results + auto out_dims = out->dims(); + auto out_ref_dims = out_ref->dims(); + CHECK_EQ(out_dims.size(), out_ref_dims.size()); + for (int i = 0; i < out_dims.size(); i++) { + CHECK_EQ(out_dims[i], out_ref_dims[i]); + } + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2f); + } +} + +TEST(NPUBridges, bilinear_interp) { +#if 1 + for (auto bs : {1, 3}) { + for (auto ic : {3, 4}) { + for (auto ih : {4, 5}) { + for (auto iw : {3, 6}) { + for (auto oh : {0, 3, 8}) { + for (auto ow : {0, 4, 9}) { + for (auto scale : {0.f, 0.5f, 0.6f, 2.0f, 2.2f}) { + for (auto out_size_h : {0, 3, 11}) { + for (auto out_size_w : {0, 2, 12}) { + for (auto align_corners : {true, false}) { + for (auto align_mode : {0, 1}) { + int act_oh = 0, act_ow = 0; + if (out_size_h > 0 && out_size_w > 0) { + act_oh = out_size_h; + act_ow = out_size_w; + } else if (scale > 1e-5) { + act_oh = static_cast(ih * scale); + act_ow = static_cast(iw * scale); + } else if (oh > 0 && ow > 0) { + act_oh = oh; + act_ow = ow; + } + if (act_oh <= 0 || act_ow <= 0) { + continue; + } + // TODO(hong19860320) multiple=(ih*iw)/(oh*ow) should + // not exceed 7.0 in NPU DDK, delete the following lines + // if the limination is removed. + const float largest_multiple = 7.0f; + float multiple = + static_cast(ih * iw) / (act_oh * act_ow); + if (multiple > largest_multiple) { + continue; + } + if (align_mode == 0 && !align_corners) { + continue; + } + VLOG(3) + << "bs: " << bs << " ic: " << ic << " ih: " << ih + << " iw: " << iw << " oh: " << oh << " ow: " << ow + << " scale: " << scale + << " out_size: " << out_size_h << "," << out_size_w + << " align_corners: " << align_corners + << " align_mode: " << align_mode; + test_bilinear_interp(bs, + ic, + ih, + iw, + oh, + ow, + scale, + out_size_h, + out_size_w, + align_corners, + align_mode); + } + } + } + } + } + } + } + } + } + } + } +#else + test_bilinear_interp(3, 4, 5, 3, 8, 4, 0.6f, 3, 0, true, 0); +#endif +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +USE_LITE_OP(bilinear_interp); +USE_NPU_BRIDGE(bilinear_interp); diff --git a/lite/npu/bridge/concat_op.cc b/lite/npu/bridge/concat_op.cc new file mode 100644 index 0000000000..e3d937a477 --- /dev/null +++ b/lite/npu/bridge/concat_op.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/concat_op.h" +#include "ai_ddk_lib/include/graph/buffer.h" +#include "ai_ddk_lib/include/graph/graph.h" +#include "ai_ddk_lib/include/graph/model.h" +#include "ai_ddk_lib/include/graph/op/all_ops.h" +#include "ai_ddk_lib/include/graph/operator.h" +#include "ai_ddk_lib/include/graph/operator_reg.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/utils.h" +#include "lite/npu/npu_helper.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +node_map_type ConcatConverter(const std::shared_ptr concat_op, + const node_map_type& inputs_map) { + lite::Scope* scope = concat_op->scope(); + const lite::OpInfo* op_info = concat_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "converting " << op_type << " ... "; + + auto x_var_names = op_info->Input("X"); + auto axis = op_info->GetAttr("axis"); + int num = x_var_names.size(); + int index = 0; + + std::shared_ptr output_node = + std::make_shared(unique_op_type); + output_node->set_attr_axis(axis); + output_node->set_attr_N(num); + output_node->create_dynamic_input_x(num); + for (auto x_var_name : x_var_names) { + if (inputs_map.find(x_var_name) != inputs_map.end()) { + output_node->set_dynamic_input_x(index + 1, *inputs_map.at(x_var_name)); + OpList::Global().add(inputs_map.at(x_var_name)); + } else { + auto consty = std::make_shared(x_var_name); + auto* x = scope->FindVar(x_var_name)->GetMutable(); + consty->set_attr_value(CvtFromLiteTensor(x)); + output_node->set_dynamic_input_x(index + 1, *consty); + OpList::Global().add(consty); + } + index++; + } + OpList::Global().add(output_node); + + node_map_type outputs_map; + outputs_map[op_info->Output("Out").front()] = output_node; + return outputs_map; +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +REGISTER_NPU_BRIDGE(concat, paddle::lite::npu::bridge::ConcatConverter); diff --git a/lite/npu/bridge/concat_op_test.cc b/lite/npu/bridge/concat_op_test.cc new file mode 100644 index 0000000000..c9aa157b75 --- /dev/null +++ b/lite/npu/bridge/concat_op_test.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/concat_op.h" +#include +#include +#include "lite/core/op_registry.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/test_helper.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +std::vector stride_numel(const DDim& ddim) { + std::vector strides(ddim.size()); + strides[ddim.size() - 1] = ddim[ddim.size() - 1]; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i]; + } + return strides; +} + +void concat_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = op_info->Input("X"); + std::vector inputs; + for (auto var : x) { + inputs.push_back(scope->FindVar(var)->GetMutable()); + } + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + int axis = op_info->GetAttr("axis"); + std::vector inputs_concat(inputs.size()); + for (int j = 0; j < inputs.size(); ++j) { + inputs_concat[j] = inputs[j]; + } + size_t num = inputs.size(); + int rows = 1; + auto dim_0 = inputs[0]->dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + std::vector inputs_cols(inputs.size()); + for (int i = 0; i < num; ++i) { + int t_cols = inputs[i]->numel() / rows; + out_cols += t_cols; + inputs_cols[i] = t_cols; + } + for (int k = 0; k < out_rows; ++k) { + float* dst_ptr = out->mutable_data() + k * out_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = inputs_cols[j]; + const float* src_prt = inputs[j]->data() + k * col_len; + std::memcpy(dst_ptr + col_idx, src_prt, sizeof(float) * col_len); + col_idx += col_len; + } + } +} + +void test_concat(std::vector> input, int axis) { + std::string x_var_name = "x"; + std::string y_var_name = "y"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + + // prepare input&output variables + Scope scope; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* y = scope.Var(y_var_name)->GetMutable(); + x->Resize(DDim(input[0])); + y->Resize(DDim(input[1])); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + CHECK_EQ(out->dims(), out_ref->dims()); + + // initialize input&output data + FillTensor(x); + FillTensor(y); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("concat"); + opdesc.SetInput("X", {x_var_name, y_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("axis", axis); + + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {x_var_name, y_var_name}, {out_var_name}); + out_ref->CopyDataFrom(*out); + concat_ref(op); + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data[i], out_ref_data[i], 5e-4); + } +} + +TEST(NPUBridges, concat) { + test_concat({{3, 3, 5, 2}, {2, 3, 5, 2}}, 0); + test_concat({{3, 5, 5, 2}, {3, 1, 5, 2}}, 1); + test_concat({{3, 3, 2, 2}, {3, 3, 4, 2}}, 2); + test_concat({{3, 3, 5, 2}, {3, 3, 5, 6}}, 3); +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +USE_LITE_OP(concat); +USE_NPU_BRIDGE(concat); diff --git a/lite/npu/bridge/conv_op.cc b/lite/npu/bridge/conv_op.cc index a3bb7c53e4..532d308b99 100644 --- a/lite/npu/bridge/conv_op.cc +++ b/lite/npu/bridge/conv_op.cc @@ -33,20 +33,16 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, auto op_info = conv_op->op_info(); auto op_type = op_info->Type(); auto unique_op_type = UniqueName(op_type); - LOG(INFO) << "Converting " << op_type << " ... "; + LOG(INFO) << "Converting " << op_type << "... "; - // get input, output and op attributes + // get input, filter and op attributes auto input_var_name = op_info->Input("Input").front(); auto input = scope->FindVar(input_var_name)->GetMutable(); auto input_dims = input->dims(); - auto output_var_name = op_info->Output("Output").front(); - auto output = scope->FindVar(output_var_name)->GetMutable(); - auto output_dims = output->dims(); auto filter_var_name = op_info->Input("Filter").front(); auto filter = scope->FindVar(filter_var_name)->GetMutable(); auto filter_dims = filter->dims(); CHECK_EQ(input_dims.size(), 4); - CHECK_EQ(output_dims.size(), 4); CHECK_EQ(filter_dims.size(), 4); auto strides = op_info->GetAttr>("strides"); auto paddings = op_info->GetAttr>("paddings"); @@ -89,33 +85,9 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, auto* bias = scope->FindVar(bias_var_name)->GetMutable(); auto channel_size = bias->dims().production(); CHECK_EQ(channel_size, filter_dims[0]); - CHECK_EQ(channel_size, output_dims[1]); bias_const_node = std::make_shared(bias_var_name); - if (use_depthwise_conv && is_depthwise_mode) { - // broadcast bias(1, oc, 1, 1) to (n, oc, oh, ow) - ge::TensorDesc bias_desc( - ge::Shape(output_dims.Vectorize()), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorPtr bias_tensor = std::make_shared(); - bias_tensor->SetTensorDesc(bias_desc); - auto old_bias_data = bias->mutable_data(); - std::vector new_bias_data(output_dims.production()); - int batch_size = output_dims[0]; - int inner_size = output_dims[2] * output_dims[3]; - for (int k = 0; k < batch_size; k++) { - for (int j = 0; j < channel_size; j++) { - for (int i = 0; i < inner_size; i++) { - new_bias_data[i + j * inner_size + k * channel_size * inner_size] = - old_bias_data[j]; - } - } - } - bias_tensor->SetData(reinterpret_cast(new_bias_data.data()), - new_bias_data.size() * sizeof(float)); - bias_const_node->set_attr_value(bias_tensor); - } else { - bias_const_node->set_attr_value( - CvtFromLiteTensor(bias, {1, channel_size, 1, 1})); - } + bias_const_node->set_attr_value( + CvtFromLiteTensor(bias, {1, channel_size, 1, 1})); OpList::Global().add(bias_const_node); } @@ -142,13 +114,11 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, OpList::Global().add(depthwise_conv_node); conv_node = depthwise_conv_node; if (bias_const_node != nullptr) { - auto eltwise_add_node = - std::make_shared(unique_op_type + "/eltwise_add"); - eltwise_add_node->set_input_x1(*depthwise_conv_node); - eltwise_add_node->set_input_x2(*bias_const_node); - eltwise_add_node->set_attr_mode(1); // 0:product, 1:sum, 2:max - OpList::Global().add(eltwise_add_node); - conv_node = eltwise_add_node; + auto add_node = std::make_shared(unique_op_type + "/add"); + add_node->set_input_x1(*depthwise_conv_node); + add_node->set_input_x2(*bias_const_node); + OpList::Global().add(add_node); + conv_node = add_node; } } else { auto common_conv_node = @@ -182,9 +152,9 @@ node_map_type ConvConverter(const std::shared_ptr conv_op, relu_node->set_input_x(*conv_node); relu_node->set_attr_mode(1); OpList::Global().add(relu_node); - outputs_map[output_var_name] = relu_node; + outputs_map[op_info->Output("Output").front()] = relu_node; } else { - outputs_map[output_var_name] = conv_node; + outputs_map[op_info->Output("Output").front()] = conv_node; } return outputs_map; } diff --git a/lite/npu/bridge/conv_transpose_op.cc b/lite/npu/bridge/conv_transpose_op.cc new file mode 100644 index 0000000000..cc587289bf --- /dev/null +++ b/lite/npu/bridge/conv_transpose_op.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/conv_transpose_op.h" +#include "ai_ddk_lib/include/graph/buffer.h" +#include "ai_ddk_lib/include/graph/graph.h" +#include "ai_ddk_lib/include/graph/model.h" +#include "ai_ddk_lib/include/graph/op/all_ops.h" +#include "ai_ddk_lib/include/graph/operator.h" +#include "ai_ddk_lib/include/graph/operator_reg.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/utils.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +node_map_type ConvTransposeConverter( + const std::shared_ptr conv_transpose_op, + const node_map_type& inputs_map) { + auto scope = conv_transpose_op->scope(); + auto op_info = conv_transpose_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " << op_type << "... "; + + // get input, output and op attributes + auto input_var_name = op_info->Input("Input").front(); + auto input = scope->FindVar(input_var_name)->GetMutable(); + auto input_shape = input->dims().Vectorize(); + auto filter_var_name = op_info->Input("Filter").front(); + auto filter = scope->FindVar(filter_var_name)->GetMutable(); + auto filter_shape = filter->dims().Vectorize(); + CHECK_EQ(input_shape.size(), 4); + CHECK_EQ(filter_shape.size(), 4); + auto strides = op_info->GetAttr>("strides"); + auto paddings = op_info->GetAttr>("paddings"); + auto groups = op_info->GetAttr("groups"); + auto dilations = op_info->GetAttr>("dilations"); + auto fuse_relu = op_info->GetAttr("fuse_relu"); + CHECK_EQ(strides.size(), 2); + CHECK_EQ(paddings.size(), 2); + CHECK_EQ(dilations.size(), 2); + + // create deconv node + auto conv_transpose_node = + std::make_shared(unique_op_type); + + // create input sizes node to describe the dimensions of input tensor + std::vector output_shape; + output_shape.push_back(input_shape[0]); + output_shape.push_back(filter_shape[1] * groups); + for (int i = 0; i < strides.size(); i++) { + int kernel_ext = dilations[i] * (filter_shape[i + 2] - 1) + 1; + int output_size = + (input_shape[i + 2] - 1) * strides[i] + kernel_ext - 2 * paddings[i]; + output_shape.push_back(output_size); + } + auto input_sizes_const_node = + std::make_shared(unique_op_type + "/input_size"); + input_sizes_const_node->set_attr_value(CreateTensorAndFillData(output_shape)); + conv_transpose_node->set_input_input_sizes(*input_sizes_const_node); + OpList::Global().add(input_sizes_const_node); + + // create filter node + CHECK(!inputs_map.count(filter_var_name)); + auto filter_const_node = std::make_shared(filter_var_name); + filter_const_node->set_attr_value(CvtFromLiteTensor(filter)); + conv_transpose_node->set_input_filter(*filter_const_node); + OpList::Global().add(filter_const_node); + + // set input node + CHECK(inputs_map.count(input_var_name)); + conv_transpose_node->set_input_x(*inputs_map.at(input_var_name)); + OpList::Global().add(inputs_map.at(input_var_name)); + + // set attributes + conv_transpose_node->set_attr_mode(1); + conv_transpose_node->set_attr_format(0); // NCHW + conv_transpose_node->set_attr_pad_mode(0); // NOTSET + conv_transpose_node->set_attr_group(groups); + conv_transpose_node->set_attr_pad(ge::AttrValue::LIST_INT( + {paddings[0], paddings[0], paddings[1], paddings[1]})); + conv_transpose_node->set_attr_dilation( + ge::AttrValue::LIST_INT({dilations[0], dilations[1]})); + conv_transpose_node->set_attr_stride( + ge::AttrValue::LIST_INT({strides[0], strides[1]})); + conv_transpose_node->set_attr_kernel( + ge::AttrValue::LIST_INT({filter_shape[2], filter_shape[3]})); + OpList::Global().add(conv_transpose_node); + + // append add node to add bias if has bias + std::shared_ptr output_node = conv_transpose_node; + if (HasInputArg(op_info, scope, "Bias")) { + // create bias node + auto bias_var_name = op_info->Input("Bias").front(); + CHECK(!inputs_map.count(bias_var_name)); + auto* bias = scope->FindVar(bias_var_name)->GetMutable(); + auto channel_size = bias->dims().production(); + CHECK_EQ(channel_size, filter_shape[1] * groups); + auto bias_const_node = std::make_shared(bias_var_name); + bias_const_node->set_attr_value( + CvtFromLiteTensor(bias, {1, channel_size, 1, 1})); + OpList::Global().add(bias_const_node); + // append add node to add bias node + auto add_node = std::make_shared(unique_op_type + "/add"); + add_node->set_input_x1(*conv_transpose_node); + add_node->set_input_x2(*bias_const_node); + OpList::Global().add(add_node); + output_node = add_node; + } + + node_map_type outputs_map; + if (fuse_relu) { + // append relu node if fuse_relu is true + auto relu_node = + std::make_shared(unique_op_type + "/relu"); + relu_node->set_input_x(*output_node); + relu_node->set_attr_mode(1); + OpList::Global().add(relu_node); + outputs_map[op_info->Output("Output").front()] = relu_node; + } else { + outputs_map[op_info->Output("Output").front()] = output_node; + } + return outputs_map; +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +REGISTER_NPU_BRIDGE(conv2d_transpose, + paddle::lite::npu::bridge::ConvTransposeConverter); diff --git a/lite/npu/bridge/conv_transpose_op_test.cc b/lite/npu/bridge/conv_transpose_op_test.cc new file mode 100644 index 0000000000..3d55e291ce --- /dev/null +++ b/lite/npu/bridge/conv_transpose_op_test.cc @@ -0,0 +1,369 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/conv_transpose_op.h" +#include +#include +#include "lite/core/op_registry.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/test_helper.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +template +void add_bias_with_relu(DType* data, + const DType* bias, + int channel_size, + int inner_size, + bool has_relu) { + for (int c = 0; c < channel_size; ++c) { + DType bias_val = bias != nullptr ? bias[c] : 0; + for (int i = 0; i < inner_size; i++) { + DType data_val = data[i]; + data_val += bias_val; + if (has_relu) { + data_val = data_val > 0 ? data_val : 0.f; + } + data[i] = data_val; + } + data += inner_size; + } +} + +template +void col2im(const DType* data_col, + const int channel_size, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + DType* data_im) { + memset(data_im, 0, height * width * channel_size * sizeof(DType)); + const int output_h = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int output_w = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + const int inner_size = height * width; + for (int c = channel_size; c--; data_im += inner_size) { + for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_row = -pad_h + kernel_row * dilation_h; + for (int output_rows = output_h; output_rows; output_rows--) { + if (input_row < 0 || input_row >= height) { + data_col += output_w; + } else { + int input_col = -pad_w + kernel_col * dilation_w; + for (int output_col = output_w; output_col; output_col--) { + if (input_col >= 0 && input_col < width) { + data_im[input_row * width + input_col] += *data_col; + } + data_col++; + input_col += stride_w; + } + } + input_row += stride_h; + } + } + } + } +} + +template +void gemm(int M, + int N, + int K, + const IType* A, + const IType* B, + OType* C, + OType alpha, + OType beta, + bool is_trans_A = false, + bool is_trans_B = false) { + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + OType sum = static_cast(0); + for (int k = 0; k < K; ++k) { + IType a; + IType b; + if (is_trans_A) { + a = A[k * M + m]; + } else { + a = A[m * K + k]; + } + if (is_trans_B) { + b = B[n * K + k]; + } else { + b = B[k * N + n]; + } + sum += a * b; + } + C[m * N + n] = alpha * sum + beta * C[m * N + n]; + } + } +} + +template +void conv_transpose_ref( + const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto input = + scope->FindVar(op_info->Input("Input").front())->GetMutable(); + auto filter = + scope->FindVar(op_info->Input("Filter").front())->GetMutable(); + auto output = + scope->FindVar(op_info->Output("Output").front())->GetMutable(); + std::vector strides = + op_info->GetAttr>("strides"); + std::vector paddings = + op_info->GetAttr>("paddings"); + int32_t groups = op_info->GetAttr("groups"); + std::vector dilations = + op_info->GetAttr>("dilations"); + bool fuse_relu = op_info->GetAttr("fuse_relu"); + Tensor* bias = nullptr; + OType* bias_data = nullptr; + if (op_info->HasInput("Bias")) { + auto bias_var_names = op_info->Input("Bias"); + if (bias_var_names.size() > 0) { + auto bias_var_name = bias_var_names.front(); + bias = scope->FindVar(bias_var_name)->GetMutable(); + bias_data = bias->mutable_data(); + } + } + auto input_dims = input->dims(); + auto filter_dims = filter->dims(); + auto output_dims = output->dims(); + auto input_data = input->mutable_data(); + auto filter_data = filter->mutable_data(); + auto output_data = output->mutable_data(); + int kernel_w = filter_dims[3]; + int kernel_h = filter_dims[2]; + int stride_w = strides[1]; + int stride_h = strides[0]; + int dila_w = dilations[1]; + int dila_h = dilations[0]; + int pad_w = paddings[1]; + int pad_h = paddings[0]; + int batch_size = input_dims[0]; + int in_ch_size = input_dims[1]; + int in_h = input_dims[2]; + int in_w = input_dims[3]; + int out_ch_size = output_dims[1]; + int out_h = output_dims[2]; + int out_w = output_dims[3]; + + int M = out_ch_size * kernel_w * kernel_h / groups; + int N = in_h * in_w; + int K = in_ch_size / groups; + + if (in_ch_size != out_ch_size || groups != in_ch_size) { + CHECK_EQ(in_ch_size % groups, 0); + CHECK_EQ(out_ch_size % groups, 0); + } + + auto workspace = std::vector(groups * M * N); + int group_input_size = in_w * in_h * in_ch_size / groups; + int group_output_size = out_w * out_h * out_ch_size / groups; + int group_col_size = M * N; + int group_filter_size = + in_ch_size * out_ch_size * kernel_w * kernel_h / (groups * groups); + bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) && + (stride_w == 1) && (pad_w == 1) && (pad_h == 1) && + (dila_w == 1) && (dila_h == 1); + for (int n = 0; n < batch_size; ++n) { + input_data += n * in_ch_size * in_h * in_w; + output_data += n * out_ch_size * out_h * out_w; + auto col_data = workspace.data(); + if (flag_1x1s1p1) { + col_data = output_data; + } + memset(col_data, 0, sizeof(OType) * group_col_size); + for (int g = 0; g < groups; ++g) { + auto input_group_data = input_data + g * group_input_size; + auto filter_group_data = filter_data + g * group_filter_size; + auto col_group_data = col_data + g * group_col_size; + gemm(M, + N, + K, + filter_group_data, + input_group_data, + col_group_data, + static_cast(1), + static_cast(0), + true, + false); + } + if (!flag_1x1s1p1) { + col2im(col_data, + out_ch_size, + out_h, + out_w, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dila_h, + dila_w, + output_data); + } + add_bias_with_relu( + output_data, bias_data, out_ch_size, out_w * out_h, fuse_relu); + } +} + +void test_conv_transpose(int bs, + int ic, + int ih, + int iw, + bool has_bias, + bool fuse_relu, + int filters, + int groups, + int dilation, + int stride, + int padding, + int kernel) { + // prepare input&output variables + Scope scope; + std::string input_var_name("input"); + std::string filter_var_name("filter"); + std::string bias_var_name("bias"); + std::string output_var_name("output"); + std::string output_ref_var_name("output_ref"); + auto* input = scope.Var(input_var_name)->GetMutable(); + auto* filter = scope.Var(filter_var_name)->GetMutable(); + auto* bias = scope.Var(bias_var_name)->GetMutable(); + auto* output = scope.Var(output_var_name)->GetMutable(); + auto* output_ref = scope.Var(output_ref_var_name)->GetMutable(); + + // get group size and input&filter shape + std::vector input_shape = {bs, ic, ih, iw}; + std::vector filter_shape = {ic, filters, kernel, kernel}; + input->Resize(input_shape); + filter->Resize(filter_shape); + + // initialize input&output data + FillTensor(input); + FillTensor(filter); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("conv2d_transpose"); + opdesc.SetInput("Input", {input_var_name}); + opdesc.SetInput("Filter", {filter_var_name}); + opdesc.SetOutput("Output", {output_var_name}); + opdesc.SetAttr("dilations", std::vector({dilation, dilation})); + opdesc.SetAttr("strides", std::vector({stride, stride})); + opdesc.SetAttr("paddings", std::vector({padding, padding})); + opdesc.SetAttr("groups", groups); + opdesc.SetAttr("fuse_relu", static_cast(fuse_relu)); + if (has_bias) { + bias->Resize({1, filters * groups, 1, 1}); + FillTensor(bias); + opdesc.SetInput("Bias", {bias_var_name}); + } + + // create and convert op to NPU model, then run it on NPU + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {input_var_name}, {output_var_name}); + output_ref->CopyDataFrom(*output); + + // execute reference implementation and save to output tensor('out') + conv_transpose_ref(op); + + // compare results + auto* output_data = output->mutable_data(); + auto* output_ref_data = output_ref->mutable_data(); + for (int i = 0; i < output->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } +} + +TEST(NPUBridges, conv_transpose) { +#if 1 + for (auto bs : {1, 2}) { + for (auto ic : {3, 6}) { + for (auto ih : {14, 28}) { + for (auto iw : {14, 28}) { + for (auto has_bias : {false, true}) { + for (auto fuse_relu : {false, true}) { + for (auto filters : {1, 2, 5}) { + for (auto groups : {1 /* , 2, 5*/}) { + for (auto dilation : {1, 2}) { + for (auto stride : {1, 2}) { + for (auto kernel : {1, 3, 5}) { + std::vector paddings = {kernel / 2}; + if (kernel / 2 != 0) { + paddings.push_back(0); + } + for (auto padding : paddings) { + VLOG(3) << "bs: " << bs << " ic: " << ic + << " ih: " << ih << " iw: " << iw + << " has_bias: " << has_bias + << " fuse_relu: " << fuse_relu + << " filters: " << filters + << " groups: " << groups + << " dilation: " << dilation + << " stride: " << stride + << " padding: " << padding + << " kernel: " << kernel; + test_conv_transpose(bs, + ic, + ih, + iw, + has_bias, + fuse_relu, + filters, + groups, + dilation, + stride, + padding, + kernel); + } + } + } + } + } + } + } + } + } + } + } + } +#else + test_conv_transpose(1, 6, 8, 8, false, false, 5, 2, 1, 1, 1, 3); +#endif +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +USE_LITE_OP(conv2d_transpose); +USE_NPU_BRIDGE(conv2d_transpose); diff --git a/lite/npu/bridge/paddle_use_npu_bridges.h b/lite/npu/bridge/paddle_use_npu_bridges.h index b55a511770..ba55f52212 100644 --- a/lite/npu/bridge/paddle_use_npu_bridges.h +++ b/lite/npu/bridge/paddle_use_npu_bridges.h @@ -25,3 +25,8 @@ USE_NPU_BRIDGE(relu); USE_NPU_BRIDGE(elementwise_add); USE_NPU_BRIDGE(scale); USE_NPU_BRIDGE(softmax); +USE_NPU_BRIDGE(concat); +USE_NPU_BRIDGE(split); +USE_NPU_BRIDGE(transpose); +USE_NPU_BRIDGE(transpose2); +USE_NPU_BRIDGE(shuffle_channel); diff --git a/lite/npu/bridge/reshape_op.cc b/lite/npu/bridge/reshape_op.cc new file mode 100644 index 0000000000..439d85c06f --- /dev/null +++ b/lite/npu/bridge/reshape_op.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/reshape_op.h" +#include "ai_ddk_lib/include/graph/buffer.h" +#include "ai_ddk_lib/include/graph/graph.h" +#include "ai_ddk_lib/include/graph/model.h" +#include "ai_ddk_lib/include/graph/op/all_ops.h" +#include "ai_ddk_lib/include/graph/operator.h" +#include "ai_ddk_lib/include/graph/operator_reg.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/utils.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +node_map_type ReshapeConverter(const std::shared_ptr reshape_op, + const node_map_type& inputs_map) { + auto scope = reshape_op->scope(); + auto op_info = reshape_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " + op_type + "..."; + + // get input, output and op attributes + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims(); + + // create reshape node and set input node from inputs_map + auto reshape_node = std::make_shared(unique_op_type); + CHECK(inputs_map.count(x_var_name)); + reshape_node->set_input_tensor(*inputs_map.at(x_var_name)); + OpList::Global().add(inputs_map.at(x_var_name)); + + // read shape from actual shape tensor as input "w" if 'Shape' is found + if (HasInputArg(op_info, scope, "Shape")) { + auto actual_shape_var_name = op_info->Input("Shape").front(); + if (!inputs_map.count(actual_shape_var_name)) { + auto actual_shape = + scope->FindVar(actual_shape_var_name)->GetMutable(); + auto actual_shape_dims = actual_shape->dims(); + auto actual_shape_data = actual_shape->mutable_data(); + auto shape = + std::vector(actual_shape_data, + actual_shape_data + actual_shape_dims.production()); + auto out_dims = operators::ValidateShape(shape, x_dims); + auto out_shape = out_dims.Vectorize(); + if (out_shape.size() > 4) { + LOG(WARNING) + << "NPU DDK only supports less than 4 dimensions, but Shape has " + << out_shape.size(); + } + auto actual_shape_const_node = + std::make_shared(actual_shape_var_name); + actual_shape_const_node->set_attr_value(CreateTensorAndFillData( + std::vector(out_shape.begin(), out_shape.end()))); + reshape_node->set_input_w(*actual_shape_const_node); + OpList::Global().add(actual_shape_const_node); + } else { + reshape_node->set_input_w(*inputs_map.at(actual_shape_var_name)); + OpList::Global().add(inputs_map.at(actual_shape_var_name)); + } + } else { + auto shape = op_info->GetAttr>("shape"); + auto out_dims = operators::ValidateShape(shape, x_dims); + auto out_shape = out_dims.Vectorize(); + if (out_shape.size() > 4) { + LOG(WARNING) + << "NPU DDK only supports less than 4 dimensions, but shape has " + << out_shape.size(); + } + reshape_node->set_attr_shape( + ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end())); + } + OpList::Global().add(reshape_node); + + node_map_type outputs_map; + outputs_map[op_info->Output("Out").front()] = reshape_node; + if (op_type == "reshape2") { + // append an extra reshape node to calc XShape + std::vector xshape_dims(x_dims.size() + 1, 1); + for (size_t i = 0; i < x_dims.size(); i++) { + xshape_dims[i + 1] = x_dims[i]; + } + if (xshape_dims.size() > 4) { + LOG(WARNING) + << "NPU DDK only supports less than 4 dimensions, but XShape has " + << xshape_dims.size(); + } + auto xshape_node = + std::make_shared(unique_op_type + "/xshape"); + xshape_node->set_input_tensor(*inputs_map.at(x_var_name)); + xshape_node->set_attr_shape( + ge::AttrValue::LIST_INT(xshape_dims.begin(), xshape_dims.end())); + OpList::Global().add(xshape_node); + outputs_map[op_info->Output("XShape").front()] = xshape_node; + } + return outputs_map; +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +REGISTER_NPU_BRIDGE(reshape, paddle::lite::npu::bridge::ReshapeConverter); +REGISTER_NPU_BRIDGE(reshape2, paddle::lite::npu::bridge::ReshapeConverter); diff --git a/lite/npu/bridge/reshape_op_test.cc b/lite/npu/bridge/reshape_op_test.cc new file mode 100644 index 0000000000..5f5377a0a1 --- /dev/null +++ b/lite/npu/bridge/reshape_op_test.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/reshape_op.h" +#include +#include +#include "lite/core/op_registry.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/test_helper.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +void reshape_ref(const std::shared_ptr op) { + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + auto x_dims = x->dims(); + auto shape = op_info->GetAttr>("shape"); + auto inplace = op_info->GetAttr("inplace"); + if (op_info->HasInput("Shape")) { + auto actual_shape_var_names = op_info->Input("Shape"); + if (actual_shape_var_names.size() > 0) { + auto actual_shape = scope->FindVar(actual_shape_var_names.front()) + ->GetMutable(); + auto actual_shape_dims = actual_shape->dims(); + auto* actual_shape_data = actual_shape->data(); + shape = + std::vector(actual_shape_data, + actual_shape_data + actual_shape_dims.production()); + } + } + if (inplace) { + out->ShareDataWith(*x); + } else { + out->CopyDataFrom(*x); + } + auto out_dims = operators::ValidateShape(shape, x_dims); + out->Resize(out_dims); +} + +void test_reshape(const std::vector& x_shape, + const std::vector& shape, + const std::vector& act_shape, + bool inplace, + bool reshape2) { + // prepare input&output variables + Scope scope; + std::string x_var_name("x"); + std::string actual_shape_var_name("actual_shape"); + std::string out_var_name("out"); + std::string out_ref_var_name("out_ref"); + std::string xshape_var_name("xshape"); + std::string xshape_ref_var_name("xshape_ref"); + auto x = scope.Var(x_var_name)->GetMutable(); + auto actual_shape = scope.Var(actual_shape_var_name)->GetMutable(); + auto out = scope.Var(out_var_name)->GetMutable(); + auto out_ref = scope.Var(out_ref_var_name)->GetMutable(); + auto xshape = scope.Var(xshape_var_name)->GetMutable(); + auto xshape_ref = scope.Var(xshape_ref_var_name)->GetMutable(); + + x->Resize(x_shape); + + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType(reshape2 ? "reshape2" : "reshape"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("shape", shape); + opdesc.SetAttr("inplace", inplace); + if (!act_shape.empty()) { + int64_t act_shape_size = act_shape.size(); + actual_shape->Resize({act_shape_size}); + memcpy(actual_shape->mutable_data(), + act_shape.data(), + act_shape_size * sizeof(int)); + opdesc.SetInput("Shape", {actual_shape_var_name}); + } + if (reshape2) { + opdesc.SetOutput("XShape", {xshape_var_name}); + } + + // create op and execute reference implementation + auto op = reshape2 ? CreateOp(opdesc, &scope) + : CreateOp(opdesc, &scope); + reshape_ref(op); + out_ref->CopyDataFrom(*out); + if (reshape2) { + xshape_ref->CopyDataFrom(*xshape); + } + + // convert op to NPU model, then run it on NPU + LauchOp(op, + {x_var_name}, + {out_var_name}); // TODO(hong19860320) support XShape for reshape2 + + // compare results + auto out_dims = out->dims(); + auto out_ref_dims = out_ref->dims(); + CHECK_EQ(out_dims.size(), out_ref_dims.size()); + for (int i = 0; i < out_dims.size(); i++) { + CHECK_EQ(out_dims[i], out_ref_dims[i]); + } + auto out_data = out->mutable_data(); + auto out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + } + // if (reshape2) { + // auto xshape_dims = xshape->dims(); + // auto xshape_ref_dims = xshape_ref->dims(); + // CHECK_EQ(xshape_dims.size(), xshape_ref_dims.size()); + // for (size_t i = 0; i < xshape_dims.size(); i++) { + // CHECK_EQ(xshape_dims[i], xshape_ref_dims[i]); + // } + // } +} + +TEST(NPUBridges, reshape) { +#if 1 + std::map, std::vector>> tests = { + {{1, 2, 4, 6}, + {{}, + {-1}, + {48}, + {-1, 48}, + {1, 48}, + {0, 48}, + {48, -1}, + {48, 1}, + {-1, 24}, + {2, 24}, + {24, 0}, + {-1, 0, 3, 2}, + {4, 2, 3, 2}, + {0, -1, 3, 2}, + {1, 8, 3, 2}}}}; + for (auto& i : tests) { + for (auto& shape : i.second) { + if (shape.empty()) { + continue; + } + for (auto& act_shape : i.second) { + for (auto& inplace : {true, false}) { + for (auto& reshape2 : {true, false}) { + std::stringstream ss; + ss << "x:{ "; + for (auto s : i.first) { + ss << s << " "; + } + ss << "} shape:{ "; + for (auto s : shape) { + ss << s << " "; + } + ss << "} act_shape:{ "; + for (auto s : act_shape) { + ss << s << " "; + } + VLOG(3) << ss.str() << "} inplace:" << inplace + << " reshape2:" << reshape2; + test_reshape(i.first, shape, act_shape, inplace, reshape2); + } + } + } + } + } +#else + test_reshape({2, 4, 6}, {-1, 0, 4, 3}, {}, true, true); + test_reshape({1, 232, 14, 14}, {-1, 2, 116, 14, 14}, {}, true, true); +#endif +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +USE_LITE_OP(reshape); +USE_NPU_BRIDGE(reshape); + +USE_LITE_OP(reshape2); +USE_NPU_BRIDGE(reshape2); diff --git a/lite/npu/bridge/scale_op.cc b/lite/npu/bridge/scale_op.cc index 094d49ca28..b2664dc963 100644 --- a/lite/npu/bridge/scale_op.cc +++ b/lite/npu/bridge/scale_op.cc @@ -29,10 +29,11 @@ namespace bridge { node_map_type ScaleConverter(const std::shared_ptr scale_op, const node_map_type& inputs_map) { - VLOG(3) << "invoking ScaleConverter..."; auto scope = scale_op->scope(); auto op_info = scale_op->op_info(); auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " + op_type + "..."; // get input, output and op attributes auto x_var_name = op_info->Input("X").front(); @@ -48,7 +49,7 @@ node_map_type ScaleConverter(const std::shared_ptr scale_op, } // create scale node and set input node from inputs_map - auto scale_node = std::make_shared(UniqueName(op_type)); + auto scale_node = std::make_shared(unique_op_type); CHECK(inputs_map.count(x_var_name)); scale_node->set_input_x(*inputs_map.at(x_var_name)); OpList::Global().add(inputs_map.at(x_var_name)); @@ -56,7 +57,7 @@ node_map_type ScaleConverter(const std::shared_ptr scale_op, // add filter node(fill with scale) auto filter_const_node = - std::make_shared(UniqueName(op_type + "/filter")); + std::make_shared(unique_op_type + "/filter"); filter_const_node->set_attr_value( CreateTensorAndFillData(scale, scale_bias_shape)); scale_node->set_input_filter(*filter_const_node); @@ -65,7 +66,7 @@ node_map_type ScaleConverter(const std::shared_ptr scale_op, // add bias node(fill with bias) if (fabs(bias) > 1e-6f) { auto bias_const_node = - std::make_shared(UniqueName(op_type + "/bias")); + std::make_shared(unique_op_type + "/bias"); bias_const_node->set_attr_value( CreateTensorAndFillData(bias, scale_bias_shape)); scale_node->set_input_bias(*bias_const_node); diff --git a/lite/npu/bridge/shuffle_channel_op.cc b/lite/npu/bridge/shuffle_channel_op.cc new file mode 100644 index 0000000000..cb1bcdbec5 --- /dev/null +++ b/lite/npu/bridge/shuffle_channel_op.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/shuffle_channel_op.h" +#include "ai_ddk_lib/include/graph/buffer.h" +#include "ai_ddk_lib/include/graph/graph.h" +#include "ai_ddk_lib/include/graph/model.h" +#include "ai_ddk_lib/include/graph/op/all_ops.h" +#include "ai_ddk_lib/include/graph/operator.h" +#include "ai_ddk_lib/include/graph/operator_reg.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/utils.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +node_map_type ShuffleChannelConverter( + const std::shared_ptr shuffle_channel_op, + const node_map_type& inputs_map) { + LOG(INFO) << "converting shuffle_channel..."; + lite::Scope* scope = shuffle_channel_op->scope(); + const lite::OpInfo* op_info = shuffle_channel_op->op_info(); + + std::shared_ptr output_node = + std::make_shared(UniqueName("shuffle_channel")); + auto x_var_name = op_info->Input("X").front(); + + output_node->set_input_x(*inputs_map.at(x_var_name)); + output_node->set_attr_group(op_info->GetAttr("group")); + + OpList::Global().add(inputs_map.at(x_var_name)); + OpList::Global().add(output_node); + + node_map_type outputs_map; + outputs_map[op_info->Output("Out").front()] = output_node; + return outputs_map; +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +REGISTER_NPU_BRIDGE(shuffle_channel, + paddle::lite::npu::bridge::ShuffleChannelConverter); diff --git a/lite/npu/bridge/shuffle_channel_op_test.cc b/lite/npu/bridge/shuffle_channel_op_test.cc new file mode 100644 index 0000000000..1ed6b59bc3 --- /dev/null +++ b/lite/npu/bridge/shuffle_channel_op_test.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/shuffle_channel_op.h" +#include +#include "lite/core/op_registry.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/test_helper.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +void shuffle_channel_ref( + const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + auto x_data = x->mutable_data(); + auto out_data = out->mutable_data(); + int group = op_info->GetAttr("group"); + auto x_dims = x->dims(); + + int n_size = x_dims.production() / x_dims[0]; + int c_size = n_size / x_dims[1]; + for (int n = 0; n < x_dims[0]; n++) { + int g_num = x_dims[1] / group; + auto tmp_out_data = out_data; + for (int g = 0; g < g_num; g++) { + auto tmp_x_data = x_data + g * c_size; + for (int i = 0; i < group; i++) { + std::memcpy(tmp_out_data, + tmp_x_data + i * g_num * c_size, + c_size * sizeof(float)); + tmp_out_data += c_size; + } + } + x_data += n_size; + out_data += n_size; + } +} + +void test_shuffle_channel(int bs, int ic, int ih, int iw, int group) { + // prepare input&output variables + Scope scope; + std::string x_var_name = "x"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + x->Resize({bs, ic, ih, iw}); + + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("shuffle_channel"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("group", group); + + // create and convert op to NPU model, then run it on NPU + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {x_var_name}, {out_var_name}); + out_ref->CopyDataFrom(*out); + + // execute reference implementation and save to output tensor + shuffle_channel_ref(op); + + // compare results + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2); + } +} + +TEST(NPUBridges, softmax) { + for (auto bs : {1, 4}) { + for (auto ic : {1, 24, 35}) { + for (auto ih : {1, 4}) { + for (auto iw : {1, 4}) { + for (auto group : {1, 3, 7, 24, 35}) { + if (ic % group != 0) continue; + test_shuffle_channel(bs, ic, ih, iw, group); + } + } + } + } + } +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +USE_LITE_OP(shuffle_channel); +USE_NPU_BRIDGE(shuffle_channel); diff --git a/lite/npu/bridge/split_op.cc b/lite/npu/bridge/split_op.cc new file mode 100644 index 0000000000..f1348c8472 --- /dev/null +++ b/lite/npu/bridge/split_op.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/split_op.h" +#include "ai_ddk_lib/include/graph/buffer.h" +#include "ai_ddk_lib/include/graph/graph.h" +#include "ai_ddk_lib/include/graph/model.h" +#include "ai_ddk_lib/include/graph/op/all_ops.h" +#include "ai_ddk_lib/include/graph/operator.h" +#include "ai_ddk_lib/include/graph/operator_reg.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/utils.h" +#include "lite/npu/npu_helper.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { +node_map_type SplitConverter(const std::shared_ptr split_op, + const node_map_type& inputs_map) { + lite::Scope* scope = split_op->scope(); + const lite::OpInfo* op_info = split_op->op_info(); + auto op_type = op_info->Type(); + auto unique_op_type = UniqueName(op_type); + LOG(INFO) << "Converting " << op_type << " ... "; + + auto x_var_name = op_info->Input("X").front(); + auto axis = op_info->GetAttr("axis"); + auto num = op_info->GetAttr("num"); + auto sections = op_info->GetAttr>("sections"); + int64_t sections_num = static_cast(sections.size()); + + std::shared_ptr output_node = + std::make_shared(unique_op_type); + CHECK(inputs_map.count(x_var_name)); + output_node->set_input_x(*inputs_map.at(x_var_name)); + OpList::Global().add(inputs_map.at(x_var_name)); + + output_node->set_attr_axis(static_cast(axis)); + if (num > 0) { + output_node->set_attr_output_num(static_cast(num)); + } else { + output_node->set_attr_output_num(sections_num); + auto size_split = ge::AttrValue::LIST_INT(sections.begin(), sections.end()); + output_node->set_attr_size_split(size_split); + } + + node_map_type outputs_map; + auto out_var_names = op_info->Output("Out"); + output_node->create_dynamic_output_y(out_var_names.size()); + int index = 1; + for (auto out_var_name : out_var_names) { + auto const_node = std::make_shared( + unique_op_type + "/const_zero" + std::to_string(index)); + const_node->set_attr_value(CreateTensorAndFillData(0)); + OpList::Global().add(const_node); + auto add_node = std::make_shared(unique_op_type + "/add" + + std::to_string(index)); + add_node->set_input_x1(*output_node, "y" + std::to_string(index)); + add_node->set_input_x2(*const_node); + outputs_map[out_var_name] = add_node; + OpList::Global().add(add_node); + index++; + } + + OpList::Global().add(output_node); + return outputs_map; +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +REGISTER_NPU_BRIDGE(split, paddle::lite::npu::bridge::SplitConverter); diff --git a/lite/npu/bridge/split_op_test.cc b/lite/npu/bridge/split_op_test.cc new file mode 100644 index 0000000000..c8a74ee5a9 --- /dev/null +++ b/lite/npu/bridge/split_op_test.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/split_op.h" +#include +#include "lite/core/op_registry.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/test_helper.h" + +namespace paddle { +namespace lite { +namespace npu { +namespace bridge { + +template +void split_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + int num = op_info->GetAttr("num"); + int axis = op_info->GetAttr("axis"); + std::vector sections = op_info->GetAttr>("sections"); + std::vector output_vec; + auto output = op_info->Output("Out"); + for (auto out_var : output) { + output_vec.push_back(scope->Var(out_var)->GetMutable()); + } + auto in_dims = x->dims(); + auto rank = in_dims.size(); + int outs_number = output_vec.size(); + std::vector outs_dims; + outs_dims.reserve(outs_number); + if (axis < 0) { + axis += rank; + } + if (num > 0) { + int out_axis_dim = in_dims[axis] / num; + for (int i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = out_axis_dim; + outs_dims.push_back(dim); + } + } else if (sections.size() > 0) { + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = sections[i]; + outs_dims.push_back(dim); + } + } + for (int j = 0; j < outs_dims.size(); ++j) { + output_vec[j]->Resize(outs_dims[j]); + } + + const dtype* din = x->mutable_data(); + std::vector in_strides(in_dims.size()); + in_strides[in_dims.size() - 1] = in_dims[in_dims.size() - 1]; + for (int i = in_dims.size() - 2; i >= 0; --i) { + in_strides[i] = in_strides[i + 1] * in_dims[i]; + } + + int input_offset = 0; + for (auto out : output_vec) { + auto out_dim = out->dims(); + std::vector out_strides(out_dim.size()); + out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; + for (int i = out_dim.size() - 2; i >= 0; --i) { + out_strides[i] = out_strides[i + 1] * out_dim[i]; + } + + dtype* out_data = out->mutable_data(); + int before = out_strides[0] / out_strides[axis]; + int in_after = in_strides[axis]; + int out_after = out_strides[axis]; + + for (int i = 0; i < before; ++i) { + std::memcpy(out_data + i * out_after, + din + input_offset + i * in_after, + sizeof(dtype) * out_after); + } + input_offset += out_strides[axis]; + } +} + +void test_split(int bs, + int ic, + int ih, + int iw, + int axis, + int num, + std::vector sections) { + const auto& bridges = lite::npu::bridge::Factory::Instance(); + const auto& supported_lists = bridges.AllFunctions(); + CHECK(bridges.HasType("split")); + // prepare input&output variables + std::string x_var_name = "x"; + std::string out_var_name_1 = "out_1"; + std::string out_var_name_2 = "out_2"; + std::string out_ref_var_name_1 = "out_ref_1"; + std::string out_ref_var_name_2 = "out_ref_2"; + + Scope scope; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out_1 = scope.Var(out_var_name_1)->GetMutable(); + auto* out_2 = scope.Var(out_var_name_2)->GetMutable(); + auto* out_ref_1 = scope.Var(out_ref_var_name_1)->GetMutable(); + auto* out_ref_2 = scope.Var(out_ref_var_name_2)->GetMutable(); + x->Resize({bs, ic, ih, iw}); + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("split"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name_1, out_var_name_2}); + opdesc.SetAttr("axis", axis); + opdesc.SetAttr("sections", sections); + opdesc.SetAttr("num", num); + // create and convert op to NPU model, then run it on NPU + auto op = CreateOp(opdesc, &scope); + LauchOp(op, {x_var_name}, {out_var_name_1, out_var_name_2}); + out_ref_1->CopyDataFrom(*out_1); + out_ref_2->CopyDataFrom(*out_2); + // execute reference implementation and save to output tensor + split_ref(op); + + // compare results + auto* out_data_1 = out_1->mutable_data(); + auto* out_data_2 = out_2->mutable_data(); + auto* out_ref_data_1 = out_ref_1->mutable_data(); + auto* out_ref_data_2 = out_ref_2->mutable_data(); + for (int i = 0; i < out_1->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data_1[i], out_ref_data_1[i], 5e-4); + } + for (int i = 0; i < out_2->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data_2[i], out_ref_data_2[i], 5e-4); + } +} + +TEST(NPUBridges, split) { + test_split(4, 2, 3, 1, 0, 2, {}); + test_split(4, 2, 3, 1, 0, 0, {3, 1}); + test_split(4, 6, 3, 1, 1, 2, {}); + test_split(4, 6, 3, 1, 1, 0, {2, 4}); + test_split(4, 2, 2, 1, 2, 2, {}); + test_split(4, 2, 6, 1, 2, 0, {3, 3}); + test_split(4, 2, 3, 4, 3, 2, {}); + test_split(4, 2, 3, 6, 3, 0, {5, 1}); +} + +} // namespace bridge +} // namespace npu +} // namespace lite +} // namespace paddle + +USE_LITE_OP(split); +USE_NPU_BRIDGE(split); diff --git a/lite/npu/bridge/transpose_op_test.cc b/lite/npu/bridge/transpose_op_test.cc index 18421ad3b8..5bb3006f2b 100644 --- a/lite/npu/bridge/transpose_op_test.cc +++ b/lite/npu/bridge/transpose_op_test.cc @@ -33,7 +33,7 @@ int data_index(std::vector pos, DDimLite dims) { std::vector pos_trans(std::vector in_pos, std::vector axis) { std::vector out_pos(in_pos.size()); for (int i = 0; i < axis.size(); i++) { - out_pos[axis[i]] = in_pos[1]; + out_pos[axis[i]] = in_pos[i]; } return out_pos; } @@ -88,11 +88,7 @@ void test_transpose(int bs, int ic, int ih, int iw, std::vector axis) { x->Resize({bs, ic, ih, iw}); // initialize input&output data - // FillTensor(x); - auto* x_data = x->mutable_data(); - for (int i = 0; i < x->numel(); i++) { - x_data[i] = i; - } + FillTensor(x); // initialize op desc cpp::OpDesc opdesc; @@ -123,7 +119,12 @@ TEST(NPUBridges, transpose) { for (auto ic : {1, 4, 7}) { for (auto ih : {1, 4, 7}) { for (auto iw : {1, 4, 7}) { - for (auto axis : {std::vector{0, 1, 2, 3}}) { + for (auto axis : {std::vector{0, 1, 2, 3}, + std::vector{0, 1, 3, 2}, + std::vector{0, 3, 1, 2}, + std::vector{1, 2, 3, 0}, + std::vector{3, 2, 1, 0}, + std::vector{2, 3, 1, 0}}) { test_transpose(bs, ic, ih, iw, axis); } } @@ -131,8 +132,8 @@ TEST(NPUBridges, transpose) { } } #endif - // test_transpose(2, 3, 4, 5, std::vector{0,1,3,2}); - test_transpose(2, 3, 4, 5, std::vector{0, 1, 2, 3}); + test_transpose(2, 3, 4, 5, std::vector{0, 1, 3, 2}); + // test_transpose(2, 3, 4, 5, std::vector{0, 1, 2, 3}); // test_transpose(2, 2, 2, 2, std::vector{0,1,3,2}); // test_transpose(1, 1, 2, 2, std::vector{0,1,3,2}); // test_transpose(1, 1, 1, 2, std::vector{0,1,2,3}); diff --git a/lite/npu/bridge/utils.h b/lite/npu/bridge/utils.h index 4ad980fd2b..2bccbccb07 100644 --- a/lite/npu/bridge/utils.h +++ b/lite/npu/bridge/utils.h @@ -41,8 +41,8 @@ ge::TensorPtr CvtFromLiteTensor(Tensor* in_tensor, DataLayoutType in_ltype = DATALAYOUT(kNCHW)); template -ge::TensorPtr CreateTensorAndFillData(T value, - std::vector shape = {1}, +ge::TensorPtr CreateTensorAndFillData(std::vector data, + std::vector shape = {}, ge::Format format = ge::FORMAT_NCHW) { const std::type_info& info = typeid(T); ge::DataType type = ge::DT_FLOAT; @@ -55,17 +55,33 @@ ge::TensorPtr CreateTensorAndFillData(T value, } else { LOG(FATAL) << "Unknow value type " << info.name(); } + if (shape.empty()) { + shape = {static_cast(data.size())}; + } else { + int size = 1; + for (auto i : shape) { + size *= i; + } + CHECK_EQ(data.size(), size); + } ge::TensorDesc desc(ge::Shape(shape), format, type); ge::TensorPtr tensor = std::make_shared(); tensor->SetTensorDesc(desc); - int64_t data_num = 1; + tensor->SetData(reinterpret_cast(data.data()), + data.size() * sizeof(T)); + return tensor; +} + +template +ge::TensorPtr CreateTensorAndFillData(T value, + std::vector shape = {1}, + ge::Format format = ge::FORMAT_NCHW) { + int64_t size = 1; for (auto i : shape) { - data_num *= i; + size *= i; } - std::vector data_value(data_num, value); - tensor->SetData(reinterpret_cast(data_value.data()), - data_num * sizeof(T)); - return tensor; + std::vector data(size, value); + return CreateTensorAndFillData(data, shape, format); } std::shared_ptr CvtNode2Tensor(const lite::mir::Node* arg_node); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 66131e3b4d..1362a86797 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -5,6 +5,7 @@ lite_cc_library(pool_op SRCS pool_op.cc DEPS ${op_DEPS}) lite_cc_library(fc_op SRCS fc_op.cc DEPS ${op_DEPS}) lite_cc_library(relu_op SRCS relu_op.cc DEPS ${op_DEPS}) lite_cc_library(mul_op SRCS mul_op.cc DEPS ${op_DEPS}) +lite_cc_library(matmul_op SRCS matmul_op.cc DEPS ${op_DEPS}) lite_cc_library(scale_op SRCS scale_op.cc DEPS ${op_DEPS}) lite_cc_library(softmax_op SRCS softmax_op.cc DEPS ${op_DEPS}) lite_cc_library(reshape_op SRCS reshape_op.cc DEPS ${op_DEPS} ) @@ -81,6 +82,8 @@ lite_cc_library(is_empty SRCS is_empty_op.cc DEPS ${op_DEPS}) lite_cc_library(shape_op_lite SRCS shape_op.cc DEPS ${op_DEPS}) lite_cc_library(cast_op_lite SRCS cast_op.cc DEPS ${op_DEPS}) lite_cc_library(slice_op_lite SRCS slice_op.cc DEPS ${op_DEPS}) +lite_cc_library(squeeze_op_lite SRCS squeeze_op.cc DEPS ${op_DEPS}) +lite_cc_library(expand_op_lite SRCS expand_op.cc DEPS ${op_DEPS}) set(ops @@ -89,6 +92,7 @@ set(ops fc_op relu_op mul_op + matmul_op scale_op softmax_op reshape_op @@ -164,6 +168,8 @@ set(ops shape_op_lite cast_op_lite slice_op_lite + squeeze_op_lite + expand_op_lite CACHE INTERNAL "ops lite") if (NOT LITE_WITH_X86) diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index a0ba1d1cd4..c4f81f8b25 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -109,6 +109,7 @@ REGISTER_LITE_OP(tanh, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(swish, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp); +REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp); #ifdef LITE_WITH_TRAIN REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index f5c668a216..b84b4ff169 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -44,7 +44,7 @@ bool ConvTransposeOpLite::InferShape() const { std::vector output_shape; output_shape.push_back(in_dims[0]); - output_shape.push_back(filter_dims[0] * param_.groups); + output_shape.push_back(filter_dims[1] * param_.groups); for (int i = 0; i < param_.strides.size(); i++) { int kernel_extent = param_.dilations[i] * (filter_dims[i + 2] - 1) + 1; int output_len = (in_dims[i + 2] - 1) * param_.strides[i] + kernel_extent - @@ -60,10 +60,9 @@ bool ConvTransposeOpLite::InferShape() const { // TODO(Superjomn) replace framework::OpDesc with a lite one. bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { - auto X = op_desc.Input("x").front(); - auto Filter = op_desc.Input("filter").front(); - auto Out = op_desc.Output("output").front(); - + auto X = op_desc.Input("Input").front(); + auto Filter = op_desc.Input("Filter").front(); + auto Out = op_desc.Output("Output").front(); param_.x = scope->FindVar(X)->GetMutable(); param_.filter = scope->FindVar(Filter)->GetMutable(); param_.output = scope->FindVar(Out)->GetMutable(); @@ -75,9 +74,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc, // optional params std::vector input_arg_names = op_desc.InputArgumentNames(); - if (std::find(input_arg_names.begin(), input_arg_names.end(), "bias") != + if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != input_arg_names.end()) { - auto bias_arguments = op_desc.Input("bias"); + auto bias_arguments = op_desc.Input("Bias"); if (bias_arguments.size() > 0) { auto bias_var = scope->FindVar(bias_arguments.front()); if (bias_var != nullptr) { @@ -87,6 +86,7 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc, } } param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + return true; } } // namespace operators diff --git a/lite/operators/expand_op.cc b/lite/operators/expand_op.cc new file mode 100644 index 0000000000..656e8babc0 --- /dev/null +++ b/lite/operators/expand_op.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/expand_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool ExpandOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + int expand_size = param_.expand_times.size(); + int x_dims_size = param_.X->dims().size(); + CHECK_EQ(expand_size, x_dims_size) + << "The number of expand_times size must be qual to the rank of " + "Input(X)."; + CHECK_LE(param_.X->dims().size(), 6) + << "The rank of Input(X) must not be greater than 6."; + return true; +} + +bool ExpandOpLite::InferShape() const { + DDim out_dims(param_.X->dims()); + for (size_t i = 0; i < param_.expand_times.size(); ++i) { + out_dims[i] *= param_.expand_times[i]; + } + param_.Out->Resize(out_dims); + return true; +} + +bool ExpandOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { + auto X_name = opdesc.Input("X").front(); + auto Out_name = opdesc.Output("Out").front(); + param_.X = GetVar(scope, X_name); + param_.Out = GetMutableVar(scope, Out_name); + param_.expand_times = opdesc.GetAttr>("expand_times"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(expand, paddle::lite::operators::ExpandOpLite); diff --git a/lite/operators/expand_op.h b/lite/operators/expand_op.h new file mode 100644 index 0000000000..ce5dcda9e8 --- /dev/null +++ b/lite/operators/expand_op.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ExpandOpLite : public OpLite { + public: + ExpandOpLite() {} + explicit ExpandOpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "expand"; } + + private: + mutable ExpandParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/interpolate_op.cc b/lite/operators/interpolate_op.cc index deb2e9e37c..b9ab9fa07e 100644 --- a/lite/operators/interpolate_op.cc +++ b/lite/operators/interpolate_op.cc @@ -67,9 +67,12 @@ bool InterpolateOp::InferShape() const { bool InterpolateOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { auto X = op_desc.Input("X").front(); - if (op_desc.Input("OutSize").size() > 0) { - auto OutSize = op_desc.Input("OutSize").front(); - param_.OutSize = scope->FindVar(OutSize)->GetMutable(); + if (op_desc.HasInput("OutSize")) { + auto out_size_var_names = op_desc.Input("OutSize"); + if (out_size_var_names.size() > 0) { + param_.OutSize = scope->FindVar(out_size_var_names.front()) + ->GetMutable(); + } } else { param_.OutSize = nullptr; } diff --git a/lite/operators/matmul_op.cc b/lite/operators/matmul_op.cc new file mode 100644 index 0000000000..90cfd1ddad --- /dev/null +++ b/lite/operators/matmul_op.cc @@ -0,0 +1,138 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/matmul_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool MatMulOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Y); + CHECK_OR_FALSE(param_.Out); + + return true; +} + +bool MatMulOpLite::InferShape() const { + const auto x_dims = param_.X->dims(); + const auto y_dims = param_.Y->dims(); + bool x_transpose = param_.transpose_X; + bool y_transpose = param_.transpose_Y; + std::vector dim_out_vec; + + if (x_dims.size() > 2 && y_dims.size() >= 2) { + // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] + // x: [B, M, K], y: [K, N], out: [B, M, N] + CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + dim_out_vec.resize(x_dims.size()); + for (size_t i = 0; i < x_dims.size() - 1; ++i) { + dim_out_vec[i] = x_dims[i]; + } + dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 1]; + } else if (x_dims.size() == 2 && y_dims.size() == 2) { + // x: [M, K], y: [K, N], out: [M, N] + // x: [M, K], y: [K, N], out: [M, N] + if (!x_transpose && !y_transpose) { + CHECK_EQ(x_dims[1], y_dims[0]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << "), x_transpose is " << x_transpose << ", y_transpose is " + << y_transpose; + } else if (!x_transpose && y_transpose) { + CHECK_EQ(x_dims[1], y_dims[1]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << "), x_transpose is " << x_transpose << ", y_transpose is " + << y_transpose; + } else if (x_transpose && !y_transpose) { + CHECK_EQ(x_dims[0], y_dims[0]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << "), x_transpose is " << x_transpose << ", y_transpose is " + << y_transpose; + } else { + CHECK_EQ(x_dims[0], y_dims[1]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << "), x_transpose is " << x_transpose << ", y_transpose is " + << y_transpose; + } + dim_out_vec.resize(x_dims.size()); + if (x_transpose) { + dim_out_vec[0] = x_dims[1]; + } else { + dim_out_vec[0] = x_dims[0]; + } + if (y_transpose) { + dim_out_vec[1] = y_dims[0]; + } else { + dim_out_vec[1] = y_dims[1]; + } + } else if (x_dims.size() > 2 && y_dims.size() == 1) { + // x: [B, M, K], y: [K], out: [B, M] + CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + dim_out_vec.resize(x_dims.size() - 1); + for (size_t i = 0; i < dim_out_vec.size(); ++i) { + dim_out_vec[i] = x_dims[i]; + } + } else if (x_dims.size() == 1 && y_dims.size() == 1) { // todo + // x: [K], y: [K], out: [1] + if (x_dims[0] == y_dims[0] && x_transpose == false && + y_transpose == false) { + dim_out_vec.resize(1); + dim_out_vec[0] = 1; + } + // x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N] + if (x_transpose == true && y_transpose == true) { + dim_out_vec.resize(2); + dim_out_vec[0] = x_dims[0]; + dim_out_vec[1] = y_dims[0]; + } + } else { + LOG(FATAL) << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + } + + DDim dim_out(dim_out_vec); + param_.Out->Resize(dim_out); + + return true; +} + +bool MatMulOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + CHECK(!op_desc.Input("X").empty()); + CHECK(!op_desc.Input("Y").empty()); + CHECK(!op_desc.Output("Out").empty()); + + auto X = op_desc.Input("X").front(); + auto Y = op_desc.Input("Y").front(); + auto Out = op_desc.Output("Out").front(); + + param_.X = GetVar(scope, X); + param_.Y = GetVar(scope, Y); + param_.Out = GetMutableVar(scope, Out); + param_.transpose_X = op_desc.GetAttr("transpose_X"); + param_.transpose_Y = op_desc.GetAttr("transpose_Y"); + param_.alpha = op_desc.GetAttr("alpha"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(matmul, paddle::lite::operators::MatMulOpLite); diff --git a/lite/operators/matmul_op.h b/lite/operators/matmul_op.h new file mode 100644 index 0000000000..0aa47c89dd --- /dev/null +++ b/lite/operators/matmul_op.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class MatMulOpLite : public OpLite { + public: + MatMulOpLite() {} + + explicit MatMulOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; + + std::string DebugString() const override { return "matmul"; } + + private: + mutable MatMulParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index b416b9e683..c011aa5e0c 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -694,6 +694,32 @@ struct SliceParam { std::vector ends{}; std::vector decrease_axis{}; }; + +/// ----------------------- shape operators ---------------------- +struct SqueezeParam { + const lite::Tensor* X{}; + lite::Tensor* Out{}; + lite::Tensor* XShape{}; + std::vector axes{}; +}; + +/// ----------------------- expand operators ---------------------- +struct ExpandParam { + const lite::Tensor* X{}; + lite::Tensor* Out{}; + std::vector expand_times{}; +}; + +/// ----------------------- matmul operators ---------------------- +struct MatMulParam { + const lite::Tensor* X{}; + const lite::Tensor* Y{}; + lite::Tensor* Out{}; + bool transpose_X{false}; + bool transpose_Y{false}; + float alpha{1.0f}; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/operators/squeeze_op.cc b/lite/operators/squeeze_op.cc new file mode 100644 index 0000000000..19bd20f1ac --- /dev/null +++ b/lite/operators/squeeze_op.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/squeeze_op.h" +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +static DDim GetOutputShape(const std::vector &squeeze_dims, + const DDim &in_dims, + bool is_runtime) { + size_t num_squeeze_dims = squeeze_dims.size(); + int cnt_squeezed_dims = 0; + bool should_squeeze[9] = {false}; + + // Determines number of dimensions of output tensor after squeeze. + // Mark and count the dimensions need to be squeezed + if (num_squeeze_dims == 0) { + for (int idx = 0; idx < in_dims.size(); ++idx) { + if (in_dims[idx] == 1) { + should_squeeze[idx] = true; + ++cnt_squeezed_dims; + } + } + } else { + for (size_t idx = 0; idx < num_squeeze_dims; ++idx) { + int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() + : squeeze_dims[idx]; + // Check current index, the upper limit has been checked. + CHECK_GE(current, 0) + << "Invalid axis, the negative axis is out of range."; + + if (is_runtime) { + CHECK_EQ(in_dims[current], 1) << "Invalid axis index, the axis that " + "will be squeezed should be equal " + "to 1."; + } + + if (!(should_squeeze[current])) { + ++cnt_squeezed_dims; + } + should_squeeze[current] = true; + } + } + + // Make output dimensions + std::vector output_shape(in_dims.size() - cnt_squeezed_dims, 0); + for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) { + if (!should_squeeze[in_idx]) { + output_shape[out_idx++] = in_dims[in_idx]; + } + } + return DDim(output_shape); +} + +bool SqueezeOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + for (int a : param_.axes) { + CHECK_LT(a, static_cast(param_.X->dims().size())) + << "The squeeze axis should be less than input tensor's rank."; + } + return true; +} + +bool SqueezeOp::InferShape() const { + std::vector squeeze_dims = param_.axes; + DDim in_dims = param_.X->dims(); + DDim out_dim = GetOutputShape(squeeze_dims, in_dims, true); + param_.Out->Resize(out_dim); + return true; +} + +bool SqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + auto x_var = scope->FindVar(opdesc.Input("X").front()); + auto output_var = scope->FindVar(opdesc.Output("Out").front()); + CHECK(x_var); + CHECK(output_var); + param_.X = const_cast(&(x_var->Get())); + param_.Out = output_var->GetMutable(); + + if (opdesc.HasAttr("axes")) { + param_.axes = opdesc.GetAttr>("axes"); + } + CHECK(param_.X) << "Input(X) of SqueezeOp should not be null."; + CHECK(param_.Out) << "Output(Out) of SqueezeOp should not be null."; + return true; +} + +bool Squeeze2Op::CheckShape() const { + SqueezeOp::CheckShape(); + CHECK_OR_FALSE(param_.XShape); + return true; +} + +bool Squeeze2Op::InferShape() const { + SqueezeOp::InferShape(); + auto x_dims = param_.X->dims(); + std::vector xshape_dims(x_dims.size() + 1, 1); + for (size_t i = 0; i < x_dims.size(); i++) { + xshape_dims[i + 1] = x_dims[i]; + } + param_.XShape->Resize(DDim(xshape_dims)); + return true; +} + +bool Squeeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + SqueezeOp::AttachImpl(opdesc, scope); + auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); + CHECK(xshape_var); + param_.XShape = xshape_var->GetMutable(); + CHECK(param_.XShape) << "Output(XShape) of ReshapeOp should not be null."; + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(squeeze, paddle::lite::operators::SqueezeOp); +REGISTER_LITE_OP(squeeze2, paddle::lite::operators::Squeeze2Op); diff --git a/lite/operators/squeeze_op.h b/lite/operators/squeeze_op.h new file mode 100644 index 0000000000..1a550c5fbe --- /dev/null +++ b/lite/operators/squeeze_op.h @@ -0,0 +1,61 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SqueezeOp : public OpLite { + public: + SqueezeOp() {} + explicit SqueezeOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "squeeze"; } + + protected: + mutable SqueezeParam param_; +}; + +class Squeeze2Op : public SqueezeOp { + public: + Squeeze2Op() : SqueezeOp() {} + explicit Squeeze2Op(const std::string &op_type) : SqueezeOp(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "squeeze2"; } +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index eaca2e41ae..777e4408c0 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -34,4 +34,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() diff --git a/lite/tests/kernels/activation_compute_test.cc b/lite/tests/kernels/activation_compute_test.cc index f35d08ef9a..a0cc9e5775 100644 --- a/lite/tests/kernels/activation_compute_test.cc +++ b/lite/tests/kernels/activation_compute_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" @@ -29,8 +30,11 @@ enum activation_type_test { SIGMOID, TANH, SWISH, - RELU6 + RELU6, + LOG, + EXP }; + class ActivationComputeTester : public arena::TestCase { protected: // common attributes for this op. @@ -154,6 +158,18 @@ class ActivationComputeTester : public arena::TestCase { } break; } + case LOG: { + for (int i = 0; i < dims_.production(); i++) { + output_data[i] = std::log(x_data[i]); + } + break; + } + case EXP: { + for (int i = 0; i < dims_.production(); i++) { + output_data[i] = std::exp(x_data[i]); + } + break; + } default: LOG(INFO) << "the type of activation is unknow."; } @@ -182,6 +198,7 @@ class ActivationComputeTester : public arena::TestCase { std::vector data(dims_.production()); for (int i = 0; i < dims_.production(); i++) { float sign = i % 3 == 0 ? -1.0f : 1.0f; + sign = type_ == "log" ? 1 : sign; data[i] = sign * static_cast(i % 128) * 0.013f + 0.001; } SetCommonTensor(input_, dims_, data.data()); @@ -417,7 +434,7 @@ TEST(Activation_swish, precision) { } TEST(Activation_relu6, precision) { - LOG(INFO) << "test relu6 op"; + LOG(INFO) << "test relu6 op..."; #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); @@ -445,5 +462,62 @@ TEST(Activation_relu6, precision) { } #endif } + +TEST(Activation_log, precision) { + LOG(INFO) << "test log op"; +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + + for (auto n : {1, 3}) { + for (auto c : {3, 6}) { + for (auto h : {9, 18}) { + for (auto w : {9, 18}) { + std::unique_ptr tester(new ActivationComputeTester( + place, + "def", + 0.01, + 6., + "all", + 0., + DDim(std::vector({n, c, h, w})), + "log", + LOG)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } +#endif +} + +TEST(Activation_exp, precision) { + LOG(INFO) << "test exp op"; +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + + for (auto n : {1, 3}) { + for (auto c : {3, 6}) { + for (auto h : {9, 18}) { + for (auto w : {9, 18}) { + std::unique_ptr tester(new ActivationComputeTester( + place, + "def", + 0.01, + 6., + "all", + 0., + DDim(std::vector({n, c, h, w})), + "exp", + EXP)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } +#endif +} + } // namespace lite } // namespace paddle diff --git a/lite/tests/kernels/conv2d_transpose_compute_test.cc b/lite/tests/kernels/conv2d_transpose_compute_test.cc index 4faa7fcb43..c44259022d 100644 --- a/lite/tests/kernels/conv2d_transpose_compute_test.cc +++ b/lite/tests/kernels/conv2d_transpose_compute_test.cc @@ -350,22 +350,22 @@ class Conv2DTransposeComputeTester : public arena::TestCase { void PrepareOpDesc(cpp::OpDesc* op_desc) { op_desc->SetType("conv2d_transpose"); - op_desc->SetInput("x", {x_}); - op_desc->SetInput("filter", {filter_}); - op_desc->SetOutput("output", {output_}); + op_desc->SetInput("Input", {x_}); + op_desc->SetInput("Filter", {filter_}); + op_desc->SetOutput("Output", {output_}); op_desc->SetAttr("strides", strides_); op_desc->SetAttr("paddings", paddings_); op_desc->SetAttr("groups", groups_); op_desc->SetAttr("dilations", dilations_); if (flag_bias_) { - op_desc->SetInput("bias", {bias_}); + op_desc->SetInput("Bias", {bias_}); } op_desc->SetAttr("fuse_relu", flag_relu_); } void PrepareData() override { std::vector input_shape = {n_, ic_, ih_, iw_}; - std::vector filter_shape = {oc_ / groups_, ic_, ks_, ks_}; + std::vector filter_shape = {ic_, oc_ / groups_, ks_, ks_}; std::vector bias_shape = {1, oc_, 1, 1}; // x tensor diff --git a/lite/tests/kernels/expand_compute_test.cc b/lite/tests/kernels/expand_compute_test.cc new file mode 100644 index 0000000000..4ab1c15a5e --- /dev/null +++ b/lite/tests/kernels/expand_compute_test.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +class ExpandComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string x_ = "X"; + std::string out_ = "Out"; + std::vector expand_times_; + DDim dims_; + + public: + ExpandComputeTester(const Place& place, + const std::string& alias, + const std::vector& expand_times, + DDim dims) + : TestCase(place, alias), expand_times_(expand_times), dims_(dims) {} + + void RunBaseline(Scope* scope) override { + const auto* input = scope->FindTensor(x_); + CHECK(input); + auto* out = scope->NewTensor(out_); + CHECK(out); + + DDim out_shape(input->dims()); + DDim in_shape = input->dims(); + + for (size_t i = 0; i < expand_times_.size(); ++i) { + out_shape[i] *= expand_times_[i]; + } + out->Resize(out_shape); + float* out_data = out->mutable_data(); + const float* input_data = input->data(); + std::vector in_stride(in_shape.size(), 1), + out_stride(out_shape.size(), 1); + for (int i = in_shape.size() - 2; i >= 0; --i) { + in_stride[i] = in_shape[i + 1] * in_stride[i + 1]; + } + for (int i = out_shape.size() - 2; i >= 0; --i) { + out_stride[i] = out_shape[i + 1] * out_stride[i + 1]; + } + for (size_t out_id = 0; out_id < out_shape.production(); ++out_id) { + int in_id = 0; + for (int i = expand_times_.size() - 1; i >= 0; --i) { + int in_j = (out_id / out_stride[i]) % in_shape[i]; + in_id += in_j * in_stride[i]; + } + out_data[out_id] = input_data[in_id]; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("expand"); + op_desc->SetInput("X", {x_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetAttr("expand_times", expand_times_); + } + + void PrepareData() override { + std::vector in_data(dims_.production()); + for (int i = 0; i < dims_.production(); ++i) { + in_data[i] = i; + } + SetCommonTensor(x_, dims_, in_data.data()); + } +}; + +void test_expand_3dim(Place place) { + for (std::vector expand_times : {std::vector({2, 3, 1}), + std::vector({2, 2, 2}), + std::vector({3, 1, 2})}) { + for (int C : {3}) { + for (int H : {2}) { + for (int W : {4}) { + std::unique_ptr tester(new ExpandComputeTester( + place, "def", expand_times, DDim({C, H, W}))); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } +} + +void test_expand_4dim(Place place) { + for (std::vector expand_times : {std::vector({2, 3, 1, 4}), + std::vector({2, 2, 2, 2}), + std::vector({3, 1, 2, 1})}) { + for (int N : {2}) { + for (int C : {3}) { + for (int H : {2}) { + for (int W : {4}) { + std::unique_ptr tester(new ExpandComputeTester( + place, "def", expand_times, DDim({N, C, H, W}))); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } +} + +TEST(Expand, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_expand_3dim(place); + test_expand_4dim(place); +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/matmul_compute_test.cc b/lite/tests/kernels/matmul_compute_test.cc new file mode 100644 index 0000000000..648180f832 --- /dev/null +++ b/lite/tests/kernels/matmul_compute_test.cc @@ -0,0 +1,452 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +void matrix_mul(int m_, + int k_, + int n_, + float alpha, + const float* x, + const float* y, + float* out) { + for (int m = 0; m < m_; ++m) { + for (int n = 0; n < n_; ++n) { + out[m * n_ + n] = 0; + for (int k = 0; k < k_; ++k) { + out[m * n_ + n] += x[m * k_ + k] * y[k * n_ + n] * alpha; + } + } + } +} + +void transpose(int m, int n, const float* src, float* dst) { + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + dst[j * m + i] = src[i * n + j]; + } + } +} + +void mul_low_efficiency(DDim x_dims_, + DDim y_dims_, + bool x_transpose_, + bool y_transpose_, + float alpha_, + const float* x_data, + const float* y_data, + float* out_data) { + if (!x_transpose_ && !y_transpose_) { + CHECK_EQ(x_dims_[1], y_dims_[0]) + << "not supported x_dims(" << x_dims_ << ") and y_dims(" << y_dims_ + << "), x_transpose is " << x_transpose_ << ", y_transpose is " + << y_transpose_; + matrix_mul( + x_dims_[0], y_dims_[0], y_dims_[1], alpha_, x_data, y_data, out_data); + } else if (!x_transpose_ && y_transpose_) { + CHECK_EQ(x_dims_[1], y_dims_[1]) + << "not supported x_dims(" << x_dims_ << ") and y_dims(" << y_dims_ + << "), x_transpose is " << x_transpose_ << ", y_transpose is " + << y_transpose_; + float* y_data_trans = + static_cast(malloc(sizeof(float) * y_dims_[0] * y_dims_[1])); + transpose(y_dims_[0], y_dims_[1], y_data, y_data_trans); + matrix_mul(x_dims_[0], + x_dims_[1], + y_dims_[0], + alpha_, + x_data, + y_data_trans, + out_data); + free(y_data_trans); + } else if (x_transpose_ && !y_transpose_) { + CHECK_EQ(x_dims_[0], y_dims_[0]) + << "not supported x_dims(" << x_dims_ << ") and y_dims(" << y_dims_ + << "), x_transpose is " << x_transpose_ << ", y_transpose is " + << y_transpose_; + float* x_data_trans = + static_cast(malloc(sizeof(float) * x_dims_[0] * x_dims_[1])); + transpose(x_dims_[0], x_dims_[1], x_data, x_data_trans); + matrix_mul(x_dims_[1], + x_dims_[0], + y_dims_[1], + alpha_, + x_data_trans, + y_data, + out_data); + free(x_data_trans); + } else { + CHECK_EQ(x_dims_[0], y_dims_[1]) + << "not supported x_dims(" << x_dims_ << ") and y_dims(" << y_dims_ + << "), x_transpose is " << x_transpose_ << ", y_transpose is " + << y_transpose_; + float* x_data_trans = + static_cast(malloc(sizeof(float) * x_dims_[0] * x_dims_[1])); + float* y_data_trans = + static_cast(malloc(sizeof(float) * y_dims_[0] * y_dims_[1])); + transpose(x_dims_[0], x_dims_[1], x_data, x_data_trans); + transpose(y_dims_[0], y_dims_[1], y_data, y_data_trans); + matrix_mul(x_dims_[1], + x_dims_[0], + y_dims_[0], + alpha_, + x_data_trans, + y_data_trans, + out_data); + free(x_data_trans); + free(y_data_trans); + } +} + +class MatMulComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string x_ = "X"; + std::string y_ = "Y"; + std::string out_ = "Out"; + DDim x_dims_; + DDim y_dims_; + bool x_transpose_; + bool y_transpose_; + float alpha_; + + public: + MatMulComputeTester(const Place& place, + const std::string& alias, + bool x_transpose, + bool y_transpose, + float alpha, + const DDim& x_dims, + const DDim& y_dims) + : TestCase(place, alias), + x_transpose_(x_transpose), + y_transpose_(y_transpose), + alpha_(alpha), + x_dims_(x_dims), + y_dims_(y_dims) {} + + void RunBaseline(Scope* scope) override { + auto* x = scope->FindTensor(x_); + auto* y = scope->FindTensor(y_); + CHECK(x); + CHECK(y); + const auto* x_data = x->data(); + const auto* y_data = y->data(); + auto* out = scope->NewTensor(out_); + CHECK(out); + + // todo alpha + std::vector dim_out_vec; + if (x_dims_.size() > 2 && y_dims_.size() >= 2) { + // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] + // x: [B, M, K], y: [K, N], out: [B, M, N] + if (x_transpose_ || y_transpose_) { + LOG(FATAL) << "not supported transpose for x and y."; + } + CHECK_EQ(x_dims_[x_dims_.size() - 1], y_dims_[y_dims_.size() - 2]) + << "not supported x_dims(" << x_dims_ << ") and y_dims(" << y_dims_ + << ")"; + dim_out_vec.resize(x_dims_.size()); + for (size_t i = 0; i < x_dims_.size() - 1; ++i) { + dim_out_vec[i] = x_dims_[i]; + } + dim_out_vec[x_dims_.size() - 1] = y_dims_[y_dims_.size() - 1]; + out->Resize(dim_out_vec); + auto* out_data = out->mutable_data(); + int x_inner = x_dims_[x_dims_.size() - 2] * x_dims_[x_dims_.size() - 1]; + + if (y_dims_.size() > 2) { + int y_inner = y_dims_[y_dims_.size() - 2] * y_dims_[y_dims_.size() - 1]; + int o_inner = x_dims_[x_dims_.size() - 2] * y_dims_[y_dims_.size() - 1]; + for (size_t i = 0; i < x_dims_.count(0, x_dims_.size() - 2); ++i) { + mul_low_efficiency( + DDim({x_dims_[x_dims_.size() - 2], x_dims_[x_dims_.size() - 1]}), + DDim({y_dims_[y_dims_.size() - 2], y_dims_[y_dims_.size() - 1]}), + x_transpose_, + y_transpose_, + alpha_, + x_data + i * x_inner, + y_data + i * y_inner, + out_data + i * o_inner); + } + } else { + int o_inner = x_dims_[x_dims_.size() - 2] * y_dims_[1]; + for (size_t i = 0; i < x_dims_.count(0, x_dims_.size() - 2); ++i) { + mul_low_efficiency( + DDim({x_dims_[x_dims_.size() - 2], x_dims_[x_dims_.size() - 1]}), + y_dims_, + x_transpose_, + y_transpose_, + alpha_, + x_data + i * x_inner, + y_data, + out_data + i * o_inner); + } + } + } else if (x_dims_.size() == 2 && y_dims_.size() == 2) { + // x: [M, K], y: [K, N], out: [M, N] + dim_out_vec.resize(x_dims_.size()); + if (x_transpose_) { + dim_out_vec[0] = x_dims_[1]; + } else { + dim_out_vec[0] = x_dims_[0]; + } + if (y_transpose_) { + dim_out_vec[1] = y_dims_[0]; + } else { + dim_out_vec[1] = y_dims_[1]; + } + out->Resize(dim_out_vec); + auto* out_data = out->mutable_data(); + mul_low_efficiency(x_dims_, + y_dims_, + x_transpose_, + y_transpose_, + alpha_, + x_data, + y_data, + out_data); + } else if (x_dims_.size() > 2 && y_dims_.size() == 1) { + // x: [B, M, K], y: [K], out: [B, M] + CHECK_EQ(x_dims_[x_dims_.size() - 1], y_dims_[0]) + << "not supported x_dims(" << x_dims_ << ") and y_dims(" << y_dims_ + << ")"; + dim_out_vec.resize(x_dims_.size() - 1); + for (size_t i = 0; i < dim_out_vec.size(); ++i) { + dim_out_vec[i] = x_dims_[i]; + } + out->Resize(dim_out_vec); + auto* out_data = out->mutable_data(); + for (size_t i = 0; i < x_dims_.count(0, x_dims_.size() - 1); ++i) { + out_data[i] = 0; + for (size_t j = 0; j < y_dims_[0]; ++j) { + out_data[i] += x_data[i * y_dims_[0] + j] * y_data[j] * alpha_; + } + } + } else if (x_dims_.size() == 1 && y_dims_.size() == 1) { // todo + // x: [K], y: [K], out: [1] + if (x_dims_[0] == y_dims_[0] && x_transpose_ == false && + y_transpose_ == false) { + dim_out_vec.resize(1); + dim_out_vec[0] = 1; + + out->Resize(dim_out_vec); + auto* out_data = out->mutable_data(); + out_data[0] = 0.f; + for (size_t i = 0; i < x_dims_[0]; ++i) { + out_data[0] += x_data[i] * y_data[i] * alpha_; + } + } + // x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N] + if (x_transpose_ == true && y_transpose_ == true) { + dim_out_vec.resize(2); + dim_out_vec[0] = x_dims_[0]; + dim_out_vec[1] = y_dims_[0]; + out->Resize(dim_out_vec); + auto* out_data = out->mutable_data(); + mul_low_efficiency(DDim({x_dims_[0], 1}), + DDim({1, y_dims_[0]}), + false, + false, + alpha_, + x_data, + y_data, + out_data); + } + } else { + LOG(FATAL) << "not supported x_dims(" << x_dims_ << ") and y_dims(" + << y_dims_ << ")"; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("matmul"); + op_desc->SetInput("X", {x_}); + op_desc->SetInput("Y", {y_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetAttr("transpose_X", x_transpose_); + op_desc->SetAttr("transpose_Y", y_transpose_); + op_desc->SetAttr("alpha", alpha_); + } + + void PrepareData() override { + std::vector x_data(x_dims_.production()); + std::vector y_data(y_dims_.production()); + + for (int i = 0; i < x_dims_.production(); ++i) { + x_data[i] = 1; // i * 1.1; + } + for (int i = 0; i < y_dims_.production(); ++i) { + y_data[i] = 1; // i * 0.9; + } + + SetCommonTensor(x_, x_dims_, x_data.data()); + SetCommonTensor(y_, y_dims_, y_data.data()); + } +}; + +void test_matmul2x2_no_transform(Place place) { + for (int m : {1, 2, 4, 8}) { + for (int k : {1, 3, 5}) { + for (int n : {1, 2, 4, 6}) { + for (float alpha : {1., 2.}) { + bool x_transform = false; + bool y_transform = false; + std::unique_ptr tester( + new MatMulComputeTester(place, + "def", + x_transform, + y_transform, + alpha, + DDim({m, k}), + DDim({k, n}))); + arena::Arena arena(std::move(tester), place, 5e-4); + arena.TestPrecision(); + } + } + } + } +} + +void test_matmul2x2_transform(Place place) { + DDim x_dim({3, 2}); + DDim y_dim({3, 2}); + float alpha = 1.f; + std::unique_ptr tester( + new MatMulComputeTester(place, "def", false, true, alpha, x_dim, y_dim)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); +} + +void test_matmul1x1_no_transpose(Place place) { + DDim x_dim({3}); + DDim y_dim({3}); + float alpha = 1.5f; + std::unique_ptr tester( + new MatMulComputeTester(place, "def", false, false, alpha, x_dim, y_dim)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); +} + +void test_matmul1x1_transpose(Place place) { + DDim x_dim({3}); + DDim y_dim({5}); + float alpha = 1.5f; + std::unique_ptr tester( + new MatMulComputeTester(place, "def", true, true, alpha, x_dim, y_dim)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); +} + +void test_matmul_nx1(Place place) { + DDim x_dim({3, 4, 2, 5}); + DDim y_dim({5}); + float alpha = 1.5f; + std::unique_ptr tester( + new MatMulComputeTester(place, "def", false, false, alpha, x_dim, y_dim)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); +} + +void test_matmul_nx2_1(Place place) { + DDim x_dim({3, 4, 2, 5}); + DDim y_dim({5, 1}); + float alpha = 1.5f; + std::unique_ptr tester( + new MatMulComputeTester(place, "def", false, false, alpha, x_dim, y_dim)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); +} + +void test_matmul_nx2_2(Place place) { + DDim x_dim({3, 4, 2, 5}); + DDim y_dim({5, 3}); + float alpha = 1.5f; + std::unique_ptr tester( + new MatMulComputeTester(place, "def", false, false, alpha, x_dim, y_dim)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); +} + +void test_matmul_nxn(Place place) { + DDim x_dim({3, 4, 2, 5}); + DDim y_dim({3, 4, 5, 2}); + float alpha = 1.5f; + std::unique_ptr tester( + new MatMulComputeTester(place, "def", false, false, alpha, x_dim, y_dim)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); +} + +TEST(Matmul2x2, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + // test_matmul2x2_transform(place); + test_matmul2x2_no_transform(place); +#endif +} + +TEST(Matmul1x1, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_matmul1x1_transpose(place); + test_matmul1x1_no_transpose(place); +#endif +} + +TEST(Matmulnx1, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_matmul_nx1(place); +#endif +} + +TEST(Matmulnx2, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_matmul_nx2_1(place); + test_matmul_nx2_2(place); +#endif +} + +TEST(Matmulnxn, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_matmul_nxn(place); +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/shape_compute_test.cc b/lite/tests/kernels/shape_compute_test.cc index e0d766c166..23eab7c94f 100644 --- a/lite/tests/kernels/shape_compute_test.cc +++ b/lite/tests/kernels/shape_compute_test.cc @@ -35,7 +35,8 @@ class ShapeComputeTester : public arena::TestCase { CHECK(input); auto* out = scope->NewTensor(out_); CHECK(out); - out->Resize(DDim({input->dims().size()})); + int64_t sz = input->dims().size(); + out->Resize(DDim({sz})); auto* out_data = out->mutable_data(); for (int i = 0; i < input->dims().size(); ++i) { out_data[i] = input->dims()[i]; diff --git a/lite/tests/kernels/squeeze_compute_test.cc b/lite/tests/kernels/squeeze_compute_test.cc new file mode 100644 index 0000000000..36efe76978 --- /dev/null +++ b/lite/tests/kernels/squeeze_compute_test.cc @@ -0,0 +1,253 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +class SqueezeComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string x_ = "X"; + std::string out_ = "Out"; + std::vector axes_; + DDim dims_; + + public: + SqueezeComputeTester(const Place& place, + const std::string& alias, + const std::vector& axes, + DDim dims) + : TestCase(place, alias), axes_(axes), dims_(dims) {} + + void RunBaseline(Scope* scope) override { + const auto* input = scope->FindTensor(x_); + CHECK(input); + auto* out = scope->NewTensor(out_); + CHECK(out); + + DDim in_dims(dims_); + size_t num_squeeze_dims = axes_.size(); + int cnt_squeezed_dims = 0; + bool should_squeeze[9] = {false}; + + if (num_squeeze_dims == 0) { + for (int idx = 0; idx < in_dims.size(); ++idx) { + if (in_dims[idx] == 1) { + should_squeeze[idx] = true; + ++cnt_squeezed_dims; + } + } + } else { + for (size_t idx = 0; idx < num_squeeze_dims; ++idx) { + int current = axes_[idx] < 0 ? axes_[idx] + in_dims.size() : axes_[idx]; + // Check current index, the upper limit has been checked. + CHECK_GE(current, 0) + << "Invalid axis, the negative axis is out of range."; + + CHECK_EQ(in_dims[current], 1) << "Invalid axis index, the axis that " + "will be squeezed should be equal " + "to 1."; + if (!(should_squeeze[current])) { + ++cnt_squeezed_dims; + } + should_squeeze[current] = true; + } + } + + std::vector output_shape(in_dims.size() - cnt_squeezed_dims, 0); + for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) { + if (!should_squeeze[in_idx]) { + output_shape[out_idx++] = in_dims[in_idx]; + } + } + + out->Resize(DDim(output_shape)); + LOG(INFO) << "baseline out size: " << out->dims(); + auto* input_data = input->data(); + auto* out_data = out->mutable_data(); + memcpy(out_data, input_data, sizeof(float) * dims_.production()); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("squeeze"); + op_desc->SetInput("X", {x_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetAttr("axes", axes_); + } + + void PrepareData() override { + std::vector in_data(dims_.production()); + for (int i = 0; i < dims_.production(); ++i) { + in_data[i] = i; + } + SetCommonTensor(x_, dims_, in_data.data()); + } +}; + +class Squeeze2ComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string x_ = "X"; + std::string out_ = "Out"; + std::string xshape_ = "XShape"; + std::vector axes_; + DDim dims_; + + public: + Squeeze2ComputeTester(const Place& place, + const std::string& alias, + const std::vector& axes, + DDim dims) + : TestCase(place, alias), axes_(axes), dims_(dims) {} + + void RunBaseline(Scope* scope) override { + const auto* input = scope->FindTensor(x_); + CHECK(input); + auto* out = scope->NewTensor(out_); + CHECK(out); + auto* xshape = scope->NewTensor(xshape_); + CHECK(xshape); + std::vector xshape_sp(dims_.size() + 1, 1); + for (size_t i = 0; i < dims_.size(); ++i) { + xshape_sp[i + 1] = dims_[i]; + } + xshape->Resize(DDim(xshape_sp)); + + DDim in_dims(dims_); + size_t num_squeeze_dims = axes_.size(); + int cnt_squeezed_dims = 0; + bool should_squeeze[9] = {false}; + + if (num_squeeze_dims == 0) { + for (int idx = 0; idx < in_dims.size(); ++idx) { + if (in_dims[idx] == 1) { + should_squeeze[idx] = true; + ++cnt_squeezed_dims; + } + } + } else { + for (size_t idx = 0; idx < num_squeeze_dims; ++idx) { + int current = axes_[idx] < 0 ? axes_[idx] + in_dims.size() : axes_[idx]; + // Check current index, the upper limit has been checked. + CHECK_GE(current, 0) + << "Invalid axis, the negative axis is out of range."; + + CHECK_EQ(in_dims[current], 1) << "Invalid axis index, the axis that " + "will be squeezed should be equal " + "to 1."; + if (!(should_squeeze[current])) { + ++cnt_squeezed_dims; + } + should_squeeze[current] = true; + } + } + + std::vector output_shape(in_dims.size() - cnt_squeezed_dims, 0); + for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) { + if (!should_squeeze[in_idx]) { + output_shape[out_idx++] = in_dims[in_idx]; + } + } + + out->Resize(DDim(output_shape)); + + auto* input_data = input->data(); + auto* out_data = out->mutable_data(); + auto* xshape_data = xshape->mutable_data(); + memcpy(out_data, input_data, sizeof(float) * dims_.production()); + memcpy(xshape_data, input_data, sizeof(float) * dims_.production()); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("squeeze2"); + op_desc->SetInput("X", {x_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetOutput("XShape", {xshape_}); + op_desc->SetAttr("axes", axes_); + } + + void PrepareData() override { + std::vector in_data(dims_.production()); + for (int i = 0; i < dims_.production(); ++i) { + in_data[i] = i; + } + SetCommonTensor(x_, dims_, in_data.data()); + } +}; + +void test_squeeze(Place place) { + for (std::vector axes : {std::vector({}), + std::vector({0, 2}), + std::vector({0, -2})}) { + for (int N : {1}) { + for (int C : {3}) { + for (int H : {1}) { + for (int W : {5}) { + std::unique_ptr tester(new SqueezeComputeTester( + place, "def", axes, DDim({N, C, H, W}))); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } +} + +void test_squeeze2(Place place) { + for (std::vector axes : {std::vector({}), + std::vector({0, 2}), + std::vector({0, -2})}) { + for (int N : {1}) { + for (int C : {3}) { + for (int H : {1}) { + for (int W : {5}) { + std::unique_ptr tester(new Squeeze2ComputeTester( + place, "def", axes, DDim({N, C, H, W}))); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } +} + +TEST(squeeze, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_squeeze(place); +#endif +} + +TEST(squeeze2, precision) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_squeeze2(place); +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tools/benchmark.sh b/lite/tools/benchmark.sh new file mode 100644 index 0000000000..66b4025f91 --- /dev/null +++ b/lite/tools/benchmark.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -e + +if [ $# -lt 2 ]; +then + echo "Input error" + echo "USAGE:" + echo " sh benchmark.sh benchmark_bin_path test_models_dir" + echo " sh benchmark.sh benchmark_bin_path test_models_dir arm_bi" + exit +fi + +BENCHMARK_BIN=$1 +MODELS_DIR=$2 +ARM_BI=$3 +ANDROID_DIR=/data/local/tmp +RESULT_FILENAME="result.txt" +WARMUP=10 +REPEATS=30 + +adb push $BENCHMARK_BIN $ANDROID_DIR/benchmark_bin +adb shell chmod 777 $ANDROID_DIR/benchmark_bin +adb push $MODELS_DIR $ANDROID_DIR + +adb shell "echo PaddleLite Benchmark > $ANDROID_DIR/$RESULT_FILENAME" +for threads in 1 2 4 +do +adb shell "echo ABI=$ARM_BI Threads=$threads Warmup=$WARMUP Repeats=$REPEATS >> $ANDROID_DIR/$RESULT_FILENAME" +for model_name in `ls $MODELS_DIR` +do + echo $model_name + adb shell "$ANDROID_DIR/benchmark_bin --model_dir=$ANDROID_DIR/${MODELS_DIR##*/}/$model_name --warmup=$WARMUP --repeats=$REPEATS --threads=$threads --result_filename=$ANDROID_DIR/$RESULT_FILENAME" +done +adb shell "echo >> $ANDROID_DIR/$RESULT_FILENAME" +done +adb pull $ANDROID_DIR/$RESULT_FILENAME . diff --git a/lite/tools/ci_build.sh b/lite/tools/ci_build.sh index 2800a4db47..54ca3427a0 100755 --- a/lite/tools/ci_build.sh +++ b/lite/tools/ci_build.sh @@ -538,6 +538,10 @@ function test_arm { echo "prepare multiclass_nms_test files..." __prepare_multiclass_nms_test_files $port + # prepare for CXXApi test + local adb="adb -s emulator-${port}" + $adb shell mkdir -p /data/local/tmp/lite_naive_model_opt + echo "test file: ${TESTS_FILE}" for _test in $(cat $TESTS_FILE); do test_arm_android $_test $port diff --git a/lite/tools/search_support_ops.py b/lite/tools/search_support_ops.py new file mode 100644 index 0000000000..43c3c0704d --- /dev/null +++ b/lite/tools/search_support_ops.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +import os +import re + + +def merge_sort_two_list(la, lb): + la.extend(lb) + la = list(set(la)) + la.sort() + return la + + +ops_file = "../api/paddle_use_ops.h" +kernels_file = "../api/paddle_use_kernels.h" +result_file = "./support_ops_list.md" + +# search ops +if os.path.exists(ops_file): + pattern = re.compile("USE_LITE_OP[(](.*?)[)]") + ops = [] + for line in open(ops_file): + if line != None and line[0:2] != "//": + op = pattern.findall(line) + ops.extend(op) + ops.sort() + # print ops + # print len(ops) +else: + print "ops_file no exist in ", ops_file + +# search kernels +if os.path.exists(kernels_file): + kernel_types = [ + "kARM, kFloat", "kARM, kInt8", "kARM, kAny", "kX86, kFloat", + "kX86, kInt8", "kX86, kAny", "kOpenCL, kFloat", "kOpenCL, kInt8", + "kOpenCL, kAny" + ] + patterns = [] + for type in kernel_types: + pat_str = "USE_LITE_KERNEL[(](.*?), " + type + patterns.append(re.compile(pat_str)) + + kernels = [[] for i in range(len(kernel_types))] + for line in open(kernels_file): + if line != None and line[0:2] != "//": + for i in range(len(kernel_types)): + kl = patterns[i].findall(line) + kernels[i].extend(kl) +else: + print "kernels_file no exist in ", kernels_file + +# write out +if os.path.exists(result_file): + os.remove(result_file) +out = open(result_file, "w") +out.write("# PaddleLite support ops and kernels\n") +out.write("## ops\n") +for op in ops: + out.write("- " + op + "\n") + +out.write("## kernels\n") +for i in range(len(kernel_types) / 3): + for j in range(2): + out.write("### " + kernel_types[3 * i + j] + "\n") + for kl in merge_sort_two_list(kernels[3 * i + j], kernels[3 * i + 2]): + out.write("- " + kl + "\n") diff --git a/lite/utils/io.h b/lite/utils/io.h index 0946c39936..72f00bd1ca 100644 --- a/lite/utils/io.h +++ b/lite/utils/io.h @@ -35,8 +35,9 @@ static bool IsFileExists(const std::string& path) { // ARM mobile not support mkdir in C++ static void MkDirRecur(const std::string& path) { #ifndef LITE_WITH_ARM - CHECK_EQ(system(string_format("mkdir -p %s", path.c_str()).c_str()), 0) - << "Cann't mkdir " << path; + if(system(string_format("mkdir -p %s", path.c_str()).c_str()) != 0) { + LOG(ERROR) << "Cann't mkdir " << path; + } #else // On ARM CHECK_NE(mkdir(path.c_str(), S_IRWXU), -1) << "Cann't mkdir " << path; #endif -- GitLab