提交 b20cda6b 编写于 作者: M Megvii Engine Team

fix(pylite): fix pylite callback test bug

GitOrigin-RevId: f4bd153950766dd1a523eec117158d4932b745a5
上级 c361b193
...@@ -235,15 +235,29 @@ class LiteNetworkIO(object): ...@@ -235,15 +235,29 @@ class LiteNetworkIO(object):
LiteAsyncCallback = CFUNCTYPE(c_int) LiteAsyncCallback = CFUNCTYPE(c_int)
LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
def wrap_async_callback(func):
global wrapper
@CFUNCTYPE(c_int)
def wrapper():
return func()
return wrapper
def start_finish_callback(func): def start_finish_callback(func):
global wrapper
@CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
def wrapper(c_ios, c_tensors, size): def wrapper(c_ios, c_tensors, size):
ios = {} ios = {}
for i in range(size): for i in range(size):
tensor = LiteTensor() tensor = LiteTensor()
tensor._tensor = c_tensors[i] tensor._tensor = c_void_p(c_tensors[i])
tensor.update() tensor.update()
io = c_ios[i] io = c_ios[i]
ios[io] = tensor ios[io] = tensor
...@@ -288,8 +302,8 @@ class _NetworkAPI(_LiteCObjBase): ...@@ -288,8 +302,8 @@ class _NetworkAPI(_LiteCObjBase):
("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]), ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]), ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]), ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
("LITE_set_start_callback", [_Cnetwork]), ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
("LITE_set_finish_callback", [_Cnetwork]), ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
] ]
...@@ -482,8 +496,8 @@ class LiteNetwork(object): ...@@ -482,8 +496,8 @@ class LiteNetwork(object):
self._api.LITE_share_runtime_memroy(self._network, src_network._network) self._api.LITE_share_runtime_memroy(self._network, src_network._network)
def async_with_callback(self, async_callback): def async_with_callback(self, async_callback):
async_callback = LiteAsyncCallback(async_callback) callback = wrap_async_callback(async_callback)
self._api.LITE_set_async_callback(self._network, async_callback) self._api.LITE_set_async_callback(self._network, callback)
def set_start_callback(self, start_callback): def set_start_callback(self, start_callback):
""" """
...@@ -491,7 +505,8 @@ class LiteNetwork(object): ...@@ -491,7 +505,8 @@ class LiteNetwork(object):
the start_callback with param mapping from LiteIO to the corresponding the start_callback with param mapping from LiteIO to the corresponding
LiteTensor LiteTensor
""" """
self._api.LITE_set_start_callback(self._network, start_callback) callback = start_finish_callback(start_callback)
self._api.LITE_set_start_callback(self._network, callback)
def set_finish_callback(self, finish_callback): def set_finish_callback(self, finish_callback):
""" """
...@@ -499,7 +514,8 @@ class LiteNetwork(object): ...@@ -499,7 +514,8 @@ class LiteNetwork(object):
the finish_callback with param mapping from LiteIO to the corresponding the finish_callback with param mapping from LiteIO to the corresponding
LiteTensor LiteTensor
""" """
self._api.LITE_set_finish_callback(self._network, finish_callback) callback = start_finish_callback(finish_callback)
self._api.LITE_set_finish_callback(self._network, callback)
def enable_profile_performance(self, profile_file): def enable_profile_performance(self, profile_file):
c_file = profile_file.encode("utf-8") c_file = profile_file.encode("utf-8")
......
...@@ -274,83 +274,81 @@ class TestNetwork(TestShuffleNet): ...@@ -274,83 +274,81 @@ class TestNetwork(TestShuffleNet):
self.do_forward(src_network) self.do_forward(src_network)
self.do_forward(new_network) self.do_forward(new_network)
# def test_network_async(self): def test_network_async(self):
# count = 0 count = 0
# finished = False finished = False
#
# def async_callback(): def async_callback():
# nonlocal finished nonlocal finished
# finished = True finished = True
# return 0 return 0
#
# option = LiteOptions() option = LiteOptions()
# option.var_sanity_check_first_run = 0 option.var_sanity_check_first_run = 0
# config = LiteConfig(option=option) config = LiteConfig(option=option)
#
# network = LiteNetwork(config=config) network = LiteNetwork(config=config)
# network.load(self.model_path) network.load(self.model_path)
#
# network.async_with_callback(async_callback) network.async_with_callback(async_callback)
#
# input_tensor = network.get_io_tensor(network.get_input_name(0)) input_tensor = network.get_io_tensor(network.get_input_name(0))
# output_tensor = network.get_io_tensor(network.get_output_name(0)) output_tensor = network.get_io_tensor(network.get_output_name(0))
#
# input_tensor.set_data_by_share(self.input_data) input_tensor.set_data_by_share(self.input_data)
# network.forward() network.forward()
#
# while not finished: while not finished:
# count += 1 count += 1
#
# assert count > 0 assert count > 0
# output_data = output_tensor.to_numpy() output_data = output_tensor.to_numpy()
# self.check_correct(output_data) self.check_correct(output_data)
#
# def test_network_start_callback(self): def test_network_start_callback(self):
# network = LiteNetwork() network = LiteNetwork()
# network.load(self.model_path) network.load(self.model_path)
# start_checked = False start_checked = False
#
# @start_finish_callback def start_callback(ios):
# def start_callback(ios): nonlocal start_checked
# nonlocal start_checked start_checked = True
# start_checked = True assert len(ios) == 1
# assert len(ios) == 1 for key in ios:
# for key in ios: io = key
# io = key data = ios[key].to_numpy().flatten()
# data = ios[key].to_numpy().flatten() input_data = self.input_data.flatten()
# input_data = self.input_data.flatten() assert data.size == input_data.size
# assert data.size == input_data.size assert io.name.decode("utf-8") == "data"
# assert io.name.decode("utf-8") == "data" for i in range(data.size):
# for i in range(data.size): assert data[i] == input_data[i]
# assert data[i] == input_data[i] return 0
# return 0
# network.set_start_callback(start_callback)
# network.set_start_callback(start_callback) self.do_forward(network, 1)
# self.do_forward(network, 1) assert start_checked == True
# assert start_checked == True
# def test_network_finish_callback(self):
# def test_network_finish_callback(self): network = LiteNetwork()
# network = LiteNetwork() network.load(self.model_path)
# network.load(self.model_path) finish_checked = False
# finish_checked = False
# def finish_callback(ios):
# @start_finish_callback nonlocal finish_checked
# def finish_callback(ios): finish_checked = True
# nonlocal finish_checked assert len(ios) == 1
# finish_checked = True for key in ios:
# assert len(ios) == 1 io = key
# for key in ios: data = ios[key].to_numpy().flatten()
# io = key output_data = self.correct_data.flatten()
# data = ios[key].to_numpy().flatten() assert data.size == output_data.size
# output_data = self.correct_data.flatten() for i in range(data.size):
# assert data.size == output_data.size assert data[i] == output_data[i]
# for i in range(data.size): return 0
# assert data[i] == output_data[i]
# return 0 network.set_finish_callback(finish_callback)
# self.do_forward(network, 1)
# network.set_finish_callback(finish_callback) assert finish_checked == True
# self.do_forward(network, 1)
# assert finish_checked == True
def test_enable_profile(self): def test_enable_profile(self):
network = LiteNetwork() network = LiteNetwork()
......
...@@ -186,6 +186,57 @@ class TestNetwork(TestShuffleNetCuda): ...@@ -186,6 +186,57 @@ class TestNetwork(TestShuffleNetCuda):
self.do_forward(src_network) self.do_forward(src_network)
self.do_forward(new_network) self.do_forward(new_network)
@require_cuda
def test_network_start_callback(self):
config = LiteConfig()
config.device = LiteDeviceType.LITE_CUDA
network = LiteNetwork(config)
network.load(self.model_path)
start_checked = False
def start_callback(ios):
nonlocal start_checked
start_checked = True
assert len(ios) == 1
for key in ios:
io = key
data = ios[key].to_numpy().flatten()
input_data = self.input_data.flatten()
assert data.size == input_data.size
assert io.name.decode("utf-8") == "data"
for i in range(data.size):
assert data[i] == input_data[i]
return 0
network.set_start_callback(start_callback)
self.do_forward(network, 1)
assert start_checked == True
@require_cuda
def test_network_finish_callback(self):
config = LiteConfig()
config.device = LiteDeviceType.LITE_CUDA
network = LiteNetwork(config)
network.load(self.model_path)
finish_checked = False
def finish_callback(ios):
nonlocal finish_checked
finish_checked = True
assert len(ios) == 1
for key in ios:
io = key
data = ios[key].to_numpy().flatten()
output_data = self.correct_data.flatten()
assert data.size == output_data.size
for i in range(data.size):
assert data[i] == output_data[i]
return 0
network.set_finish_callback(finish_callback)
self.do_forward(network, 1)
assert finish_checked == True
@require_cuda() @require_cuda()
def test_enable_profile(self): def test_enable_profile(self):
config = LiteConfig() config = LiteConfig()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册