提交 64f10504 编写于 作者: S superjomn

add program

上级 18959c09
...@@ -24,16 +24,17 @@ struct Config {}; ...@@ -24,16 +24,17 @@ struct Config {};
class Predictor { class Predictor {
public: public:
Predictor() { scope_ = std::make_shared<Scope>(); }
void Build(const std::string& model_path, void Build(const std::string& model_path,
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places) {
CHECK(!executor_.get()) << "duplicate build found"; CHECK(!executor_.get()) << "duplicate build found";
CHECK(!scope_.get()) << "duplicate build found";
framework::proto::ProgramDesc prog; framework::proto::ProgramDesc prog;
LoadModel(model_path, &scope_, &prog); LoadModel(model_path, scope_.get(), &prog);
framework::ProgramDesc prog_desc(prog); framework::ProgramDesc prog_desc(prog);
executor_.reset(new Executor(&scope_, valid_places)); executor_.reset(new Executor(prog_desc, scope_.get(), valid_places));
executor_->PrepareWorkspace(prog_desc);
executor_->Build(prog_desc);
} }
// Get a tensor for input from scope directly. // Get a tensor for input from scope directly.
...@@ -53,7 +54,7 @@ class Predictor { ...@@ -53,7 +54,7 @@ class Predictor {
void Run() { executor_->Run(); } void Run() { executor_->Run(); }
private: private:
Scope scope_; std::shared_ptr<Scope> scope_;
std::unique_ptr<lite::Executor> executor_; std::unique_ptr<lite::Executor> executor_;
}; };
......
...@@ -26,8 +26,7 @@ TEST(CXXApi, raw) { ...@@ -26,8 +26,7 @@ TEST(CXXApi, raw) {
LoadModel("/home/chunwei/project2/models/model2", &scope, &prog); LoadModel("/home/chunwei/project2/models/model2", &scope, &prog);
framework::ProgramDesc prog_desc(prog); framework::ProgramDesc prog_desc(prog);
lite::Executor executor(&scope, lite::Executor executor(&scope, {Place{TARGET(kHost), PRECISION(kFloat)}});
{OpLite::Place{TARGET(kHost), PRECISION(kFloat)}});
auto x = scope.Var("a")->GetMutable<Tensor>(); auto x = scope.Var("a")->GetMutable<Tensor>();
x->Resize({100, 100}); x->Resize({100, 100});
...@@ -41,7 +40,7 @@ TEST(CXXApi, raw) { ...@@ -41,7 +40,7 @@ TEST(CXXApi, raw) {
TEST(CXXApi, test) { TEST(CXXApi, test) {
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build("/home/chunwei/project2/models/model2", predictor.Build("/home/chunwei/project2/models/model2",
{OpLite::Place{TARGET(kHost), PRECISION(kFloat)}}); {Place{TARGET(kHost), PRECISION(kFloat)}});
auto* x = predictor.GetInputTensor("a"); auto* x = predictor.GetInputTensor("a");
x->Resize({100, 200}); x->Resize({100, 200});
x->mutable_data<float>(); x->mutable_data<float>();
......
...@@ -18,6 +18,7 @@ cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph ...@@ -18,6 +18,7 @@ cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph
ops_lite ops_lite
host_kernels host_kernels
) )
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86) cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
......
...@@ -22,22 +22,12 @@ ...@@ -22,22 +22,12 @@
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/mir/node.h" #include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/program.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
// A program is used to represent a code program, in Paddle, a code program
// contains:
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct Program {
std::list<std::string> tmp_vars;
std::list<std::string> weights;
std::list<std::shared_ptr<OpLite>> ops;
lite::Scope *scope{};
};
// An Graph for MIR. It is built from a list of Op and a scope. // An Graph for MIR. It is built from a list of Op and a scope.
class GraphBase {}; class GraphBase {};
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/scope.h"
namespace paddle { namespace paddle {
...@@ -24,41 +25,16 @@ namespace lite { ...@@ -24,41 +25,16 @@ namespace lite {
// The Executor is used to run the operators. // The Executor is used to run the operators.
class Executor { class Executor {
public: public:
Executor(lite::Scope* scope, const std::vector<Place>& valid_places) Executor(const framework::ProgramDesc& desc,
: scope_(scope), valid_places_(valid_places) {} const std::shared_ptr<lite::Scope>& scope,
const std::vector<Place>& valid_places)
// Create temporary variables. : valid_places_(valid_places) {
void PrepareWorkspace(framework::ProgramDesc& program) { program_.reset(new Program(desc, scope, valid_places));
CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
exec_scope_ = &scope_->NewScope();
for (auto var_desc : program.Block(0).AllVars()) {
if (!var_desc->Persistable()) {
auto* var = exec_scope_->Var(var_desc->Name());
LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var;
}
}
}
// Build from a program and scope.
void Build(framework::ProgramDesc& program) {
CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators.
for (auto* op_desc : program.Block(0).AllOps()) {
auto op_type = op_desc->Type();
if (op_type == "feed" || op_type == "fetch") continue;
LOG(INFO) << "create Op [" << op_type << "]";
ops_.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel
ops_.back()->PickKernel(valid_places_);
ops_.back()->Attach(*op_desc, exec_scope_);
}
} }
// Run the program. // Run the program.
void Run() { void Run() {
for (auto& op : ops_) { for (auto& op : program_->ops) {
LOG(INFO) << op->DebugString(); LOG(INFO) << op->DebugString();
// TODO(Superjomn) check only once // TODO(Superjomn) check only once
op->CheckShape(); op->CheckShape();
...@@ -67,14 +43,11 @@ class Executor { ...@@ -67,14 +43,11 @@ class Executor {
} }
} }
lite::Scope* scope() { return scope_; } const Program& program() const { return *program_; }
lite::Scope* exec_scope() { return exec_scope_; }
private: private:
std::vector<std::unique_ptr<OpLite>> ops_;
lite::Scope* scope_{};
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
lite::Scope* exec_scope_{}; std::unique_ptr<Program> program_;
}; };
} // namespace lite } // namespace lite
......
...@@ -27,7 +27,7 @@ namespace lite { ...@@ -27,7 +27,7 @@ namespace lite {
*/ */
class Optimizer { class Optimizer {
public: public:
void Run(mir::Program&& program, const std::vector<Place>& valid_places, void Run(Program&& program, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}) { const std::vector<std::string>& passes = {}) {
CHECK(!graph_) << "duplicate optimize found"; CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph); graph_.reset(new mir::SSAGraph);
...@@ -36,8 +36,8 @@ class Optimizer { ...@@ -36,8 +36,8 @@ class Optimizer {
} }
// Generate a new program based on the mir graph. // Generate a new program based on the mir graph.
std::unique_ptr<mir::Program> GenProgram() { std::unique_ptr<Program> GenProgram() {
std::unique_ptr<mir::Program> res; std::unique_ptr<Program> res;
return res; return res;
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/program.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <string>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
// A program is used to represent a code program, in Paddle, a code program
// contains:
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct Program {
std::list<std::string> tmp_vars;
std::list<std::string> weights;
std::list<std::shared_ptr<OpLite>> ops;
// the scope to run the kernels, NOTE not the root scope.
std::shared_ptr<lite::Scope> scope;
// Runtime scope.
lite::Scope* exec_scope{};
explicit Program(const std::shared_ptr<Scope>& root) { scope = root; }
Program(const framework::ProgramDesc& desc,
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places) {
scope = root;
PrepareWorkspace(desc);
Build(desc, valid_places);
}
std::unique_ptr<Program> Clone() const {
std::unique_ptr<Program> res(new Program(scope));
res->tmp_vars = tmp_vars;
res->weights = weights;
res->ops = ops;
return res;
}
private:
// Build from a program and scope.
void Build(const framework::ProgramDesc& program,
const std::vector<Place>& valid_places) {
CHECK(ops.empty()) << "Executor duplicate Build found";
// Create operators.
for (auto* op_desc : program.Block(0).AllOps()) {
auto op_type = op_desc->Type();
if (op_type == "feed" || op_type == "fetch") continue;
LOG(INFO) << "create Op [" << op_type << "]";
ops.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel
ops.back()->PickKernel(valid_places);
ops.back()->Attach(*op_desc, exec_scope);
}
}
// Create temporary variables.
void PrepareWorkspace(const framework::ProgramDesc& program) {
CHECK(!exec_scope) << "Duplicate PrepareWorkspace found";
exec_scope = &scope->NewScope();
for (auto var_desc : program.Block(0).AllVars()) {
if (!var_desc->Persistable()) {
auto* var = exec_scope->Var(var_desc->Name());
LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var;
}
}
}
};
} // namespace lite
} // namespace paddle
...@@ -20,9 +20,8 @@ ...@@ -20,9 +20,8 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
mir::Program FakeProgram() { Program FakeProgram() {
mir::Program program; Program program(std::make_shared<lite::Scope>());
program.scope = new lite::Scope;
auto add_fc = [&](int id, std::string x) { auto add_fc = [&](int id, std::string x) {
// create variables // create variables
...@@ -48,7 +47,7 @@ mir::Program FakeProgram() { ...@@ -48,7 +47,7 @@ mir::Program FakeProgram() {
auto fc_op = LiteOpRegistry::Global().Create("fc"); auto fc_op = LiteOpRegistry::Global().Create("fc");
fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}}); fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc_op->Attach(desc, program.scope); fc_op->Attach(desc, program.scope.get());
program.ops.emplace_back(std::move(fc_op)); program.ops.emplace_back(std::move(fc_op));
w1v->Resize({100, 100}); w1v->Resize({100, 100});
......
...@@ -39,6 +39,15 @@ class Scope final { ...@@ -39,6 +39,15 @@ class Scope final {
const Scope* parent() const { return parent_; } const Scope* parent() const { return parent_; }
// Following the legacy scope interface.
std::vector<std::string> LocalVarNames() const {
std::vector<std::string> keys;
for (const auto& item : vars_) {
keys.push_back(item.first);
}
return keys;
}
private: private:
// Scope in `kids_` are owned by this class. // Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册