提交 cf53d9e0 编写于 作者: M Megvii Engine Team

fix(mgb/tensor): do tensor overlap check only when d2d and h2h

GitOrigin-RevId: 9125936a350baab699961b26decd314a39921545
上级 7e2b2dbf
......@@ -18,7 +18,9 @@
namespace mgb {
class AtlasCompNode final : public CompNodeImplHelper {
public:
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM;
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;
class CompNodeImpl;
class EventImpl;
......
......@@ -16,8 +16,9 @@
namespace mgb {
class CambriconCompNode final: public CompNodeImplHelper {
public:
static constexpr Flag sm_flag =
Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM;
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;
class CompNodeImpl;
class EventImpl;
......
......@@ -38,7 +38,8 @@ namespace mgb {
static constexpr Flag sm_flag =
Flag::SUPPORT_RECORDER |
Flag::RECORDER_SUPPORT_DYNAMIC_ALLOC |
Flag::EVENT_DTOR_UNSAFE;
Flag::EVENT_DTOR_UNSAFE |
Flag::SUPPORT_UNIFIED_ADDRESS;
//! base class for comp nodes that can be dispatched on CPU.
//! This is currently used by CPU, FPGA and CADENCE
......
......@@ -16,8 +16,9 @@
namespace mgb {
class CudaCompNode final: public CompNodeImplHelper {
public:
static constexpr Flag sm_flag =
Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM;
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;
class CompNodeImpl;
class EventImpl;
......
......@@ -16,7 +16,9 @@
namespace mgb {
class ROCmCompNode final : public CompNodeImplHelper {
public:
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM;
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;
class CompNodeImpl;
class EventImpl;
......
......@@ -518,6 +518,60 @@ DEF(sub, )(const SubTensorSpec &spec) const {
// def }
/* ===================== TensorND::copy_from ===================== */
namespace {
/**
* \brief determine whether to check overlap of two tensors.
* \return true : when HostStorage || (DeviceStorage && SUPPORT_UNIFIED_ADDRESS)
* \note when both support unified address, we can treat them both on CPU. So,
* overlap check should be done
*/
template <typename TensorStorage, typename RStorage>
inline bool should_check_overlap(const TensorND<TensorStorage>& dst,
const TensorND<RStorage>& src) {
return true;
}
template <>
inline bool should_check_overlap<HostTensorStorage, DeviceTensorStorage>(
const HostTensorND& dst, const DeviceTensorND& src) {
return src.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
}
template <>
inline bool should_check_overlap<DeviceTensorStorage, HostTensorStorage>(
const DeviceTensorND& dst, const HostTensorND& src) {
return dst.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
}
/**
* \brief D2D tensor copy should check overlap when
* 1. They are on the same mem node. But note that the address must be logical
* comparable. i.e. the original address alloc on enflame is uncomparable.
* 2. They both support unified address, so can be treated as CPU address.
*/
template <>
inline bool should_check_overlap<DeviceTensorStorage, DeviceTensorStorage>(
const DeviceTensorND& dst, const DeviceTensorND& src) {
bool is_same_memnode =
dst.comp_node().mem_node() == src.comp_node().mem_node();
bool unified_address = src.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS) &&
dst.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
return is_same_memnode || unified_address;
}
/**
* \brief check overlap of two tensors. throw exception when overlapped
*/
inline void check_overlapped(const dt_byte* dst_min, const dt_byte* dst_max,
const dt_byte* src_min, const dt_byte* src_max) {
mgb_throw_if(src_min < dst_max && dst_min < src_max, TensorCopyOverlapError,
"cound not perform copy between overlapped tensors");
}
} // namespace
template<class TensorStorage>
template<class RStorage>
......@@ -539,12 +593,12 @@ TensorND<TensorStorage>::copy_from(const TensorND<RStorage> &src) {
return static_cast<ChainReturnType&>(*this);
}
if (src.layout().is_physical_contiguous()) {
const dt_byte
*dst_min = m_storage.ptr(), *dst_max = dst_min + size_bytes,
*src_min = src.storage().ptr(), *src_max = src_min + size_bytes;
mgb_throw_if(src_max > dst_min && dst_max > src_min,
TensorCopyOverlapError,
"cound not perform copy between overlapped tensors");
if (should_check_overlap(*this, src)) {
check_overlapped(m_storage.ptr(),
m_storage.ptr() + size_bytes,
src.storage().ptr(),
src.storage().ptr() + size_bytes);
}
m_storage.copy_from(src.storage(), size_bytes);
return static_cast<ChainReturnType&>(*this);
}
......@@ -574,15 +628,12 @@ TensorND<TensorStorage>::copy_from_fixlayout(
src_span = src.layout().span(),
dst_span = layout().span();
const dt_byte
*src_ptr_min = src.raw_ptr() + src_span.low_byte,
*src_ptr_max = src.raw_ptr() + src_span.high_byte,
*dst_ptr_min = this->raw_ptr() + dst_span.low_byte,
*dst_ptr_max = this->raw_ptr() + dst_span.high_byte;
mgb_throw_if(src_ptr_max > dst_ptr_min && dst_ptr_max > src_ptr_min,
TensorCopyOverlapError,
"cound not perform copy between overlapped tensors");
if (should_check_overlap(*this, src)) {
check_overlapped(this->raw_ptr() + dst_span.low_byte,
this->raw_ptr() + dst_span.high_byte,
src.raw_ptr() + src_span.low_byte,
src.raw_ptr() + src_span.high_byte);
}
bool self_contig = m_layout.is_physical_contiguous(),
src_contig = src.layout().is_physical_contiguous();
......
......@@ -436,6 +436,10 @@ class CompNode {
//! MGB_HAVE_THREAD=0. Usually this means that execution on the
//! CompNode is synchronous, i.e. behaves like cpu:default
SUPPORT_NO_THREAD = 1 << 5,
//! Whether this comp node supports unified address. i.e. CPU and
//! CUDA supports unified address.
SUPPORT_UNIFIED_ADDRESS = 1 << 6,
};
bool contain_flag(Flag flag) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册