tensor.cpp 27.8 KB
Newer Older
1 2 3 4
/**
 * \file src/core/impl/tensor.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/tensor.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
M
Megvii Engine Team 已提交
15
#include "megbrain/opr/param_defs.h"
16 17 18 19 20 21

#include "megdnn/oprs.h"

#include <thread>

#include <cmath>
M
Megvii Engine Team 已提交
22
#include <cstring>
23 24 25 26 27

using namespace mgb;

namespace {

M
Megvii Engine Team 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
//! implement non-contiguous d2d copy
void noncont_tensor_copy(
        const DeviceTensorND& dest, const DeviceTensorND& src, bool contig_dest,
        bool contig_src) {
    auto src_cn = src.comp_node();
    auto dst_cn = dest.comp_node();
    if (src_cn.device_type() == dst_cn.device_type()) {
        // perform relayout op for better performance when src and dst are
        // placed on comp nodes with the same device type
        auto&& src_env = CompNodeEnv::from_comp_node(src.comp_node());
        auto relayout = opr::intl::get_megdnn_global_opr<megdnn::Relayout>(dst_cn);
        dst_cn.activate();
        relayout->exec(
                const_cast<DeviceTensorND&>(src).as_megdnn(), dest.as_megdnn(),
                MegDNNHandle::get(src_env).handle());
    } else {
44 45
        if (contig_src) {
            mgb_assert(!contig_dest);
M
Megvii Engine Team 已提交
46 47 48
            DeviceTensorND tmp{dst_cn};
            tmp.copy_from(src);
            dest.copy_from_fixlayout(tmp);
49 50 51 52 53 54
            return;
        }
        DeviceTensorND tmp;
        tmp.copy_from(src);
        dest.copy_from_fixlayout(tmp);
    }
M
Megvii Engine Team 已提交
55
}
56

M
Megvii Engine Team 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
//! implement non-contiguous h2h copy
void noncont_tensor_copy(
        const HostTensorND& dest, const HostTensorND& src, bool, bool) {
    auto opr =
            opr::intl::get_megdnn_global_opr<megdnn::Relayout>(CompNode::default_cpu());

    opr->exec(const_cast<HostTensorND&>(src).as_megdnn(), dest.as_megdnn());
}

//! implement non-contiguous d2h copy
void noncont_tensor_copy(
        const HostTensorND& dest, const DeviceTensorND& src, bool contig_dest,
        bool contig_src) {
    if (contig_src) {
        mgb_assert(!contig_dest);
72
        HostTensorND tmp;
M
Megvii Engine Team 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
        tmp.copy_from(src).sync();
        dest.copy_from_fixlayout(tmp);  // sync not needed for h2h copy
        return;
    }
    DeviceTensorND tmp;
    tmp.copy_from(src);
    dest.copy_from_fixlayout(tmp);
}

//! implement non-contiguous h2d copy
void noncont_tensor_copy(
        const DeviceTensorND& dest, const HostTensorND& src, bool contig_dest,
        bool contig_src) {
    if (contig_src) {
        mgb_assert(!contig_dest);
        DeviceTensorND tmp;
        // no need to sync because device free is async-safe with respect to
        // host thread
91
        tmp.copy_from(src);
M
Megvii Engine Team 已提交
92 93
        dest.copy_from_fixlayout(tmp);
        return;
94
    }
M
Megvii Engine Team 已提交
95 96 97 98 99
    HostTensorND tmp;
    tmp.copy_from(src);
    dest.copy_from_fixlayout(tmp).sync();
}
}  // anonymous namespace
100 101 102 103

/* ============= Slice and SubTensorSpec ============= */

SubTensorSpec SubTensorSpec::make_from_offset_elem(
M
Megvii Engine Team 已提交
104
        const TensorLayout& layout, ptrdiff_t offset_elem) {
105 106 107 108 109 110 111 112 113
    mgb_assert(layout.ndim && layout.dtype.valid());
    return {layout, offset_elem};
}

