未验证 提交 da130862 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework] Remove program_desc from Program and Update Clone method (#3976)

上级 7fedd6f8
...@@ -49,18 +49,33 @@ class LITE_API Predictor { ...@@ -49,18 +49,33 @@ class LITE_API Predictor {
program_desc_ = std::make_shared<cpp::ProgramDesc>(); program_desc_ = std::make_shared<cpp::ProgramDesc>();
} }
// Create a predictor with the weight variable scope set. ///////////////////////////////////////////////////////////////////
// Function: Predictor
// Usage: Constructor of Predictor. Create a predictor with the
// weight variable scope set given.
///////////////////////////////////////////////////////////////////
explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope) explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
: scope_(root_scope) {} : scope_(root_scope) {}
///////////////////////////////////////////////////////////////////
// Function: Predictor
// Usage: Constructor of Predictor. This constructor function can
// only be called in Predictor->Clone. This Function will create
// a predictor from existed ProgramDesc, Scope and RuntimeProgram.
///////////////////////////////////////////////////////////////////
Predictor(const std::shared_ptr<cpp::ProgramDesc>& program_desc, Predictor(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::shared_ptr<Scope>& root_scope, const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& vars_to_clone = {}) const std::vector<std::string>& var_names = {})
: program_desc_(program_desc), scope_(root_scope) { : program_desc_(program_desc), scope_(root) {
Program program(program_desc_, scope_, valid_places, vars_to_clone); // step1. Create a Program to construct the exec_scope and ops
optimizer_ = Optimizer(std::move(program), valid_places); Program program(program_desc_, scope_, valid_places, var_names);
exec_scope_ = optimizer_.exec_scope(); exec_scope_ = program.exec_scope();
valid_places_ = valid_places; valid_places_ = valid_places;
// step3. Create the RuntimeProgram.
program_.reset(
new RuntimeProgram(program_desc_, exec_scope_, kRootBlockIdx));
program_generated_ = true;
} }
// Build from a model, with places set for hardware config. // Build from a model, with places set for hardware config.
...@@ -83,26 +98,58 @@ class LITE_API Predictor { ...@@ -83,26 +98,58 @@ class LITE_API Predictor {
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}); const std::vector<std::string>& passes = {});
std::shared_ptr<Predictor> Clone() const { //////////////////////////////////////////////////////////
return std::make_shared<Predictor>(program_desc_, scope_, valid_places_); // Function: Clone
// Usage: Create a Predictor from an existed one,
// the cloned predictor will share persistable variables
// in scope_ with the original predictor.
//////////////////////////////////////////////////////////
std::shared_ptr<Predictor> Clone() {
// step 1. Generate runtime_program, update op_info and var_info in
// program_desc_
if (!program_generated_) {
GenRuntimeProgram();
}
program_->SaveToProgram(program_desc_);
// step 2. Create a predictor friom current program_desc_ and
// runtime_program.
auto predictor =
std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
// step3. Return the result
return predictor;
} }
//////////////////////////////////////////////////////////
std::shared_ptr<Predictor> Clone( // Function: Clone(var_names)
const std::vector<std::string>& vars_to_clone) const { // Usage: Create a Predictor from an existed one,
// the cloned predictor will share persistable variables
// but persistable variables of name var_names will not
// be shared.
//////////////////////////////////////////////////////////
std::shared_ptr<Predictor> Clone(const std::vector<std::string>& var_names) {
CHECK(program_desc_) << "Both program and scope of current predicotr " CHECK(program_desc_) << "Both program and scope of current predicotr "
"should be not be nullptr in Clone mode."; "should be not be nullptr in Clone mode.";
CHECK(scope_) << "Both program and scope of current predicotr should be " CHECK(scope_) << "Both program and scope of current predicotr should be "
"not be nullptr in Clone mode."; "not be nullptr in Clone mode.";
// step 1. Generate runtime_program, update op_info and var_info in
// program_desc_
if (!program_generated_) {
GenRuntimeProgram();
}
program_->SaveToProgram(program_desc_);
// step 2. Create a predictor friom current program_desc_ and
// runtime_program.
auto predictor = std::make_shared<Predictor>( auto predictor = std::make_shared<Predictor>(
program_desc_, scope_, valid_places_, vars_to_clone); program_desc_, scope_, valid_places_, var_names);
// step3. Copy some persistable variables into private scope.
for (auto var_name : vars_to_clone) { for (auto var_name : var_names) {
predictor->exec_scope_->LocalVar(var_name); predictor->exec_scope_->LocalVar(var_name);
auto* tensor = predictor->scope_->Var(var_name)->GetMutable<Tensor>(); auto* tensor =
predictor->scope_->Var(var_name)->GetMutable<lite::Tensor>();
auto* sub_tensor = auto* sub_tensor =
predictor->exec_scope_->Var(var_name)->GetMutable<Tensor>(); predictor->exec_scope_->Var(var_name)->GetMutable<Tensor>();
sub_tensor->CopyDataFrom(*tensor); sub_tensor->CopyDataFrom(*tensor);
} }
// step4. Return the result
return predictor; return predictor;
} }
...@@ -161,7 +208,7 @@ class LITE_API Predictor { ...@@ -161,7 +208,7 @@ class LITE_API Predictor {
std::shared_ptr<cpp::ProgramDesc> program_desc_; std::shared_ptr<cpp::ProgramDesc> program_desc_;
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
Scope* exec_scope_; Scope* exec_scope_;
std::unique_ptr<RuntimeProgram> program_; std::shared_ptr<RuntimeProgram> program_;
bool program_generated_{false}; bool program_generated_{false};
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::vector<std::string> output_names_; std::vector<std::string> output_names_;
......
...@@ -47,21 +47,18 @@ struct Program { ...@@ -47,21 +47,18 @@ struct Program {
Program(const std::shared_ptr<cpp::ProgramDesc>& program_desc, Program(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::shared_ptr<Scope>& root_scope, const std::shared_ptr<Scope>& root_scope,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& vars_to_clone = {}) const std::vector<std::string>& var_names = {})
: scope_(root_scope), : scope_(root_scope), valid_places_(valid_places) {
valid_places_(valid_places),
program_desc_(program_desc) {
CHECK(scope_) << "scope should be init first"; CHECK(scope_) << "scope should be init first";
VLOG(4) << "prepare work"; VLOG(4) << "prepare work";
PrepareWorkspace(program_desc_, vars_to_clone); PrepareWorkspace(program_desc, var_names);
VLOG(4) << "build desc"; VLOG(4) << "build desc";
Build(program_desc_); Build(program_desc);
VLOG(4) << "build desc finished"; VLOG(4) << "build desc finished";
} }
std::unique_ptr<Program> Clone() const { std::unique_ptr<Program> Clone() const {
return std::unique_ptr<Program>( return std::unique_ptr<Program>(new Program(scope_));
new Program(program_desc_, scope_, valid_places_));
} }
const std::list<std::string>& weights() const { return weights_; } const std::list<std::string>& weights() const { return weights_; }
...@@ -83,8 +80,6 @@ struct Program { ...@@ -83,8 +80,6 @@ struct Program {
Scope* exec_scope() { return exec_scope_; } Scope* exec_scope() { return exec_scope_; }
Scope* scope() { return scope_.get(); } Scope* scope() { return scope_.get(); }
cpp::ProgramDesc* program_desc() { return program_desc_.get(); }
const std::map<std::string, const Type*>& var_type_map() const { const std::map<std::string, const Type*>& var_type_map() const {
return var_type_map_; return var_type_map_;
} }
...@@ -106,7 +101,6 @@ struct Program { ...@@ -106,7 +101,6 @@ struct Program {
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
// Runtime scope. // Runtime scope.
Scope* exec_scope_{}; Scope* exec_scope_{};
std::shared_ptr<cpp::ProgramDesc> program_desc_;
}; };
struct Instruction { struct Instruction {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册