提交 64f10504 编写于 作者: S superjomn

add program

上级 18959c09
......@@ -24,16 +24,17 @@ struct Config {};
class Predictor {
public:
Predictor() { scope_ = std::make_shared<Scope>(); }
void Build(const std::string& model_path,
const std::vector<Place>& valid_places) {
CHECK(!executor_.get()) << "duplicate build found";
CHECK(!scope_.get()) << "duplicate build found";
framework::proto::ProgramDesc prog;
LoadModel(model_path, &scope_, &prog);
LoadModel(model_path, scope_.get(), &prog);
framework::ProgramDesc prog_desc(prog);
executor_.reset(new Executor(&scope_, valid_places));
executor_->PrepareWorkspace(prog_desc);
executor_->Build(prog_desc);
executor_.reset(new Executor(prog_desc, scope_.get(), valid_places));
}
// Get a tensor for input from scope directly.
......@@ -53,7 +54,7 @@ class Predictor {
void Run() { executor_->Run(); }
private:
Scope scope_;
std::shared_ptr<Scope> scope_;
std::unique_ptr<lite::Executor> executor_;
};
......
......@@ -26,8 +26,7 @@ TEST(CXXApi, raw) {
LoadModel("/home/chunwei/project2/models/model2", &scope, &prog);
framework::ProgramDesc prog_desc(prog);
lite::Executor executor(&scope,
{OpLite::Place{TARGET(kHost), PRECISION(kFloat)}});
lite::Executor executor(&scope, {Place{TARGET(kHost), PRECISION(kFloat)}});
auto x = scope.Var("a")->GetMutable<Tensor>();
x->Resize({100, 100});
......@@ -41,7 +40,7 @@ TEST(CXXApi, raw) {
TEST(CXXApi, test) {
lite::Predictor predictor;
predictor.Build("/home/chunwei/project2/models/model2",
{OpLite::Place{TARGET(kHost), PRECISION(kFloat)}});
{Place{TARGET(kHost), PRECISION(kFloat)}});
auto* x = predictor.GetInputTensor("a");
x->Resize({100, 200});
x->mutable_data<float>();
......
......@@ -18,6 +18,7 @@ cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph
ops_lite
host_kernels
)
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
......
......@@ -22,22 +22,12 @@
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/program.h"
namespace paddle {
namespace lite {
namespace mir {
// A program is used to represent a code program, in Paddle, a code program
// contains:
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct Program {
std::list<std::string> tmp_vars;
std::list<std::string> weights;
std::list<std::shared_ptr<OpLite>> ops;
lite::Scope *scope{};
};
// An Graph for MIR. It is built from a list of Op and a scope.
class GraphBase {};
......
......@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/scope.h"
namespace paddle {
......@@ -24,41 +25,16 @@ namespace lite {
// The Executor is used to run the operators.
class Executor {
public:
Executor(lite::Scope* scope, const std::vector<Place>& valid_places)
: scope_(scope), valid_places_(valid_places) {}
// Create temporary variables.
void PrepareWorkspace(framework::ProgramDesc& program) {
CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
exec_scope_ = &scope_->NewScope();
for (auto var_desc : program.Block(0).AllVars()) {
if (!var_desc->Persistable()) {
auto* var = exec_scope_->Var(var_desc->Name());
LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var;
}
}
}
// Build from a program and scope.
void Build(framework::ProgramDesc& program) {
CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators.
for (auto* op_desc : program.Block(0).AllOps()) {
auto op_type = op_desc->Type();
if (op_type == "feed" || op_type == "fetch") continue;
LOG(INFO) << "create Op [" << op_type << "]";
ops_.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel
ops_.back()->PickKernel(valid_places_);
ops_.back()->Attach(*op_desc, exec_scope_);
}
Executor(const framework::ProgramDesc& desc,
const std::shared_ptr<lite::Scope>& scope,
const std::vector<Place>& valid_places)
: valid_places_(valid_places) {
program_.reset(new Program(desc, scope, valid_places));
}
// Run the program.
void Run() {
for (auto& op : ops_) {
for (auto& op : program_->ops) {
LOG(INFO) << op->DebugString();
// TODO(Superjomn) check only once
op->CheckShape();
......@@ -67,14 +43,11 @@ class Executor {
}
}
lite::Scope* scope() { return scope_; }
lite::Scope* exec_scope() { return exec_scope_; }
const Program& program() const { return *program_; }
private:
std::vector<std::unique_ptr<OpLite>> ops_;
lite::Scope* scope_{};
std::vector<Place> valid_places_;
lite::Scope* exec_scope_{};
std::unique_ptr<Program> program_;
};
} // namespace lite
......
......@@ -27,7 +27,7 @@ namespace lite {
*/
class Optimizer {
public:
void Run(mir::Program&& program, const std::vector<Place>& valid_places,
void Run(Program&& program, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}) {
CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph);
......@@ -36,8 +36,8 @@ class Optimizer {
}
// Generate a new program based on the mir graph.
std::unique_ptr<mir::Program> GenProgram() {
std::unique_ptr<mir::Program> res;
std::unique_ptr<Program> GenProgram() {
std::unique_ptr<Program> res;
return res;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/program.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <string>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
// A program is used to represent a code program, in Paddle, a code program
// contains:
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct Program {
std::list<std::string> tmp_vars;
std::list<std::string> weights;
std::list<std::shared_ptr<OpLite>> ops;
// the scope to run the kernels, NOTE not the root scope.
std::shared_ptr<lite::Scope> scope;
// Runtime scope.
lite::Scope* exec_scope{};
explicit Program(const std::shared_ptr<Scope>& root) { scope = root; }
Program(const framework::ProgramDesc& desc,
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places) {
scope = root;
PrepareWorkspace(desc);
Build(desc, valid_places);
}
std::unique_ptr<Program> Clone() const {
std::unique_ptr<Program> res(new Program(scope));
res->tmp_vars = tmp_vars;
res->weights = weights;
res->ops = ops;
return res;
}
private:
// Build from a program and scope.
void Build(const framework::ProgramDesc& program,
const std::vector<Place>& valid_places) {
CHECK(ops.empty()) << "Executor duplicate Build found";
// Create operators.
for (auto* op_desc : program.Block(0).AllOps()) {
auto op_type = op_desc->Type();
if (op_type == "feed" || op_type == "fetch") continue;
LOG(INFO) << "create Op [" << op_type << "]";
ops.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel
ops.back()->PickKernel(valid_places);
ops.back()->Attach(*op_desc, exec_scope);
}
}
// Create temporary variables.
void PrepareWorkspace(const framework::ProgramDesc& program) {
CHECK(!exec_scope) << "Duplicate PrepareWorkspace found";
exec_scope = &scope->NewScope();
for (auto var_desc : program.Block(0).AllVars()) {
if (!var_desc->Persistable()) {
auto* var = exec_scope->Var(var_desc->Name());
LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var;
}
}
}
};
} // namespace lite
} // namespace paddle
......@@ -20,9 +20,8 @@
namespace paddle {
namespace lite {
mir::Program FakeProgram() {
mir::Program program;
program.scope = new lite::Scope;
Program FakeProgram() {
Program program(std::make_shared<lite::Scope>());
auto add_fc = [&](int id, std::string x) {
// create variables
......@@ -48,7 +47,7 @@ mir::Program FakeProgram() {
auto fc_op = LiteOpRegistry::Global().Create("fc");
fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc_op->Attach(desc, program.scope);
fc_op->Attach(desc, program.scope.get());
program.ops.emplace_back(std::move(fc_op));
w1v->Resize({100, 100});
......
......@@ -39,6 +39,15 @@ class Scope final {
const Scope* parent() const { return parent_; }
// Following the legacy scope interface.
std::vector<std::string> LocalVarNames() const {
std::vector<std::string> keys;
for (const auto& item : vars_) {
keys.push_back(item.first);
}
return keys;
}
private:
// Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册