SubTensorSpec Slice::apply(TensorLayout layout, int axis) const {
    mgb_assert(layout.ndim > 0 && layout.dtype.valid());
    if (axis == megdnn::param::OptionalAxisV1::INVALID_AXIS) {
        axis = 0;
        layout = layout.collapse_contiguous();
M
Megvii Engine Team 已提交
114 115 116
        mgb_assert(
                layout.ndim == 1,
                "apply Slice with axis==INVALID_AXIS on non-contig layout");
117 118 119 120
    }
    // axis in [-ndim, ndim) is available
    if (axis < 0)
        axis += layout.ndim;
M
Megvii Engine Team 已提交
121 122
    mgb_assert(
            axis >= 0 && static_cast<size_t>(axis) < layout.ndim,
123 124 125 126 127 128
            "invalid axis: %d; ndim=%zu", axis, layout.ndim);

    ptrdiff_t size_ax = layout.shape[axis];
    ptrdiff_t begin, end, step = m_step.val_with_default(1);
    mgb_assert(step, "Slice step can not be zero");

M
Megvii Engine Team 已提交
129
    auto tostr = [](const Maybe<ptrdiff_t>& v) -> std::string {
130 131 132 133
        if (!v.valid())
            return "None";
        return std::to_string(v.val());
    };
M
Megvii Engine Team 已提交
134 135 136
    auto mod_size = [size_ax](ptrdiff_t v) -> ptrdiff_t {
        if (size_ax == 0)
            return 0;
137 138 139 140
        return v < 0 ? v + size_ax : v;
    };
    MGB_MARK_USED_VAR(tostr);

M
Megvii Engine Team 已提交
141 142 143 144 145 146 147 148 149 150 151 152
#define CHECK(cond)                                                               \
    if (m_is_scalar_idx) {                                                        \
        mgb_assert(                                                               \
                cond, "index out of bound: layout=%s; request index=%s, axis=%d", \
                layout.to_string().c_str(), tostr(m_begin).c_str(), axis);        \
    } else {                                                                      \
        mgb_assert(                                                               \
                cond,                                                             \
                "index out of bound: layout=%s; request begin=%s end=%s step=%s " \
                "axis=%d",                                                        \
                layout.to_string().c_str(), tostr(m_begin).c_str(),               \
                tostr(m_end).c_str(), tostr(m_step).c_str(), axis);               \
153
    }
154 155 156 157

    if (step > 0) {
        begin = mod_size(m_begin.val_with_default(0));
        end = mod_size(m_end.val_with_default(size_ax));
158 159 160 161 162
        if (!m_is_scalar_idx) {
            end = std::min(end, size_ax);
            begin = std::min(begin, end);
        }
        CHECK(begin >= 0 && end >= begin && end <= size_ax)
163 164 165
    } else {
        begin = mod_size(m_begin.val_with_default(size_ax - 1));
        end = m_end.valid() ? mod_size(m_end.val()) : -1;
166
        if (!m_is_scalar_idx) {
M
Megvii Engine Team 已提交
167
            begin = std::min(begin, std::max<ptrdiff_t>(size_ax - 1, 0));
168 169
            end = std::min(end, begin);
        }
M
Megvii Engine Team 已提交
170
        CHECK(step < 0 && begin >= 0 && end <= begin && begin < size_ax && end >= -1)
171 172 173 174 175 176 177 178
    }
    auto step_abs = std::abs(step);
    layout.shape[axis] = (std::abs(end - begin) + step_abs - 1) / step_abs;
    auto orig_stride = layout.stride[axis];
    layout.stride[axis] *= step;

    // make stride as contiguous as possible
    if (layout.shape[axis] != 1 && axis)
M
Megvii Engine Team 已提交
179
        --axis;
180 181
    if (layout.shape[axis] == 1) {
        auto stride = layout.stride[axis] =
M
Megvii Engine Team 已提交
182 183 184
                axis + 1 < static_cast<int>(layout.ndim)
                        ? layout.stride[axis + 1] * layout.shape[axis + 1]
                        : 1;
185

M
Megvii Engine Team 已提交
186
        for (int i = axis - 1; i >= 0; --i) {
187 188 189 190 191 192 193 194 195 196 197 198 199 200
            if (layout.shape[i] == 1) {
                layout.stride[i] = stride;
            } else {
                break;
            }
        }
    }

    auto offset_elem = layout.is_empty() ? 0 : orig_stride * begin;
    return SubTensorSpec::make_from_offset_elem(layout, offset_elem);

#undef CHECK
}

M
Megvii Engine Team 已提交
201 202 203
void SubTensorSpec::merge_with(const SubTensorSpec& rhs) {
    mgb_assert(
            m_layout.dtype.valid() && m_layout.dtype == rhs.m_layout.dtype &&
204 205 206 207 208 209 210 211
            rhs.m_layout.ndim);
    m_offset_elem += rhs.m_offset_elem;
    m_layout = rhs.m_layout;
}

/* ===================== TensorStorage ===================== */

