提交 3c8dab9b 编写于 作者: D DannyIsFunny

test=develop

上级 b66052dd
...@@ -30,14 +30,14 @@ void Predictor::SaveModel(const std::string &dir, ...@@ -30,14 +30,14 @@ void Predictor::SaveModel(const std::string &dir,
if (!program_) { if (!program_) {
GenRuntimeProgram(); GenRuntimeProgram();
} }
program_->SaveOpInfosToProgram(&program_desc_); program_->SaveOpInfosToProgram(program_desc_.get());
program_->UpdateVarsOfProgram(&program_desc_); program_->UpdateVarsOfProgram(program_desc_.get());
switch (model_type) { switch (model_type) {
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
SaveModelPb(dir, *program_->exec_scope(), program_desc_, true); SaveModelPb(dir, *program_->exec_scope(), *program_desc_.get(), true);
break; break;
case lite_api::LiteModelType::kNaiveBuffer: case lite_api::LiteModelType::kNaiveBuffer:
SaveModelNaive(dir, *program_->exec_scope(), program_desc_); SaveModelNaive(dir, *program_->exec_scope(), *program_desc_.get());
break; break;
default: default:
LOG(FATAL) << "Unknown model type"; LOG(FATAL) << "Unknown model type";
...@@ -226,7 +226,7 @@ std::vector<const lite::Tensor *> Predictor::GetOutputs() const { ...@@ -226,7 +226,7 @@ std::vector<const lite::Tensor *> Predictor::GetOutputs() const {
#endif #endif
const cpp::ProgramDesc &Predictor::program_desc() const { const cpp::ProgramDesc &Predictor::program_desc() const {
return program_desc_; return *program_desc_.get();
} }
const RuntimeProgram &Predictor::runtime_program() const { return *program_; } const RuntimeProgram &Predictor::runtime_program() const { return *program_; }
...@@ -269,14 +269,14 @@ void Predictor::Build(const std::string &model_path, ...@@ -269,14 +269,14 @@ void Predictor::Build(const std::string &model_path,
model_file, model_file,
param_file, param_file,
scope_.get(), scope_.get(),
&program_desc_, program_desc_.get(),
combined_param, combined_param,
model_from_memory); model_from_memory);
} break; } break;
case lite_api::LiteModelType::kNaiveBuffer: case lite_api::LiteModelType::kNaiveBuffer:
CHECK(!model_path.empty()) CHECK(!model_path.empty())
<< "NaiveBuffer backend only supported combined param"; << "NaiveBuffer backend only supported combined param";
LoadModelNaiveFromFile(model_path, scope_.get(), &program_desc_); LoadModelNaiveFromFile(model_path, scope_.get(), program_desc_.get());
break; break;
default: default:
LOG(FATAL) << "Unknown model type"; LOG(FATAL) << "Unknown model type";
...@@ -284,7 +284,7 @@ void Predictor::Build(const std::string &model_path, ...@@ -284,7 +284,7 @@ void Predictor::Build(const std::string &model_path,
Build(program_desc_, valid_places, passes); Build(program_desc_, valid_places, passes);
} }
void Predictor::Build(const cpp::ProgramDesc &desc, void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &desc,
const std::vector<Place> &valid_places, const std::vector<Place> &valid_places,
const std::vector<std::string> &passes) { const std::vector<std::string> &passes) {
program_desc_ = desc; program_desc_ = desc;
...@@ -301,9 +301,9 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -301,9 +301,9 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
"fake_dequantize_max_abs", "fake_dequantize_max_abs",
"fake_channel_wise_dequantize_max_abs"}; "fake_channel_wise_dequantize_max_abs"};
bool is_quantized_model = false; bool is_quantized_model = false;
for (size_t i = 0; i < program_desc_.BlocksSize() && !is_quantized_model; for (size_t i = 0; i < program_desc_->BlocksSize() && !is_quantized_model;
++i) { ++i) {
auto *block_desc = program_desc_.GetBlock<cpp::BlockDesc>(i); auto *block_desc = program_desc_->GetBlock<cpp::BlockDesc>(i);
for (size_t j = 0; j < block_desc->OpsSize() && !is_quantized_model; ++j) { for (size_t j = 0; j < block_desc->OpsSize() && !is_quantized_model; ++j) {
auto *op_desc = block_desc->GetOp<cpp::OpDesc>(j); auto *op_desc = block_desc->GetOp<cpp::OpDesc>(j);
std::string op_type = op_desc->Type(); std::string op_type = op_desc->Type();
...@@ -318,7 +318,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -318,7 +318,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)}); inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)});
} }
Program program(desc, scope_, inner_places); Program program(*desc.get(), scope_, inner_places);
valid_places_ = inner_places; valid_places_ = inner_places;
core::KernelPickFactor factor; core::KernelPickFactor factor;
......
...@@ -42,16 +42,19 @@ static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list"; ...@@ -42,16 +42,19 @@ static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list";
class LITE_API Predictor { class LITE_API Predictor {
public: public:
// Create an empty predictor. // Create an empty predictor.
Predictor() { scope_ = std::make_shared<Scope>(); } Predictor() {
scope_ = std::make_shared<Scope>();
program_desc_ = std::make_shared<cpp::ProgramDesc>();
}
// Create a predictor with the weight variable scope set. // Create a predictor with the weight variable scope set.
explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope) explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
: scope_(root_scope) {} : scope_(root_scope) {}
Predictor(const cpp::ProgramDesc& desc, Predictor(const std::shared_ptr<cpp::ProgramDesc>& desc,
const std::shared_ptr<Scope>& root, const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places) const std::vector<Place>& valid_places)
: program_desc_(desc), scope_(root) { : program_desc_(desc), scope_(root) {
optimizer_ = Program program(*desc.get(), scope_, valid_places);
Optimizer(new Program(desc, scope_, valid_places), valid_places); optimizer_ = Optimizer(std::move(program), valid_places);
exec_scope_ = optimizer_.exec_scope(); exec_scope_ = optimizer_.exec_scope();
GenRuntimeProgram(); GenRuntimeProgram();
valid_places_ = valid_places; valid_places_ = valid_places;
...@@ -74,7 +77,7 @@ class LITE_API Predictor { ...@@ -74,7 +77,7 @@ class LITE_API Predictor {
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf, lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
bool memory_from_memory = false); bool memory_from_memory = false);
void Build(const cpp::ProgramDesc& desc, void Build(const std::shared_ptr<cpp::ProgramDesc>& desc,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}); const std::vector<std::string>& passes = {});
...@@ -133,7 +136,7 @@ class LITE_API Predictor { ...@@ -133,7 +136,7 @@ class LITE_API Predictor {
private: private:
Optimizer optimizer_; Optimizer optimizer_;
cpp::ProgramDesc program_desc_; std::shared_ptr<cpp::ProgramDesc> program_desc_;
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
const Scope* exec_scope_; const Scope* exec_scope_;
std::unique_ptr<RuntimeProgram> program_; std::unique_ptr<RuntimeProgram> program_;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "lite/core/mir/generate_program_pass.h" #include "lite/core/mir/generate_program_pass.h"
#include "lite/core/mir/pass_manager.h" #include "lite/core/mir/pass_manager.h"
...@@ -39,16 +40,18 @@ class Optimizer { ...@@ -39,16 +40,18 @@ class Optimizer {
public: public:
Optimizer() {} Optimizer() {}
Optimizer(Program* program, const std::vector<Place>& valid_places) { Optimizer(Program&& program, const std::vector<Place>& valid_places) {
program_ = program; program_ = &program;
valid_places_ = valid_places; valid_places_ = valid_places;
CHECK(!valid_places.empty()) << "At least one valid_place should be set"; CHECK(!valid_places.empty()) << "At least one valid_place should be set";
CHECK(!graph_) << "duplicate optimize found"; CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph); core::KernelPickFactor factor;
graph_->Build(*program, valid_places); factor.ConsiderTarget();
graph_->SetValidPlaces(valid_places); factor.ConsiderPrecision();
exec_scope_ = program->exec_scope(); factor.ConsiderDataLayout();
Run(std::move(program), valid_places, factor, {});
} }
void Run(Program&& program, void Run(Program&& program,
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
// limitations under the License. // limitations under the License.
#include "lite/core/scope.h" #include "lite/core/scope.h"
#define SCOPE_KIDS_READER_LOCK lite::fluid::AutoRDLock auto_lock(&kids_lock_); #define SCOPE_KIDS_READER_LOCK lite::fluid::AutoRDLock auto_lock(kids_lock_);
#define SCOPE_KIDS_WRITER_LOCK lite::fluid::AutoWRLock auto_lock(&kids_lock_); #define SCOPE_KIDS_WRITER_LOCK lite::fluid::AutoWRLock auto_lock(kids_lock_);
#define SCOPE_VARS_READER_LOCK lite::fluid::AutoRDLock auto_lock(&vars_lock_); #define SCOPE_VARS_READER_LOCK lite::fluid::AutoRDLock auto_lock(vars_lock_);
#define SCOPE_VARS_WRITER_LOCK lite::fluid::AutoWRLock auto_lock(&vars_lock_); #define SCOPE_VARS_WRITER_LOCK lite::fluid::AutoWRLock auto_lock(vars_lock_);
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -41,7 +41,6 @@ Variable *Scope::Var(const std::string &name) { ...@@ -41,7 +41,6 @@ Variable *Scope::Var(const std::string &name) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
auto *var = FindVar(name); auto *var = FindVar(name);
if (var) return var; if (var) return var;
// create a new variable. // create a new variable.
vars_.emplace(name, std::unique_ptr<Variable>(new Variable)); vars_.emplace(name, std::unique_ptr<Variable>(new Variable));
return vars_[name].get(); return vars_[name].get();
...@@ -51,31 +50,34 @@ Variable *Scope::FindVar(const std::string &name) const { ...@@ -51,31 +50,34 @@ Variable *Scope::FindVar(const std::string &name) const {
Variable *var{nullptr}; Variable *var{nullptr};
var = FindLocalVar(name); var = FindLocalVar(name);
const Scope *cur_scope = this; const Scope *cur_scope = this;
rwlock_->RDLock();
while (!var && cur_scope->parent()) { while (!var && cur_scope->parent()) {
// SCOPE_VARS_READER_LOCK
cur_scope = cur_scope->parent(); cur_scope = cur_scope->parent();
var = cur_scope->FindLocalVar(name); var = cur_scope->FindLocalVar(name);
} }
rwlock_->UNLock();
return var; return var;
} }
Variable *Scope::FindLocalVar(const std::string &name) const { Variable *Scope::FindLocalVar(const std::string &name) const {
// SCOPE_VARS_READER_LOCK rwlock_->RDLock();
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) { if (it != vars_.end()) {
rwlock_->UNLock();
return it->second.get(); return it->second.get();
} }
rwlock_->UNLock();
return nullptr; return nullptr;
} }
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> keys; std::vector<std::string> keys;
{ {
// SCOPE_VARS_READER_LOCK rwlock_->RDLock();
for (const auto &item : vars_) { for (const auto &item : vars_) {
keys.push_back(item.first); keys.push_back(item.first);
} }
rwlock_->UNLock();
} }
return keys; return keys;
} }
......
...@@ -27,7 +27,11 @@ namespace lite { ...@@ -27,7 +27,11 @@ namespace lite {
class Scope final { class Scope final {
public: public:
Scope() {} Scope() {
kids_lock_ = new lite::fluid::RWLock;
vars_lock_ = new lite::fluid::RWLock;
rwlock_.reset(new lite::fluid::RWLock);
}
// delete below two functions to allow pybind to recognise it cannot make a // delete below two functions to allow pybind to recognise it cannot make a
// copy // copy
// link: // link:
...@@ -74,8 +78,9 @@ class Scope final { ...@@ -74,8 +78,9 @@ class Scope final {
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
const Scope* parent_{nullptr}; const Scope* parent_{nullptr};
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
mutable lite::fluid::RWLock kids_lock_; lite::fluid::RWLock* kids_lock_{nullptr};
mutable lite::fluid::RWLock vars_lock_; lite::fluid::RWLock* vars_lock_{nullptr};
std::unique_ptr<lite::fluid::RWLock> rwlock_{nullptr};
}; };
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册