未验证 提交 da130862 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework] Remove program_desc from Program and Update Clone method (#3976)

上级 7fedd6f8
......@@ -49,18 +49,33 @@ class LITE_API Predictor {
program_desc_ = std::make_shared<cpp::ProgramDesc>();
}
// Create a predictor with the weight variable scope set.
///////////////////////////////////////////////////////////////////
// Function: Predictor
// Usage: Constructor of Predictor. Create a predictor with the
// weight variable scope set given.
///////////////////////////////////////////////////////////////////
explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
: scope_(root_scope) {}
///////////////////////////////////////////////////////////////////
// Function: Predictor
// Usage: Constructor of Predictor. This constructor function can
// only be called in Predictor->Clone. This Function will create
// a predictor from existed ProgramDesc, Scope and RuntimeProgram.
///////////////////////////////////////////////////////////////////
Predictor(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::shared_ptr<Scope>& root_scope,
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places,
const std::vector<std::string>& vars_to_clone = {})
: program_desc_(program_desc), scope_(root_scope) {
Program program(program_desc_, scope_, valid_places, vars_to_clone);
optimizer_ = Optimizer(std::move(program), valid_places);
exec_scope_ = optimizer_.exec_scope();
const std::vector<std::string>& var_names = {})
: program_desc_(program_desc), scope_(root) {
// step1. Create a Program to construct the exec_scope and ops
Program program(program_desc_, scope_, valid_places, var_names);
exec_scope_ = program.exec_scope();
valid_places_ = valid_places;
// step3. Create the RuntimeProgram.
program_.reset(
new RuntimeProgram(program_desc_, exec_scope_, kRootBlockIdx));
program_generated_ = true;
}
// Build from a model, with places set for hardware config.
......@@ -83,26 +98,58 @@ class LITE_API Predictor {
const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {});
std::shared_ptr<Predictor> Clone() const {
return std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
//////////////////////////////////////////////////////////
// Function: Clone
// Usage: Create a Predictor from an existed one,
// the cloned predictor will share persistable variables
// in scope_ with the original predictor.
//////////////////////////////////////////////////////////
std::shared_ptr<Predictor> Clone() {
// step 1. Generate runtime_program, update op_info and var_info in
// program_desc_
if (!program_generated_) {
GenRuntimeProgram();
}
program_->SaveToProgram(program_desc_);
// step 2. Create a predictor friom current program_desc_ and
// runtime_program.
auto predictor =
std::make_shared<Predictor>(program_desc_, scope_, valid_places_);
// step3. Return the result
return predictor;
}
std::shared_ptr<Predictor> Clone(
const std::vector<std::string>& vars_to_clone) const {
//////////////////////////////////////////////////////////
// Function: Clone(var_names)
// Usage: Create a Predictor from an existed one,
// the cloned predictor will share persistable variables
// but persistable variables of name var_names will not
// be shared.
//////////////////////////////////////////////////////////
std::shared_ptr<Predictor> Clone(const std::vector<std::string>& var_names) {
CHECK(program_desc_) << "Both program and scope of current predicotr "
"should be not be nullptr in Clone mode.";
CHECK(scope_) << "Both program and scope of current predicotr should be "
"not be nullptr in Clone mode.";
// step 1. Generate runtime_program, update op_info and var_info in
// program_desc_
if (!program_generated_) {
GenRuntimeProgram();
}
program_->SaveToProgram(program_desc_);
// step 2. Create a predictor friom current program_desc_ and
// runtime_program.
auto predictor = std::make_shared<Predictor>(
program_desc_, scope_, valid_places_, vars_to_clone);
for (auto var_name : vars_to_clone) {
program_desc_, scope_, valid_places_, var_names);
// step3. Copy some persistable variables into private scope.
for (auto var_name : var_names) {
predictor->exec_scope_->LocalVar(var_name);
auto* tensor = predictor->scope_->Var(var_name)->GetMutable<Tensor>();
auto* tensor =
predictor->scope_->Var(var_name)->GetMutable<lite::Tensor>();
auto* sub_tensor =
predictor->exec_scope_->Var(var_name)->GetMutable<Tensor>();
sub_tensor->CopyDataFrom(*tensor);
}
// step4. Return the result
return predictor;
}
......@@ -161,7 +208,7 @@ class LITE_API Predictor {
std::shared_ptr<cpp::ProgramDesc> program_desc_;
std::shared_ptr<Scope> scope_;
Scope* exec_scope_;
std::unique_ptr<RuntimeProgram> program_;
std::shared_ptr<RuntimeProgram> program_;
bool program_generated_{false};
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
......
......@@ -47,21 +47,18 @@ struct Program {
Program(const std::shared_ptr<cpp::ProgramDesc>& program_desc,
const std::shared_ptr<Scope>& root_scope,
const std::vector<Place>& valid_places,
const std::vector<std::string>& vars_to_clone = {})
: scope_(root_scope),
valid_places_(valid_places),
program_desc_(program_desc) {
const std::vector<std::string>& var_names = {})
: scope_(root_scope), valid_places_(valid_places) {
CHECK(scope_) << "scope should be init first";
VLOG(4) << "prepare work";
PrepareWorkspace(program_desc_, vars_to_clone);
PrepareWorkspace(program_desc, var_names);
VLOG(4) << "build desc";
Build(program_desc_);
Build(program_desc);
VLOG(4) << "build desc finished";
}
std::unique_ptr<Program> Clone() const {
return std::unique_ptr<Program>(
new Program(program_desc_, scope_, valid_places_));
return std::unique_ptr<Program>(new Program(scope_));
}
const std::list<std::string>& weights() const { return weights_; }
......@@ -83,8 +80,6 @@ struct Program {
Scope* exec_scope() { return exec_scope_; }
Scope* scope() { return scope_.get(); }
cpp::ProgramDesc* program_desc() { return program_desc_.get(); }
const std::map<std::string, const Type*>& var_type_map() const {
return var_type_map_;
}
......@@ -106,7 +101,6 @@ struct Program {
std::vector<Place> valid_places_;
// Runtime scope.
Scope* exec_scope_{};
std::shared_ptr<cpp::ProgramDesc> program_desc_;
};
struct Instruction {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册