提交 43036686 编写于 作者: S Superjomn

init io complement pass

上级 64f10504
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#pragma once #pragma once
#include "paddle/fluid/lite/core/op_executor.h" #include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/model_parser/model_parser.h" #include "paddle/fluid/lite/model_parser/model_parser.h"
namespace paddle { namespace paddle {
...@@ -28,34 +30,23 @@ class Predictor { ...@@ -28,34 +30,23 @@ class Predictor {
void Build(const std::string& model_path, void Build(const std::string& model_path,
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places) {
CHECK(!executor_.get()) << "duplicate build found";
CHECK(!scope_.get()) << "duplicate build found"; CHECK(!scope_.get()) << "duplicate build found";
framework::proto::ProgramDesc prog; framework::proto::ProgramDesc prog;
LoadModel(model_path, scope_.get(), &prog); LoadModel(model_path, scope_.get(), &prog);
framework::ProgramDesc prog_desc(prog); framework::ProgramDesc prog_desc(prog);
executor_.reset(new Executor(prog_desc, scope_.get(), valid_places)); Program program(prog_desc, scope_, valid_places);
}
// Get a tensor for input from scope directly.
Tensor* GetInputTensor(const std::string& name) {
auto* var = executor_->exec_scope()->FindVar(name);
CHECK(var) << "no tensor called " << name << " exists";
return var->GetMutable<Tensor>();
}
// Get a tensor for output from scope directly. Optimizer optimizer;
const Tensor* GetOutputTensor(const std::string& name) { optimizer.Run(std::move(program), valid_places);
auto* var = executor_->exec_scope()->FindVar(name); program_ = optimizer.GenRuntimeProgram();
CHECK(var) << "no tensor called " << name << " exists";
return &var->Get<Tensor>();
} }
void Run() { executor_->Run(); } void Run() { program_->Run(); }
private: private:
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
std::unique_ptr<lite::Executor> executor_; std::unique_ptr<RuntimeProgram> program_;
}; };
} // namespace lite } // namespace lite
......
...@@ -177,7 +177,6 @@ class ParamTypeRegistry { ...@@ -177,7 +177,6 @@ class ParamTypeRegistry {
const ParamType* Retrieve(const Place& place, const std::string& op_type, const ParamType* Retrieve(const Place& place, const std::string& op_type,
const std::string& arg_name) { const std::string& arg_name) {
KernelIdTy key{op_type, place, io, arg_name}; KernelIdTy key{op_type, place, io, arg_name};
LOG(INFO) << "Looking for " << key;
auto it = types_.find(key); auto it = types_.find(key);
if (it == types_.end()) return nullptr; if (it == types_.end()) return nullptr;
return &it->second; return &it->second;
......
...@@ -5,10 +5,10 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph) ...@@ -5,10 +5,10 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
cc_library(mir_passes cc_library(mir_passes
SRCS static_kernel_pick_pass.cc SRCS static_kernel_pick_pass.cc
variable_place_inference_pass.cc
io_complement_pass.cc io_complement_pass.cc
graph_visualize_pass.cc graph_visualize_pass.cc
generate_program_pass.cc generate_program_pass.cc
variable_place_inference_pass.cc
demo_pass.cc demo_pass.cc
DEPS mir_pass types_lite) DEPS mir_pass types_lite)
......
...@@ -23,7 +23,8 @@ void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -23,7 +23,8 @@ void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
for (auto& item : graph->InstructTopologicalOrder()) { for (auto& item : graph->InstructTopologicalOrder()) {
if (item->IsInstruct()) { if (item->IsInstruct()) {
auto& instruct = item->AsInstruct(); auto& instruct = item->AsInstruct();
kernels_.emplace_back(std::move(instruct.valid_kernels.front())); insts_.emplace_back(instruct.op,
std::move(instruct.valid_kernels.front()));
} }
} }
} }
......
...@@ -30,10 +30,14 @@ class GenerateProgramPass : public ProgramPass { ...@@ -30,10 +30,14 @@ class GenerateProgramPass : public ProgramPass {
public: public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override; void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
std::list<std::unique_ptr<KernelBase>>& kernels() { return kernels_; } std::unique_ptr<RuntimeProgram> GenProgram() {
std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_)));
return program;
}
private: private:
std::list<std::unique_ptr<KernelBase>> kernels_; std::vector<Instruction> insts_;
}; };
} // namespace mir } // namespace mir
......
...@@ -19,8 +19,28 @@ namespace paddle { ...@@ -19,8 +19,28 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph> &graph) { void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// Start from inputs of the graph, those should should have place set. // Start from inputs of the graph, those should have place set.
for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue;
auto& inst = node.AsInstruct();
// inputs
for (auto* in : node.inlinks) {
CHECK(in->IsArgument());
auto name = in->AsArgument().name;
std::string tmp;
CHECK(inst.op_info->GetInputArgname(name, &tmp));
auto type =
ParamTypeRegistry::Global().Retrieve<ParamTypeRegistry::IO::kInput>(
inst.place, inst.op_type, tmp);
CHECK(type) << "no param type found for " << inst.op_type << ":" << name
<< " " << inst.place;
if (type->tensor_place != inst.place) {
LOG(INFO) << "found IO unmatched tensor";
}
}
}
} }
} // namespace mir } // namespace mir
......
...@@ -48,6 +48,11 @@ class Node { ...@@ -48,6 +48,11 @@ class Node {
std::shared_ptr<OpInfo> op_info; std::shared_ptr<OpInfo> op_info;
// TODO(Superjomn) make this a shared_ptr for resource safety. // TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape std::shared_ptr<OpLite> op; // we hold op to run InferShape
KernelBase& picked_kernel() {
CHECK(!valid_kernels.empty());
return *valid_kernels.front();
}
}; };
struct Argument { struct Argument {
......
...@@ -11,3 +11,9 @@ ...@@ -11,3 +11,9 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/op_executor.h"
namespace paddle {
namespace lite {} // namespace lite
} // namespace paddle
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/scope.h"
namespace paddle { namespace paddle {
...@@ -50,5 +51,18 @@ class Executor { ...@@ -50,5 +51,18 @@ class Executor {
std::unique_ptr<Program> program_; std::unique_ptr<Program> program_;
}; };
class RuntimeExecutor {
public:
RuntimeExecutor(RuntimeProgram* program) : program_(program) {}
void Run() {
CHECK(program_);
program_->Run();
}
private:
RuntimeProgram* program_{};
};
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -61,5 +61,25 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { ...@@ -61,5 +61,25 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
return AttachImpl(opdesc, scope); return AttachImpl(opdesc, scope);
} }
bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) {
for (auto &item : input_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool OpInfo::GetOutputArgname(const std::string &value_name, std::string *out) {
for (auto &item : output_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -151,6 +151,8 @@ class OpInfo { ...@@ -151,6 +151,8 @@ class OpInfo {
const std::map<std::string, std::list<std::string>> &output_argument() { const std::map<std::string, std::list<std::string>> &output_argument() {
return output_argument_; return output_argument_;
} }
bool GetInputArgname(const std::string &value_name, std::string *out);
bool GetOutputArgname(const std::string &value_name, std::string *out);
const std::list<std::string> &input_argnames() const { const std::list<std::string> &input_argnames() const {
return input_argnames_; return input_argnames_;
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
#pragma once #pragma once
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/program.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -36,9 +38,11 @@ class Optimizer { ...@@ -36,9 +38,11 @@ class Optimizer {
} }
// Generate a new program based on the mir graph. // Generate a new program based on the mir graph.
std::unique_ptr<Program> GenProgram() { std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
std::unique_ptr<Program> res; std::unique_ptr<Program> res;
return res; auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
return pass->GenProgram();
} }
// Generate C++ code which combines the inference program, model and weights. // Generate C++ code which combines the inference program, model and weights.
......
...@@ -37,12 +37,8 @@ TEST(Optimizer, test) { ...@@ -37,12 +37,8 @@ TEST(Optimizer, test) {
.ConsiderPrecision(); .ConsiderPrecision();
optimizer.Run(std::move(program), places); optimizer.Run(std::move(program), places);
auto runtime_program = optimizer.GenRuntimeProgram();
auto* program_pass = LOG(INFO) << "num instructions " << runtime_program->num_instructions();
mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
auto& kernels = program_pass->kernels();
LOG(INFO) << "get kernels: " << kernels.size();
} }
} // namespace lite } // namespace lite
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
...@@ -86,5 +87,43 @@ struct Program { ...@@ -86,5 +87,43 @@ struct Program {
} }
}; };
struct Instruction {
Instruction(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel)
: op_(op), kernel_(std::move(kernel)) {}
void Run() {
CHECK(op_);
CHECK(kernel_);
op_->InferShape();
kernel_->Run();
}
private:
std::shared_ptr<OpLite> op_;
std::unique_ptr<KernelBase> kernel_;
};
/*
* A program contains kernels for runtime.
*/
class RuntimeProgram {
public:
explicit RuntimeProgram(std::vector<Instruction>&& instruction)
: instructions_(std::move(instruction)) {}
void Run() {
for (auto& inst : instructions_) {
inst.Run();
}
}
size_t num_instructions() const { return instructions_.size(); }
private:
RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_;
};
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -84,6 +84,8 @@ struct Place { ...@@ -84,6 +84,8 @@ struct Place {
layout == other.layout && device == other.device; layout == other.layout && device == other.device;
} }
bool operator!=(const Place& other) const { return !(*this == other); }
friend bool operator<(const Place& a, const Place& b) { friend bool operator<(const Place& a, const Place& b) {
if (a.target != b.target) return a.target < b.target; if (a.target != b.target) return a.target < b.target;
if (a.precision != b.precision) return a.precision < b.precision; if (a.precision != b.precision) return a.precision < b.precision;
...@@ -92,6 +94,11 @@ struct Place { ...@@ -92,6 +94,11 @@ struct Place {
return true; return true;
} }
friend std::ostream& operator<<(std::ostream& os, const Place& other) {
os << other.DebugString();
return os;
}
std::string DebugString() const { std::string DebugString() const {
std::stringstream os; std::stringstream os;
os << TargetToStr(target) << "/" << PrecisionToStr(precision) << "/" os << TargetToStr(target) << "/" << PrecisionToStr(precision) << "/"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册