提交 cf53d9e0 编写于 作者: M Megvii Engine Team

fix(mgb/tensor): do tensor overlap check only when d2d and h2h

GitOrigin-RevId: 9125936a350baab699961b26decd314a39921545
上级 7e2b2dbf
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
namespace mgb { namespace mgb {
class AtlasCompNode final : public CompNodeImplHelper { class AtlasCompNode final : public CompNodeImplHelper {
public: public:
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM; static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;
class CompNodeImpl; class CompNodeImpl;
class EventImpl; class EventImpl;
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
namespace mgb { namespace mgb {
class CambriconCompNode final: public CompNodeImplHelper { class CambriconCompNode final: public CompNodeImplHelper {
public: public:
static constexpr Flag sm_flag = static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM; Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;
class CompNodeImpl; class CompNodeImpl;
class EventImpl; class EventImpl;
......
...@@ -38,7 +38,8 @@ namespace mgb { ...@@ -38,7 +38,8 @@ namespace mgb {
static constexpr Flag sm_flag = static constexpr Flag sm_flag =
Flag::SUPPORT_RECORDER | Flag::SUPPORT_RECORDER |
Flag::RECORDER_SUPPORT_DYNAMIC_ALLOC | Flag::RECORDER_SUPPORT_DYNAMIC_ALLOC |
Flag::EVENT_DTOR_UNSAFE; Flag::EVENT_DTOR_UNSAFE |
Flag::SUPPORT_UNIFIED_ADDRESS;
//! base class for comp nodes that can be dispatched on CPU. //! base class for comp nodes that can be dispatched on CPU.
//! This is currently used by CPU, FPGA and CADENCE //! This is currently used by CPU, FPGA and CADENCE
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
namespace mgb { namespace mgb {
class CudaCompNode final: public CompNodeImplHelper { class CudaCompNode final: public CompNodeImplHelper {
public: public:
static constexpr Flag sm_flag = static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM; Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;
class CompNodeImpl; class CompNodeImpl;
class EventImpl; class EventImpl;
......
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
namespace mgb { namespace mgb {
class ROCmCompNode final : public CompNodeImplHelper { class ROCmCompNode final : public CompNodeImplHelper {
public: public:
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM; static constexpr Flag sm_flag = Flag::QUEUE_LIMITED |
Flag::HAS_COPY_STREAM |
Flag::SUPPORT_UNIFIED_ADDRESS;
class CompNodeImpl; class CompNodeImpl;
class EventImpl; class EventImpl;
......
...@@ -518,6 +518,60 @@ DEF(sub, )(const SubTensorSpec &spec) const { ...@@ -518,6 +518,60 @@ DEF(sub, )(const SubTensorSpec &spec) const {
// def } // def }
/* ===================== TensorND::copy_from ===================== */ /* ===================== TensorND::copy_from ===================== */
namespace {
/**
* \brief determine whether to check overlap of two tensors.
* \return true : when HostStorage || (DeviceStorage && SUPPORT_UNIFIED_ADDRESS)
* \note when both support unified address, we can treat them both on CPU. So,
* overlap check should be done
*/
template <typename TensorStorage, typename RStorage>
inline bool should_check_overlap(const TensorND<TensorStorage>& dst,
const TensorND<RStorage>& src) {
return true;
}
template <>
inline bool should_check_overlap<HostTensorStorage, DeviceTensorStorage>(
const HostTensorND& dst, const DeviceTensorND& src) {
return src.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
}
template <>
inline bool should_check_overlap<DeviceTensorStorage, HostTensorStorage>(
const DeviceTensorND& dst, const HostTensorND& src) {
return dst.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
}
/**
* \brief D2D tensor copy should check overlap when
* 1. They are on the same mem node. But note that the address must be logical
* comparable. i.e. the original address alloc on enflame is uncomparable.
* 2. They both support unified address, so can be treated as CPU address.
*/
template <>
inline bool should_check_overlap<DeviceTensorStorage, DeviceTensorStorage>(
const DeviceTensorND& dst, const DeviceTensorND& src) {
bool is_same_memnode =
dst.comp_node().mem_node() == src.comp_node().mem_node();
bool unified_address = src.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS) &&
dst.comp_node().contain_flag(
CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
return is_same_memnode || unified_address;
}
/**
* \brief check overlap of two tensors. throw exception when overlapped
*/
inline void check_overlapped(const dt_byte* dst_min, const dt_byte* dst_max,
const dt_byte* src_min, const dt_byte* src_max) {
mgb_throw_if(src_min < dst_max && dst_min < src_max, TensorCopyOverlapError,
"cound not perform copy between overlapped tensors");
}
} // namespace
template<class TensorStorage> template<class TensorStorage>
template<class RStorage> template<class RStorage>
...@@ -539,12 +593,12 @@ TensorND<TensorStorage>::copy_from(const TensorND<RStorage> &src) { ...@@ -539,12 +593,12 @@ TensorND<TensorStorage>::copy_from(const TensorND<RStorage> &src) {
return static_cast<ChainReturnType&>(*this); return static_cast<ChainReturnType&>(*this);
} }
if (src.layout().is_physical_contiguous()) { if (src.layout().is_physical_contiguous()) {
const dt_byte if (should_check_overlap(*this, src)) {
*dst_min = m_storage.ptr(), *dst_max = dst_min + size_bytes, check_overlapped(m_storage.ptr(),
*src_min = src.storage().ptr(), *src_max = src_min + size_bytes; m_storage.ptr() + size_bytes,
mgb_throw_if(src_max > dst_min && dst_max > src_min, src.storage().ptr(),
TensorCopyOverlapError, src.storage().ptr() + size_bytes);
"cound not perform copy between overlapped tensors"); }
m_storage.copy_from(src.storage(), size_bytes); m_storage.copy_from(src.storage(), size_bytes);
return static_cast<ChainReturnType&>(*this); return static_cast<ChainReturnType&>(*this);
} }
...@@ -574,15 +628,12 @@ TensorND<TensorStorage>::copy_from_fixlayout( ...@@ -574,15 +628,12 @@ TensorND<TensorStorage>::copy_from_fixlayout(
src_span = src.layout().span(), src_span = src.layout().span(),
dst_span = layout().span(); dst_span = layout().span();
const dt_byte if (should_check_overlap(*this, src)) {
*src_ptr_min = src.raw_ptr() + src_span.low_byte, check_overlapped(this->raw_ptr() + dst_span.low_byte,
*src_ptr_max = src.raw_ptr() + src_span.high_byte, this->raw_ptr() + dst_span.high_byte,
*dst_ptr_min = this->raw_ptr() + dst_span.low_byte, src.raw_ptr() + src_span.low_byte,
*dst_ptr_max = this->raw_ptr() + dst_span.high_byte; src.raw_ptr() + src_span.high_byte);
}
mgb_throw_if(src_ptr_max > dst_ptr_min && dst_ptr_max > src_ptr_min,
TensorCopyOverlapError,
"cound not perform copy between overlapped tensors");
bool self_contig = m_layout.is_physical_contiguous(), bool self_contig = m_layout.is_physical_contiguous(),
src_contig = src.layout().is_physical_contiguous(); src_contig = src.layout().is_physical_contiguous();
......
...@@ -436,6 +436,10 @@ class CompNode { ...@@ -436,6 +436,10 @@ class CompNode {
//! MGB_HAVE_THREAD=0. Usually this means that execution on the //! MGB_HAVE_THREAD=0. Usually this means that execution on the
//! CompNode is synchronous, i.e. behaves like cpu:default //! CompNode is synchronous, i.e. behaves like cpu:default
SUPPORT_NO_THREAD = 1 << 5, SUPPORT_NO_THREAD = 1 << 5,
//! Whether this comp node supports unified address. i.e. CPU and
//! CUDA supports unified address.
SUPPORT_UNIFIED_ADDRESS = 1 << 6,
}; };
bool contain_flag(Flag flag) { bool contain_flag(Flag flag) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册