tensor.cc 20.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "ir/tensor.h"

H
He Wei 已提交
19
#include <atomic>
20 21 22
#include <functional>
#include <numeric>
#include <vector>
H
He Wei 已提交
23
#include <memory>
24 25
#include <sstream>
#include <string>
H
He Wei 已提交
26
#include <utility>
27 28 29 30
#include <iomanip>
#include <algorithm>
#include <type_traits>
#include <typeinfo>
31

32
#include "abstract/abstract_value.h"
33 34 35

namespace mindspore {
namespace tensor {
36 37
constexpr auto kEllipsis = "...";
constexpr auto kThreshold = 6;
38

39 40 41 42
constexpr auto kThreshold1DFloat = kThreshold * 2;
constexpr auto kThreshold1DInt = kThreshold * 4;
constexpr auto kThreshold1DBool = kThreshold * 2;

H
He Wei 已提交
43 44 45
static std::string MakeId() {
  // Use atomic to make id generator thread safe.
  static std::atomic<uint64_t> last_id{1};
H
He Wei 已提交
46
  return "T" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed));
47 48
}

H
He Wei 已提交
49 50
static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) {
  return data_type ? data_type->type_id() : defaultTypeId;
51 52
}

L
lirongzhen1 已提交
53
static size_t SizeOf(const ShapeVector &shape) {
H
He Wei 已提交
54
  return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
55 56
}

L
lirongzhen1 已提交
57
static std::string ShapeToString(const ShapeVector &shape) {
58 59 60 61 62 63 64 65 66 67 68
  std::string str = "[";
  const size_t count = shape.size();
  for (size_t i = 0; i < count; ++i) {
    if (i > 0) {
      str.append(", ");
    }
    str.append(std::to_string(shape[i]));
  }
  return str.append("]");
}

H
He Wei 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
template <typename T, typename U>
std::unique_ptr<T[]> NewData(const U *input, size_t size) {
  if (input == nullptr || size == 0) {
    return nullptr;
  }
  auto data = std::make_unique<T[]>(size);
  if constexpr (!std::is_same<T, U>::value && (std::is_same<T, float16>::value || std::is_same<U, float16>::value)) {
    // Because float16 do not support implicit cast from/to other types,
    // We can not use std::copy() on array of float16, use a loop here.
    for (size_t i = 0; i < size; ++i) {
      data[i] = static_cast<T>(input[i]);
    }
  } else {
    // otherwise, use std::copy for better performance.
    std::copy(input, input + size, data.get());
  }
  return data;
}

template <typename T, typename Scalar>
std::unique_ptr<T[]> NewData(Scalar scalar) {
  auto data = std::make_unique<T[]>(1);
  data[0] = static_cast<T>(scalar);
  return data;
}

