提交 5474b000 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix convert_inputs before apply

GitOrigin-RevId: ab41974a1f3f60f3261a5f34d76ff858c8ccd07b
上级 a5fad7d0
......@@ -146,6 +146,7 @@ def conv2d(
compute_mode=compute_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
......@@ -209,6 +210,7 @@ def conv_transpose2d(
dilate_w=dilate_w,
strategy=get_conv_execution_strategy(),
)
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp)
if bias is not None:
output += bias
......@@ -243,6 +245,7 @@ def local_conv2d(
dilate_w=dilate_w,
# strategy=get_conv_execution_strategy(),
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
......@@ -900,6 +903,7 @@ def warp_perspective(
op = builtin.WarpPerspective(
imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val
)
inp, M = utils.convert_inputs(inp, M)
(result,) = apply(op, inp, M, Tensor(dsize))
return result
......@@ -1004,6 +1008,7 @@ def matmul(
format=format,
)
inp1, inp2 = utils.convert_inputs(inp1, inp2)
(result,) = apply(op, inp1, inp2)
if shp is not None:
result = result.reshape(shp)
......@@ -1327,6 +1332,7 @@ def roi_pooling(
output_shape = (output_shape, output_shape)
op = builtin.ROIPooling(mode=mode, scale=scale)
inp, rois = utils.convert_inputs(inp, rois)
result, _ = apply(
op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device)
)
......@@ -1374,6 +1380,7 @@ def roi_align(
sample_height=sample_height,
sample_width=sample_width,
)
input, rois = utils.convert_inputs(input, rois)
result, *_ = apply(op, input, rois)
return result
......
......@@ -104,6 +104,10 @@ struct Dispatcher {
auto& frame = stack.back();
auto& mro = *frame.mro;
auto& i = frame.mro_offset;
if (!mro.size()) {
PyErr_SetString(PyExc_NotImplementedError, "function not registered in dispatcher");
return nullptr;
}
for (; i < mro.size(); ++i) {
if (mro[i]->enabled) {
auto ret = caller(mro[i]->func);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册