提交 41c745da 编写于 作者: - --get 提交者: jackzhang235

(feat): add int64 to int32 pass

上级 ce7548d9
...@@ -46,6 +46,8 @@ USE_MIR_PASS(multi_stream_analysis_pass); ...@@ -46,6 +46,8 @@ USE_MIR_PASS(multi_stream_analysis_pass);
USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(npu_subgraph_pass);
USE_MIR_PASS(xpu_subgraph_pass); USE_MIR_PASS(xpu_subgraph_pass);
USE_MIR_PASS(int64_to_int32_pass);
// USE_MIR_PASS(identity_cast_eliminate_pass);
USE_MIR_PASS(mlu_subgraph_pass); USE_MIR_PASS(mlu_subgraph_pass);
USE_MIR_PASS(mlu_postprocess_pass); USE_MIR_PASS(mlu_postprocess_pass);
USE_MIR_PASS(weight_quantization_preprocess_pass); USE_MIR_PASS(weight_quantization_preprocess_pass);
......
...@@ -35,6 +35,7 @@ lite_cc_library(mir_passes ...@@ -35,6 +35,7 @@ lite_cc_library(mir_passes
generate_program_pass.cc generate_program_pass.cc
argument_type_display_pass.cc argument_type_display_pass.cc
demo_pass.cc demo_pass.cc
int64_to_int32_pass.cc
runtime_context_assign_pass.cc runtime_context_assign_pass.cc
memory_optimize_pass.cc memory_optimize_pass.cc
multi_stream_analysis_pass.cc multi_stream_analysis_pass.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/int64_to_int32_pass.h"
#include <list>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/operators/subgraph_op.h"
namespace paddle {
namespace lite {
namespace mir {
void Int64ToInt32Pass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::list<Node*> nodes;
for (auto& node : graph->StmtTopologicalOrder()) {
nodes.push_back(node);
}
for (auto& node : nodes) {
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
if (!node->IsStmt() || node->AsStmt().op_type() == "feed") continue;
if (!node->IsStmt() || node->AsStmt().op_type() == "fetch") continue;
auto inlinks = node->inlinks;
ChangeInt64ToInt32IfNeeded(node);
}
}
/*
some op decide data type beside input or output tensor from op_param:
3. fillconstant
4. FillConstantBatchSiz
5. uniformrandom
int64 input or output from arm kernels
1. argmax:
2. beam_search
3. gather
4. lookup_table
5. read_from_arry
6. topk
7. write_to_arry
8. feed
9. compare
10. ctc
may support int64
1. cast
2. concat
*/
void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) {
CHECK(inst_node->IsStmt());
auto& inst = inst_node->AsStmt();
std::string op_type = inst.op_info()->Type();
// TODO(zhaoying): support more op
if (op_type == "cast") {
auto in_dtype = inst.op_info()->GetAttr<int>("in_dtype");
auto out_dtype = inst.op_info()->GetAttr<int>("out_dtype");
VLOG(6) << "in_dtype : " << in_dtype;
VLOG(6) << "out_dtype : " << out_dtype;
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
cpp::OpDesc* cast_opdesc = const_cast<OpInfo*>(inst.op_info());
cast_opdesc->SetAttr<int>("out_dtype", 2);
cast_opdesc->SetAttr<int>("in_dtype", 2);
}
if (op_type == "fill_constant") {
CHECK(0) << "int64_to_int32 pass do not expect fill_constant op for now";
} else if (op_type == "uniform_random") {
CHECK(0) << "int64_to_int32 pass do not expect uniform_random op for now";
// auto dtype = opdesc.GetAttr<int>("dtype");
// if (dtype == static_cast<int32_t>(lite::core::FluidType::INT64)) {
// opdesc.SetAttr<int>("dtype",static_cast<int32_t>(lite::core::FluidType::INT32);
// }
} else if (op_type == "fill_constant_batch_size_like") {
CHECK(0) << "int64_to_int32 pass do not expect "
"fill_constant_batch_size_like op for now";
}
for (auto* in : inst_node->inlinks) {
CHECK(in->IsRoleSet());
CHECK(in->IsArg());
CHECK(in->AsArg().type);
auto in_arg_name = in->AsArg().name;
std::string tmp;
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto rt_precision = in->AsArg().type->precision();
// ================== DEBUG INFO ===================
VLOG(6) << "op :" << op_type;
VLOG(6) << "arg name :" << in_arg_name;
VLOG(6) << "arg :" << tmp;
VLOG(6) << "runtime precision :" << PrecisionToStr(rt_precision);
// ================== DEBUG END ===================
if (rt_precision == PRECISION(kInt64)) {
VLOG(6) << "change precison from int64 to int32";
in->AsArg().type =
const_cast<Type*>(Type::GetTensorTy(in->AsArg().type->target(),
PRECISION(kInt32),
in->AsArg().type->layout()));
}
}
for (auto* out : inst_node->outlinks) {
CHECK(out->IsRoleSet());
CHECK(out->IsArg());
CHECK(out->AsArg().type);
auto out_arg_name = out->AsArg().name;
std::string tmp;
CHECK(inst.op_info()->GetOutputArgname(out_arg_name, &tmp));
auto rt_precision = out->AsArg().type->precision();
// ================== DEBUG INFO ===================
VLOG(6) << "op :" << op_type;
VLOG(6) << "arg name :" << out_arg_name;
VLOG(6) << "arg :" << tmp;
VLOG(6) << "runtime precision :" << PrecisionToStr(rt_precision);
// ================== DEBUG END ===================
if (rt_precision == PRECISION(kInt64)) {
VLOG(6) << "change precison from int64 to int32";
out->AsArg().type =
const_cast<Type*>(Type::GetTensorTy(out->AsArg().type->target(),
PRECISION(kInt32),
out->AsArg().type->layout()));
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(int64_to_int32_pass, paddle::lite::mir::Int64ToInt32Pass)
.BindTargets({TARGET(kMLU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace mir {
class Int64ToInt32Pass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void ChangeInt64ToInt32IfNeeded(Node* inst_node);
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -90,6 +90,8 @@ class Optimizer { ...@@ -90,6 +90,8 @@ class Optimizer {
"xpu_subgraph_pass", "xpu_subgraph_pass",
"bm_subgraph_pass", "bm_subgraph_pass",
"rknpu_subgraph_pass", "rknpu_subgraph_pass",
"int64_to_int32_pass",
// "identity_cast_eliminate_pass",
"mlu_subgraph_pass", "mlu_subgraph_pass",
"static_kernel_pick_pass", // pick original kernel from graph "static_kernel_pick_pass", // pick original kernel from graph
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册