tensor.cc 18.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "ir/tensor.h"

H
He Wei 已提交
19
#include <atomic>
20 21 22
#include <functional>
#include <numeric>
#include <vector>
H
He Wei 已提交
23
#include <memory>
24 25
#include <sstream>
#include <string>
H
He Wei 已提交
26
#include <utility>
27 28 29 30
#include <iomanip>
#include <algorithm>
#include <type_traits>
#include <typeinfo>
31

32
#include "abstract/abstract_value.h"
33 34 35

namespace mindspore {
namespace tensor {
36 37
constexpr auto kEllipsis = "...";
constexpr auto kThreshold = 6;
38

39 40 41 42
constexpr auto kThreshold1DFloat = kThreshold * 2;
constexpr auto kThreshold1DInt = kThreshold * 4;
constexpr auto kThreshold1DBool = kThreshold * 2;

H
He Wei 已提交
43 44 45
static std::string MakeId() {
  // Use atomic to make id generator thread safe.
  static std::atomic<uint64_t> last_id{1};
H
He Wei 已提交
46
  return "T" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed));
47 48
}

H
He Wei 已提交
49 50
static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) {
  return data_type ? data_type->type_id() : defaultTypeId;
51 52
}

L
lirongzhen1 已提交
53
static size_t SizeOf(const ShapeVector &shape) {
H
He Wei 已提交
54
  return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
55 56
}

L
lirongzhen1 已提交
57
static std::string ShapeToString(const ShapeVector &shape) {
58 59 60 61 62 63 64 65 66 67 68
  std::string str = "[";
  const size_t count = shape.size();
  for (size_t i = 0; i < count; ++i) {
    if (i > 0) {
      str.append(", ");
    }
    str.append(std::to_string(shape[i]));
  }
  return str.append("]");
}

H
He Wei 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
template <typename T, typename U>
std::unique_ptr<T[]> NewData(const U *input, size_t size) {
  if (input == nullptr || size == 0) {
    return nullptr;
  }
  auto data = std::make_unique<T[]>(size);
  if constexpr (!std::is_same<T, U>::value && (std::is_same<T, float16>::value || std::is_same<U, float16>::value)) {
    // Because float16 do not support implicit cast from/to other types,
    // We can not use std::copy() on array of float16, use a loop here.
    for (size_t i = 0; i < size; ++i) {
      data[i] = static_cast<T>(input[i]);
    }
  } else {
    // otherwise, use std::copy for better performance.
    std::copy(input, input + size, data.get());
  }
  return data;
}

template <typename T, typename Scalar>
std::unique_ptr<T[]> NewData(Scalar scalar) {
  auto data = std::make_unique<T[]>(1);
  data[0] = static_cast<T>(scalar);
  return data;
}

H
He Wei 已提交
95
template <typename T>
L
lirongzhen1 已提交
96
std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId data_type) {
H
He Wei 已提交
97
  const size_t size = SizeOf(shape);
H
He Wei 已提交
98
  switch (data_type) {
99 100 101 102
    case kNumberTypeBool: {
      auto buf = static_cast<bool *>(data);
      return NewData<T>(buf, size);
    }
H
He Wei 已提交
103 104
    case kNumberTypeUInt8: {
      auto buf = static_cast<uint8_t *>(data);
H
He Wei 已提交
105
      return NewData<T>(buf, size);
H
He Wei 已提交
106 107 108
    }
    case kNumberTypeInt8: {
      auto buf = static_cast<int8_t *>(data);
H
He Wei 已提交
109
      return NewData<T>(buf, size);
H
He Wei 已提交
110 111 112
    }
    case kNumberTypeInt16: {
      auto buf = static_cast<int16_t *>(data);
H
He Wei 已提交
113
      return NewData<T>(buf, size);
H
He Wei 已提交
114 115 116
    }
    case kNumberTypeInt32: {
      auto buf = static_cast<int32_t *>(data);
H
He Wei 已提交
117
      return NewData<T>(buf, size);
H
He Wei 已提交
118 119 120
    }
    case kNumberTypeInt64: {
      auto buf = static_cast<int64_t *>(data);
H
He Wei 已提交
121
      return NewData<T>(buf, size);
H
He Wei 已提交
122 123 124
    }
    case kNumberTypeUInt16: {
      auto buf = static_cast<uint16_t *>(data);
H
He Wei 已提交
125
      return NewData<T>(buf, size);
H
He Wei 已提交
126 127 128
    }
    case kNumberTypeUInt32: {
      auto buf = static_cast<uint32_t *>(data);
H
He Wei 已提交
129
      return NewData<T>(buf, size);
H
He Wei 已提交
130 131 132
    }
    case kNumberTypeUInt64: {
      auto buf = static_cast<uint64_t *>(data);
H
He Wei 已提交
133
      return NewData<T>(buf, size);
H
He Wei 已提交
134 135 136
    }
    case kNumberTypeFloat16: {
      auto buf = static_cast<float16 *>(data);
H
He Wei 已提交
137
      return NewData<T>(buf, size);
H
He Wei 已提交
138 139
    }
    case kNumberTypeFloat32: {
H
He Wei 已提交
140 141
      auto buf = static_cast<float *>(data);
      return NewData<T>(buf, size);
H
He Wei 已提交
142 143 144
    }
    case kNumberTypeFloat64: {
      auto buf = static_cast<double *>(data);
H
He Wei 已提交
145
      return NewData<T>(buf, size);
H
He Wei 已提交
146 147 148
    }
    default:
      break;
149
  }
H
He Wei 已提交
150
  MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
H
huangdongrun 已提交
151
}
152

