提交 21f5a7fc 编写于 作者: M Megvii Engine Team

fix(subgraph): fix device recognition and scalar propagate

GitOrigin-RevId: fd2fe8bec9d9730e8689dbad314e54a7ecbc8bde
上级 27346b0b
...@@ -230,7 +230,7 @@ for name, mode in [ ...@@ -230,7 +230,7 @@ for name, mode in [
def subgraph( def subgraph(
name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False
): ):
if device.physical_name.startswith("cpu"): if not device.physical_name.startswith("gpu"):
gopt_level = None # disable jit and compile gopt_level = None # disable jit and compile
jit_fusion = False jit_fusion = False
...@@ -370,7 +370,15 @@ def subgraph_fn( ...@@ -370,7 +370,15 @@ def subgraph_fn(
jit_fusion=jit_fusion, jit_fusion=jit_fusion,
custom_grad=custom_grad, custom_grad=custom_grad,
)(func) )(func)
return lambda *args: apply(op(), *args)
def wrapped_func(*args):
if custom_grad:
outputs = op()(*args)
else:
outputs = apply(op(), *args)
return outputs
return wrapped_func
else: else:
return interpret_subgraph(func, dtype, device) return interpret_subgraph(func, dtype, device)
......
...@@ -988,7 +988,6 @@ def _get_softplus_op(dtype=None, device=None): ...@@ -988,7 +988,6 @@ def _get_softplus_op(dtype=None, device=None):
device=device, device=device,
nr_inputs=1, nr_inputs=1,
jit_fusion=True, jit_fusion=True,
# gopt_level=0,
custom_grad=True, custom_grad=True,
) )
def softplus(inputs, f, c): def softplus(inputs, f, c):
......
...@@ -18,14 +18,7 @@ from ..core.ops import builtin ...@@ -18,14 +18,7 @@ from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor.array_method import _broadcast, _remove_axis from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import ( from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn
astensor1d,
convert_inputs,
get_device,
isscalar,
setscalar,
subgraph_fn,
)
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import ceil from .elemwise import ceil
...@@ -821,8 +814,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: ...@@ -821,8 +814,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
where = _get_where_op(dtype=dtype, device=device) where = _get_where_op(dtype=dtype, device=device)
(oup,) = where(mask, x, y) (oup,) = where(mask, x, y)
if isscalar(mask):
setscalar(oup)
return oup return oup
......
...@@ -67,7 +67,7 @@ void init_common(py::module m) { ...@@ -67,7 +67,7 @@ void init_common(py::module m) {
[](const CompNode& cn) { return cn.to_string_logical(); }) [](const CompNode& cn) { return cn.to_string_logical(); })
.def_property_readonly( .def_property_readonly(
"physical_name", "physical_name",
[](const CompNode& cn) { return cn.to_string(); }) [](const CompNode& cn) { return cn.to_string_physical(); })
.def_property_readonly( .def_property_readonly(
"get_mem_status_bytes", "get_mem_status_bytes",
[](const CompNode& cn) { [](const CompNode& cn) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册