提交 f1ca00a4 编写于 作者: S Superjomn

rename some concepts

Instruction to Stmt
上级 ec27aa46
......@@ -104,12 +104,16 @@ class KernelBase {
mutable operators::param_t param_;
// The corresponding op type.
std::string op_type_{};
// The extra identity to help defficiate a specific kernel, op_type_ + alias_
// is the unique ID for the kernel.
std::string alias_{};
};
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// device.
// TODO(Superjomn) Consider to add a Platform type to differentiate CUDNN,
// MKLDNN, plain CUDA C implementations.
template <TargetType Target, PrecisionType Precision,
DataLayoutType DataLayout = DataLayoutType::kNCHW>
class OpKernel : public KernelBase {
......
......@@ -24,13 +24,13 @@ class ArgumentTypeDisplayPass : public DebugPass {
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {
LOG(INFO) << "== Argument types ==";
for (auto& node : graph->mutable_nodes()) {
if (!node.IsArgument()) continue;
if (!node.IsArg()) continue;
auto* type = node.AsArgument().type;
auto* type = node.AsArg().type;
if (type) {
LOG(INFO) << "* ARG " << node.AsArgument().name << " type: " << *type;
LOG(INFO) << "* ARG " << node.AsArg().name << " type: " << *type;
} else {
LOG(INFO) << "* ARG " << node.AsArgument().name << " type: UNK";
LOG(INFO) << "* ARG " << node.AsArg().name << " type: UNK";
}
}
LOG(INFO) << "---------------------";
......
......@@ -23,11 +23,10 @@ namespace mir {
void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
LOG(INFO) << "final program \n" << Visualize(graph.get());
for (auto& item : graph->InstructTopologicalOrder()) {
if (item->IsInstruct()) {
auto& instruct = item->AsInstruct();
LOG(INFO) << instruct;
insts_.emplace_back(instruct.op,
std::move(instruct.valid_kernels.front()));
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << stmt;
insts_.emplace_back(stmt.op, std::move(stmt.valid_kernels.front()));
}
}
}
......
......@@ -34,16 +34,16 @@ std::string Visualize(mir::SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
std::string key;
if (node.IsArgument()) {
key = node.AsArgument().name;
if (node.IsArg()) {
key = node.AsArg().name;
} else {
key = node.AsInstruct().op_type + std::to_string(id++);
key = node.AsStmt().op_type + std::to_string(id++);
}
if (node.IsInstruct()) {
if (node.IsStmt()) {
dot.AddNode(key, {Dot::Attr("shape", "box")});
for (auto& x : node.inlinks) {
auto name = x->AsArgument().name;
auto name = x->AsArg().name;
if (!exists_args.count(name)) {
dot.AddNode(name, {});
}
......@@ -51,7 +51,7 @@ std::string Visualize(mir::SSAGraph* graph) {
exists_args.insert(name);
}
for (auto& x : node.outlinks) {
auto name = x->AsArgument().name;
auto name = x->AsArg().name;
if (!exists_args.count(name)) {
dot.AddNode(name, {});
}
......
......@@ -19,20 +19,20 @@ namespace paddle {
namespace lite {
namespace mir {
class IoCopyKernelPickPass : public InstructionPass {
class IoCopyKernelPickPass : public StmtPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {
for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue;
auto& inst = node.AsInstruct();
if (!node.IsStmt()) continue;
auto& inst = node.AsStmt();
if (inst.op_type != "io_copy") continue;
LOG(INFO) << "....> picking a IO COPY kernel";
auto& kernels = node.AsInstruct().valid_kernels;
auto& kernels = node.AsStmt().valid_kernels;
CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op";
const auto* inty = node.inlinks.front()->AsArgument().type;
const auto* outy = node.outlinks.front()->AsArgument().type;
const auto* inty = node.inlinks.front()->AsArg().type;
const auto* outy = node.outlinks.front()->AsArg().type;
LOG(INFO) << "input type " << *inty;
LOG(INFO) << "output type " << *outy;
......
......@@ -34,15 +34,15 @@ class Node {
Node() = default;
enum class Role {
kArgument = 0,
kInstruct,
kArg = 0,
kStmt,
kNumRoles, /*should be last*/
kUnk,
};
struct Instruct {
struct Stmt {
std::string op_type;
// The kernel instances this Instruct contains.
// The kernel instances this Statement contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape
......@@ -62,13 +62,13 @@ class Node {
return *valid_kernels.front();
}
friend std::ostream& operator<<(std::ostream& os, const Instruct& other) {
os << "Instruct " << other.op_type << " " << other.place();
friend std::ostream& operator<<(std::ostream& os, const Stmt& other) {
os << "Statement " << other.op_type << " " << other.place();
return os;
}
};
struct Argument {
struct Arg {
std::string name;
const Type* type{};
// Weight is a special kind of argument, it is marked as weight explicitly
......@@ -76,16 +76,16 @@ class Node {
bool is_weight{false};
};
Argument& AsArgument(const std::string& name) {
auto& x = AsArgument();
Arg& AsArg(const std::string& name) {
auto& x = AsArg();
x.name = name;
return x;
}
Instruct& AsInstruct(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<OpLite>& op) {
auto& x = AsInstruct();
Stmt& AsStmt(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<OpLite>& op) {
auto& x = AsStmt();
x.op_type = op_type;
x.op = op;
x.valid_kernels = std::move(kernels);
......@@ -93,23 +93,23 @@ class Node {
}
// Set roles.
Argument& AsArgument() {
Arg& AsArg() {
if (role_ != Role::kUnk) {
CHECK(role_ == Role::kArgument);
return *argument_;
CHECK(role_ == Role::kArg);
return *arg_;
}
role_ = Role::kArgument;
argument_.reset(new Argument);
return *argument_;
role_ = Role::kArg;
arg_.reset(new Arg);
return *arg_;
}
Instruct& AsInstruct() {
Stmt& AsStmt() {
if (role_ != Role::kUnk) {
CHECK(role_ == Role::kInstruct);
return *instruct_;
CHECK(role_ == Role::kStmt);
return *stmt_;
}
role_ = Role::kInstruct;
instruct_.reset(new Instruct);
return *instruct_;
role_ = Role::kStmt;
stmt_.reset(new Stmt);
return *stmt_;
}
friend std::ostream& operator<<(std::ostream& os, Node& other) {
......@@ -117,26 +117,26 @@ class Node {
if (!other.IsRoleSet()) {
os << "Unk role node";
}
if (other.IsArgument()) {
auto& arg = other.AsArgument();
if (other.IsArg()) {
auto& arg = other.AsArg();
os << "Argument " << arg.name;
}
if (other.IsInstruct()) {
auto& arg = other.AsInstruct();
os << "Instruct " << arg.op_type;
if (other.IsStmt()) {
auto& arg = other.AsStmt();
os << "Statement " << arg.op_type;
}
return os;
}
// Check roles.
bool IsRoleSet() const { return role_ != Role::kUnk; }
bool IsInstruct() const { return role_ == Role::kInstruct; }
bool IsArgument() const { return role_ == Role::kArgument; }
bool IsStmt() const { return role_ == Role::kStmt; }
bool IsArg() const { return role_ == Role::kArg; }
private:
// Either instruct_ or argument_ is used.
std::unique_ptr<Instruct> instruct_;
std::unique_ptr<Argument> argument_;
// Either stmt_ or argument_ is used.
std::unique_ptr<Stmt> stmt_;
std::unique_ptr<Arg> arg_;
Role role_{Role::kUnk};
};
......
......@@ -26,8 +26,8 @@ class Pass {
enum class Kind {
// Will modify the program/graph topology.
kProgramWise = 0,
// Will modify the instruction, with the graph topology fixed.
kInstructionWise,
// Will modify the statement, with the graph topology fixed.
kStmtWise,
// Will not modify the IR, just collect information or visualization.
kDebug,
};
......@@ -45,7 +45,7 @@ class Pass {
Kind kind() const { return kind_; }
bool is_debug_pass() const { return kind_ == Kind::kDebug; }
bool is_program_pass() const { return kind_ == Kind::kProgramWise; }
bool is_instruction_pass() const { return kind_ == Kind::kInstructionWise; }
bool is_stmt_pass() const { return kind_ == Kind::kStmtWise; }
virtual ~Pass() = default;
......@@ -61,9 +61,9 @@ class ProgramPass : public Pass {
ProgramPass() : Pass(Kind::kProgramWise) {}
};
class InstructionPass : public Pass {
class StmtPass : public Pass {
public:
InstructionPass() : Pass(Kind::kInstructionWise) {}
StmtPass() : Pass(Kind::kStmtWise) {}
};
class DebugPass : public Pass {
......
......@@ -19,7 +19,7 @@ namespace paddle {
namespace lite {
namespace mir {
class RuntimeContextAssignPass : public InstructionPass {
class RuntimeContextAssignPass : public StmtPass {
public:
RuntimeContextAssignPass() {
#ifdef LITE_WITH_CUDA
......@@ -29,9 +29,9 @@ class RuntimeContextAssignPass : public InstructionPass {
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {
for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue;
if (!node.IsStmt()) continue;
auto& inst = node.AsInstruct();
auto& inst = node.AsStmt();
switch (inst.picked_kernel().target()) {
case TARGET(kHost):
......
......@@ -37,14 +37,14 @@ std::map<mir::Node *, std::set<mir::Node *>> SSAGraph::BuildOperationAdjList() {
std::map<mir::Node *, std::set<mir::Node *>> adj_list;
for (auto &n : mutable_nodes()) {
if (!n.IsInstruct()) continue;
if (!n.IsStmt()) continue;
if (adj_list.find(&n) == adj_list.end()) {
adj_list[&n] = std::set<mir::Node *>();
}
std::vector<mir::Node *> nodes;
for (auto &var : n.inlinks) {
for (auto &adj_n : var->inlinks) {
PADDLE_ENFORCE(adj_n->IsInstruct());
PADDLE_ENFORCE(adj_n->IsStmt());
nodes.push_back(adj_n);
}
}
......@@ -96,7 +96,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
VLOG(5) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArgument(name);
new_node.AsArg(name);
arguments_[name] = &new_node;
}
}
......@@ -109,7 +109,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
VLOG(5) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArgument(name);
new_node.AsArg(name);
arguments_[name] = &new_node;
}
}
......@@ -122,7 +122,7 @@ Node *SSAGraph::GraphCreateInstructNode(
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
auto kernels = op->CreateKernels(valid_places);
node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op);
node_storage_.back().AsStmt(op->op_type_, std::move(kernels), op);
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
......@@ -202,14 +202,14 @@ bool SSAGraph::CheckNodesRoleSet() {
bool SSAGraph::CheckLinksRoleSet() {
for (auto &node : mutable_nodes()) {
CHECK_OR_FALSE(node.IsRoleSet());
if (!node.IsInstruct()) continue;
if (!node.IsStmt()) continue;
for (auto *x : node.inlinks) {
CHECK_OR_FALSE(x->IsRoleSet());
CHECK_OR_FALSE(x->IsArgument());
CHECK_OR_FALSE(x->IsArg());
}
for (auto *x : node.outlinks) {
CHECK_OR_FALSE(x->IsRoleSet());
CHECK_OR_FALSE(x->IsArgument());
CHECK_OR_FALSE(x->IsArg());
}
}
return true;
......@@ -219,7 +219,7 @@ Node *SSAGraph::NewArgumentNode(const std::string &name) {
node_storage_.emplace_back();
CHECK(!arguments_.count(name)) << "duplicate argument called " << name;
arguments_[name] = &node_storage_.back();
node_storage_.back().AsArgument(name);
node_storage_.back().AsArg(name);
return &node_storage_.back();
}
......
......@@ -76,7 +76,7 @@ class SSAGraph : GraphBase {
void MarkArgumentWeights(const Program &program) {
for (const auto &name : program.weights) {
arguments_[name]->AsArgument().is_weight = true;
arguments_[name]->AsArg().is_weight = true;
}
}
......@@ -115,9 +115,9 @@ static void DirectedLink(Node *a, Node *b) {
static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) {
// instr -> output argument
if (a->IsInstruct() && b->IsArgument()) {
auto &inst = a->AsInstruct();
auto &output = b->AsArgument();
if (a->IsStmt() && b->IsArg()) {
auto &inst = a->AsStmt();
auto &output = b->AsArg();
if (!output.type) {
output.type = inst.picked_kernel().GetOutputDeclType(arg_name);
......@@ -125,9 +125,9 @@ static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) {
}
// input argument -> instr
if (a->IsArgument() && b->IsInstruct()) {
auto &input = a->AsArgument();
auto &inst = b->AsInstruct();
if (a->IsArg() && b->IsStmt()) {
auto &input = a->AsArg();
auto &inst = b->AsStmt();
if (!input.type) {
input.type = inst.picked_kernel().GetInputDeclType(arg_name);
}
......
......@@ -33,8 +33,8 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK(graph) << "graph not valid";
// sort kernels by the factors.
for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue;
auto& instruct = node.AsInstruct();
if (!node.IsStmt()) continue;
auto& instruct = node.AsStmt();
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
for (auto&& kernel : instruct.valid_kernels) {
size_t score = KernelGrade(*kernel);
......
......@@ -33,7 +33,7 @@ namespace mir {
* - kernel_pick_factors, the factors to consider in picking kernels.
* Set them first before execute the pass.
*/
class StaticKernelPickPass : public mir::InstructionPass {
class StaticKernelPickPass : public mir::StmtPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
......
......@@ -33,7 +33,7 @@ void TypeTargetTransformPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK(!valid_places_.empty());
for (auto& node : nodes) {
if (!node->IsInstruct()) continue;
if (!node->IsStmt()) continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
......@@ -49,22 +49,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
return;
CHECK(inst_node->IsInstruct());
auto& inst = inst_node->AsInstruct();
CHECK(inst_node->IsStmt());
auto& inst = inst_node->AsStmt();
CHECK(in->IsRoleSet());
CHECK(in->IsArgument());
auto in_arg_name = in->AsArgument().name;
CHECK(in->IsArg());
auto in_arg_name = in->AsArg().name;
std::string tmp;
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
CHECK(in->AsArgument().type);
if (!TargetCompatibleTo(*in->AsArgument().type, *decl_arg_type)) {
LOG(INFO) << "found Target unmatched tensor: " << in->AsArgument().name
CHECK(in->AsArg().type);
if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) {
LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name
<< " for kernel " << inst.op->DebugString() << " "
<< *in->AsArgument().type << " -> " << *decl_arg_type;
<< *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArgument().type, *decl_arg_type, in->AsArgument().name,
graph, inst_node, valid_places_);
AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in->AsArg().name, graph,
inst_node, valid_places_);
}
}
......@@ -73,7 +73,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
Node* inst_node, const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Instruct Node.
// So there will be a new Argument node and a new IoCopy Statement Node.
auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name = var + "/trans/" + std::to_string(node_id());
......@@ -85,7 +85,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
// CHECK(io_copy_op);
// Create the new var manually.
inst_node->AsInstruct().op->scope()->Var(io_copy_output_name);
inst_node->AsStmt().op->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction.
lite::OpDesc op_desc;
......@@ -93,16 +93,16 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc.SetInput("Input", {var});
op_desc.SetOutput("Out", {io_copy_output_name});
io_copy_op->Attach(op_desc, inst_node->AsInstruct().op->scope());
io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope());
auto kernels = io_copy_op->CreateKernels(valid_places);
io_copy_inst->AsInstruct("io_copy", std::move(kernels), io_copy_op);
io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op);
// Remove the old link
RemoveDirectedLink(graph->Argument(var), inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
auto& inst = inst_node->AsInstruct();
auto& inst = inst_node->AsStmt();
auto inst_program_desc = inst.op_info()->desc();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
......@@ -111,20 +111,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information
auto desc_dummy = inst_node->AsInstruct().op->op_info()->desc();
auto desc_dummy = inst_node->AsStmt().op->op_info()->desc();
UpdateInputTo(&desc_dummy, var, io_copy_output_name);
lite::OpDesc desc_fake(desc_dummy);
inst_node->AsInstruct().op->Attach(desc_fake,
inst_node->AsInstruct().op->scope());
inst_node->AsStmt().op->Attach(desc_fake, inst_node->AsStmt().op->scope());
std::string tmp;
if (inst_node->AsInstruct().op_info()->GetInputArgname("a", &tmp)) {
if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) {
CHECK(false) << "get old a " << tmp;
}
for (auto& kernel : inst_node->AsInstruct().valid_kernels) {
inst_node->AsInstruct().op->AttachKernel(kernel.get());
for (auto& kernel : inst_node->AsStmt().valid_kernels) {
inst_node->AsStmt().op->AttachKernel(kernel.get());
}
graph->CheckValid();
......
......@@ -34,8 +34,8 @@ class VariablePlaceInferencePass : public DebugPass {
CHECK(!graph->inputs().empty()) << "graph's inputs should be set";
for (const auto& v : graph->inputs()) {
// the feed op might in the inputs
if (v->IsInstruct()) {
LOG(INFO) << "found kernel in inputs " << v->AsInstruct().op_type;
if (v->IsStmt()) {
LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type;
continue;
}
......@@ -49,9 +49,9 @@ class VariablePlaceInferencePass : public DebugPass {
void CheckAllArgumentTypeDetermined(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (node.IsArgument()) {
CHECK(node.AsArgument().type) << "node " << node.AsArgument().name
<< " type not determined, " << &node;
if (node.IsArg()) {
CHECK(node.AsArg().type) << "node " << node.AsArg().name
<< " type not determined, " << &node;
}
}
}
......@@ -59,7 +59,7 @@ class VariablePlaceInferencePass : public DebugPass {
void InferenceArgumentPlace(SSAGraph* graph) {
VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->InstructTopologicalOrder()) {
auto& inst = x->AsInstruct();
auto& inst = x->AsStmt();
// The IoCopyOp is a tool operator, it won't support the type inference.
if (inst.op_type == "io_copy") continue;
// LOG(INFO) << "- inferencing type " <<
......@@ -76,7 +76,7 @@ class VariablePlaceInferencePass : public DebugPass {
VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
auto& arg_node = node->AsArg();
if (!arg_node.type) {
VLOG(4) << "set type " << *type << " " << node;
arg_node.type = type;
......@@ -94,9 +94,9 @@ class VariablePlaceInferencePass : public DebugPass {
VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
auto& arg_node = node->AsArg();
if (!arg_node.type) {
node->AsArgument().type = type;
node->AsArg().type = type;
VLOG(3) << "set type " << *type;
}
}
......
......@@ -38,7 +38,7 @@ TEST(Optimizer, test) {
optimizer.Run(std::move(program), places);
auto runtime_program = optimizer.GenRuntimeProgram();
LOG(INFO) << "num instructions " << runtime_program->num_instructions();
LOG(INFO) << "num statements " << runtime_program->num_instructions();
}
} // namespace lite
......
......@@ -152,7 +152,7 @@ class Type : public DataTypeBase {
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another.
// is is possible to add a statement to transform a type to another.
virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); }
template <bool is_unknown, bool is_tensor = true,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册