提交 432fdb7e 编写于 作者: M Megvii Engine Team

feat(mgb/opr): let SetSubtensor support empty IO

GitOrigin-RevId: 13909e3b11f54a1f3ea0acc7af259a4804c2d77b
上级 e954b8f9
......@@ -608,3 +608,50 @@ def test_subtensor_on_empty_tensor(symbolic):
run_test(lambda x: x[3, 10:1:1, 5:-1])
run_test(lambda x: x[:100, :100, :100])
run_test(lambda x: x[100:200, 300:400, 500:600])
@pytest.mark.parametrize("symbolic", [True, False, None])
def test_setsubtensor_on_empty_tensor(symbolic):
def run_test(inp_shp, fn):
np_x = np.random.randn(*inp_shp).astype(np.float32)
mge_x = megengine.tensor(np_x)
out_ref = fn(np_x)
if symbolic is not None:
fn = jit.trace(symbolic=symbolic)(fn)
for i in range(3):
out = fn(mge_x)
np.testing.assert_equal(out.numpy(), out_ref)
def test1(x):
x[1:100:2, :, :] = x[1:100:2, :, :]
return x
def test2(x):
x[-10:5:2, :, :] = x[-10:5:2, :, :]
return x
def test3(x):
x[5:1:-1, :, :] = x[5:1:-1, :, :]
return x
def test4(x):
x[3, 10:1:1, 5:-1] = x[3, 10:1:1, 5:-1]
return x
def test5(x):
x[:100, :100, :100] = x[:100, :100, :100]
return x
def test6(x):
x[100:200, 300:400, 500:600] = x[100:200, 300:400, 500:600]
return x
run_test((10, 0, 10), test1)
run_test((10, 0, 10), test2)
run_test((10, 0, 10), test3)
run_test((10, 0, 10), test4)
run_test((10, 0, 10), test5)
run_test((10, 0, 10), test6)
run_test((10, 10, 10), test4)
run_test((10, 10, 10), test5)
run_test((10, 10, 10), test6)
......@@ -134,6 +134,11 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
}
TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
if (value.empty()) {
auto layout = value.layout();
layout.init_contiguous_stride();
const_cast<HostTensorND&>(value).reset(value.storage(), layout);
}
auto info = alloc();
init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()});
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
......
......@@ -819,10 +819,36 @@ void ModifySubtensorImplHelper::init_output_static_infer_desc() {
/* f{{{ ======================= SetSubtensor ======================= */
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetSubtensor, "set_subtensor", true);
SetSubtensor::SetSubtensor(VarNode *inp, VarNode *value, const IndexDesc &desc,
const OperatorNodeConfig &config,
const InputTensorReplacer &input_tensor_replacer):
Super({inp->owner_graph(), config, "set_subtensor", {inp, value}},
inp, value, desc, true, input_tensor_replacer) {
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
SymbolVar SetSubtensor::make(SymbolVar inp, SymbolVar value, const IndexDesc &desc,
const OperatorNodeConfig &config,
const InputTensorReplacer &input_tensor_replacer) {
return inp.insert_single_output_opr<SetSubtensor>(
inp.node(), value.node(), desc, config, input_tensor_replacer);
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SetSubtensor);
void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
sub.copy_from_fixlayout(val);
if (!val.layout().is_empty()) {
sub.copy_from_fixlayout(val);
}
}
SetSubtensor::NodeProp* SetSubtensor::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(1),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
#if MGB_ENABLE_GRAD
......
......@@ -374,6 +374,7 @@ MGB_DEFINE_OPR_CLASS(Subtensor,
MGB_DEFINE_OPR_CLASS(SetSubtensor, intl::ModifySubtensorImplHelper) // {
void modify(DeviceTensorND &sub, const DeviceTensorND &val) override;
NodeProp* do_make_node_prop() const override;
public:
MGB_DECL_FANCY_INDEXING_OPR_MODIFY(SetSubtensor);
......
......@@ -935,6 +935,48 @@ TEST(TestTensorManip, SubtensorEmptyIO) {
});
}
TEST(TestTensorManip, SetSubtensorEmptyIO) {
using AIdx = opr::SetSubtensor::AxisIndexer;
using IndexDesc = std::vector<AIdx>;
using IndexDescCreater = thin_function<IndexDesc(SymbolVar)>;
HostTensorGenerator<> gen;
auto run = [&](const TensorShape& inp_shp, const TensorShape& val_shp, const IndexDescCreater& c) {
auto host_x = gen(inp_shp),
host_v = gen(val_shp);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
v = opr::Host2DeviceCopy::make(*graph, host_v);
auto y = opr::SetSubtensor::make(x, v, c(x));
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(host_y.shape(), inp_shp);
};
// x.shape = {0}, v.shape = {0}, x[:0] = v
run({0}, {0}, [&](SymbolVar x)->IndexDesc {
return {AIdx::make_interval(0, None, x.make_scalar(0), None)};
});
// x.shape = {100, 0}, v.shape = {45, 0}, x[0:-10:2] = v
run({100, 0}, {45, 0}, [&](SymbolVar x)->IndexDesc {
return {AIdx::make_interval(0, x.make_scalar(0), x.make_scalar(-10), x.make_scalar(2))};
});
// x.shape = {100, 0}, v.shape = {40, 0}, x[10:-10:2, 0:0] = v
run({100, 0}, {40, 0}, [&](SymbolVar x)->IndexDesc {
return {AIdx::make_interval(0, x.make_scalar(10), x.make_scalar(-10), x.make_scalar(2)),
AIdx::make_interval(1, x.make_scalar(0), x.make_scalar(0), None)};
});
// x.shape = {10, 0, 10}, v.shape = {0, 10}, x[5, 10:-10:-2] = v
run({10, 0, 10}, {0, 10}, [&](SymbolVar x)->IndexDesc {
return {AIdx::make_index(0, x.make_scalar(5)),
AIdx::make_interval(1, x.make_scalar(10), x.make_scalar(-10), x.make_scalar(2))};
});
// x.shape = {10}, v.shape = {0}, x[100:] = v
run({10}, {0}, [&](SymbolVar x)->IndexDesc {
return {AIdx::make_interval(0, x.make_scalar(100), None, None)};
});
}
namespace {
void test_subtensor_fwdonly(bool dyn_inp, bool dyn_idx) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册