提交 b66052dd 编写于 作者: D DannyIsFunny

test=develop

上级 3896590b
...@@ -228,7 +228,6 @@ std::vector<const lite::Tensor *> Predictor::GetOutputs() const { ...@@ -228,7 +228,6 @@ std::vector<const lite::Tensor *> Predictor::GetOutputs() const {
const cpp::ProgramDesc &Predictor::program_desc() const { const cpp::ProgramDesc &Predictor::program_desc() const {
return program_desc_; return program_desc_;
} }
const RuntimeProgram &Predictor::runtime_program() const { return *program_; } const RuntimeProgram &Predictor::runtime_program() const { return *program_; }
void Predictor::Build(const lite_api::CxxConfig &config, void Predictor::Build(const lite_api::CxxConfig &config,
...@@ -294,7 +293,6 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -294,7 +293,6 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
inner_places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); inner_places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny));
inner_places.emplace_back( inner_places.emplace_back(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
const std::vector<std::string> quant_dequant_op = { const std::vector<std::string> quant_dequant_op = {
"fake_quantize_abs_max", "fake_quantize_abs_max",
"fake_quantize_range_abs_max", "fake_quantize_range_abs_max",
...@@ -321,6 +319,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -321,6 +319,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
} }
Program program(desc, scope_, inner_places); Program program(desc, scope_, inner_places);
valid_places_ = inner_places;
core::KernelPickFactor factor; core::KernelPickFactor factor;
factor.ConsiderTarget(); factor.ConsiderTarget();
......
...@@ -46,6 +46,17 @@ class LITE_API Predictor { ...@@ -46,6 +46,17 @@ class LITE_API Predictor {
// Create a predictor with the weight variable scope set. // Create a predictor with the weight variable scope set.
explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope) explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
: scope_(root_scope) {} : scope_(root_scope) {}
Predictor(const cpp::ProgramDesc& desc,
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places)
: program_desc_(desc), scope_(root) {
optimizer_ =
Optimizer(new Program(desc, scope_, valid_places), valid_places);
exec_scope_ = optimizer_.exec_scope();
GenRuntimeProgram();
valid_places_ = valid_places;
PrepareFeedFetch();
}
// Build from a model, with places set for hardware config. // Build from a model, with places set for hardware config.
void Build( void Build(
...@@ -67,6 +78,16 @@ class LITE_API Predictor { ...@@ -67,6 +78,16 @@ class LITE_API Predictor {
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}); const std::vector<std::string>& passes = {});
std::shared_ptr<Predictor> Clone() const {
// CHECK(program_desc_) << "Both program and scope of current predicotr
// should be not be nullptr in Clone mode." ;
// CHECK(scope_) << "Both program and scope of current predicotr should
// be not be nullptr in Clone mode.";
auto predictor =
std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
return predictor;
}
void GenRuntimeProgram(); void GenRuntimeProgram();
// Run the predictor for a single batch of data. // Run the predictor for a single batch of data.
...@@ -119,11 +140,14 @@ class LITE_API Predictor { ...@@ -119,11 +140,14 @@ class LITE_API Predictor {
bool program_generated_{false}; bool program_generated_{false};
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::vector<std::string> output_names_; std::vector<std::string> output_names_;
std::vector<Place> valid_places_;
}; };
class CxxPaddleApiImpl : public lite_api::PaddlePredictor { class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
public: public:
CxxPaddleApiImpl() {} CxxPaddleApiImpl() {}
explicit CxxPaddleApiImpl(const std::shared_ptr<Predictor>& raw_predictor)
: raw_predictor_(raw_predictor) {}
/// Create a new predictor from a config. /// Create a new predictor from a config.
void Init(const lite_api::CxxConfig& config); void Init(const lite_api::CxxConfig& config);
...@@ -155,9 +179,10 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { ...@@ -155,9 +179,10 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
bool record_info = false) override; bool record_info = false) override;
private: private:
Predictor raw_predictor_; std::shared_ptr<Predictor> raw_predictor_;
lite_api::CxxConfig config_; lite_api::CxxConfig config_;
std::mutex mutex_; std::mutex mutex_;
bool status_is_cloned_{false};
}; };
/* /*
......
...@@ -34,17 +34,21 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -34,17 +34,21 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
Env<TARGET(kCUDA)>::Init(); Env<TARGET(kCUDA)>::Init();
#endif #endif
auto places = config.valid_places(); if (!status_is_cloned_) {
std::vector<std::string> passes{}; auto places = config.valid_places();
auto use_layout_preprocess_pass = std::vector<std::string> passes{};
config.model_dir().find("OPENCL_PRE_PRECESS"); auto use_layout_preprocess_pass =
VLOG(1) << "use_layout_preprocess_pass:" << use_layout_preprocess_pass; config.model_dir().find("OPENCL_PRE_PRECESS");
if (places[0].target == TARGET(kOpenCL) && VLOG(1) << "use_layout_preprocess_pass:" << use_layout_preprocess_pass;
use_layout_preprocess_pass != std::string::npos) { if (places[0].target == TARGET(kOpenCL) &&
passes = {"type_layout_cast_preprocess_pass"}; use_layout_preprocess_pass != std::string::npos) {
VLOG(1) << "add pass:" << passes[0]; passes = {"type_layout_cast_preprocess_pass"};
VLOG(1) << "add pass:" << passes[0];
}
raw_predictor_->Build(config, places, passes);
} else {
CHECK(raw_predictor_) << "The Predictor can not be nullptr in Clone mode.";
} }
raw_predictor_.Build(config, places, passes);
mode_ = config.power_mode(); mode_ = config.power_mode();
threads_ = config.threads(); threads_ = config.threads();
...@@ -61,34 +65,36 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -61,34 +65,36 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) { std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
auto *x = raw_predictor_.GetInput(i); auto *x = raw_predictor_->GetInput(i);
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
} }
std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput( std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
int i) const { int i) const {
const auto *x = raw_predictor_.GetOutput(i); const auto *x = raw_predictor_->GetOutput(i);
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
} }
std::vector<std::string> CxxPaddleApiImpl::GetInputNames() { std::vector<std::string> CxxPaddleApiImpl::GetInputNames() {
return raw_predictor_.GetInputNames(); return raw_predictor_->GetInputNames();
} }
std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() { std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_.GetOutputNames(); return raw_predictor_->GetOutputNames();
} }
void CxxPaddleApiImpl::Run() { void CxxPaddleApiImpl::Run() {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
lite::DeviceInfo::Global().SetRunMode(mode_, threads_); lite::DeviceInfo::Global().SetRunMode(mode_, threads_);
#endif #endif
raw_predictor_.Run(); raw_predictor_->Run();
} }
std::shared_ptr<lite_api::PaddlePredictor> CxxPaddleApiImpl::Clone() { std::shared_ptr<lite_api::PaddlePredictor> CxxPaddleApiImpl::Clone() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto predictor = std::make_shared<lite::CxxPaddleApiImpl>(); auto predictor =
std::make_shared<lite::CxxPaddleApiImpl>(raw_predictor_->Clone());
status_is_cloned_ = true;
predictor->Init(config_); predictor->Init(config_);
return predictor; return predictor;
} }
...@@ -97,20 +103,20 @@ std::string CxxPaddleApiImpl::GetVersion() const { return version(); } ...@@ -97,20 +103,20 @@ std::string CxxPaddleApiImpl::GetVersion() const { return version(); }
std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor( std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor(
const std::string &name) const { const std::string &name) const {
auto *x = raw_predictor_.GetTensor(name); auto *x = raw_predictor_->GetTensor(name);
return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x));
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInputByName( std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInputByName(
const std::string &name) { const std::string &name) {
return std::unique_ptr<lite_api::Tensor>( return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(raw_predictor_.GetInputByName(name))); new lite_api::Tensor(raw_predictor_->GetInputByName(name)));
} }
void CxxPaddleApiImpl::SaveOptimizedModel(const std::string &model_dir, void CxxPaddleApiImpl::SaveOptimizedModel(const std::string &model_dir,
lite_api::LiteModelType model_type, lite_api::LiteModelType model_type,
bool record_info) { bool record_info) {
raw_predictor_.SaveModel(model_dir, model_type, record_info); raw_predictor_->SaveModel(model_dir, model_type, record_info);
} }
} // namespace lite } // namespace lite
......
...@@ -53,6 +53,44 @@ TEST(CXXApi, save_model) { ...@@ -53,6 +53,44 @@ TEST(CXXApi, save_model) {
lite_api::LiteModelType::kNaiveBuffer); lite_api::LiteModelType::kNaiveBuffer);
} }
TEST(CXXApi, clone_predictor) {
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kX86), PRECISION(kFloat)}});
predictor.Build(FLAGS_model_dir, "", "", valid_places);
auto cloned_predictor = predictor.Clone();
// primary predicotr
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(std::vector<int64_t>({1, 100}));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100; i++) {
data[i] = 1;
}
predictor.Run();
auto* output_tensor = predictor.GetOutput(0);
auto output_shape = output_tensor->dims().Vectorize();
ASSERT_EQ(output_shape.size(), 2);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 500);
// cloned predictor
auto* cloned_input_tensor = cloned_predictor->GetInput(0);
cloned_input_tensor->Resize(std::vector<int64_t>({1, 100}));
auto* cloned_data = cloned_input_tensor->mutable_data<float>();
for (int i = 0; i < 100; i++) {
cloned_data[i] = 1;
}
cloned_predictor->Run();
auto* cloned_output_tensor = cloned_predictor->GetOutput(0);
int step = 50;
for (int i = 0; i < output_tensor->data_size(); i += step) {
EXPECT_NEAR(output_tensor->data<float>()[i],
cloned_output_tensor->data<float>()[i],
1e-6);
}
}
/*TEST(CXXTrainer, train) { /*TEST(CXXTrainer, train) {
Place place({TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}); Place place({TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)});
std::vector<Place> valid_places({place}); std::vector<Place> valid_places({place});
......
...@@ -37,6 +37,20 @@ namespace lite { ...@@ -37,6 +37,20 @@ namespace lite {
*/ */
class Optimizer { class Optimizer {
public: public:
Optimizer() {}
Optimizer(Program* program, const std::vector<Place>& valid_places) {
program_ = program;
valid_places_ = valid_places;
CHECK(!valid_places.empty()) << "At least one valid_place should be set";
CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph);
graph_->Build(*program, valid_places);
graph_->SetValidPlaces(valid_places);
exec_scope_ = program->exec_scope();
}
void Run(Program&& program, void Run(Program&& program,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
core::KernelPickFactor kernel_pick_factor, core::KernelPickFactor kernel_pick_factor,
......
...@@ -13,11 +13,16 @@ ...@@ -13,11 +13,16 @@
// limitations under the License. // limitations under the License.
#include "lite/core/scope.h" #include "lite/core/scope.h"
#define SCOPE_KIDS_READER_LOCK lite::fluid::AutoRDLock auto_lock(&kids_lock_);
#define SCOPE_KIDS_WRITER_LOCK lite::fluid::AutoWRLock auto_lock(&kids_lock_);
#define SCOPE_VARS_READER_LOCK lite::fluid::AutoRDLock auto_lock(&vars_lock_);
#define SCOPE_VARS_WRITER_LOCK lite::fluid::AutoWRLock auto_lock(&vars_lock_);
namespace paddle { namespace paddle {
namespace lite { namespace lite {
Scope::~Scope() { Scope::~Scope() {
SCOPE_KIDS_WRITER_LOCK
for (auto *x : kids_) { for (auto *x : kids_) {
if (x) { if (x) {
delete x; delete x;
...@@ -26,12 +31,14 @@ Scope::~Scope() { ...@@ -26,12 +31,14 @@ Scope::~Scope() {
} }
Scope &Scope::NewScope() const { Scope &Scope::NewScope() const {
SCOPE_KIDS_WRITER_LOCK
kids_.push_back(new Scope); kids_.push_back(new Scope);
kids_.back()->parent_ = this; kids_.back()->parent_ = this;
return *kids_.back(); return *kids_.back();
} }
Variable *Scope::Var(const std::string &name) { Variable *Scope::Var(const std::string &name) {
SCOPE_VARS_WRITER_LOCK
auto *var = FindVar(name); auto *var = FindVar(name);
if (var) return var; if (var) return var;
...@@ -45,6 +52,7 @@ Variable *Scope::FindVar(const std::string &name) const { ...@@ -45,6 +52,7 @@ Variable *Scope::FindVar(const std::string &name) const {
var = FindLocalVar(name); var = FindLocalVar(name);
const Scope *cur_scope = this; const Scope *cur_scope = this;
while (!var && cur_scope->parent()) { while (!var && cur_scope->parent()) {
// SCOPE_VARS_READER_LOCK
cur_scope = cur_scope->parent(); cur_scope = cur_scope->parent();
var = cur_scope->FindLocalVar(name); var = cur_scope->FindLocalVar(name);
} }
...@@ -53,6 +61,7 @@ Variable *Scope::FindVar(const std::string &name) const { ...@@ -53,6 +61,7 @@ Variable *Scope::FindVar(const std::string &name) const {
} }
Variable *Scope::FindLocalVar(const std::string &name) const { Variable *Scope::FindLocalVar(const std::string &name) const {
// SCOPE_VARS_READER_LOCK
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) { if (it != vars_.end()) {
return it->second.get(); return it->second.get();
...@@ -62,8 +71,11 @@ Variable *Scope::FindLocalVar(const std::string &name) const { ...@@ -62,8 +71,11 @@ Variable *Scope::FindLocalVar(const std::string &name) const {
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> keys; std::vector<std::string> keys;
for (const auto &item : vars_) { {
keys.push_back(item.first); // SCOPE_VARS_READER_LOCK
for (const auto &item : vars_) {
keys.push_back(item.first);
}
} }
return keys; return keys;
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/variable.h" #include "lite/core/variable.h"
#include "lite/fluid/rw_lock.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -73,6 +74,8 @@ class Scope final { ...@@ -73,6 +74,8 @@ class Scope final {
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
const Scope* parent_{nullptr}; const Scope* parent_{nullptr};
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
mutable lite::fluid::RWLock kids_lock_;
mutable lite::fluid::RWLock vars_lock_;
}; };
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册