提交 f1ca00a4 编写于 作者: S Superjomn

rename some concepts

Instruction to Stmt
上级 ec27aa46
...@@ -104,12 +104,16 @@ class KernelBase { ...@@ -104,12 +104,16 @@ class KernelBase {
mutable operators::param_t param_; mutable operators::param_t param_;
// The corresponding op type. // The corresponding op type.
std::string op_type_{}; std::string op_type_{};
// The extra identity to help defficiate a specific kernel, op_type_ + alias_
// is the unique ID for the kernel.
std::string alias_{}; std::string alias_{};
}; };
// Light-weight kernel implementation. // Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target // The OpKernel is designed to implement the specific algorithm on a target
// device. // device.
// TODO(Superjomn) Consider to add a Platform type to differentiate CUDNN,
// MKLDNN, plain CUDA C implementations.
template <TargetType Target, PrecisionType Precision, template <TargetType Target, PrecisionType Precision,
DataLayoutType DataLayout = DataLayoutType::kNCHW> DataLayoutType DataLayout = DataLayoutType::kNCHW>
class OpKernel : public KernelBase { class OpKernel : public KernelBase {
......
...@@ -24,13 +24,13 @@ class ArgumentTypeDisplayPass : public DebugPass { ...@@ -24,13 +24,13 @@ class ArgumentTypeDisplayPass : public DebugPass {
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override { void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {
LOG(INFO) << "== Argument types =="; LOG(INFO) << "== Argument types ==";
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (!node.IsArgument()) continue; if (!node.IsArg()) continue;
auto* type = node.AsArgument().type; auto* type = node.AsArg().type;
if (type) { if (type) {
LOG(INFO) << "* ARG " << node.AsArgument().name << " type: " << *type; LOG(INFO) << "* ARG " << node.AsArg().name << " type: " << *type;
} else { } else {
LOG(INFO) << "* ARG " << node.AsArgument().name << " type: UNK"; LOG(INFO) << "* ARG " << node.AsArg().name << " type: UNK";
} }
} }
LOG(INFO) << "---------------------"; LOG(INFO) << "---------------------";
......
...@@ -23,11 +23,10 @@ namespace mir { ...@@ -23,11 +23,10 @@ namespace mir {
void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
LOG(INFO) << "final program \n" << Visualize(graph.get()); LOG(INFO) << "final program \n" << Visualize(graph.get());
for (auto& item : graph->InstructTopologicalOrder()) { for (auto& item : graph->InstructTopologicalOrder()) {
if (item->IsInstruct()) { if (item->IsStmt()) {
auto& instruct = item->AsInstruct(); auto& stmt = item->AsStmt();
LOG(INFO) << instruct; LOG(INFO) << stmt;
insts_.emplace_back(instruct.op, insts_.emplace_back(stmt.op, std::move(stmt.valid_kernels.front()));
std::move(instruct.valid_kernels.front()));
} }
} }
} }
......
...@@ -34,16 +34,16 @@ std::string Visualize(mir::SSAGraph* graph) { ...@@ -34,16 +34,16 @@ std::string Visualize(mir::SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
std::string key; std::string key;
if (node.IsArgument()) { if (node.IsArg()) {
key = node.AsArgument().name; key = node.AsArg().name;
} else { } else {
key = node.AsInstruct().op_type + std::to_string(id++); key = node.AsStmt().op_type + std::to_string(id++);
} }
if (node.IsInstruct()) { if (node.IsStmt()) {
dot.AddNode(key, {Dot::Attr("shape", "box")}); dot.AddNode(key, {Dot::Attr("shape", "box")});
for (auto& x : node.inlinks) { for (auto& x : node.inlinks) {
auto name = x->AsArgument().name; auto name = x->AsArg().name;
if (!exists_args.count(name)) { if (!exists_args.count(name)) {
dot.AddNode(name, {}); dot.AddNode(name, {});
} }
...@@ -51,7 +51,7 @@ std::string Visualize(mir::SSAGraph* graph) { ...@@ -51,7 +51,7 @@ std::string Visualize(mir::SSAGraph* graph) {
exists_args.insert(name); exists_args.insert(name);
} }
for (auto& x : node.outlinks) { for (auto& x : node.outlinks) {
auto name = x->AsArgument().name; auto name = x->AsArg().name;
if (!exists_args.count(name)) { if (!exists_args.count(name)) {
dot.AddNode(name, {}); dot.AddNode(name, {});
} }
......
...@@ -19,20 +19,20 @@ namespace paddle { ...@@ -19,20 +19,20 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
class IoCopyKernelPickPass : public InstructionPass { class IoCopyKernelPickPass : public StmtPass {
public: public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override { void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue; if (!node.IsStmt()) continue;
auto& inst = node.AsInstruct(); auto& inst = node.AsStmt();
if (inst.op_type != "io_copy") continue; if (inst.op_type != "io_copy") continue;
LOG(INFO) << "....> picking a IO COPY kernel"; LOG(INFO) << "....> picking a IO COPY kernel";
auto& kernels = node.AsInstruct().valid_kernels; auto& kernels = node.AsStmt().valid_kernels;
CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op"; CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op";
const auto* inty = node.inlinks.front()->AsArgument().type; const auto* inty = node.inlinks.front()->AsArg().type;
const auto* outy = node.outlinks.front()->AsArgument().type; const auto* outy = node.outlinks.front()->AsArg().type;
LOG(INFO) << "input type " << *inty; LOG(INFO) << "input type " << *inty;
LOG(INFO) << "output type " << *outy; LOG(INFO) << "output type " << *outy;
......
...@@ -34,15 +34,15 @@ class Node { ...@@ -34,15 +34,15 @@ class Node {
Node() = default; Node() = default;
enum class Role { enum class Role {
kArgument = 0, kArg = 0,
kInstruct, kStmt,
kNumRoles, /*should be last*/ kNumRoles, /*should be last*/
kUnk, kUnk,
}; };
struct Instruct { struct Stmt {
std::string op_type; std::string op_type;
// The kernel instances this Instruct contains. // The kernel instances this Statement contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels; std::vector<std::unique_ptr<KernelBase>> valid_kernels;
// TODO(Superjomn) make this a shared_ptr for resource safety. // TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape std::shared_ptr<OpLite> op; // we hold op to run InferShape
...@@ -62,13 +62,13 @@ class Node { ...@@ -62,13 +62,13 @@ class Node {
return *valid_kernels.front(); return *valid_kernels.front();
} }
friend std::ostream& operator<<(std::ostream& os, const Instruct& other) { friend std::ostream& operator<<(std::ostream& os, const Stmt& other) {
os << "Instruct " << other.op_type << " " << other.place(); os << "Statement " << other.op_type << " " << other.place();
return os; return os;
} }
}; };
struct Argument { struct Arg {
std::string name; std::string name;
const Type* type{}; const Type* type{};
// Weight is a special kind of argument, it is marked as weight explicitly // Weight is a special kind of argument, it is marked as weight explicitly
...@@ -76,16 +76,16 @@ class Node { ...@@ -76,16 +76,16 @@ class Node {
bool is_weight{false}; bool is_weight{false};
}; };
Argument& AsArgument(const std::string& name) { Arg& AsArg(const std::string& name) {
auto& x = AsArgument(); auto& x = AsArg();
x.name = name; x.name = name;
return x; return x;
} }
Instruct& AsInstruct(const std::string& op_type, Stmt& AsStmt(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels, std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<OpLite>& op) { const std::shared_ptr<OpLite>& op) {
auto& x = AsInstruct(); auto& x = AsStmt();
x.op_type = op_type; x.op_type = op_type;
x.op = op; x.op = op;
x.valid_kernels = std::move(kernels); x.valid_kernels = std::move(kernels);
...@@ -93,23 +93,23 @@ class Node { ...@@ -93,23 +93,23 @@ class Node {
} }
// Set roles. // Set roles.
Argument& AsArgument() { Arg& AsArg() {
if (role_ != Role::kUnk) { if (role_ != Role::kUnk) {
CHECK(role_ == Role::kArgument); CHECK(role_ == Role::kArg);
return *argument_; return *arg_;
} }
role_ = Role::kArgument; role_ = Role::kArg;
argument_.reset(new Argument); arg_.reset(new Arg);
return *argument_; return *arg_;
} }
Instruct& AsInstruct() { Stmt& AsStmt() {
if (role_ != Role::kUnk) { if (role_ != Role::kUnk) {
CHECK(role_ == Role::kInstruct); CHECK(role_ == Role::kStmt);
return *instruct_; return *stmt_;
} }
role_ = Role::kInstruct; role_ = Role::kStmt;
instruct_.reset(new Instruct); stmt_.reset(new Stmt);
return *instruct_; return *stmt_;
} }
friend std::ostream& operator<<(std::ostream& os, Node& other) { friend std::ostream& operator<<(std::ostream& os, Node& other) {
...@@ -117,26 +117,26 @@ class Node { ...@@ -117,26 +117,26 @@ class Node {
if (!other.IsRoleSet()) { if (!other.IsRoleSet()) {
os << "Unk role node"; os << "Unk role node";
} }
if (other.IsArgument()) { if (other.IsArg()) {
auto& arg = other.AsArgument(); auto& arg = other.AsArg();
os << "Argument " << arg.name; os << "Argument " << arg.name;
} }
if (other.IsInstruct()) { if (other.IsStmt()) {
auto& arg = other.AsInstruct(); auto& arg = other.AsStmt();
os << "Instruct " << arg.op_type; os << "Statement " << arg.op_type;
} }
return os; return os;
} }
// Check roles. // Check roles.
bool IsRoleSet() const { return role_ != Role::kUnk; } bool IsRoleSet() const { return role_ != Role::kUnk; }
bool IsInstruct() const { return role_ == Role::kInstruct; } bool IsStmt() const { return role_ == Role::kStmt; }
bool IsArgument() const { return role_ == Role::kArgument; } bool IsArg() const { return role_ == Role::kArg; }
private: private:
// Either instruct_ or argument_ is used. // Either stmt_ or argument_ is used.
std::unique_ptr<Instruct> instruct_; std::unique_ptr<Stmt> stmt_;
std::unique_ptr<Argument> argument_; std::unique_ptr<Arg> arg_;
Role role_{Role::kUnk}; Role role_{Role::kUnk};
}; };
......
...@@ -26,8 +26,8 @@ class Pass { ...@@ -26,8 +26,8 @@ class Pass {
enum class Kind { enum class Kind {
// Will modify the program/graph topology. // Will modify the program/graph topology.
kProgramWise = 0, kProgramWise = 0,
// Will modify the instruction, with the graph topology fixed. // Will modify the statement, with the graph topology fixed.
kInstructionWise, kStmtWise,
// Will not modify the IR, just collect information or visualization. // Will not modify the IR, just collect information or visualization.
kDebug, kDebug,
}; };
...@@ -45,7 +45,7 @@ class Pass { ...@@ -45,7 +45,7 @@ class Pass {
Kind kind() const { return kind_; } Kind kind() const { return kind_; }
bool is_debug_pass() const { return kind_ == Kind::kDebug; } bool is_debug_pass() const { return kind_ == Kind::kDebug; }
bool is_program_pass() const { return kind_ == Kind::kProgramWise; } bool is_program_pass() const { return kind_ == Kind::kProgramWise; }
bool is_instruction_pass() const { return kind_ == Kind::kInstructionWise; } bool is_stmt_pass() const { return kind_ == Kind::kStmtWise; }
virtual ~Pass() = default; virtual ~Pass() = default;
...@@ -61,9 +61,9 @@ class ProgramPass : public Pass { ...@@ -61,9 +61,9 @@ class ProgramPass : public Pass {
ProgramPass() : Pass(Kind::kProgramWise) {} ProgramPass() : Pass(Kind::kProgramWise) {}
}; };
class InstructionPass : public Pass { class StmtPass : public Pass {
public: public:
InstructionPass() : Pass(Kind::kInstructionWise) {} StmtPass() : Pass(Kind::kStmtWise) {}
}; };
class DebugPass : public Pass { class DebugPass : public Pass {
......
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
class RuntimeContextAssignPass : public InstructionPass { class RuntimeContextAssignPass : public StmtPass {
public: public:
RuntimeContextAssignPass() { RuntimeContextAssignPass() {
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
...@@ -29,9 +29,9 @@ class RuntimeContextAssignPass : public InstructionPass { ...@@ -29,9 +29,9 @@ class RuntimeContextAssignPass : public InstructionPass {
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override { void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue; if (!node.IsStmt()) continue;
auto& inst = node.AsInstruct(); auto& inst = node.AsStmt();
switch (inst.picked_kernel().target()) { switch (inst.picked_kernel().target()) {
case TARGET(kHost): case TARGET(kHost):
......
...@@ -37,14 +37,14 @@ std::map<mir::Node *, std::set<mir::Node *>> SSAGraph::BuildOperationAdjList() { ...@@ -37,14 +37,14 @@ std::map<mir::Node *, std::set<mir::Node *>> SSAGraph::BuildOperationAdjList() {
std::map<mir::Node *, std::set<mir::Node *>> adj_list; std::map<mir::Node *, std::set<mir::Node *>> adj_list;
for (auto &n : mutable_nodes()) { for (auto &n : mutable_nodes()) {
if (!n.IsInstruct()) continue; if (!n.IsStmt()) continue;
if (adj_list.find(&n) == adj_list.end()) { if (adj_list.find(&n) == adj_list.end()) {
adj_list[&n] = std::set<mir::Node *>(); adj_list[&n] = std::set<mir::Node *>();
} }
std::vector<mir::Node *> nodes; std::vector<mir::Node *> nodes;
for (auto &var : n.inlinks) { for (auto &var : n.inlinks) {
for (auto &adj_n : var->inlinks) { for (auto &adj_n : var->inlinks) {
PADDLE_ENFORCE(adj_n->IsInstruct()); PADDLE_ENFORCE(adj_n->IsStmt());
nodes.push_back(adj_n); nodes.push_back(adj_n);
} }
} }
...@@ -96,7 +96,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { ...@@ -96,7 +96,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
VLOG(5) << "create arg node " << name; VLOG(5) << "create arg node " << name;
node_storage_.emplace_back(); node_storage_.emplace_back();
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
new_node.AsArgument(name); new_node.AsArg(name);
arguments_[name] = &new_node; arguments_[name] = &new_node;
} }
} }
...@@ -109,7 +109,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { ...@@ -109,7 +109,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
VLOG(5) << "create arg node " << name; VLOG(5) << "create arg node " << name;
node_storage_.emplace_back(); node_storage_.emplace_back();
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
new_node.AsArgument(name); new_node.AsArg(name);
arguments_[name] = &new_node; arguments_[name] = &new_node;
} }
} }
...@@ -122,7 +122,7 @@ Node *SSAGraph::GraphCreateInstructNode( ...@@ -122,7 +122,7 @@ Node *SSAGraph::GraphCreateInstructNode(
op->SetValidPlaces(valid_places); op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
auto kernels = op->CreateKernels(valid_places); auto kernels = op->CreateKernels(valid_places);
node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op); node_storage_.back().AsStmt(op->op_type_, std::move(kernels), op);
CHECK(new_node.inlinks.empty()) << "duplicate Build found"; CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found"; CHECK(new_node.outlinks.empty()) << "duplicate Build found";
...@@ -202,14 +202,14 @@ bool SSAGraph::CheckNodesRoleSet() { ...@@ -202,14 +202,14 @@ bool SSAGraph::CheckNodesRoleSet() {
bool SSAGraph::CheckLinksRoleSet() { bool SSAGraph::CheckLinksRoleSet() {
for (auto &node : mutable_nodes()) { for (auto &node : mutable_nodes()) {
CHECK_OR_FALSE(node.IsRoleSet()); CHECK_OR_FALSE(node.IsRoleSet());
if (!node.IsInstruct()) continue; if (!node.IsStmt()) continue;
for (auto *x : node.inlinks) { for (auto *x : node.inlinks) {
CHECK_OR_FALSE(x->IsRoleSet()); CHECK_OR_FALSE(x->IsRoleSet());
CHECK_OR_FALSE(x->IsArgument()); CHECK_OR_FALSE(x->IsArg());
} }
for (auto *x : node.outlinks) { for (auto *x : node.outlinks) {
CHECK_OR_FALSE(x->IsRoleSet()); CHECK_OR_FALSE(x->IsRoleSet());
CHECK_OR_FALSE(x->IsArgument()); CHECK_OR_FALSE(x->IsArg());
} }
} }
return true; return true;
...@@ -219,7 +219,7 @@ Node *SSAGraph::NewArgumentNode(const std::string &name) { ...@@ -219,7 +219,7 @@ Node *SSAGraph::NewArgumentNode(const std::string &name) {
node_storage_.emplace_back(); node_storage_.emplace_back();
CHECK(!arguments_.count(name)) << "duplicate argument called " << name; CHECK(!arguments_.count(name)) << "duplicate argument called " << name;
arguments_[name] = &node_storage_.back(); arguments_[name] = &node_storage_.back();
node_storage_.back().AsArgument(name); node_storage_.back().AsArg(name);
return &node_storage_.back(); return &node_storage_.back();
} }
......
...@@ -76,7 +76,7 @@ class SSAGraph : GraphBase { ...@@ -76,7 +76,7 @@ class SSAGraph : GraphBase {
void MarkArgumentWeights(const Program &program) { void MarkArgumentWeights(const Program &program) {
for (const auto &name : program.weights) { for (const auto &name : program.weights) {
arguments_[name]->AsArgument().is_weight = true; arguments_[name]->AsArg().is_weight = true;
} }
} }
...@@ -115,9 +115,9 @@ static void DirectedLink(Node *a, Node *b) { ...@@ -115,9 +115,9 @@ static void DirectedLink(Node *a, Node *b) {
static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) { static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) {
// instr -> output argument // instr -> output argument
if (a->IsInstruct() && b->IsArgument()) { if (a->IsStmt() && b->IsArg()) {
auto &inst = a->AsInstruct(); auto &inst = a->AsStmt();
auto &output = b->AsArgument(); auto &output = b->AsArg();
if (!output.type) { if (!output.type) {
output.type = inst.picked_kernel().GetOutputDeclType(arg_name); output.type = inst.picked_kernel().GetOutputDeclType(arg_name);
...@@ -125,9 +125,9 @@ static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) { ...@@ -125,9 +125,9 @@ static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) {
} }
// input argument -> instr // input argument -> instr
if (a->IsArgument() && b->IsInstruct()) { if (a->IsArg() && b->IsStmt()) {
auto &input = a->AsArgument(); auto &input = a->AsArg();
auto &inst = b->AsInstruct(); auto &inst = b->AsStmt();
if (!input.type) { if (!input.type) {
input.type = inst.picked_kernel().GetInputDeclType(arg_name); input.type = inst.picked_kernel().GetInputDeclType(arg_name);
} }
......
...@@ -33,8 +33,8 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -33,8 +33,8 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK(graph) << "graph not valid"; CHECK(graph) << "graph not valid";
// sort kernels by the factors. // sort kernels by the factors.
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue; if (!node.IsStmt()) continue;
auto& instruct = node.AsInstruct(); auto& instruct = node.AsStmt();
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored; std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
for (auto&& kernel : instruct.valid_kernels) { for (auto&& kernel : instruct.valid_kernels) {
size_t score = KernelGrade(*kernel); size_t score = KernelGrade(*kernel);
......
...@@ -33,7 +33,7 @@ namespace mir { ...@@ -33,7 +33,7 @@ namespace mir {
* - kernel_pick_factors, the factors to consider in picking kernels. * - kernel_pick_factors, the factors to consider in picking kernels.
* Set them first before execute the pass. * Set them first before execute the pass.
*/ */
class StaticKernelPickPass : public mir::InstructionPass { class StaticKernelPickPass : public mir::StmtPass {
public: public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override; void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
......
...@@ -33,7 +33,7 @@ void TypeTargetTransformPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -33,7 +33,7 @@ void TypeTargetTransformPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK(!valid_places_.empty()); CHECK(!valid_places_.empty());
for (auto& node : nodes) { for (auto& node : nodes) {
if (!node->IsInstruct()) continue; if (!node->IsStmt()) continue;
auto inlinks = node->inlinks; auto inlinks = node->inlinks;
for (auto* in : inlinks) { for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in); ComplementInputs(graph.get(), node, in);
...@@ -49,22 +49,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node, ...@@ -49,22 +49,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
return; return;
CHECK(inst_node->IsInstruct()); CHECK(inst_node->IsStmt());
auto& inst = inst_node->AsInstruct(); auto& inst = inst_node->AsStmt();
CHECK(in->IsRoleSet()); CHECK(in->IsRoleSet());
CHECK(in->IsArgument()); CHECK(in->IsArg());
auto in_arg_name = in->AsArgument().name; auto in_arg_name = in->AsArg().name;
std::string tmp; std::string tmp;
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
CHECK(in->AsArgument().type); CHECK(in->AsArg().type);
if (!TargetCompatibleTo(*in->AsArgument().type, *decl_arg_type)) { if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) {
LOG(INFO) << "found Target unmatched tensor: " << in->AsArgument().name LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name
<< " for kernel " << inst.op->DebugString() << " " << " for kernel " << inst.op->DebugString() << " "
<< *in->AsArgument().type << " -> " << *decl_arg_type; << *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist. // Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArgument().type, *decl_arg_type, in->AsArgument().name, AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in->AsArg().name, graph,
graph, inst_node, valid_places_); inst_node, valid_places_);
} }
} }
...@@ -73,7 +73,7 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -73,7 +73,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
Node* inst_node, const std::vector<Place>& valid_places) { Node* inst_node, const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set"; CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst // var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Instruct Node. // So there will be a new Argument node and a new IoCopy Statement Node.
auto node_id = [&] { return graph->nodes().size(); }; auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name = var + "/trans/" + std::to_string(node_id()); auto io_copy_output_name = var + "/trans/" + std::to_string(node_id());
...@@ -85,7 +85,7 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -85,7 +85,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
// CHECK(io_copy_op); // CHECK(io_copy_op);
// Create the new var manually. // Create the new var manually.
inst_node->AsInstruct().op->scope()->Var(io_copy_output_name); inst_node->AsStmt().op->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction. // Create IoCopy Instruction.
lite::OpDesc op_desc; lite::OpDesc op_desc;
...@@ -93,16 +93,16 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -93,16 +93,16 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc.SetInput("Input", {var}); op_desc.SetInput("Input", {var});
op_desc.SetOutput("Out", {io_copy_output_name}); op_desc.SetOutput("Out", {io_copy_output_name});
io_copy_op->Attach(op_desc, inst_node->AsInstruct().op->scope()); io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope());
auto kernels = io_copy_op->CreateKernels(valid_places); auto kernels = io_copy_op->CreateKernels(valid_places);
io_copy_inst->AsInstruct("io_copy", std::move(kernels), io_copy_op); io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op);
// Remove the old link // Remove the old link
RemoveDirectedLink(graph->Argument(var), inst_node); RemoveDirectedLink(graph->Argument(var), inst_node);
// Update the original instruction OpDesc. // Update the original instruction OpDesc.
// Update its input to the io_copy_output_name // Update its input to the io_copy_output_name
auto& inst = inst_node->AsInstruct(); auto& inst = inst_node->AsStmt();
auto inst_program_desc = inst.op_info()->desc(); auto inst_program_desc = inst.op_info()->desc();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst // Add new link, var -> new_inst, new_inst->newarg, newarg->inst
...@@ -111,20 +111,19 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -111,20 +111,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink(io_copy_output_arg, inst_node); DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information // reset opdesc and update kernel information
auto desc_dummy = inst_node->AsInstruct().op->op_info()->desc(); auto desc_dummy = inst_node->AsStmt().op->op_info()->desc();
UpdateInputTo(&desc_dummy, var, io_copy_output_name); UpdateInputTo(&desc_dummy, var, io_copy_output_name);
lite::OpDesc desc_fake(desc_dummy); lite::OpDesc desc_fake(desc_dummy);
inst_node->AsInstruct().op->Attach(desc_fake, inst_node->AsStmt().op->Attach(desc_fake, inst_node->AsStmt().op->scope());
inst_node->AsInstruct().op->scope());
std::string tmp; std::string tmp;
if (inst_node->AsInstruct().op_info()->GetInputArgname("a", &tmp)) { if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) {
CHECK(false) << "get old a " << tmp; CHECK(false) << "get old a " << tmp;
} }
for (auto& kernel : inst_node->AsInstruct().valid_kernels) { for (auto& kernel : inst_node->AsStmt().valid_kernels) {
inst_node->AsInstruct().op->AttachKernel(kernel.get()); inst_node->AsStmt().op->AttachKernel(kernel.get());
} }
graph->CheckValid(); graph->CheckValid();
......
...@@ -34,8 +34,8 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -34,8 +34,8 @@ class VariablePlaceInferencePass : public DebugPass {
CHECK(!graph->inputs().empty()) << "graph's inputs should be set"; CHECK(!graph->inputs().empty()) << "graph's inputs should be set";
for (const auto& v : graph->inputs()) { for (const auto& v : graph->inputs()) {
// the feed op might in the inputs // the feed op might in the inputs
if (v->IsInstruct()) { if (v->IsStmt()) {
LOG(INFO) << "found kernel in inputs " << v->AsInstruct().op_type; LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type;
continue; continue;
} }
...@@ -49,9 +49,9 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -49,9 +49,9 @@ class VariablePlaceInferencePass : public DebugPass {
void CheckAllArgumentTypeDetermined(SSAGraph* graph) { void CheckAllArgumentTypeDetermined(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (node.IsArgument()) { if (node.IsArg()) {
CHECK(node.AsArgument().type) << "node " << node.AsArgument().name CHECK(node.AsArg().type) << "node " << node.AsArg().name
<< " type not determined, " << &node; << " type not determined, " << &node;
} }
} }
} }
...@@ -59,7 +59,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -59,7 +59,7 @@ class VariablePlaceInferencePass : public DebugPass {
void InferenceArgumentPlace(SSAGraph* graph) { void InferenceArgumentPlace(SSAGraph* graph) {
VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global(); VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->InstructTopologicalOrder()) { for (auto& x : graph->InstructTopologicalOrder()) {
auto& inst = x->AsInstruct(); auto& inst = x->AsStmt();
// The IoCopyOp is a tool operator, it won't support the type inference. // The IoCopyOp is a tool operator, it won't support the type inference.
if (inst.op_type == "io_copy") continue; if (inst.op_type == "io_copy") continue;
// LOG(INFO) << "- inferencing type " << // LOG(INFO) << "- inferencing type " <<
...@@ -76,7 +76,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -76,7 +76,7 @@ class VariablePlaceInferencePass : public DebugPass {
VLOG(3) << "--- var " << arg_name; VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name); auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph"; CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument(); auto& arg_node = node->AsArg();
if (!arg_node.type) { if (!arg_node.type) {
VLOG(4) << "set type " << *type << " " << node; VLOG(4) << "set type " << *type << " " << node;
arg_node.type = type; arg_node.type = type;
...@@ -94,9 +94,9 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -94,9 +94,9 @@ class VariablePlaceInferencePass : public DebugPass {
VLOG(3) << "--- var " << arg_name; VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name); auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph"; CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument(); auto& arg_node = node->AsArg();
if (!arg_node.type) { if (!arg_node.type) {
node->AsArgument().type = type; node->AsArg().type = type;
VLOG(3) << "set type " << *type; VLOG(3) << "set type " << *type;
} }
} }
......
...@@ -38,7 +38,7 @@ TEST(Optimizer, test) { ...@@ -38,7 +38,7 @@ TEST(Optimizer, test) {
optimizer.Run(std::move(program), places); optimizer.Run(std::move(program), places);
auto runtime_program = optimizer.GenRuntimeProgram(); auto runtime_program = optimizer.GenRuntimeProgram();
LOG(INFO) << "num instructions " << runtime_program->num_instructions(); LOG(INFO) << "num statements " << runtime_program->num_instructions();
} }
} // namespace lite } // namespace lite
......
...@@ -152,7 +152,7 @@ class Type : public DataTypeBase { ...@@ -152,7 +152,7 @@ class Type : public DataTypeBase {
} }
// Can cast to another type. This is heavily used in MIR, by determine whether // Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another. // is is possible to add a statement to transform a type to another.
virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); } virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); }
template <bool is_unknown, bool is_tensor = true, template <bool is_unknown, bool is_tensor = true,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册