提交 43036686 编写于 作者: S Superjomn

init io complement pass

上级 64f10504
......@@ -15,6 +15,8 @@
#pragma once
#include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
namespace paddle {
......@@ -28,34 +30,23 @@ class Predictor {
void Build(const std::string& model_path,
const std::vector<Place>& valid_places) {
CHECK(!executor_.get()) << "duplicate build found";
CHECK(!scope_.get()) << "duplicate build found";
framework::proto::ProgramDesc prog;
LoadModel(model_path, scope_.get(), &prog);
framework::ProgramDesc prog_desc(prog);
executor_.reset(new Executor(prog_desc, scope_.get(), valid_places));
}
// Get a tensor for input from scope directly.
Tensor* GetInputTensor(const std::string& name) {
auto* var = executor_->exec_scope()->FindVar(name);
CHECK(var) << "no tensor called " << name << " exists";
return var->GetMutable<Tensor>();
}
Program program(prog_desc, scope_, valid_places);
// Get a tensor for output from scope directly.
const Tensor* GetOutputTensor(const std::string& name) {
auto* var = executor_->exec_scope()->FindVar(name);
CHECK(var) << "no tensor called " << name << " exists";
return &var->Get<Tensor>();
Optimizer optimizer;
optimizer.Run(std::move(program), valid_places);
program_ = optimizer.GenRuntimeProgram();
}
void Run() { executor_->Run(); }
void Run() { program_->Run(); }
private:
std::shared_ptr<Scope> scope_;
std::unique_ptr<lite::Executor> executor_;
std::unique_ptr<RuntimeProgram> program_;
};
} // namespace lite
......
......@@ -177,7 +177,6 @@ class ParamTypeRegistry {
const ParamType* Retrieve(const Place& place, const std::string& op_type,
const std::string& arg_name) {
KernelIdTy key{op_type, place, io, arg_name};
LOG(INFO) << "Looking for " << key;
auto it = types_.find(key);
if (it == types_.end()) return nullptr;
return &it->second;
......
......@@ -5,10 +5,10 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
cc_library(mir_passes
SRCS static_kernel_pick_pass.cc
variable_place_inference_pass.cc
io_complement_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
variable_place_inference_pass.cc
demo_pass.cc
DEPS mir_pass types_lite)
......
......@@ -23,7 +23,8 @@ void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
for (auto& item : graph->InstructTopologicalOrder()) {
if (item->IsInstruct()) {
auto& instruct = item->AsInstruct();
kernels_.emplace_back(std::move(instruct.valid_kernels.front()));
insts_.emplace_back(instruct.op,
std::move(instruct.valid_kernels.front()));
}
}
}
......
......@@ -30,10 +30,14 @@ class GenerateProgramPass : public ProgramPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
std::list<std::unique_ptr<KernelBase>>& kernels() { return kernels_; }
std::unique_ptr<RuntimeProgram> GenProgram() {
std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_)));
return program;
}
private:
std::list<std::unique_ptr<KernelBase>> kernels_;
std::vector<Instruction> insts_;
};
} // namespace mir
......
......@@ -19,8 +19,28 @@ namespace paddle {
namespace lite {
namespace mir {
void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph> &graph) {
// Start from inputs of the graph, those should should have place set.
void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set.
for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue;
auto& inst = node.AsInstruct();
// inputs
for (auto* in : node.inlinks) {
CHECK(in->IsArgument());
auto name = in->AsArgument().name;
std::string tmp;
CHECK(inst.op_info->GetInputArgname(name, &tmp));
auto type =
ParamTypeRegistry::Global().Retrieve<ParamTypeRegistry::IO::kInput>(
inst.place, inst.op_type, tmp);
CHECK(type) << "no param type found for " << inst.op_type << ":" << name
<< " " << inst.place;
if (type->tensor_place != inst.place) {
LOG(INFO) << "found IO unmatched tensor";
}
}
}
}
} // namespace mir
......
......@@ -48,6 +48,11 @@ class Node {
std::shared_ptr<OpInfo> op_info;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape
KernelBase& picked_kernel() {
CHECK(!valid_kernels.empty());
return *valid_kernels.front();
}
};
struct Argument {
......
......@@ -11,3 +11,9 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_executor.h"
namespace paddle {
namespace lite {} // namespace lite
} // namespace paddle
......@@ -17,6 +17,7 @@
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/scope.h"
namespace paddle {
......@@ -50,5 +51,18 @@ class Executor {
std::unique_ptr<Program> program_;
};
class RuntimeExecutor {
public:
RuntimeExecutor(RuntimeProgram* program) : program_(program) {}
void Run() {
CHECK(program_);
program_->Run();
}
private:
RuntimeProgram* program_{};
};
} // namespace lite
} // namespace paddle
......@@ -61,5 +61,25 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
return AttachImpl(opdesc, scope);
}
bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) {
for (auto &item : input_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
bool OpInfo::GetOutputArgname(const std::string &value_name, std::string *out) {
for (auto &item : output_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = item.first;
return true;
}
}
return false;
}
} // namespace lite
} // namespace paddle
......@@ -151,6 +151,8 @@ class OpInfo {
const std::map<std::string, std::list<std::string>> &output_argument() {
return output_argument_;
}
bool GetInputArgname(const std::string &value_name, std::string *out);
bool GetOutputArgname(const std::string &value_name, std::string *out);
const std::list<std::string> &input_argnames() const {
return input_argnames_;
......
......@@ -15,8 +15,10 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/program.h"
namespace paddle {
namespace lite {
......@@ -36,9 +38,11 @@ class Optimizer {
}
// Generate a new program based on the mir graph.
std::unique_ptr<Program> GenProgram() {
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
std::unique_ptr<Program> res;
return res;
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
return pass->GenProgram();
}
// Generate C++ code which combines the inference program, model and weights.
......
......@@ -37,12 +37,8 @@ TEST(Optimizer, test) {
.ConsiderPrecision();
optimizer.Run(std::move(program), places);
auto* program_pass =
mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
auto& kernels = program_pass->kernels();
LOG(INFO) << "get kernels: " << kernels.size();
auto runtime_program = optimizer.GenRuntimeProgram();
LOG(INFO) << "num instructions " << runtime_program->num_instructions();
}
} // namespace lite
......
......@@ -18,6 +18,7 @@
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
......@@ -86,5 +87,43 @@ struct Program {
}
};
struct Instruction {
Instruction(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel)
: op_(op), kernel_(std::move(kernel)) {}
void Run() {
CHECK(op_);
CHECK(kernel_);
op_->InferShape();
kernel_->Run();
}
private:
std::shared_ptr<OpLite> op_;
std::unique_ptr<KernelBase> kernel_;
};
/*
* A program contains kernels for runtime.
*/
class RuntimeProgram {
public:
explicit RuntimeProgram(std::vector<Instruction>&& instruction)
: instructions_(std::move(instruction)) {}
void Run() {
for (auto& inst : instructions_) {
inst.Run();
}
}
size_t num_instructions() const { return instructions_.size(); }
private:
RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_;
};
} // namespace lite
} // namespace paddle
......@@ -84,6 +84,8 @@ struct Place {
layout == other.layout && device == other.device;
}
bool operator!=(const Place& other) const { return !(*this == other); }
friend bool operator<(const Place& a, const Place& b) {
if (a.target != b.target) return a.target < b.target;
if (a.precision != b.precision) return a.precision < b.precision;
......@@ -92,6 +94,11 @@ struct Place {
return true;
}
friend std::ostream& operator<<(std::ostream& os, const Place& other) {
os << other.DebugString();
return os;
}
std::string DebugString() const {
std::stringstream os;
os << TargetToStr(target) << "/" << PrecisionToStr(precision) << "/"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册