提交 f5b8fec4 编写于 作者: M Megvii Engine Team

fix(imperative): remove big tensor from host side

GitOrigin-RevId: 2047982d7331df4d0aa2f6e97a20bf7c56cdad1c
上级 68cde873
......@@ -149,9 +149,13 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
const_cast<HostTensorND&>(value).reset(value.storage(), layout);
}
auto info = alloc();
init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()});
constexpr int size_threshold = TensorShape::MAX_NDIM;
init(info, {value.layout(), value.comp_node()});
if (value.layout().total_nr_elems() <= size_threshold) {
info->h_value = value;
info->desc.value = value.proxy_to_default_cpu();
}
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
info->h_value = value;
m_buffer.enqueue(Put{info, value, no_cache});
if (m_async_level == 0) {
sync_impl();
......
......@@ -130,7 +130,10 @@ Tensor::Tensor(
: m_layout(layout), m_blob(std::move(blob)), m_offset(offset), m_value(hv) {}
Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) {
m_value = hv;
constexpr int size_threshold = TensorShape::MAX_NDIM;
if (hv.layout().total_nr_elems() <= size_threshold) {
m_value = hv;
}
MGB_RECORD_EVENT(
profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(),
dev_tensor().raw_ptr());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册