提交 7bf8d269 编写于 作者: Z zjun

fix scalar print

上级 831ceba6
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
namespace mindspore { namespace mindspore {
const char kShapeSeperator[] = ","; const char kShapeSeperator[] = ",";
const char kShapeScalar[] = "[0]"; const char kShapeScalar[] = "[0]";
const char kShapeNone[] = "[]";
static std::map<std::string, TypeId> print_type_map = { static std::map<std::string, TypeId> print_type_map = {
{"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8}, {"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8},
{"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16}, {"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16},
...@@ -163,9 +164,9 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { ...@@ -163,9 +164,9 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
} }
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_); std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
MS_EXCEPTION_IF_NULL(str_data_ptr); MS_EXCEPTION_IF_NULL(str_data_ptr);
if (item.tensorShape_ == kShapeScalar) { if (item.tensorShape_ == kShapeScalar || item.tensorShape_ == kShapeNone) {
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) { if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
MS_LOG(EXCEPTION) << "Print op receive data length is invalid."; MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
} }
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_, &buf); convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_, &buf);
continue; continue;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册