提交 0b3b1744 编写于 作者: Z Zhang Qinghua

Optimize the codes of Tensor ToString().

上级 7f37bfbb
......@@ -36,6 +36,10 @@ namespace tensor {
constexpr auto kEllipsis = "...";
constexpr auto kThreshold = 6;
constexpr auto kThreshold1DFloat = kThreshold * 2;
constexpr auto kThreshold1DInt = kThreshold * 4;
constexpr auto kThreshold1DBool = kThreshold * 2;
static std::string MakeId() {
// Use atomic to make id generator thread safe.
static std::atomic<uint64_t> last_id{1};
......@@ -120,22 +124,21 @@ std::vector<T> CopyData(const std::vector<int> &shape, void *data, size_t data_l
template <typename T>
class TensorDataImpl : public TensorData {
public:
explicit TensorDataImpl(const std::vector<int> &shape)
: ndim_(shape.size()), data_size_(SizeOf(shape)), shape_(shape) {}
explicit TensorDataImpl(const std::vector<int> &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {}
TensorDataImpl(const std::vector<int> &shape, void *data, size_t data_len)
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_len)), shape_(shape) {}
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_len)) {}
TensorDataImpl(const std::vector<int> &shape, void *data, TypeId data_type)
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_type)), shape_(shape) {}
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_type)) {}
template <typename InputIt>
TensorDataImpl(const std::vector<int> &shape, InputIt first, InputIt last)
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_(first, last), shape_(shape) {}
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_(first, last) {}
template <typename Scalar>
TensorDataImpl(const std::vector<int> &shape, Scalar scalar)
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_({static_cast<T>(scalar)}), shape_(shape) {}
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_({static_cast<T>(scalar)}) {}
ssize_t size() const override { return static_cast<ssize_t>(data_size_); }
......@@ -151,12 +154,13 @@ class TensorDataImpl : public TensorData {
// Prevent null pointer for empty shape.
return empty_data.data();
}
CheckDataSafe();
// Lazy allocation.
if (data_.empty()) {
data_.resize(data_size_);
}
return data_.data();
}
std::vector<int> shape() const { return shape_; }
bool equals(const TensorData &other) const override {
auto ptr = dynamic_cast<const TensorDataImpl<T> *>(&other);
if (ptr) {
......@@ -165,99 +169,101 @@ class TensorDataImpl : public TensorData {
return false;
}
// Prepare for lazy allocation.
void CheckDataSafe() {
// Lazy allocation.
if (data_.empty()) {
data_.resize(data_size_);
}
}
// ToString() for lazy allocation.
std::string ToStringSafe() {
CheckDataSafe();
return ToString();
}
std::string ToString() const override {
std::string ToString(const TypeId type, const std::vector<int> &shape) const override {
constexpr auto valid =
std::is_same<T, Bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value ||
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
if (!valid) {
MS_LOG(EXCEPTION) << "Type is invalid, T: " << typeid(T).name();
}
static_assert(valid, "Type is invalid");
if (data_size_ == 0) {
return "";
}
if (data_.empty()) {
MS_LOG(ERROR) << "data_ is empty, data_size_: " << data_size_;
return "";
return "<uninitialized>";
}
std::ostringstream ss;
ssize_t cursor = 0;
SummaryStringRecursive(ss, &cursor, 0);
SummaryStringRecursive(ss, type, shape, &cursor, 0);
return ss.str();
}
private:
void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end) const {
void OutputDataString(std::ostringstream &ss, const TypeId type, ssize_t cursor, ssize_t start, ssize_t end) const {
int linefeedThreshold;
constexpr auto isFloat =
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value ||
std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value;
for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
if (isFloat) {
const auto value = data_[cursor + i];
if constexpr (isFloat) {
ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right)
<< data_[cursor + i];
<< value;
linefeedThreshold = kThreshold1DFloat;
} else if (type == kNumberTypeBool) {
ss << std::setw(5) << std::setiosflags(std::ios::right) << (value == 0 ? "False" : "True");
linefeedThreshold = kThreshold1DBool;
} else {
if (isSigned && static_cast<int64_t>(data_[cursor + i]) >= 0) {
ss << ' ';
constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value ||
std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value;
if constexpr (isSigned) {
if (static_cast<int64_t>(value) >= 0) {
ss << ' ';
}
}
if constexpr (std::is_same<T, int8_t>::value) {
ss << static_cast<int16_t>(value);
} else if constexpr (std::is_same<T, uint8_t>::value) {
ss << static_cast<uint16_t>(value);
} else {
ss << value;
}
ss << data_[cursor + i];
linefeedThreshold = kThreshold1DInt;
}
if (i != end - 1) {
ss << ' ';
}
if (ndim_ == 1 && (i + 1) % linefeedThreshold == 0) { // Add a line feed every {threshold of type} for 1D tensor.
ss << '\n' << ' ';
}
}
}
void SummaryStringRecursive(std::ostringstream &ss, ssize_t *cursor, ssize_t depth) const {
void SummaryStringRecursive(std::ostringstream &ss, const TypeId type, const std::vector<int> &shape, ssize_t *cursor,
ssize_t depth) const {
if (depth >= static_cast<ssize_t>(ndim_)) {
return;
}
ss << '[';
if (depth == static_cast<ssize_t>(ndim_) - 1) { // Bottom dimension
ssize_t num = shape_[depth];
if (num > kThreshold) {
OutputDataString(ss, *cursor, 0, kThreshold / 2);
ssize_t num = shape[depth];
if (num > kThreshold && ndim_ > 1) {
OutputDataString(ss, type, *cursor, 0, kThreshold / 2);
ss << ' ' << kEllipsis << ' ';
OutputDataString(ss, *cursor, num - kThreshold / 2, num);
OutputDataString(ss, type, *cursor, num - kThreshold / 2, num);
} else {
OutputDataString(ss, *cursor, 0, num);
OutputDataString(ss, type, *cursor, 0, num);
}
*cursor += num;
} else { // Middle dimension
ssize_t num = shape_[depth];
ssize_t num = shape[depth];
// Handle the first half.
for (ssize_t i = 0; i < std::min(static_cast<ssize_t>(kThreshold / 2), num); i++) {
if (i > 0) {
ss << '\n';
ss << std::setw(depth + 1) << ' '; // Add the indent.
}
SummaryStringRecursive(ss, cursor, depth + 1);
SummaryStringRecursive(ss, type, shape, cursor, depth + 1);
}
// Handle the ignored part.
if (num > kThreshold) {
ss << '\n';
ss << std::setw(depth + 1) << ' '; // Add the indent.
ss << kEllipsis << '\n';
ss << kEllipsis;
// Ignored at this layer.
ssize_t ignored = shape_[depth + 1];
ssize_t ignored = shape[depth + 1];
for (ssize_t i = depth + 2; i < static_cast<ssize_t>(ndim_); i++) {
ignored *= shape_[i];
ignored *= shape[i];
}
// Multiple with ignored layers number.
ignored *= num - kThreshold;
......@@ -269,7 +275,7 @@ class TensorDataImpl : public TensorData {
for (ssize_t i = num - kThreshold / 2; i < num; i++) {
ss << '\n';
ss << std::setw(depth + 1) << ' '; // Add the indent.
SummaryStringRecursive(ss, cursor, depth + 1);
SummaryStringRecursive(ss, type, shape, cursor, depth + 1);
}
}
}
......@@ -279,7 +285,6 @@ class TensorDataImpl : public TensorData {
size_t ndim_{0};
size_t data_size_{0};
std::vector<T> data_;
std::vector<int> shape_;
};
template <typename... Args>
......@@ -404,7 +409,7 @@ std::string Tensor::ToString() const {
buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
// only print small tensor
if (DataSize() < small_tensor_size) {
buf << ", value:" << data().ToString();
buf << ", value:" << data().ToString(data_type_, shape());
}
return buf.str();
}
......@@ -414,20 +419,10 @@ std::string Tensor::ToStringRepr() const {
auto type_ptr = this->Dtype();
MS_EXCEPTION_IF_NULL(type_ptr);
buf << "Tensor shape:[" << shape() << "]" << type_ptr->ToString();
buf << "\nvalue:" << data().ToString();
buf << "\nvalue:" << data().ToString(data_type_, shape());
return buf.str();
}
std::string Tensor::ToStringSafe() {
data().CheckDataSafe();
return ToString();
}
std::string Tensor::ToStringReprSafe() {
data().CheckDataSafe();
return ToStringRepr();
}
void Tensor::data_sync() const {
if (device_address_ != nullptr) {
if (!device_address_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
......
......@@ -54,16 +54,10 @@ class TensorData {
virtual ssize_t ndim() const = 0;
/// Data pointer.
virtual void *data() = 0;
/// Shape of data.
virtual std::vector<int> shape() const = 0;
/// Is data equals.
virtual bool equals(const TensorData &other) const = 0;
/// Check for lazy allocation.
virtual void CheckDataSafe() = 0;
/// To string for lazy allocation.
virtual std::string ToStringSafe() = 0;
/// To string.
virtual std::string ToString() const = 0;
virtual std::string ToString(const TypeId type, const std::vector<int> &shape) const = 0;
};
using TensorDataPtr = std::shared_ptr<TensorData>;
......@@ -222,12 +216,6 @@ class Tensor : public MetaTensor {
std::string ToStringRepr() const;
/// To string for lazy allocation.
std::string ToStringSafe();
/// To string for lazy allocation.
std::string ToStringReprSafe();
bool is_init() { return init_flag_; }
void set_init_flag(bool flag) { init_flag_ = flag; }
......
......@@ -351,8 +351,8 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.set_dtype(mindspore.int32)
mindspore.int32
)mydelimiter")
.def("__str__", &Tensor::ToStringSafe)
.def("__repr__", &Tensor::ToStringReprSafe)
.def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr)
.def(py::pickle(
[](const Tensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册