提交 2775f458 编写于 作者: M Megvii Engine Team

feat(subgraph): subgraph builder supports jit and custom grad

GitOrigin-RevId: e1a1ebdf1c1f8b3d7fd8b3795d618a8e71b0dcc4
上级 3c61e0e0
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import itertools
from typing import Iterable, Union from typing import Iterable, Union
import numpy as np import numpy as np
...@@ -22,6 +23,7 @@ from .._imperative_rt.core2 import ( ...@@ -22,6 +23,7 @@ from .._imperative_rt.core2 import (
) )
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._wrap import as_device from .._wrap import as_device
from ..autodiff.grad import Function
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from .amp import _high_prec_dtype, _low_prec_dtype from .amp import _high_prec_dtype, _low_prec_dtype
...@@ -197,8 +199,15 @@ def _normalize_axis( ...@@ -197,8 +199,15 @@ def _normalize_axis(
_opr_map = { _opr_map = {
("-", 1): builtin.Elemwise(mode="negate"), ("-", 1): builtin.Elemwise(mode="negate"),
("abs", 1): builtin.Elemwise(mode="abs"),
("exp", 1): builtin.Elemwise(mode="exp"),
("log1p", 1): builtin.Elemwise(mode="log1p"),
("relu", 1): builtin.Elemwise(mode="relu"),
("cond_leq_mov", 3): builtin.Elemwise(mode="cond_leq_mov"),
("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"), ("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"),
("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"), ("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"),
("[?:]", 2): builtin.Subtensor(items=[(0, True, False, False, False)]),
("[:?]", 2): builtin.Subtensor(items=[(0, False, True, False, False)]),
} }
for name, mode in [ for name, mode in [
...@@ -209,15 +218,21 @@ for name, mode in [ ...@@ -209,15 +218,21 @@ for name, mode in [
("//", "floor_div"), ("//", "floor_div"),
("**", "pow"), ("**", "pow"),
("max", "max"), ("max", "max"),
("min", "min"),
("additive", "add"), ("additive", "add"),
("exp", "EXP"), ("exp", "EXP"),
("switch_gt0", "switch_gt0"),
("abs_grad", "abs_grad"),
]: ]:
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode) _opr_map[(name, 2)] = builtin.Elemwise(mode=mode)
def subgraph(name, dtype, device, nr_inputs, gopt_level=None): def subgraph(
name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False
):
if device.physical_name.startswith("cpu"): if device.physical_name.startswith("cpu"):
gopt_level = None # disable jit and compile gopt_level = None # disable jit and compile
jit_fusion = False
def as_op(op, nargs): def as_op(op, nargs):
if isinstance(op, str): if isinstance(op, str):
...@@ -241,14 +256,64 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): ...@@ -241,14 +256,64 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
def apply_const(value, dtype=dtype, device=device): def apply_const(value, dtype=dtype, device=device):
return builder.apply_const(value, dtype, device) return builder.apply_const(value, dtype, device)
inputs = [builder.input() for _ in range(nr_inputs)] def build(builder, outputs, outputs_has_grad):
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const) builder = type(builder)(builder)
builder.outputs(outputs) builder.outputs(outputs)
builder.outputs_has_grad(outputs_has_grad) builder.outputs_has_grad(outputs_has_grad)
if gopt_level is None: if jit_fusion:
return lambda: builder.get() assert gopt_level is None
op = lambda: builder.jit_fuse()
elif gopt_level is None:
op = lambda: builder.get()
else:
op = lambda: builder.compile(gopt_level)
return op
inputs = [builder.input() for _ in range(nr_inputs)]
if not custom_grad:
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const)
return build(builder, outputs, outputs_has_grad)
else: else:
return lambda: builder.compile(gopt_level) gen = func(inputs, apply_expr, apply_const)
outputs = gen.send(None)
nr_outputs = len(outputs)
forward_fn = build(builder, outputs, [False] * nr_outputs)
output_grads = [builder.input() for _ in range(nr_outputs)]
input_grads = gen.send(output_grads)
assert len(input_grads) == nr_inputs
input_grads_mask = [input_grad is not None for input_grad in input_grads]
indices = [
i - 1 if mask else None
for i, mask in zip(
itertools.accumulate(input_grads_mask), input_grads_mask
)
]
encoded_input_grads = [grad for grad in input_grads if grad is not None]
backward_fn = build(
builder, encoded_input_grads, [False] * len(encoded_input_grads)
)
class SubgraphOp(Function):
def __init__(self):
self.inputs = None
def forward(self, *inputs):
self.inputs = inputs
return apply(forward_fn(), *inputs)
def backward(self, *output_grads):
inputs = self.inputs
self.inputs = None
encoded_input_grads = apply(backward_fn(), *inputs, *output_grads)
input_grads = [
encoded_input_grads[i] if i is not None else None
for i in indices
]
return input_grads
gen.close()
return SubgraphOp
return decorator return decorator
...@@ -274,15 +339,37 @@ def interpret_subgraph(func, dtype, device): ...@@ -274,15 +339,37 @@ def interpret_subgraph(func, dtype, device):
return Const(value, dtype=dtype, device=device)()[0] return Const(value, dtype=dtype, device=device)()[0]
outputs, outputs_has_grad = func(args, apply_expr, apply_const) outputs, outputs_has_grad = func(args, apply_expr, apply_const)
outputs = [
output if has_grad else output.detach()
for output, has_grad in zip(outputs, outputs_has_grad)
]
return outputs return outputs
return decorated_func return decorated_func
def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False): def subgraph_fn(
name,
dtype,
device,
nr_inputs,
gopt_level=None,
jit_fusion=False,
custom_grad=False,
*,
interpret=False
):
def decorator(func): def decorator(func):
if not interpret: if not interpret:
op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func) op = subgraph(
name,
dtype,
device,
nr_inputs,
gopt_level=gopt_level,
jit_fusion=jit_fusion,
custom_grad=custom_grad,
)(func)
return lambda *args: apply(op(), *args) return lambda *args: apply(op(), *args)
else: else:
return interpret_subgraph(func, dtype, device) return interpret_subgraph(func, dtype, device)
......
...@@ -33,6 +33,7 @@ from ..core._imperative_rt.ops import ( ...@@ -33,6 +33,7 @@ from ..core._imperative_rt.ops import (
ExternOpr, ExternOpr,
RemoteRecv, RemoteRecv,
RemoteSend, RemoteSend,
set_jit_enabled,
) )
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
...@@ -711,12 +712,14 @@ class trace: ...@@ -711,12 +712,14 @@ class trace:
graph = G.Graph() graph = G.Graph()
jit_enabled = set_jit_enabled(False)
dest_vars = self._trace.dump( dest_vars = self._trace.dump(
graph, graph,
input_bindings, input_bindings,
[*zip(self._output_bindings, output_names)], [*zip(self._output_bindings, output_names)],
prefer_input_names, prefer_input_names,
) )
set_jit_enabled(jit_enabled)
# dest_vars = [i._node for i in dest_vars] # dest_vars = [i._node for i in dest_vars]
......
...@@ -577,21 +577,26 @@ void init_ops(py::module m) { ...@@ -577,21 +577,26 @@ void init_ops(py::module m) {
struct PySubgraphBuilder { struct PySubgraphBuilder {
explicit PySubgraphBuilder(std::string name) : name{name} {} explicit PySubgraphBuilder(std::string name) : name{name} {}
std::string name; std::string name;
std::shared_ptr<Subgraph> graph_storage = std::make_shared<Subgraph>(); Subgraph graph;
std::shared_ptr<UniqueKey> graph_key = std::make_shared<UniqueKey>();
Subgraph& graph = *graph_storage;
mgb::SmallVector<bool> output_grad_mask; mgb::SmallVector<bool> output_grad_mask;
Subgraph::var_t next_var = 1; Subgraph::var_t next_var = 1;
std::shared_ptr<mgb::Hashable> key = nullptr;
std::shared_ptr<OpDef> build() const { std::shared_ptr<OpDef> build() {
return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key); if (key == nullptr) {
key = std::make_shared<UniqueKey>();
}
return SubgraphOp::make(
name, std::make_shared<Subgraph>(graph), output_grad_mask, key);
} }
}; };
py::class_<PySubgraphBuilder>(m, "SubgraphBuilder") py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
.def(py::init<std::string>()) .def(py::init<std::string>())
.def(py::init<PySubgraphBuilder>())
.def("input", .def("input",
[](PySubgraphBuilder& self) { [](PySubgraphBuilder& self) {
mgb_assert(self.key == nullptr);
auto var = self.next_var++; auto var = self.next_var++;
self.graph.inputs.push_back(var); self.graph.inputs.push_back(var);
return var; return var;
...@@ -599,6 +604,7 @@ void init_ops(py::module m) { ...@@ -599,6 +604,7 @@ void init_ops(py::module m) {
.def("apply", .def("apply",
[](PySubgraphBuilder& self, std::shared_ptr<OpDef> op, [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op,
Subgraph::vars_t inputs, size_t nr_outputs) { Subgraph::vars_t inputs, size_t nr_outputs) {
mgb_assert(self.key == nullptr);
Subgraph::vars_t outputs; Subgraph::vars_t outputs;
for (size_t i = 0; i < nr_outputs; ++i) { for (size_t i = 0; i < nr_outputs; ++i) {
outputs.push_back(self.next_var++); outputs.push_back(self.next_var++);
...@@ -609,6 +615,7 @@ void init_ops(py::module m) { ...@@ -609,6 +615,7 @@ void init_ops(py::module m) {
.def("apply_const", .def("apply_const",
[](PySubgraphBuilder& self, py::object value, mgb::DType dtype, [](PySubgraphBuilder& self, py::object value, mgb::DType dtype,
mgb::CompNode cn) { mgb::CompNode cn) {
mgb_assert(self.key == nullptr);
auto var = self.next_var++; auto var = self.next_var++;
mgb::HostTensorND hvalue(cn); mgb::HostTensorND hvalue(cn);
npy::np2tensor( npy::np2tensor(
...@@ -619,11 +626,13 @@ void init_ops(py::module m) { ...@@ -619,11 +626,13 @@ void init_ops(py::module m) {
}) })
.def("outputs", .def("outputs",
[](PySubgraphBuilder& self, Subgraph::vars_t outputs) { [](PySubgraphBuilder& self, Subgraph::vars_t outputs) {
mgb_assert(self.key == nullptr);
self.graph.outputs = outputs; self.graph.outputs = outputs;
self.output_grad_mask.resize(outputs.size(), true); self.output_grad_mask.resize(outputs.size(), true);
}) })
.def("outputs_has_grad", .def("outputs_has_grad",
[](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) { [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) {
mgb_assert(self.key == nullptr);
mgb_assert( mgb_assert(
self.graph.outputs.size() == self.output_grad_mask.size()); self.graph.outputs.size() == self.output_grad_mask.size());
self.output_grad_mask = outputs_has_grad; self.output_grad_mask = outputs_has_grad;
...@@ -632,11 +641,18 @@ void init_ops(py::module m) { ...@@ -632,11 +641,18 @@ void init_ops(py::module m) {
[](PySubgraphBuilder& self) { [](PySubgraphBuilder& self) {
return (std::shared_ptr<OpDef>)self.build(); return (std::shared_ptr<OpDef>)self.build();
}) })
.def("compile", [](PySubgraphBuilder& self, int gopt_level) { .def("compile",
[](PySubgraphBuilder& self, int gopt_level) {
return (std::shared_ptr<OpDef>)CompiledOp::make( return (std::shared_ptr<OpDef>)CompiledOp::make(
self.build(), gopt_level); self.build(), gopt_level);
})
.def("jit_fuse", [](PySubgraphBuilder& self) {
return (std::shared_ptr<OpDef>)CompiledOp::make(
JITFusionOp::make(self.build()));
}); });
m.def("set_jit_enabled", &JITFusionOp::set_enabled);
auto custom = submodule(m, "_custom"); auto custom = submodule(m, "_custom");
init_custom(custom); init_custom(custom);
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "megbrain/imperative/transformations/scalar.h" #include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
...@@ -320,6 +321,24 @@ std::vector<ValueRef> inplace_add_rule( ...@@ -320,6 +321,24 @@ std::vector<ValueRef> inplace_add_rule(
} }
} }
template <typename T>
std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) {
// TODO: add flag instead of assume
bool all_scalar = true;
for (auto&& input : inputs) {
if (!input.is<ScalarValue>()) {
all_scalar = false;
}
}
auto outputs = imperative::apply(op, unwrap_inputs(inputs));
if (all_scalar) {
for (auto& output : outputs) {
output = ScalarValue::make(output);
}
}
return outputs;
}
struct ScalarRuleRegistry { struct ScalarRuleRegistry {
ScalarRuleRegistry() { ScalarRuleRegistry() {
register_scalar_rule(elemwise_rule); register_scalar_rule(elemwise_rule);
...@@ -339,6 +358,8 @@ struct ScalarRuleRegistry { ...@@ -339,6 +358,8 @@ struct ScalarRuleRegistry {
register_scalar_rule(broadcast_rule); register_scalar_rule(broadcast_rule);
register_scalar_rule(copy_rule); register_scalar_rule(copy_rule);
register_scalar_rule(inplace_add_rule); register_scalar_rule(inplace_add_rule);
register_scalar_rule(subgraph_op_rule<SubgraphOp>);
register_scalar_rule(subgraph_op_rule<CompiledOp>);
} }
} _; } _;
} // namespace } // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册