提交 2775f458 编写于 作者: M Megvii Engine Team

feat(subgraph): subgraph builder supports jit and custom grad

GitOrigin-RevId: e1a1ebdf1c1f8b3d7fd8b3795d618a8e71b0dcc4
上级 3c61e0e0
......@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import itertools
from typing import Iterable, Union
import numpy as np
......@@ -22,6 +23,7 @@ from .._imperative_rt.core2 import (
)
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._wrap import as_device
from ..autodiff.grad import Function
from ..ops import builtin
from ..ops.special import Const
from .amp import _high_prec_dtype, _low_prec_dtype
......@@ -197,8 +199,15 @@ def _normalize_axis(
_opr_map = {
("-", 1): builtin.Elemwise(mode="negate"),
("abs", 1): builtin.Elemwise(mode="abs"),
("exp", 1): builtin.Elemwise(mode="exp"),
("log1p", 1): builtin.Elemwise(mode="log1p"),
("relu", 1): builtin.Elemwise(mode="relu"),
("cond_leq_mov", 3): builtin.Elemwise(mode="cond_leq_mov"),
("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"),
("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"),
("[?:]", 2): builtin.Subtensor(items=[(0, True, False, False, False)]),
("[:?]", 2): builtin.Subtensor(items=[(0, False, True, False, False)]),
}
for name, mode in [
......@@ -209,15 +218,21 @@ for name, mode in [
("//", "floor_div"),
("**", "pow"),
("max", "max"),
("min", "min"),
("additive", "add"),
("exp", "EXP"),
("switch_gt0", "switch_gt0"),
("abs_grad", "abs_grad"),
]:
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode)
def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
def subgraph(
name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False
):
if device.physical_name.startswith("cpu"):
gopt_level = None # disable jit and compile
jit_fusion = False
def as_op(op, nargs):
if isinstance(op, str):
......@@ -241,14 +256,64 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
def apply_const(value, dtype=dtype, device=device):
return builder.apply_const(value, dtype, device)
def build(builder, outputs, outputs_has_grad):
builder = type(builder)(builder)
builder.outputs(outputs)
builder.outputs_has_grad(outputs_has_grad)
if jit_fusion:
assert gopt_level is None
op = lambda: builder.jit_fuse()
elif gopt_level is None:
op = lambda: builder.get()
else:
op = lambda: builder.compile(gopt_level)
return op
inputs = [builder.input() for _ in range(nr_inputs)]
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const)
builder.outputs(outputs)
builder.outputs_has_grad(outputs_has_grad)
if gopt_level is None:
return lambda: builder.get()
if not custom_grad:
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const)
return build(builder, outputs, outputs_has_grad)
else:
return lambda: builder.compile(gopt_level)
gen = func(inputs, apply_expr, apply_const)
outputs = gen.send(None)
nr_outputs = len(outputs)
forward_fn = build(builder, outputs, [False] * nr_outputs)
output_grads = [builder.input() for _ in range(nr_outputs)]
input_grads = gen.send(output_grads)
assert len(input_grads) == nr_inputs
input_grads_mask = [input_grad is not None for input_grad in input_grads]
indices = [
i - 1 if mask else None
for i, mask in zip(
itertools.accumulate(input_grads_mask), input_grads_mask
)
]
encoded_input_grads = [grad for grad in input_grads if grad is not None]
backward_fn = build(
builder, encoded_input_grads, [False] * len(encoded_input_grads)
)
class SubgraphOp(Function):
def __init__(self):
self.inputs = None
def forward(self, *inputs):
self.inputs = inputs
return apply(forward_fn(), *inputs)
def backward(self, *output_grads):
inputs = self.inputs
self.inputs = None
encoded_input_grads = apply(backward_fn(), *inputs, *output_grads)
input_grads = [
encoded_input_grads[i] if i is not None else None
for i in indices
]
return input_grads
gen.close()
return SubgraphOp
return decorator
......@@ -274,15 +339,37 @@ def interpret_subgraph(func, dtype, device):
return Const(value, dtype=dtype, device=device)()[0]
outputs, outputs_has_grad = func(args, apply_expr, apply_const)
outputs = [
output if has_grad else output.detach()
for output, has_grad in zip(outputs, outputs_has_grad)
]
return outputs
return decorated_func
def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False):
def subgraph_fn(
name,
dtype,
device,
nr_inputs,
gopt_level=None,
jit_fusion=False,
custom_grad=False,
*,
interpret=False
):
def decorator(func):
if not interpret:
op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func)
op = subgraph(
name,
dtype,
device,
nr_inputs,
gopt_level=gopt_level,
jit_fusion=jit_fusion,
custom_grad=custom_grad,
)(func)
return lambda *args: apply(op(), *args)
else:
return interpret_subgraph(func, dtype, device)
......
......@@ -33,6 +33,7 @@ from ..core._imperative_rt.ops import (
ExternOpr,
RemoteRecv,
RemoteSend,
set_jit_enabled,
)
from ..core._trace_option import set_symbolic_shape
from ..core.tensor import megbrain_graph as G
......@@ -711,12 +712,14 @@ class trace:
graph = G.Graph()
jit_enabled = set_jit_enabled(False)
dest_vars = self._trace.dump(
graph,
input_bindings,
[*zip(self._output_bindings, output_names)],
prefer_input_names,
)
set_jit_enabled(jit_enabled)
# dest_vars = [i._node for i in dest_vars]
......
......@@ -577,21 +577,26 @@ void init_ops(py::module m) {
struct PySubgraphBuilder {
explicit PySubgraphBuilder(std::string name) : name{name} {}
std::string name;
std::shared_ptr<Subgraph> graph_storage = std::make_shared<Subgraph>();
std::shared_ptr<UniqueKey> graph_key = std::make_shared<UniqueKey>();
Subgraph& graph = *graph_storage;
Subgraph graph;
mgb::SmallVector<bool> output_grad_mask;
Subgraph::var_t next_var = 1;
std::shared_ptr<mgb::Hashable> key = nullptr;
std::shared_ptr<OpDef> build() const {
return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key);
std::shared_ptr<OpDef> build() {
if (key == nullptr) {
key = std::make_shared<UniqueKey>();
}
return SubgraphOp::make(
name, std::make_shared<Subgraph>(graph), output_grad_mask, key);
}
};
py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
.def(py::init<std::string>())
.def(py::init<PySubgraphBuilder>())
.def("input",
[](PySubgraphBuilder& self) {
mgb_assert(self.key == nullptr);
auto var = self.next_var++;
self.graph.inputs.push_back(var);
return var;
......@@ -599,6 +604,7 @@ void init_ops(py::module m) {
.def("apply",
[](PySubgraphBuilder& self, std::shared_ptr<OpDef> op,
Subgraph::vars_t inputs, size_t nr_outputs) {
mgb_assert(self.key == nullptr);
Subgraph::vars_t outputs;
for (size_t i = 0; i < nr_outputs; ++i) {
outputs.push_back(self.next_var++);
......@@ -609,6 +615,7 @@ void init_ops(py::module m) {
.def("apply_const",
[](PySubgraphBuilder& self, py::object value, mgb::DType dtype,
mgb::CompNode cn) {
mgb_assert(self.key == nullptr);
auto var = self.next_var++;
mgb::HostTensorND hvalue(cn);
npy::np2tensor(
......@@ -619,11 +626,13 @@ void init_ops(py::module m) {
})
.def("outputs",
[](PySubgraphBuilder& self, Subgraph::vars_t outputs) {
mgb_assert(self.key == nullptr);
self.graph.outputs = outputs;
self.output_grad_mask.resize(outputs.size(), true);
})
.def("outputs_has_grad",
[](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) {
mgb_assert(self.key == nullptr);
mgb_assert(
self.graph.outputs.size() == self.output_grad_mask.size());
self.output_grad_mask = outputs_has_grad;
......@@ -632,11 +641,18 @@ void init_ops(py::module m) {
[](PySubgraphBuilder& self) {
return (std::shared_ptr<OpDef>)self.build();
})
.def("compile", [](PySubgraphBuilder& self, int gopt_level) {
.def("compile",
[](PySubgraphBuilder& self, int gopt_level) {
return (std::shared_ptr<OpDef>)CompiledOp::make(
self.build(), gopt_level);
})
.def("jit_fuse", [](PySubgraphBuilder& self) {
return (std::shared_ptr<OpDef>)CompiledOp::make(
self.build(), gopt_level);
JITFusionOp::make(self.build()));
});
m.def("set_jit_enabled", &JITFusionOp::set_enabled);
auto custom = submodule(m, "_custom");
init_custom(custom);
}
......
......@@ -12,6 +12,7 @@
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
namespace mgb {
namespace imperative {
......@@ -320,6 +321,24 @@ std::vector<ValueRef> inplace_add_rule(
}
}
template <typename T>
std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) {
// TODO: add flag instead of assume
bool all_scalar = true;
for (auto&& input : inputs) {
if (!input.is<ScalarValue>()) {
all_scalar = false;
}
}
auto outputs = imperative::apply(op, unwrap_inputs(inputs));
if (all_scalar) {
for (auto& output : outputs) {
output = ScalarValue::make(output);
}
}
return outputs;
}
struct ScalarRuleRegistry {
ScalarRuleRegistry() {
register_scalar_rule(elemwise_rule);
......@@ -339,6 +358,8 @@ struct ScalarRuleRegistry {
register_scalar_rule(broadcast_rule);
register_scalar_rule(copy_rule);
register_scalar_rule(inplace_add_rule);
register_scalar_rule(subgraph_op_rule<SubgraphOp>);
register_scalar_rule(subgraph_op_rule<CompiledOp>);
}
} _;
} // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册