From f07892b6bb908cfa179cee2372a1764fdcaf91ee Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 8 Jun 2020 03:40:54 +0000 Subject: [PATCH] (bugfix): change cast dtype to int32 only when dtype is int64 --- lite/core/mir/int64_to_int32_pass.cc | 51 ++++++++++++++++------------ 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/lite/core/mir/int64_to_int32_pass.cc b/lite/core/mir/int64_to_int32_pass.cc index a08329993f..ca5c069aab 100644 --- a/lite/core/mir/int64_to_int32_pass.cc +++ b/lite/core/mir/int64_to_int32_pass.cc @@ -44,10 +44,12 @@ void Int64ToInt32Pass::Apply(const std::unique_ptr& graph) { } /* - some op decide data type beside input or output tensor from op_param: - 3. fillconstant - 4. FillConstantBatchSiz - 5. uniformrandom + ops which has datatype param + 1. cast + 2. fillconstant + 3. FillConstantBatchSiz + 4. uniformrandom + 5. assign int64 input or output from arm kernels 1. argmax: @@ -61,9 +63,16 @@ void Int64ToInt32Pass::Apply(const std::unique_ptr& graph) { 9. compare 10. ctc - may support int64 - 1. cast - 2. concat + int64 input or output from x86 kernels + 1. gather: + 2. lookup_table + 3. reshape + 4. sequence_reshape + 5. sequence_reverse + 6. sequence_unpad + + BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6; + SIZE_T = 19;UINT8 = 20;INT8 = 21; */ void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) { @@ -76,24 +85,22 @@ void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) { auto out_dtype = inst.op_info()->GetAttr("out_dtype"); VLOG(6) << "in_dtype : " << in_dtype; VLOG(6) << "out_dtype : " << out_dtype; - // BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6; - // SIZE_T = 19;UINT8 = 20;INT8 = 21; cpp::OpDesc* cast_opdesc = const_cast(inst.op_info()); - cast_opdesc->SetAttr("out_dtype", 2); - cast_opdesc->SetAttr("in_dtype", 2); + if (in_dtype == 3) { // INT64 + cast_opdesc->SetAttr("in_dtype", 2); // INT32 + } + if (out_dtype == 3) { // INT64 + cast_opdesc->SetAttr("out_dtype", 2); // INT32 + } } - if (op_type == "fill_constant") { - CHECK(0) << "int64_to_int32 pass do not expect fill_constant op for now"; - } else if (op_type == "uniform_random") { - CHECK(0) << "int64_to_int32 pass do not expect uniform_random op for now"; - // auto dtype = opdesc.GetAttr("dtype"); - // if (dtype == static_cast(lite::core::FluidType::INT64)) { - // opdesc.SetAttr("dtype",static_cast(lite::core::FluidType::INT32); - // } - } else if (op_type == "fill_constant_batch_size_like") { - CHECK(0) << "int64_to_int32 pass do not expect " - "fill_constant_batch_size_like op for now"; + + if (op_type == "fill_constant" || + op_type == "fill_constant_batch_size_like" || + op_type == "uniform_random" || op_type == "assign") { + CHECK(inst.op_info()->GetAttr("dtype") != 3) + << "mlu does not expect int64 " << op_type << " op"; } + for (auto* in : inst_node->inlinks) { CHECK(in->IsRoleSet()); CHECK(in->IsArg()); -- GitLab