未验证 提交 24d37695 编写于 作者: H huzhiqiang 提交者: GitHub

[Parl] Add CxxPredictor->Clone() method (#3759)

上级 9f767f94
...@@ -31,14 +31,14 @@ void Predictor::SaveModel(const std::string &dir, ...@@ -31,14 +31,14 @@ void Predictor::SaveModel(const std::string &dir,
if (!program_) { if (!program_) {
GenRuntimeProgram(); GenRuntimeProgram();
} }
program_->SaveOpInfosToProgram(&program_desc_); program_->SaveOpInfosToProgram(program_desc_.get());
program_->UpdateVarsOfProgram(&program_desc_); program_->UpdateVarsOfProgram(program_desc_.get());
switch (model_type) { switch (model_type) {
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
SaveModelPb(dir, *program_->exec_scope(), program_desc_, true); SaveModelPb(dir, *program_->exec_scope(), *program_desc_.get(), true);
break; break;
case lite_api::LiteModelType::kNaiveBuffer: case lite_api::LiteModelType::kNaiveBuffer:
SaveModelNaive(dir, *program_->exec_scope(), program_desc_); SaveModelNaive(dir, *program_->exec_scope(), *program_desc_.get());
break; break;
default: default:
LOG(FATAL) << "Unknown model type"; LOG(FATAL) << "Unknown model type";
...@@ -232,9 +232,8 @@ std::vector<const lite::Tensor *> Predictor::GetOutputs() const { ...@@ -232,9 +232,8 @@ std::vector<const lite::Tensor *> Predictor::GetOutputs() const {
#endif #endif
const cpp::ProgramDesc &Predictor::program_desc() const { const cpp::ProgramDesc &Predictor::program_desc() const {
return program_desc_; return *program_desc_.get();
} }
const RuntimeProgram &Predictor::runtime_program() const { return *program_; } const RuntimeProgram &Predictor::runtime_program() const { return *program_; }
void Predictor::Build(const lite_api::CxxConfig &config, void Predictor::Build(const lite_api::CxxConfig &config,
...@@ -276,14 +275,14 @@ void Predictor::Build(const std::string &model_path, ...@@ -276,14 +275,14 @@ void Predictor::Build(const std::string &model_path,
model_file, model_file,
param_file, param_file,
scope_.get(), scope_.get(),
&program_desc_, program_desc_.get(),
combined_param, combined_param,
model_from_memory); model_from_memory);
} break; } break;
case lite_api::LiteModelType::kNaiveBuffer: case lite_api::LiteModelType::kNaiveBuffer:
CHECK(!model_path.empty()) CHECK(!model_path.empty())
<< "NaiveBuffer backend only supported combined param"; << "NaiveBuffer backend only supported combined param";
LoadModelNaiveFromFile(model_path, scope_.get(), &program_desc_); LoadModelNaiveFromFile(model_path, scope_.get(), program_desc_.get());
break; break;
default: default:
LOG(FATAL) << "Unknown model type"; LOG(FATAL) << "Unknown model type";
...@@ -291,7 +290,7 @@ void Predictor::Build(const std::string &model_path, ...@@ -291,7 +290,7 @@ void Predictor::Build(const std::string &model_path,
Build(program_desc_, valid_places, passes); Build(program_desc_, valid_places, passes);
} }
void Predictor::Build(const cpp::ProgramDesc &desc, void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &desc,
const std::vector<Place> &valid_places, const std::vector<Place> &valid_places,
const std::vector<std::string> &passes) { const std::vector<std::string> &passes) {
program_desc_ = desc; program_desc_ = desc;
...@@ -313,9 +312,9 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -313,9 +312,9 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
"fake_dequantize_max_abs", "fake_dequantize_max_abs",
"fake_channel_wise_dequantize_max_abs"}; "fake_channel_wise_dequantize_max_abs"};
bool is_quantized_model = false; bool is_quantized_model = false;
for (size_t i = 0; i < program_desc_.BlocksSize() && !is_quantized_model; for (size_t i = 0; i < program_desc_->BlocksSize() && !is_quantized_model;
++i) { ++i) {
auto *block_desc = program_desc_.GetBlock<cpp::BlockDesc>(i); auto *block_desc = program_desc_->GetBlock<cpp::BlockDesc>(i);
for (size_t j = 0; j < block_desc->OpsSize() && !is_quantized_model; ++j) { for (size_t j = 0; j < block_desc->OpsSize() && !is_quantized_model; ++j) {
auto *op_desc = block_desc->GetOp<cpp::OpDesc>(j); auto *op_desc = block_desc->GetOp<cpp::OpDesc>(j);
std::string op_type = op_desc->Type(); std::string op_type = op_desc->Type();
...@@ -333,7 +332,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -333,7 +332,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
#endif #endif
} }
Program program(desc, scope_, inner_places); Program program(*desc.get(), scope_, inner_places);
valid_places_ = inner_places;
core::KernelPickFactor factor; core::KernelPickFactor factor;
factor.ConsiderTarget(); factor.ConsiderTarget();
......
...@@ -42,11 +42,24 @@ static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list"; ...@@ -42,11 +42,24 @@ static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list";
class LITE_API Predictor { class LITE_API Predictor {
public: public:
// Create an empty predictor. // Create an empty predictor.
Predictor() { scope_ = std::make_shared<Scope>(); } Predictor() {
scope_ = std::make_shared<Scope>();
program_desc_ = std::make_shared<cpp::ProgramDesc>();
}
// Create a predictor with the weight variable scope set. // Create a predictor with the weight variable scope set.
explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope) explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
: scope_(root_scope) {} : scope_(root_scope) {}
Predictor(const std::shared_ptr<cpp::ProgramDesc>& desc,
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places,
const std::vector<std::string>& var_names = {})
: program_desc_(desc), scope_(root) {
Program program(*desc.get(), scope_, valid_places, var_names);
optimizer_ = Optimizer(std::move(program), valid_places);
exec_scope_ = optimizer_.exec_scope();
valid_places_ = valid_places;
}
// Build from a model, with places set for hardware config. // Build from a model, with places set for hardware config.
void Build( void Build(
...@@ -64,10 +77,35 @@ class LITE_API Predictor { ...@@ -64,10 +77,35 @@ class LITE_API Predictor {
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
bool memory_from_memory = false); bool memory_from_memory = false);
void Build(const cpp::ProgramDesc& desc, void Build(const std::shared_ptr<cpp::ProgramDesc>& desc,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}); const std::vector<std::string>& passes = {});
std::shared_ptr<Predictor> Clone() const {
auto predictor =
std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
return predictor;
}
std::shared_ptr<Predictor> Clone(
const std::vector<std::string>& var_names) const {
CHECK(program_desc_) << "Both program and scope of current predicotr "
"should be not be nullptr in Clone mode.";
CHECK(scope_) << "Both program and scope of current predicotr should be "
"not be nullptr in Clone mode.";
auto predictor = std::make_shared<Predictor>(
program_desc_, scope_, valid_places_, var_names);
for (auto i : var_names) {
predictor->exec_scope_->LocalVar(i);
auto* tensor = predictor->scope_->Var(i)->GetMutable<lite::Tensor>();
auto* sub_tensor =
predictor->exec_scope_->Var(i)->GetMutable<lite::Tensor>();
sub_tensor->CopyDataFrom(*tensor);
}
return predictor;
}
void GenRuntimeProgram(); void GenRuntimeProgram();
// Run the predictor for a single batch of data. // Run the predictor for a single batch of data.
...@@ -119,18 +157,26 @@ class LITE_API Predictor { ...@@ -119,18 +157,26 @@ class LITE_API Predictor {
private: private:
Optimizer optimizer_; Optimizer optimizer_;
cpp::ProgramDesc program_desc_; std::shared_ptr<cpp::ProgramDesc> program_desc_;
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
const Scope* exec_scope_; Scope* exec_scope_;
std::unique_ptr<RuntimeProgram> program_; std::unique_ptr<RuntimeProgram> program_;
bool program_generated_{false}; bool program_generated_{false};
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::vector<std::string> output_names_; std::vector<std::string> output_names_;
std::vector<Place> valid_places_;
}; };
class CxxPaddleApiImpl : public lite_api::PaddlePredictor { class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
public: public:
CxxPaddleApiImpl() {} CxxPaddleApiImpl() {
raw_predictor_ = std::make_shared<Predictor>();
status_is_cloned_ = false;
}
explicit CxxPaddleApiImpl(const std::shared_ptr<Predictor>& raw_predictor)
: raw_predictor_(raw_predictor) {
status_is_cloned_ = true;
}
/// Create a new predictor from a config. /// Create a new predictor from a config.
void Init(const lite_api::CxxConfig& config); void Init(const lite_api::CxxConfig& config);
...@@ -143,6 +189,9 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { ...@@ -143,6 +189,9 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
std::shared_ptr<lite_api::PaddlePredictor> Clone() override; std::shared_ptr<lite_api::PaddlePredictor> Clone() override;
std::shared_ptr<lite_api::PaddlePredictor> Clone(
const std::vector<std::string>& var_names) override;
std::string GetVersion() const override; std::string GetVersion() const override;
// get inputs names and get outputs names // get inputs names and get outputs names
...@@ -168,9 +217,10 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { ...@@ -168,9 +217,10 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
bool record_info = false) override; bool record_info = false) override;
private: private:
Predictor raw_predictor_; std::shared_ptr<Predictor> raw_predictor_;
lite_api::CxxConfig config_; lite_api::CxxConfig config_;
std::mutex mutex_; std::mutex mutex_;
bool status_is_cloned_;
}; };
/* /*
......
...@@ -34,87 +34,102 @@ namespace lite { ...@@ -34,87 +34,102 @@ namespace lite {
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
config_ = config; config_ = config;
auto places = config.valid_places(); if (!status_is_cloned_) {
std::vector<std::string> passes = config.get_passes_internal(); auto places = config.valid_places();
std::vector<std::string> passes = config.get_passes_internal();
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
// if kCUDA is included in valid places, it should be initialized first, // if kCUDA is included in valid places, it should be initialized first,
// otherwise skip this step. // otherwise skip this step.
for (auto &p : places) { for (auto &p : places) {
if (p.target == TARGET(kCUDA)) { if (p.target == TARGET(kCUDA)) {
Env<TARGET(kCUDA)>::Init(); Env<TARGET(kCUDA)>::Init();
if (config_.multi_stream()) { if (config_.multi_stream()) {
passes = {"multi_stream_analysis_pass"}; passes = {"multi_stream_analysis_pass"};
VLOG(3) << "add pass: " << passes[0]; VLOG(3) << "add pass: " << passes[0];
}
break;
} }
break;
} }
}
#endif #endif
#ifdef LITE_WITH_MLU #ifdef LITE_WITH_MLU
Env<TARGET(kMLU)>::Init(); Env<TARGET(kMLU)>::Init();
lite::DeviceInfo::Global().SetMLURunMode(config.mlu_core_version(), lite::DeviceInfo::Global().SetMLURunMode(config.mlu_core_version(),
config.mlu_core_number(), config.mlu_core_number(),
config.mlu_use_first_conv(), config.mlu_use_first_conv(),
config.mlu_first_conv_mean(), config.mlu_first_conv_mean(),
config.mlu_first_conv_std(), config.mlu_first_conv_std(),
config.mlu_input_layout()); config.mlu_input_layout());
#endif // LITE_WITH_MLU #endif // LITE_WITH_MLU
auto use_layout_preprocess_pass = auto use_layout_preprocess_pass =
config.model_dir().find("OPENCL_PRE_PRECESS"); config.model_dir().find("OPENCL_PRE_PRECESS");
VLOG(1) << "use_layout_preprocess_pass:" << use_layout_preprocess_pass; VLOG(1) << "use_layout_preprocess_pass:" << use_layout_preprocess_pass;
if (places[0].target == TARGET(kOpenCL) && if (places[0].target == TARGET(kOpenCL) &&
use_layout_preprocess_pass != std::string::npos) { use_layout_preprocess_pass != std::string::npos) {
passes = {"type_layout_cast_preprocess_pass"}; passes = {"type_layout_cast_preprocess_pass"};
VLOG(1) << "add pass:" << passes[0]; VLOG(1) << "add pass:" << passes[0];
}
raw_predictor_->Build(config, places, passes);
} else {
raw_predictor_->PrepareFeedFetch();
CHECK(raw_predictor_) << "The Predictor can not be nullptr in Clone mode.";
} }
raw_predictor_.Build(config, places, passes);
mode_ = config.power_mode(); mode_ = config.power_mode();
threads_ = config.threads(); threads_ = config.threads();
#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ #if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \
!(defined LITE_ON_MODEL_OPTIMIZE_TOOL) && !defined(__APPLE__) !(defined LITE_ON_MODEL_OPTIMIZE_TOOL)
int num_threads = config.x86_math_library_num_threads(); int num_threads = config.x86_math_library_num_threads();
int real_num_threads = num_threads > 1 ? num_threads : 1; int real_num_threads = num_threads > 1 ? num_threads : 1;
paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads); paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads);
omp_set_num_threads(real_num_threads); omp_set_num_threads(real_num_threads);
VLOG(3) << "set_x86_math_library_math_threads() is set successfully and the " VLOG(3) << "set_x86_math_library_math_threads() is set successfully and the "
"number of threads is:" "number of threads is:"
<< num_threads; << real_num_threads;
#endif #endif
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) { std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
auto *x = raw_predictor_.GetInput(i); auto *x = raw_predictor_->GetInput(i);
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
} }
std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput( std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
int i) const { int i) const {
const auto *x = raw_predictor_.GetOutput(i); const auto *x = raw_predictor_->GetOutput(i);
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
} }
std::vector<std::string> CxxPaddleApiImpl::GetInputNames() { std::vector<std::string> CxxPaddleApiImpl::GetInputNames() {
return raw_predictor_.GetInputNames(); return raw_predictor_->GetInputNames();
} }
std::vector<std::string> CxxPaddleApiImpl::GetParamNames() { std::vector<std::string> CxxPaddleApiImpl::GetParamNames() {
return raw_predictor_.GetParamNames(); return raw_predictor_->GetParamNames();
} }
std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() { std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_.GetOutputNames(); return raw_predictor_->GetOutputNames();
} }
void CxxPaddleApiImpl::Run() { void CxxPaddleApiImpl::Run() {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
lite::DeviceInfo::Global().SetRunMode(mode_, threads_); lite::DeviceInfo::Global().SetRunMode(mode_, threads_);
#endif #endif
raw_predictor_.Run(); raw_predictor_->Run();
} }
std::shared_ptr<lite_api::PaddlePredictor> CxxPaddleApiImpl::Clone() { std::shared_ptr<lite_api::PaddlePredictor> CxxPaddleApiImpl::Clone() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto predictor = std::make_shared<lite::CxxPaddleApiImpl>(); auto predictor =
std::make_shared<lite::CxxPaddleApiImpl>(raw_predictor_->Clone());
predictor->Init(config_);
return predictor;
}
std::shared_ptr<lite_api::PaddlePredictor> CxxPaddleApiImpl::Clone(
const std::vector<std::string> &var_names) {
std::lock_guard<std::mutex> lock(mutex_);
auto predictor = std::make_shared<lite::CxxPaddleApiImpl>(
raw_predictor_->Clone(var_names));
predictor->Init(config_); predictor->Init(config_);
return predictor; return predictor;
} }
...@@ -123,26 +138,26 @@ std::string CxxPaddleApiImpl::GetVersion() const { return version(); } ...@@ -123,26 +138,26 @@ std::string CxxPaddleApiImpl::GetVersion() const { return version(); }
std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor( std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor(
const std::string &name) const { const std::string &name) const {
auto *x = raw_predictor_.GetTensor(name); auto *x = raw_predictor_->GetTensor(name);
return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x));
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetMutableTensor( std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetMutableTensor(
const std::string &name) { const std::string &name) {
return std::unique_ptr<lite_api::Tensor>( return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(raw_predictor_.GetMutableTensor(name))); new lite_api::Tensor(raw_predictor_->GetMutableTensor(name)));
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInputByName( std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInputByName(
const std::string &name) { const std::string &name) {
return std::unique_ptr<lite_api::Tensor>( return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(raw_predictor_.GetInputByName(name))); new lite_api::Tensor(raw_predictor_->GetInputByName(name)));
} }
void CxxPaddleApiImpl::SaveOptimizedModel(const std::string &model_dir, void CxxPaddleApiImpl::SaveOptimizedModel(const std::string &model_dir,
lite_api::LiteModelType model_type, lite_api::LiteModelType model_type,
bool record_info) { bool record_info) {
raw_predictor_.SaveModel(model_dir, model_type, record_info); raw_predictor_->SaveModel(model_dir, model_type, record_info);
} }
} // namespace lite } // namespace lite
......
...@@ -53,6 +53,44 @@ TEST(CXXApi, save_model) { ...@@ -53,6 +53,44 @@ TEST(CXXApi, save_model) {
lite_api::LiteModelType::kNaiveBuffer); lite_api::LiteModelType::kNaiveBuffer);
} }
TEST(CXXApi, clone_predictor) {
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kX86), PRECISION(kFloat)}});
predictor.Build(FLAGS_model_dir, "", "", valid_places);
auto cloned_predictor = predictor.Clone();
// primary predicotr
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(std::vector<int64_t>({1, 100}));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100; i++) {
data[i] = 1;
}
predictor.Run();
auto* output_tensor = predictor.GetOutput(0);
auto output_shape = output_tensor->dims().Vectorize();
ASSERT_EQ(output_shape.size(), 2);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 500);
// cloned predictor
auto* cloned_input_tensor = cloned_predictor->GetInput(0);
cloned_input_tensor->Resize(std::vector<int64_t>({1, 100}));
auto* cloned_data = cloned_input_tensor->mutable_data<float>();
for (int i = 0; i < 100; i++) {
cloned_data[i] = 1;
}
cloned_predictor->Run();
auto* cloned_output_tensor = cloned_predictor->GetOutput(0);
int step = 50;
for (int i = 0; i < output_tensor->data_size(); i += step) {
EXPECT_NEAR(output_tensor->data<float>()[i],
cloned_output_tensor->data<float>()[i],
1e-6);
}
}
/*TEST(CXXTrainer, train) { /*TEST(CXXTrainer, train) {
Place place({TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}); Place place({TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)});
std::vector<Place> valid_places({place}); std::vector<Place> valid_places({place});
......
...@@ -114,7 +114,8 @@ class LightPredictorImpl : public lite_api::PaddlePredictor { ...@@ -114,7 +114,8 @@ class LightPredictorImpl : public lite_api::PaddlePredictor {
void Run() override; void Run() override;
std::shared_ptr<lite_api::PaddlePredictor> Clone() override; std::shared_ptr<lite_api::PaddlePredictor> Clone() override;
std::shared_ptr<lite_api::PaddlePredictor> Clone(
const std::vector<std::string>& var_names) override;
std::string GetVersion() const override; std::string GetVersion() const override;
std::vector<std::string> GetInputNames() override; std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override; std::vector<std::string> GetOutputNames() override;
......
...@@ -66,6 +66,12 @@ std::shared_ptr<lite_api::PaddlePredictor> LightPredictorImpl::Clone() { ...@@ -66,6 +66,12 @@ std::shared_ptr<lite_api::PaddlePredictor> LightPredictorImpl::Clone() {
return nullptr; return nullptr;
} }
std::shared_ptr<lite_api::PaddlePredictor> LightPredictorImpl::Clone(
const std::vector<std::string>& var_names) {
LOG(FATAL) << "The Clone API is not supported in LigthPredictor";
return nullptr;
}
std::string LightPredictorImpl::GetVersion() const { return lite::version(); } std::string LightPredictorImpl::GetVersion() const { return lite::version(); }
std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetTensor( std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetTensor(
......
...@@ -79,6 +79,8 @@ class LITE_API PaddlePredictor { ...@@ -79,6 +79,8 @@ class LITE_API PaddlePredictor {
virtual void Run() = 0; virtual void Run() = 0;
virtual std::shared_ptr<PaddlePredictor> Clone() = 0; virtual std::shared_ptr<PaddlePredictor> Clone() = 0;
virtual std::shared_ptr<PaddlePredictor> Clone(
const std::vector<std::string>& var_names) = 0;
virtual std::string GetVersion() const = 0; virtual std::string GetVersion() const = 0;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "lite/core/mir/generate_program_pass.h" #include "lite/core/mir/generate_program_pass.h"
#include "lite/core/mir/pass_manager.h" #include "lite/core/mir/pass_manager.h"
...@@ -37,6 +38,21 @@ namespace lite { ...@@ -37,6 +38,21 @@ namespace lite {
*/ */
class Optimizer { class Optimizer {
public: public:
Optimizer() {}
Optimizer(Program&& program, const std::vector<Place>& valid_places) {
program_ = &program;
valid_places_ = valid_places;
CHECK(!valid_places.empty()) << "At least one valid_place should be set";
core::KernelPickFactor factor;
factor.ConsiderTarget();
factor.ConsiderPrecision();
factor.ConsiderDataLayout();
Run(std::move(program), valid_places, factor, {});
}
void Run(Program&& program, void Run(Program&& program,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
core::KernelPickFactor kernel_pick_factor, core::KernelPickFactor kernel_pick_factor,
......
...@@ -216,7 +216,8 @@ void Program::Build(const cpp::ProgramDesc& prog) { ...@@ -216,7 +216,8 @@ void Program::Build(const cpp::ProgramDesc& prog) {
} }
} }
void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { void Program::PrepareWorkspace(const cpp::ProgramDesc& prog,
const std::vector<std::string>& var_names) {
CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found"; CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
exec_scope_ = &scope_->NewScope(); exec_scope_ = &scope_->NewScope();
// Create Feed and Fetch var. // Create Feed and Fetch var.
...@@ -274,6 +275,13 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { ...@@ -274,6 +275,13 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) {
} }
} }
} }
for (auto i : var_names) {
exec_scope_->LocalVar(i);
auto* tensor = scope_->Var(i)->GetMutable<lite::Tensor>();
auto* sub_tensor = exec_scope_->Var(i)->GetMutable<lite::Tensor>();
sub_tensor->CopyDataFrom(*tensor);
}
} }
void Instruction::Run() { void Instruction::Run() {
......
...@@ -41,11 +41,12 @@ struct Program { ...@@ -41,11 +41,12 @@ struct Program {
explicit Program(const std::shared_ptr<Scope>& root) { scope_ = root; } explicit Program(const std::shared_ptr<Scope>& root) { scope_ = root; }
Program(const cpp::ProgramDesc& desc, Program(const cpp::ProgramDesc& desc,
const std::shared_ptr<Scope>& root, const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places) const std::vector<Place>& valid_places,
const std::vector<std::string>& var_names = {})
: scope_(root), valid_places_(valid_places), desc_(desc) { : scope_(root), valid_places_(valid_places), desc_(desc) {
CHECK(scope_) << "scope should be init first"; CHECK(scope_) << "scope should be init first";
VLOG(4) << "prepare work"; VLOG(4) << "prepare work";
PrepareWorkspace(desc); PrepareWorkspace(desc, var_names);
VLOG(4) << "build desc"; VLOG(4) << "build desc";
Build(desc); Build(desc);
VLOG(4) << "build desc finished"; VLOG(4) << "build desc finished";
...@@ -75,7 +76,8 @@ struct Program { ...@@ -75,7 +76,8 @@ struct Program {
// Build from a program and scope. // Build from a program and scope.
void Build(const cpp::ProgramDesc& program); void Build(const cpp::ProgramDesc& program);
// Create temporary variables. // Create temporary variables.
void PrepareWorkspace(const cpp::ProgramDesc& program); void PrepareWorkspace(const cpp::ProgramDesc& program,
const std::vector<std::string>& var_names = {});
private: private:
std::map<std::string, PrecisionType> var_data_type_; std::map<std::string, PrecisionType> var_data_type_;
......
...@@ -13,11 +13,16 @@ ...@@ -13,11 +13,16 @@
// limitations under the License. // limitations under the License.
#include "lite/core/scope.h" #include "lite/core/scope.h"
#define SCOPE_KIDS_READER_LOCK lite::fluid::AutoRDLock auto_lock(kids_lock_);
#define SCOPE_KIDS_WRITER_LOCK lite::fluid::AutoWRLock auto_lock(kids_lock_);
#define SCOPE_VARS_READER_LOCK lite::fluid::AutoRDLock auto_lock(vars_lock_);
#define SCOPE_VARS_WRITER_LOCK lite::fluid::AutoWRLock auto_lock(vars_lock_);
namespace paddle { namespace paddle {
namespace lite { namespace lite {
Scope::~Scope() { Scope::~Scope() {
SCOPE_KIDS_WRITER_LOCK
for (auto *x : kids_) { for (auto *x : kids_) {
if (x) { if (x) {
delete x; delete x;
...@@ -26,15 +31,25 @@ Scope::~Scope() { ...@@ -26,15 +31,25 @@ Scope::~Scope() {
} }
Scope &Scope::NewScope() const { Scope &Scope::NewScope() const {
SCOPE_KIDS_WRITER_LOCK
kids_.push_back(new Scope); kids_.push_back(new Scope);
kids_.back()->parent_ = this; kids_.back()->parent_ = this;
return *kids_.back(); return *kids_.back();
} }
Variable *Scope::Var(const std::string &name) { Variable *Scope::Var(const std::string &name) {
SCOPE_VARS_WRITER_LOCK
auto *var = FindVar(name); auto *var = FindVar(name);
if (var) return var; if (var) return var;
// create a new variable.
vars_.emplace(name, std::unique_ptr<Variable>(new Variable));
return vars_[name].get();
}
Variable *Scope::LocalVar(const std::string &name) {
SCOPE_VARS_WRITER_LOCK
auto *var = FindLocalVar(name);
if (var) return var;
// create a new variable. // create a new variable.
vars_.emplace(name, std::unique_ptr<Variable>(new Variable)); vars_.emplace(name, std::unique_ptr<Variable>(new Variable));
return vars_[name].get(); return vars_[name].get();
...@@ -44,19 +59,23 @@ Variable *Scope::FindVar(const std::string &name) const { ...@@ -44,19 +59,23 @@ Variable *Scope::FindVar(const std::string &name) const {
Variable *var{nullptr}; Variable *var{nullptr};
var = FindLocalVar(name); var = FindLocalVar(name);
const Scope *cur_scope = this; const Scope *cur_scope = this;
rwlock_->RDLock();
while (!var && cur_scope->parent()) { while (!var && cur_scope->parent()) {
cur_scope = cur_scope->parent(); cur_scope = cur_scope->parent();
var = cur_scope->FindLocalVar(name); var = cur_scope->FindLocalVar(name);
} }
rwlock_->UNLock();
return var; return var;
} }
Variable *Scope::FindLocalVar(const std::string &name) const { Variable *Scope::FindLocalVar(const std::string &name) const {
rwlock_->RDLock();
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) { if (it != vars_.end()) {
rwlock_->UNLock();
return it->second.get(); return it->second.get();
} }
rwlock_->UNLock();
return nullptr; return nullptr;
} }
...@@ -85,8 +104,12 @@ std::vector<std::string> Scope::AttributeVarNames() const { ...@@ -85,8 +104,12 @@ std::vector<std::string> Scope::AttributeVarNames() const {
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> keys; std::vector<std::string> keys;
for (const auto &item : vars_) { {
keys.push_back(item.first); rwlock_->RDLock();
for (const auto &item : vars_) {
keys.push_back(item.first);
}
rwlock_->UNLock();
} }
return keys; return keys;
} }
......
...@@ -20,13 +20,18 @@ ...@@ -20,13 +20,18 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/variable.h" #include "lite/core/variable.h"
#include "lite/fluid/rw_lock.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
class Scope final { class Scope final {
public: public:
Scope() {} Scope() {
kids_lock_ = new lite::fluid::RWLock;
vars_lock_ = new lite::fluid::RWLock;
rwlock_.reset(new lite::fluid::RWLock);
}
// delete below two functions to allow pybind to recognise it cannot make a // delete below two functions to allow pybind to recognise it cannot make a
// copy // copy
// link: // link:
...@@ -39,6 +44,8 @@ class Scope final { ...@@ -39,6 +44,8 @@ class Scope final {
Variable* Var(const std::string& name); Variable* Var(const std::string& name);
Variable* LocalVar(const std::string& name);
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
Variable* FindLocalVar(const std::string& name) const; Variable* FindLocalVar(const std::string& name) const;
...@@ -75,6 +82,9 @@ class Scope final { ...@@ -75,6 +82,9 @@ class Scope final {
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
const Scope* parent_{nullptr}; const Scope* parent_{nullptr};
std::map<std::string, std::unique_ptr<Variable>> vars_; std::map<std::string, std::unique_ptr<Variable>> vars_;
lite::fluid::RWLock* kids_lock_{nullptr};
lite::fluid::RWLock* vars_lock_{nullptr};
std::unique_ptr<lite::fluid::RWLock> rwlock_{nullptr};
}; };
} // namespace lite } // namespace lite
......
...@@ -375,6 +375,7 @@ function make_x86 { ...@@ -375,6 +375,7 @@ function make_x86 {
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF \ -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF \
-DLITE_WITH_ARM=OFF \ -DLITE_WITH_ARM=OFF \
-DWITH_GPU=OFF \ -DWITH_GPU=OFF \
-DLITE_SHUTDOWN_LOG=ON \
-DLITE_WITH_PYTHON=${BUILD_PYTHON} \ -DLITE_WITH_PYTHON=${BUILD_PYTHON} \
-DLITE_BUILD_EXTRA=ON \ -DLITE_BUILD_EXTRA=ON \
-DLITE_WITH_LOG=${WITH_LOG} \ -DLITE_WITH_LOG=${WITH_LOG} \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册