提交 239d716b 编写于 作者: S superjomn

init optimizer and kernel_executor

上级 cdb12e59
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite executor_lite host_kernels ops_lite) cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite)
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite) cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite)
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/lite/core/executor.h" #include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/model_parser/model_parser.h" #include "paddle/fluid/lite/model_parser/model_parser.h"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "paddle/fluid/lite/api/cxx_api.h" #include "paddle/fluid/lite/api/cxx_api.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/lite/core/executor.h" #include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
namespace paddle { namespace paddle {
......
...@@ -5,16 +5,18 @@ cc_library(variable_lite SRCS variable.cc) ...@@ -5,16 +5,18 @@ cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc) cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc) cc_library(scope_lite SRCS scope.cc)
cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite) cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite)
cc_library(executor_lite SRCS executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite
#TODO(Superjomn) remove these dependencies from original framework #TODO(Superjomn) remove these dependencies from original framework
proto_desc) proto_desc)
cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite)
cc_library(type_system SRCS type_system.cc DEPS tensor_lite) cc_library(type_system SRCS type_system.cc DEPS tensor_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86) cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite) cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
cc_test(test_tensor_lite SRCS tensor_test.cc) cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test(test_executor_lite SRCS executor_test.cc DEPS executor_lite ops_lite host_kernels) cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels)
cc_test(test_type_system SRCS type_system_test.cc DEPS type_system) cc_test(test_type_system SRCS type_system_test.cc DEPS type_system)
add_subdirectory(mir) add_subdirectory(mir)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/kernel_executor.h"
namespace paddle {
namespace lite {} // namespace lite
} // namespace paddle
\ No newline at end of file
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace paddle {
namespace lite {
/*
* KernelExecutor executes a list of kernels.
*/
class KernelExecutorBase {
public:
KernelExecutorBase(std::unique_ptr<mir::Program>&& program);
// Prepare runtime context.
void PrepareWorkspace();
void Run();
private:
lite::Scope* scope_{};
lite::Scope* exec_scope_{};
};
/*
* KernelExecutor executes the kernels without concurrency, works in X86 place.
*/
class SerialKernelExecutor : public KernelExecutorBase {};
/*
* KernelExecutor executes the kernels with CUDA like stream parallel support,
* works in CUDA like devices.
*/
class StreamKernelExecutor : public KernelExecutorBase {};
} // namespace lite
} // namespace paddle
...@@ -17,4 +17,5 @@ cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS ...@@ -17,4 +17,5 @@ cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
proto_desc ops_lite proto_desc ops_lite
host_kernels host_kernels
mir_passes mir_passes
mir_pass_manager
) )
...@@ -20,6 +20,8 @@ namespace paddle { ...@@ -20,6 +20,8 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
using inference::analysis::Dot;
void GraphVisualizePass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { void GraphVisualizePass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
Visualize(graph.get()); Visualize(graph.get());
} }
...@@ -39,7 +41,7 @@ std::string Visualize(mir::SSAGraph* graph) { ...@@ -39,7 +41,7 @@ std::string Visualize(mir::SSAGraph* graph) {
} }
if (node.IsInstruct()) { if (node.IsInstruct()) {
dot.AddNode(key, {}); dot.AddNode(key, {Dot::Attr("shape", "box")});
for (auto& x : node.inlinks) { for (auto& x : node.inlinks) {
auto name = x->AsArgument().name; auto name = x->AsArgument().name;
if (!exists_args.count(name)) { if (!exists_args.count(name)) {
......
...@@ -24,5 +24,3 @@ PassManager::PassManager() {} ...@@ -24,5 +24,3 @@ PassManager::PassManager() {}
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_MIR_PASS(demo);
...@@ -30,9 +30,10 @@ namespace mir { ...@@ -30,9 +30,10 @@ namespace mir {
// - main block, which is a list of OpLite // - main block, which is a list of OpLite
// - scope: which contains all the weights // - scope: which contains all the weights
struct Program { struct Program {
std::list<std::string> inputs; std::list<std::string> tmp_vars;
std::list<std::string> weights;
std::list<std::unique_ptr<OpLite>> ops; std::list<std::unique_ptr<OpLite>> ops;
std::unique_ptr<lite::Scope> scope; lite::Scope *scope;
}; };
// An Graph for MIR. It is built from a list of Op and a scope. // An Graph for MIR. It is built from a list of Op and a scope.
...@@ -44,7 +45,7 @@ class SSAGraph : GraphBase { ...@@ -44,7 +45,7 @@ class SSAGraph : GraphBase {
// @param valid_places: the valid places user set for the system. // @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places) { void Build(const Program &program, const std::vector<Place> &valid_places) {
// create inputs // create inputs
for (const auto &name : program.inputs) { for (const auto &name : program.tmp_vars) {
node_storage_.emplace_back(); node_storage_.emplace_back();
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument(); auto &arg = new_node.AsArgument();
......
...@@ -34,7 +34,7 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x, ...@@ -34,7 +34,7 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
Program FakeProgram() { Program FakeProgram() {
Program program; Program program;
program.scope.reset(new lite::Scope); program.scope = new lite::Scope;
auto add_fc = [&](int id, std::string x) { auto add_fc = [&](int id, std::string x) {
// create variables // create variables
...@@ -55,12 +55,12 @@ Program FakeProgram() { ...@@ -55,12 +55,12 @@ Program FakeProgram() {
desc.Flush(); desc.Flush();
// add to input // add to input
program.inputs.push_back(w1); program.tmp_vars.push_back(w1);
program.inputs.push_back(b1); program.tmp_vars.push_back(b1);
auto fc_op = LiteOpRegistry::Global().Create("fc"); auto fc_op = LiteOpRegistry::Global().Create("fc");
fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}}); fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc_op->Attach(desc, program.scope.get()); fc_op->Attach(desc, program.scope);
program.ops.emplace_back(std::move(fc_op)); program.ops.emplace_back(std::move(fc_op));
w1v->Resize({100, 100}); w1v->Resize({100, 100});
...@@ -74,7 +74,7 @@ Program FakeProgram() { ...@@ -74,7 +74,7 @@ Program FakeProgram() {
// out1, w2, b2 -fc-> out2 // out1, w2, b2 -fc-> out2
std::string x = "x"; std::string x = "x";
program.inputs.push_back(x); program.tmp_vars.push_back(x);
auto* xv = program.scope->Var(x)->GetMutable<Tensor>(); auto* xv = program.scope->Var(x)->GetMutable<Tensor>();
xv->Resize({100, 100}); xv->Resize({100, 100});
......
...@@ -24,7 +24,7 @@ namespace lite { ...@@ -24,7 +24,7 @@ namespace lite {
// The Executor is used to run the operators. // The Executor is used to run the operators.
class Executor { class Executor {
public: public:
Executor(lite::Scope* scope, const std::vector<OpLite::Place>& valid_places) Executor(lite::Scope* scope, const std::vector<Place>& valid_places)
: scope_(scope), valid_places_(valid_places) {} : scope_(scope), valid_places_(valid_places) {}
// Create temporary variables. // Create temporary variables.
...@@ -52,7 +52,7 @@ class Executor { ...@@ -52,7 +52,7 @@ class Executor {
ops_.emplace_back(LiteOpRegistry::Global().Create(op_type)); ops_.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel // pick initial kernel
ops_.back()->PickKernel(valid_places_); ops_.back()->PickKernel(valid_places_);
ops_.back()->AttachImpl(*op_desc, exec_scope_); ops_.back()->Attach(*op_desc, exec_scope_);
} }
} }
...@@ -73,7 +73,7 @@ class Executor { ...@@ -73,7 +73,7 @@ class Executor {
private: private:
std::vector<std::unique_ptr<OpLite>> ops_; std::vector<std::unique_ptr<OpLite>> ops_;
lite::Scope* scope_{}; lite::Scope* scope_{};
std::vector<OpLite::Place> valid_places_; std::vector<Place> valid_places_;
lite::Scope* exec_scope_{}; lite::Scope* exec_scope_{};
}; };
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/executor.h" #include "paddle/fluid/lite/core/op_executor.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace paddle {
namespace lite {
/*
* lite::Optimizer optimize a program. It utilize the mir passes to analysis the
* program and export an optimized program.
*/
class Optimizer {
public:
void Run(std::unique_ptr<mir::Program>&& program,
const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}) {
CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph);
graph_->Build(*program, valid_places);
RunPasses();
}
// Generate a new program based on the mir graph.
std::unique_ptr<mir::Program> GenProgram() {}
// Generate C++ code which combines the inference program, model and weights.
void GenCode(const std::string& code_dir);
const mir::SSAGraph& ssa_graph() const {
CHECK(graph_);
return *graph_;
}
protected:
// Run the default passes registered in the PassManager.
void RunPasses() { mir::PassManager::Global().Run(); }
// Specify the passes and run them.
void RunPasses(std::vector<std::string>& passes);
private:
std::unique_ptr<mir::SSAGraph> graph_;
};
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册