未验证 提交 eb4f381c 编写于 作者: W Wilber 提交者: GitHub

for multiple-use args, only cast once. test=develop (#3404)

for multiple-use args, only cast once
上级 fbe0799e
...@@ -80,7 +80,7 @@ static bool InferScaleFromSubgraph(std::string var_name, ...@@ -80,7 +80,7 @@ static bool InferScaleFromSubgraph(std::string var_name,
auto input_or_output_scales = op_info->GetAttr<std::vector<float>>(attr_name); auto input_or_output_scales = op_info->GetAttr<std::vector<float>>(attr_name);
auto size = input_or_output_names.size(); auto size = input_or_output_names.size();
CHECK(size == input_or_output_scales.size()); CHECK(size == input_or_output_scales.size());
for (int i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
if (input_or_output_names[i] == var_name) { if (input_or_output_names[i] == var_name) {
*scale = input_or_output_scales[i]; *scale = input_or_output_scales[i];
return true; return true;
...@@ -137,18 +137,23 @@ void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -137,18 +137,23 @@ void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
nodes.push_back(node); nodes.push_back(node);
} }
// record the copied node.
std::unordered_map<std::string, Node*> cast_nodes;
for (auto& node : nodes) { for (auto& node : nodes) {
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks; auto inlinks = node->inlinks;
for (auto* in : inlinks) { for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in); ComplementInputs(graph.get(), node, in, &cast_nodes);
} }
} }
} }
void PrecisionCastPass::ComplementInputs(SSAGraph* graph, void PrecisionCastPass::ComplementInputs(
SSAGraph* graph,
Node* inst_node, Node* inst_node,
Node* in) { Node* in,
std::unordered_map<std::string, Node*>* cast_nodes) {
// If this input is out of date. // If this input is out of date.
if (inst_node->inlinks.end() == if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
...@@ -184,15 +189,18 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph, ...@@ -184,15 +189,18 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
in, in,
graph, graph,
inst_node, inst_node,
cast_nodes,
graph->valid_places()); graph->valid_places());
} }
} }
void PrecisionCastPass::AddCastInst(const Type& from, void PrecisionCastPass::AddCastInst(
const Type& from,
const Type& to, const Type& to,
Node* in, Node* in,
SSAGraph* graph, SSAGraph* graph,
Node* inst_node, Node* inst_node,
std::unordered_map<std::string, Node*>* cast_nodes,
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set"; CHECK(!valid_places.empty()) << "valid_place should be set";
...@@ -203,6 +211,18 @@ void PrecisionCastPass::AddCastInst(const Type& from, ...@@ -203,6 +211,18 @@ void PrecisionCastPass::AddCastInst(const Type& from,
auto cast_op_output_name = in->AsArg().name + "/precision_trans"; auto cast_op_output_name = in->AsArg().name + "/precision_trans";
// in->AsArg().name + "/precision_trans/" + // in->AsArg().name + "/precision_trans/" +
// paddle::lite::to_string(node_id()); // paddle::lite::to_string(node_id());
if (cast_nodes->count(in->AsArg().name)) {
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the cast_op_output_name
// Add new link, newarg->inst
DirectedLink(cast_nodes->at(in->AsArg().name),
inst_node); // [io_copy kernel]'s output -> [current kernel]
// reset opdesc and update kernel information
UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
} else {
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name); auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name);
cast_op_output_arg->AsArg().type = cast_op_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), to.precision(), from.layout()); LiteType::GetTensorTy(from.target(), to.precision(), from.layout());
...@@ -241,6 +261,7 @@ void PrecisionCastPass::AddCastInst(const Type& from, ...@@ -241,6 +261,7 @@ void PrecisionCastPass::AddCastInst(const Type& from,
selected_kernels.emplace_back(std::move(kernel)); selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel // we pick the kernel
cast_inst->AsStmt(cast_type, std::move(selected_kernels), cast_op); cast_inst->AsStmt(cast_type, std::move(selected_kernels), cast_op);
(*cast_nodes)[in->AsArg().name] = cast_op_output_arg;
break; break;
} }
} }
...@@ -263,6 +284,7 @@ void PrecisionCastPass::AddCastInst(const Type& from, ...@@ -263,6 +284,7 @@ void PrecisionCastPass::AddCastInst(const Type& from,
// reset opdesc and update kernel information // reset opdesc and update kernel information
UpdateInputs( UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name); inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
}
// recreate the op // recreate the op
auto original_selected_kernel = auto original_selected_kernel =
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "lite/core/mir/pass.h" #include "lite/core/mir/pass.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -34,13 +35,17 @@ class PrecisionCastPass : public ProgramPass { ...@@ -34,13 +35,17 @@ class PrecisionCastPass : public ProgramPass {
public: public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override; void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); void ComplementInputs(SSAGraph* graph,
Node* inst_node,
Node* in,
std::unordered_map<std::string, Node*>* cast_nodes);
void AddCastInst(const Type& from, void AddCastInst(const Type& from,
const Type& to, const Type& to,
Node* in, Node* in,
SSAGraph* graph, SSAGraph* graph,
Node* inst_node, Node* inst_node,
std::unordered_map<std::string, Node*>* cast_nodes,
const std::vector<Place>& valid_places); const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places); void SetValidPlaces(const std::vector<Place>& valid_places);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册