class mgb::HostTensorStorageTrait {
M
Megvii Engine Team 已提交
212 213
public:
    static void* alloc(CompNode node, size_t size) { return node.alloc_host(size); }
214

M
Megvii Engine Team 已提交
215
    static void free(CompNode node, void* data) { node.free_host(data); }
216 217 218
};

class mgb::DeviceTensorStorageTrait {
M
Megvii Engine Team 已提交
219 220
public:
    static void* alloc(CompNode node, size_t size) { return node.alloc_device(size); }
221

M
Megvii Engine Team 已提交
222
    static void free(CompNode node, void* data) { node.free_device(data); }
223 224
};

M
Megvii Engine Team 已提交
225 226
template <class Trait>
TensorStorage<Trait>& TensorStorage<Trait>::operator=(const TensorStorage& rhs) {
227 228 229 230 231 232 233 234 235
    if (rhs.m_size > rhs.m_capacity) {
        rhs.ptr();
    }
    m_allow_realloc = rhs.m_allow_realloc;
    m_comp_node = rhs.m_comp_node;
    m_size = rhs.m_size;
    m_capacity = rhs.m_capacity;
    m_offset = rhs.m_offset;
    m_data = rhs.m_data;
236
    m_ref_ptr = rhs.m_ref_ptr;
237 238 239
    return *this;
}

M
Megvii Engine Team 已提交
240
template <class Trait>
241 242
TensorStorage<Trait>& TensorStorage<Trait>::ensure_size(size_t sz) {
    if (sz > m_size) {
M
Megvii Engine Team 已提交
243 244
        mgb_throw_if(
                !m_allow_realloc || m_offset, MegBrainError,
245 246 247 248 249 250 251
                "can not grow a tensor that does not allow realloc");
        check_comp_node_valid();
    }
    m_size = sz;
    return *this;
}

M
Megvii Engine Team 已提交
252 253 254
template <class Trait>
TensorStorage<Trait> TensorStorage<Trait>::sub(ptrdiff_t offset) const {
    ptr();  // apply lazy resize
255 256 257 258
    ptrdiff_t toff = offset + m_offset;
    if (offset == static_cast<ptrdiff_t>(m_size)) {
        return {false, m_comp_node, 0, 0, 0, RawStorage{}};
    }
M
Megvii Engine Team 已提交
259 260 261 262 263 264 265 266 267
    mgb_assert(
            toff >= 0 && offset < static_cast<ptrdiff_t>(m_size),
            "bad subtensor: offset=%td m_offset=%zu m_size=%zu", offset, m_offset,
            m_size);
    return {false,
            m_comp_node,
            m_size - offset,
            m_capacity - offset,
            static_cast<size_t>(toff),
268 269
            m_data,
            m_ref_ptr};
M
Megvii Engine Team 已提交
270 271 272
}

template <class Trait>
273 274 275 276
dt_byte* TensorStorage<Trait>::apply_lazy_and_get_ptr() {
    check_comp_node_valid();
    if (m_size > m_capacity) {
        mgb_assert(m_allow_realloc && !m_offset);
M
Megvii Engine Team 已提交
277 278
        m_data.reset();  // free old ptr
        m_capacity = 0;  // to be exception safe
279 280 281
        auto ptr = static_cast<dt_byte*>(Trait::alloc(m_comp_node, m_size));
        mgb_throw_if(!ptr, SystemError, "failed to allocate memory");
        CompNode cn = m_comp_node;
M
Megvii Engine Team 已提交
282
        m_data.reset(ptr, [cn](void* p) { Trait::free(cn, p); });
283
        m_ref_ptr = std::make_shared<void*>(static_cast<void*>(nullptr));
284 285
        m_capacity = m_size;
    }
286
    *m_ref_ptr = static_cast<void*>(m_data.get());
287 288 289
    return m_data.get() + m_offset;
}

M
Megvii Engine Team 已提交
290
template <class Trait>
291 292 293 294 295 296 297 298 299 300 301 302 303
TensorStorage<Trait>& TensorStorage<Trait>::comp_node(
        CompNode node, bool allow_mem_node_change) {
    mgb_assert(node.valid());
    if (m_comp_node.valid() && node.mem_node() != m_comp_node.mem_node()) {
        mgb_assert(allow_mem_node_change);
        m_allow_realloc = true;
        m_size = m_capacity = m_offset = 0;
        m_data.reset();
    }
    m_comp_node = node;
    return *this;
}

