提交 4d9491b5 编写于 作者: B buxue

fix bug of auto cast when there is scalar

上级 871d6524
......@@ -65,21 +65,9 @@ void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &arg
}
}
}
bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_number, const TypeId &scalar_type,
const size_t &s_type_number) {
if (scalar_type == kNumberTypeFloat16 || scalar_type == kNumberTypeFloat32 || scalar_type == kNumberTypeFloat64) {
if (tensor_type == kNumberTypeFloat16 || tensor_type == kNumberTypeFloat32 || tensor_type == kNumberTypeFloat64) {
return t_type_number >= s_type_number;
}
return false;
}
return true;
}
void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type,
const size_t type_number) {
void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) {
*max_type_id = type_id;
*max_type = type;
*max_type_number = type_number;
}
......@@ -118,7 +106,6 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices,
const std::set<size_t> &write_indices) {
TypeId max_type_id = kTypeUnknown;
TypeId max_type = kTypeUnknown;
size_t max_type_number = 0;
bool has_int8 = false;
for (const auto &index : indices) {
......@@ -128,6 +115,9 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) {
continue;
}
if (arg_type != kObjectTypeTensorType) {
continue;
}
auto it = type_map.find(arg_type_id);
if (it == type_map.end()) {
continue;
......@@ -136,24 +126,11 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
has_int8 = true;
}
if (max_type_id == kTypeUnknown) {
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second);
continue;
}
if (max_type == arg_type) {
if (it->second > max_type_number) {
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
}
} else {
if (arg_type == kObjectTypeTensorType) {
if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) {
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
}
} else {
if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) {
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
}
}
if (it->second > max_type_number) {
SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册