提交 f07892b6 编写于 作者: Z zhaoying 提交者: jackzhang235

(bugfix): change cast dtype to int32 only when dtype is int64

上级 41c745da
...@@ -44,10 +44,12 @@ void Int64ToInt32Pass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -44,10 +44,12 @@ void Int64ToInt32Pass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} }
/* /*
some op decide data type beside input or output tensor from op_param: ops which has datatype param
3. fillconstant 1. cast
4. FillConstantBatchSiz 2. fillconstant
5. uniformrandom 3. FillConstantBatchSiz
4. uniformrandom
5. assign
int64 input or output from arm kernels int64 input or output from arm kernels
1. argmax: 1. argmax:
...@@ -61,9 +63,16 @@ void Int64ToInt32Pass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -61,9 +63,16 @@ void Int64ToInt32Pass::Apply(const std::unique_ptr<SSAGraph>& graph) {
9. compare 9. compare
10. ctc 10. ctc
may support int64 int64 input or output from x86 kernels
1. cast 1. gather:
2. concat 2. lookup_table
3. reshape
4. sequence_reshape
5. sequence_reverse
6. sequence_unpad
BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
SIZE_T = 19;UINT8 = 20;INT8 = 21;
*/ */
void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) { void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) {
...@@ -76,24 +85,22 @@ void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) { ...@@ -76,24 +85,22 @@ void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) {
auto out_dtype = inst.op_info()->GetAttr<int>("out_dtype"); auto out_dtype = inst.op_info()->GetAttr<int>("out_dtype");
VLOG(6) << "in_dtype : " << in_dtype; VLOG(6) << "in_dtype : " << in_dtype;
VLOG(6) << "out_dtype : " << out_dtype; VLOG(6) << "out_dtype : " << out_dtype;
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
cpp::OpDesc* cast_opdesc = const_cast<OpInfo*>(inst.op_info()); cpp::OpDesc* cast_opdesc = const_cast<OpInfo*>(inst.op_info());
cast_opdesc->SetAttr<int>("out_dtype", 2); if (in_dtype == 3) { // INT64
cast_opdesc->SetAttr<int>("in_dtype", 2); cast_opdesc->SetAttr<int>("in_dtype", 2); // INT32
}
if (out_dtype == 3) { // INT64
cast_opdesc->SetAttr<int>("out_dtype", 2); // INT32
} }
if (op_type == "fill_constant") {
CHECK(0) << "int64_to_int32 pass do not expect fill_constant op for now";
} else if (op_type == "uniform_random") {
CHECK(0) << "int64_to_int32 pass do not expect uniform_random op for now";
// auto dtype = opdesc.GetAttr<int>("dtype");
// if (dtype == static_cast<int32_t>(lite::core::FluidType::INT64)) {
// opdesc.SetAttr<int>("dtype",static_cast<int32_t>(lite::core::FluidType::INT32);
// }
} else if (op_type == "fill_constant_batch_size_like") {
CHECK(0) << "int64_to_int32 pass do not expect "
"fill_constant_batch_size_like op for now";
} }
if (op_type == "fill_constant" ||
op_type == "fill_constant_batch_size_like" ||
op_type == "uniform_random" || op_type == "assign") {
CHECK(inst.op_info()->GetAttr<int>("dtype") != 3)
<< "mlu does not expect int64 " << op_type << " op";
}
for (auto* in : inst_node->inlinks) { for (auto* in : inst_node->inlinks) {
CHECK(in->IsRoleSet()); CHECK(in->IsRoleSet());
CHECK(in->IsArg()); CHECK(in->IsArg());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册