提交 80dcd70e 编写于 作者: Z zhaoying 提交者: jackzhang235

(bugfix): args for some op(for example, conv2d) does not has type attr

上级 f07892b6
...@@ -104,47 +104,49 @@ void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) { ...@@ -104,47 +104,49 @@ void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) {
for (auto* in : inst_node->inlinks) { for (auto* in : inst_node->inlinks) {
CHECK(in->IsRoleSet()); CHECK(in->IsRoleSet());
CHECK(in->IsArg()); CHECK(in->IsArg());
CHECK(in->AsArg().type); if (in->AsArg().type) {
auto in_arg_name = in->AsArg().name; auto in_arg_name = in->AsArg().name;
std::string tmp; std::string tmp;
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto rt_precision = in->AsArg().type->precision(); auto rt_precision = in->AsArg().type->precision();
// ================== DEBUG INFO =================== // ================== DEBUG INFO ===================
VLOG(6) << "op :" << op_type; VLOG(6) << "op :" << op_type;
VLOG(6) << "arg name :" << in_arg_name; VLOG(6) << "arg name :" << in_arg_name;
VLOG(6) << "arg :" << tmp; VLOG(6) << "arg :" << tmp;
VLOG(6) << "runtime precision :" << PrecisionToStr(rt_precision); VLOG(6) << "runtime precision :" << PrecisionToStr(rt_precision);
// ================== DEBUG END =================== // ================== DEBUG END ===================
if (rt_precision == PRECISION(kInt64)) { if (rt_precision == PRECISION(kInt64)) {
VLOG(6) << "change precison from int64 to int32"; VLOG(6) << "change precison from int64 to int32";
in->AsArg().type = in->AsArg().type =
const_cast<Type*>(Type::GetTensorTy(in->AsArg().type->target(), const_cast<Type*>(Type::GetTensorTy(in->AsArg().type->target(),
PRECISION(kInt32), PRECISION(kInt32),
in->AsArg().type->layout())); in->AsArg().type->layout()));
}
} }
} }
for (auto* out : inst_node->outlinks) { for (auto* out : inst_node->outlinks) {
CHECK(out->IsRoleSet()); CHECK(out->IsRoleSet());
CHECK(out->IsArg()); CHECK(out->IsArg());
CHECK(out->AsArg().type); if (out->AsArg().type) {
auto out_arg_name = out->AsArg().name; auto out_arg_name = out->AsArg().name;
std::string tmp; std::string tmp;
CHECK(inst.op_info()->GetOutputArgname(out_arg_name, &tmp)); CHECK(inst.op_info()->GetOutputArgname(out_arg_name, &tmp));
auto rt_precision = out->AsArg().type->precision(); auto rt_precision = out->AsArg().type->precision();
// ================== DEBUG INFO =================== // ================== DEBUG INFO ===================
VLOG(6) << "op :" << op_type; VLOG(6) << "op :" << op_type;
VLOG(6) << "arg name :" << out_arg_name; VLOG(6) << "arg name :" << out_arg_name;
VLOG(6) << "arg :" << tmp; VLOG(6) << "arg :" << tmp;
VLOG(6) << "runtime precision :" << PrecisionToStr(rt_precision); VLOG(6) << "runtime precision :" << PrecisionToStr(rt_precision);
// ================== DEBUG END =================== // ================== DEBUG END ===================
if (rt_precision == PRECISION(kInt64)) { if (rt_precision == PRECISION(kInt64)) {
VLOG(6) << "change precison from int64 to int32"; VLOG(6) << "change precison from int64 to int32";
out->AsArg().type = out->AsArg().type =
const_cast<Type*>(Type::GetTensorTy(out->AsArg().type->target(), const_cast<Type*>(Type::GetTensorTy(out->AsArg().type->target(),
PRECISION(kInt32), PRECISION(kInt32),
out->AsArg().type->layout())); out->AsArg().type->layout()));
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册