提交 75129cf0 编写于 作者: M Megvii Engine Team

chore(mge): clean up before merge to dev

* remove dead test
* clean some codes
* fix test_fake_quant due to change of implementation

GitOrigin-RevId: f030a9966d1664cd4a75cc9a1a992174610b70c0
上级 aba0acc7
...@@ -16,7 +16,6 @@ from typing import Dict, List, Union ...@@ -16,7 +16,6 @@ from typing import Dict, List, Union
import numpy as np import numpy as np
from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id
from .. import _imperative_rt from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
...@@ -26,6 +25,19 @@ from ..ops.builtin import OpDef ...@@ -26,6 +25,19 @@ from ..ops.builtin import OpDef
from .core import OpBase, TensorBase from .core import OpBase, TensorBase
def set_priority_to_id(dest_vars):
"""
For all oprs in the subgraph constructed by dest_vars,
sets its priority to id if its original priority is zero.
:param dest_vars: target vars representing the graph.
"""
dest_vec = []
for i in dest_vars:
assert isinstance(i, _imperative_rt.VarNode)
dest_vec.append(i)
_imperative_rt.graph._set_priority_to_id(dest_vec)
class Graph(_imperative_rt.ComputingGraph): class Graph(_imperative_rt.ComputingGraph):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -46,8 +58,8 @@ class Graph(_imperative_rt.ComputingGraph): ...@@ -46,8 +58,8 @@ class Graph(_imperative_rt.ComputingGraph):
cache[obj] = wrapper(obj) cache[obj] = wrapper(obj)
return cache[obj] return cache[obj]
def set_priority_to_id(self, dest_vars): def _set_priority_to_id(self, dest_vars):
_set_priority_to_id(_unwrap(dest_vars)) set_priority_to_id(_unwrap(dest_vars))
def compile(self, *args): def compile(self, *args):
self._function = super().compile(_unwrap(args)) self._function = super().compile(_unwrap(args))
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
import functools import functools
import multiprocessing as mp import multiprocessing as mp
from ..core._imperative_rt import sync from ..core._imperative_rt.core2 import sync
from .group import group_barrier, init_process_group from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork from .helper import get_device_count_by_fork
from .server import Server from .server import Server
......
...@@ -367,7 +367,7 @@ class trace: ...@@ -367,7 +367,7 @@ class trace:
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
else: else:
lazy_eval_graph.options.graph_opt_level = 2 lazy_eval_graph.options.graph_opt_level = 2
lazy_eval_graph.set_priority_to_id([*lazy_eval_links, *readers]) lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
lazy_eval_graph.compile(*lazy_eval_links, *readers) lazy_eval_graph.compile(*lazy_eval_links, *readers)
lazy_eval_graph() lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors): for r, x in zip(readers, lazy_eval_tensors):
...@@ -618,7 +618,7 @@ class trace: ...@@ -618,7 +618,7 @@ class trace:
graph.options.graph_opt_level = self._graph_opt_level graph.options.graph_opt_level = self._graph_opt_level
else: else:
graph.options.graph_opt_level = 2 graph.options.graph_opt_level = 2
graph.set_priority_to_id([*readers, *in_out_links, *io_links]) graph._set_priority_to_id([*readers, *in_out_links, *io_links])
graph.compile(*readers, *in_out_links, *io_links) graph.compile(*readers, *in_out_links, *io_links)
def _reset_exec_env(self): def _reset_exec_env(self):
......
...@@ -13,6 +13,7 @@ import numpy ...@@ -13,6 +13,7 @@ import numpy
from ..core import _imperative_rt from ..core import _imperative_rt
from ..core._imperative_rt import OperatorNode, VarNode from ..core._imperative_rt import OperatorNode, VarNode
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.megbrain_graph import set_priority_to_id
from ..tensor import Tensor from ..tensor import Tensor
__all__ = [ __all__ = [
...@@ -271,19 +272,6 @@ def replace_oprs( ...@@ -271,19 +272,6 @@ def replace_oprs(
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
def set_priority_to_id(dest_vars):
"""
For all oprs in the subgraph constructed by dest_vars,
sets its priority to id if its original priority is zero.
:param dest_vars: target vars representing the graph.
"""
dest_vec = []
for i in dest_vars:
assert isinstance(i, VarNode)
dest_vec.append(i)
_imperative_rt.graph._set_priority_to_id(dest_vec)
def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]: def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]:
""" """
Loads a serialized computing graph and run inference with input data. Loads a serialized computing graph and run inference with input data.
......
...@@ -32,7 +32,7 @@ struct GradSlotWeakPtr { ...@@ -32,7 +32,7 @@ struct GradSlotWeakPtr {
size_t idx; size_t idx;
}; };
struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject { struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override { std::shared_ptr<void> on_comp_node_finalize() override {
clear(); clear();
return {}; return {};
...@@ -56,7 +56,7 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( ...@@ -56,7 +56,7 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph(
} }
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
size_t key = XXHash{}.update(buf, buf_size).digest(); uint64_t key = XXHash{}.update(buf, buf_size).digest();
auto&& iter = backward_graph_cache.find(key); auto&& iter = backward_graph_cache.find(key);
if (iter != backward_graph_cache.end()) { if (iter != backward_graph_cache.end()) {
......
...@@ -32,7 +32,7 @@ inline bool _is_quantize(PyArray_Descr* dtype) { ...@@ -32,7 +32,7 @@ inline bool _is_quantize(PyArray_Descr* dtype) {
PyObject* _get_mgb_dtype(PyArray_Descr* dtype) { PyObject* _get_mgb_dtype(PyArray_Descr* dtype) {
// Return value: New reference. // Return value: New reference.
if (!_is_quantize(dtype)) { if (!_is_quantize(dtype)) {
throw py::type_error("expact quantize dtype"); throw py::type_error("expect quantize dtype");
} }
PyObject* ob = PyDict_GetItemString(dtype->metadata, "mgb_dtype"); PyObject* ob = PyDict_GetItemString(dtype->metadata, "mgb_dtype");
if (!PyDict_CheckExact(ob)) { if (!PyDict_CheckExact(ob)) {
......
...@@ -143,8 +143,10 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje ...@@ -143,8 +143,10 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
// PyErr_SetString(PyExc_TypeError, "keyword argument not allowed"); // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
// return nullptr; // return nullptr;
// } // }
if (!nargs) { if (nargs < 2) {
PyErr_SetString(PyExc_TypeError, "expect Op"); PyErr_SetString(PyExc_TypeError,
"py_apply expects one Op and at least one tensor "
"as argument");
return nullptr; return nullptr;
} }
...@@ -227,7 +229,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -227,7 +229,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
} }
} }
} else { } else {
py::detail::loader_life_support life_sup; // required to cast DType py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
auto data = tup[0].cast<py::array>(); auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>(); DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>(); CompNode cn = tup[2].cast<CompNode>();
...@@ -298,7 +300,6 @@ PyObject* TensorWrapper::handle() { ...@@ -298,7 +300,6 @@ PyObject* TensorWrapper::handle() {
void TensorWrapper::set_handle(PyObject* dest) { void TensorWrapper::set_handle(PyObject* dest) {
auto py_dest = py::reinterpret_borrow<py::object>(dest); auto py_dest = py::reinterpret_borrow<py::object>(dest);
SharedHandle real_dest = py_dest.cast<SharedHandle>(); SharedHandle real_dest = py_dest.cast<SharedHandle>();
auto&& t = std::move(m_tensor->m_handle);
m_tensor->m_handle = std::move(real_dest); m_tensor->m_handle = std::move(real_dest);
} }
...@@ -617,7 +618,7 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { ...@@ -617,7 +618,7 @@ CompNode _get_device(PyObject*const* args, size_t nargs) {
} }
} }
if (!valid) { if (!valid) {
mgb_assert(0, "expact at least 1 device"); mgb_assert(0, "expect at least 1 device");
} }
Py_DECREF(tuple); Py_DECREF(tuple);
return cn; return cn;
......
...@@ -88,6 +88,7 @@ def test_dist_grad(): ...@@ -88,6 +88,7 @@ def test_dist_grad():
worker() worker()
def test_grad(): def test_grad():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np) x = as_tensor(x_np)
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pytest
# from megengine.core.interpreter.hints import function
@pytest.mark.skip(reason="under rewrite")
def test_1():
@function
def f(x, p):
x = x + 1
if p:
return x * x
return x * 2
x = Tensor(0)
for _ in range(5):
assert f(x, 0).numpy() == 2
assert f(x, 1).numpy() == 1
...@@ -83,7 +83,7 @@ def test_TQT(): ...@@ -83,7 +83,7 @@ def test_TQT():
def _save_to(self, name="grad"): def _save_to(self, name="grad"):
def callback(tensor, grad): def callback(grad):
setattr(self, name, grad) setattr(self, name, grad)
return callback return callback
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册