提交 f23bfe0d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1202 Fix tensor print order

Merge pull request !1202 from zjun/fix_tensor_print
...@@ -50,6 +50,7 @@ bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *cons ...@@ -50,6 +50,7 @@ bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *cons
if (tensor_shape == nullptr) { if (tensor_shape == nullptr) {
return false; return false;
} }
MS_EXCEPTION_IF_NULL(dims);
std::string shape_str = input_shape_str; std::string shape_str = input_shape_str;
if (shape_str.size() <= 2) { if (shape_str.size() <= 2) {
return false; return false;
...@@ -71,6 +72,8 @@ bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *cons ...@@ -71,6 +72,8 @@ bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *cons
bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *const print_tensor, bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *const print_tensor,
const size_t &memory_size) { const size_t &memory_size) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(print_tensor);
auto *tensor_data_ptr = static_cast<uint8_t *>(print_tensor->data_c(true)); auto *tensor_data_ptr = static_cast<uint8_t *>(print_tensor->data_c(true));
MS_EXCEPTION_IF_NULL(tensor_data_ptr); MS_EXCEPTION_IF_NULL(tensor_data_ptr);
auto cp_ret = auto cp_ret =
...@@ -83,55 +86,57 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co ...@@ -83,55 +86,57 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
} }
template <typename T> template <typename T>
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type) { void PrintScalarToString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *buf) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr); const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr);
std::ostringstream buf_scalar; *buf << "Tensor shape:[1] " << tensor_type;
buf_scalar << "Tensor shape :1 " << tensor_type; *buf << "\nval:";
buf_scalar << "\nval:"; *buf << *data_ptr << "\n";
buf_scalar << *data_ptr;
std::cout << buf_scalar.str() << std::endl;
} }
void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type) { void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *buf) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr); const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr);
std::ostringstream buf_scalar; *buf << "Tensor shape:[1] " << tensor_type;
buf_scalar << "Tensor shape :1 " << tensor_type; *buf << "\nval:";
buf_scalar << "\nval:"; if (*data_ptr) {
if (*data_ptr == true) { *buf << "True\n";
buf_scalar << "True";
} else { } else {
buf_scalar << "False"; *buf << "False\n";
} }
std::cout << buf_scalar.str() << std::endl;
} }
void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type) { void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type, std::ostringstream *buf) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
auto type_iter = print_type_map.find(tensor_type); auto type_iter = print_type_map.find(tensor_type);
auto type_id = type_iter->second; auto type_id = type_iter->second;
if (type_id == TypeId::kNumberTypeBool) { if (type_id == TypeId::kNumberTypeBool) {
PrintScalarToBoolString(str_data_ptr, tensor_type); PrintScalarToBoolString(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeInt8) { } else if (type_id == TypeId::kNumberTypeInt8) {
PrintScalarToString<int8_t>(str_data_ptr, tensor_type); PrintScalarToString<int8_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeUInt8) { } else if (type_id == TypeId::kNumberTypeUInt8) {
PrintScalarToString<uint8_t>(str_data_ptr, tensor_type); PrintScalarToString<uint8_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeInt16) { } else if (type_id == TypeId::kNumberTypeInt16) {
PrintScalarToString<int16_t>(str_data_ptr, tensor_type); PrintScalarToString<int16_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeUInt16) { } else if (type_id == TypeId::kNumberTypeUInt16) {
PrintScalarToString<uint16_t>(str_data_ptr, tensor_type); PrintScalarToString<uint16_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeInt32) { } else if (type_id == TypeId::kNumberTypeInt32) {
PrintScalarToString<int32_t>(str_data_ptr, tensor_type); PrintScalarToString<int32_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeUInt32) { } else if (type_id == TypeId::kNumberTypeUInt32) {
PrintScalarToString<uint32_t>(str_data_ptr, tensor_type); PrintScalarToString<uint32_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeInt64) { } else if (type_id == TypeId::kNumberTypeInt64) {
PrintScalarToString<int64_t>(str_data_ptr, tensor_type); PrintScalarToString<int64_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeUInt64) { } else if (type_id == TypeId::kNumberTypeUInt64) {
PrintScalarToString<uint64_t>(str_data_ptr, tensor_type); PrintScalarToString<uint64_t>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeFloat16) { } else if (type_id == TypeId::kNumberTypeFloat16) {
PrintScalarToString<float16>(str_data_ptr, tensor_type); PrintScalarToString<float16>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeFloat32) { } else if (type_id == TypeId::kNumberTypeFloat32) {
PrintScalarToString<float>(str_data_ptr, tensor_type); PrintScalarToString<float>(str_data_ptr, tensor_type, buf);
} else if (type_id == TypeId::kNumberTypeFloat64) { } else if (type_id == TypeId::kNumberTypeFloat64) {
PrintScalarToString<double>(str_data_ptr, tensor_type); PrintScalarToString<double>(str_data_ptr, tensor_type, buf);
} else { } else {
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << "."; MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << ".";
} }
...@@ -142,11 +147,7 @@ bool judgeLengthValid(const size_t str_len, const string &tensor_type) { ...@@ -142,11 +147,7 @@ bool judgeLengthValid(const size_t str_len, const string &tensor_type) {
if (type_iter == type_size_map.end()) { if (type_iter == type_size_map.end()) {
MS_LOG(EXCEPTION) << "type of scalar to print is not support."; MS_LOG(EXCEPTION) << "type of scalar to print is not support.";
} }
return str_len == type_iter->second;
if (str_len != type_iter->second) {
return false;
}
return true;
} }
#ifndef NO_DLIB #ifndef NO_DLIB
...@@ -166,7 +167,7 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { ...@@ -166,7 +167,7 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) { if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
MS_LOG(EXCEPTION) << "Print op receive data length is invalid."; MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
} }
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_); convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_, &buf);
continue; continue;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册