未验证 提交 a5241e1d 编写于 作者: Y Yan Chunwei 提交者: GitHub

refine program (#17726)

上级 01e1cdac
...@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "glog/logging.h" #include <utility>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
......
...@@ -64,7 +64,7 @@ class LightPredictor { ...@@ -64,7 +64,7 @@ class LightPredictor {
private: private:
void BuildRuntimeProgram(const framework::proto::ProgramDesc& prog) { void BuildRuntimeProgram(const framework::proto::ProgramDesc& prog) {
std::vector<Instruct> insts; std::vector<Instruction> insts;
// 1. Create op first // 1. Create op first
Program program(prog, scope_, {}); Program program(prog, scope_, {});
...@@ -72,7 +72,7 @@ class LightPredictor { ...@@ -72,7 +72,7 @@ class LightPredictor {
// Create the kernels of the target places, and filter out the specific // Create the kernels of the target places, and filter out the specific
// kernel with the target alias. // kernel with the target alias.
for (auto& op : program.ops) { for (auto& op : program.ops_) {
lite::pb::OpDesc desc(op->op_info()->desc()); lite::pb::OpDesc desc(op->op_info()->desc());
auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>(); auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>();
std::string op_type, alias; std::string op_type, alias;
...@@ -89,8 +89,8 @@ class LightPredictor { ...@@ -89,8 +89,8 @@ class LightPredictor {
insts.emplace_back(op, std::move(*it)); insts.emplace_back(op, std::move(*it));
} }
program_.reset(new RuntimeProgram(std::move(insts))); program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope); CHECK(program.exec_scope_);
program_->set_exec_scope(program.exec_scope); program_->set_exec_scope(program.exec_scope_);
} }
private: private:
......
...@@ -30,7 +30,7 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp ...@@ -30,7 +30,7 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite) cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
add_subdirectory(mir) add_subdirectory(mir)
......
...@@ -41,7 +41,7 @@ class GenerateProgramPass : public ProgramPass { ...@@ -41,7 +41,7 @@ class GenerateProgramPass : public ProgramPass {
} }
private: private:
std::vector<Instruct> insts_; std::vector<Instruction> insts_;
}; };
} // namespace mir } // namespace mir
......
...@@ -94,7 +94,7 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() { ...@@ -94,7 +94,7 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() {
} }
void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
for (const auto &name : program.tmp_vars) { for (const auto &name : program.tmp_vars()) {
CHECK(!arguments_.count(name)) << "duplicate creating temp variable: " CHECK(!arguments_.count(name)) << "duplicate creating temp variable: "
<< name; << name;
VLOG(5) << "create arg node " << name; VLOG(5) << "create arg node " << name;
...@@ -107,7 +107,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { ...@@ -107,7 +107,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
// create weight nodes. // create weight nodes.
for (const auto &name : program.weights) { for (const auto &name : program.weights()) {
CHECK(!arguments_.count(name)) << "duplicate creating weight variable: " CHECK(!arguments_.count(name)) << "duplicate creating weight variable: "
<< name; << name;
VLOG(5) << "create arg node " << name; VLOG(5) << "create arg node " << name;
...@@ -140,7 +140,7 @@ void SSAGraph::Build(const Program &program, ...@@ -140,7 +140,7 @@ void SSAGraph::Build(const Program &program,
GraphCreateWeightVarNodes(program); GraphCreateWeightVarNodes(program);
CHECK(CheckNodesRoleSet()); CHECK(CheckNodesRoleSet());
for (auto &op : program.ops) { for (auto &op : program.ops()) {
auto *op_node = GraphCreateInstructNode(program, op, valid_places); auto *op_node = GraphCreateInstructNode(program, op, valid_places);
for (const std::string &name : op->op_info()->input_names()) { for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name); auto *arg = Argument(name);
......
...@@ -77,7 +77,7 @@ class SSAGraph : GraphBase { ...@@ -77,7 +77,7 @@ class SSAGraph : GraphBase {
bool CheckLinksRoleSet(); bool CheckLinksRoleSet();
void MarkArgumentWeights(const Program &program) { void MarkArgumentWeights(const Program &program) {
for (const auto &name : program.weights) { for (const auto &name : program.weights()) {
arguments_[name]->AsArg().is_weight = true; arguments_[name]->AsArg().is_weight = true;
} }
} }
......
...@@ -147,7 +147,7 @@ class OpLite : public Registry { ...@@ -147,7 +147,7 @@ class OpLite : public Registry {
class OpInfo : public cpp::OpDesc { class OpInfo : public cpp::OpDesc {
public: public:
OpInfo(const OpInfo &) = default; OpInfo(const OpInfo &) = default;
OpInfo(const cpp::OpDesc &other) : cpp::OpDesc(other) {} explicit OpInfo(const cpp::OpDesc &other) : cpp::OpDesc(other) {}
// Collect all the input variable's name. // Collect all the input variable's name.
std::vector<std::string> input_names() const { std::vector<std::string> input_names() const {
......
...@@ -64,7 +64,7 @@ class Optimizer { ...@@ -64,7 +64,7 @@ class Optimizer {
RunPasses(passes); RunPasses(passes);
} }
#endif #endif
exec_scope_ = program.exec_scope; exec_scope_ = program.exec_scope();
} }
void KernelPickPreferPlace(const Place& place) { void KernelPickPreferPlace(const Place& place) {
......
...@@ -62,5 +62,45 @@ void RuntimeProgram::SaveParams(const std::string &dir, ...@@ -62,5 +62,45 @@ void RuntimeProgram::SaveParams(const std::string &dir,
} }
} }
void Program::Build(const framework::proto::ProgramDesc &program) {
CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators.
for (const auto &proto_op_desc : program.blocks(0).ops()) {
lite::OpDesc op_desc_dummy(proto_op_desc);
cpp::OpDesc op_desc;
TransformOpDescPbToCpp(op_desc_dummy, &op_desc);
auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]";
LOG(INFO) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops_.emplace_back(std::move(op));
ops_.back()->Attach(op_desc, exec_scope_);
}
}
void Program::PrepareWorkspace(const framework::proto::ProgramDesc &program) {
CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
exec_scope_ = &scope_->NewScope();
// Create Feed and Fetch var.
scope_->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
scope_->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
tmp_vars_.push_back("feed");
tmp_vars_.push_back("fetch");
CHECK(!program.blocks().empty());
for (auto proto_var_desc : program.blocks(0).vars()) {
lite::VarDesc var_desc(proto_var_desc);
if (!var_desc.Persistable()) {
tmp_vars_.push_back(var_desc.Name());
exec_scope_->Var(var_desc.Name());
} else {
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
weights_.push_back(var_desc.Name());
}
}
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -37,78 +37,47 @@ static const char kKernelTypeAttr[] = "__@kernel_type_attr@__"; ...@@ -37,78 +37,47 @@ static const char kKernelTypeAttr[] = "__@kernel_type_attr@__";
// - main block, which is a list of OpLite // - main block, which is a list of OpLite
// - scope: which contains all the weights // - scope: which contains all the weights
struct Program { struct Program {
std::list<std::string> tmp_vars; public:
std::list<std::string> weights; explicit Program(const std::shared_ptr<Scope>& root) { scope_ = root; }
std::list<std::shared_ptr<OpLite>> ops;
// the scope to run the kernels, NOTE this is the execution scope.
std::shared_ptr<lite::Scope> scope;
std::vector<Place> valid_places;
// Runtime scope.
lite::Scope* exec_scope{};
const framework::proto::ProgramDesc desc;
explicit Program(const std::shared_ptr<Scope>& root) { scope = root; }
Program(const framework::proto::ProgramDesc& desc, Program(const framework::proto::ProgramDesc& desc,
const std::shared_ptr<Scope>& root, const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places) const std::vector<Place>& valid_places)
: scope(root), valid_places(valid_places), desc(desc) { : scope_(root), valid_places_(valid_places), desc_(desc) {
CHECK(scope) << "scope should be init first"; CHECK(scope_) << "scope should be init first";
PrepareWorkspace(desc); PrepareWorkspace(desc);
Build(desc); Build(desc);
} }
std::unique_ptr<Program> Clone() const { std::unique_ptr<Program> Clone() const {
std::unique_ptr<Program> res(new Program(desc, scope, valid_places)); std::unique_ptr<Program> res(new Program(desc_, scope_, valid_places_));
return res; return res;
} }
const std::list<std::string>& weights() const { return weights_; }
const std::list<std::string>& tmp_vars() const { return tmp_vars_; }
const std::list<std::shared_ptr<OpLite>>& ops() const { return ops_; }
lite::Scope* exec_scope() { return exec_scope_; }
private: private:
// Build from a program and scope. // Build from a program and scope.
void Build(const framework::proto::ProgramDesc& program) { void Build(const framework::proto::ProgramDesc& program);
CHECK(ops.empty()) << "Executor duplicate Build found";
// Create operators.
for (const auto& proto_op_desc : program.blocks(0).ops()) {
pb::OpDesc op_desc(proto_op_desc);
auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]";
LOG(INFO) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops.emplace_back(std::move(op));
cpp::OpDesc cpp_op_desc;
TransformOpDescPbToCpp(op_desc, &cpp_op_desc);
ops.back()->Attach(cpp_op_desc, exec_scope);
}
}
// Create temporary variables. // Create temporary variables.
void PrepareWorkspace(const framework::proto::ProgramDesc& program) { void PrepareWorkspace(const framework::proto::ProgramDesc& program);
CHECK(!exec_scope) << "Duplicate PrepareWorkspace found";
exec_scope = &scope->NewScope(); private:
// Create Feed and Fetch var. std::list<std::string> tmp_vars_;
scope->Var("feed")->GetMutable<std::vector<lite::Tensor>>(); std::list<std::string> weights_;
scope->Var("fetch")->GetMutable<std::vector<lite::Tensor>>(); std::list<std::shared_ptr<OpLite>> ops_;
// the scope to run the kernels, NOTE this is the execution scope.
tmp_vars.push_back("feed"); std::shared_ptr<lite::Scope> scope_;
tmp_vars.push_back("fetch"); std::vector<Place> valid_places_;
CHECK(!program.blocks().empty()); // Runtime scope.
for (auto proto_var_desc : program.blocks(0).vars()) { lite::Scope* exec_scope_{};
lite::VarDesc var_desc(proto_var_desc); const framework::proto::ProgramDesc desc_;
if (!var_desc.Persistable()) {
tmp_vars.push_back(var_desc.Name());
exec_scope->Var(var_desc.Name());
} else {
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
weights.push_back(var_desc.Name());
}
}
}
}; };
struct Instruct { struct Instruction {
Instruct(const std::shared_ptr<OpLite>& op, Instruction(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel) std::unique_ptr<KernelBase>&& kernel)
: op_(op), kernel_(std::move(kernel)) { : op_(op), kernel_(std::move(kernel)) {
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
...@@ -132,7 +101,7 @@ struct Instruct { ...@@ -132,7 +101,7 @@ struct Instruct {
kernel_->Launch(); kernel_->Launch();
} }
friend std::ostream& operator<<(std::ostream& os, const Instruct& other) { friend std::ostream& operator<<(std::ostream& os, const Instruction& other) {
os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")"; os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")";
return os; return os;
} }
...@@ -156,7 +125,7 @@ struct Instruct { ...@@ -156,7 +125,7 @@ struct Instruct {
*/ */
class RuntimeProgram { class RuntimeProgram {
public: public:
explicit RuntimeProgram(std::vector<Instruct>&& insts) explicit RuntimeProgram(std::vector<Instruction>&& insts)
: instructions_(std::move(insts)) { : instructions_(std::move(insts)) {
if (instructions_.empty()) { if (instructions_.empty()) {
LOG(FATAL) << "no instructions"; LOG(FATAL) << "no instructions";
...@@ -186,7 +155,7 @@ class RuntimeProgram { ...@@ -186,7 +155,7 @@ class RuntimeProgram {
private: private:
RuntimeProgram(const RuntimeProgram&) = delete; RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruct> instructions_; std::vector<Instruction> instructions_;
lite::Scope* exec_scope_{}; lite::Scope* exec_scope_{};
}; };
......
...@@ -33,9 +33,9 @@ Program FakeProgram() { ...@@ -33,9 +33,9 @@ Program FakeProgram() {
std::string w1 = "w" + std::to_string(id); std::string w1 = "w" + std::to_string(id);
std::string b1 = "b" + std::to_string(id); std::string b1 = "b" + std::to_string(id);
std::string out1 = "out" + std::to_string(id); std::string out1 = "out" + std::to_string(id);
auto w1v = program.scope->Var(w1)->GetMutable<lite::Tensor>(); auto w1v = program.scope_->Var(w1)->GetMutable<lite::Tensor>();
auto b1v = program.scope->Var(b1)->GetMutable<lite::Tensor>(); auto b1v = program.scope_->Var(b1)->GetMutable<lite::Tensor>();
auto out1v = program.scope->Var(out1)->GetMutable<lite::Tensor>(); auto out1v = program.scope_->Var(out1)->GetMutable<lite::Tensor>();
lite::OpDesc desc; lite::OpDesc desc;
desc.SetInput("Input", {x}); desc.SetInput("Input", {x});
...@@ -46,12 +46,12 @@ Program FakeProgram() { ...@@ -46,12 +46,12 @@ Program FakeProgram() {
desc.SetAttr("in_num_col_dims", 1); desc.SetAttr("in_num_col_dims", 1);
// add to input // add to input
program.tmp_vars.push_back(w1); program.tmp_vars_.push_back(w1);
program.tmp_vars.push_back(b1); program.tmp_vars_.push_back(b1);
auto fc_op = LiteOpRegistry::Global().Create("fc"); auto fc_op = LiteOpRegistry::Global().Create("fc");
fc_op->Attach(desc, program.scope.get()); fc_op->Attach(desc, program.scope_.get());
program.ops.emplace_back(std::move(fc_op)); program.ops_.emplace_back(std::move(fc_op));
w1v->Resize(DDimHvy(std::vector<int64_t>({100, 100}))); w1v->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
b1v->Resize(DDimHvy(std::vector<int64_t>({100, 1}))); b1v->Resize(DDimHvy(std::vector<int64_t>({100, 1})));
...@@ -64,8 +64,8 @@ Program FakeProgram() { ...@@ -64,8 +64,8 @@ Program FakeProgram() {
// out1, w2, b2 -fc-> out2 // out1, w2, b2 -fc-> out2
std::string x = "x"; std::string x = "x";
program.tmp_vars.push_back(x); program.tmp_vars_.push_back(x);
auto* xv = program.scope->Var(x)->GetMutable<lite::Tensor>(); auto* xv = program.scope_->Var(x)->GetMutable<lite::Tensor>();
xv->Resize(DDimHvy(std::vector<int64_t>({100, 100}))); xv->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/model_parser/compatible_pb.h" #include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "compatible_pb.h" #include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -61,6 +61,7 @@ static std::string Join(const std::vector<std::string>& vec, ...@@ -61,6 +61,7 @@ static std::string Join(const std::vector<std::string>& vec,
if (!vec.empty()) { if (!vec.empty()) {
ss << vec.back(); ss << vec.back();
} }
return ss.str(); return ss.str();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册