H
He Wei 已提交
95
template <typename T>
L
lirongzhen1 已提交
96
std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId data_type) {
H
He Wei 已提交
97
  const size_t size = SizeOf(shape);
H
He Wei 已提交
98
  switch (data_type) {
99 100 101 102
    case kNumberTypeBool: {
      auto buf = static_cast<bool *>(data);
      return NewData<T>(buf, size);
    }
H
He Wei 已提交
103 104
    case kNumberTypeUInt8: {
      auto buf = static_cast<uint8_t *>(data);
H
He Wei 已提交
105
      return NewData<T>(buf, size);
H
He Wei 已提交
106 107 108
    }
    case kNumberTypeInt8: {
      auto buf = static_cast<int8_t *>(data);
H
He Wei 已提交
109
      return NewData<T>(buf, size);
H
He Wei 已提交
110 111 112
    }
    case kNumberTypeInt16: {
      auto buf = static_cast<int16_t *>(data);
H
He Wei 已提交
113
      return NewData<T>(buf, size);
H
He Wei 已提交
114 115 116
    }
    case kNumberTypeInt32: {
      auto buf = static_cast<int32_t *>(data);
H
He Wei 已提交
117
      return NewData<T>(buf, size);
H
He Wei 已提交
118 119 120
    }
    case kNumberTypeInt64: {
      auto buf = static_cast<int64_t *>(data);
H
He Wei 已提交
121
      return NewData<T>(buf, size);
H
He Wei 已提交
122 123 124
    }
    case kNumberTypeUInt16: {
      auto buf = static_cast<uint16_t *>(data);
H
He Wei 已提交
125
      return NewData<T>(buf, size);
H
He Wei 已提交
126 127 128
    }
    case kNumberTypeUInt32: {
      auto buf = static_cast<uint32_t *>(data);
H
He Wei 已提交
129
      return NewData<T>(buf, size);
H
He Wei 已提交
130 131 132
    }
    case kNumberTypeUInt64: {
      auto buf = static_cast<uint64_t *>(data);
H
He Wei 已提交
133
      return NewData<T>(buf, size);
H
He Wei 已提交
134 135 136
    }
    case kNumberTypeFloat16: {
      auto buf = static_cast<float16 *>(data);
H
He Wei 已提交
137
      return NewData<T>(buf, size);
H
He Wei 已提交
138 139
    }
    case kNumberTypeFloat32: {
H
He Wei 已提交
140 141
      auto buf = static_cast<float *>(data);
      return NewData<T>(buf, size);
H
He Wei 已提交
142 143 144
    }
    case kNumberTypeFloat64: {
      auto buf = static_cast<double *>(data);
H
He Wei 已提交
145
      return NewData<T>(buf, size);
H
He Wei 已提交
146 147 148
    }
    default:
      break;
149
  }
H
He Wei 已提交
150
  MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
H
huangdongrun 已提交
151
}
152

H
He Wei 已提交
153
template <typename T>
L
lirongzhen1 已提交
154
std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, size_t data_len) {
H
He Wei 已提交
155 156 157 158
  size_t size = SizeOf(shape);
  if (size * sizeof(T) != data_len) {
    MS_LOG(EXCEPTION) << "Incorrect tensor input data length  " << data_len << ", expect " << size * sizeof(T)
                      << " item size " << sizeof(T);
159
  }
H
He Wei 已提交
160
  auto buf = static_cast<T *>(data);
H
He Wei 已提交
161
  return NewData<T>(buf, size);
162 163
}

H
He Wei 已提交
164 165 166 167
// Tensor data implementation.
template <typename T>
class TensorDataImpl : public TensorData {
 public:
L
lirongzhen1 已提交
168
  explicit TensorDataImpl(const ShapeVector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {}
M
fixed:  
mxm 已提交
169
  ~TensorDataImpl() = default;
170

L
lirongzhen1 已提交
171
  TensorDataImpl(const ShapeVector &shape, void *data, size_t data_len)
172
      : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_len)) {}
173

L
lirongzhen1 已提交
174
  TensorDataImpl(const ShapeVector &shape, void *data, TypeId data_type)
175
      : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_type)) {}
176

H
He Wei 已提交
177
  template <typename U>
L
lirongzhen1 已提交
178
  TensorDataImpl(const ShapeVector &shape, const U *input, size_t size)
H
He Wei 已提交
179
      : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData<T>(input, size)) {}
180

H
He Wei 已提交
181
  template <typename Scalar>
L
lirongzhen1 已提交
182
  TensorDataImpl(const ShapeVector &shape, Scalar scalar)
H
He Wei 已提交
183
      : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData<T>(scalar)) {}
184

H
He Wei 已提交
185
  ssize_t size() const override { return static_cast<ssize_t>(data_size_); }
186

H
He Wei 已提交
187
  ssize_t itemsize() const override { return static_cast<ssize_t>(sizeof(T)); }
188

H
He Wei 已提交
189
  ssize_t nbytes() const override { return size() * itemsize(); }
190

H
He Wei 已提交
191
  ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
192

H
He Wei 已提交
193
  void *data() override {
H
He Wei 已提交
194
    if (data_ == nullptr) {
H
He Wei 已提交
195
      // Lazy allocation.
H
He Wei 已提交
196
      data_ = std::make_unique<T[]>(data_size_);
197
    }
H
He Wei 已提交
198
    return data_.get();
199 200
  }

