提交 13858cf6 编写于 作者: W Wilber 提交者: GitHub

for multiple-use args, only cast once. test=develop (#3404)

for multiple-use args, only cast once
上级 30145270
...@@ -80,7 +80,7 @@ static bool InferScaleFromSubgraph(std::string var_name, ...@@ -80,7 +80,7 @@ static bool InferScaleFromSubgraph(std::string var_name,
auto input_or_output_scales = op_info->GetAttr<std::vector<float>>(attr_name); auto input_or_output_scales = op_info->GetAttr<std::vector<float>>(attr_name);
auto size = input_or_output_names.size(); auto size = input_or_output_names.size();
CHECK(size == input_or_output_scales.size()); CHECK(size == input_or_output_scales.size());
for (int i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
if (input_or_output_names[i] == var_name) { if (input_or_output_names[i] == var_name) {
*scale = input_or_output_scales[i]; *scale = input_or_output_scales[i];
return true; return true;
...@@ -137,18 +137,23 @@ void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -137,18 +137,23 @@ void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
nodes.push_back(node); nodes.push_back(node);
} }
// record the copied node.
std::unordered_map<std::string, Node*> cast_nodes;
for (auto& node : nodes) { for (auto& node : nodes) {
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks; auto inlinks = node->inlinks;
for (auto* in : inlinks) { for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in); ComplementInputs(graph.get(), node, in, &cast_nodes);
} }
} }
} }
void PrecisionCastPass::ComplementInputs(SSAGraph* graph, void PrecisionCastPass::ComplementInputs(
Node* inst_node, SSAGraph* graph,
Node* in) { Node* inst_node,
Node* in,
std::unordered_map<std::string, Node*>* cast_nodes) {
// If this input is out of date. // If this input is out of date.
if (inst_node->inlinks.end() == if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
...@@ -184,16 +189,19 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph, ...@@ -184,16 +189,19 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
in, in,
graph, graph,
inst_node, inst_node,
cast_nodes,
graph->valid_places()); graph->valid_places());
} }
} }
void PrecisionCastPass::AddCastInst(const Type& from, void PrecisionCastPass::AddCastInst(
const Type& to, const Type& from,
Node* in, const Type& to,
SSAGraph* graph, Node* in,
Node* inst_node, SSAGraph* graph,
const std::vector<Place>& valid_places) { Node* inst_node,
std::unordered_map<std::string, Node*>* cast_nodes,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set"; CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst // var -> new_transform_op -> new_var -> inst
...@@ -203,66 +211,80 @@ void PrecisionCastPass::AddCastInst(const Type& from, ...@@ -203,66 +211,80 @@ void PrecisionCastPass::AddCastInst(const Type& from,
auto cast_op_output_name = in->AsArg().name + "/precision_trans"; auto cast_op_output_name = in->AsArg().name + "/precision_trans";
// in->AsArg().name + "/precision_trans/" + // in->AsArg().name + "/precision_trans/" +
// paddle::lite::to_string(node_id()); // paddle::lite::to_string(node_id());
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name); if (cast_nodes->count(in->AsArg().name)) {
cast_op_output_arg->AsArg().type = // Remove the old link
LiteType::GetTensorTy(from.target(), to.precision(), from.layout()); RemoveDirectedLink(in, inst_node);
auto* cast_inst = graph->NewInstructNode(); // Update the original instruction OpDesc.
// Update its input to the cast_op_output_name
// Add new link, newarg->inst
DirectedLink(cast_nodes->at(in->AsArg().name),
inst_node); // [io_copy kernel]'s output -> [current kernel]
// reset opdesc and update kernel information
UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
} else {
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name);
cast_op_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), to.precision(), from.layout());
auto* cast_inst = graph->NewInstructNode();
// create Op and kernels. // create Op and kernels.
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
std::string cast_type = in_persist ? "calib_once" : "calib"; std::string cast_type = in_persist ? "calib_once" : "calib";
cast_op_output_arg->AsArg().is_persist = in_persist; cast_op_output_arg->AsArg().is_persist = in_persist;
auto cast_op = LiteOpRegistry::Global().Create(cast_type); auto cast_op = LiteOpRegistry::Global().Create(cast_type);
CHECK(cast_op) << "create op [" << cast_op << "] failed"; CHECK(cast_op) << "create op [" << cast_op << "] failed";
// Create the new var manually. // Create the new var manually.
inst_node->AsStmt().op()->scope()->Var(cast_op_output_name); inst_node->AsStmt().op()->scope()->Var(cast_op_output_name);
// Create Calib Instruction. // Create Calib Instruction.
cpp::OpDesc op_desc; cpp::OpDesc op_desc;
op_desc.SetType(cast_type); op_desc.SetType(cast_type);
op_desc.SetInput("Input", {in->AsArg().name}); op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {cast_op_output_name}); op_desc.SetOutput("Out", {cast_op_output_name});
float scale; float scale;
if (InferScale(in, inst_node, &scale)) { if (InferScale(in, inst_node, &scale)) {
op_desc.SetAttr("scale", scale); op_desc.SetAttr("scale", scale);
} }
cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = cast_op->CreateKernels(valid_places); auto kernels = cast_op->CreateKernels(valid_places);
std::vector<std::unique_ptr<KernelBase>> selected_kernels; std::vector<std::unique_ptr<KernelBase>> selected_kernels;
bool is_found = false; bool is_found = false;
for (auto& kernel : kernels) { for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TypeCompatible(*in_arg_ty, from) && if (TypeCompatible(*in_arg_ty, from) &&
out_arg_ty->precision() == to.precision()) { out_arg_ty->precision() == to.precision()) {
is_found = true; is_found = true;
selected_kernels.emplace_back(std::move(kernel)); selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel // we pick the kernel
cast_inst->AsStmt(cast_type, std::move(selected_kernels), cast_op); cast_inst->AsStmt(cast_type, std::move(selected_kernels), cast_op);
break; (*cast_nodes)[in->AsArg().name] = cast_op_output_arg;
break;
}
} }
}
CHECK(is_found) << "Can't find a Cast kernel for Cast op: " << from << ":" CHECK(is_found) << "Can't find a Cast kernel for Cast op: " << from << ":"
<< in->AsArg().name << "->" << to << ":" << in->AsArg().name << "->" << to << ":"
<< inst_node->AsStmt().op_info()->Type(); << inst_node->AsStmt().op_info()->Type();
// Remove the old link // Remove the old link
RemoveDirectedLink(in, inst_node); RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc. // Update the original instruction OpDesc.
// Update its input to the io_copy_output_name // Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst // Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(in, cast_inst); DirectedLink(in, cast_inst);
DirectedLink(cast_inst, cast_op_output_arg); DirectedLink(cast_inst, cast_op_output_arg);
DirectedLink(cast_op_output_arg, inst_node); DirectedLink(cast_op_output_arg, inst_node);
// reset opdesc and update kernel information // reset opdesc and update kernel information
UpdateInputs( UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name); inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
}
// recreate the op // recreate the op
auto original_selected_kernel = auto original_selected_kernel =
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "lite/core/mir/pass.h" #include "lite/core/mir/pass.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -34,13 +35,17 @@ class PrecisionCastPass : public ProgramPass { ...@@ -34,13 +35,17 @@ class PrecisionCastPass : public ProgramPass {
public: public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override; void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); void ComplementInputs(SSAGraph* graph,
Node* inst_node,
Node* in,
std::unordered_map<std::string, Node*>* cast_nodes);
void AddCastInst(const Type& from, void AddCastInst(const Type& from,
const Type& to, const Type& to,
Node* in, Node* in,
SSAGraph* graph, SSAGraph* graph,
Node* inst_node, Node* inst_node,
std::unordered_map<std::string, Node*>* cast_nodes,
const std::vector<Place>& valid_places); const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places); void SetValidPlaces(const std::vector<Place>& valid_places);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册