提交 2a063f8e 编写于 作者: M Megvii Engine Team

fix(subgraph): fix scope mismatch of subgraph content

GitOrigin-RevId: 6e23456250aa70c4cbdd71ecd9cfa6c19270a316
上级 3206af9d
...@@ -227,19 +227,19 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): ...@@ -227,19 +227,19 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
gopt_level = None # disable jit and compile gopt_level = None # disable jit and compile
binary_ops = { binary_ops = {
"+": builtin.Elemwise(mode="add"), "+": lambda: builtin.Elemwise(mode="add"),
"-": builtin.Elemwise(mode="sub"), "-": lambda: builtin.Elemwise(mode="sub"),
"*": builtin.Elemwise(mode="mul"), "*": lambda: builtin.Elemwise(mode="mul"),
"/": builtin.Elemwise(mode="true_div"), "/": lambda: builtin.Elemwise(mode="true_div"),
"//": builtin.Elemwise(mode="floor_div"), "//": lambda: builtin.Elemwise(mode="floor_div"),
"**": builtin.Elemwise(mode="pow"), "**": lambda: builtin.Elemwise(mode="pow"),
"√": builtin.Elemwise(mode="expm1"), "√": lambda: builtin.Elemwise(mode="expm1"),
"max": builtin.Elemwise(mode="max"), "max": lambda: builtin.Elemwise(mode="max"),
"additive": builtin.Elemwise(mode="add"), "additive": lambda: builtin.Elemwise(mode="add"),
} }
unary_ops = { unary_ops = {
"-": builtin.Elemwise(mode="negate"), "-": lambda: builtin.Elemwise(mode="negate"),
} }
def decorator(func): def decorator(func):
...@@ -248,9 +248,9 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): ...@@ -248,9 +248,9 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
def apply_expr(op, *args): def apply_expr(op, *args):
if isinstance(op, str): if isinstance(op, str):
if len(args) == 2: if len(args) == 2:
op = binary_ops[op] op = binary_ops[op]()
elif len(args) == 1: elif len(args) == 1:
op = unary_ops[op] op = unary_ops[op]()
return builder.apply(op, args, 1)[0] return builder.apply(op, args, 1)[0]
def apply_const(value, dtype=dtype, device=device): def apply_const(value, dtype=dtype, device=device):
...@@ -261,8 +261,8 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): ...@@ -261,8 +261,8 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
builder.outputs(outputs) builder.outputs(outputs)
builder.outputs_has_grad(outputs_has_grad) builder.outputs_has_grad(outputs_has_grad)
if gopt_level is None: if gopt_level is None:
return builder.get() return lambda: builder.get()
else: else:
return builder.compile(gopt_level) return lambda: builder.compile(gopt_level)
return decorator return decorator
...@@ -767,6 +767,19 @@ def matinv(inp: Tensor) -> Tensor: ...@@ -767,6 +767,19 @@ def matinv(inp: Tensor) -> Tensor:
return result return result
class _Hashable:
def __init__(self, value) -> None:
self.value = value
def __hash__(self) -> int:
return hash(str(self.value))
def __eq__(self, o: object) -> bool:
if not isinstance(o, _Hashable):
return False
return self.value == o.value
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _get_extentedMatrixMulOp( def _get_extentedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
...@@ -833,7 +846,7 @@ def _get_extentedMatrixMulOp( ...@@ -833,7 +846,7 @@ def _get_extentedMatrixMulOp(
transposeB=transpose_b, transposeB=transpose_b,
compute_mode=compute_mode, compute_mode=compute_mode,
format=format, format=format,
strategy=strategy, strategy=strategy.value,
) )
result = f(op, inp1, inp2) result = f(op, inp1, inp2)
result_shape = f(GetVarShape(), result) result_shape = f(GetVarShape(), result)
...@@ -954,7 +967,7 @@ def _get_extentedBatchedMatrixMulOp( ...@@ -954,7 +967,7 @@ def _get_extentedBatchedMatrixMulOp(
transposeB=transpose_b, transposeB=transpose_b,
compute_mode=compute_mode, compute_mode=compute_mode,
format=format, format=format,
strategy=strategy, strategy=strategy.value,
) )
result = f(op, inp1, inp2) result = f(op, inp1, inp2)
...@@ -1051,9 +1064,9 @@ def matmul( ...@@ -1051,9 +1064,9 @@ def matmul(
transpose_b, transpose_b,
compute_mode, compute_mode,
format, format,
strategy=get_execution_strategy(), strategy=_Hashable(get_execution_strategy()),
) )
(result,) = apply(extentedMatrixMulOp, inp1, inp2) (result,) = apply(extentedMatrixMulOp(), inp1, inp2)
return result return result
else: # dispath to BatchedMatrixMul else: # dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
...@@ -1065,9 +1078,9 @@ def matmul( ...@@ -1065,9 +1078,9 @@ def matmul(
transpose_b, transpose_b,
compute_mode, compute_mode,
format, format,
strategy=get_execution_strategy(), strategy=_Hashable(get_execution_strategy()),
) )
(result,) = apply(extentedBatchedMatrixMulOp, inp1, inp2) (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
return result return result
......
...@@ -1328,7 +1328,7 @@ def sync_batch_norm( ...@@ -1328,7 +1328,7 @@ def sync_batch_norm(
syncbn_split_stats, syncbn_split_stats,
) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels) ) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels)
reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0, inp) reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0(), inp)
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
...@@ -1338,19 +1338,28 @@ def sync_batch_norm( ...@@ -1338,19 +1338,28 @@ def sync_batch_norm(
if training: if training:
if is_distributed(): if is_distributed():
# reduce all nodes' data to calculate mean and variance # reduce all nodes' data to calculate mean and variance
(stat,) = apply(syncbn_concat_stats, reduce_size, channel_x1s, channel_x2s) (stat,) = apply(
syncbn_concat_stats(), reduce_size, channel_x1s, channel_x2s
)
stat = all_reduce_sum(stat, group) stat = all_reduce_sum(stat, group)
reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats, stat) reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats(), stat)
outvar, channel_mean, *_ = apply( outvar, channel_mean, *_ = apply(
syncbn_stage1, inp, reduce_size, channel_x1s, channel_x2s, eps, weight, bias syncbn_stage1(),
inp,
reduce_size,
channel_x1s,
channel_x2s,
eps,
weight,
bias,
) )
else: else:
assert running_var is not None and running_mean is not None assert running_var is not None and running_mean is not None
channel_mean = running_mean channel_mean = running_mean
channel_var = running_var channel_var = running_var
outvar, *_ = apply( outvar, *_ = apply(
syncbn_stage1_inference, inp, channel_mean, channel_var, eps, weight, bias syncbn_stage1_inference(), inp, channel_mean, channel_var, eps, weight, bias
) )
# outvar = output * weight + bias # outvar = output * weight + bias
...@@ -1362,7 +1371,7 @@ def sync_batch_norm( ...@@ -1362,7 +1371,7 @@ def sync_batch_norm(
if training and running_var is not None and running_mean is not None: if training and running_var is not None and running_mean is not None:
momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device) momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device)
running_mean[...], running_var[...] = apply( running_mean[...], running_var[...] = apply(
syncbn_stage2, syncbn_stage2(),
running_mean, running_mean,
running_var, running_var,
momentum, momentum,
......
...@@ -482,9 +482,15 @@ void init_ops(py::module m) { ...@@ -482,9 +482,15 @@ void init_ops(py::module m) {
struct PySubgraphBuilder { struct PySubgraphBuilder {
explicit PySubgraphBuilder(std::string name) : name{name}{} explicit PySubgraphBuilder(std::string name) : name{name}{}
std::string name; std::string name;
Subgraph graph; std::shared_ptr<Subgraph> graph_storage = std::make_shared<Subgraph>();
std::shared_ptr<UniqueKey> graph_key = std::make_shared<UniqueKey>();
Subgraph& graph = *graph_storage;
mgb::SmallVector<bool> output_grad_mask; mgb::SmallVector<bool> output_grad_mask;
Subgraph::var_t next_var = 1; Subgraph::var_t next_var = 1;
std::shared_ptr<OpDef> build() const {
return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key);
}
}; };
py::class_<PySubgraphBuilder>(m, "SubgraphBuilder") py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
...@@ -518,10 +524,9 @@ void init_ops(py::module m) { ...@@ -518,10 +524,9 @@ void init_ops(py::module m) {
self.output_grad_mask = outputs_has_grad; self.output_grad_mask = outputs_has_grad;
}) })
.def("get", [](PySubgraphBuilder& self){ .def("get", [](PySubgraphBuilder& self){
return (std::shared_ptr<OpDef>)SubgraphOp::make(self.name, self.graph, self.output_grad_mask); return (std::shared_ptr<OpDef>)self.build();
}) })
.def("compile", [](PySubgraphBuilder& self, int gopt_level){ .def("compile", [](PySubgraphBuilder& self, int gopt_level){
auto op = SubgraphOp::make(self.name, self.graph, self.output_grad_mask); return (std::shared_ptr<OpDef>)CompiledOp::make(self.build(), gopt_level);
return (std::shared_ptr<OpDef>)CompiledOp::make(op, gopt_level);
}); });
} }
...@@ -181,7 +181,7 @@ OP_TRAIT_REG(Identity, Identity) ...@@ -181,7 +181,7 @@ OP_TRAIT_REG(Identity, Identity)
namespace { namespace subgraph { namespace { namespace subgraph {
EncodedSubraph make_forward_graph(const OpDef& def, SmallVector<LogicalTensorDesc> inputs) { EncodedSubraph make_forward_graph(const OpDef& def, SmallVector<LogicalTensorDesc> inputs) {
return EncodedSubraph::make(def.cast_final_safe<SubgraphOp>().graph); return EncodedSubraph::make(*def.cast_final_safe<SubgraphOp>().graph);
} }
EncodedSubraph make_backward_graph( EncodedSubraph make_backward_graph(
...@@ -197,16 +197,19 @@ EncodedSubraph make_backward_graph( ...@@ -197,16 +197,19 @@ EncodedSubraph make_backward_graph(
} }
} }
auto bgraph = subgraph_detail::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); auto bgraph = subgraph_detail::make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
return EncodedSubraph::make_single(SubgraphOp::make(op.name+"Grad", bgraph.graph), bgraph.input_mask, bgraph.output_mask); return EncodedSubraph::make_single(
SubgraphOp::make(op.name + "Grad",
std::make_shared<Subgraph>(bgraph.graph)),
bgraph.input_mask, bgraph.output_mask);
} }
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
auto& op = def.cast_final_safe<SubgraphOp>(); auto& op = def.cast_final_safe<SubgraphOp>();
return { return {
{"name", op.name}, {"name", op.name},
{"inputs", mgb::imperative::to_string(op.graph.inputs)}, {"inputs", mgb::imperative::to_string(op.graph->inputs)},
{"exprs", mgb::imperative::to_string(op.graph.exprs)}, {"exprs", mgb::imperative::to_string(op.graph->exprs)},
{"outputs", mgb::imperative::to_string(op.graph.outputs)}, {"outputs", mgb::imperative::to_string(op.graph->outputs)},
}; };
} }
...@@ -222,7 +225,7 @@ std::string make_name(const OpDef& def) { ...@@ -222,7 +225,7 @@ std::string make_name(const OpDef& def) {
auto hash(const OpDef& def) { auto hash(const OpDef& def) {
auto& op = def.cast_final_safe<SubgraphOp>(); auto& op = def.cast_final_safe<SubgraphOp>();
if (!op.graph_key) { if (!op.graph_key) {
return (size_t)reinterpret_cast<uintptr_t>(&op.graph); return (size_t)reinterpret_cast<uintptr_t>(op.graph.get());
} }
return op.graph_key->hash(); return op.graph_key->hash();
} }
...@@ -238,7 +241,7 @@ auto is_same_st(const OpDef& def, const OpDef& another) { ...@@ -238,7 +241,7 @@ auto is_same_st(const OpDef& def, const OpDef& another) {
if (has_graph_key) { if (has_graph_key) {
graph_same = rhs.graph_key && lhs.graph_key->is_same(*rhs.graph_key); graph_same = rhs.graph_key && lhs.graph_key->is_same(*rhs.graph_key);
} else { } else {
graph_same = !rhs.graph_key && &lhs.graph == &rhs.graph; graph_same = !rhs.graph_key && lhs.graph.get() == rhs.graph.get();
} }
return graph_same; return graph_same;
} }
...@@ -354,7 +357,9 @@ auto apply_on_physical_tensor( ...@@ -354,7 +357,9 @@ auto apply_on_physical_tensor(
auto apply_on_var_node( auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
return OpDef::apply_on_var_node(*def.cast_final_safe<CompiledOp>().op, inputs); auto& op = def.cast_final_safe<CompiledOp>();
op.op->set_scope(op.scope());
return OpDef::apply_on_var_node(*op.op, inputs);
} }
auto infer_output_attrs_fallible( auto infer_output_attrs_fallible(
...@@ -397,7 +402,9 @@ EncodedSubraph make_backward_graph( ...@@ -397,7 +402,9 @@ EncodedSubraph make_backward_graph(
if (backward_graph.graph.is_single()) { if (backward_graph.graph.is_single()) {
bgraph_op = backward_graph.graph.as_single(); bgraph_op = backward_graph.graph.as_single();
} else { } else {
bgraph_op = SubgraphOp::make(name+"Grad", backward_graph.graph, grad_outputs_has_grad, key); bgraph_op = SubgraphOp::make(
name + "Grad", std::make_shared<Subgraph>(backward_graph.graph),
grad_outputs_has_grad, key);
} }
auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level); auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level);
auto encoded_graph = EncodedSubraph::make_single(compiled_op, backward_graph.input_mask, backward_graph.output_mask); auto encoded_graph = EncodedSubraph::make_single(compiled_op, backward_graph.input_mask, backward_graph.output_mask);
...@@ -431,6 +438,8 @@ OP_TRAIT_REG(CompiledOp, CompiledOp) ...@@ -431,6 +438,8 @@ OP_TRAIT_REG(CompiledOp, CompiledOp)
.fallback(); .fallback();
}} }}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(UniqueKey);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp); MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey); MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey);
......
...@@ -28,7 +28,8 @@ VarNodeArray apply_on_var_node( ...@@ -28,7 +28,8 @@ VarNodeArray apply_on_var_node(
for (auto&& input: inputs) { for (auto&& input: inputs) {
input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()}); input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()});
} }
auto apply_functor = [](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){ auto apply_functor = [&](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){
op->set_scope(def.scope());
return OpDef::apply_on_var_node(*op, inputs); return OpDef::apply_on_var_node(*op, inputs);
}; };
auto const_functor = [&](const TensorPtr& value) { auto const_functor = [&](const TensorPtr& value) {
......
...@@ -48,16 +48,28 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> { ...@@ -48,16 +48,28 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> {
MGB_DYN_TYPE_OBJ_FINAL_DECL; MGB_DYN_TYPE_OBJ_FINAL_DECL;
}; };
struct UniqueKey final: Hashable {
public:
size_t hash() const override {
return reinterpret_cast<uintptr_t>(this);
}
protected:
bool is_same_st(const Hashable& rhs) const override {
return this == &rhs.cast_final_safe<UniqueKey>();
}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};
struct SubgraphOp final: OpDefImplBase<SubgraphOp> { struct SubgraphOp final: OpDefImplBase<SubgraphOp> {
std::string name; std::string name;
Subgraph graph; std::shared_ptr<Subgraph> graph;
SmallVector<bool> output_grad_mask; SmallVector<bool> output_grad_mask;
std::shared_ptr<Hashable> graph_key; std::shared_ptr<Hashable> graph_key;
SubgraphOp() = default; SubgraphOp() = default;
SubgraphOp(std::string name, Subgraph graph, SmallVector<bool> output_grad_mask={}, std::shared_ptr<Hashable> key=nullptr) SubgraphOp(std::string name, std::shared_ptr<Subgraph> graph, SmallVector<bool> output_grad_mask={}, std::shared_ptr<Hashable> key=nullptr)
: name{name}, graph{graph}, output_grad_mask{output_grad_mask}, graph_key{std::move(key)}{ : name{name}, graph{graph}, output_grad_mask{output_grad_mask}, graph_key{std::move(key)}{
if (this->output_grad_mask.empty()) { if (this->output_grad_mask.empty()) {
this->output_grad_mask.resize(graph.outputs.size(), true); this->output_grad_mask.resize(graph->outputs.size(), true);
} }
} }
MGB_DYN_TYPE_OBJ_FINAL_DECL; MGB_DYN_TYPE_OBJ_FINAL_DECL;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册