diff --git a/lite/pylite/megenginelite/network.py b/lite/pylite/megenginelite/network.py index 4edcc96035905664c595be08285d8d7d3b9c68ab..7324f48d142c13b39958e77b677b2c3a3f6a93ea 100644 --- a/lite/pylite/megenginelite/network.py +++ b/lite/pylite/megenginelite/network.py @@ -235,15 +235,29 @@ class LiteNetworkIO(object): LiteAsyncCallback = CFUNCTYPE(c_int) +LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) +LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) + + +def wrap_async_callback(func): + global wrapper + + @CFUNCTYPE(c_int) + def wrapper(): + return func() + + return wrapper def start_finish_callback(func): + global wrapper + @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) def wrapper(c_ios, c_tensors, size): ios = {} for i in range(size): tensor = LiteTensor() - tensor._tensor = c_tensors[i] + tensor._tensor = c_void_p(c_tensors[i]) tensor.update() io = c_ios[i] ios[io] = tensor @@ -288,8 +302,8 @@ class _NetworkAPI(_LiteCObjBase): ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]), ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]), ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]), - ("LITE_set_start_callback", [_Cnetwork]), - ("LITE_set_finish_callback", [_Cnetwork]), + ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), + ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), ] @@ -482,8 +496,8 @@ class LiteNetwork(object): self._api.LITE_share_runtime_memroy(self._network, src_network._network) def async_with_callback(self, async_callback): - async_callback = LiteAsyncCallback(async_callback) - self._api.LITE_set_async_callback(self._network, async_callback) + callback = wrap_async_callback(async_callback) + self._api.LITE_set_async_callback(self._network, callback) def set_start_callback(self, start_callback): """ @@ -491,7 +505,8 @@ class LiteNetwork(object): the start_callback with param mapping from LiteIO to the corresponding LiteTensor """ - self._api.LITE_set_start_callback(self._network, start_callback) + callback = start_finish_callback(start_callback) + self._api.LITE_set_start_callback(self._network, callback) def set_finish_callback(self, finish_callback): """ @@ -499,7 +514,8 @@ class LiteNetwork(object): the finish_callback with param mapping from LiteIO to the corresponding LiteTensor """ - self._api.LITE_set_finish_callback(self._network, finish_callback) + callback = start_finish_callback(finish_callback) + self._api.LITE_set_finish_callback(self._network, callback) def enable_profile_performance(self, profile_file): c_file = profile_file.encode("utf-8") diff --git a/lite/pylite/test/test_network.py b/lite/pylite/test/test_network.py index f445a35fae922c76e3c4eff2df49167b9d3d97a4..6bb8c979d47ae1fc481a6ad78caeb370d7903d45 100644 --- a/lite/pylite/test/test_network.py +++ b/lite/pylite/test/test_network.py @@ -274,83 +274,81 @@ class TestNetwork(TestShuffleNet): self.do_forward(src_network) self.do_forward(new_network) - # def test_network_async(self): - # count = 0 - # finished = False - # - # def async_callback(): - # nonlocal finished - # finished = True - # return 0 - # - # option = LiteOptions() - # option.var_sanity_check_first_run = 0 - # config = LiteConfig(option=option) - # - # network = LiteNetwork(config=config) - # network.load(self.model_path) - # - # network.async_with_callback(async_callback) - # - # input_tensor = network.get_io_tensor(network.get_input_name(0)) - # output_tensor = network.get_io_tensor(network.get_output_name(0)) - # - # input_tensor.set_data_by_share(self.input_data) - # network.forward() - # - # while not finished: - # count += 1 - # - # assert count > 0 - # output_data = output_tensor.to_numpy() - # self.check_correct(output_data) - # - # def test_network_start_callback(self): - # network = LiteNetwork() - # network.load(self.model_path) - # start_checked = False - # - # @start_finish_callback - # def start_callback(ios): - # nonlocal start_checked - # start_checked = True - # assert len(ios) == 1 - # for key in ios: - # io = key - # data = ios[key].to_numpy().flatten() - # input_data = self.input_data.flatten() - # assert data.size == input_data.size - # assert io.name.decode("utf-8") == "data" - # for i in range(data.size): - # assert data[i] == input_data[i] - # return 0 - # - # network.set_start_callback(start_callback) - # self.do_forward(network, 1) - # assert start_checked == True - # - # def test_network_finish_callback(self): - # network = LiteNetwork() - # network.load(self.model_path) - # finish_checked = False - # - # @start_finish_callback - # def finish_callback(ios): - # nonlocal finish_checked - # finish_checked = True - # assert len(ios) == 1 - # for key in ios: - # io = key - # data = ios[key].to_numpy().flatten() - # output_data = self.correct_data.flatten() - # assert data.size == output_data.size - # for i in range(data.size): - # assert data[i] == output_data[i] - # return 0 - # - # network.set_finish_callback(finish_callback) - # self.do_forward(network, 1) - # assert finish_checked == True + def test_network_async(self): + count = 0 + finished = False + + def async_callback(): + nonlocal finished + finished = True + return 0 + + option = LiteOptions() + option.var_sanity_check_first_run = 0 + config = LiteConfig(option=option) + + network = LiteNetwork(config=config) + network.load(self.model_path) + + network.async_with_callback(async_callback) + + input_tensor = network.get_io_tensor(network.get_input_name(0)) + output_tensor = network.get_io_tensor(network.get_output_name(0)) + + input_tensor.set_data_by_share(self.input_data) + network.forward() + + while not finished: + count += 1 + + assert count > 0 + output_data = output_tensor.to_numpy() + self.check_correct(output_data) + + def test_network_start_callback(self): + network = LiteNetwork() + network.load(self.model_path) + start_checked = False + + def start_callback(ios): + nonlocal start_checked + start_checked = True + assert len(ios) == 1 + for key in ios: + io = key + data = ios[key].to_numpy().flatten() + input_data = self.input_data.flatten() + assert data.size == input_data.size + assert io.name.decode("utf-8") == "data" + for i in range(data.size): + assert data[i] == input_data[i] + return 0 + + network.set_start_callback(start_callback) + self.do_forward(network, 1) + assert start_checked == True + + def test_network_finish_callback(self): + network = LiteNetwork() + network.load(self.model_path) + finish_checked = False + + def finish_callback(ios): + nonlocal finish_checked + finish_checked = True + assert len(ios) == 1 + for key in ios: + io = key + data = ios[key].to_numpy().flatten() + output_data = self.correct_data.flatten() + assert data.size == output_data.size + for i in range(data.size): + assert data[i] == output_data[i] + return 0 + + network.set_finish_callback(finish_callback) + self.do_forward(network, 1) + assert finish_checked == True def test_enable_profile(self): network = LiteNetwork() diff --git a/lite/pylite/test/test_network_cuda.py b/lite/pylite/test/test_network_cuda.py index a5b2ac839dea007335d24b32ecdd7a6b24f4ea73..56e74b247d923367a3d44bbb9e8e4d91634cb2d8 100644 --- a/lite/pylite/test/test_network_cuda.py +++ b/lite/pylite/test/test_network_cuda.py @@ -186,6 +186,57 @@ class TestNetwork(TestShuffleNetCuda): self.do_forward(src_network) self.do_forward(new_network) + @require_cuda + def test_network_start_callback(self): + config = LiteConfig() + config.device = LiteDeviceType.LITE_CUDA + network = LiteNetwork(config) + network.load(self.model_path) + start_checked = False + + def start_callback(ios): + nonlocal start_checked + start_checked = True + assert len(ios) == 1 + for key in ios: + io = key + data = ios[key].to_numpy().flatten() + input_data = self.input_data.flatten() + assert data.size == input_data.size + assert io.name.decode("utf-8") == "data" + for i in range(data.size): + assert data[i] == input_data[i] + return 0 + + network.set_start_callback(start_callback) + self.do_forward(network, 1) + assert start_checked == True + + @require_cuda + def test_network_finish_callback(self): + config = LiteConfig() + config.device = LiteDeviceType.LITE_CUDA + network = LiteNetwork(config) + network.load(self.model_path) + finish_checked = False + + def finish_callback(ios): + nonlocal finish_checked + finish_checked = True + assert len(ios) == 1 + for key in ios: + io = key + data = ios[key].to_numpy().flatten() + output_data = self.correct_data.flatten() + assert data.size == output_data.size + for i in range(data.size): + assert data[i] == output_data[i] + return 0 + + network.set_finish_callback(finish_callback) + self.do_forward(network, 1) + assert finish_checked == True + @require_cuda() def test_enable_profile(self): config = LiteConfig()