201 202 203 204 205
  const void *const_data() const override {
    // May return nullptr if data not initialized.
    return data_.get();
  }

H
He Wei 已提交
206 207
  bool equals(const TensorData &other) const override {
    auto ptr = dynamic_cast<const TensorDataImpl<T> *>(&other);
H
He Wei 已提交
208
    if (ptr == nullptr) {
209 210
      // Not same type, compare data byte by byte.
      return TensorData::equals(other);
H
He Wei 已提交
211
    }
H
He Wei 已提交
212 213 214 215 216 217 218 219
    if (ptr == this) {
      return true;
    }
    if (data_ == nullptr || ptr->data_ == nullptr) {
      return false;
    }
    return (ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) &&
           std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get());
220
  }
H
He Wei 已提交
221

222
  std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override {
223
    constexpr auto valid =
224
      std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
225 226 227
      std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
      std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value ||
      std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
228
    static_assert(valid, "Type is invalid");
229 230 231
    if (data_size_ == 0) {
      return "";
    }
H
He Wei 已提交
232
    if (data_ == nullptr) {
233
      return "<uninitialized>";
234 235
    }

H
He Wei 已提交
236
    std::ostringstream ss;
237
    if (data_size_ == 1 && ndim_ == 0) {  // Scalar
238
      OutputDataString(ss, 0, 0, 1, false);
239 240
      return ss.str();
    }
241
    ssize_t cursor = 0;
242
    SummaryStringRecursive(ss, shape, &cursor, 0, use_comma);
243 244 245 246
    return ss.str();
  }

 private:
247
  void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end, bool use_comma) const {
248
    const bool isScalar = ndim_ == 0 && end - start == 1;
249 250
    constexpr auto isFloat =
      std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
251 252
    constexpr auto isBool = std::is_same<T, bool>::value;
    constexpr int linefeedThreshold = isFloat ? kThreshold1DFloat : (isBool ? kThreshold1DBool : kThreshold1DInt);
253
    for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
254 255
      const auto value = data_[cursor + i];
      if constexpr (isFloat) {
256 257 258
        if (isScalar) {
          ss << value;
        } else {
259
          if constexpr (std::is_same<T, float16>::value) {
H
HuangBingjian 已提交
260 261 262 263 264 265
            ss << std::setw(11) << std::setprecision(4) << std::setiosflags(std::ios::scientific | std::ios::right)
               << value;
          } else {
            ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right)
               << value;
          }
266
        }
267
      } else if (std::is_same<T, bool>::value) {
268
        if (isScalar) {
269
          ss << (value ? "True" : "False");
270
        } else {
271
          ss << std::setw(5) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
272
        }
273
      } else {
274
        constexpr auto isSigned = std::is_same<T, int64_t>::value;
275
        if constexpr (isSigned) {
276
          if (!isScalar && static_cast<int64_t>(value) >= 0) {
277 278 279
            ss << ' ';
          }
        }
280

281
        // Set width and indent for different int type with signed position.
282
        //
283 284 285 286 287 288 289 290 291
        //   uint8 width:  3,  [0, 255]
        //   int8 width:   4,  [-128, 127]
        //   uint16 width: 5,  [0, 65535]
        //   int16 width:  6,  [-32768, 32767]
        //   uint32 width: 10, [0, 4294967295]
        //   int32 width:  11, [-2147483648, 2147483647]
        //   uint64 width: NOT SET (20, [0, 18446744073709551615])
        //   int64 width:  NOT SET (20, [-9223372036854775808, 9223372036854775807])
        if constexpr (std::is_same<T, uint8_t>::value) {
292
          ss << std::setw(3) << std::setiosflags(std::ios::right) << static_cast<uint16_t>(value);
293 294 295
        } else if constexpr (std::is_same<T, int8_t>::value) {
          ss << std::setw(4) << std::setiosflags(std::ios::right) << static_cast<int16_t>(value);
        } else if constexpr (std::is_same<T, uint16_t>::value) {
296
          ss << std::setw(5) << std::setiosflags(std::ios::right) << value;
297 298 299
        } else if constexpr (std::is_same<T, int16_t>::value) {
          ss << std::setw(6) << std::setiosflags(std::ios::right) << value;
        } else if constexpr (std::is_same<T, uint32_t>::value) {
300
          ss << std::setw(10) << std::setiosflags(std::ios::right) << value;
301 302
        } else if constexpr (std::is_same<T, int32_t>::value) {
          ss << std::setw(11) << std::setiosflags(std::ios::right) << value;
303 304
        } else {
          ss << value;
305 306
        }
      }
