未验证 提交 8903c795 编写于 作者: W Wilber 提交者: GitHub

fix type_target_cast pass. support only copy once for multiple use arg. test=develop (#2572)

For multiple-use parameters, only copy once
上级 7ef0e7fe
......@@ -16,6 +16,7 @@
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
......@@ -35,18 +36,23 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
CHECK(!valid_places_.empty());
// record the copied node.
std::unordered_map<std::string, Node*> copied_nodes;
for (auto& node : nodes) {
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
ComplementInputs(graph.get(), node, in, &copied_nodes);
}
}
}
void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph,
Node* inst_node,
Node* in) {
void TypeTargetTransformPass::ComplementInputs(
SSAGraph* graph,
Node* inst_node,
Node* in,
std::unordered_map<std::string, Node*>* copied_nodes) {
// If this input is out of date.
if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
......@@ -67,8 +73,13 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph,
<< " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(
*in->AsArg().type, *decl_arg_type, in, graph, inst_node, valid_places_);
AddIoCopyInst(*in->AsArg().type,
*decl_arg_type,
in,
graph,
inst_node,
copied_nodes,
valid_places_);
}
}
......@@ -78,128 +89,132 @@ void TypeTargetTransformPass::AddIoCopyInst(
Node* in,
SSAGraph* graph,
Node* inst_node,
std::unordered_map<std::string, Node*>* copied_nodes,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Statement Node.
CHECK(in->IsArg());
// auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name =
string_format("%s/target_trans", in->AsArg().name.c_str());
// string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id());
// TODO(MyPandaShaoxiang) should set same place with input?
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
// Set the place for io_copy_output_arg node, the target should be equal to
// to.target()
// The precision and layout should be equal to from.precision(), from.layout()
io_copy_output_arg->AsArg().type =
LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
auto* io_copy_inst = graph->NewInstructNode();
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy";
io_copy_output_arg->AsArg().is_persist = in_persist;
// create Op and kernels.
auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type);
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
// CHECK(io_copy_op);
// Create the new var manually.
inst_node->AsStmt().op()->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction.
cpp::OpDesc op_desc;
op_desc.SetType(io_copy_type);
op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {io_copy_output_name});
io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = io_copy_op->CreateKernels(valid_places);
// fix(MyPandaShaoxiang): select kernel that input_dcl_type same as in.type
bool is_found = false;
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
VLOG(4) << "------ kernel info -------";
VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty;
VLOG(4) << "from(last kernel output):" << from;
VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty;
VLOG(4) << "to:" << to << "\n";
// kernel choose branch for opencl backend
// judge inst's target whether is kOpenCL
// Note: to == *decl_arg_type == in of inst, not output of last inst
// ignore [layout check] for layout between [to] and [from]
// Because all of origin opencl insts in model, are not default layout
// NCHW,
// so skip layout check.
// detailed node info see below:
// [*in->AsArg().type] -> [from]: out of inst's previous kernel
// [*decl_arg_type] -> [to]: input of inst, not output of last
// [in_arg_ty]: in of io_copy
// [out_arg_ty]: out of io_copy
//
// noto: replace LITE_WITH_OPENCL macro with judge input and output target
// of io_copy
if ((in_arg_ty->target() == TARGET(kOpenCL) ||
out_arg_ty->target() == TARGET(kOpenCL)) && // judge OpenCL first
(TargetCompatibleTo(*in_arg_ty, from) &&
PrecisionCompatibleTo(*in_arg_ty, from) &&
DeviceCompatibleTo(*in_arg_ty, from) &&
TargetCompatibleTo(*out_arg_ty, to))) {
VLOG(4) << "picked, opencl found";
is_found = true;
} else if (TypeCompatible(*in_arg_ty, from) &&
out_arg_ty->target() == to.target()) {
VLOG(4) << "picked";
is_found = true;
}
if (is_found) {
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
io_copy_inst->AsStmt(
io_copy_type, std::move(selected_kernels), io_copy_op);
break;
if (copied_nodes->count(in->AsArg().name)) {
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, newarg->inst
DirectedLink(copied_nodes->at(in->AsArg().name),
inst_node); // [io_copy kernel]'s output -> [current kernel]
UpdateInstNode(in, graph, inst_node, io_copy_output_name);
} else {
// TODO(MyPandaShaoxiang) should set same place with input?
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
// Set the place for io_copy_output_arg node, the target should be equal to
// to.target()
// The precision and layout should be equal to from.precision(),
// from.layout()
io_copy_output_arg->AsArg().type =
LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
auto* io_copy_inst = graph->NewInstructNode();
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy";
io_copy_output_arg->AsArg().is_persist = in_persist;
// create Op and kernels.
auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type);
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
// CHECK(io_copy_op);
// Create the new var manually.
inst_node->AsStmt().op()->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction.
cpp::OpDesc op_desc;
op_desc.SetType(io_copy_type);
op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {io_copy_output_name});
io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = io_copy_op->CreateKernels(valid_places);
// fix(MyPandaShaoxiang): select kernel that input_dcl_type same as in.type
bool is_found = false;
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
VLOG(4) << "------ kernel info -------";
VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty;
VLOG(4) << "from(last kernel output):" << from;
VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty;
VLOG(4) << "to:" << to << "\n";
// kernel choose branch for opencl backend
// judge inst's target whether is kOpenCL
// Note: to == *decl_arg_type == in of inst, not output of last inst
// ignore [layout check] for layout between [to] and [from]
// Because all of origin opencl insts in model, are not default layout
// NCHW,
// so skip layout check.
// detailed node info see below:
// [*in->AsArg().type] -> [from]: out of inst's previous kernel
// [*decl_arg_type] -> [to]: input of inst, not output of last
// [in_arg_ty]: in of io_copy
// [out_arg_ty]: out of io_copy
//
// noto: replace LITE_WITH_OPENCL macro with judge input and output target
// of io_copy
if ((in_arg_ty->target() == TARGET(kOpenCL) ||
out_arg_ty->target() == TARGET(kOpenCL)) && // judge OpenCL first
(TargetCompatibleTo(*in_arg_ty, from) &&
PrecisionCompatibleTo(*in_arg_ty, from) &&
DeviceCompatibleTo(*in_arg_ty, from) &&
TargetCompatibleTo(*out_arg_ty, to))) {
VLOG(4) << "picked, opencl found";
is_found = true;
} else if (TypeCompatible(*in_arg_ty, from) &&
out_arg_ty->target() == to.target()) {
VLOG(4) << "picked";
is_found = true;
}
if (is_found) {
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
io_copy_inst->AsStmt(
io_copy_type, std::move(selected_kernels), io_copy_op);
(*copied_nodes)[in->AsArg().name] = io_copy_output_arg;
break;
}
VLOG(4) << "not picked";
}
VLOG(4) << "not picked";
}
CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from
<< ":" << in->AsArg().name << " -> " << to << ":"
<< inst_node->AsStmt().op_info()->Type();
// Remove the old link
RemoveDirectedLink(in, inst_node);
CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from
<< ":" << in->AsArg().name << " -> " << to << ":"
<< inst_node->AsStmt().op_info()->Type();
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(in, io_copy_inst); // [last kernel]'s output -> [io_copy kernel]
DirectedLink(
io_copy_inst,
io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output
DirectedLink(io_copy_output_arg,
inst_node); // [io_copy kernel]'s output -> [current kernel]
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(in,
io_copy_inst); // [last kernel]'s output -> [io_copy kernel]
DirectedLink(
io_copy_inst,
io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output
DirectedLink(io_copy_output_arg,
inst_node); // [io_copy kernel]'s output -> [current kernel]
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(),
in->AsArg().name,
io_copy_output_name);
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info();
// ResetOp() will change the Stmt op_info_ value,
// after that the old op_info_ value will be nullified.
// So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp.
// `update_op_info` is the copy of `*inst_node->AsStmt().op_info().
// Whenever update the op_info of a stmt, we should call its ResetOp().
inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places());
inst_node->AsStmt().kernels().clear();
inst_node->AsStmt().kernels().emplace_back(
std::move(original_selected_kernel));
UpdateInstNode(in, graph, inst_node, io_copy_output_name);
}
std::string tmp;
if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) {
......@@ -220,6 +235,28 @@ void TypeTargetTransformPass::SetValidPlaces(
valid_places_ = valid_places;
}
void TypeTargetTransformPass::UpdateInstNode(Node* in,
SSAGraph* graph,
Node* inst_node,
std::string io_copy_output_name) {
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(),
in->AsArg().name,
io_copy_output_name);
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info();
// ResetOp() will change the Stmt op_info_ value,
// after that the old op_info_ value will be nullified.
// So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp.
// `update_op_info` is the copy of `*inst_node->AsStmt().op_info().
// Whenever update the op_info of a stmt, we should call its ResetOp().
inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places());
inst_node->AsStmt().kernels().clear();
inst_node->AsStmt().kernels().emplace_back(
std::move(original_selected_kernel));
}
} // namespace mir
} // namespace lite
} // namespace paddle
......
......@@ -16,6 +16,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/op_registry.h"
......@@ -44,13 +45,17 @@ class TypeTargetTransformPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void ComplementInputs(SSAGraph* graph,
Node* inst_node,
Node* in,
std::unordered_map<std::string, Node*>* copied_nodes);
void AddIoCopyInst(const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
std::unordered_map<std::string, Node*>* copied_nodes,
const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places);
......@@ -58,6 +63,11 @@ class TypeTargetTransformPass : public ProgramPass {
const std::vector<Place>& valid_places() const { return valid_places_; }
private:
void UpdateInstNode(Node* in,
SSAGraph* graph,
Node* inst_node,
std::string io_copy_output_name);
std::vector<Place> valid_places_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册