M
Megvii Engine Team 已提交
304 305
template <class Trait>
void TensorStorage<Trait>::reset(CompNode node, size_t size, RawStorage data) {
306 307 308 309 310 311
    mgb_assert(m_allow_realloc);
    m_comp_node = node;
    m_size = size;
    m_capacity = size;
    m_offset = 0;
    m_data = std::move(data);
312 313 314 315 316 317 318 319 320 321 322 323 324
    m_ref_ptr = std::make_shared<void*>(static_cast<void*>(m_data.get()));
}

template <class Trait>
void TensorStorage<Trait>::only_reset_raw_storage(
        CompNode node, size_t size, RawStorage data, size_t offset) {
    mgb_assert(m_allow_realloc);
    m_comp_node = node;
    m_size = size;
    m_capacity = size;
    m_offset = offset;
    m_data = std::move(data);
    *m_ref_ptr = static_cast<void*>(m_data.get());
325 326
}

M
Megvii Engine Team 已提交
327 328
template <class Trait>
template <class RTrait, typename>
329
TensorStorage<Trait> TensorStorage<Trait>::make_proxy(
M
Megvii Engine Team 已提交
330 331 332
        const TensorStorage<RTrait>& src) {
    mgb_assert(
            src.comp_node().mem_node() == CompNode::default_cpu().mem_node(),
333 334 335
            "proxy source should be on CPU; got %s",
            src.comp_node().to_string().c_str());
    src.ptr();
336 337
    return {true,         src.m_comp_node, src.m_size,   src.m_capacity,
            src.m_offset, src.m_data,      src.m_ref_ptr};
338 339
}

M
Megvii Engine Team 已提交
340
template <class Trait>
341
void TensorStorage<Trait>::on_invalid_comp_node() {
M
Megvii Engine Team 已提交
342 343 344
    mgb_throw(
            MegBrainError,
            "trying to acccess TensorStorage with invalid "
345 346 347 348 349 350
            "comp node");
}

namespace mgb {

// host to host
M
Megvii Engine Team 已提交
351 352
template <>
template <>
353
MGE_WIN_DECLSPEC_FUC void TensorStorage<HostTensorStorageTrait>::copy_from(
M
Megvii Engine Team 已提交
354
        const TensorStorage<HostTensorStorageTrait>& src, size_t size) const {
355 356 357 358 359
    mgb_assert(size <= this->size() && size <= src.size());
    memcpy(ptr(), src.ptr(), size);
}

// device to host
M
Megvii Engine Team 已提交
360 361
template <>
template <>
362
MGE_WIN_DECLSPEC_FUC void TensorStorage<HostTensorStorageTrait>::copy_from(
M
Megvii Engine Team 已提交
363
        const TensorStorage<DeviceTensorStorageTrait>& src, size_t size) const {
364 365 366 367 368
    bool need_sync = false;
    mgb_assert(size <= this->size() && size <= src.size());
    if (m_comp_node != src.comp_node()) {
        auto default_cpu = CompNode::default_cpu();
        if (src.comp_node() != default_cpu) {
M
Megvii Engine Team 已提交
369 370
            mgb_assert(
                    m_comp_node == default_cpu,
371 372 373 374 375 376 377 378 379 380 381
                    "inconsistent D2H copy:"
                    " copy from device to host using different comp nodes:"
                    " device_node=%s host_node=%s",
                    src.comp_node().to_string().c_str(),
                    m_comp_node.to_string().c_str());
            // copy_from() should use m_comp_node, and default_cpu is
            // synchronous with current thread, so this copy has no
            // synchronizing ambiguity and we only need to sync on host
            need_sync = true;
        }
    }
382 383 384
    megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
    megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
    src.comp_node().copy_to_host_ref(dst_ptr, src_ptr, size);
385 386 387 388 389
    if (need_sync)
        src.comp_node().sync();
}

// host to device
M
Megvii Engine Team 已提交
390 391
template <>
template <>
392
MGE_WIN_DECLSPEC_FUC void TensorStorage<DeviceTensorStorageTrait>::copy_from(
M
Megvii Engine Team 已提交
393
        const TensorStorage<HostTensorStorageTrait>& src, size_t size) const {
394
    mgb_assert(size <= this->size() && size <= src.size());
395 396 397
    megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
    megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
    m_comp_node.copy_to_device_ref(dst_ptr, src_ptr, size);
398 399 400
}

// device to device
M
Megvii Engine Team 已提交
401 402
template <>
template <>
403
MGE_WIN_DECLSPEC_FUC void TensorStorage<DeviceTensorStorageTrait>::copy_from(
M
Megvii Engine Team 已提交
404
        const TensorStorage<DeviceTensorStorageTrait>& src, size_t size) const {
405
    mgb_assert(size <= this->size() && size <= src.size());
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
    if (src.comp_node().device_type() == CompNode::DeviceType::CPU &&
        comp_node().device_type() == CompNode::DeviceType::CUDA) {
        // current thread(i.e. cuda dispatcher thread) should wait for all
        // operations on src's comp_node to finish, otherwise a race condition
        // might occur between the worker thread of src's comp_node and the
        // thread responsible for copying pageable memory in \p src to a pinned
        // buffer, refer to
        // https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html
        //
        // Note: it is highly recommended that copy tensor from cpu to cuda
        // with asynchronized disaptching(see graph option async_exec_level),
        // or main thread might be blocked by worker thread corresponding to
        // the src's comp_node, resulting in bad performance
        //
        // TODO: consider using cudaMallocHost or cudaHostRegister
        // to pin the memory of src tensor, so it does not require synchronization
        // and is more efficient
        src.comp_node().sync();
424 425 426
        megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
        megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
        comp_node().copy_to_device_ref(dst_ptr, src_ptr, size);
427
    } else {
428 429 430
        megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
        megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
        src.comp_node().peer_copy_to_ref(m_comp_node, dst_ptr, src_ptr, size);
431
    }
432 433 434
}

// proxy host to device
M
Megvii Engine Team 已提交
435 436 437 438
template TensorStorage<DeviceTensorStorageTrait> TensorStorage<
        DeviceTensorStorageTrait>::
        make_proxy<HostTensorStorageTrait, void>(
                const TensorStorage<HostTensorStorageTrait>&);
439 440

// proxy device to host
M
Megvii Engine Team 已提交
441 442 443
template TensorStorage<HostTensorStorageTrait> TensorStorage<HostTensorStorageTrait>::
        make_proxy<DeviceTensorStorageTrait, void>(
                const TensorStorage<DeviceTensorStorageTrait>&);
444

M
Megvii Engine Team 已提交
445
}  // namespace mgb
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462