307
      if (!isScalar && i != end - 1) {
308 309 310
        if (use_comma) {
          ss << ',';
        }
311 312
        ss << ' ';
      }
313 314
      if (!isScalar && ndim_ == 1 && (i + 1) % linefeedThreshold == 0) {
        // Add a line feed every {threshold of type} for 1D tensor.
315 316
        ss << '\n' << ' ';
      }
317 318 319
    }
  }

320 321
  void SummaryStringRecursive(std::ostringstream &ss, const ShapeVector &shape, ssize_t *cursor, ssize_t depth,
                              bool use_comma) const {
322 323 324
    if (depth >= static_cast<ssize_t>(ndim_)) {
      return;
    }
H
He Wei 已提交
325
    ss << '[';
326
    if (depth == static_cast<ssize_t>(ndim_) - 1) {  // Bottom dimension
327 328
      ssize_t num = shape[depth];
      if (num > kThreshold && ndim_ > 1) {
329
        OutputDataString(ss, *cursor, 0, kThreshold / 2, use_comma);
330
        ss << ' ' << kEllipsis << ' ';
331
        OutputDataString(ss, *cursor, num - kThreshold / 2, num, use_comma);
332
      } else {
333
        OutputDataString(ss, *cursor, 0, num, use_comma);
334 335 336
      }
      *cursor += num;
    } else {  // Middle dimension
337
      ssize_t num = shape[depth];
338 339 340
      // Handle the first half.
      for (ssize_t i = 0; i < std::min(static_cast<ssize_t>(kThreshold / 2), num); i++) {
        if (i > 0) {
341 342 343
          if (use_comma) {
            ss << ',';
          }
344 345 346
          ss << '\n';
          ss << std::setw(depth + 1) << ' ';  // Add the indent.
        }
347
        SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma);
348 349 350
      }
      // Handle the ignored part.
      if (num > kThreshold) {
351 352 353
        if (use_comma) {
          ss << ',';
        }
354 355
        ss << '\n';
        ss << std::setw(depth + 1) << ' ';  // Add the indent.
356
        ss << kEllipsis;
357
        // Ignored at this layer.
358
        ssize_t ignored = shape[depth + 1];
359
        for (ssize_t i = depth + 2; i < static_cast<ssize_t>(ndim_); i++) {
360
          ignored *= shape[i];
361 362 363 364 365 366 367 368
        }
        // Multiple with ignored layers number.
        ignored *= num - kThreshold;

        *cursor += ignored;
      }
      // Handle the second half.
      if (num > kThreshold / 2) {
369 370 371 372 373
        auto continue_pos = num - kThreshold / 2;
        for (ssize_t i = continue_pos; i < num; i++) {
          if (use_comma && i != continue_pos) {
            ss << ',';
          }
374 375
          ss << '\n';
          ss << std::setw(depth + 1) << ' ';  // Add the indent.
376
          SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma);
377 378
        }
      }
379
    }
H
He Wei 已提交
380
    ss << ']';
381 382
  }

H
He Wei 已提交
383 384
  size_t ndim_{0};
  size_t data_size_{0};
H
He Wei 已提交
385
  std::unique_ptr<T[]> data_;
H
He Wei 已提交
386 387 388
};

