提交 b20cda6b 编写于 作者: M Megvii Engine Team

fix(pylite): fix pylite callback test bug

GitOrigin-RevId: f4bd153950766dd1a523eec117158d4932b745a5
上级 c361b193
......@@ -235,15 +235,29 @@ class LiteNetworkIO(object):
LiteAsyncCallback = CFUNCTYPE(c_int)
LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
def wrap_async_callback(func):
global wrapper
@CFUNCTYPE(c_int)
def wrapper():
return func()
return wrapper
def start_finish_callback(func):
global wrapper
@CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
def wrapper(c_ios, c_tensors, size):
ios = {}
for i in range(size):
tensor = LiteTensor()
tensor._tensor = c_tensors[i]
tensor._tensor = c_void_p(c_tensors[i])
tensor.update()
io = c_ios[i]
ios[io] = tensor
......@@ -288,8 +302,8 @@ class _NetworkAPI(_LiteCObjBase):
("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
("LITE_set_start_callback", [_Cnetwork]),
("LITE_set_finish_callback", [_Cnetwork]),
("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
]
......@@ -482,8 +496,8 @@ class LiteNetwork(object):
self._api.LITE_share_runtime_memroy(self._network, src_network._network)
def async_with_callback(self, async_callback):
async_callback = LiteAsyncCallback(async_callback)
self._api.LITE_set_async_callback(self._network, async_callback)
callback = wrap_async_callback(async_callback)
self._api.LITE_set_async_callback(self._network, callback)
def set_start_callback(self, start_callback):
"""
......@@ -491,7 +505,8 @@ class LiteNetwork(object):
the start_callback with param mapping from LiteIO to the corresponding
LiteTensor
"""
self._api.LITE_set_start_callback(self._network, start_callback)
callback = start_finish_callback(start_callback)
self._api.LITE_set_start_callback(self._network, callback)
def set_finish_callback(self, finish_callback):
"""
......@@ -499,7 +514,8 @@ class LiteNetwork(object):
the finish_callback with param mapping from LiteIO to the corresponding
LiteTensor
"""
self._api.LITE_set_finish_callback(self._network, finish_callback)
callback = start_finish_callback(finish_callback)
self._api.LITE_set_finish_callback(self._network, callback)
def enable_profile_performance(self, profile_file):
c_file = profile_file.encode("utf-8")
......
......@@ -274,83 +274,81 @@ class TestNetwork(TestShuffleNet):
self.do_forward(src_network)
self.do_forward(new_network)
# def test_network_async(self):
# count = 0
# finished = False
#
# def async_callback():
# nonlocal finished
# finished = True
# return 0
#
# option = LiteOptions()
# option.var_sanity_check_first_run = 0
# config = LiteConfig(option=option)
#
# network = LiteNetwork(config=config)
# network.load(self.model_path)
#
# network.async_with_callback(async_callback)
#
# input_tensor = network.get_io_tensor(network.get_input_name(0))
# output_tensor = network.get_io_tensor(network.get_output_name(0))
#
# input_tensor.set_data_by_share(self.input_data)
# network.forward()
#
# while not finished:
# count += 1
#
# assert count > 0
# output_data = output_tensor.to_numpy()
# self.check_correct(output_data)
#
# def test_network_start_callback(self):
# network = LiteNetwork()
# network.load(self.model_path)
# start_checked = False
#
# @start_finish_callback
# def start_callback(ios):
# nonlocal start_checked
# start_checked = True
# assert len(ios) == 1
# for key in ios:
# io = key
# data = ios[key].to_numpy().flatten()
# input_data = self.input_data.flatten()
# assert data.size == input_data.size
# assert io.name.decode("utf-8") == "data"
# for i in range(data.size):
# assert data[i] == input_data[i]
# return 0
#
# network.set_start_callback(start_callback)
# self.do_forward(network, 1)
# assert start_checked == True
#
# def test_network_finish_callback(self):
# network = LiteNetwork()
# network.load(self.model_path)
# finish_checked = False
#
# @start_finish_callback
# def finish_callback(ios):
# nonlocal finish_checked
# finish_checked = True
# assert len(ios) == 1
# for key in ios:
# io = key
# data = ios[key].to_numpy().flatten()
# output_data = self.correct_data.flatten()
# assert data.size == output_data.size
# for i in range(data.size):
# assert data[i] == output_data[i]
# return 0
#
# network.set_finish_callback(finish_callback)
# self.do_forward(network, 1)
# assert finish_checked == True
def test_network_async(self):
count = 0
finished = False
def async_callback():
nonlocal finished
finished = True
return 0
option = LiteOptions()
option.var_sanity_check_first_run = 0
config = LiteConfig(option=option)
network = LiteNetwork(config=config)
network.load(self.model_path)
network.async_with_callback(async_callback)
input_tensor = network.get_io_tensor(network.get_input_name(0))
output_tensor = network.get_io_tensor(network.get_output_name(0))
input_tensor.set_data_by_share(self.input_data)
network.forward()
while not finished:
count += 1
assert count > 0
output_data = output_tensor.to_numpy()
self.check_correct(output_data)
def test_network_start_callback(self):
network = LiteNetwork()
network.load(self.model_path)
start_checked = False
def start_callback(ios):
nonlocal start_checked
start_checked = True
assert len(ios) == 1
for key in ios:
io = key
data = ios[key].to_numpy().flatten()
input_data = self.input_data.flatten()
assert data.size == input_data.size
assert io.name.decode("utf-8") == "data"
for i in range(data.size):
assert data[i] == input_data[i]
return 0
network.set_start_callback(start_callback)
self.do_forward(network, 1)
assert start_checked == True
def test_network_finish_callback(self):
network = LiteNetwork()
network.load(self.model_path)
finish_checked = False
def finish_callback(ios):
nonlocal finish_checked
finish_checked = True
assert len(ios) == 1
for key in ios:
io = key
data = ios[key].to_numpy().flatten()
output_data = self.correct_data.flatten()
assert data.size == output_data.size
for i in range(data.size):
assert data[i] == output_data[i]
return 0
network.set_finish_callback(finish_callback)
self.do_forward(network, 1)
assert finish_checked == True
def test_enable_profile(self):
network = LiteNetwork()
......
......@@ -186,6 +186,57 @@ class TestNetwork(TestShuffleNetCuda):
self.do_forward(src_network)
self.do_forward(new_network)
@require_cuda
def test_network_start_callback(self):
config = LiteConfig()
config.device = LiteDeviceType.LITE_CUDA
network = LiteNetwork(config)
network.load(self.model_path)
start_checked = False
def start_callback(ios):
nonlocal start_checked
start_checked = True
assert len(ios) == 1
for key in ios:
io = key
data = ios[key].to_numpy().flatten()
input_data = self.input_data.flatten()
assert data.size == input_data.size
assert io.name.decode("utf-8") == "data"
for i in range(data.size):
assert data[i] == input_data[i]
return 0
network.set_start_callback(start_callback)
self.do_forward(network, 1)
assert start_checked == True
@require_cuda
def test_network_finish_callback(self):
config = LiteConfig()
config.device = LiteDeviceType.LITE_CUDA
network = LiteNetwork(config)
network.load(self.model_path)
finish_checked = False
def finish_callback(ios):
nonlocal finish_checked
finish_checked = True
assert len(ios) == 1
for key in ios:
io = key
data = ios[key].to_numpy().flatten()
output_data = self.correct_data.flatten()
assert data.size == output_data.size
for i in range(data.size):
assert data[i] == output_data[i]
return 0
network.set_finish_callback(finish_callback)
self.do_forward(network, 1)
assert finish_checked == True
@require_cuda()
def test_enable_profile(self):
config = LiteConfig()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册