提交 1ae37919 编写于 作者: D dyning

trans to paddle-rc

上级 fa675f89
...@@ -37,6 +37,7 @@ from ppocr.data.lmdb_dataset import LMDBDateSet ...@@ -37,6 +37,7 @@ from ppocr.data.lmdb_dataset import LMDBDateSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
def term_mp(sig_num, frame): def term_mp(sig_num, frame):
""" kill all child processes """ kill all child processes
""" """
...@@ -45,24 +46,27 @@ def term_mp(sig_num, frame): ...@@ -45,24 +46,27 @@ def term_mp(sig_num, frame):
print("main proc {} exit, kill process group " "{}".format(pid, pgid)) print("main proc {} exit, kill process group " "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL) os.killpg(pgid, signal.SIGKILL)
signal.signal(signal.SIGINT, term_mp) signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp) signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device):
def build_dataloader(config, mode, device, logger):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDateSet'] support_dict = ['SimpleDataSet', 'LMDBDateSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test." assert mode in ['Train', 'Eval', 'Test'
], "Mode should be Train, Eval or Test."
dataset = eval(module_name)(config, mode)
dataset = eval(module_name)(config, mode, logger)
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card'] batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last'] drop_last = loader_config['drop_last']
num_workers = loader_config['num_workers'] num_workers = loader_config['num_workers']
if mode == "Train": if mode == "Train":
#Distribute data to multiple cards #Distribute data to multiple cards
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
...@@ -76,14 +80,13 @@ def build_dataloader(config, mode, device): ...@@ -76,14 +80,13 @@ def build_dataloader(config, mode, device):
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
drop_last=drop_last) drop_last=drop_last)
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
places=device, places=device,
num_workers=num_workers, num_workers=num_workers,
return_list=True) return_list=True)
return data_loader return data_loader
#return data_loader, _dataset.info_dict
\ No newline at end of file
...@@ -22,37 +22,26 @@ import lmdb ...@@ -22,37 +22,26 @@ import lmdb
import cv2 import cv2
from .imaug import transform, create_operators from .imaug import transform, create_operators
from ppocr.utils.logging import get_logger
logger = get_logger()
class LMDBDateSet(Dataset): class LMDBDateSet(Dataset):
def __init__(self, config, mode): def __init__(self, config, mode, logger):
super(LMDBDateSet, self).__init__() super(LMDBDateSet, self).__init__()
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card'] batch_size = loader_config['batch_size_per_card']
data_dir = dataset_config['data_dir'] data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
logger.info("Initialize indexs of datasets:%s" % data_dir) logger.info("Initialize indexs of datasets:%s" % data_dir)
self.data_idx_order_list = self.dataset_traversal() self.data_idx_order_list = self.dataset_traversal()
if self.do_shuffle: if self.do_shuffle:
np.random.shuffle(self.data_idx_order_list) np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
# # for rec
# character = ''
# for op in self.ops:
# if hasattr(op, 'character'):
# character = getattr(op, 'character')
# self.info_dict = {'character': character}
def load_hierarchical_lmdb_dataset(self, data_dir): def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
...@@ -71,7 +60,7 @@ class LMDBDateSet(Dataset): ...@@ -71,7 +60,7 @@ class LMDBDateSet(Dataset):
"txn":txn, "num_samples":num_samples} "txn":txn, "num_samples":num_samples}
dataset_idx += 1 dataset_idx += 1
return lmdb_sets return lmdb_sets
def dataset_traversal(self): def dataset_traversal(self):
lmdb_num = len(self.lmdb_sets) lmdb_num = len(self.lmdb_sets)
total_sample_num = 0 total_sample_num = 0
...@@ -88,7 +77,7 @@ class LMDBDateSet(Dataset): ...@@ -88,7 +77,7 @@ class LMDBDateSet(Dataset):
data_idx_order_list[beg_idx:end_idx, 1] += 1 data_idx_order_list[beg_idx:end_idx, 1] += 1
beg_idx = beg_idx + tmp_sample_num beg_idx = beg_idx + tmp_sample_num
return data_idx_order_list return data_idx_order_list
def get_img_data(self, value): def get_img_data(self, value):
"""get_img_data""" """get_img_data"""
if not value: if not value:
...@@ -110,15 +99,15 @@ class LMDBDateSet(Dataset): ...@@ -110,15 +99,15 @@ class LMDBDateSet(Dataset):
img_key = 'image-%09d'.encode() % index img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key) imgbuf = txn.get(img_key)
return imgbuf, label return imgbuf, label
def __getitem__(self, idx): def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx] lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx) lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx) file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info( sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
self.lmdb_sets[lmdb_idx]['txn'], file_idx) file_idx)
if sample_info is None: if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__())) return self.__getitem__(np.random.randint(self.__len__()))
img, label = sample_info img, label = sample_info
data = {'image': img, 'label': label} data = {'image': img, 'label': label}
outs = transform(data, self.ops) outs = transform(data, self.ops)
...@@ -128,4 +117,3 @@ class LMDBDateSet(Dataset): ...@@ -128,4 +117,3 @@ class LMDBDateSet(Dataset):
def __len__(self): def __len__(self):
return self.data_idx_order_list.shape[0] return self.data_idx_order_list.shape[0]
...@@ -20,18 +20,17 @@ from paddle.io import Dataset ...@@ -20,18 +20,17 @@ from paddle.io import Dataset
import time import time
from .imaug import transform, create_operators from .imaug import transform, create_operators
from ppocr.utils.logging import get_logger
logger = get_logger()
class SimpleDataSet(Dataset): class SimpleDataSet(Dataset):
def __init__(self, config, mode): def __init__(self, config, mode, logger):
super(SimpleDataSet, self).__init__() super(SimpleDataSet, self).__init__()
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card'] batch_size = loader_config['batch_size_per_card']
self.delimiter = dataset_config.get('delimiter', '\t') self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list) data_source_num = len(label_file_list)
...@@ -39,19 +38,21 @@ class SimpleDataSet(Dataset): ...@@ -39,19 +38,21 @@ class SimpleDataSet(Dataset):
ratio_list = [1.0] ratio_list = [1.0]
else: else:
ratio_list = dataset_config.pop('ratio_list') ratio_list = dataset_config.pop('ratio_list')
assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1." assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1."
assert len(ratio_list) == data_source_num, "The length of ratio_list should be the same as the file_list." assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines_list, data_num_list = self.get_image_info_list( self.data_lines_list, data_num_list = self.get_image_info_list(
label_file_list) label_file_list)
self.data_idx_order_list = self.dataset_traversal( self.data_idx_order_list = self.dataset_traversal(
data_num_list, ratio_list, batch_size) data_num_list, ratio_list, batch_size)
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
def get_image_info_list(self, file_list): def get_image_info_list(self, file_list):
...@@ -65,7 +66,7 @@ class SimpleDataSet(Dataset): ...@@ -65,7 +66,7 @@ class SimpleDataSet(Dataset):
data_lines_list.append(lines) data_lines_list.append(lines)
data_num_list.append(len(lines)) data_num_list.append(len(lines))
return data_lines_list, data_num_list return data_lines_list, data_num_list
def dataset_traversal(self, data_num_list, ratio_list, batch_size): def dataset_traversal(self, data_num_list, ratio_list, batch_size):
select_num_list = [] select_num_list = []
dataset_num = len(data_num_list) dataset_num = len(data_num_list)
...@@ -87,8 +88,7 @@ class SimpleDataSet(Dataset): ...@@ -87,8 +88,7 @@ class SimpleDataSet(Dataset):
cur_index = cur_index_sets[dataset_idx] cur_index = cur_index_sets[dataset_idx]
if cur_index >= data_num_list[dataset_idx]: if cur_index >= data_num_list[dataset_idx]:
break break
data_idx_order_list.append(( data_idx_order_list.append((dataset_idx, cur_index))
dataset_idx, cur_index))
cur_index_sets[dataset_idx] += 1 cur_index_sets[dataset_idx] += 1
if finish_read_num == dataset_num: if finish_read_num == dataset_num:
break break
...@@ -99,7 +99,7 @@ class SimpleDataSet(Dataset): ...@@ -99,7 +99,7 @@ class SimpleDataSet(Dataset):
for dno in range(len(self.data_lines_list)): for dno in range(len(self.data_lines_list)):
random.shuffle(self.data_lines_list[dno]) random.shuffle(self.data_lines_list[dno])
return return
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx, file_idx = self.data_idx_order_list[idx] dataset_idx, file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines_list[dataset_idx][file_idx] data_line = self.data_lines_list[dataset_idx][file_idx]
...@@ -119,4 +119,3 @@ class SimpleDataSet(Dataset): ...@@ -119,4 +119,3 @@ class SimpleDataSet(Dataset):
def __len__(self): def __len__(self):
return len(self.data_idx_order_list) return len(self.data_idx_order_list)
...@@ -158,7 +158,7 @@ class ConvBNLayer(nn.Layer): ...@@ -158,7 +158,7 @@ class ConvBNLayer(nn.Layer):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self.if_act = if_act self.if_act = if_act
self.act = act self.act = act
self.conv = nn.Conv2d( self.conv = nn.Conv2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
...@@ -183,7 +183,7 @@ class ConvBNLayer(nn.Layer): ...@@ -183,7 +183,7 @@ class ConvBNLayer(nn.Layer):
if self.act == "relu": if self.act == "relu":
x = F.relu(x) x = F.relu(x)
elif self.act == "hard_swish": elif self.act == "hard_swish":
x = F.hard_swish(x) x = F.activation.hard_swish(x)
else: else:
print("The activation function is selected incorrectly.") print("The activation function is selected incorrectly.")
exit() exit()
...@@ -242,16 +242,15 @@ class ResidualUnit(nn.Layer): ...@@ -242,16 +242,15 @@ class ResidualUnit(nn.Layer):
x = self.mid_se(x) x = self.mid_se(x)
x = self.linear_conv(x) x = self.linear_conv(x)
if self.if_shortcut: if self.if_shortcut:
x = paddle.elementwise_add(inputs, x) x = paddle.add(inputs, x)
return x return x
class SEModule(nn.Layer): class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4, name=""): def __init__(self, in_channels, reduction=4, name=""):
super(SEModule, self).__init__() super(SEModule, self).__init__()
self.avg_pool = nn.Pool2D( self.avg_pool = nn.AdaptiveAvgPool2D(1)
pool_type="avg", global_pooling=True, use_cudnn=False) self.conv1 = nn.Conv2D(
self.conv1 = nn.Conv2d(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels // reduction, out_channels=in_channels // reduction,
kernel_size=1, kernel_size=1,
...@@ -259,7 +258,7 @@ class SEModule(nn.Layer): ...@@ -259,7 +258,7 @@ class SEModule(nn.Layer):
padding=0, padding=0,
weight_attr=ParamAttr(name=name + "_1_weights"), weight_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset")) bias_attr=ParamAttr(name=name + "_1_offset"))
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction, in_channels=in_channels // reduction,
out_channels=in_channels, out_channels=in_channels,
kernel_size=1, kernel_size=1,
...@@ -273,5 +272,5 @@ class SEModule(nn.Layer): ...@@ -273,5 +272,5 @@ class SEModule(nn.Layer):
outputs = self.conv1(outputs) outputs = self.conv1(outputs)
outputs = F.relu(outputs) outputs = F.relu(outputs)
outputs = self.conv2(outputs) outputs = self.conv2(outputs)
outputs = F.hard_sigmoid(outputs) outputs = F.activation.hard_sigmoid(outputs)
return inputs * outputs return inputs * outputs
\ No newline at end of file
...@@ -127,7 +127,7 @@ class MobileNetV3(nn.Layer): ...@@ -127,7 +127,7 @@ class MobileNetV3(nn.Layer):
act='hard_swish', act='hard_swish',
name='conv_last') name='conv_last')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(scale * cls_ch_squeeze) self.out_channels = make_divisible(scale * cls_ch_squeeze)
def forward(self, x): def forward(self, x):
......
...@@ -33,7 +33,7 @@ def get_bias_attr(k, name): ...@@ -33,7 +33,7 @@ def get_bias_attr(k, name):
class Head(nn.Layer): class Head(nn.Layer):
def __init__(self, in_channels, name_list): def __init__(self, in_channels, name_list):
super(Head, self).__init__() super(Head, self).__init__()
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels // 4, out_channels=in_channels // 4,
kernel_size=3, kernel_size=3,
...@@ -51,14 +51,14 @@ class Head(nn.Layer): ...@@ -51,14 +51,14 @@ class Head(nn.Layer):
moving_mean_name=name_list[1] + '.w_1', moving_mean_name=name_list[1] + '.w_1',
moving_variance_name=name_list[1] + '.w_2', moving_variance_name=name_list[1] + '.w_2',
act='relu') act='relu')
self.conv2 = nn.ConvTranspose2d( self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
out_channels=in_channels // 4, out_channels=in_channels // 4,
kernel_size=2, kernel_size=2,
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name=name_list[2] + '.w_0', name=name_list[2] + '.w_0',
initializer=paddle.nn.initializer.MSRA(uniform=False)), initializer=paddle.nn.initializer.KaimingNormal()),
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2")) bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2"))
self.conv_bn2 = nn.BatchNorm( self.conv_bn2 = nn.BatchNorm(
num_channels=in_channels // 4, num_channels=in_channels // 4,
...@@ -71,14 +71,14 @@ class Head(nn.Layer): ...@@ -71,14 +71,14 @@ class Head(nn.Layer):
moving_mean_name=name_list[3] + '.w_1', moving_mean_name=name_list[3] + '.w_1',
moving_variance_name=name_list[3] + '.w_2', moving_variance_name=name_list[3] + '.w_2',
act="relu") act="relu")
self.conv3 = nn.ConvTranspose2d( self.conv3 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
out_channels=1, out_channels=1,
kernel_size=2, kernel_size=2,
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name=name_list[4] + '.w_0', name=name_list[4] + '.w_0',
initializer=paddle.nn.initializer.MSRA(uniform=False)), initializer=paddle.nn.initializer.KaimingNormal()),
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"), bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"),
) )
......
...@@ -26,37 +26,37 @@ class DBFPN(nn.Layer): ...@@ -26,37 +26,37 @@ class DBFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, **kwargs):
super(DBFPN, self).__init__() super(DBFPN, self).__init__()
self.out_channels = out_channels self.out_channels = out_channels
weight_attr = paddle.nn.initializer.MSRA(uniform=False) weight_attr = paddle.nn.initializer.KaimingNormal()
self.in2_conv = nn.Conv2d( self.in2_conv = nn.Conv2D(
in_channels=in_channels[0], in_channels=in_channels[0],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name='conv2d_51.w_0', initializer=weight_attr), name='conv2d_51.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in3_conv = nn.Conv2d( self.in3_conv = nn.Conv2D(
in_channels=in_channels[1], in_channels=in_channels[1],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name='conv2d_50.w_0', initializer=weight_attr), name='conv2d_50.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in4_conv = nn.Conv2d( self.in4_conv = nn.Conv2D(
in_channels=in_channels[2], in_channels=in_channels[2],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name='conv2d_49.w_0', initializer=weight_attr), name='conv2d_49.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in5_conv = nn.Conv2d( self.in5_conv = nn.Conv2D(
in_channels=in_channels[3], in_channels=in_channels[3],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name='conv2d_48.w_0', initializer=weight_attr), name='conv2d_48.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p5_conv = nn.Conv2d( self.p5_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
...@@ -64,7 +64,7 @@ class DBFPN(nn.Layer): ...@@ -64,7 +64,7 @@ class DBFPN(nn.Layer):
weight_attr=ParamAttr( weight_attr=ParamAttr(
name='conv2d_52.w_0', initializer=weight_attr), name='conv2d_52.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p4_conv = nn.Conv2d( self.p4_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
...@@ -72,7 +72,7 @@ class DBFPN(nn.Layer): ...@@ -72,7 +72,7 @@ class DBFPN(nn.Layer):
weight_attr=ParamAttr( weight_attr=ParamAttr(
name='conv2d_53.w_0', initializer=weight_attr), name='conv2d_53.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p3_conv = nn.Conv2d( self.p3_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
...@@ -80,7 +80,7 @@ class DBFPN(nn.Layer): ...@@ -80,7 +80,7 @@ class DBFPN(nn.Layer):
weight_attr=ParamAttr( weight_attr=ParamAttr(
name='conv2d_54.w_0', initializer=weight_attr), name='conv2d_54.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p2_conv = nn.Conv2d( self.p2_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
...@@ -97,17 +97,17 @@ class DBFPN(nn.Layer): ...@@ -97,17 +97,17 @@ class DBFPN(nn.Layer):
in3 = self.in3_conv(c3) in3 = self.in3_conv(c3)
in2 = self.in2_conv(c2) in2 = self.in2_conv(c2)
out4 = in4 + F.resize_nearest(in5, scale=2) # 1/16 out4 = in4 + F.upsample(in5, scale_factor=2, mode="nearest") # 1/16
out3 = in3 + F.resize_nearest(out4, scale=2) # 1/8 out3 = in3 + F.upsample(out4, scale_factor=2, mode="nearest") # 1/8
out2 = in2 + F.resize_nearest(out3, scale=2) # 1/4 out2 = in2 + F.upsample(out3, scale_factor=2, mode="nearest") # 1/4
p5 = self.p5_conv(in5) p5 = self.p5_conv(in5)
p4 = self.p4_conv(out4) p4 = self.p4_conv(out4)
p3 = self.p3_conv(out3) p3 = self.p3_conv(out3)
p2 = self.p2_conv(out2) p2 = self.p2_conv(out2)
p5 = F.resize_nearest(p5, scale=8) p5 = F.upsample(p5, scale_factor=8, mode="nearest")
p4 = F.resize_nearest(p4, scale=4) p4 = F.upsample(p4, scale_factor=4, mode="nearest")
p3 = F.resize_nearest(p3, scale=2) p3 = F.upsample(p3, scale_factor=2, mode="nearest")
fuse = paddle.concat([p5, p4, p3, p2], axis=1) fuse = paddle.concat([p5, p4, p3, p2], axis=1)
return fuse return fuse
...@@ -50,9 +50,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -50,9 +50,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step3 build optimizer # step3 build optimizer
optim_name = config.pop('name') optim_name = config.pop('name')
# Regularization is invalid. The bug will be fixed in paddle-rc. The param is
# weight_decay.
optim = getattr(optimizer, optim_name)(learning_rate=lr, optim = getattr(optimizer, optim_name)(learning_rate=lr,
regularization=reg, weight_decay=reg,
**config) **config)
return optim(parameters), lr return optim(parameters), lr
...@@ -17,7 +17,7 @@ from __future__ import division ...@@ -17,7 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
from paddle.optimizer import lr_scheduler from paddle.optimizer import lr as lr_scheduler
class Linear(object): class Linear(object):
......
...@@ -52,7 +52,6 @@ def get_logger(name='ppocr', log_file=None, log_level=logging.INFO): ...@@ -52,7 +52,6 @@ def get_logger(name='ppocr', log_file=None, log_level=logging.INFO):
stream_handler = logging.StreamHandler(stream=sys.stdout) stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0: if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0] log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True) os.makedirs(log_file_folder, exist_ok=True)
......
...@@ -42,16 +42,12 @@ def _mkdir_if_not_exist(path, logger): ...@@ -42,16 +42,12 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain( def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
model,
logger,
path=None,
load_static_weights=False):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {} does not "
"exists.".format(path)) "exists.".format(path))
if load_static_weights: if load_static_weights:
pre_state_dict = paddle.io.load_program_state(path) pre_state_dict = paddle.static.load_program_state(path)
param_state_dict = {} param_state_dict = {}
model_dict = model.state_dict() model_dict = model.state_dict()
for key in model_dict.keys(): for key in model_dict.keys():
...@@ -113,15 +109,11 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -113,15 +109,11 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
if not isinstance(pretrained_model, list): if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model] pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list): if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len( load_static_weights = [load_static_weights] * len(pretrained_model)
pretrained_model)
for idx, pretrained in enumerate(pretrained_model): for idx, pretrained in enumerate(pretrained_model):
load_static = load_static_weights[idx] load_static = load_static_weights[idx]
load_dygraph_pretrain( load_dygraph_pretrain(
model, model, logger, path=pretrained, load_static_weights=load_static)
logger,
path=pretrained,
load_static_weights=load_static)
logger.info("load pretrained model from {}".format( logger.info("load pretrained model from {}".format(
pretrained_model)) pretrained_model))
else: else:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
from paddle.jit import to_static
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from tools.program import load_config
from tools.program import merge_config
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="configuration file to use")
parser.add_argument(
"-o", "--output_path", type=str, default='./output/infer/')
return parser.parse_args()
class Model(paddle.nn.Layer):
def __init__(self, model):
super(Model, self).__init__()
self.pre_model = model
# Please modify the 'shape' according to actual needs
@to_static(input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 32, None], dtype='float32')
])
def forward(self, inputs):
x = self.pre_model(inputs)
return x
def main():
FLAGS = parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
# build post process
post_process_class = build_post_process(config['PostProcess'],
config['Global'])
# build model
#for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
init_model(config, model, logger)
model.eval()
model = Model(model)
paddle.jit.save(model, FLAGS.output_path)
if __name__ == "__main__":
main()
...@@ -33,6 +33,7 @@ from ppocr.utils.logging import get_logger ...@@ -33,6 +33,7 @@ from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
import numpy as np import numpy as np
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
def __init__(self): def __init__(self):
super(ArgsParser, self).__init__( super(ArgsParser, self).__init__(
...@@ -185,7 +186,7 @@ def train(config, ...@@ -185,7 +186,7 @@ def train(config,
for epoch in range(start_epoch, epoch_num): for epoch in range(start_epoch, epoch_num):
if epoch > 0: if epoch > 0:
train_loader = build_dataloader(config, 'Train', device) train_loader = build_dataloader(config, 'Train', device)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
if idx >= len(train_dataloader): if idx >= len(train_dataloader):
break break
...@@ -196,12 +197,7 @@ def train(config, ...@@ -196,12 +197,7 @@ def train(config,
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
avg_loss = loss['loss'] avg_loss = loss['loss']
if config['Global']['distributed']: avg_loss.backward()
avg_loss = model.scale_loss(avg_loss)
avg_loss.backward()
model.apply_collective_grads()
else:
avg_loss.backward()
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if not isinstance(lr_scheduler, float): if not isinstance(lr_scheduler, float):
...@@ -227,7 +223,8 @@ def train(config, ...@@ -227,7 +223,8 @@ def train(config,
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
vdl_writer.add_scalar('TRAIN/lr', lr, global_step) vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
if global_step > 0 and global_step % print_batch_step == 0: if dist.get_rank(
) == 0 and global_step > 0 and global_step % print_batch_step == 0:
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format( strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format(
epoch, epoch_num, global_step, logs, train_batch_elapse) epoch, epoch_num, global_step, logs, train_batch_elapse)
...@@ -235,8 +232,8 @@ def train(config, ...@@ -235,8 +232,8 @@ def train(config,
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
cur_metirc = eval(model, valid_dataloader, cur_metirc = eval(model, valid_dataloader, post_process_class,
post_process_class, eval_class, logger, print_batch_step) eval_class, logger, print_batch_step)
cur_metirc_str = 'cur metirc, {}'.format(', '.join( cur_metirc_str = 'cur metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) ['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
logger.info(cur_metirc_str) logger.info(cur_metirc_str)
...@@ -298,18 +295,17 @@ def train(config, ...@@ -298,18 +295,17 @@ def train(config,
return return
def eval(model, valid_dataloader, def eval(model, valid_dataloader, post_process_class, eval_class, logger,
post_process_class, eval_class, print_batch_step):
logger, print_batch_step):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
total_time = 0.0 total_time = 0.0
# pbar = tqdm(total=len(valid_dataloader), desc='eval model:') # pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader): if idx >= len(valid_dataloader):
break break
images = paddle.to_variable(batch[0]) images = paddle.to_tensor(batch[0])
start = time.time() start = time.time()
preds = model(images) preds = model(images)
...@@ -319,13 +315,14 @@ def eval(model, valid_dataloader, ...@@ -319,13 +315,14 @@ def eval(model, valid_dataloader,
total_time += time.time() - start total_time += time.time() - start
# Evaluate the results of the current batch # Evaluate the results of the current batch
eval_class(post_result, batch) eval_class(post_result, batch)
# pbar.update(1) # pbar.update(1)
total_frame += len(images) total_frame += len(images)
if idx % print_batch_step == 0: if idx % print_batch_step == 0 and dist.get_rank() == 0:
logger.info('tackling images for eval: {}/{}'.format( logger.info('tackling images for eval: {}/{}'.format(
idx, len(valid_dataloader))) idx, len(valid_dataloader)))
# Get final metirc,eg. acc or hmean # Get final metirc,eg. acc or hmean
metirc = eval_class.get_metric() metirc = eval_class.get_metric()
# pbar.close() # pbar.close()
model.train() model.train()
metirc['fps'] = total_frame / total_time metirc['fps'] = total_frame / total_time
...@@ -348,16 +345,15 @@ def preprocess(): ...@@ -348,16 +345,15 @@ def preprocess():
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device) device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1 config['Global']['distributed'] = dist.get_world_size() != 1
paddle.disable_static(device)
# save_config # save_config
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True) os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False) yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(log_file='{}/train.log'.format(save_model_dir)) logger = get_logger(log_file='{}/train.log'.format(save_model_dir))
if config['Global']['use_visualdl']: if config['Global']['use_visualdl']:
from visualdl import LogWriter from visualdl import LogWriter
......
...@@ -27,9 +27,8 @@ import yaml ...@@ -27,9 +27,8 @@ import yaml
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
paddle.manual_seed(2) paddle.seed(2)
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss from ppocr.losses import build_loss
...@@ -49,18 +48,18 @@ def main(config, device, logger, vdl_writer): ...@@ -49,18 +48,18 @@ def main(config, device, logger, vdl_writer):
dist.init_parallel_env() dist.init_parallel_env()
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
train_dataloader = build_dataloader(config, 'Train', device) train_dataloader = build_dataloader(config, 'Train', device, logger)
if config['Eval']: if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
else: else:
valid_dataloader = None valid_dataloader = None
# build post process # build post process
post_process_class = build_post_process( post_process_class = build_post_process(config['PostProcess'],
config['PostProcess'], global_config) global_config)
# build model # build model
#for rec algorithm #for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
...@@ -72,38 +71,29 @@ def main(config, device, logger, vdl_writer): ...@@ -72,38 +71,29 @@ def main(config, device, logger, vdl_writer):
# build loss # build loss
loss_class = build_loss(config['Loss']) loss_class = build_loss(config['Loss'])
# build optim # build optim
optimizer, lr_scheduler = build_optimizer(config['Optimizer'], optimizer, lr_scheduler = build_optimizer(
config['Optimizer'],
epochs=config['Global']['epoch_num'], epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_dataloader), step_each_epoch=len(train_dataloader),
parameters=model.parameters()) parameters=model.parameters())
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = init_model(config, model, logger, optimizer)
# start train # start train
program.train(config, program.train(config, train_dataloader, valid_dataloader, device, model,
train_dataloader, loss_class, optimizer, lr_scheduler, post_process_class,
valid_dataloader, eval_class, pre_best_model_dict, logger, vdl_writer)
device,
model,
loss_class,
optimizer,
lr_scheduler,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
vdl_writer)
def test_reader(config, device, logger): def test_reader(config, device, logger):
loader = build_dataloader(config, 'Train', device) loader = build_dataloader(config, 'Train', device)
# loader = build_dataloader(config, 'Eval', device) # loader = build_dataloader(config, 'Eval', device)
import time import time
starttime = time.time() starttime = time.time()
count = 0 count = 0
...@@ -113,11 +103,13 @@ def test_reader(config, device, logger): ...@@ -113,11 +103,13 @@ def test_reader(config, device, logger):
if count % 1 == 0: if count % 1 == 0:
batch_time = time.time() - starttime batch_time = time.time() - starttime
starttime = time.time() starttime = time.time()
logger.info("reader: {}, {}, {}".format(count, len(data), batch_time)) logger.info("reader: {}, {}, {}".format(count,
len(data), batch_time))
except Exception as e: except Exception as e:
logger.info(e) logger.info(e)
logger.info("finish reader: {}, Success!".format(count)) logger.info("finish reader: {}, Success!".format(count))
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess() config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册