From 2775f4580cbe0a14f12b5da537a4671675508810 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 26 Sep 2021 19:53:02 +0800 Subject: [PATCH] feat(subgraph): subgraph builder supports jit and custom grad GitOrigin-RevId: e1a1ebdf1c1f8b3d7fd8b3795d618a8e71b0dcc4 --- .../python/megengine/core/tensor/utils.py | 105 ++++++++++++++++-- imperative/python/megengine/jit/tracing.py | 3 + imperative/python/src/ops.cpp | 30 +++-- .../src/impl/transformations/scalar.cpp | 21 ++++ 4 files changed, 143 insertions(+), 16 deletions(-) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 45d934767..fdf0e3444 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections +import itertools from typing import Iterable, Union import numpy as np @@ -22,6 +23,7 @@ from .._imperative_rt.core2 import ( ) from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from .._wrap import as_device +from ..autodiff.grad import Function from ..ops import builtin from ..ops.special import Const from .amp import _high_prec_dtype, _low_prec_dtype @@ -197,8 +199,15 @@ def _normalize_axis( _opr_map = { ("-", 1): builtin.Elemwise(mode="negate"), + ("abs", 1): builtin.Elemwise(mode="abs"), + ("exp", 1): builtin.Elemwise(mode="exp"), + ("log1p", 1): builtin.Elemwise(mode="log1p"), + ("relu", 1): builtin.Elemwise(mode="relu"), + ("cond_leq_mov", 3): builtin.Elemwise(mode="cond_leq_mov"), ("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"), ("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"), + ("[?:]", 2): builtin.Subtensor(items=[(0, True, False, False, False)]), + ("[:?]", 2): builtin.Subtensor(items=[(0, False, True, False, False)]), } for name, mode in [ @@ -209,15 +218,21 @@ for name, mode in [ ("//", "floor_div"), ("**", "pow"), ("max", "max"), + ("min", "min"), ("additive", "add"), ("exp", "EXP"), + ("switch_gt0", "switch_gt0"), + ("abs_grad", "abs_grad"), ]: _opr_map[(name, 2)] = builtin.Elemwise(mode=mode) -def subgraph(name, dtype, device, nr_inputs, gopt_level=None): +def subgraph( + name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False +): if device.physical_name.startswith("cpu"): gopt_level = None # disable jit and compile + jit_fusion = False def as_op(op, nargs): if isinstance(op, str): @@ -241,14 +256,64 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): def apply_const(value, dtype=dtype, device=device): return builder.apply_const(value, dtype, device) + def build(builder, outputs, outputs_has_grad): + builder = type(builder)(builder) + builder.outputs(outputs) + builder.outputs_has_grad(outputs_has_grad) + if jit_fusion: + assert gopt_level is None + op = lambda: builder.jit_fuse() + elif gopt_level is None: + op = lambda: builder.get() + else: + op = lambda: builder.compile(gopt_level) + return op + inputs = [builder.input() for _ in range(nr_inputs)] - outputs, outputs_has_grad = func(inputs, apply_expr, apply_const) - builder.outputs(outputs) - builder.outputs_has_grad(outputs_has_grad) - if gopt_level is None: - return lambda: builder.get() + if not custom_grad: + outputs, outputs_has_grad = func(inputs, apply_expr, apply_const) + return build(builder, outputs, outputs_has_grad) else: - return lambda: builder.compile(gopt_level) + gen = func(inputs, apply_expr, apply_const) + outputs = gen.send(None) + nr_outputs = len(outputs) + forward_fn = build(builder, outputs, [False] * nr_outputs) + + output_grads = [builder.input() for _ in range(nr_outputs)] + input_grads = gen.send(output_grads) + assert len(input_grads) == nr_inputs + input_grads_mask = [input_grad is not None for input_grad in input_grads] + indices = [ + i - 1 if mask else None + for i, mask in zip( + itertools.accumulate(input_grads_mask), input_grads_mask + ) + ] + encoded_input_grads = [grad for grad in input_grads if grad is not None] + backward_fn = build( + builder, encoded_input_grads, [False] * len(encoded_input_grads) + ) + + class SubgraphOp(Function): + def __init__(self): + self.inputs = None + + def forward(self, *inputs): + self.inputs = inputs + return apply(forward_fn(), *inputs) + + def backward(self, *output_grads): + inputs = self.inputs + self.inputs = None + encoded_input_grads = apply(backward_fn(), *inputs, *output_grads) + input_grads = [ + encoded_input_grads[i] if i is not None else None + for i in indices + ] + return input_grads + + gen.close() + return SubgraphOp return decorator @@ -274,15 +339,37 @@ def interpret_subgraph(func, dtype, device): return Const(value, dtype=dtype, device=device)()[0] outputs, outputs_has_grad = func(args, apply_expr, apply_const) + outputs = [ + output if has_grad else output.detach() + for output, has_grad in zip(outputs, outputs_has_grad) + ] return outputs return decorated_func -def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False): +def subgraph_fn( + name, + dtype, + device, + nr_inputs, + gopt_level=None, + jit_fusion=False, + custom_grad=False, + *, + interpret=False +): def decorator(func): if not interpret: - op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func) + op = subgraph( + name, + dtype, + device, + nr_inputs, + gopt_level=gopt_level, + jit_fusion=jit_fusion, + custom_grad=custom_grad, + )(func) return lambda *args: apply(op(), *args) else: return interpret_subgraph(func, dtype, device) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index a5aa567ec..79ca807bb 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -33,6 +33,7 @@ from ..core._imperative_rt.ops import ( ExternOpr, RemoteRecv, RemoteSend, + set_jit_enabled, ) from ..core._trace_option import set_symbolic_shape from ..core.tensor import megbrain_graph as G @@ -711,12 +712,14 @@ class trace: graph = G.Graph() + jit_enabled = set_jit_enabled(False) dest_vars = self._trace.dump( graph, input_bindings, [*zip(self._output_bindings, output_names)], prefer_input_names, ) + set_jit_enabled(jit_enabled) # dest_vars = [i._node for i in dest_vars] diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index a53b1027e..aea124e16 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -577,21 +577,26 @@ void init_ops(py::module m) { struct PySubgraphBuilder { explicit PySubgraphBuilder(std::string name) : name{name} {} std::string name; - std::shared_ptr graph_storage = std::make_shared(); - std::shared_ptr graph_key = std::make_shared(); - Subgraph& graph = *graph_storage; + Subgraph graph; mgb::SmallVector output_grad_mask; Subgraph::var_t next_var = 1; + std::shared_ptr key = nullptr; - std::shared_ptr build() const { - return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key); + std::shared_ptr build() { + if (key == nullptr) { + key = std::make_shared(); + } + return SubgraphOp::make( + name, std::make_shared(graph), output_grad_mask, key); } }; py::class_(m, "SubgraphBuilder") .def(py::init()) + .def(py::init()) .def("input", [](PySubgraphBuilder& self) { + mgb_assert(self.key == nullptr); auto var = self.next_var++; self.graph.inputs.push_back(var); return var; @@ -599,6 +604,7 @@ void init_ops(py::module m) { .def("apply", [](PySubgraphBuilder& self, std::shared_ptr op, Subgraph::vars_t inputs, size_t nr_outputs) { + mgb_assert(self.key == nullptr); Subgraph::vars_t outputs; for (size_t i = 0; i < nr_outputs; ++i) { outputs.push_back(self.next_var++); @@ -609,6 +615,7 @@ void init_ops(py::module m) { .def("apply_const", [](PySubgraphBuilder& self, py::object value, mgb::DType dtype, mgb::CompNode cn) { + mgb_assert(self.key == nullptr); auto var = self.next_var++; mgb::HostTensorND hvalue(cn); npy::np2tensor( @@ -619,11 +626,13 @@ void init_ops(py::module m) { }) .def("outputs", [](PySubgraphBuilder& self, Subgraph::vars_t outputs) { + mgb_assert(self.key == nullptr); self.graph.outputs = outputs; self.output_grad_mask.resize(outputs.size(), true); }) .def("outputs_has_grad", [](PySubgraphBuilder& self, mgb::SmallVector outputs_has_grad) { + mgb_assert(self.key == nullptr); mgb_assert( self.graph.outputs.size() == self.output_grad_mask.size()); self.output_grad_mask = outputs_has_grad; @@ -632,11 +641,18 @@ void init_ops(py::module m) { [](PySubgraphBuilder& self) { return (std::shared_ptr)self.build(); }) - .def("compile", [](PySubgraphBuilder& self, int gopt_level) { + .def("compile", + [](PySubgraphBuilder& self, int gopt_level) { + return (std::shared_ptr)CompiledOp::make( + self.build(), gopt_level); + }) + .def("jit_fuse", [](PySubgraphBuilder& self) { return (std::shared_ptr)CompiledOp::make( - self.build(), gopt_level); + JITFusionOp::make(self.build())); }); + m.def("set_jit_enabled", &JITFusionOp::set_enabled); + auto custom = submodule(m, "_custom"); init_custom(custom); } diff --git a/imperative/src/impl/transformations/scalar.cpp b/imperative/src/impl/transformations/scalar.cpp index 8daa5827e..eabd9efde 100644 --- a/imperative/src/impl/transformations/scalar.cpp +++ b/imperative/src/impl/transformations/scalar.cpp @@ -12,6 +12,7 @@ #include "megbrain/imperative/transformations/scalar.h" #include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/ops/utility.h" namespace mgb { namespace imperative { @@ -320,6 +321,24 @@ std::vector inplace_add_rule( } } +template +std::vector subgraph_op_rule(const T& op, Span inputs) { + // TODO: add flag instead of assume + bool all_scalar = true; + for (auto&& input : inputs) { + if (!input.is()) { + all_scalar = false; + } + } + auto outputs = imperative::apply(op, unwrap_inputs(inputs)); + if (all_scalar) { + for (auto& output : outputs) { + output = ScalarValue::make(output); + } + } + return outputs; +} + struct ScalarRuleRegistry { ScalarRuleRegistry() { register_scalar_rule(elemwise_rule); @@ -339,6 +358,8 @@ struct ScalarRuleRegistry { register_scalar_rule(broadcast_rule); register_scalar_rule(copy_rule); register_scalar_rule(inplace_add_rule); + register_scalar_rule(subgraph_op_rule); + register_scalar_rule(subgraph_op_rule); } } _; } // namespace -- GitLab