提交 80dcd70e 编写于 作者: Z zhaoying 提交者: jackzhang235

(bugfix): args for some op(for example, conv2d) does not has type attr

上级 f07892b6
......@@ -104,47 +104,49 @@ void Int64ToInt32Pass::ChangeInt64ToInt32IfNeeded(Node* inst_node) {
for (auto* in : inst_node->inlinks) {
CHECK(in->IsRoleSet());
CHECK(in->IsArg());
CHECK(in->AsArg().type);
auto in_arg_name = in->AsArg().name;
std::string tmp;
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto rt_precision = in->AsArg().type->precision();
// ================== DEBUG INFO ===================
VLOG(6) << "op :" << op_type;
VLOG(6) << "arg name :" << in_arg_name;
VLOG(6) << "arg :" << tmp;
VLOG(6) << "runtime precision :" << PrecisionToStr(rt_precision);
// ================== DEBUG END ===================
if (rt_precision == PRECISION(kInt64)) {
VLOG(6) << "change precison from int64 to int32";
in->AsArg().type =
const_cast<Type*>(Type::GetTensorTy(in->AsArg().type->target(),
PRECISION(kInt32),
in->AsArg().type->layout()));
if (in->AsArg().type) {
auto in_arg_name = in->AsArg().name;
std::string tmp;
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto rt_precision = in->AsArg().type->precision();
// ================== DEBUG INFO ===================
VLOG(6) << "op :" << op_type;
VLOG(6) << "arg name :" << in_arg_name;
VLOG(6) << "arg :" << tmp;
VLOG(6) << "runtime precision :" << PrecisionToStr(rt_precision);
// ================== DEBUG END ===================
if (rt_precision == PRECISION(kInt64)) {
VLOG(6) << "change precison from int64 to int32";
in->AsArg().type =
const_cast<Type*>(Type::GetTensorTy(in->AsArg().type->target(),
PRECISION(kInt32),
in->AsArg().type->layout()));
}
}
}
for (auto* out : inst_node->outlinks) {
CHECK(out->IsRoleSet());
CHECK(out->IsArg());
CHECK(out->AsArg().type);
auto out_arg_name = out->AsArg().name;
std::string tmp;
CHECK(inst.op_info()->GetOutputArgname(out_arg_name, &tmp));
auto rt_precision = out->AsArg().type->precision();
// ================== DEBUG INFO ===================
VLOG(6) << "op :" << op_type;
VLOG(6) << "arg name :" << out_arg_name;
VLOG(6) << "arg :" << tmp;
VLOG(6) << "runtime precision :" << PrecisionToStr(rt_precision);
// ================== DEBUG END ===================
if (rt_precision == PRECISION(kInt64)) {
VLOG(6) << "change precison from int64 to int32";
out->AsArg().type =
const_cast<Type*>(Type::GetTensorTy(out->AsArg().type->target(),
PRECISION(kInt32),
out->AsArg().type->layout()));
if (out->AsArg().type) {
auto out_arg_name = out->AsArg().name;
std::string tmp;
CHECK(inst.op_info()->GetOutputArgname(out_arg_name, &tmp));
auto rt_precision = out->AsArg().type->precision();
// ================== DEBUG INFO ===================
VLOG(6) << "op :" << op_type;
VLOG(6) << "arg name :" << out_arg_name;
VLOG(6) << "arg :" << tmp;
VLOG(6) << "runtime precision :" << PrecisionToStr(rt_precision);
// ================== DEBUG END ===================
if (rt_precision == PRECISION(kInt64)) {
VLOG(6) << "change precison from int64 to int32";
out->AsArg().type =
const_cast<Type*>(Type::GetTensorTy(out->AsArg().type->target(),
PRECISION(kInt32),
out->AsArg().type->layout()));
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册