提交 e80e19d0 编写于 作者: H He Wei

Refactory Tensor bool type and ToString()

1. Use C++ native bool type for tensor:
Since we changed tensor data from vector<T> to unique_ptr<T[]>,
boolean tensor can be implemented using C++ native bool type.

2. Improve ToString():
Try to make tensor string representation looks better.
上级 857c0301
......@@ -54,6 +54,18 @@ static size_t SizeOf(const std::vector<int> &shape) {
return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
}
static std::string ShapeToString(const std::vector<int> &shape) {
std::string str = "[";
const size_t count = shape.size();
for (size_t i = 0; i < count; ++i) {
if (i > 0) {
str.append(", ");
}
str.append(std::to_string(shape[i]));
}
return str.append("]");
}
template <typename T, typename U>
std::unique_ptr<T[]> NewData(const U *input, size_t size) {
if (input == nullptr || size == 0) {
......@@ -84,7 +96,10 @@ template <typename T>
std::unique_ptr<T[]> CopyData(const std::vector<int> &shape, void *const data, TypeId data_type) {
const size_t size = SizeOf(shape);
switch (data_type) {
case kNumberTypeBool:
case kNumberTypeBool: {
auto buf = static_cast<bool *>(data);
return NewData<T>(buf, size);
}
case kNumberTypeUInt8: {
auto buf = static_cast<uint8_t *>(data);
return NewData<T>(buf, size);
......@@ -200,7 +215,7 @@ class TensorDataImpl : public TensorData {
std::string ToString(const TypeId type, const std::vector<int> &shape) const override {
constexpr auto valid =
std::is_same<T, Bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value ||
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
......@@ -214,27 +229,28 @@ class TensorDataImpl : public TensorData {
std::ostringstream ss;
if (data_size_ == 1 && ndim_ == 0) { // Scalar
OutputDataString(ss, type, 0, 0, 1);
OutputDataString(ss, 0, 0, 1);
return ss.str();
}
ssize_t cursor = 0;
SummaryStringRecursive(ss, type, shape, &cursor, 0);
SummaryStringRecursive(ss, shape, &cursor, 0);
return ss.str();
}
private:
void OutputDataString(std::ostringstream &ss, const TypeId type, ssize_t cursor, ssize_t start, ssize_t end) const {
bool isScalar = ndim_ == 0 && end - start == 1;
int linefeedThreshold;
void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end) const {
const bool isScalar = ndim_ == 0 && end - start == 1;
constexpr auto isFloat =
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
constexpr auto isBool = std::is_same<T, bool>::value;
constexpr int linefeedThreshold = isFloat ? kThreshold1DFloat : (isBool ? kThreshold1DBool : kThreshold1DInt);
for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
const auto value = data_[cursor + i];
if constexpr (isFloat) {
if (isScalar) {
ss << value;
} else {
if (std::is_same<T, float16>::value) {
if constexpr (std::is_same<T, float16>::value) {
ss << std::setw(11) << std::setprecision(4) << std::setiosflags(std::ios::scientific | std::ios::right)
<< value;
} else {
......@@ -242,14 +258,12 @@ class TensorDataImpl : public TensorData {
<< value;
}
}
linefeedThreshold = kThreshold1DFloat;
} else if (type == kNumberTypeBool) {
} else if (std::is_same<T, bool>::value) {
if (isScalar) {
ss << (value == 0 ? "False" : "True");
ss << (value ? "True" : "False");
} else {
ss << std::setw(5) << std::setiosflags(std::ios::right) << (value == 0 ? "False" : "True");
ss << std::setw(5) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
}
linefeedThreshold = kThreshold1DBool;
} else {
constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value ||
std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value;
......@@ -276,7 +290,6 @@ class TensorDataImpl : public TensorData {
} else {
ss << value;
}
linefeedThreshold = kThreshold1DInt;
}
if (!isScalar && i != end - 1) {
ss << ' ';
......@@ -288,7 +301,7 @@ class TensorDataImpl : public TensorData {
}
}
void SummaryStringRecursive(std::ostringstream &ss, const TypeId type, const std::vector<int> &shape, ssize_t *cursor,
void SummaryStringRecursive(std::ostringstream &ss, const std::vector<int> &shape, ssize_t *cursor,
ssize_t depth) const {
if (depth >= static_cast<ssize_t>(ndim_)) {
return;
......@@ -297,11 +310,11 @@ class TensorDataImpl : public TensorData {
if (depth == static_cast<ssize_t>(ndim_) - 1) { // Bottom dimension
ssize_t num = shape[depth];
if (num > kThreshold && ndim_ > 1) {
OutputDataString(ss, type, *cursor, 0, kThreshold / 2);
OutputDataString(ss, *cursor, 0, kThreshold / 2);
ss << ' ' << kEllipsis << ' ';
OutputDataString(ss, type, *cursor, num - kThreshold / 2, num);
OutputDataString(ss, *cursor, num - kThreshold / 2, num);
} else {
OutputDataString(ss, type, *cursor, 0, num);
OutputDataString(ss, *cursor, 0, num);
}
*cursor += num;
} else { // Middle dimension
......@@ -312,7 +325,7 @@ class TensorDataImpl : public TensorData {
ss << '\n';
ss << std::setw(depth + 1) << ' '; // Add the indent.
}
SummaryStringRecursive(ss, type, shape, cursor, depth + 1);
SummaryStringRecursive(ss, shape, cursor, depth + 1);
}
// Handle the ignored part.
if (num > kThreshold) {
......@@ -334,7 +347,7 @@ class TensorDataImpl : public TensorData {
for (ssize_t i = num - kThreshold / 2; i < num; i++) {
ss << '\n';
ss << std::setw(depth + 1) << ' '; // Add the indent.
SummaryStringRecursive(ss, type, shape, cursor, depth + 1);
SummaryStringRecursive(ss, shape, cursor, depth + 1);
}
}
}
......@@ -350,6 +363,7 @@ template <typename... Args>
TensorDataPtr MakeTensorData(TypeId data_type, const std::vector<int> &shape, const Args... args) {
switch (data_type) {
case kNumberTypeBool:
return std::make_shared<TensorDataImpl<bool>>(shape, args...);
case kNumberTypeUInt8:
return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
case kNumberTypeInt8:
......@@ -463,31 +477,35 @@ std::string Tensor::GetShapeAndDataTypeInfo() const {
}
std::string Tensor::ToString() const {
const int small_tensor_size = 30;
constexpr int small_tensor_size = 30;
std::ostringstream buf;
auto dtype = Dtype();
MS_EXCEPTION_IF_NULL(dtype);
data_sync();
buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
// only print small tensor
buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ",\n";
if (DataSize() < small_tensor_size) {
buf << ", value:" << data().ToString(data_type_, shape());
// Only print data for small tensor.
buf << data().ToString(data_type_, shape_) << ')';
} else {
buf << "[...])";
}
return buf.str();
}
std::string Tensor::ToStringRepr() const {
std::ostringstream buf;
auto type_ptr = this->Dtype();
MS_EXCEPTION_IF_NULL(type_ptr);
auto dtype = Dtype();
MS_EXCEPTION_IF_NULL(dtype);
data_sync();
buf << "Tensor shape:[" << shape() << "]" << type_ptr->ToString();
buf << "\nvalue:" << data().ToString(data_type_, shape());
buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ",\n"
<< data().ToString(data_type_, shape_) << ')';
return buf.str();
}
void Tensor::data_sync() const {
if (device_sync_ != nullptr) {
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost when asnumpy.";
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册