提交 4f77509e 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mgb/opr): allow empty ImmutableTensor

Fixes MGE-675.

GitOrigin-RevId: c6771740fc48226f1b7c79d519de61445e671290
上级 6a7e7ce1
......@@ -761,7 +761,8 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr(
if (is_const_var(m_const_var_type, opr)) {
auto sz = var_mem_size(opr->output(0));
mgb_assert(sz);
mgb_assert(sz || opr->output(0)->contain_flag(
VarNode::Flag::ALLOW_EMPTY_SHAPE));
info.is_const = true;
info.max_size = sz;
return make_ret();
......
......@@ -382,7 +382,7 @@ class ImmutableTensor::Value {
void setup(CompNode cn, const HostTensorND &val);
bool initialized() const {
return !m_dev.empty();
return m_dev.shape_valid();
}
//! value on comp node
......@@ -400,8 +400,9 @@ class ImmutableTensor::Value {
};
void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) {
mgb_assert(m_dev.empty() && !val.empty());
mgb_assert(m_dev.empty() && !m_dev.shape_valid());
m_dev.comp_node(cn).copy_from(val).sync();
mgb_assert(val.empty() == m_dev.empty());
auto one_elem = [](const TensorShape& shape) {
for (size_t i = 0; i < shape.ndim; ++i) {
......@@ -446,6 +447,7 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
HostTensorND m_val_ref;
const dt_byte* val_ptr() const {
mgb_assert(m_trait.size_bytes);
return m_val.empty() ? m_val_ref.raw_ptr() : m_val.data();
}
......@@ -454,9 +456,8 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
TensorKey(const HostTensorND &v):
m_val_ref{v}
{
mgb_assert(v.layout().is_contiguous());
mgb_assert(v.layout().is_contiguous() || v.layout().is_empty());
m_trait.size_bytes = v.layout().span().high_byte;
mgb_assert(m_trait.size_bytes);
auto &&layout = m_trait.layout;
// zero to enable byte-comparison
......@@ -467,15 +468,19 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
layout.shape[i] = v.layout().shape[i];
layout.stride[i] = v.layout().stride[i];
}
m_trait.hash = XXHash{}.
update(v.raw_ptr(), m_trait.size_bytes).
update(&m_trait.layout, sizeof(m_trait.layout)).
digest();
XXHash hasher;
if (!v.empty()) {
hasher.update(v.raw_ptr(), m_trait.size_bytes);
}
hasher.update(&m_trait.layout, sizeof(m_trait.layout));
m_trait.hash = hasher.digest();
}
bool operator == (const TensorKey &rhs) const {
return !memcmp(&m_trait, &rhs.m_trait, sizeof(Trait)) &&
!memcmp(val_ptr(), rhs.val_ptr(), m_trait.size_bytes);
((m_trait.size_bytes == 0 &&
rhs.m_trait.size_bytes == 0) ||
!memcmp(val_ptr(), rhs.val_ptr(), m_trait.size_bytes));
}
size_t hash() const {
......@@ -485,6 +490,7 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
//! copy from m_val_ref to m_val, to avoid refed value being
//! modified
void copy_val_permanent() {
if (m_trait.size_bytes == 0) return;
mgb_assert(m_val.empty());
m_val.resize(m_trait.size_bytes);
memcpy(m_val.data(), m_val_ref.raw_ptr(), m_trait.size_bytes);
......@@ -544,7 +550,6 @@ class ImmutableTensor::DevValueCache final: public UserDataContainer::UserData {
}
const Value& get(const HostTensorND &tensor) {
mgb_assert(!tensor.empty());
if (tensor.shape().is_scalar()) {
return get(DTypeScalar::make_from_raw(
tensor.dtype(), tensor.raw_ptr()));
......@@ -595,6 +600,7 @@ ImmutableTensor::ImmutableTensor(ComputingGraph &graph,
add_output(value.dev().dtype());
add_equivalence_component<ScalarHash<const void*>>(&value);
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
ImmutableTensor::~ImmutableTensor() noexcept = default;
......
......@@ -177,6 +177,17 @@ TEST(TestOprIO, ImmutableTensorLarge) {
}
}
TEST(TestOprIO, ImmutableTensorEmpty) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
auto host_x = gen({1, 9, 1, 9, 8, 1, 0});
auto x = opr::ImmutableTensor::make(*graph, *host_x);
HostTensorND host_x2;
auto func = graph->compile({make_callback_copy(x, host_x2)});
func->execute();
ASSERT_TRUE(host_x2.shape().is_empty());
}
TEST(TestOprIO, SharedDeviceTensor) {
HostTensorGenerator<> gen;
auto hv = gen({123});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册