H
He Wei 已提交
153
template <typename T>
L
lirongzhen1 已提交
154
std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, size_t data_len) {
H
He Wei 已提交
155 156 157 158
  size_t size = SizeOf(shape);
  if (size * sizeof(T) != data_len) {
    MS_LOG(EXCEPTION) << "Incorrect tensor input data length  " << data_len << ", expect " << size * sizeof(T)
                      << " item size " << sizeof(T);
159
  }
H
He Wei 已提交
160
  auto buf = static_cast<T *>(data);
H
He Wei 已提交
161
  return NewData<T>(buf, size);
162 163
}

H
He Wei 已提交
164 165 166 167
// Tensor data implementation.
template <typename T>
class TensorDataImpl : public TensorData {
 public:
L
lirongzhen1 已提交
168
  explicit TensorDataImpl(const ShapeVector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {}
M
fixed:  
mxm 已提交
169
  ~TensorDataImpl() = default;
170

L
lirongzhen1 已提交
171
  TensorDataImpl(const ShapeVector &shape, void *data, size_t data_len)
172
      : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_len)) {}
173

L
lirongzhen1 已提交
174
  TensorDataImpl(const ShapeVector &shape, void *data, TypeId data_type)
175
      : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_type)) {}
176

H
He Wei 已提交
177
  template <typename U>
L
lirongzhen1 已提交
178
  TensorDataImpl(const ShapeVector &shape, const U *input, size_t size)
H
He Wei 已提交
179
      : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData<T>(input, size)) {}
180

H
He Wei 已提交
181
  template <typename Scalar>
L
lirongzhen1 已提交
182
  TensorDataImpl(const ShapeVector &shape, Scalar scalar)
H
He Wei 已提交
183
      : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData<T>(scalar)) {}
184

H
He Wei 已提交
185
  ssize_t size() const override { return static_cast<ssize_t>(data_size_); }
186

H
He Wei 已提交
187
  ssize_t itemsize() const override { return static_cast<ssize_t>(sizeof(T)); }
188

H
He Wei 已提交
189
  ssize_t nbytes() const override { return size() * itemsize(); }
190

H
He Wei 已提交
191
  ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
192

H
He Wei 已提交
193
  void *data() override {
H
He Wei 已提交
194
    if (data_ == nullptr) {
H
He Wei 已提交
195
      // Lazy allocation.
H
He Wei 已提交
196
      data_ = std::make_unique<T[]>(data_size_);
197
    }
H
He Wei 已提交
198
    return data_.get();
199 200
  }

H
He Wei 已提交
201 202
  bool equals(const TensorData &other) const override {
    auto ptr = dynamic_cast<const TensorDataImpl<T> *>(&other);
H
He Wei 已提交
203
    if (ptr == nullptr) {
204
      return false;
H
He Wei 已提交
205
    }
H
He Wei 已提交
206 207 208 209 210 211 212 213
    if (ptr == this) {
      return true;
    }
    if (data_ == nullptr || ptr->data_ == nullptr) {
      return false;
    }
    return (ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
           std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get());
214
  }
H
He Wei 已提交
215

