提交 3e11efa7 编写于 作者: L LDOUBLEV

fix mem leak

上级 6887f457
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import numpy as np import numpy as np
import os import os
import random import random
import traceback
from paddle.io import Dataset from paddle.io import Dataset
from .imaug import transform, create_operators from .imaug import transform, create_operators
...@@ -46,7 +45,6 @@ class SimpleDataSet(Dataset): ...@@ -46,7 +45,6 @@ class SimpleDataSet(Dataset):
self.seed = seed self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.check_data()
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if self.mode == "train" and self.do_shuffle: if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random() self.shuffle_data_random()
...@@ -103,18 +101,25 @@ class SimpleDataSet(Dataset): ...@@ -103,18 +101,25 @@ class SimpleDataSet(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data = self.data_lines[file_idx] data_line = self.data_lines[file_idx]
try: try:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
data['ext_data'] = self.get_ext_data() data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops) outs = transform(data, self.ops)
except: except Exception as e:
error_meg = traceback.format_exc()
self.logger.error( self.logger.error(
"When parsing file {} and label {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(
data['img_path'],data['label'], error_meg)) data_line, e))
outs = None outs = None
if outs is None: if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation. # during evaluation, we should fix the idx to get same results for many times of evaluation.
...@@ -125,17 +130,3 @@ class SimpleDataSet(Dataset): ...@@ -125,17 +130,3 @@ class SimpleDataSet(Dataset):
def __len__(self): def __len__(self):
return len(self.data_idx_order_list) return len(self.data_idx_order_list)
def check_data(self):
new_data_lines = []
for data_line in self.data_lines:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
if os.path.exists(img_path):
new_data_lines.append({'img_path': img_path, 'label': label})
else:
self.logger.info("{} does not exist!".format(img_path))
self.data_lines = new_data_lines
\ No newline at end of file
...@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): ...@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained") pretrained = model_config.pop("pretrained")
model = BaseModel(model_config) model = BaseModel(model_config)
if pretrained is not None: if pretrained is not None:
model = load_pretrained_params(model, pretrained) load_pretrained_params(model, pretrained)
if freeze_params: if freeze_params:
for param in model.parameters(): for param in model.parameters():
param.trainable = False param.trainable = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册