diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 816e8a94a27045714c4cd4d517da5e13392a388f..b6a1f0c266b46acae992e5d3c85a506bc2a85768 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -227,19 +227,19 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): gopt_level = None # disable jit and compile binary_ops = { - "+": builtin.Elemwise(mode="add"), - "-": builtin.Elemwise(mode="sub"), - "*": builtin.Elemwise(mode="mul"), - "/": builtin.Elemwise(mode="true_div"), - "//": builtin.Elemwise(mode="floor_div"), - "**": builtin.Elemwise(mode="pow"), - "√": builtin.Elemwise(mode="expm1"), - "max": builtin.Elemwise(mode="max"), - "additive": builtin.Elemwise(mode="add"), + "+": lambda: builtin.Elemwise(mode="add"), + "-": lambda: builtin.Elemwise(mode="sub"), + "*": lambda: builtin.Elemwise(mode="mul"), + "/": lambda: builtin.Elemwise(mode="true_div"), + "//": lambda: builtin.Elemwise(mode="floor_div"), + "**": lambda: builtin.Elemwise(mode="pow"), + "√": lambda: builtin.Elemwise(mode="expm1"), + "max": lambda: builtin.Elemwise(mode="max"), + "additive": lambda: builtin.Elemwise(mode="add"), } unary_ops = { - "-": builtin.Elemwise(mode="negate"), + "-": lambda: builtin.Elemwise(mode="negate"), } def decorator(func): @@ -248,9 +248,9 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): def apply_expr(op, *args): if isinstance(op, str): if len(args) == 2: - op = binary_ops[op] + op = binary_ops[op]() elif len(args) == 1: - op = unary_ops[op] + op = unary_ops[op]() return builder.apply(op, args, 1)[0] def apply_const(value, dtype=dtype, device=device): @@ -261,8 +261,8 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): builder.outputs(outputs) builder.outputs_has_grad(outputs_has_grad) if gopt_level is None: - return builder.get() + return lambda: builder.get() else: - return builder.compile(gopt_level) + return lambda: builder.compile(gopt_level) return decorator diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index eadc6ec859a773e2fbfca0d27603ae915f976c8d..e8117ab131ff57c0fc70e375a1b9475dc69ba928 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -767,6 +767,19 @@ def matinv(inp: Tensor) -> Tensor: return result +class _Hashable: + def __init__(self, value) -> None: + self.value = value + + def __hash__(self) -> int: + return hash(str(self.value)) + + def __eq__(self, o: object) -> bool: + if not isinstance(o, _Hashable): + return False + return self.value == o.value + + @lru_cache(maxsize=None) def _get_extentedMatrixMulOp( device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, @@ -833,7 +846,7 @@ def _get_extentedMatrixMulOp( transposeB=transpose_b, compute_mode=compute_mode, format=format, - strategy=strategy, + strategy=strategy.value, ) result = f(op, inp1, inp2) result_shape = f(GetVarShape(), result) @@ -954,7 +967,7 @@ def _get_extentedBatchedMatrixMulOp( transposeB=transpose_b, compute_mode=compute_mode, format=format, - strategy=strategy, + strategy=strategy.value, ) result = f(op, inp1, inp2) @@ -1051,9 +1064,9 @@ def matmul( transpose_b, compute_mode, format, - strategy=get_execution_strategy(), + strategy=_Hashable(get_execution_strategy()), ) - (result,) = apply(extentedMatrixMulOp, inp1, inp2) + (result,) = apply(extentedMatrixMulOp(), inp1, inp2) return result else: # dispath to BatchedMatrixMul extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( @@ -1065,9 +1078,9 @@ def matmul( transpose_b, compute_mode, format, - strategy=get_execution_strategy(), + strategy=_Hashable(get_execution_strategy()), ) - (result,) = apply(extentedBatchedMatrixMulOp, inp1, inp2) + (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) return result diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index d7c7a64b837e5ca80a46531b3b7f065970435c7d..2f27f858c7ef63c3f153d3e292b9f8fc842aa800 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1328,7 +1328,7 @@ def sync_batch_norm( syncbn_split_stats, ) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels) - reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0, inp) + reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0(), inp) eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) @@ -1338,19 +1338,28 @@ def sync_batch_norm( if training: if is_distributed(): # reduce all nodes' data to calculate mean and variance - (stat,) = apply(syncbn_concat_stats, reduce_size, channel_x1s, channel_x2s) + (stat,) = apply( + syncbn_concat_stats(), reduce_size, channel_x1s, channel_x2s + ) stat = all_reduce_sum(stat, group) - reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats, stat) + reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats(), stat) outvar, channel_mean, *_ = apply( - syncbn_stage1, inp, reduce_size, channel_x1s, channel_x2s, eps, weight, bias + syncbn_stage1(), + inp, + reduce_size, + channel_x1s, + channel_x2s, + eps, + weight, + bias, ) else: assert running_var is not None and running_mean is not None channel_mean = running_mean channel_var = running_var outvar, *_ = apply( - syncbn_stage1_inference, inp, channel_mean, channel_var, eps, weight, bias + syncbn_stage1_inference(), inp, channel_mean, channel_var, eps, weight, bias ) # outvar = output * weight + bias @@ -1362,7 +1371,7 @@ def sync_batch_norm( if training and running_var is not None and running_mean is not None: momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device) running_mean[...], running_var[...] = apply( - syncbn_stage2, + syncbn_stage2(), running_mean, running_var, momentum, diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 9ca8c812b794bb7419715b12eab4b393bcd0370d..d821894f9aff033687cd2ffc86b658fadc41f1f6 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -482,9 +482,15 @@ void init_ops(py::module m) { struct PySubgraphBuilder { explicit PySubgraphBuilder(std::string name) : name{name}{} std::string name; - Subgraph graph; + std::shared_ptr graph_storage = std::make_shared(); + std::shared_ptr graph_key = std::make_shared(); + Subgraph& graph = *graph_storage; mgb::SmallVector output_grad_mask; Subgraph::var_t next_var = 1; + + std::shared_ptr build() const { + return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key); + } }; py::class_(m, "SubgraphBuilder") @@ -518,10 +524,9 @@ void init_ops(py::module m) { self.output_grad_mask = outputs_has_grad; }) .def("get", [](PySubgraphBuilder& self){ - return (std::shared_ptr)SubgraphOp::make(self.name, self.graph, self.output_grad_mask); + return (std::shared_ptr)self.build(); }) .def("compile", [](PySubgraphBuilder& self, int gopt_level){ - auto op = SubgraphOp::make(self.name, self.graph, self.output_grad_mask); - return (std::shared_ptr)CompiledOp::make(op, gopt_level); + return (std::shared_ptr)CompiledOp::make(self.build(), gopt_level); }); } diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp index 21104a726890b2b25c97a777a0152281e8a7e93b..73445bfe8135fc56c299b743ea13fd0917572657 100644 --- a/imperative/src/impl/ops/utility.cpp +++ b/imperative/src/impl/ops/utility.cpp @@ -181,7 +181,7 @@ OP_TRAIT_REG(Identity, Identity) namespace { namespace subgraph { EncodedSubraph make_forward_graph(const OpDef& def, SmallVector inputs) { - return EncodedSubraph::make(def.cast_final_safe().graph); + return EncodedSubraph::make(*def.cast_final_safe().graph); } EncodedSubraph make_backward_graph( @@ -197,16 +197,19 @@ EncodedSubraph make_backward_graph( } } auto bgraph = subgraph_detail::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); - return EncodedSubraph::make_single(SubgraphOp::make(op.name+"Grad", bgraph.graph), bgraph.input_mask, bgraph.output_mask); + return EncodedSubraph::make_single( + SubgraphOp::make(op.name + "Grad", + std::make_shared(bgraph.graph)), + bgraph.input_mask, bgraph.output_mask); } std::vector> props(const OpDef& def) { auto& op = def.cast_final_safe(); return { {"name", op.name}, - {"inputs", mgb::imperative::to_string(op.graph.inputs)}, - {"exprs", mgb::imperative::to_string(op.graph.exprs)}, - {"outputs", mgb::imperative::to_string(op.graph.outputs)}, + {"inputs", mgb::imperative::to_string(op.graph->inputs)}, + {"exprs", mgb::imperative::to_string(op.graph->exprs)}, + {"outputs", mgb::imperative::to_string(op.graph->outputs)}, }; } @@ -222,7 +225,7 @@ std::string make_name(const OpDef& def) { auto hash(const OpDef& def) { auto& op = def.cast_final_safe(); if (!op.graph_key) { - return (size_t)reinterpret_cast(&op.graph); + return (size_t)reinterpret_cast(op.graph.get()); } return op.graph_key->hash(); } @@ -238,7 +241,7 @@ auto is_same_st(const OpDef& def, const OpDef& another) { if (has_graph_key) { graph_same = rhs.graph_key && lhs.graph_key->is_same(*rhs.graph_key); } else { - graph_same = !rhs.graph_key && &lhs.graph == &rhs.graph; + graph_same = !rhs.graph_key && lhs.graph.get() == rhs.graph.get(); } return graph_same; } @@ -354,7 +357,9 @@ auto apply_on_physical_tensor( auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { - return OpDef::apply_on_var_node(*def.cast_final_safe().op, inputs); + auto& op = def.cast_final_safe(); + op.op->set_scope(op.scope()); + return OpDef::apply_on_var_node(*op.op, inputs); } auto infer_output_attrs_fallible( @@ -397,7 +402,9 @@ EncodedSubraph make_backward_graph( if (backward_graph.graph.is_single()) { bgraph_op = backward_graph.graph.as_single(); } else { - bgraph_op = SubgraphOp::make(name+"Grad", backward_graph.graph, grad_outputs_has_grad, key); + bgraph_op = SubgraphOp::make( + name + "Grad", std::make_shared(backward_graph.graph), + grad_outputs_has_grad, key); } auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level); auto encoded_graph = EncodedSubraph::make_single(compiled_op, backward_graph.input_mask, backward_graph.output_mask); @@ -431,6 +438,8 @@ OP_TRAIT_REG(CompiledOp, CompiledOp) .fallback(); }} +MGB_DYN_TYPE_OBJ_FINAL_IMPL(UniqueKey); + MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp); MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey); diff --git a/imperative/src/impl/subgraph_detail.cpp b/imperative/src/impl/subgraph_detail.cpp index ee4971645026acd5e7e4f229895006fd6d87e525..c345a3077d745a32ef7a9fe702a14c2cf496e283 100644 --- a/imperative/src/impl/subgraph_detail.cpp +++ b/imperative/src/impl/subgraph_detail.cpp @@ -28,7 +28,8 @@ VarNodeArray apply_on_var_node( for (auto&& input: inputs) { input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()}); } - auto apply_functor = [](const std::shared_ptr& op, const VarNodeArray& inputs, size_t nr_outputs){ + auto apply_functor = [&](const std::shared_ptr& op, const VarNodeArray& inputs, size_t nr_outputs){ + op->set_scope(def.scope()); return OpDef::apply_on_var_node(*op, inputs); }; auto const_functor = [&](const TensorPtr& value) { diff --git a/imperative/src/include/megbrain/imperative/ops/utility.h b/imperative/src/include/megbrain/imperative/ops/utility.h index 5fe810d21fe32f5b3c4cc6352aba7d78f0be1e0d..e5bbb14c44e2f1fab1a3c11a7a88c5df91b950e4 100644 --- a/imperative/src/include/megbrain/imperative/ops/utility.h +++ b/imperative/src/include/megbrain/imperative/ops/utility.h @@ -48,16 +48,28 @@ struct ShapeInfer final : OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; }; +struct UniqueKey final: Hashable { +public: + size_t hash() const override { + return reinterpret_cast(this); + } +protected: + bool is_same_st(const Hashable& rhs) const override { + return this == &rhs.cast_final_safe(); + } + MGB_DYN_TYPE_OBJ_FINAL_DECL; +}; + struct SubgraphOp final: OpDefImplBase { std::string name; - Subgraph graph; + std::shared_ptr graph; SmallVector output_grad_mask; std::shared_ptr graph_key; SubgraphOp() = default; - SubgraphOp(std::string name, Subgraph graph, SmallVector output_grad_mask={}, std::shared_ptr key=nullptr) + SubgraphOp(std::string name, std::shared_ptr graph, SmallVector output_grad_mask={}, std::shared_ptr key=nullptr) : name{name}, graph{graph}, output_grad_mask{output_grad_mask}, graph_key{std::move(key)}{ if (this->output_grad_mask.empty()) { - this->output_grad_mask.resize(graph.outputs.size(), true); + this->output_grad_mask.resize(graph->outputs.size(), true); } } MGB_DYN_TYPE_OBJ_FINAL_DECL;