L
lirongzhen1 已提交
216
  std::string ToString(const TypeId type, const ShapeVector &shape) const override {
217
    constexpr auto valid =
218
      std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
219 220 221
      std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
      std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value ||
      std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
222
    static_assert(valid, "Type is invalid");
223 224 225
    if (data_size_ == 0) {
      return "";
    }
H
He Wei 已提交
226
    if (data_ == nullptr) {
227
      return "<uninitialized>";
228 229
    }

H
He Wei 已提交
230
    std::ostringstream ss;
231
    if (data_size_ == 1 && ndim_ == 0) {  // Scalar
232
      OutputDataString(ss, 0, 0, 1);
233 234
      return ss.str();
    }
235
    ssize_t cursor = 0;
236
    SummaryStringRecursive(ss, shape, &cursor, 0);
237 238 239 240
    return ss.str();
  }

 private:
241 242
  void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end) const {
    const bool isScalar = ndim_ == 0 && end - start == 1;
243 244
    constexpr auto isFloat =
      std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
245 246
    constexpr auto isBool = std::is_same<T, bool>::value;
    constexpr int linefeedThreshold = isFloat ? kThreshold1DFloat : (isBool ? kThreshold1DBool : kThreshold1DInt);
247
    for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
248 249
      const auto value = data_[cursor + i];
      if constexpr (isFloat) {
250 251 252
        if (isScalar) {
          ss << value;
        } else {
253
          if constexpr (std::is_same<T, float16>::value) {
H
HuangBingjian 已提交
254 255 256 257 258 259
            ss << std::setw(11) << std::setprecision(4) << std::setiosflags(std::ios::scientific | std::ios::right)
               << value;
          } else {
            ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right)
               << value;
          }
260
        }
261
      } else if (std::is_same<T, bool>::value) {
262
        if (isScalar) {
263
          ss << (value ? "True" : "False");
264
        } else {
265
          ss << std::setw(5) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
266
        }
267
      } else {
268 269 270
        constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value ||
                                  std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value;
        if constexpr (isSigned) {
271
          if (!isScalar && static_cast<int64_t>(value) >= 0) {
272 273 274
            ss << ' ';
          }
        }
275 276 277 278 279 280 281

        // Set width and indent for different int type.
        //
        //   int8/uint8 width:   3
        //   int16/uint16 width: 5
        //   int32/uint32 width: 10
        //   int64/uint64 width: NOT SET
282
        if constexpr (std::is_same<T, int8_t>::value) {
283
          ss << std::setw(3) << std::setiosflags(std::ios::right) << static_cast<int16_t>(value);
284
        } else if constexpr (std::is_same<T, uint8_t>::value) {
285 286 287 288 289
          ss << std::setw(3) << std::setiosflags(std::ios::right) << static_cast<uint16_t>(value);
        } else if constexpr (std::is_same<T, int16_t>::value || std::is_same<T, uint16_t>::value) {
          ss << std::setw(5) << std::setiosflags(std::ios::right) << value;
        } else if constexpr (std::is_same<T, int32_t>::value || std::is_same<T, uint32_t>::value) {
          ss << std::setw(10) << std::setiosflags(std::ios::right) << value;
290 291
        } else {
          ss << value;
292 293
        }
      }
294
      if (!isScalar && i != end - 1) {
295 296
        ss << ' ';
      }
297 298
      if (!isScalar && ndim_ == 1 && (i + 1) % linefeedThreshold == 0) {
        // Add a line feed every {threshold of type} for 1D tensor.
299 300
        ss << '\n' << ' ';
      }
301 302 303
    }
  }

L
lirongzhen1 已提交
304
  void SummaryStringRecursive(std::ostringstream &ss, const ShapeVector &shape, ssize_t *cursor, ssize_t depth) const {
305 306 307
    if (depth >= static_cast<ssize_t>(ndim_)) {
      return;
    }
H
He Wei 已提交
308
    ss << '[';
309
    if (depth == static_cast<ssize_t>(ndim_) - 1) {  // Bottom dimension
310 311
      ssize_t num = shape[depth];
      if (num > kThreshold && ndim_ > 1) {
312
        OutputDataString(ss, *cursor, 0, kThreshold / 2);
313
        ss << ' ' << kEllipsis << ' ';
314
        OutputDataString(ss, *cursor, num - kThreshold / 2, num);
315
      } else {
316
        OutputDataString(ss, *cursor, 0, num);
317 318 319
      }
      *cursor += num;
    } else {  // Middle dimension
320
      ssize_t num = shape[depth];
321 322 323 324 325 326
      // Handle the first half.
      for (ssize_t i = 0; i < std::min(static_cast<ssize_t>(kThreshold / 2), num); i++) {
        if (i > 0) {
          ss << '\n';
          ss << std::setw(depth + 1) << ' ';  // Add the indent.
        }
327
        SummaryStringRecursive(ss, shape, cursor, depth + 1);
328 329 330 331 332
      }
      // Handle the ignored part.
      if (num > kThreshold) {
        ss << '\n';
        ss << std::setw(depth + 1) << ' ';  // Add the indent.
333
        ss << kEllipsis;
334
        // Ignored at this layer.
335
        ssize_t ignored = shape[depth + 1];
336
        for (ssize_t i = depth + 2; i < static_cast<ssize_t>(ndim_); i++) {
337
          ignored *= shape[i];
338 339 340 341 342 343 344 345 346 347 348
        }
        // Multiple with ignored layers number.
        ignored *= num - kThreshold;

        *cursor += ignored;
      }
      // Handle the second half.
      if (num > kThreshold / 2) {
        for (ssize_t i = num - kThreshold / 2; i < num; i++) {
          ss << '\n';
          ss << std::setw(depth + 1) << ' ';  // Add the indent.
349
          SummaryStringRecursive(ss, shape, cursor, depth + 1);
350 351
        }
      }
352
    }
H
He Wei 已提交
353
    ss << ']';
354 355
  }

