diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py index 2e88a558b0a4fb8b33809ccc90942a3dfb557fa4..d8cc44dd6d4bd659ec25f3dde54d90f7e7b7df84 100644 --- a/parl/remote/remote_decorator.py +++ b/parl/remote/remote_decorator.py @@ -16,6 +16,7 @@ import numpy as np import pyarrow import threading import time +import traceback import zmq from parl.remote import remote_constants from parl.utils import get_ip_address, logger, to_str, to_byte @@ -179,7 +180,7 @@ def remote_class(cls): except Exception as e: error_str = str(e) - logger.error(e) + logger.error(error_str) if type(e) == AttributeError: self.reply_socket.send_multipart([ @@ -197,9 +198,12 @@ def remote_class(cls): to_byte(error_str) ]) else: + traceback_str = str(traceback.format_exc()) + logger.error('traceback:\n{}'.format(traceback_str)) self.reply_socket.send_multipart([ remote_constants.EXCEPTION_TAG, - to_byte(error_str) + to_byte(error_str + '\ntraceback:\n' + + traceback_str) ]) continue diff --git a/parl/remote/tests/remote_test.py b/parl/remote/tests/remote_test.py index f36ab80feee8805f0b026dfffa67b20afe3faae5..64ba62c756f77842eaa8c46f991df4cf3ebc6fdb 100644 --- a/parl/remote/tests/remote_test.py +++ b/parl/remote/tests/remote_test.py @@ -50,6 +50,9 @@ class Simulator: value += 1 return value + def will_raise_exeception_func(self): + x = 1 / 0 + class TestRemote(unittest.TestCase): def _setUp(self, server_port): @@ -91,7 +94,8 @@ class TestRemote(unittest.TestCase): try: remote_sim.get_arg3() - except RemoteAttributeError: + except RemoteAttributeError as e: + logger.info('Expected exception: {}'.format(e)) # expected return @@ -105,7 +109,8 @@ class TestRemote(unittest.TestCase): try: remote_sim.set_arg3(3) - except RemoteAttributeError: + except RemoteAttributeError as e: + logger.info('Expected exception: {}'.format(e)) # expected return @@ -119,7 +124,8 @@ class TestRemote(unittest.TestCase): try: remote_sim.set_arg1(wrong_arg=1) - except RemoteError: + except RemoteError as e: + logger.info('Expected exception: {}'.format(e)) # expected return @@ -277,6 +283,22 @@ class TestRemote(unittest.TestCase): for t in threads: t.join() + def test_remote_object_with_call_raise_exception_function(self): + server_port = 17781 + self._setUp(server_port) + + remote_sim = self.remote_manager.get_remote() + + try: + remote_sim.will_raise_exeception_func() + except RemoteError as e: + assert 'Traceback (most recent call last)' in str(e) + logger.info('Expected exception: {}'.format(e)) + # expected + return + + assert False + def _run_remote_add(self, remote_sim): value = 0 for i in range(1000):