未验证 提交 eb4f381c 编写于 作者: W Wilber 提交者: GitHub

for multiple-use args, only cast once. test=develop (#3404)

for multiple-use args, only cast once
上级 fbe0799e
......@@ -80,7 +80,7 @@ static bool InferScaleFromSubgraph(std::string var_name,
auto input_or_output_scales = op_info->GetAttr<std::vector<float>>(attr_name);
auto size = input_or_output_names.size();
CHECK(size == input_or_output_scales.size());
for (int i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
if (input_or_output_names[i] == var_name) {
*scale = input_or_output_scales[i];
return true;
......@@ -137,18 +137,23 @@ void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
nodes.push_back(node);
}
// record the copied node.
std::unordered_map<std::string, Node*> cast_nodes;
for (auto& node : nodes) {
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
ComplementInputs(graph.get(), node, in, &cast_nodes);
}
}
}
void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
void PrecisionCastPass::ComplementInputs(
SSAGraph* graph,
Node* inst_node,
Node* in) {
Node* in,
std::unordered_map<std::string, Node*>* cast_nodes) {
// If this input is out of date.
if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
......@@ -184,15 +189,18 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
in,
graph,
inst_node,
cast_nodes,
graph->valid_places());
}
}
void PrecisionCastPass::AddCastInst(const Type& from,
void PrecisionCastPass::AddCastInst(
const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
std::unordered_map<std::string, Node*>* cast_nodes,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
......@@ -203,6 +211,18 @@ void PrecisionCastPass::AddCastInst(const Type& from,
auto cast_op_output_name = in->AsArg().name + "/precision_trans";
// in->AsArg().name + "/precision_trans/" +
// paddle::lite::to_string(node_id());
if (cast_nodes->count(in->AsArg().name)) {
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the cast_op_output_name
// Add new link, newarg->inst
DirectedLink(cast_nodes->at(in->AsArg().name),
inst_node); // [io_copy kernel]'s output -> [current kernel]
// reset opdesc and update kernel information
UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
} else {
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name);
cast_op_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), to.precision(), from.layout());
......@@ -241,6 +261,7 @@ void PrecisionCastPass::AddCastInst(const Type& from,
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
cast_inst->AsStmt(cast_type, std::move(selected_kernels), cast_op);
(*cast_nodes)[in->AsArg().name] = cast_op_output_arg;
break;
}
}
......@@ -263,6 +284,7 @@ void PrecisionCastPass::AddCastInst(const Type& from,
// reset opdesc and update kernel information
UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
}
// recreate the op
auto original_selected_kernel =
......
......@@ -16,6 +16,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/op_registry.h"
......@@ -34,13 +35,17 @@ class PrecisionCastPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void ComplementInputs(SSAGraph* graph,
Node* inst_node,
Node* in,
std::unordered_map<std::string, Node*>* cast_nodes);
void AddCastInst(const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
std::unordered_map<std::string, Node*>* cast_nodes,
const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册