diff --git a/parl/utils/csv_logger.py b/parl/utils/csv_logger.py index e5152e599b831fab0dbed3a43d3bb40b3412bcb5..8b9045658338e63a2f2d92371e96b755d1d74442 100644 --- a/parl/utils/csv_logger.py +++ b/parl/utils/csv_logger.py @@ -19,12 +19,24 @@ __all__ = ['CSVLogger'] class CSVLogger(object): def __init__(self, output_file): - """CSV Logger which can write dict result to csv file + """CSV Logger which can write dict result to csv file. + + Args: + output_file(str): filename of the csv file. """ self.output_file = open(output_file, "w") self.csv_writer = None def log_dict(self, result): + """Ouput result to the csv file. + + Will create the header of the csv file automatically when the function is called for the first time. + Ususally, the keys of the result should be the same every time you call the function. + + Args: + result(dict) + """ + assert isinstance(result, dict), "the input should be a dict." if self.csv_writer is None: self.csv_writer = csv.DictWriter(self.output_file, result.keys()) self.csv_writer.writeheader() @@ -38,4 +50,9 @@ class CSVLogger(object): self.output_file.flush() def close(self): - self.output_file.close() + if not self.output_file.closed: + self.output_file.close() + + def __del__(self): + if not self.output_file.closed: + self.output_file.close()