diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 45fda4dbfe218f74da72236ce5e378a681c424cc..960512a48d200bbc1e218a5282ab0f1873039119 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -149,9 +149,13 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { const_cast(value).reset(value.storage(), layout); } auto info = alloc(); - init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()}); + constexpr int size_threshold = TensorShape::MAX_NDIM; + init(info, {value.layout(), value.comp_node()}); + if (value.layout().total_nr_elems() <= size_threshold) { + info->h_value = value; + info->desc.value = value.proxy_to_default_cpu(); + } info->mem_desc.id = StorageIdentifier::make(++m_storage_id); - info->h_value = value; m_buffer.enqueue(Put{info, value, no_cache}); if (m_async_level == 0) { sync_impl(); diff --git a/imperative/src/impl/physical_tensor.cpp b/imperative/src/impl/physical_tensor.cpp index 12df60bf472cb405ffe3db1ee13d278372d89326..84035f85817029b5632ea78670e8a4c75268457b 100644 --- a/imperative/src/impl/physical_tensor.cpp +++ b/imperative/src/impl/physical_tensor.cpp @@ -130,7 +130,10 @@ Tensor::Tensor( : m_layout(layout), m_blob(std::move(blob)), m_offset(offset), m_value(hv) {} Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) { - m_value = hv; + constexpr int size_threshold = TensorShape::MAX_NDIM; + if (hv.layout().total_nr_elems() <= size_threshold) { + m_value = hv; + } MGB_RECORD_EVENT( profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), dev_tensor().raw_ptr());