提交 f23bfe0d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1202 Fix tensor print order

Merge pull request !1202 from zjun/fix_tensor_print
......@@ -50,6 +50,7 @@ bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *cons
if (tensor_shape == nullptr) {
return false;
}
MS_EXCEPTION_IF_NULL(dims);
std::string shape_str = input_shape_str;
if (shape_str.size() <= 2) {
return false;
......@@ -71,6 +72,8 @@ bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *cons
bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *const print_tensor,
const size_t &memory_size) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(print_tensor);
auto *tensor_data_ptr = static_cast<uint8_t *>(print_tensor->data_c(true));
MS_EXCEPTION_IF_NULL(tensor_data_ptr);
auto cp_ret =
......@@ -83,55 +86,57 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
}
template <typename T>
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type) {
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *buf) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr);
std::ostringstream buf_scalar;
buf_scalar << "Tensor shape :1 " << tensor_type;
buf_scalar << "\nval:";
buf_scalar << *data_ptr;
std::cout << buf_scalar.str() << std::endl;
*buf << "Tensor shape:[1] " << tensor_type;
*buf << "\nval:";
*buf << *data_ptr << "\n";
}
void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type) {
void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *buf) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr);
std::ostringstream buf_scalar;
buf_scalar << "Tensor shape :1 " << tensor_type;
buf_scalar << "\nval:";
if (*data_ptr == true) {
buf_scalar << "True";
*buf << "Tensor shape:[1] " << tensor_type;
*buf << "\nval:";
if (*data_ptr) {
*buf << "True\n";
} else {
buf_scalar << "False";
*buf << "False\n";
}
std::cout << buf_scalar.str() << std::endl;
}
void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type) {
void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type, std::ostringstream *buf) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
auto type_iter = print_type_map.find(tensor_type);
auto type_id = type_iter->second;
if (type_id == TypeId::kNumberTypeBool) {
PrintScalarToBoolString(str_data_ptr, tensor_type);
PrintScalarToBoolString(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeInt8) {
PrintScalarToString<int8_t>(str_data_ptr, tensor_type);
PrintScalarToString<int8_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeUInt8) {
PrintScalarToString<uint8_t>(str_data_ptr, tensor_type);
PrintScalarToString<uint8_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeInt16) {
PrintScalarToString<int16_t>(str_data_ptr, tensor_type);
PrintScalarToString<int16_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeUInt16) {
PrintScalarToString<uint16_t>(str_data_ptr, tensor_type);
PrintScalarToString<uint16_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeInt32) {
PrintScalarToString<int32_t>(str_data_ptr, tensor_type);
PrintScalarToString<int32_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeUInt32) {
PrintScalarToString<uint32_t>(str_data_ptr, tensor_type);
PrintScalarToString<uint32_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeInt64) {
PrintScalarToString<int64_t>(str_data_ptr, tensor_type);
PrintScalarToString<int64_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeUInt64) {
PrintScalarToString<uint64_t>(str_data_ptr, tensor_type);
PrintScalarToString<uint64_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeFloat16) {
PrintScalarToString<float16>(str_data_ptr, tensor_type);
PrintScalarToString<float16>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeFloat32) {
PrintScalarToString<float>(str_data_ptr, tensor_type);
PrintScalarToString<float>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeFloat64) {
PrintScalarToString<double>(str_data_ptr, tensor_type);
PrintScalarToString<double>(str_data_ptr, tensor_type, buf);
} else {
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << ".";
}
......@@ -142,11 +147,7 @@ bool judgeLengthValid(const size_t str_len, const string &tensor_type) {
if (type_iter == type_size_map.end()) {
MS_LOG(EXCEPTION) << "type of scalar to print is not support.";
}
if (str_len != type_iter->second) {
return false;
}
return true;
return str_len == type_iter->second;
}
#ifndef NO_DLIB
......@@ -166,7 +167,7 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
}
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_);
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_, &buf);
continue;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册