提交 0a191f8d 编写于 作者: M Megvii Engine Team

fix(misc): some fixes to make faster_rcnn run in no-symbolic-trace

GitOrigin-RevId: 0fc40ad9c8ad02a8d2ae3e41bcf36f73dccff71e
上级 f95bb7b7
......@@ -234,15 +234,10 @@ def setitem(tensor, index, value):
try_result = try_condtake(tensor, index)
if len(try_result) == 2:
index = try_result[1]
if index.shape[0] == 0:
return tensor
tensor = tensor.reshape(-1)
if not isinstance(value, Tensor):
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)()
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
for v in tensors:
if len(v.shape) > 0 and v.shape[0] == 0:
return tensor
if use_subtensor:
op = builtin.Subtensor(items=items)
else:
......@@ -250,19 +245,17 @@ def setitem(tensor, index, value):
(tmp_result,) = apply(op, tensor, *tensors)
# XXX: broadcast can always be applied even if shapes are equal
if make_shape_tuple(value.shape) != make_shape_tuple(tmp_result.shape):
for i in range(min(len(value.shape), len(tmp_result.shape))):
if (
value.shape[-i - 1] != 1
and value.shape[-i - 1] != tmp_result.shape[-i - 1]
):
raise ValueError(
"cannot copy tensor with shape {} to subtensor with shape {}".format(
value.shape, tmp_result.shape
)
for i in range(min(len(value.shape), len(tmp_result.shape))):
if (value.shape[-i - 1] != 1) & (
value.shape[-i - 1] != tmp_result.shape[-i - 1]
):
raise ValueError(
"cannot copy tensor with shape {} to subtensor with shape {}".format(
value.shape, tmp_result.shape
)
value = value._broadcast(tmp_result.shape)
)
value = value._broadcast(tmp_result.shape)
if use_subtensor:
op = builtin.SetSubtensor(items=items)
else:
......
......@@ -644,12 +644,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
v0, index0 = cond_take(mask, x)
v1, index1 = cond_take(~mask, y)
if v0.shape == (0,):
out = v1
elif v1.shape == (0,):
out = v0
else:
out = concat([v0, v1])
out = concat([v0, v1])
out[index0] = v0
out[index1] = v1
......
......@@ -85,7 +85,8 @@ public:
var->m_comp_node = dev_tensor.comp_node();
var->m_shape = dev_tensor.shape();
var->m_dev_tensor = dev_tensor;
var->reset_dev_tensor_from_tensor(dev_tensor);
var->m_mem_plan.reset_from_owner_var().chunk()
.mem_alloc_status.set_from_owner_var();
return var;
}
......@@ -560,7 +561,11 @@ void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs) {
mgb_assert(var->comp_node() == tensor->comp_node() &&
var->shape().eq_shape(layout) &&
var->dtype() == layout.dtype);
var->assign_dev_tensor_from_tensor(tensor->dev_tensor());
if (!tensor->layout().is_empty()) {
var->assign_dev_tensor_from_tensor(tensor->dev_tensor());
} else {
var->m_dev_tensor.storage({var->comp_node()});
}
++ j;
}
chk.mem_alloc_status.set_from_owner_var();
......
......@@ -365,6 +365,9 @@ WARN(IndexingIncrMultiAxisVec);
template <class Opr>
void IndexingMultiAxisVecBase<Opr>::scn_do_execute() {
if (output(0)->layout().is_empty()) {
return;
}
auto inp = input(0)->dev_tensor();
inp = inp.sub(fancy_indexing_make_sub_spec(inp.layout()));
auto &&index_desc = make_megdnn_index_desc(
......
......@@ -81,6 +81,11 @@ void ReadonlyFwdHelper::mixin_rofwd_init_mem_plan(OperatorNodeBase &opr) {
void ReadonlyFwdHelper::mixin_rofwd_execute(OperatorNodeBase &opr) {
mgb_assert(m_rofwd_subspec.layout().ndim, "rofwd uninitialized");
if (m_rofwd_subspec.layout().is_empty()) {
mgb_assert(opr.output(0)->shape().is_empty(), "output layout mismatch");
return;
}
auto &&out = opr.output(0)->dev_tensor(),
&&inp = opr.input(0)->dev_tensor();
if (m_mem_fwd_success) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册