提交 dde80670 编写于 作者: W Wilber 提交者: GitHub

fix type_target_cast pass. support only copy once for multiple use arg. test=develop (#2572)

For multiple-use parameters, only copy once
上级 8b84b728
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <list> #include <list>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/graph_visualize_pass.h"
...@@ -35,18 +36,23 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -35,18 +36,23 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
CHECK(!valid_places_.empty()); CHECK(!valid_places_.empty());
// record the copied node.
std::unordered_map<std::string, Node*> copied_nodes;
for (auto& node : nodes) { for (auto& node : nodes) {
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks; auto inlinks = node->inlinks;
for (auto* in : inlinks) { for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in); ComplementInputs(graph.get(), node, in, &copied_nodes);
} }
} }
} }
void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, void TypeTargetTransformPass::ComplementInputs(
SSAGraph* graph,
Node* inst_node, Node* inst_node,
Node* in) { Node* in,
std::unordered_map<std::string, Node*>* copied_nodes) {
// If this input is out of date. // If this input is out of date.
if (inst_node->inlinks.end() == if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
...@@ -67,8 +73,13 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, ...@@ -67,8 +73,13 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph,
<< " for kernel " << inst.op()->DebugString() << " " << " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type; << *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist. // Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst( AddIoCopyInst(*in->AsArg().type,
*in->AsArg().type, *decl_arg_type, in, graph, inst_node, valid_places_); *decl_arg_type,
in,
graph,
inst_node,
copied_nodes,
valid_places_);
} }
} }
...@@ -78,21 +89,37 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -78,21 +89,37 @@ void TypeTargetTransformPass::AddIoCopyInst(
Node* in, Node* in,
SSAGraph* graph, SSAGraph* graph,
Node* inst_node, Node* inst_node,
std::unordered_map<std::string, Node*>* copied_nodes,
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set"; CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst // var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Statement Node. // So there will be a new Argument node and a new IoCopy Statement Node.
CHECK(in->IsArg()); CHECK(in->IsArg());
// auto node_id = [&] { return graph->nodes().size(); }; // auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name = auto io_copy_output_name =
string_format("%s/target_trans", in->AsArg().name.c_str()); string_format("%s/target_trans", in->AsArg().name.c_str());
// string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id()); // string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id());
if (copied_nodes->count(in->AsArg().name)) {
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, newarg->inst
DirectedLink(copied_nodes->at(in->AsArg().name),
inst_node); // [io_copy kernel]'s output -> [current kernel]
UpdateInstNode(in, graph, inst_node, io_copy_output_name);
} else {
// TODO(MyPandaShaoxiang) should set same place with input? // TODO(MyPandaShaoxiang) should set same place with input?
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
// Set the place for io_copy_output_arg node, the target should be equal to // Set the place for io_copy_output_arg node, the target should be equal to
// to.target() // to.target()
// The precision and layout should be equal to from.precision(), from.layout() // The precision and layout should be equal to from.precision(),
// from.layout()
io_copy_output_arg->AsArg().type = io_copy_output_arg->AsArg().type =
LiteType::GetTensorTy(to.target(), from.precision(), from.layout()); LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
auto* io_copy_inst = graph->NewInstructNode(); auto* io_copy_inst = graph->NewInstructNode();
...@@ -162,6 +189,7 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -162,6 +189,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
// we pick the kernel // we pick the kernel
io_copy_inst->AsStmt( io_copy_inst->AsStmt(
io_copy_type, std::move(selected_kernels), io_copy_op); io_copy_type, std::move(selected_kernels), io_copy_op);
(*copied_nodes)[in->AsArg().name] = io_copy_output_arg;
break; break;
} }
...@@ -177,29 +205,16 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -177,29 +205,16 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Update the original instruction OpDesc. // Update the original instruction OpDesc.
// Update its input to the io_copy_output_name // Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst // Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(in, io_copy_inst); // [last kernel]'s output -> [io_copy kernel] DirectedLink(in,
io_copy_inst); // [last kernel]'s output -> [io_copy kernel]
DirectedLink( DirectedLink(
io_copy_inst, io_copy_inst,
io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output
DirectedLink(io_copy_output_arg, DirectedLink(io_copy_output_arg,
inst_node); // [io_copy kernel]'s output -> [current kernel] inst_node); // [io_copy kernel]'s output -> [current kernel]
// reset opdesc and update kernel information UpdateInstNode(in, graph, inst_node, io_copy_output_name);
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), }
in->AsArg().name,
io_copy_output_name);
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info();
// ResetOp() will change the Stmt op_info_ value,
// after that the old op_info_ value will be nullified.
// So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp.
// `update_op_info` is the copy of `*inst_node->AsStmt().op_info().
// Whenever update the op_info of a stmt, we should call its ResetOp().
inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places());
inst_node->AsStmt().kernels().clear();
inst_node->AsStmt().kernels().emplace_back(
std::move(original_selected_kernel));
std::string tmp; std::string tmp;
if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) {
...@@ -220,6 +235,28 @@ void TypeTargetTransformPass::SetValidPlaces( ...@@ -220,6 +235,28 @@ void TypeTargetTransformPass::SetValidPlaces(
valid_places_ = valid_places; valid_places_ = valid_places;
} }
void TypeTargetTransformPass::UpdateInstNode(Node* in,
SSAGraph* graph,
Node* inst_node,
std::string io_copy_output_name) {
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(),
in->AsArg().name,
io_copy_output_name);
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info();
// ResetOp() will change the Stmt op_info_ value,
// after that the old op_info_ value will be nullified.
// So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp.
// `update_op_info` is the copy of `*inst_node->AsStmt().op_info().
// Whenever update the op_info of a stmt, we should call its ResetOp().
inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places());
inst_node->AsStmt().kernels().clear();
inst_node->AsStmt().kernels().emplace_back(
std::move(original_selected_kernel));
}
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "lite/core/mir/pass.h" #include "lite/core/mir/pass.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -44,13 +45,17 @@ class TypeTargetTransformPass : public ProgramPass { ...@@ -44,13 +45,17 @@ class TypeTargetTransformPass : public ProgramPass {
public: public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override; void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); void ComplementInputs(SSAGraph* graph,
Node* inst_node,
Node* in,
std::unordered_map<std::string, Node*>* copied_nodes);
void AddIoCopyInst(const Type& from, void AddIoCopyInst(const Type& from,
const Type& to, const Type& to,
Node* in, Node* in,
SSAGraph* graph, SSAGraph* graph,
Node* inst_node, Node* inst_node,
std::unordered_map<std::string, Node*>* copied_nodes,
const std::vector<Place>& valid_places); const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places); void SetValidPlaces(const std::vector<Place>& valid_places);
...@@ -58,6 +63,11 @@ class TypeTargetTransformPass : public ProgramPass { ...@@ -58,6 +63,11 @@ class TypeTargetTransformPass : public ProgramPass {
const std::vector<Place>& valid_places() const { return valid_places_; } const std::vector<Place>& valid_places() const { return valid_places_; }
private: private:
void UpdateInstNode(Node* in,
SSAGraph* graph,
Node* inst_node,
std::string io_copy_output_name);
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册