H
He Wei 已提交
356 357
  size_t ndim_{0};
  size_t data_size_{0};
H
He Wei 已提交
358
  std::unique_ptr<T[]> data_;
H
He Wei 已提交
359 360 361
};

template <typename... Args>
L
lirongzhen1 已提交
362
TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const Args... args) {
363 364
  switch (data_type) {
    case kNumberTypeBool:
365
      return std::make_shared<TensorDataImpl<bool>>(shape, args...);
H
He Wei 已提交
366 367
    case kNumberTypeUInt8:
      return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
368
    case kNumberTypeInt8:
H
He Wei 已提交
369
      return std::make_shared<TensorDataImpl<int8_t>>(shape, args...);
370
    case kNumberTypeInt16:
H
He Wei 已提交
371
      return std::make_shared<TensorDataImpl<int16_t>>(shape, args...);
372
    case kNumberTypeInt32:
H
He Wei 已提交
373
      return std::make_shared<TensorDataImpl<int32_t>>(shape, args...);
374
    case kNumberTypeInt64:
H
He Wei 已提交
375
      return std::make_shared<TensorDataImpl<int64_t>>(shape, args...);
376
    case kNumberTypeUInt16:
H
He Wei 已提交
377
      return std::make_shared<TensorDataImpl<uint16_t>>(shape, args...);
378
    case kNumberTypeUInt32:
H
He Wei 已提交
379
      return std::make_shared<TensorDataImpl<uint32_t>>(shape, args...);
380
    case kNumberTypeUInt64:
H
He Wei 已提交
381
      return std::make_shared<TensorDataImpl<uint64_t>>(shape, args...);
382
    case kNumberTypeFloat16:
H
He Wei 已提交
383
      return std::make_shared<TensorDataImpl<float16>>(shape, args...);
384
    case kNumberTypeFloat32:
H
He Wei 已提交
385
      return std::make_shared<TensorDataImpl<float>>(shape, args...);
386
    case kNumberTypeFloat64:
H
He Wei 已提交
387
      return std::make_shared<TensorDataImpl<double>>(shape, args...);
388 389 390
    default:
      break;
  }
H
He Wei 已提交
391
  MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
392 393
}

H
He Wei 已提交
394 395 396 397 398 399
Tensor::Tensor(const Tensor &tensor)
    : MetaTensor(tensor),
      init_flag_(tensor.init_flag_),
      data_(tensor.data_),
      dirty_(tensor.dirty_),
      id_(tensor.id_),
W
WilliamLian 已提交
400 401
      device_sync_(tensor.device_sync_),
      padding_type_(tensor.padding_type()) {}
402

H
He Wei 已提交
403 404 405 406 407 408
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
    : MetaTensor(data_type, tensor.shape_),
      init_flag_(tensor.init_flag_),
      data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
      dirty_(tensor.dirty_),
      id_(tensor.id_),
W
WilliamLian 已提交
409 410
      device_sync_(tensor.device_sync_),
      padding_type_(tensor.padding_type()) {}
411

L
lirongzhen1 已提交
412
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
H
He Wei 已提交
413
    : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
414

L
lirongzhen1 已提交
415
Tensor::Tensor(TypeId data_type, const ShapeVector &shape)
H
He Wei 已提交
416
    : Tensor(data_type, shape, MakeTensorData(data_type, shape)) {}
417

L
lirongzhen1 已提交
418
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len)
H
He Wei 已提交
419
    : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {}
