提交 29f9aade 编写于 作者: S Superjomn

add light api for mobile

上级 e7f32773
set(cxx_api_lite_deps scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host kernels_cuda optimizer_lite model_parser_lite)
if(LITE_WITH_CUDA)
cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda optimizer_lite)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite)
else()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host)
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels)
cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} target_wrapper_cuda)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda)
endif()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps})
set(light_api_deps
scope_lite host_kernels ops_lite target_wrapper_host model_parser_lite)
if(LITE_WITH_CUDA)
set(light_api_deps ${light_api_deps} target_wrapper_cuda)
endif()
cc_library(light_api_lite SRCS light_api.cc DEPS ${light_api_deps})
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels)
cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite)
......@@ -18,7 +18,7 @@
namespace paddle {
namespace lite {
void Predictor::SaveModel(const std::string &dir) {
void CxxPredictor::SaveModel(const std::string &dir) {
MkDirRecursively(dir.c_str());
program_->PersistModel(dir, program_desc_);
}
......
......@@ -25,9 +25,9 @@ namespace lite {
struct Config {};
class Predictor {
class CxxPredictor {
public:
Predictor() { scope_ = std::make_shared<Scope>(); }
CxxPredictor() { scope_ = std::make_shared<Scope>(); }
void Build(const std::string& model_path, const Place& prefer_place,
const std::vector<Place>& valid_places) {
......
......@@ -22,7 +22,7 @@ namespace paddle {
namespace lite {
TEST(CXXApi, test) {
lite::Predictor predictor;
lite::CxxPredictor predictor;
#ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
#else
......@@ -60,7 +60,7 @@ TEST(CXXApi, test) {
}
TEST(CXXApi, save_model) {
lite::Predictor predictor;
lite::CxxPredictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
predictor.Build("/home/chunwei/project/models/model2",
Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/light_api.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file implements a light-weight API which can run on mobile. We limit the
* dependencies and the runtime computation complexity.
*/
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
namespace paddle {
namespace lite {
class CxxPredictor {
public:
CxxPredictor() { scope_ = std::make_shared<Scope>(); }
void Build(const std::string& model_dir) {
framework::proto::ProgramDesc desc;
LoadModel(model_dir, scope_.get(), &desc);
BuildRuntimeProgram(desc);
}
void Run() { program_->Run(); }
// Get offset-th col of feed.
Tensor* GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope";
auto* feed_list = _feed_list->GetMutable<std::vector<Tensor>>();
if (offset >= feed_list->size()) {
feed_list->resize(offset + 1);
}
return &feed_list->at(offset);
}
const Tensor* GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
}
private:
void BuildRuntimeProgram(const framework::proto::ProgramDesc& prog) {
std::vector<Instruct> insts;
// 1. Create op first
Program program(prog, scope_, {});
// 2. Create Instructs
// Create the kernels of the target places, and filter out the specific
// kernel with the target alias.
for (auto& op : program.ops) {
lite::pb::OpDesc desc(op->op_info()->desc());
auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>();
std::string op_type, alias;
Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
auto kernels = op->CreateKernels({place});
// filter out a kernel
auto it = std::find_if(kernels.begin(), kernels.end(),
[&](std::unique_ptr<KernelBase>& it) {
return it->alias() == alias;
});
CHECK(it != kernels.end());
insts.emplace_back(op, std::move(*it));
}
program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope);
program_->set_exec_scope(program.exec_scope);
}
private:
std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_;
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/light_api.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
const std::string model_dir =
"/home/chunwei/project/Paddle/cmake-build-relwithdebinfo/paddle/fluid/lite/"
"api/optimized_model";
TEST(LightAPI, load) {
CxxPredictor predictor;
predictor.Build(model_dir);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100});
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
predictor.Run();
}
} // namespace lite
} // namespace paddle
USE_LITE_OP(mul);
USE_LITE_OP(fc);
USE_LITE_OP(scale);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
......@@ -21,7 +21,7 @@ cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS kernel_lite target_wrapper_x86)
cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels)
......
......@@ -96,14 +96,40 @@ class KernelBase {
// Generate the key of the parameter type.
std::string GenParamTypeKey() const;
std::string SerializeKernelType() const {
std::string SerializedKernelType() const {
return SerializeKernelType(op_type(), alias(), place());
}
static std::string SerializeKernelType(const std::string& op_type,
const std::string& alias,
const Place& place) {
std::stringstream ss;
ss << op_type() << "/";
ss << alias_ << "/";
ss << place();
ss << op_type << "/";
ss << alias << "/";
// We serialize the place value not the string representation here for
// easier deserialization.
ss << static_cast<int>(place.target) << "/";
ss << static_cast<int>(place.precision) << "/";
ss << static_cast<int>(place.layout);
return ss.str();
}
static void ParseKernelType(const std::string& kernel_type,
std::string* op_type, std::string* alias,
Place* place) {
std::stringstream ss(kernel_type);
std::getline(ss, *op_type, '/');
std::getline(ss, *alias, '/');
std::string target, precision, layout;
std::getline(ss, target, '/');
std::getline(ss, precision, '/');
std::getline(ss, layout, '/');
place->target = static_cast<TargetType>(std::stoi(target));
place->precision = static_cast<PrecisionType>(std::stoi(precision));
place->layout = static_cast<DataLayoutType>(std::stoi(layout));
}
virtual ~KernelBase() = default;
void Torch() {}
......
......@@ -42,6 +42,22 @@ TEST(Kernel, test) {
ASSERT_EQ(test_code, 100);
}
TEST(Kernel, kernel_type) {
const std::string op_type = "fc";
const std::string alias = "def";
Place place(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
auto kernel_type = KernelBase::SerializeKernelType(op_type, alias, place);
LOG(INFO) << "kernel_type: " << kernel_type;
ASSERT_EQ(kernel_type, "fc/def/1/1/1");
std::string op_type1, alias1;
Place place1;
KernelBase::ParseKernelType(kernel_type, &op_type1, &alias1, &place1);
ASSERT_EQ(op_type, op_type1);
ASSERT_EQ(alias, alias1);
ASSERT_EQ(place, place1);
}
} // namespace core
} // namespace lite
} // namespace paddle
......@@ -38,7 +38,7 @@ class GenerateProgramPass : public ProgramPass {
}
private:
std::vector<Instruction> insts_;
std::vector<Instruct> insts_;
};
} // namespace mir
......
......@@ -18,53 +18,48 @@
namespace paddle {
namespace lite {
void RuntimeProgram::PersistModel(const std::string &path,
void RuntimeProgram::PersistModel(const std::string &dir,
const framework::proto::ProgramDesc &desc) {
// Persist model.
const std::string model_path = path + "/__model__";
const std::string model_path = dir + "/__model__";
std::ofstream model_ostream(model_path, std::ios_base::binary);
CHECK(model_ostream.is_open());
const std::string pb_str = SerializeModelTopology(desc);
const std::string pb_str = SerializeProgram(desc);
model_ostream.write(pb_str.c_str(), pb_str.size());
model_ostream.close();
// Persist params.
const std::string params_path = path + "/params";
CHECK(!IsFileExists(params_path)) << "file " << params_path
<< " exists, can't overwrite";
std::ofstream params_ostream(params_path, std::ios_base::binary);
CHECK(params_ostream.is_open());
framework::proto::ProgramDesc latest_program;
latest_program.ParseFromString(pb_str);
SerializeParams(params_ostream, latest_program);
SaveParams(dir, latest_program);
}
std::string RuntimeProgram::SerializeModelTopology(
std::string RuntimeProgram::SerializeProgram(
const framework::proto::ProgramDesc &desc) {
const std::string kKernelTypeAttr = "__@kernel_type_attr@__";
auto program_dummy = desc;
program_dummy.mutable_blocks(0)->clear_ops();
for (auto &node : instructions_) {
auto desc_dummy = node.op()->op_info()->desc();
OpDesc desc(desc_dummy);
desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializeKernelType());
desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType());
// append new opdesc
*program_dummy.mutable_blocks(0)->add_ops() = *desc.Proto();
}
return program_dummy.SerializeAsString();
}
void RuntimeProgram::SerializeParams(
std::ostream &os, const framework::proto::ProgramDesc &desc) {
std::vector<std::string> ws;
void RuntimeProgram::SaveParams(const std::string &dir,
const framework::proto::ProgramDesc &desc) {
CHECK(exec_scope_);
for (auto &item : desc.blocks(0).vars()) {
const std::string path = dir + "/" + item.name();
if (item.name() == "feed" || item.name() == "fetch") continue;
if (item.persistable()) {
ws.push_back(item.name());
std::ofstream file(path, std::ios::binary);
SerializeTensor(file, *exec_scope_, item.name());
file.close();
}
}
CHECK(exec_scope_);
SerializeTensors(os, *exec_scope_, ws);
}
} // namespace lite
......
......@@ -26,6 +26,8 @@
namespace paddle {
namespace lite {
static const std::string kKernelTypeAttr = "__@kernel_type_attr@__";
// A program is used to represent a code program, in Paddle, a code program
// contains:
// - main block, which is a list of OpLite
......@@ -46,8 +48,9 @@ struct Program {
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places)
: scope(root), valid_places(valid_places), desc(desc) {
CHECK(scope) << "scope should be init first";
PrepareWorkspace(desc);
Build(desc, valid_places);
Build(desc);
}
std::unique_ptr<Program> Clone() const {
......@@ -57,8 +60,7 @@ struct Program {
private:
// Build from a program and scope.
void Build(const framework::proto::ProgramDesc& program,
const std::vector<Place>& valid_places) {
void Build(const framework::proto::ProgramDesc& program) {
CHECK(ops.empty()) << "Executor duplicate Build found";
// Create operators.
......@@ -67,7 +69,9 @@ struct Program {
auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]";
ops.emplace_back(LiteOpRegistry::Global().Create(op_type));
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops.emplace_back(op);
ops.back()->Attach(op_desc, exec_scope);
}
}
......@@ -95,9 +99,9 @@ struct Program {
}
};
struct Instruction {
Instruction(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel)
struct Instruct {
Instruct(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel)
: op_(op), kernel_(std::move(kernel)) {}
void Run() {
......@@ -111,7 +115,7 @@ struct Instruction {
kernel_->Run();
}
friend std::ostream& operator<<(std::ostream& os, const Instruction& other) {
friend std::ostream& operator<<(std::ostream& os, const Instruct& other) {
os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")";
return os;
}
......@@ -130,7 +134,7 @@ struct Instruction {
*/
class RuntimeProgram {
public:
explicit RuntimeProgram(std::vector<Instruction>&& insts)
explicit RuntimeProgram(std::vector<Instruct>&& insts)
: instructions_(std::move(insts)) {
if (instructions_.empty()) {
LOG(FATAL) << "no instructions";
......@@ -145,7 +149,7 @@ class RuntimeProgram {
}
// Serialize the graph and save to the disk.
void PersistModel(const std::string& path,
void PersistModel(const std::string& dir,
const framework::proto::ProgramDesc& desc);
void set_exec_scope(lite::Scope* x) { exec_scope_ = x; }
......@@ -154,13 +158,13 @@ class RuntimeProgram {
size_t num_instructions() const { return instructions_.size(); }
protected:
std::string SerializeModelTopology(const framework::proto::ProgramDesc& desc);
void SerializeParams(std::ostream& os,
const framework::proto::ProgramDesc& desc);
std::string SerializeProgram(const framework::proto::ProgramDesc& desc);
void SaveParams(const std::string& dir,
const framework::proto::ProgramDesc& desc);
private:
RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_;
std::vector<Instruct> instructions_;
lite::Scope* exec_scope_{};
};
......
cc_library(target_wrapper_host SRCS target_wrapper.cc)
cc_library(target_wrapper_host SRCS target_wrapper.cc DEPS target_wrapper_lite)
......@@ -6,8 +6,13 @@ else()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto)
endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite
compatible_pb_lite)
set(model_parser_deps variable_lite scope_lite tensor_lite scope_lite
target_wrapper_host
compatible_pb_lite
)
if (LITE_WITH_CUDA)
set(model_parser_deps ${model_parser_deps} target_wrapper_cuda)
endif()
cc_library(model_parser_lite SRCS model_parser.cc DEPS ${model_parser_deps})
add_subdirectory(pb)
......@@ -35,7 +35,7 @@ int SizeOfType(framework::proto::VarType::Type type) {
DO(INT64, int64_t);
#undef DO
default:
LOG(FATAL) << "unknown data type";
LOG(FATAL) << "unknown data type " << type;
}
return -1;
}
......@@ -86,12 +86,12 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
void LoadLoDTensor(std::istream &is, Variable *var) {
auto *tensor = var->GetMutable<lite::Tensor>();
uint32_t version;
uint32_t version{};
is.read(reinterpret_cast<char *>(&version), sizeof(version));
LOG(INFO) << "model version " << version;
// Load LoD information
uint64_t lod_level;
uint64_t lod_level{};
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
......@@ -136,6 +136,7 @@ void LoadParams(const std::string &path) {}
// Load directly to CPU, and latter transfer to other devices.
void LoadParam(const std::string &path, Variable *out) {
std::ifstream fin(path, std::ios::binary);
CHECK(fin.is_open()) << "failed to open file " << path;
LoadLoDTensor(fin, out);
}
......@@ -164,13 +165,12 @@ void LoadModel(const std::string &model_dir, Scope *scope,
}
void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
// the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
{
int size = tensor.lod().size();
uint64_t size = tensor.lod().size();
// the 2st field, LoD information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
......@@ -186,11 +186,15 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
}
}
// There are two version fields in a LoDTensor.
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework::proto::VarType::TensorDesc desc;
desc.set_data_type(framework::proto::VarType_Type_LOD_TENSOR);
// TODO(Superjomn) support other data types.
desc.set_data_type(framework::proto::VarType_Type_FP32);
auto dims = tensor.dims();
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
......@@ -221,14 +225,12 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
}
}
void SerializeTensors(std::ostream &os, const lite::Scope &scope,
const std::vector<std::string> &vars) {
void SerializeTensor(std::ostream &os, const lite::Scope &scope,
const std::string &var_name) {
// Store all the persistable vars.
for (const auto &_var : vars) {
auto *var = scope.FindVar(_var);
const auto &tensor = var->Get<lite::Tensor>();
TensorToStream(os, tensor);
}
auto *var = scope.FindVar(var_name);
const auto &tensor = var->Get<lite::Tensor>();
TensorToStream(os, tensor);
}
} // namespace lite
......
......@@ -41,8 +41,8 @@ void LoadModel(const std::string& model_dir, Scope* scope,
framework::proto::ProgramDesc* prog);
// Serialize tensors to ostream.
void SerializeTensors(std::ostream& os, const lite::Scope& scope,
const std::vector<std::string>& vars);
void SerializeTensor(std::ostream& os, const lite::Scope& scope,
const std::string& var);
// LoDTensor to ostream
void TensorToStream(std::ostream& os, const lite::Tensor& tensor);
......
......@@ -22,7 +22,7 @@ namespace lite {
template <>
void TargetWrapper<TARGET(kX86)>::MemcpySync(void *dst, const void *src,
size_t size, IoDirection dir) {
std::copy_n(reinterpret_cast<uint8_t *>(src), size,
std::copy_n(reinterpret_cast<const uint8_t *>(src), size,
reinterpret_cast<uint8_t *>(dst));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册