template <typename... Args>
L
lirongzhen1 已提交
389
TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const Args... args) {
390 391
  switch (data_type) {
    case kNumberTypeBool:
392
      return std::make_shared<TensorDataImpl<bool>>(shape, args...);
H
He Wei 已提交
393 394
    case kNumberTypeUInt8:
      return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
395
    case kNumberTypeInt8:
H
He Wei 已提交
396
      return std::make_shared<TensorDataImpl<int8_t>>(shape, args...);
397
    case kNumberTypeInt16:
H
He Wei 已提交
398
      return std::make_shared<TensorDataImpl<int16_t>>(shape, args...);
399
    case kNumberTypeInt32:
H
He Wei 已提交
400
      return std::make_shared<TensorDataImpl<int32_t>>(shape, args...);
401
    case kNumberTypeInt64:
H
He Wei 已提交
402
      return std::make_shared<TensorDataImpl<int64_t>>(shape, args...);
403
    case kNumberTypeUInt16:
H
He Wei 已提交
404
      return std::make_shared<TensorDataImpl<uint16_t>>(shape, args...);
405
    case kNumberTypeUInt32:
H
He Wei 已提交
406
      return std::make_shared<TensorDataImpl<uint32_t>>(shape, args...);
407
    case kNumberTypeUInt64:
H
He Wei 已提交
408
      return std::make_shared<TensorDataImpl<uint64_t>>(shape, args...);
409
    case kNumberTypeFloat16:
H
He Wei 已提交
410
      return std::make_shared<TensorDataImpl<float16>>(shape, args...);
411
    case kNumberTypeFloat32:
H
He Wei 已提交
412
      return std::make_shared<TensorDataImpl<float>>(shape, args...);
413
    case kNumberTypeFloat64:
H
He Wei 已提交
414
      return std::make_shared<TensorDataImpl<double>>(shape, args...);
415 416 417
    default:
      break;
  }
H
He Wei 已提交
418
  MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
419 420
}

H
He Wei 已提交
421 422 423 424 425
Tensor::Tensor(const Tensor &tensor)
    : MetaTensor(tensor),
      init_flag_(tensor.init_flag_),
      data_(tensor.data_),
      id_(tensor.id_),
K
kswang 已提交
426
      event_(tensor.event_),
K
kswang 已提交
427
      sync_status_(tensor.sync_status_),
W
WilliamLian 已提交
428 429
      device_sync_(tensor.device_sync_),
      padding_type_(tensor.padding_type()) {}
430

H
He Wei 已提交
431 432 433 434 435
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
    : MetaTensor(data_type, tensor.shape_),
      init_flag_(tensor.init_flag_),
      data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
      id_(tensor.id_),
K
kswang 已提交
436
      event_(tensor.event_),
K
kswang 已提交
437
      sync_status_(tensor.sync_status_),
W
WilliamLian 已提交
438 439
      device_sync_(tensor.device_sync_),
      padding_type_(tensor.padding_type()) {}
440

L
lirongzhen1 已提交
441
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
H
He Wei 已提交
442
    : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
443

L
lirongzhen1 已提交
444
Tensor::Tensor(TypeId data_type, const ShapeVector &shape)
H
He Wei 已提交
445
    : Tensor(data_type, shape, MakeTensorData(data_type, shape)) {}
446

L
lirongzhen1 已提交
447
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len)
H
He Wei 已提交
448
    : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {}
449

L
lirongzhen1 已提交
450
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type)
H
He Wei 已提交
451
    : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {}
452

H
He Wei 已提交
453 454
Tensor::Tensor(const std::vector<int64_t> &input, const TypePtr &data_type)
    : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast<int>(input.size())}),
H
He Wei 已提交
455
      data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
H
He Wei 已提交
456 457 458 459
      id_(MakeId()) {}

Tensor::Tensor(const std::vector<double> &input, const TypePtr &data_type)
    : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast<int>(input.size())}),
H
He Wei 已提交
460
      data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
H
He Wei 已提交
461 462 463 464 465 466
      id_(MakeId()) {}

Tensor::Tensor(int64_t input, const TypePtr &data_type)
    : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}),
      data_(MakeTensorData(data_type_, {}, input)),
      id_(MakeId()) {}
467

