提交 5e7d2a91 编写于 作者: M Megvii Engine Team

refactor(mgb): add TensorND::proxy_to_default_cpu

GitOrigin-RevId: 3ab8525f1c0e04d128632982a7f3ffdb78971a23
上级 7a8a2830
......@@ -101,6 +101,14 @@ class Slice {
SubTensorSpec apply(TensorLayout layout, int axis) const;
};
template <class Trait> class TensorStorage;
class DeviceTensorStorageTrait;
class HostTensorStorageTrait;
using HostTensorStorage = TensorStorage<HostTensorStorageTrait>;
using DeviceTensorStorage = TensorStorage<DeviceTensorStorageTrait>;
/*!
* \brief manager for raw tensor memory
*
......@@ -230,6 +238,18 @@ class TensorStorage {
std::enable_if<!std::is_same<Trait, RTrait>::value>::type>
static TensorStorage make_proxy(const TensorStorage<RTrait> &src);
/*!
* \brief make a DeviceTensorStorage on default_cpu
* that shares memory with this
*
* this must be a HostTensorStorage. Alignment not checked.
*/
template<bool x = true, typename = std::enable_if_t<x && std::is_same<Trait, HostTensorStorageTrait>::value>>
DeviceTensorStorage proxy_to_default_cpu() const {
ptr();
return {true, CompNode::default_cpu(), m_size, m_capacity, m_offset, m_data};
}
//! shortcut for raw_storage().use_count(), but won't trigger lazy alloc
size_t use_count() const {
if (m_size > m_capacity) {
......@@ -284,11 +304,12 @@ class TensorStorage {
[[noreturn]] static void on_invalid_comp_node();
};
class DeviceTensorStorageTrait;
class HostTensorStorageTrait;
using HostTensorStorage = TensorStorage<HostTensorStorageTrait>;
using DeviceTensorStorage = TensorStorage<DeviceTensorStorageTrait>;
template<class TensorStorage> class TensorND;
using HostTensorND = TensorND<HostTensorStorage>;
using DeviceTensorND = TensorND<DeviceTensorStorage>;
/*!
* \brief n-dimensional tensor
......@@ -519,10 +540,15 @@ class TensorND {
ret.reset(TensorStorage::make_proxy(src.storage()), src.layout());
return ret;
}
};
using HostTensorND = TensorND<HostTensorStorage>;
using DeviceTensorND = TensorND<DeviceTensorStorage>;
//! similar to HostTensorStorage::proxy_to_default_cpu
template<bool x = true, typename = std::enable_if_t<x && std::is_same<TensorStorage, HostTensorStorage>::value>>
DeviceTensorND proxy_to_default_cpu() const {
DeviceTensorND ret;
ret.reset(storage().proxy_to_default_cpu(), layout());
return ret;
}
};
/*!
* \brief call memset in the data of a device tensor
......
......@@ -418,4 +418,13 @@ TEST(TestTensor, CpuCudaD2DCopy) {
}
}
TEST(TestTensor, ProxyToDefaultCPU) {
auto cn = CompNode::load("xpux");
auto x = HostTensorND(cn, TensorLayout({1, 2, 3}, dtype::Float32{}));
auto y = x.proxy_to_default_cpu();
ASSERT_EQ(y.comp_node(), CompNode::default_cpu());
ASSERT_EQ(x.layout(), y.layout());
ASSERT_EQ(x.raw_ptr(), y.raw_ptr());
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册