提交 21f5a7fc 编写于 作者: M Megvii Engine Team

fix(subgraph): fix device recognition and scalar propagate

GitOrigin-RevId: fd2fe8bec9d9730e8689dbad314e54a7ecbc8bde
上级 27346b0b
......@@ -230,7 +230,7 @@ for name, mode in [
def subgraph(
name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False
):
if device.physical_name.startswith("cpu"):
if not device.physical_name.startswith("gpu"):
gopt_level = None # disable jit and compile
jit_fusion = False
......@@ -370,7 +370,15 @@ def subgraph_fn(
jit_fusion=jit_fusion,
custom_grad=custom_grad,
)(func)
return lambda *args: apply(op(), *args)
def wrapped_func(*args):
if custom_grad:
outputs = op()(*args)
else:
outputs = apply(op(), *args)
return outputs
return wrapped_func
else:
return interpret_subgraph(func, dtype, device)
......
......@@ -988,7 +988,6 @@ def _get_softplus_op(dtype=None, device=None):
device=device,
nr_inputs=1,
jit_fusion=True,
# gopt_level=0,
custom_grad=True,
)
def softplus(inputs, f, c):
......
......@@ -18,14 +18,7 @@ from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const
from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import (
astensor1d,
convert_inputs,
get_device,
isscalar,
setscalar,
subgraph_fn,
)
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn
from ..device import get_default_device
from ..tensor import Tensor
from .elemwise import ceil
......@@ -821,8 +814,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
where = _get_where_op(dtype=dtype, device=device)
(oup,) = where(mask, x, y)
if isscalar(mask):
setscalar(oup)
return oup
......
......@@ -67,7 +67,7 @@ void init_common(py::module m) {
[](const CompNode& cn) { return cn.to_string_logical(); })
.def_property_readonly(
"physical_name",
[](const CompNode& cn) { return cn.to_string(); })
[](const CompNode& cn) { return cn.to_string_physical(); })
.def_property_readonly(
"get_mem_status_bytes",
[](const CompNode& cn) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册