未验证 提交 eb4f381c 编写于 作者: W Wilber 提交者: GitHub

for multiple-use args, only cast once. test=develop (#3404)

for multiple-use args, only cast once
上级 fbe0799e
......@@ -80,7 +80,7 @@ static bool InferScaleFromSubgraph(std::string var_name,
auto input_or_output_scales = op_info->GetAttr<std::vector<float>>(attr_name);
auto size = input_or_output_names.size();
CHECK(size == input_or_output_scales.size());
for (int i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
if (input_or_output_names[i] == var_name) {
*scale = input_or_output_scales[i];
return true;
......@@ -137,18 +137,23 @@ void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
nodes.push_back(node);
}
// record the copied node.
std::unordered_map<std::string, Node*> cast_nodes;
for (auto& node : nodes) {
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
ComplementInputs(graph.get(), node, in, &cast_nodes);
}
}
}
void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
Node* inst_node,
Node* in) {
void PrecisionCastPass::ComplementInputs(
SSAGraph* graph,
Node* inst_node,
Node* in,
std::unordered_map<std::string, Node*>* cast_nodes) {
// If this input is out of date.
if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
......@@ -184,16 +189,19 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
in,
graph,
inst_node,
cast_nodes,
graph->valid_places());
}
}
void PrecisionCastPass::AddCastInst(const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
const std::vector<Place>& valid_places) {
void PrecisionCastPass::AddCastInst(
const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
std::unordered_map<std::string, Node*>* cast_nodes,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst
......@@ -203,66 +211,80 @@ void PrecisionCastPass::AddCastInst(const Type& from,
auto cast_op_output_name = in->AsArg().name + "/precision_trans";
// in->AsArg().name + "/precision_trans/" +
// paddle::lite::to_string(node_id());
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name);
cast_op_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), to.precision(), from.layout());
auto* cast_inst = graph->NewInstructNode();
if (cast_nodes->count(in->AsArg().name)) {
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the cast_op_output_name
// Add new link, newarg->inst
DirectedLink(cast_nodes->at(in->AsArg().name),
inst_node); // [io_copy kernel]'s output -> [current kernel]
// reset opdesc and update kernel information
UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
} else {
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name);
cast_op_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), to.precision(), from.layout());
auto* cast_inst = graph->NewInstructNode();
// create Op and kernels.
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
std::string cast_type = in_persist ? "calib_once" : "calib";
cast_op_output_arg->AsArg().is_persist = in_persist;
auto cast_op = LiteOpRegistry::Global().Create(cast_type);
CHECK(cast_op) << "create op [" << cast_op << "] failed";
// create Op and kernels.
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
std::string cast_type = in_persist ? "calib_once" : "calib";
cast_op_output_arg->AsArg().is_persist = in_persist;
auto cast_op = LiteOpRegistry::Global().Create(cast_type);
CHECK(cast_op) << "create op [" << cast_op << "] failed";
// Create the new var manually.
inst_node->AsStmt().op()->scope()->Var(cast_op_output_name);
// Create the new var manually.
inst_node->AsStmt().op()->scope()->Var(cast_op_output_name);
// Create Calib Instruction.
cpp::OpDesc op_desc;
op_desc.SetType(cast_type);
op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {cast_op_output_name});
float scale;
if (InferScale(in, inst_node, &scale)) {
op_desc.SetAttr("scale", scale);
}
// Create Calib Instruction.
cpp::OpDesc op_desc;
op_desc.SetType(cast_type);
op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {cast_op_output_name});
float scale;
if (InferScale(in, inst_node, &scale)) {
op_desc.SetAttr("scale", scale);
}
cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = cast_op->CreateKernels(valid_places);
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
bool is_found = false;
for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TypeCompatible(*in_arg_ty, from) &&
out_arg_ty->precision() == to.precision()) {
is_found = true;
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
cast_inst->AsStmt(cast_type, std::move(selected_kernels), cast_op);
break;
cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = cast_op->CreateKernels(valid_places);
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
bool is_found = false;
for (auto& kernel : kernels) {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TypeCompatible(*in_arg_ty, from) &&
out_arg_ty->precision() == to.precision()) {
is_found = true;
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
cast_inst->AsStmt(cast_type, std::move(selected_kernels), cast_op);
(*cast_nodes)[in->AsArg().name] = cast_op_output_arg;
break;
}
}
}
CHECK(is_found) << "Can't find a Cast kernel for Cast op: " << from << ":"
<< in->AsArg().name << "->" << to << ":"
<< inst_node->AsStmt().op_info()->Type();
CHECK(is_found) << "Can't find a Cast kernel for Cast op: " << from << ":"
<< in->AsArg().name << "->" << to << ":"
<< inst_node->AsStmt().op_info()->Type();
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Remove the old link
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(in, cast_inst);
DirectedLink(cast_inst, cast_op_output_arg);
DirectedLink(cast_op_output_arg, inst_node);
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(in, cast_inst);
DirectedLink(cast_inst, cast_op_output_arg);
DirectedLink(cast_op_output_arg, inst_node);
// reset opdesc and update kernel information
UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
// reset opdesc and update kernel information
UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
}
// recreate the op
auto original_selected_kernel =
......
......@@ -16,6 +16,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/op_registry.h"
......@@ -34,13 +35,17 @@ class PrecisionCastPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void ComplementInputs(SSAGraph* graph,
Node* inst_node,
Node* in,
std::unordered_map<std::string, Node*>* cast_nodes);
void AddCastInst(const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
std::unordered_map<std::string, Node*>* cast_nodes,
const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册