From 52f86cc39ab29fbf8682dbb1ce6fdff1aa8e45bb Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Sat, 28 Dec 2019 10:17:56 +0800 Subject: [PATCH] Upgrade of Model_optimize_tool (#2624) --- lite/api/CMakeLists.txt | 2 +- lite/api/cxx_api.cc | 6 +- lite/api/model_optimize_tool.cc | 220 +++++++++++++++++- lite/core/CMakeLists.txt | 10 +- lite/kernels/arm/CMakeLists.txt | 2 +- lite/kernels/cuda/CMakeLists.txt | 2 +- lite/kernels/fpga/CMakeLists.txt | 2 +- lite/kernels/opencl/CMakeLists.txt | 10 +- lite/operators/compare_op.cc | 2 +- .../create_fake_kernel_registry.py | 3 + .../cmake_tools/parse_kernel_registry.py | 4 + lite/tools/cmake_tools/parse_op_registry.py | 4 + .../cmake_tools/record_supported_kernel_op.py | 129 ++++++++++ 13 files changed, 380 insertions(+), 16 deletions(-) create mode 100644 lite/tools/cmake_tools/record_supported_kernel_op.py diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index de1a76c9c3..84f8a09860 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -277,7 +277,7 @@ if (LITE_ON_MODEL_OPTIMIZE_TOOL) message(STATUS "Compiling model_optimize_tool") lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc DEPS gflags kernel op optimizer mir_passes utils) - add_dependencies(model_optimize_tool op_list_h kernel_list_h all_kernel_faked_cc) + add_dependencies(model_optimize_tool op_list_h kernel_list_h all_kernel_faked_cc supported_kernel_op_info_h) endif(LITE_ON_MODEL_OPTIMIZE_TOOL) lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 990d08f18f..c1e9fc4224 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -201,7 +201,11 @@ void Predictor::Build(const lite_api::CxxConfig &config, const std::string &model_file = config.model_file(); const std::string ¶m_file = config.param_file(); const bool model_from_memory = config.model_from_memory(); - LOG(INFO) << "load from memory " << model_from_memory; + if (model_from_memory) { + LOG(INFO) << "Load model from memory."; + } else { + LOG(INFO) << "Load model from file."; + } Build(model_path, model_file, diff --git a/lite/api/model_optimize_tool.cc b/lite/api/model_optimize_tool.cc index b678c7ecd2..fc23e0b54b 100644 --- a/lite/api/model_optimize_tool.cc +++ b/lite/api/model_optimize_tool.cc @@ -16,8 +16,9 @@ #ifdef PADDLE_WITH_TESTING #include #endif -// "all_kernel_faked.cc" and "kernel_src_map.h" are created automatically during -// model_optimize_tool's compiling period +// "supported_kernel_op_info.h", "all_kernel_faked.cc" and "kernel_src_map.h" +// are created automatically during model_optimize_tool's compiling period +#include #include "all_kernel_faked.cc" // NOLINT #include "kernel_src_map.h" // NOLINT #include "lite/api/cxx_api.h" @@ -25,8 +26,11 @@ #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" #include "lite/core/op_registry.h" +#include "lite/model_parser/compatible_pb.h" +#include "lite/model_parser/pb/program_desc.h" #include "lite/utils/cp_logging.h" #include "lite/utils/string.h" +#include "supported_kernel_op_info.h" // NOLINT DEFINE_string(model_dir, "", @@ -62,10 +66,16 @@ DEFINE_string(valid_targets, "The targets this model optimized for, should be one of (arm, " "opencl, x86), splitted by space"); DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels"); +DEFINE_bool(print_supported_ops, + false, + "Print supported operators on the inputed target"); +DEFINE_bool(print_all_ops, + false, + "Print all the valid operators of Paddle-Lite"); +DEFINE_bool(print_model_ops, false, "Print operators in the input model"); namespace paddle { namespace lite_api { - //! Display the kernel information. void DisplayKernels() { LOG(INFO) << ::paddle::lite::KernelRegistry::Global().DebugString(); @@ -130,9 +140,7 @@ void RunOptimize(const std::string& model_dir, config.set_model_dir(model_dir); config.set_model_file(model_file); config.set_param_file(param_file); - config.set_valid_places(valid_places); - auto predictor = lite_api::CreatePaddlePredictor(config); LiteModelType model_type; @@ -168,6 +176,202 @@ void CollectModelMetaInfo(const std::string& output_dir, lite::WriteLines(std::vector(total.begin(), total.end()), output_path); } +void PrintOpsInfo(std::set valid_ops = {}) { + std::vector targets = {"kHost", + "kX86", + "kCUDA", + "kARM", + "kOpenCL", + "kFPGA", + "kNPU", + "kXPU", + "kAny", + "kUnk"}; + int maximum_optype_length = 0; + for (auto it = supported_ops.begin(); it != supported_ops.end(); it++) { + maximum_optype_length = it->first.size() > maximum_optype_length + ? it->first.size() + : maximum_optype_length; + } + std::cout << std::setiosflags(std::ios::internal); + std::cout << std::setw(maximum_optype_length) << "OP_name"; + for (int i = 0; i < targets.size(); i++) { + std::cout << std::setw(10) << targets[i].substr(1); + } + std::cout << std::endl; + if (valid_ops.empty()) { + for (auto it = supported_ops.begin(); it != supported_ops.end(); it++) { + std::cout << std::setw(maximum_optype_length) << it->first; + auto ops_valid_places = it->second; + for (int i = 0; i < targets.size(); i++) { + if (std::find(ops_valid_places.begin(), + ops_valid_places.end(), + targets[i]) != ops_valid_places.end()) { + std::cout << std::setw(10) << "Y"; + } else { + std::cout << std::setw(10) << " "; + } + } + std::cout << std::endl; + } + } else { + for (auto op = valid_ops.begin(); op != valid_ops.end(); op++) { + std::cout << std::setw(maximum_optype_length) << *op; + // Check: If this kernel doesn't match any operator, we will skip it. + if (supported_ops.find(*op) == supported_ops.end()) { + continue; + } + // Print OP info. + auto ops_valid_places = supported_ops.at(*op); + for (int i = 0; i < targets.size(); i++) { + if (std::find(ops_valid_places.begin(), + ops_valid_places.end(), + targets[i]) != ops_valid_places.end()) { + std::cout << std::setw(10) << "Y"; + } else { + std::cout << std::setw(10) << " "; + } + } + std::cout << std::endl; + } + } +} +/// Print help information +void PrintHelpInfo() { + // at least one argument should be inputed + const char help_info[] = + "At least one argument should be inputed. Valid arguments are listed " + "below:\n" + " Arguments of model optimization:\n" + " `--model_dir=`\n" + " `--model_file=`\n" + " `--param_file=`\n" + " `--optimize_out_type=(protobuf|naive_buffer)`\n" + " `--optimize_out=`\n" + " `--valid_targets=(arm|opencl|x86|npu|xpu)`\n" + " `--prefer_int8_kernel=(true|false)`\n" + " `--record_tailoring_info=(true|false)`\n" + " Arguments of model checking and ops information:\n" + " `--print_all_ops=true` Display all the valid operators of " + "Paddle-Lite\n" + " `--print_supported_ops=true " + "--valid_targets=(arm|opencl|x86|npu|xpu)`" + " Display valid operators of input targets\n" + " `--print_model_ops=true --model_dir= " + "--valid_targets=(arm|opencl|x86|npu|xpu)`" + " Display operators in the input model\n"; + std::cout << help_info << std::endl; + exit(1); +} + +// Parse Input command +void ParseInputCommand() { + if (FLAGS_print_all_ops) { + std::cout << "All OPs supported by Paddle-Lite: " << supported_ops.size() + << " ops in total." << std::endl; + PrintOpsInfo(); + exit(1); + } else if (FLAGS_print_supported_ops) { + auto valid_places = paddle::lite_api::ParserValidPlaces(); + // get valid_targets string + std::vector target_types = {}; + for (int i = 0; i < valid_places.size(); i++) { + target_types.push_back(valid_places[i].target); + } + std::string targets_str = TargetToStr(target_types[0]); + for (int i = 1; i < target_types.size(); i++) { + targets_str = targets_str + TargetToStr(target_types[i]); + } + + std::cout << "Supported OPs on '" << targets_str << "': " << std::endl; + target_types.push_back(TARGET(kHost)); + target_types.push_back(TARGET(kUnk)); + + std::set valid_ops; + for (int i = 0; i < target_types.size(); i++) { + auto ops = supported_ops_target[static_cast(target_types[i])]; + valid_ops.insert(ops.begin(), ops.end()); + } + PrintOpsInfo(valid_ops); + exit(1); + } +} +// test whether this model is supported +void CheckIfModelSupported() { + // 1. parse valid places and valid targets + auto valid_places = paddle::lite_api::ParserValidPlaces(); + // set valid_ops + auto valid_ops = supported_ops_target[static_cast(TARGET(kHost))]; + auto valid_unktype_ops = supported_ops_target[static_cast(TARGET(kUnk))]; + valid_ops.insert( + valid_ops.end(), valid_unktype_ops.begin(), valid_unktype_ops.end()); + for (int i = 0; i < valid_places.size(); i++) { + auto target = valid_places[i].target; + auto ops = supported_ops_target[static_cast(target)]; + valid_ops.insert(valid_ops.end(), ops.begin(), ops.end()); + } + // get valid ops + std::set valid_ops_set(valid_ops.begin(), valid_ops.end()); + + // 2.Load model into program to get ops in model + std::string prog_path = FLAGS_model_dir + "/__model__"; + if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) { + prog_path = FLAGS_model_file; + } + lite::cpp::ProgramDesc cpp_prog; + framework::proto::ProgramDesc pb_proto_prog = + *lite::LoadProgram(prog_path, false); + lite::pb::ProgramDesc pb_prog(&pb_proto_prog); + // Transform to cpp::ProgramDesc + lite::TransformProgramDescAnyToCpp(pb_prog, &cpp_prog); + + std::set unsupported_ops; + std::set input_model_ops; + for (int index = 0; index < cpp_prog.BlocksSize(); index++) { + auto current_block = cpp_prog.GetBlock(index); + for (size_t i = 0; i < current_block->OpsSize(); ++i) { + auto& op_desc = *current_block->GetOp(i); + auto op_type = op_desc.Type(); + input_model_ops.insert(op_type); + if (valid_ops_set.count(op_type) == 0) { + unsupported_ops.insert(op_type); + } + } + } + // 3. Print ops_info of input model and check if this model is supported + if (FLAGS_print_model_ops) { + std::cout << "OPs in the input model include:\n"; + PrintOpsInfo(input_model_ops); + } + if (!unsupported_ops.empty()) { + std::string unsupported_ops_str = *unsupported_ops.begin(); + for (auto op_str = ++unsupported_ops.begin(); + op_str != unsupported_ops.end(); + op_str++) { + unsupported_ops_str = unsupported_ops_str + ", " + *op_str; + } + std::vector targets = {}; + for (int i = 0; i < valid_places.size(); i++) { + targets.push_back(valid_places[i].target); + } + std::sort(targets.begin(), targets.end()); + targets.erase(unique(targets.begin(), targets.end()), targets.end()); + std::string targets_str = TargetToStr(targets[0]); + for (int i = 1; i < targets.size(); i++) { + targets_str = targets_str + "," + TargetToStr(targets[i]); + } + + LOG(ERROR) << "Error: This model is not supported, because " + << unsupported_ops.size() << " ops are not supported on '" + << targets_str << "'. These unsupported ops are: '" + << unsupported_ops_str << "'."; + exit(1); + } + if (FLAGS_print_model_ops) { + std::cout << "Paddle-Lite supports this model!" << std::endl; + exit(1); + } +} void Main() { if (FLAGS_display_kernels) { @@ -241,7 +445,13 @@ void Main() { } // namespace paddle int main(int argc, char** argv) { + // If there is none input argument, print help info. + if (argc < 2) { + paddle::lite_api::PrintHelpInfo(); + } google::ParseCommandLineFlags(&argc, &argv, false); + paddle::lite_api::ParseInputCommand(); + paddle::lite_api::CheckIfModelSupported(); paddle::lite_api::Main(); return 0; } diff --git a/lite/core/CMakeLists.txt b/lite/core/CMakeLists.txt index 34d9deff6a..8fda0a12fd 100644 --- a/lite/core/CMakeLists.txt +++ b/lite/core/CMakeLists.txt @@ -95,7 +95,15 @@ add_custom_command( add_custom_target(op_list_h DEPENDS ops.h) add_custom_target(kernel_list_h DEPENDS kernels.h) add_custom_target(all_kernel_faked_cc DEPENDS all_kernel_faked.cc) - +# create headfile to restore ops info sorted by suppported platforms +add_custom_command( + COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/record_supported_kernel_op.py + ${kernels_src_list} + ${ops_src_list} + ${CMAKE_BINARY_DIR}/supported_kernel_op_info.h + OUTPUT supported_kernel_op_info.h # not a real path to the output to force it execute every time. + ) + add_custom_target(supported_kernel_op_info_h DEPENDS supported_kernel_op_info.h) #----------------------------------------------- NOT CHANGE ----------------------------------------------- lite_cc_library(kernel SRCS kernel.cc DEPS context type_system target_wrapper any op_params tensor diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index ce8b8365a8..74b86c519e 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -1,6 +1,6 @@ # NOTE we leave the add_kernel not protected by LITE_WITH_LIGHT_WEIGHT_FRAMEWORK so that all the kernels will be registered # to the model_optimize_tool. -if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) +if((NOT LITE_ON_MODEL_OPTIMIZE_TOOL) AND (NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM))) return() endif() diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index bf59d02726..2df00f00a4 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -1,4 +1,4 @@ -if(NOT LITE_WITH_CUDA) +if((NOT LITE_ON_MODEL_OPTIMIZE_TOOL) AND (NOT LITE_WITH_CUDA)) return() endif() diff --git a/lite/kernels/fpga/CMakeLists.txt b/lite/kernels/fpga/CMakeLists.txt index 7c47e72872..f6c3a39949 100755 --- a/lite/kernels/fpga/CMakeLists.txt +++ b/lite/kernels/fpga/CMakeLists.txt @@ -1,4 +1,4 @@ -if (NOT LITE_WITH_FPGA) +if ((NOT LITE_ON_MODEL_OPTIMIZE_TOOL) AND (NOT LITE_WITH_FPGA)) return() endif() diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 3423b1e920..f4d3254a7b 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -14,7 +14,7 @@ add_kernel(pool_opencl OPENCL basic SRCS pool_compute.cc DEPS ${cl_kernel_deps}) add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${tensor_lite} ${cl_kernel_deps}) add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps}) add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps}) -add_kernel(conv2d_1x1_opencl OPENCL basic SRCS conv2d_1x1_compute.cc DEPS ${cl_kernel_deps}) +#add_kernel(conv2d_1x1_opencl OPENCL basic SRCS conv2d_1x1_compute.cc DEPS ${cl_kernel_deps}) add_kernel(reshape_opencl OPENCL basic SRCS reshape_compute.cc DEPS ${cl_kernel_deps}) add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps}) add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps}) @@ -49,12 +49,14 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc DEPS depthwise_conv2d_opencl op_registry program context cl_image_converter ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) -lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc - DEPS conv2d_1x1_opencl cl_image_converter op_registry program context - ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) +#lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc +# DEPS conv2d_1x1_opencl cl_image_converter op_registry program context +# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) + lite_cc_test(test_reshape_opencl SRCS reshape_compute_test.cc DEPS reshape_opencl cl_image_converter op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) + lite_cc_test(test_conv_opencl SRCS conv_compute_test.cc DEPS conv_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) diff --git a/lite/operators/compare_op.cc b/lite/operators/compare_op.cc index 3210520cd5..aa500ba35c 100644 --- a/lite/operators/compare_op.cc +++ b/lite/operators/compare_op.cc @@ -54,7 +54,7 @@ bool CompareOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { } // namespace paddle REGISTER_LITE_OP(equal, paddle::lite::operators::CompareOp); -REGISTER_LITE_OP(notequal, paddle::lite::operators::CompareOp); +REGISTER_LITE_OP(not_equal, paddle::lite::operators::CompareOp); REGISTER_LITE_OP(less_than, paddle::lite::operators::CompareOp); REGISTER_LITE_OP(less_equal, paddle::lite::operators::CompareOp); REGISTER_LITE_OP(greater_than, paddle::lite::operators::CompareOp); diff --git a/lite/tools/cmake_tools/create_fake_kernel_registry.py b/lite/tools/cmake_tools/create_fake_kernel_registry.py index 140d773207..35012d5b16 100644 --- a/lite/tools/cmake_tools/create_fake_kernel_registry.py +++ b/lite/tools/cmake_tools/create_fake_kernel_registry.py @@ -18,6 +18,9 @@ import logging from ast import RegisterLiteKernelParser from utils import * +if len(sys.argv) != 4: + print("Error: create_fake_kernel_registry.py requires three inputs!") + exit(1) ops_list_path = sys.argv[1] dest_path = sys.argv[2] kernelmap_path = sys.argv[3] diff --git a/lite/tools/cmake_tools/parse_kernel_registry.py b/lite/tools/cmake_tools/parse_kernel_registry.py index f4f0b95483..6c020ec438 100644 --- a/lite/tools/cmake_tools/parse_kernel_registry.py +++ b/lite/tools/cmake_tools/parse_kernel_registry.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function import sys import logging from ast import RegisterLiteKernelParser +if len(sys.argv) != 5: + print("Error: parse_kernel_registry.py requires four inputs!") + exit(1) ops_list_path = sys.argv[1] dest_path = sys.argv[2] minkernels_list_path = sys.argv[3] diff --git a/lite/tools/cmake_tools/parse_op_registry.py b/lite/tools/cmake_tools/parse_op_registry.py index db58c455a9..7eb3337ed8 100644 --- a/lite/tools/cmake_tools/parse_op_registry.py +++ b/lite/tools/cmake_tools/parse_op_registry.py @@ -13,10 +13,14 @@ # limitations under the License. ''' Collect op registry information. ''' +from __future__ import print_function import sys import logging from ast import RegisterLiteOpParser +if len(sys.argv) != 5: + print("Error: parse_op_registry.py requires four inputs!") + exit(1) ops_list_path = sys.argv[1] dest_path = sys.argv[2] minops_list_path = sys.argv[3] diff --git a/lite/tools/cmake_tools/record_supported_kernel_op.py b/lite/tools/cmake_tools/record_supported_kernel_op.py new file mode 100644 index 0000000000..f6a3af6bd3 --- /dev/null +++ b/lite/tools/cmake_tools/record_supported_kernel_op.py @@ -0,0 +1,129 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +import logging +from ast import RegisterLiteKernelParser +from ast import RegisterLiteOpParser + +if len(sys.argv) != 4: + print("Error: record_supported_kernel_op.py requires three inputs!") + exit(1) +kernels_list_path = sys.argv[1] +ops_list_path = sys.argv[2] +kernel_op_map_dest_path = sys.argv[3] + + +out_lines = [ +''' +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +const std::vector> supported_ops_target = { +''' +] + +ops_lines=[] + +# valid targets and valid_ops +valid_targets = ["kUnk", "kHost", "kX86", "kCUDA", "kARM", "kOpenCL", "kAny", "kFPGA", "kNPU", "kXPU"] +valid_ops = [[],[],[],[],[],[],[],[],[],[]] +class TargetType: + kUnk = 0 + kHost = 1 + kX86 = 2 + kCUDA = 3 + kARM = 4 + kOpenCL = 5 + kFPGA = 7 + kNPU = 8 + kXPU = 9 + kAny = 6 # any target + +# record op_info of valid kernels into `valid_ops` according to different target type +with open(kernels_list_path) as f: + paths = set([path for path in f]) + for path in paths: + with open(path.strip()) as g: + c = g.read() + kernel_parser = RegisterLiteKernelParser(c) + kernel_parser.parse() + for k in kernel_parser.kernels: + if hasattr(TargetType, k.target): + index=getattr(TargetType, k.target) + valid_ops[index].append(k.op_type) + +# clear the repeated ops +for target in valid_targets: + index = getattr(TargetType, target) + valid_ops[index] = list(set(valid_ops[index])) + +paths = set() +with open(ops_list_path) as f: + paths = set([path for path in f]) + for path in paths: + str_info = open(path.strip()).read() + op_parser = RegisterLiteOpParser(str_info) + ops = op_parser.parse() + for op in ops: + if "_grad" in op: + continue + out = ' {"%s", { "' % op + op_targets = [] + for target in valid_targets: + if op in valid_ops[getattr(TargetType, target)]: + op_targets.append(target) + if len(op_targets) > 0: + out = out +'", "'.join(op_targets)+ '" }}' + else: + # unknow type op: kUnk = 0 + valid_ops[0].append(op) + out = out +'kUnk" }}' + ops_lines.append(out) + +with open(kernel_op_map_dest_path, 'w') as f: + logging.info("write kernel list to %s" % kernel_op_map_dest_path) + f.write('\n'.join(out_lines)) + # write kernels into head file + for target in valid_targets: + if len(valid_ops[getattr(TargetType, target)]) == 0 : + f.write("\n // %s_OPS: " %target) + f.write('\n {},') + else: + f.write("\n // %s_OPS: " %target) + f.write('\n {"') + f.write('","'.join(valid_ops[getattr(TargetType, target)])) + f.write('"},\n') + f.write('};') + # write op info into head file + f.write('\nconst std::map> supported_ops={\n') + f.write(',\n'.join(ops_lines)) + f.write('\n};') -- GitLab