/* ===================== TensorND ===================== */

// ctor def {

#define DEF                        \
    template <class TensorStorage> \
    TensorND<TensorStorage>::TensorND
DEF() = default;

DEF(CompNode node) : m_storage{node} {}

DEF(DType dtype) : m_layout{dtype} {}

DEF(CompNode node, DType dtype) : m_storage{node}, m_layout{dtype} {}

//! allocate contiguous from given comp node, shape and dtype
463 464 465 466 467
DEF(CompNode node, const TensorShape& shape, DType dtype)
        : m_storage{node}, m_layout{dtype} {
    resize(shape);
}

468 469 470 471 472 473 474 475 476
DEF(CompNode node, const TensorShape& shape, DType dtype, TensorFormat format)
        : m_storage{node}, m_layout{dtype, format} {
    resize(shape);
}

//! allocate contiguous from given comp node and layout (strides not
//! used)
DEF(CompNode node, const TensorLayout& layout)
        : TensorND(node, layout, layout.dtype, layout.format) {
M
Megvii Engine Team 已提交
477 478 479 480
    mgb_assert(
            layout.is_contiguous(),
            "non-contiguous layout used for initializing a tensor: %s",
            layout.to_string().c_str());
481 482 483 484 485 486
}

#undef DEF
// ctor def }

// def {
M
Megvii Engine Team 已提交
487 488 489 490
#define DEF(name, ret)                                    \
    template <class TensorStorage>                        \
    typename TensorND<TensorStorage>::ChainReturnType ret \
            TensorND<TensorStorage>::name
491 492 493

DEF(resize, &)(const TensorShape& shape) {
    mgb_assert(m_layout.dtype.valid());
494
    m_layout.init_contiguous_stride(shape);
495
    m_storage.ensure_size(m_layout.span().dist_byte());
496 497 498
    return static_cast<ChainReturnType&>(*this);
}

M
Megvii Engine Team 已提交
499
DEF(reset, &)(TensorStorage storage, const TensorLayout& layout) {
500 501 502
    //! The storage to be reset is either satisfy the layout or empty.
    //! Empty storage is used after weight preprocess for saving memory and
    //! checking layout when running
M
Megvii Engine Team 已提交
503
    mgb_assert(!layout.ndim || storage.valid_span(layout.span()) || storage.empty());
504 505 506 507 508
    m_storage = std::move(storage);
    m_layout = layout;
    return static_cast<ChainReturnType&>(*this);
}

