未验证 提交 4300ef75 编写于 作者: H huzhiqiang 提交者: GitHub

Upgrade of Model_optimize_tool (#2624)

上级 86762e1e
...@@ -277,7 +277,7 @@ if (LITE_ON_MODEL_OPTIMIZE_TOOL) ...@@ -277,7 +277,7 @@ if (LITE_ON_MODEL_OPTIMIZE_TOOL)
message(STATUS "Compiling model_optimize_tool") message(STATUS "Compiling model_optimize_tool")
lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc
DEPS gflags kernel op optimizer mir_passes utils) DEPS gflags kernel op optimizer mir_passes utils)
add_dependencies(model_optimize_tool op_list_h kernel_list_h all_kernel_faked_cc) add_dependencies(model_optimize_tool op_list_h kernel_list_h all_kernel_faked_cc supported_kernel_op_info_h)
endif(LITE_ON_MODEL_OPTIMIZE_TOOL) endif(LITE_ON_MODEL_OPTIMIZE_TOOL)
lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light
......
...@@ -201,7 +201,11 @@ void Predictor::Build(const lite_api::CxxConfig &config, ...@@ -201,7 +201,11 @@ void Predictor::Build(const lite_api::CxxConfig &config,
const std::string &model_file = config.model_file(); const std::string &model_file = config.model_file();
const std::string &param_file = config.param_file(); const std::string &param_file = config.param_file();
const bool model_from_memory = config.model_from_memory(); const bool model_from_memory = config.model_from_memory();
LOG(INFO) << "load from memory " << model_from_memory; if (model_from_memory) {
LOG(INFO) << "Load model from memory.";
} else {
LOG(INFO) << "Load model from file.";
}
Build(model_path, Build(model_path,
model_file, model_file,
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
#ifdef PADDLE_WITH_TESTING #ifdef PADDLE_WITH_TESTING
#include <gtest/gtest.h> #include <gtest/gtest.h>
#endif #endif
// "all_kernel_faked.cc" and "kernel_src_map.h" are created automatically during // "supported_kernel_op_info.h", "all_kernel_faked.cc" and "kernel_src_map.h"
// model_optimize_tool's compiling period // are created automatically during model_optimize_tool's compiling period
#include <iomanip>
#include "all_kernel_faked.cc" // NOLINT #include "all_kernel_faked.cc" // NOLINT
#include "kernel_src_map.h" // NOLINT #include "kernel_src_map.h" // NOLINT
#include "lite/api/cxx_api.h" #include "lite/api/cxx_api.h"
...@@ -25,8 +26,11 @@ ...@@ -25,8 +26,11 @@
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h" #include "lite/api/paddle_use_passes.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/model_parser/compatible_pb.h"
#include "lite/model_parser/pb/program_desc.h"
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
#include "lite/utils/string.h" #include "lite/utils/string.h"
#include "supported_kernel_op_info.h" // NOLINT
DEFINE_string(model_dir, DEFINE_string(model_dir,
"", "",
...@@ -62,10 +66,16 @@ DEFINE_string(valid_targets, ...@@ -62,10 +66,16 @@ DEFINE_string(valid_targets,
"The targets this model optimized for, should be one of (arm, " "The targets this model optimized for, should be one of (arm, "
"opencl, x86), splitted by space"); "opencl, x86), splitted by space");
DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels"); DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels");
DEFINE_bool(print_supported_ops,
false,
"Print supported operators on the inputed target");
DEFINE_bool(print_all_ops,
false,
"Print all the valid operators of Paddle-Lite");
DEFINE_bool(print_model_ops, false, "Print operators in the input model");
namespace paddle { namespace paddle {
namespace lite_api { namespace lite_api {
//! Display the kernel information. //! Display the kernel information.
void DisplayKernels() { void DisplayKernels() {
LOG(INFO) << ::paddle::lite::KernelRegistry::Global().DebugString(); LOG(INFO) << ::paddle::lite::KernelRegistry::Global().DebugString();
...@@ -130,9 +140,7 @@ void RunOptimize(const std::string& model_dir, ...@@ -130,9 +140,7 @@ void RunOptimize(const std::string& model_dir,
config.set_model_dir(model_dir); config.set_model_dir(model_dir);
config.set_model_file(model_file); config.set_model_file(model_file);
config.set_param_file(param_file); config.set_param_file(param_file);
config.set_valid_places(valid_places); config.set_valid_places(valid_places);
auto predictor = lite_api::CreatePaddlePredictor(config); auto predictor = lite_api::CreatePaddlePredictor(config);
LiteModelType model_type; LiteModelType model_type;
...@@ -168,6 +176,202 @@ void CollectModelMetaInfo(const std::string& output_dir, ...@@ -168,6 +176,202 @@ void CollectModelMetaInfo(const std::string& output_dir,
lite::WriteLines(std::vector<std::string>(total.begin(), total.end()), lite::WriteLines(std::vector<std::string>(total.begin(), total.end()),
output_path); output_path);
} }
void PrintOpsInfo(std::set<std::string> valid_ops = {}) {
std::vector<std::string> targets = {"kHost",
"kX86",
"kCUDA",
"kARM",
"kOpenCL",
"kFPGA",
"kNPU",
"kXPU",
"kAny",
"kUnk"};
int maximum_optype_length = 0;
for (auto it = supported_ops.begin(); it != supported_ops.end(); it++) {
maximum_optype_length = it->first.size() > maximum_optype_length
? it->first.size()
: maximum_optype_length;
}
std::cout << std::setiosflags(std::ios::internal);
std::cout << std::setw(maximum_optype_length) << "OP_name";
for (int i = 0; i < targets.size(); i++) {
std::cout << std::setw(10) << targets[i].substr(1);
}
std::cout << std::endl;
if (valid_ops.empty()) {
for (auto it = supported_ops.begin(); it != supported_ops.end(); it++) {
std::cout << std::setw(maximum_optype_length) << it->first;
auto ops_valid_places = it->second;
for (int i = 0; i < targets.size(); i++) {
if (std::find(ops_valid_places.begin(),
ops_valid_places.end(),
targets[i]) != ops_valid_places.end()) {
std::cout << std::setw(10) << "Y";
} else {
std::cout << std::setw(10) << " ";
}
}
std::cout << std::endl;
}
} else {
for (auto op = valid_ops.begin(); op != valid_ops.end(); op++) {
std::cout << std::setw(maximum_optype_length) << *op;
// Check: If this kernel doesn't match any operator, we will skip it.
if (supported_ops.find(*op) == supported_ops.end()) {
continue;
}
// Print OP info.
auto ops_valid_places = supported_ops.at(*op);
for (int i = 0; i < targets.size(); i++) {
if (std::find(ops_valid_places.begin(),
ops_valid_places.end(),
targets[i]) != ops_valid_places.end()) {
std::cout << std::setw(10) << "Y";
} else {
std::cout << std::setw(10) << " ";
}
}
std::cout << std::endl;
}
}
}
/// Print help information
void PrintHelpInfo() {
// at least one argument should be inputed
const char help_info[] =
"At least one argument should be inputed. Valid arguments are listed "
"below:\n"
" Arguments of model optimization:\n"
" `--model_dir=<model_param_dir>`\n"
" `--model_file=<model_path>`\n"
" `--param_file=<param_path>`\n"
" `--optimize_out_type=(protobuf|naive_buffer)`\n"
" `--optimize_out=<output_optimize_model_dir>`\n"
" `--valid_targets=(arm|opencl|x86|npu|xpu)`\n"
" `--prefer_int8_kernel=(true|false)`\n"
" `--record_tailoring_info=(true|false)`\n"
" Arguments of model checking and ops information:\n"
" `--print_all_ops=true` Display all the valid operators of "
"Paddle-Lite\n"
" `--print_supported_ops=true "
"--valid_targets=(arm|opencl|x86|npu|xpu)`"
" Display valid operators of input targets\n"
" `--print_model_ops=true --model_dir=<model_param_dir> "
"--valid_targets=(arm|opencl|x86|npu|xpu)`"
" Display operators in the input model\n";
std::cout << help_info << std::endl;
exit(1);
}
// Parse Input command
void ParseInputCommand() {
if (FLAGS_print_all_ops) {
std::cout << "All OPs supported by Paddle-Lite: " << supported_ops.size()
<< " ops in total." << std::endl;
PrintOpsInfo();
exit(1);
} else if (FLAGS_print_supported_ops) {
auto valid_places = paddle::lite_api::ParserValidPlaces();
// get valid_targets string
std::vector<TargetType> target_types = {};
for (int i = 0; i < valid_places.size(); i++) {
target_types.push_back(valid_places[i].target);
}
std::string targets_str = TargetToStr(target_types[0]);
for (int i = 1; i < target_types.size(); i++) {
targets_str = targets_str + TargetToStr(target_types[i]);
}
std::cout << "Supported OPs on '" << targets_str << "': " << std::endl;
target_types.push_back(TARGET(kHost));
target_types.push_back(TARGET(kUnk));
std::set<std::string> valid_ops;
for (int i = 0; i < target_types.size(); i++) {
auto ops = supported_ops_target[static_cast<int>(target_types[i])];
valid_ops.insert(ops.begin(), ops.end());
}
PrintOpsInfo(valid_ops);
exit(1);
}
}
// test whether this model is supported
void CheckIfModelSupported() {
// 1. parse valid places and valid targets
auto valid_places = paddle::lite_api::ParserValidPlaces();
// set valid_ops
auto valid_ops = supported_ops_target[static_cast<int>(TARGET(kHost))];
auto valid_unktype_ops = supported_ops_target[static_cast<int>(TARGET(kUnk))];
valid_ops.insert(
valid_ops.end(), valid_unktype_ops.begin(), valid_unktype_ops.end());
for (int i = 0; i < valid_places.size(); i++) {
auto target = valid_places[i].target;
auto ops = supported_ops_target[static_cast<int>(target)];
valid_ops.insert(valid_ops.end(), ops.begin(), ops.end());
}
// get valid ops
std::set<std::string> valid_ops_set(valid_ops.begin(), valid_ops.end());
// 2.Load model into program to get ops in model
std::string prog_path = FLAGS_model_dir + "/__model__";
if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) {
prog_path = FLAGS_model_file;
}
lite::cpp::ProgramDesc cpp_prog;
framework::proto::ProgramDesc pb_proto_prog =
*lite::LoadProgram(prog_path, false);
lite::pb::ProgramDesc pb_prog(&pb_proto_prog);
// Transform to cpp::ProgramDesc
lite::TransformProgramDescAnyToCpp(pb_prog, &cpp_prog);
std::set<std::string> unsupported_ops;
std::set<std::string> input_model_ops;
for (int index = 0; index < cpp_prog.BlocksSize(); index++) {
auto current_block = cpp_prog.GetBlock<lite::cpp::BlockDesc>(index);
for (size_t i = 0; i < current_block->OpsSize(); ++i) {
auto& op_desc = *current_block->GetOp<lite::cpp::OpDesc>(i);
auto op_type = op_desc.Type();
input_model_ops.insert(op_type);
if (valid_ops_set.count(op_type) == 0) {
unsupported_ops.insert(op_type);
}
}
}
// 3. Print ops_info of input model and check if this model is supported
if (FLAGS_print_model_ops) {
std::cout << "OPs in the input model include:\n";
PrintOpsInfo(input_model_ops);
}
if (!unsupported_ops.empty()) {
std::string unsupported_ops_str = *unsupported_ops.begin();
for (auto op_str = ++unsupported_ops.begin();
op_str != unsupported_ops.end();
op_str++) {
unsupported_ops_str = unsupported_ops_str + ", " + *op_str;
}
std::vector<TargetType> targets = {};
for (int i = 0; i < valid_places.size(); i++) {
targets.push_back(valid_places[i].target);
}
std::sort(targets.begin(), targets.end());
targets.erase(unique(targets.begin(), targets.end()), targets.end());
std::string targets_str = TargetToStr(targets[0]);
for (int i = 1; i < targets.size(); i++) {
targets_str = targets_str + "," + TargetToStr(targets[i]);
}
LOG(ERROR) << "Error: This model is not supported, because "
<< unsupported_ops.size() << " ops are not supported on '"
<< targets_str << "'. These unsupported ops are: '"
<< unsupported_ops_str << "'.";
exit(1);
}
if (FLAGS_print_model_ops) {
std::cout << "Paddle-Lite supports this model!" << std::endl;
exit(1);
}
}
void Main() { void Main() {
if (FLAGS_display_kernels) { if (FLAGS_display_kernels) {
...@@ -241,7 +445,13 @@ void Main() { ...@@ -241,7 +445,13 @@ void Main() {
} // namespace paddle } // namespace paddle
int main(int argc, char** argv) { int main(int argc, char** argv) {
// If there is none input argument, print help info.
if (argc < 2) {
paddle::lite_api::PrintHelpInfo();
}
google::ParseCommandLineFlags(&argc, &argv, false); google::ParseCommandLineFlags(&argc, &argv, false);
paddle::lite_api::ParseInputCommand();
paddle::lite_api::CheckIfModelSupported();
paddle::lite_api::Main(); paddle::lite_api::Main();
return 0; return 0;
} }
...@@ -95,7 +95,15 @@ add_custom_command( ...@@ -95,7 +95,15 @@ add_custom_command(
add_custom_target(op_list_h DEPENDS ops.h) add_custom_target(op_list_h DEPENDS ops.h)
add_custom_target(kernel_list_h DEPENDS kernels.h) add_custom_target(kernel_list_h DEPENDS kernels.h)
add_custom_target(all_kernel_faked_cc DEPENDS all_kernel_faked.cc) add_custom_target(all_kernel_faked_cc DEPENDS all_kernel_faked.cc)
# create headfile to restore ops info sorted by suppported platforms
add_custom_command(
COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/record_supported_kernel_op.py
${kernels_src_list}
${ops_src_list}
${CMAKE_BINARY_DIR}/supported_kernel_op_info.h
OUTPUT supported_kernel_op_info.h # not a real path to the output to force it execute every time.
)
add_custom_target(supported_kernel_op_info_h DEPENDS supported_kernel_op_info.h)
#----------------------------------------------- NOT CHANGE ----------------------------------------------- #----------------------------------------------- NOT CHANGE -----------------------------------------------
lite_cc_library(kernel SRCS kernel.cc lite_cc_library(kernel SRCS kernel.cc
DEPS context type_system target_wrapper any op_params tensor DEPS context type_system target_wrapper any op_params tensor
......
# NOTE we leave the add_kernel not protected by LITE_WITH_LIGHT_WEIGHT_FRAMEWORK so that all the kernels will be registered # NOTE we leave the add_kernel not protected by LITE_WITH_LIGHT_WEIGHT_FRAMEWORK so that all the kernels will be registered
# to the model_optimize_tool. # to the model_optimize_tool.
if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) if((NOT LITE_ON_MODEL_OPTIMIZE_TOOL) AND (NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)))
return() return()
endif() endif()
......
if(NOT LITE_WITH_CUDA) if((NOT LITE_ON_MODEL_OPTIMIZE_TOOL) AND (NOT LITE_WITH_CUDA))
return() return()
endif() endif()
......
if (NOT LITE_WITH_FPGA) if ((NOT LITE_ON_MODEL_OPTIMIZE_TOOL) AND (NOT LITE_WITH_FPGA))
return() return()
endif() endif()
......
...@@ -14,7 +14,7 @@ add_kernel(pool_opencl OPENCL basic SRCS pool_compute.cc DEPS ${cl_kernel_deps}) ...@@ -14,7 +14,7 @@ add_kernel(pool_opencl OPENCL basic SRCS pool_compute.cc DEPS ${cl_kernel_deps})
add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${tensor_lite} ${cl_kernel_deps}) add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${tensor_lite} ${cl_kernel_deps})
add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps}) add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps})
add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps}) add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps})
add_kernel(conv2d_1x1_opencl OPENCL basic SRCS conv2d_1x1_compute.cc DEPS ${cl_kernel_deps}) #add_kernel(conv2d_1x1_opencl OPENCL basic SRCS conv2d_1x1_compute.cc DEPS ${cl_kernel_deps})
add_kernel(reshape_opencl OPENCL basic SRCS reshape_compute.cc DEPS ${cl_kernel_deps}) add_kernel(reshape_opencl OPENCL basic SRCS reshape_compute.cc DEPS ${cl_kernel_deps})
add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps}) add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps})
add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps}) add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps})
...@@ -49,12 +49,14 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc ...@@ -49,12 +49,14 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc
DEPS depthwise_conv2d_opencl op_registry program context cl_image_converter DEPS depthwise_conv2d_opencl op_registry program context cl_image_converter
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc #lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc
DEPS conv2d_1x1_opencl cl_image_converter op_registry program context # DEPS conv2d_1x1_opencl cl_image_converter op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) # ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_reshape_opencl SRCS reshape_compute_test.cc lite_cc_test(test_reshape_opencl SRCS reshape_compute_test.cc
DEPS reshape_opencl cl_image_converter op_registry program context DEPS reshape_opencl cl_image_converter op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_conv_opencl SRCS conv_compute_test.cc lite_cc_test(test_conv_opencl SRCS conv_compute_test.cc
DEPS conv_opencl op_registry program context DEPS conv_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
......
...@@ -54,7 +54,7 @@ bool CompareOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -54,7 +54,7 @@ bool CompareOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(equal, paddle::lite::operators::CompareOp); REGISTER_LITE_OP(equal, paddle::lite::operators::CompareOp);
REGISTER_LITE_OP(notequal, paddle::lite::operators::CompareOp); REGISTER_LITE_OP(not_equal, paddle::lite::operators::CompareOp);
REGISTER_LITE_OP(less_than, paddle::lite::operators::CompareOp); REGISTER_LITE_OP(less_than, paddle::lite::operators::CompareOp);
REGISTER_LITE_OP(less_equal, paddle::lite::operators::CompareOp); REGISTER_LITE_OP(less_equal, paddle::lite::operators::CompareOp);
REGISTER_LITE_OP(greater_than, paddle::lite::operators::CompareOp); REGISTER_LITE_OP(greater_than, paddle::lite::operators::CompareOp);
......
...@@ -18,6 +18,9 @@ import logging ...@@ -18,6 +18,9 @@ import logging
from ast import RegisterLiteKernelParser from ast import RegisterLiteKernelParser
from utils import * from utils import *
if len(sys.argv) != 4:
print("Error: create_fake_kernel_registry.py requires three inputs!")
exit(1)
ops_list_path = sys.argv[1] ops_list_path = sys.argv[1]
dest_path = sys.argv[2] dest_path = sys.argv[2]
kernelmap_path = sys.argv[3] kernelmap_path = sys.argv[3]
......
...@@ -12,10 +12,14 @@ ...@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import sys import sys
import logging import logging
from ast import RegisterLiteKernelParser from ast import RegisterLiteKernelParser
if len(sys.argv) != 5:
print("Error: parse_kernel_registry.py requires four inputs!")
exit(1)
ops_list_path = sys.argv[1] ops_list_path = sys.argv[1]
dest_path = sys.argv[2] dest_path = sys.argv[2]
minkernels_list_path = sys.argv[3] minkernels_list_path = sys.argv[3]
......
...@@ -13,10 +13,14 @@ ...@@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
''' Collect op registry information. ''' ''' Collect op registry information. '''
from __future__ import print_function
import sys import sys
import logging import logging
from ast import RegisterLiteOpParser from ast import RegisterLiteOpParser
if len(sys.argv) != 5:
print("Error: parse_op_registry.py requires four inputs!")
exit(1)
ops_list_path = sys.argv[1] ops_list_path = sys.argv[1]
dest_path = sys.argv[2] dest_path = sys.argv[2]
minops_list_path = sys.argv[3] minops_list_path = sys.argv[3]
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
import logging
from ast import RegisterLiteKernelParser
from ast import RegisterLiteOpParser
if len(sys.argv) != 4:
print("Error: record_supported_kernel_op.py requires three inputs!")
exit(1)
kernels_list_path = sys.argv[1]
ops_list_path = sys.argv[2]
kernel_op_map_dest_path = sys.argv[3]
out_lines = [
'''
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include<vector>
#include<map>
#include<string>
const std::vector<std::vector<std::string>> supported_ops_target = {
'''
]
ops_lines=[]
# valid targets and valid_ops
valid_targets = ["kUnk", "kHost", "kX86", "kCUDA", "kARM", "kOpenCL", "kAny", "kFPGA", "kNPU", "kXPU"]
valid_ops = [[],[],[],[],[],[],[],[],[],[]]
class TargetType:
kUnk = 0
kHost = 1
kX86 = 2
kCUDA = 3
kARM = 4
kOpenCL = 5
kFPGA = 7
kNPU = 8
kXPU = 9
kAny = 6 # any target
# record op_info of valid kernels into `valid_ops` according to different target type
with open(kernels_list_path) as f:
paths = set([path for path in f])
for path in paths:
with open(path.strip()) as g:
c = g.read()
kernel_parser = RegisterLiteKernelParser(c)
kernel_parser.parse()
for k in kernel_parser.kernels:
if hasattr(TargetType, k.target):
index=getattr(TargetType, k.target)
valid_ops[index].append(k.op_type)
# clear the repeated ops
for target in valid_targets:
index = getattr(TargetType, target)
valid_ops[index] = list(set(valid_ops[index]))
paths = set()
with open(ops_list_path) as f:
paths = set([path for path in f])
for path in paths:
str_info = open(path.strip()).read()
op_parser = RegisterLiteOpParser(str_info)
ops = op_parser.parse()
for op in ops:
if "_grad" in op:
continue
out = ' {"%s", { "' % op
op_targets = []
for target in valid_targets:
if op in valid_ops[getattr(TargetType, target)]:
op_targets.append(target)
if len(op_targets) > 0:
out = out +'", "'.join(op_targets)+ '" }}'
else:
# unknow type op: kUnk = 0
valid_ops[0].append(op)
out = out +'kUnk" }}'
ops_lines.append(out)
with open(kernel_op_map_dest_path, 'w') as f:
logging.info("write kernel list to %s" % kernel_op_map_dest_path)
f.write('\n'.join(out_lines))
# write kernels into head file
for target in valid_targets:
if len(valid_ops[getattr(TargetType, target)]) == 0 :
f.write("\n // %s_OPS: " %target)
f.write('\n {},')
else:
f.write("\n // %s_OPS: " %target)
f.write('\n {"')
f.write('","'.join(valid_ops[getattr(TargetType, target)]))
f.write('"},\n')
f.write('};')
# write op info into head file
f.write('\nconst std::map<std::string, std::vector<std::string>> supported_ops={\n')
f.write(',\n'.join(ops_lines))
f.write('\n};')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册