From dde80670e830236b10c7fb8a030a4850dc7e5d80 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 10 Dec 2019 14:10:03 +0800 Subject: [PATCH] fix type_target_cast pass. support only copy once for multiple use arg. test=develop (#2572) For multiple-use parameters, only copy once --- lite/core/mir/type_target_cast_pass.cc | 263 ++++++++++++++----------- lite/core/mir/type_target_cast_pass.h | 12 +- 2 files changed, 161 insertions(+), 114 deletions(-) diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc index b008faa687..ae74bd8d4d 100644 --- a/lite/core/mir/type_target_cast_pass.cc +++ b/lite/core/mir/type_target_cast_pass.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include "lite/core/mir/graph_visualize_pass.h" @@ -35,18 +36,23 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr& graph) { CHECK(!valid_places_.empty()); + // record the copied node. + std::unordered_map copied_nodes; + for (auto& node : nodes) { if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; auto inlinks = node->inlinks; for (auto* in : inlinks) { - ComplementInputs(graph.get(), node, in); + ComplementInputs(graph.get(), node, in, &copied_nodes); } } } -void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, - Node* inst_node, - Node* in) { +void TypeTargetTransformPass::ComplementInputs( + SSAGraph* graph, + Node* inst_node, + Node* in, + std::unordered_map* copied_nodes) { // If this input is out of date. if (inst_node->inlinks.end() == std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) @@ -67,8 +73,13 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, << " for kernel " << inst.op()->DebugString() << " " << *in->AsArg().type << " -> " << *decl_arg_type; // Add an IoCopy instruction to make the input compatible with other dist. - AddIoCopyInst( - *in->AsArg().type, *decl_arg_type, in, graph, inst_node, valid_places_); + AddIoCopyInst(*in->AsArg().type, + *decl_arg_type, + in, + graph, + inst_node, + copied_nodes, + valid_places_); } } @@ -78,128 +89,132 @@ void TypeTargetTransformPass::AddIoCopyInst( Node* in, SSAGraph* graph, Node* inst_node, + std::unordered_map* copied_nodes, const std::vector& valid_places) { CHECK(!valid_places.empty()) << "valid_place should be set"; // var -> new_transform_op -> new_var -> inst // So there will be a new Argument node and a new IoCopy Statement Node. CHECK(in->IsArg()); + // auto node_id = [&] { return graph->nodes().size(); }; auto io_copy_output_name = string_format("%s/target_trans", in->AsArg().name.c_str()); // string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id()); - // TODO(MyPandaShaoxiang) should set same place with input? - auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); - // Set the place for io_copy_output_arg node, the target should be equal to - // to.target() - // The precision and layout should be equal to from.precision(), from.layout() - io_copy_output_arg->AsArg().type = - LiteType::GetTensorTy(to.target(), from.precision(), from.layout()); - auto* io_copy_inst = graph->NewInstructNode(); - - bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; - std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy"; - io_copy_output_arg->AsArg().is_persist = in_persist; - // create Op and kernels. - auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type); - CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; - // CHECK(io_copy_op); - // Create the new var manually. - inst_node->AsStmt().op()->scope()->Var(io_copy_output_name); - - // Create IoCopy Instruction. - cpp::OpDesc op_desc; - op_desc.SetType(io_copy_type); - op_desc.SetInput("Input", {in->AsArg().name}); - op_desc.SetOutput("Out", {io_copy_output_name}); - - io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); - auto kernels = io_copy_op->CreateKernels(valid_places); - // fix(MyPandaShaoxiang): select kernel that input_dcl_type same as in.type - bool is_found = false; - std::vector> selected_kernels; - for (auto& kernel : kernels) { - const Type* in_arg_ty = kernel->GetInputDeclType("Input"); - const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); - - VLOG(4) << "------ kernel info -------"; - VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty; - VLOG(4) << "from(last kernel output):" << from; - VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty; - VLOG(4) << "to:" << to << "\n"; - - // kernel choose branch for opencl backend - // judge inst's target whether is kOpenCL - // Note: to == *decl_arg_type == in of inst, not output of last inst - // ignore [layout check] for layout between [to] and [from] - // Because all of origin opencl insts in model, are not default layout - // NCHW, - // so skip layout check. - // detailed node info see below: - // [*in->AsArg().type] -> [from]: out of inst's previous kernel - // [*decl_arg_type] -> [to]: input of inst, not output of last - // [in_arg_ty]: in of io_copy - // [out_arg_ty]: out of io_copy - // - // noto: replace LITE_WITH_OPENCL macro with judge input and output target - // of io_copy - if ((in_arg_ty->target() == TARGET(kOpenCL) || - out_arg_ty->target() == TARGET(kOpenCL)) && // judge OpenCL first - (TargetCompatibleTo(*in_arg_ty, from) && - PrecisionCompatibleTo(*in_arg_ty, from) && - DeviceCompatibleTo(*in_arg_ty, from) && - TargetCompatibleTo(*out_arg_ty, to))) { - VLOG(4) << "picked, opencl found"; - is_found = true; - } else if (TypeCompatible(*in_arg_ty, from) && - out_arg_ty->target() == to.target()) { - VLOG(4) << "picked"; - is_found = true; - } - if (is_found) { - selected_kernels.emplace_back(std::move(kernel)); - // we pick the kernel - io_copy_inst->AsStmt( - io_copy_type, std::move(selected_kernels), io_copy_op); - break; + if (copied_nodes->count(in->AsArg().name)) { + // Remove the old link + RemoveDirectedLink(in, inst_node); + + // Update the original instruction OpDesc. + // Update its input to the io_copy_output_name + // Add new link, newarg->inst + DirectedLink(copied_nodes->at(in->AsArg().name), + inst_node); // [io_copy kernel]'s output -> [current kernel] + + UpdateInstNode(in, graph, inst_node, io_copy_output_name); + } else { + // TODO(MyPandaShaoxiang) should set same place with input? + auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); + // Set the place for io_copy_output_arg node, the target should be equal to + // to.target() + // The precision and layout should be equal to from.precision(), + // from.layout() + io_copy_output_arg->AsArg().type = + LiteType::GetTensorTy(to.target(), from.precision(), from.layout()); + auto* io_copy_inst = graph->NewInstructNode(); + + bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; + std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy"; + io_copy_output_arg->AsArg().is_persist = in_persist; + // create Op and kernels. + auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type); + CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; + // CHECK(io_copy_op); + // Create the new var manually. + inst_node->AsStmt().op()->scope()->Var(io_copy_output_name); + + // Create IoCopy Instruction. + cpp::OpDesc op_desc; + op_desc.SetType(io_copy_type); + op_desc.SetInput("Input", {in->AsArg().name}); + op_desc.SetOutput("Out", {io_copy_output_name}); + + io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); + auto kernels = io_copy_op->CreateKernels(valid_places); + // fix(MyPandaShaoxiang): select kernel that input_dcl_type same as in.type + bool is_found = false; + std::vector> selected_kernels; + for (auto& kernel : kernels) { + const Type* in_arg_ty = kernel->GetInputDeclType("Input"); + const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); + + VLOG(4) << "------ kernel info -------"; + VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty; + VLOG(4) << "from(last kernel output):" << from; + VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty; + VLOG(4) << "to:" << to << "\n"; + + // kernel choose branch for opencl backend + // judge inst's target whether is kOpenCL + // Note: to == *decl_arg_type == in of inst, not output of last inst + // ignore [layout check] for layout between [to] and [from] + // Because all of origin opencl insts in model, are not default layout + // NCHW, + // so skip layout check. + // detailed node info see below: + // [*in->AsArg().type] -> [from]: out of inst's previous kernel + // [*decl_arg_type] -> [to]: input of inst, not output of last + // [in_arg_ty]: in of io_copy + // [out_arg_ty]: out of io_copy + // + // noto: replace LITE_WITH_OPENCL macro with judge input and output target + // of io_copy + if ((in_arg_ty->target() == TARGET(kOpenCL) || + out_arg_ty->target() == TARGET(kOpenCL)) && // judge OpenCL first + (TargetCompatibleTo(*in_arg_ty, from) && + PrecisionCompatibleTo(*in_arg_ty, from) && + DeviceCompatibleTo(*in_arg_ty, from) && + TargetCompatibleTo(*out_arg_ty, to))) { + VLOG(4) << "picked, opencl found"; + is_found = true; + } else if (TypeCompatible(*in_arg_ty, from) && + out_arg_ty->target() == to.target()) { + VLOG(4) << "picked"; + is_found = true; + } + + if (is_found) { + selected_kernels.emplace_back(std::move(kernel)); + // we pick the kernel + io_copy_inst->AsStmt( + io_copy_type, std::move(selected_kernels), io_copy_op); + (*copied_nodes)[in->AsArg().name] = io_copy_output_arg; + break; + } + + VLOG(4) << "not picked"; } - VLOG(4) << "not picked"; - } + CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from + << ":" << in->AsArg().name << " -> " << to << ":" + << inst_node->AsStmt().op_info()->Type(); + // Remove the old link + RemoveDirectedLink(in, inst_node); - CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from - << ":" << in->AsArg().name << " -> " << to << ":" - << inst_node->AsStmt().op_info()->Type(); - // Remove the old link - RemoveDirectedLink(in, inst_node); - - // Update the original instruction OpDesc. - // Update its input to the io_copy_output_name - // Add new link, var -> new_inst, new_inst->newarg, newarg->inst - DirectedLink(in, io_copy_inst); // [last kernel]'s output -> [io_copy kernel] - DirectedLink( - io_copy_inst, - io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output - DirectedLink(io_copy_output_arg, - inst_node); // [io_copy kernel]'s output -> [current kernel] + // Update the original instruction OpDesc. + // Update its input to the io_copy_output_name + // Add new link, var -> new_inst, new_inst->newarg, newarg->inst + DirectedLink(in, + io_copy_inst); // [last kernel]'s output -> [io_copy kernel] + DirectedLink( + io_copy_inst, + io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output + DirectedLink(io_copy_output_arg, + inst_node); // [io_copy kernel]'s output -> [current kernel] - // reset opdesc and update kernel information - UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), - in->AsArg().name, - io_copy_output_name); - auto original_selected_kernel = - std::move(inst_node->AsStmt().kernels().front()); - auto update_op_info = *inst_node->AsStmt().op_info(); - // ResetOp() will change the Stmt op_info_ value, - // after that the old op_info_ value will be nullified. - // So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp. - // `update_op_info` is the copy of `*inst_node->AsStmt().op_info(). - // Whenever update the op_info of a stmt, we should call its ResetOp(). - inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places()); - inst_node->AsStmt().kernels().clear(); - inst_node->AsStmt().kernels().emplace_back( - std::move(original_selected_kernel)); + UpdateInstNode(in, graph, inst_node, io_copy_output_name); + } std::string tmp; if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { @@ -220,6 +235,28 @@ void TypeTargetTransformPass::SetValidPlaces( valid_places_ = valid_places; } +void TypeTargetTransformPass::UpdateInstNode(Node* in, + SSAGraph* graph, + Node* inst_node, + std::string io_copy_output_name) { + // reset opdesc and update kernel information + UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), + in->AsArg().name, + io_copy_output_name); + auto original_selected_kernel = + std::move(inst_node->AsStmt().kernels().front()); + auto update_op_info = *inst_node->AsStmt().op_info(); + // ResetOp() will change the Stmt op_info_ value, + // after that the old op_info_ value will be nullified. + // So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp. + // `update_op_info` is the copy of `*inst_node->AsStmt().op_info(). + // Whenever update the op_info of a stmt, we should call its ResetOp(). + inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places()); + inst_node->AsStmt().kernels().clear(); + inst_node->AsStmt().kernels().emplace_back( + std::move(original_selected_kernel)); +} + } // namespace mir } // namespace lite } // namespace paddle diff --git a/lite/core/mir/type_target_cast_pass.h b/lite/core/mir/type_target_cast_pass.h index 8a8cfaf9f9..e9a275882f 100644 --- a/lite/core/mir/type_target_cast_pass.h +++ b/lite/core/mir/type_target_cast_pass.h @@ -16,6 +16,7 @@ #include #include +#include #include #include "lite/core/mir/pass.h" #include "lite/core/op_registry.h" @@ -44,13 +45,17 @@ class TypeTargetTransformPass : public ProgramPass { public: void Apply(const std::unique_ptr& graph) override; - void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); + void ComplementInputs(SSAGraph* graph, + Node* inst_node, + Node* in, + std::unordered_map* copied_nodes); void AddIoCopyInst(const Type& from, const Type& to, Node* in, SSAGraph* graph, Node* inst_node, + std::unordered_map* copied_nodes, const std::vector& valid_places); void SetValidPlaces(const std::vector& valid_places); @@ -58,6 +63,11 @@ class TypeTargetTransformPass : public ProgramPass { const std::vector& valid_places() const { return valid_places_; } private: + void UpdateInstNode(Node* in, + SSAGraph* graph, + Node* inst_node, + std::string io_copy_output_name); + std::vector valid_places_; }; -- GitLab