提交 9caa75ae 编写于 作者: C Chunwei

enhance gen code

上级 c8bb0af7
......@@ -104,7 +104,7 @@ file(WRITE ${offline_lib_registry_file} "") # clean
# LIGHT_DEPS: LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
# HVY_DEPS: NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
function(lite_cc_library TARGET)
set(options "")
set(options STATIC static SHARED shared)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS
HVY_DEPS ARGS)
......@@ -120,8 +120,11 @@ function(lite_cc_library TARGET)
LIGHT_DEPS ${args_LIGHT_DEPS}
HVY_DEPS ${args_HVY_DEPS}
)
cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS})
if (${args_SHARED} OR ${args_shared})
cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS} SHARED)
else()
cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS})
endif()
# collect targets need to compile for lite
add_dependencies(lite_compile_deps ${TARGET})
......
......@@ -100,14 +100,12 @@ lite_cc_test(test_apis_lite SRCS apis_test.cc
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
--optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
lite_cc_library(cxx_api_impl_lite SRCS cxx_api_impl.cc DEPS cxx_api_lite)
lite_cc_library(light_api_impl_lite SRCS light_api_impl.cc DEPS light_api_lite)
lite_cc_library(paddle_api_lite SRCS paddle_api.cc DEPS op_params_lite)
lite_cc_library(paddle_api_full SRCS paddle_api.cc DEPS cxx_api_impl_lite light_api_impl_lite)
lite_cc_library(paddle_api_light SRCS paddle_api.cc DEPS light_api_impl_lite)
lite_cc_library(paddle_api_full SRCS cxx_api_impl.cc DEPS cxx_api_lite paddle_api_lite light_api_lite)
lite_cc_library(paddle_api_light SRCS light_api_impl.cc DEPS light_api_lite paddle_api_lite)
lite_cc_test(test_paddle_api_lite SRCS paddle_api_test.cc DEPS cxx_api_lite light_api_lite paddle_api_full
lite_cc_test(test_paddle_api_lite SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light
${ops_lite}
ARM_DEPS ${arm_kernels}
X86_DEPS ${x86_kernels}
......@@ -116,7 +114,9 @@ if (WITH_TESTING)
add_dependencies(test_paddle_api_lite extern_lite_download_lite_naive_model_tar_gz)
endif()
#lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
#X86_DEPS operator
#DEPS light_api_lite model_parser_lite target_wrapper_host mir_passes
#ARM_DEPS ${arm_kernels})
lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc DEPS paddle_api_full)
# lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
# X86_DEPS operator
# DEPS light_api_lite model_parser_lite target_wrapper_host mir_passes
# ARM_DEPS ${arm_kernels})
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/lite/api/paddle_api.h"
#include "paddle/fluid/lite/api/paddle_use_kernels.h"
#include "paddle/fluid/lite/api/paddle_use_ops.h"
#include "paddle/fluid/lite/api/paddle_use_passes.h"
#include "paddle/fluid/lite/utils/string.h"
DEFINE_string(model_dir, "", "path of the model");
DEFINE_string(optimize_out, "", "path of the output optimized model");
DEFINE_string(valid_targets, "ARM",
"The targets this model optimized for, should be one of (arm, "
"opencl, x86), splitted by space");
DEFINE_bool(int8_mode, false, "Support Int8 quantitative mode");
namespace paddle {
namespace lite_api {
void Main() {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
std::vector<Place> valid_places;
auto target_reprs = lite::Split(FLAGS_valid_targets, " ");
for (auto& target_repr : target_reprs) {
if (target_repr == "arm") {
valid_places.emplace_back(TARGET(kARM));
} else if (target_repr == "opencl") {
valid_places.emplace_back(TARGET(kOpenCL));
} else if (target_repr == "x86") {
valid_places.emplace_back(TARGET(kX86));
} else {
LOG(FATAL) << lite::string_format(
"Wrong target '%s' found, please check the command flag "
"'valid_targets'",
target_repr.c_str());
}
}
CHECK(!valid_places.empty())
<< "At least one target should be set, should set the "
"command argument 'valid_targets'";
if (FLAGS_int8_mode) {
LOG(WARNING) << "Int8 mode is only support by ARM target";
valid_places.push_back(Place{TARGET(kARM), PRECISION(kInt8)});
config.set_preferred_place(Place{TARGET(kARM), PRECISION(kInt8)});
}
config.set_valid_places(valid_places);
auto predictor = lite_api::CreatePaddlePredictor(config);
predictor->SaveOptimizedModel(FLAGS_optimize_out);
}
} // namespace lite_api
} // namespace paddle
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, false);
paddle::lite_api::Main();
return 0;
}
......@@ -56,6 +56,7 @@ TEST(CxxApi, run) {
predictor->SaveOptimizedModel(FLAGS_model_dir + ".opt2");
}
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(LightApi, run) {
lite_api::MobileConfig config;
config.set_model_dir(FLAGS_model_dir + ".opt2");
......@@ -79,6 +80,7 @@ TEST(LightApi, run) {
EXPECT_NEAR(out[0], 50.2132, 1e-3);
EXPECT_NEAR(out[1], -28.8729, 1e-3);
}
#endif
} // namespace lite_api
} // namespace paddle
......@@ -83,7 +83,7 @@ struct Place {
int16_t device{0}; // device ID
Place() = default;
Place(TargetType target, PrecisionType precision,
Place(TargetType target, PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW), int16_t device = 0)
: target(target), precision(precision), layout(layout), device(device) {}
......
......@@ -31,7 +31,7 @@ cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
lite_cc_library(program_lite SRCS program.cc
DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite
DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite ${ops_lite}
HVY_DEPS framework_proto
PROFILE_DEPS basic_profiler_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
......
......@@ -26,3 +26,5 @@ if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
add_dependencies(__generated_code__ test_gen_code_lite)
add_dependencies(__generated_code__ extern_lite_download_lite_naive_model_tar_gz)
endif()
lite_cc_binary(paddle_code_generator SRCS paddle_code_generator.cc DEPS model_parser_lite gen_code_lite)
......@@ -111,6 +111,15 @@ void Module::AddOpDescHelper(const std::string &op_id,
return std::to_string(desc.GetAttr<bool>(name));
case AttrType::STRING:
return "\"" + desc.GetAttr<std::string>(name) + "\"";
case AttrType::FLOATS: {
auto vals = desc.GetAttr<std::vector<float>>(name);
return "{" + Join(vals, ",") + "}";
}
case AttrType::INTS: {
auto vals = desc.GetAttr<std::vector<int>>(name);
return "{" + Join(vals, ",") + "}";
}
case AttrType::STRINGS: {
std::vector<std::string> tmp;
auto vals = desc.GetAttr<std::vector<std::string>>(name);
......@@ -137,8 +146,12 @@ void Module::AddOpDescHelper(const std::string &op_id,
return "bool";
case AttrType::STRING:
return "std::string";
case AttrType::FLOATS:
return "std::vector<float>";
case AttrType::STRINGS:
return "std::vector<std::string>";
case AttrType::INTS:
return "std::vector<int>";
default:
LOG(FATAL) << "Unsupported attribute type: " << static_cast<int>(type);
}
......@@ -160,6 +173,8 @@ void Module::AddOp(const cpp::OpDesc &op) {
auto op_name = OpUniqueName();
AddOpDescHelper(op_name, op);
LOG(INFO) << "add op " << op_name;
Line(string_format("// Create Op: %s", op.Type().c_str()));
Line(string_format("auto %s = lite::LiteOpRegistry::Global().Create(\"%s\");",
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include "paddle/fluid/lite/gen_code/gen_code.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
DEFINE_string(optimized_model, "", "");
DEFINE_string(generated_code_file, "__generated_code__.cc", "");
namespace paddle {
namespace lite {
namespace gencode {
void GenCode(const std::string& model_dir, const std::string& out_file) {
lite::Scope scope;
framework::proto::ProgramDesc desc;
LoadModel(model_dir, &scope, &desc);
ProgramCodeGenerator codegen(desc, scope);
std::ofstream file(out_file);
file << codegen.GenCode();
file.close();
}
} // namespace gencode
} // namespace lite
} // namespace paddle
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, false);
paddle::lite::gencode::GenCode(FLAGS_optimized_model,
FLAGS_generated_code_file);
return 0;
}
......@@ -4,6 +4,7 @@ set -ex
TESTS_FILE="./lite_tests.txt"
LIBS_FILE="./lite_libs.txt"
readonly ADB_WORK_DIR="/data/local/tmp"
readonly common_flags="-DWITH_LITE=ON -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF -DWITH_PYTHON=OFF -DWITH_TESTING=ON -DLITE_WITH_ARM=OFF"
NUM_CORES_FOR_COMPILE=8
......@@ -183,7 +184,36 @@ function test_arm_model {
adb -s emulator-${port} shell chmod +x "${adb_work_dir}/${test_name}"
local adb_model_path="${adb_work_dir}/`basename ${model_dir}`"
adb -s emulator-${port} shell "${adb_work_dir}/${test_name} --model_dir=$adb_model_path"
}
function _test_model_optimize_tool {
local port=$1
local remote_model_path=$ADB_WORK_DIR/lite_naive_model
local remote_test=$ADB_WORK_DIR/model_optimize_tool
local adb="adb -s emulator-${port}"
make model_optimize_tool -j$NUM_CORES_FOR_COMPILE
local test_path=$(find . -name model_optimize_tool)
local model_path=$(find . -name lite_naive_model)
$adb push ${test_path} ${ADB_WORK_DIR}
$adb shell mkdir -p $remote_model_path
$adb push $model_path/* $remote_model_path
$adb shell $remote_test --model_dir $remote_model_path --optimize_out ${remote_model_path}.opt \
--valid_targets "arm"
}
function _test_paddle_code_generator {
local port=$1
local test_name=paddle_code_generator
local remote_test=$ADB_WORK_DIR/$test_name
local remote_model=$ADB_WORK_DIR/lite_naive_model.opt
local adb="adb -s emulator-${port}"
make paddle_code_generator -j$NUM_CORES_FOR_COMPILE
local test_path=$(find . -name $test_name)
$adb push $test_path $remote_test
$adb shell $remote_test --optimized_model $remote_model --generated_code_file $ADB_WORK_DIR/gen_code.cc
}
function cmake_arm {
......@@ -273,6 +303,9 @@ function test_arm {
# test finally
test_arm_api $port
_test_model_optimize_tool $port
_test_paddle_code_generator $port
}
function prepare_emulator {
......
......@@ -52,8 +52,8 @@ static std::string to_string_with_precision(const T& v, const int n = 6) {
return ss.str();
}
static std::string Join(const std::vector<std::string>& vec,
const std::string& delim) {
template <typename T>
std::string Join(const std::vector<T>& vec, const std::string& delim) {
if (vec.empty()) return "";
std::stringstream ss;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册