H
He Wei 已提交
468 469 470 471 472 473 474 475 476 477 478
Tensor::Tensor(double input, const TypePtr &data_type)
    : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {}),
      data_(MakeTensorData(data_type_, {}, input)),
      id_(MakeId()) {}

bool Tensor::operator==(const Tensor &tensor) const {
  return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_));
}

bool Tensor::ValueEqual(const Tensor &tensor) const {
  return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_)));
479
}
H
huangdongrun 已提交
480 481 482 483
// assgin value to this tensor
Tensor &Tensor::AssignValue(const Tensor &tensor) {
  if (this != &tensor) {
    MetaTensor::operator=(tensor);
484
    device_sync_ = tensor.device_sync_;
H
huangdongrun 已提交
485
    data_ = tensor.data_;
486
    id_ = tensor.id_;
K
kswang 已提交
487
    event_ = tensor.event_;
K
kswang 已提交
488
    sync_status_ = tensor.sync_status_;
W
WilliamLian 已提交
489
    padding_type_ = tensor.padding_type_;
H
huangdongrun 已提交
490 491 492
  }
  return *this;
}
493

494 495 496 497 498 499 500 501
abstract::AbstractBasePtr Tensor::ToAbstract() {
  auto tens = shared_from_base<Tensor>();
  auto dtype = tens->Dtype();
  if (!IsSubType(dtype, kNumber)) {
    MS_LOG(EXCEPTION) << "Expect tensor type kNumber but got: " << dtype->ToString() << ".";
  }
  auto tensor_shape = tens->shape();
  auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
502
  // if is parameter always no value.
W
Wei Luning 已提交
503 504
  if (is_parameter_) {
    auto param_name = param_info_->name();
505 506 507 508 509 510
    auto ref_key = std::make_shared<RefKey>(param_name);
    auto abs_ref_key = ref_key->ToAbstract();
    abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
  } else {
    abs_tensor->set_value(shared_from_base<Tensor>());
  }
511 512 513 514 515
  return abs_tensor;
}

std::string Tensor::GetShapeAndDataTypeInfo() const {
  std::ostringstream buf;
W
Wei Luning 已提交
516
  buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
517 518 519
  return buf.str();
}

520
std::string Tensor::ToStringInternal(int limit_size) const {
521
  std::ostringstream buf;
522 523
  auto dtype = Dtype();
  MS_EXCEPTION_IF_NULL(dtype);
H
He Wei 已提交
524
  data_sync();
H
He Wei 已提交
525
  buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ',';
526
  if (limit_size <= 0 || DataSize() < limit_size) {
527
    // Only print data for small tensor.
528
    buf << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, false) << ')';
529
  } else {
H
He Wei 已提交
530
    buf << " [...])";
531 532 533 534
  }
  return buf.str();
}

535 536 537 538 539 540 541
std::string Tensor::ToString() const {
  constexpr int small_tensor_size = 30;
  return ToStringInternal(small_tensor_size);
}

std::string Tensor::ToStringNoLimit() const { return ToStringInternal(0); }

542 543
std::string Tensor::ToStringRepr() const {
  std::ostringstream buf;
544 545
  auto dtype = Dtype();
  MS_EXCEPTION_IF_NULL(dtype);
H
He Wei 已提交
546
  data_sync();
H
He Wei 已提交
547
  buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ','
548
      << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, true) << ')';
549 550 551
  return buf.str();
}

H
He Wei 已提交
552
void Tensor::data_sync() const {
K
kswang 已提交
553 554 555 556 557 558
  Wait();
  if (device_sync_ == nullptr) {
    return;
  }
  if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
    MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
559 560 561
  }
}

H
He Wei 已提交
562 563 564 565 566 567 568
TypeId Tensor::set_data_type(const TypeId data_type) {
  if (data_type != data_type_) {
    data_ = MakeTensorData(data_type, shape_, data_->data(), data_type_);
    return MetaTensor::set_data_type(data_type);
  }
  return data_type;
}
569 570
}  // namespace tensor
}  // namespace mindspore