提交 f07892b6 编写于 作者: Z zhaoying 提交者: jackzhang235

(bugfix): change cast dtype to int32 only when dtype is int64

上级 41c745da
......@@ -44,10 +44,12 @@ void Int64ToInt32Pass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
/*
some op decide data type beside input or output tensor from op_param:
3. fillconstant
4. FillConstantBatchSiz
5. uniformrandom
ops which has datatype param
1. cast
2. fillconstant
3. FillConstantBatchSiz
4. uniformrandom
5. assign
int64 input or output from arm kernels
1. argmax:
......@@ -61,9 +63,16 @@ void Int64ToInt32Pass::Apply(const std::unique_ptr<SSAGraph>& graph) {
9. compare
10. ctc
may support int64
1. cast
2. concat
int64 input or output from x86 kernels
1. gather:
2. lookup_table
3. reshape
4. sequence_reshape
5. sequence_reverse
6. sequence_unpad
BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
SIZE_T = 19;UINT8 = 20;INT8 = 21;
*/
void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) {
......@@ -76,24 +85,22 @@ void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) {
auto out_dtype = inst.op_info()->GetAttr<int>("out_dtype");
VLOG(6) << "in_dtype : " << in_dtype;
VLOG(6) << "out_dtype : " << out_dtype;
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
cpp::OpDesc* cast_opdesc = const_cast<OpInfo*>(inst.op_info());
cast_opdesc->SetAttr<int>("out_dtype", 2);
cast_opdesc->SetAttr<int>("in_dtype", 2);
if (in_dtype == 3) { // INT64
cast_opdesc->SetAttr<int>("in_dtype", 2); // INT32
}
if (out_dtype == 3) { // INT64
cast_opdesc->SetAttr<int>("out_dtype", 2); // INT32
}
}
if (op_type == "fill_constant") {
CHECK(0) << "int64_to_int32 pass do not expect fill_constant op for now";
} else if (op_type == "uniform_random") {
CHECK(0) << "int64_to_int32 pass do not expect uniform_random op for now";
// auto dtype = opdesc.GetAttr<int>("dtype");
// if (dtype == static_cast<int32_t>(lite::core::FluidType::INT64)) {
// opdesc.SetAttr<int>("dtype",static_cast<int32_t>(lite::core::FluidType::INT32);
// }
} else if (op_type == "fill_constant_batch_size_like") {
CHECK(0) << "int64_to_int32 pass do not expect "
"fill_constant_batch_size_like op for now";
if (op_type == "fill_constant" ||
op_type == "fill_constant_batch_size_like" ||
op_type == "uniform_random" || op_type == "assign") {
CHECK(inst.op_info()->GetAttr<int>("dtype") != 3)
<< "mlu does not expect int64 " << op_type << " op";
}
for (auto* in : inst_node->inlinks) {
CHECK(in->IsRoleSet());
CHECK(in->IsArg());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册