提交 dde80670 编写于 作者: W Wilber 提交者: GitHub

fix type_target_cast pass. support only copy once for multiple use arg. test=develop (#2572)

For multiple-use parameters, only copy once
上级 8b84b728
......@@ -16,6 +16,7 @@
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
......@@ -35,18 +36,23 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
CHECK(!valid_places_.empty());
// record the copied node.
std::unordered_map<std::string, Node*> copied_nodes;
for (auto& node : nodes) {
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
ComplementInputs(graph.get(), node, in, &copied_nodes);
}
}
}
void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph,
void TypeTargetTransformPass::ComplementInputs(
SSAGraph* graph,
Node* inst_node,
Node* in) {
Node* in,
std::unordered_map<std::string, Node*>* copied_nodes) {
// If this input is out of date.
if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
......@@ -67,8 +73,13 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph,
<< " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(
*in->AsArg().type, *decl_arg_type, in, graph, inst_node, valid_places_);
AddIoCopyInst(*in->AsArg().type,
*decl_arg_type,
in,
graph,
inst_node,
copied_nodes,
valid_places_);
}
}
......@@ -78,21 +89,37 @@ void TypeTargetTransformPass::AddIoCopyInst(
Node* in,
SSAGraph* graph,
Node* inst_node,
std::unordered_map<std::string, Node*>* copied_nodes,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Statement Node.
CHECK(in->IsArg());
// auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name =
string_format("%s/target_trans", in->AsArg().name.c_str());
// string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id());
if (copied_nodes->count(in->AsArg().name)) {
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, newarg->inst
DirectedLink(copied_nodes->at(in->AsArg().name),
inst_node); // [io_copy kernel]'s output -> [current kernel]
UpdateInstNode(in, graph, inst_node, io_copy_output_name);
} else {
// TODO(MyPandaShaoxiang) should set same place with input?
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
// Set the place for io_copy_output_arg node, the target should be equal to
// to.target()
// The precision and layout should be equal to from.precision(), from.layout()
// The precision and layout should be equal to from.precision(),
// from.layout()
io_copy_output_arg->AsArg().type =
LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
auto* io_copy_inst = graph->NewInstructNode();
......@@ -162,6 +189,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
// we pick the kernel
io_copy_inst->AsStmt(
io_copy_type, std::move(selected_kernels), io_copy_op);
(*copied_nodes)[in->AsArg().name] = io_copy_output_arg;
break;
}
......@@ -177,29 +205,16 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(in, io_copy_inst); // [last kernel]'s output -> [io_copy kernel]
DirectedLink(in,
io_copy_inst); // [last kernel]'s output -> [io_copy kernel]
DirectedLink(
io_copy_inst,
io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output
DirectedLink(io_copy_output_arg,
inst_node); // [io_copy kernel]'s output -> [current kernel]
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(),
in->AsArg().name,
io_copy_output_name);
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info();
// ResetOp() will change the Stmt op_info_ value,
// after that the old op_info_ value will be nullified.
// So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp.
// `update_op_info` is the copy of `*inst_node->AsStmt().op_info().
// Whenever update the op_info of a stmt, we should call its ResetOp().
inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places());
inst_node->AsStmt().kernels().clear();
inst_node->AsStmt().kernels().emplace_back(
std::move(original_selected_kernel));
UpdateInstNode(in, graph, inst_node, io_copy_output_name);
}
std::string tmp;
if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) {
......@@ -220,6 +235,28 @@ void TypeTargetTransformPass::SetValidPlaces(
valid_places_ = valid_places;
}
void TypeTargetTransformPass::UpdateInstNode(Node* in,
SSAGraph* graph,
Node* inst_node,
std::string io_copy_output_name) {
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(),
in->AsArg().name,
io_copy_output_name);
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info();
// ResetOp() will change the Stmt op_info_ value,
// after that the old op_info_ value will be nullified.
// So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp.
// `update_op_info` is the copy of `*inst_node->AsStmt().op_info().
// Whenever update the op_info of a stmt, we should call its ResetOp().
inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places());
inst_node->AsStmt().kernels().clear();
inst_node->AsStmt().kernels().emplace_back(
std::move(original_selected_kernel));
}
} // namespace mir
} // namespace lite
} // namespace paddle
......
......@@ -16,6 +16,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/op_registry.h"
......@@ -44,13 +45,17 @@ class TypeTargetTransformPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void ComplementInputs(SSAGraph* graph,
Node* inst_node,
Node* in,
std::unordered_map<std::string, Node*>* copied_nodes);
void AddIoCopyInst(const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
std::unordered_map<std::string, Node*>* copied_nodes,
const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places);
......@@ -58,6 +63,11 @@ class TypeTargetTransformPass : public ProgramPass {
const std::vector<Place>& valid_places() const { return valid_places_; }
private:
void UpdateInstNode(Node* in,
SSAGraph* graph,
Node* inst_node,
std::string io_copy_output_name);
std::vector<Place> valid_places_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册