509 510 511 512 513 514 515 516 517 518 519
DEF(only_reset_raw_storage, &)(TensorStorage storage) {
    //! The storage to be reset is either satisfy the layout or empty.
    //! Empty storage is used after weight preprocess for saving memory and
    //! checking layout when running
    mgb_assert(storage.valid_span(m_layout.span()) || storage.empty());
    m_storage.only_reset_raw_storage(
            storage.comp_node(), storage.size(), storage.raw_storage(),
            storage.offset());
    return static_cast<ChainReturnType&>(*this);
}

520 521 522 523 524 525 526 527 528
DEF(comp_node, &)(CompNode comp_node, bool allow_mem_node_change) {
    auto orig_cn = m_storage.comp_node_allow_invalid();
    m_storage.comp_node(comp_node, allow_mem_node_change);
    if (orig_cn.valid() && orig_cn.mem_node() != comp_node.mem_node()) {
        m_layout.ndim = 0;
    }
    return static_cast<ChainReturnType&>(*this);
}

M
Megvii Engine Team 已提交
529 530
DEF(storage, &)(const TensorStorage& storage) {
    if (m_storage.empty() || storage.empty() || m_storage.ptr() != storage.ptr()) {
531 532 533 534 535 536 537 538
        m_storage = storage;
        m_layout.ndim = 0;
    }
    return static_cast<ChainReturnType&>(*this);
}

DEF(dtype, &)(DType dtype) {
    if (m_layout.dtype != dtype) {
539
        m_layout.modify_dtype_inplace(dtype);
540 541 542 543 544 545 546 547 548 549 550 551 552
        m_layout.ndim = 0;
    }
    return static_cast<ChainReturnType&>(*this);
}

DEF(format, &)(TensorFormat format) {
    if (m_layout.format != format) {
        m_layout.format = format;
        m_layout.ndim = 0;
    }
    return static_cast<ChainReturnType&>(*this);
}

M
Megvii Engine Team 已提交
553
DEF(operator[], )(std::initializer_list<Slice> slice) const {
554 555
    auto subspec = SubTensorSpec::make_from_offset_elem(m_layout, 0);
    size_t axis = 0;
M
Megvii Engine Team 已提交
556
    for (auto&& i : slice) {
557
        subspec.merge_with(i.apply(subspec.layout(), axis));
M
Megvii Engine Team 已提交
558
        axis++;
559 560 561 562
    }
    return sub(subspec);
}

M
Megvii Engine Team 已提交
563
DEF(sub, )(const SubTensorSpec& spec) const {
564 565 566 567 568 569 570 571 572 573 574 575 576 577
    mgb_assert(
            spec.layout().dtype == dtype() && spec.layout().format == format(),
            "invalid subtensor spec: sub_layout=%s self=%s",
            spec.layout().to_string().c_str(), m_layout.to_string().c_str());
    ChainReturnType rst;
    rst.reset(m_storage.sub(spec.offset_byte()), spec.layout());
    return rst;
}

#undef DEF

// def }