420

L
lirongzhen1 已提交
421
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type)
H
He Wei 已提交
422
    : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {}
423

H
He Wei 已提交
424 425
Tensor::Tensor(const std::vector<int64_t> &input, const TypePtr &data_type)
    : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast<int>(input.size())}),
H
He Wei 已提交
426
      data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
H
He Wei 已提交
427 428 429 430
      id_(MakeId()) {}

Tensor::Tensor(const std::vector<double> &input, const TypePtr &data_type)
    : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast<int>(input.size())}),
H
He Wei 已提交
431
      data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
H
He Wei 已提交
432 433 434 435 436 437
      id_(MakeId()) {}

Tensor::Tensor(int64_t input, const TypePtr &data_type)
    : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}),
      data_(MakeTensorData(data_type_, {}, input)),
      id_(MakeId()) {}
438

H
He Wei 已提交
439 440 441 442 443 444 445 446 447 448 449
Tensor::Tensor(double input, const TypePtr &data_type)
    : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {}),
      data_(MakeTensorData(data_type_, {}, input)),
      id_(MakeId()) {}

bool Tensor::operator==(const Tensor &tensor) const {
  return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_));
}

bool Tensor::ValueEqual(const Tensor &tensor) const {
  return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_)));
450
}
H
huangdongrun 已提交
451 452 453 454
// assgin value to this tensor
Tensor &Tensor::AssignValue(const Tensor &tensor) {
  if (this != &tensor) {
    MetaTensor::operator=(tensor);
455 456
    dirty_ = tensor.dirty_;
    device_sync_ = tensor.device_sync_;
H
huangdongrun 已提交
457
    data_ = tensor.data_;
458
    id_ = tensor.id_;
W
WilliamLian 已提交
459
    padding_type_ = tensor.padding_type_;
H
huangdongrun 已提交
460 461 462
  }
  return *this;
}
463

464 465 466 467 468 469 470 471
abstract::AbstractBasePtr Tensor::ToAbstract() {
  auto tens = shared_from_base<Tensor>();
  auto dtype = tens->Dtype();
  if (!IsSubType(dtype, kNumber)) {
    MS_LOG(EXCEPTION) << "Expect tensor type kNumber but got: " << dtype->ToString() << ".";
  }
  auto tensor_shape = tens->shape();
  auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
472
  // if is parameter always no value.
W
Wei Luning 已提交
473 474
  if (is_parameter_) {
    auto param_name = param_info_->name();
475 476 477 478 479 480
    auto ref_key = std::make_shared<RefKey>(param_name);
    auto abs_ref_key = ref_key->ToAbstract();
    abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
  } else {
    abs_tensor->set_value(shared_from_base<Tensor>());
  }
481 482 483 484 485
  return abs_tensor;
}

std::string Tensor::GetShapeAndDataTypeInfo() const {
  std::ostringstream buf;
W
Wei Luning 已提交
486
  buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
487 488 489 490
  return buf.str();
}

std::string Tensor::ToString() const {
491
  constexpr int small_tensor_size = 30;
492
  std::ostringstream buf;
493 494
  auto dtype = Dtype();
  MS_EXCEPTION_IF_NULL(dtype);
H
He Wei 已提交
495
  data_sync();
H
He Wei 已提交
496
  buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ',';
497
  if (DataSize() < small_tensor_size) {
498
    // Only print data for small tensor.
H
He Wei 已提交
499
    buf << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_) << ')';
500
  } else {
H
He Wei 已提交
501
    buf << " [...])";
502 503 504 505 506 507
  }
  return buf.str();
}

std::string Tensor::ToStringRepr() const {
  std::ostringstream buf;
508 509
  auto dtype = Dtype();
  MS_EXCEPTION_IF_NULL(dtype);
H
He Wei 已提交
510
  data_sync();
H
He Wei 已提交
511 512
  buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ','
      << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_) << ')';
513 514 515
  return buf.str();
}

H
He Wei 已提交
516
void Tensor::data_sync() const {
517 518
  if (device_sync_ != nullptr) {
    if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
519
      MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
520 521 522 523
    }
  }
}

H
He Wei 已提交
524 525 526 527 528 529 530
TypeId Tensor::set_data_type(const TypeId data_type) {
  if (data_type != data_type_) {
    data_ = MakeTensorData(data_type, shape_, data_->data(), data_type_);
    return MetaTensor::set_data_type(data_type);
  }
  return data_type;
}
531 532
}  // namespace tensor
}  // namespace mindspore