/* ===================== TensorND::copy_from ===================== */
578 579 580 581 582 583 584 585
namespace {
/**
 * \brief determine whether to check overlap of two tensors.
 * \return true : when HostStorage || (DeviceStorage && SUPPORT_UNIFIED_ADDRESS)
 * \note when both support unified address, we can treat them both on CPU. So,
 * overlap check should be done
 */
template <typename TensorStorage, typename RStorage>
M
Megvii Engine Team 已提交
586 587
inline bool should_check_overlap(
        const TensorND<TensorStorage>& dst, const TensorND<RStorage>& src) {
588 589 590 591 592 593
    return true;
}

template <>
inline bool should_check_overlap<HostTensorStorage, DeviceTensorStorage>(
        const HostTensorND& dst, const DeviceTensorND& src) {
M
Megvii Engine Team 已提交
594
    return src.comp_node().contain_flag(CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
595 596 597 598 599
}

template <>
inline bool should_check_overlap<DeviceTensorStorage, HostTensorStorage>(
        const DeviceTensorND& dst, const HostTensorND& src) {
M
Megvii Engine Team 已提交
600
    return dst.comp_node().contain_flag(CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
601 602 603 604 605 606 607 608 609 610 611
}

/**
 * \brief D2D tensor copy should check overlap when
 * 1. They are on the same mem node. But note that the address must be logical
 * comparable. i.e. the original address alloc on enflame is uncomparable.
 * 2. They both support unified address, so can be treated as CPU address.
 */
template <>
inline bool should_check_overlap<DeviceTensorStorage, DeviceTensorStorage>(
        const DeviceTensorND& dst, const DeviceTensorND& src) {
M
Megvii Engine Team 已提交
612 613 614 615
    bool is_same_memnode = dst.comp_node().mem_node() == src.comp_node().mem_node();
    bool unified_address =
            src.comp_node().contain_flag(CompNode::Flag::SUPPORT_UNIFIED_ADDRESS) &&
            dst.comp_node().contain_flag(CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
616 617 618 619 620 621
    return is_same_memnode || unified_address;
}

/**
 * \brief check overlap of two tensors. throw exception when overlapped
 */
M
Megvii Engine Team 已提交
622 623 624 625 626 627
inline void check_overlapped(
        const dt_byte* dst_min, const dt_byte* dst_max, const dt_byte* src_min,
        const dt_byte* src_max) {
    mgb_throw_if(
            src_min < dst_max && dst_min < src_max, TensorCopyOverlapError,
            "cound not perform copy between overlapped tensors");
628 629
}
}  // namespace
630

M
Megvii Engine Team 已提交
631 632 633 634
template <class TensorStorage>
template <class RStorage>
typename TensorND<TensorStorage>::ChainReturnType& TensorND<TensorStorage>::copy_from(
        const TensorND<RStorage>& src) {
635 636 637 638 639 640 641 642
    if (!m_storage.comp_node_valid())
        m_storage.comp_node(src.comp_node());

    if (m_layout.dtype.valid())
        m_layout.dtype.assert_is(src.dtype());
    else
        m_layout.dtype = src.dtype();

643 644
    m_layout = TensorLayout(src.shape(), m_layout.dtype);
    size_t size_bytes = m_layout.span().dist_byte();
645 646 647 648
    m_storage.ensure_size(size_bytes);
    if (!size_bytes) {
        return static_cast<ChainReturnType&>(*this);
    }
649 650 651 652
    // requirement:
    // default case, physical contiguous
    // lowbit aligned, logical contiguous
    if (src.layout().is_physical_contiguous() ||
M
Megvii Engine Team 已提交
653
        (src.layout().format.is_lowbit_aligned() && src.layout().is_contiguous())) {
654
        if (should_check_overlap(*this, src)) {
M
Megvii Engine Team 已提交
655 656 657
            check_overlapped(
                    m_storage.ptr(), m_storage.ptr() + size_bytes, src.storage().ptr(),
                    src.storage().ptr() + size_bytes);
658
        }
659 660 661 662 663 664 665 666
        m_storage.copy_from(src.storage(), size_bytes);
        return static_cast<ChainReturnType&>(*this);
    }
    return const_cast<ChainReturnType&>(copy_from_fixlayout(src));
}

template <class TensorStorage>
template <class RStorage>
M
Megvii Engine Team 已提交
667 668
const typename TensorND<TensorStorage>::ChainReturnType& TensorND<
        TensorStorage>::copy_from_fixlayout(const TensorND<RStorage>& src) const {
669
    dtype().assert_is(src.dtype());
M
Megvii Engine Team 已提交
670 671
    mgb_assert(
            m_layout.eq_shape(src.layout()),
672 673 674 675 676 677 678 679
            "shape differs in copy_from_fixlayout: %s vs %s",
            static_cast<const TensorShape&>(m_layout).to_string().c_str(),
            static_cast<const TensorShape&>(src.layout()).to_string().c_str());

    if (src.empty()) {
        return static_cast<const ChainReturnType&>(*this);
    }

M
Megvii Engine Team 已提交
680 681
    mgb_assert(
            m_layout.is_non_overlapping_strong(),
682 683
            "copy dest must have non-overlapping layout");

M
Megvii Engine Team 已提交
684
    TensorLayout::Span src_span = src.layout().span(), dst_span = layout().span();
685

686
    if (should_check_overlap(*this, src)) {
M
Megvii Engine Team 已提交
687 688 689 690
        check_overlapped(
                this->raw_ptr() + dst_span.low_byte,
                this->raw_ptr() + dst_span.high_byte, src.raw_ptr() + src_span.low_byte,
                src.raw_ptr() + src_span.high_byte);
691
    }
692

M
Megvii Engine Team 已提交
693 694 695
    bool self_contig =
                 m_layout.is_physical_contiguous() ||
                 (m_layout.format.is_lowbit_aligned() && m_layout.is_contiguous()),
696
         src_contig = src.layout().is_physical_contiguous() ||
697 698
                      (src.layout().format.is_lowbit_aligned() &&
                       src.layout().is_contiguous());
699
    if (self_contig && src_contig) {
M
Megvii Engine Team 已提交
700
        if ((m_layout.format.is_default() && src.layout().format.is_default()) ||
701 702
            (m_layout.format.is_lowbit_aligned() &&
             src.layout().format.is_lowbit_aligned())) {
M
Megvii Engine Team 已提交
703 704 705
            mgb_assert(
                    src_span.low_byte == 0 && dst_span.low_byte == 0 &&
                    src_span.high_byte == dst_span.high_byte);
706 707 708
            m_storage.copy_from(src.storage(), src_span.high_byte);
        } else {
            mgb_assert(src_span.low_byte == 0 && dst_span.low_byte == 0);
M
Megvii Engine Team 已提交
709 710
            m_storage.copy_from(
                    src.storage(), std::min(src_span.high_byte, dst_span.high_byte));
711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726
        }
        return static_cast<const ChainReturnType&>(*this);
    }
    noncont_tensor_copy(*this, src, self_contig, src_contig);
    return static_cast<const ChainReturnType&>(*this);
}

/* =================== misc =================== */

void mgb::dev_tensor_memset(const DeviceTensorND& tensor, int val) {
    auto&& env = CompNodeEnv::from_comp_node(tensor.comp_node());
    env.activate();
    size_t size = tensor.layout().span().dist_byte();
    switch (env.property().type) {
#if MGB_CUDA
        case CompNode::DeviceType::CUDA:
727 728
            MGB_CUDA_CHECK(cudaMemsetAsync(
                    tensor.raw_ptr(), val, size, env.cuda_env().stream));
729
            break;
730 731
#endif
#if MGB_ATLAS
M
Megvii Engine Team 已提交
732
        case CompNode::DeviceType::ATLAS:
733
#if MGB_USE_ATLAS_ASYNC_API
734 735
            MGB_ATLAS_CHECK(aclrtMemsetAsync(
                    tensor.raw_ptr(), -1, val, size, env.atlas_env().stream));
736
#else
737
            MGB_ATLAS_CHECK(aclrtMemset(tensor.raw_ptr(), -1, val, size));
738
#endif
M
Megvii Engine Team 已提交
739
            break;
740
#endif
741
#if MGB_CAMBRICON
M
Megvii Engine Team 已提交
742 743
        case CompNode::DeviceType::CAMBRICON:
            MGB_CNRT_CHECK(cnrtSyncQueue(env.cnrt_env().queue));
744
            MGB_CNRT_CHECK(cnrtMemset(tensor.raw_ptr(), val, size));
M
Megvii Engine Team 已提交
745
            break;
746
#endif
M
Megvii Engine Team 已提交
747
        case CompNode::DeviceType::CPU: {
748 749 750
            auto fill = [tensor, size, val]() {
                std::memset(tensor.as_megdnn().raw_ptr(), val, size);
            };
751 752 753
            env.cpu_env().dispatch(fill);
        } break;
        default:
M
Megvii Engine Team 已提交
754 755 756
            mgb_throw(
                    MegBrainError, "unhandled comp node in dev_tensor_memset: %s",
                    tensor.comp_node().to_string().c_str());
757 758 759 760
    }
}

namespace mgb {
M
Megvii Engine Team 已提交
761 762 763 764
template class TensorStorage<HostTensorStorageTrait>;
template class TensorStorage<DeviceTensorStorageTrait>;
template class TensorND<TensorStorage<HostTensorStorageTrait>>;
template class TensorND<TensorStorage<DeviceTensorStorageTrait>>;
765

M
Megvii Engine Team 已提交
766
/* ===== copy_from related ===== */
767 768 769

#define HT_RAW TensorND<HostTensorStorage>
#define DT_RAW TensorND<DeviceTensorStorage>
M
Megvii Engine Team 已提交
770 771
#define HT(f)  f<HostTensorStorage>(const HT_RAW&)
#define DT(f)  f<DeviceTensorStorage>(const DT_RAW&)
772

M
Megvii Engine Team 已提交
773 774 775
#define INST(f, c)                              \
    template c HostTensorND& HT_RAW::HT(f) c;   \
    template c HostTensorND& HT_RAW::DT(f) c;   \
776 777 778
    template c DeviceTensorND& DT_RAW::HT(f) c; \
    template c DeviceTensorND& DT_RAW::DT(f) c

M
Megvii Engine Team 已提交
779 780
INST(copy_from, );
INST(copy_from_fixlayout, const);
781 782 783 784 785 786 787

#undef INST
#undef DT
#undef HT
#undef DT_RAW
#undef HT_RAW

M
Megvii Engine Team 已提交
788
}  // namespace mgb
789 790

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}