提交 a0692c17 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/models into parallel_executor

And some fix:
1. Remove some arguments
2. Rename 'inference' to 'infer'
3. Refine document
.DS_Store .DS_Store
*.pyc *.pyc
.*~
...@@ -17,7 +17,7 @@ addons: ...@@ -17,7 +17,7 @@ addons:
- python-pip - python-pip
- python2.7-dev - python2.7-dev
- clang-format-3.8 - clang-format-3.8
ssh_known_hosts: 52.76.173.135 ssh_known_hosts: 13.229.163.131
before_install: before_install:
- if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi - if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
- sudo pip install -U virtualenv pre-commit pip - sudo pip install -U virtualenv pre-commit pip
......
...@@ -15,9 +15,7 @@ from multiprocessing import Manager, Process ...@@ -15,9 +15,7 @@ from multiprocessing import Manager, Process
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal from data_utils.util import suppress_complaints, suppress_signal
from data_utils.util import SharedNDArray, SharedMemoryPoolManager from data_utils.util import CriticalException, ForceExitWrapper
from data_utils.util import DaemonProcessGroup, batch_to_ndarray
from data_utils.util import CriticalException, ForceExitWrapper, EpochEndSignal
class SampleInfo(object): class SampleInfo(object):
...@@ -32,11 +30,12 @@ class SampleInfo(object): ...@@ -32,11 +30,12 @@ class SampleInfo(object):
label_bin_path (str): File containing the label data. label_bin_path (str): File containing the label data.
label_size (int): Byte count of the sample's label data. label_size (int): Byte count of the sample's label data.
label_frame_num (int): Label number of the sample. label_frame_num (int): Label number of the sample.
sample_name (str): Key of the sample
""" """
def __init__(self, feature_bin_path, feature_start, feature_size, def __init__(self, feature_bin_path, feature_start, feature_size,
feature_frame_num, feature_dim, label_bin_path, label_start, feature_frame_num, feature_dim, label_bin_path, label_start,
label_size, label_frame_num): label_size, label_frame_num, sample_name):
self.feature_bin_path = feature_bin_path self.feature_bin_path = feature_bin_path
self.feature_start = feature_start self.feature_start = feature_start
self.feature_size = feature_size self.feature_size = feature_size
...@@ -47,6 +46,7 @@ class SampleInfo(object): ...@@ -47,6 +46,7 @@ class SampleInfo(object):
self.label_start = label_start self.label_start = label_start
self.label_size = label_size self.label_size = label_size
self.label_frame_num = label_frame_num self.label_frame_num = label_frame_num
self.sample_name = sample_name
class SampleInfoBucket(object): class SampleInfoBucket(object):
...@@ -69,8 +69,8 @@ class SampleInfoBucket(object): ...@@ -69,8 +69,8 @@ class SampleInfoBucket(object):
split_sentence_threshold(int): Sentence whose length larger than split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation. the value will trigger split operation.
split_sub_sentence_len(int): sub-sentence length is equal to split_sub_sentence_len(int): sub-sentence length is equal to
(split_sub_sentence_len + \ (split_sub_sentence_len
rand() % split_perturb). + rand() % split_perturb).
""" """
def __init__(self, def __init__(self,
...@@ -104,24 +104,33 @@ class SampleInfoBucket(object): ...@@ -104,24 +104,33 @@ class SampleInfoBucket(object):
feature_bin_path = self._feature_bin_paths[block_idx] feature_bin_path = self._feature_bin_paths[block_idx]
feature_desc_path = self._feature_desc_paths[block_idx] feature_desc_path = self._feature_desc_paths[block_idx]
label_desc_lines = open(label_desc_path).readlines()
feature_desc_lines = open(feature_desc_path).readlines() feature_desc_lines = open(feature_desc_path).readlines()
sample_num = int(label_desc_lines[0].split()[1]) label_desc_lines = []
assert sample_num == int(feature_desc_lines[0].split()[1]) if label_desc_path != "":
label_desc_lines = open(label_desc_path).readlines()
sample_num = int(feature_desc_lines[0].split()[1])
if label_desc_path != "":
assert sample_num == int(label_desc_lines[0].split()[1])
for i in xrange(sample_num): for i in xrange(sample_num):
feature_desc_split = feature_desc_lines[i + 1].split() feature_desc_split = feature_desc_lines[i + 1].split()
sample_name = feature_desc_split[0]
feature_start = int(feature_desc_split[2]) feature_start = int(feature_desc_split[2])
feature_size = int(feature_desc_split[3]) feature_size = int(feature_desc_split[3])
feature_frame_num = int(feature_desc_split[4]) feature_frame_num = int(feature_desc_split[4])
feature_dim = int(feature_desc_split[5]) feature_dim = int(feature_desc_split[5])
label_desc_split = label_desc_lines[i + 1].split() label_start = -1
label_start = int(label_desc_split[2]) label_size = -1
label_size = int(label_desc_split[3]) label_frame_num = feature_frame_num
label_frame_num = int(label_desc_split[4]) if label_desc_path != "":
assert feature_frame_num == label_frame_num label_desc_split = label_desc_lines[i + 1].split()
label_start = int(label_desc_split[2])
label_size = int(label_desc_split[3])
label_frame_num = int(label_desc_split[4])
assert feature_frame_num == label_frame_num
if self._split_sentence_threshold == -1 or \ if self._split_sentence_threshold == -1 or \
self._split_perturb == -1 or \ self._split_perturb == -1 or \
...@@ -131,7 +140,7 @@ class SampleInfoBucket(object): ...@@ -131,7 +140,7 @@ class SampleInfoBucket(object):
SampleInfo(feature_bin_path, feature_start, SampleInfo(feature_bin_path, feature_start,
feature_size, feature_frame_num, feature_dim, feature_size, feature_frame_num, feature_dim,
label_bin_path, label_start, label_size, label_bin_path, label_start, label_size,
label_frame_num)) label_frame_num, sample_name))
#split sentence #split sentence
else: else:
cur_frame_pos = 0 cur_frame_pos = 0
...@@ -152,16 +161,19 @@ class SampleInfoBucket(object): ...@@ -152,16 +161,19 @@ class SampleInfoBucket(object):
* feature_dim * 4, cur_frame_len * feature_dim * * feature_dim * 4, cur_frame_len * feature_dim *
4, cur_frame_len, feature_dim, label_bin_path, 4, cur_frame_len, feature_dim, label_bin_path,
label_start + cur_frame_pos * 4, cur_frame_len * label_start + cur_frame_pos * 4, cur_frame_len *
4, cur_frame_len)) 4, cur_frame_len, sample_name))
remain_frame_num -= cur_frame_len remain_frame_num -= cur_frame_len
cur_frame_pos += cur_frame_len cur_frame_pos += cur_frame_len
if remain_frame_num <= 0: if remain_frame_num <= 0:
break break
return sample_info_list return sample_info_list
class EpochEndSignal():
pass
class AsyncDataReader(object): class AsyncDataReader(object):
"""DataReader provides basic audio sample preprocessing pipeline including """DataReader provides basic audio sample preprocessing pipeline including
data loading and data augmentation. data loading and data augmentation.
...@@ -190,7 +202,7 @@ class AsyncDataReader(object): ...@@ -190,7 +202,7 @@ class AsyncDataReader(object):
def __init__(self, def __init__(self,
feature_file_list, feature_file_list,
label_file_list, label_file_list="",
drop_frame_len=512, drop_frame_len=512,
proc_num=10, proc_num=10,
sample_buffer_size=1024, sample_buffer_size=1024,
...@@ -213,25 +225,30 @@ class AsyncDataReader(object): ...@@ -213,25 +225,30 @@ class AsyncDataReader(object):
self._sample_info_buffer_size = sample_info_buffer_size self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size self._batch_buffer_size = batch_buffer_size
self._proc_num = proc_num self._proc_num = proc_num
if self._proc_num <= 2:
raise ValueError("Value of `proc_num` should be greater than 2.")
self._sample_proc_num = self._proc_num - 2
self._verbose = verbose self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False)) self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
def generate_bucket_list(self, is_shuffle): def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None: if self._block_info_list is None:
block_feature_info_lines = open(self._feature_file_list).readlines() block_feature_info_lines = open(self._feature_file_list).readlines()
block_label_info_lines = open(self._label_file_list).readlines()
assert len(block_feature_info_lines) == len(block_label_info_lines)
self._block_info_list = [] self._block_info_list = []
for i in xrange(0, len(block_feature_info_lines), 2): if self._label_file_list != "":
block_info = (block_feature_info_lines[i], block_label_info_lines = open(self._label_file_list).readlines()
block_feature_info_lines[i + 1], assert len(block_feature_info_lines) == len(
block_label_info_lines[i], block_label_info_lines)
block_label_info_lines[i + 1]) for i in xrange(0, len(block_feature_info_lines), 2):
self._block_info_list.append( block_info = (block_feature_info_lines[i],
map(lambda line: line.strip(), block_info)) block_feature_info_lines[i + 1],
block_label_info_lines[i],
block_label_info_lines[i + 1])
self._block_info_list.append(
map(lambda line: line.strip(), block_info))
else:
for i in xrange(0, len(block_feature_info_lines), 2):
block_info = (block_feature_info_lines[i],
block_feature_info_lines[i + 1], "", "")
self._block_info_list.append(
map(lambda line: line.strip(), block_info))
if is_shuffle: if is_shuffle:
self._rng.shuffle(self._block_info_list) self._rng.shuffle(self._block_info_list)
...@@ -251,23 +268,13 @@ class AsyncDataReader(object): ...@@ -251,23 +268,13 @@ class AsyncDataReader(object):
def set_transformers(self, transformers): def set_transformers(self, transformers):
self._transformers = transformers self._transformers = transformers
def recycle(self, *args): def _sample_generator(self):
for shared_ndarray in args:
if not isinstance(shared_ndarray, SharedNDArray):
raise Value("Only support recycle SharedNDArray object.")
shared_ndarray.recycle(self._pool_manager.pool)
def _start_async_processing(self):
sample_info_queue = self._manager.Queue(self._sample_info_buffer_size) sample_info_queue = self._manager.Queue(self._sample_info_buffer_size)
sample_queue = self._manager.Queue(self._sample_buffer_size) sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0 self._order_id = 0
@suppress_complaints(verbose=self._verbose, notify=self._force_exit) @suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_feeding_task(sample_info_queue): def ordered_feeding_task(sample_info_queue):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
for sample_info_bucket in self._bucket_list: for sample_info_bucket in self._bucket_list:
try: try:
sample_info_list = \ sample_info_list = \
...@@ -280,12 +287,13 @@ class AsyncDataReader(object): ...@@ -280,12 +287,13 @@ class AsyncDataReader(object):
sample_info_queue.put((sample_info, self._order_id)) sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1 self._order_id += 1
for i in xrange(self._sample_proc_num): for i in xrange(self._proc_num):
sample_info_queue.put(EpochEndSignal()) sample_info_queue.put(EpochEndSignal())
feeding_proc = DaemonProcessGroup( feeding_thread = Thread(
proc_num=1, target=ordered_feeding_task, args=(sample_info_queue, )) target=ordered_feeding_task, args=(sample_info_queue, ))
feeding_proc.start_all() feeding_thread.daemon = True
feeding_thread.start()
@suppress_complaints(verbose=self._verbose, notify=self._force_exit) @suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_processing_task(sample_info_queue, sample_queue, out_order): def ordered_processing_task(sample_info_queue, sample_queue, out_order):
...@@ -313,25 +321,32 @@ class AsyncDataReader(object): ...@@ -313,25 +321,32 @@ class AsyncDataReader(object):
sample_info.feature_size) sample_info.feature_size)
assert sample_info.feature_frame_num \ assert sample_info.feature_frame_num \
* sample_info.feature_dim * 4 == len(feature_bytes), \ * sample_info.feature_dim * 4 \
(sample_info.feature_bin_path, == len(feature_bytes), \
sample_info.feature_frame_num, (sample_info.feature_bin_path,
sample_info.feature_dim, sample_info.feature_frame_num,
len(feature_bytes)) sample_info.feature_dim,
len(feature_bytes))
label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start, label_data = None
sample_info.label_size) if sample_info.label_bin_path != "":
label_bytes = read_bytes(sample_info.label_bin_path,
assert sample_info.label_frame_num * 4 == len(label_bytes), ( sample_info.label_start,
sample_info.label_bin_path, sample_info.label_array, sample_info.label_size)
len(label_bytes))
assert sample_info.label_frame_num * 4 == len(
label_array = struct.unpack('I' * sample_info.label_frame_num, label_bytes), (sample_info.label_bin_path,
label_bytes) sample_info.label_array,
label_data = np.array( len(label_bytes))
label_array, dtype='int64').reshape(
(sample_info.label_frame_num, 1)) label_array = struct.unpack(
'I' * sample_info.label_frame_num, label_bytes)
label_data = np.array(
label_array, dtype='int64').reshape(
(sample_info.label_frame_num, 1))
else:
label_data = np.zeros(
(sample_info.label_frame_num, 1), dtype='int64')
feature_frame_num = sample_info.feature_frame_num feature_frame_num = sample_info.feature_frame_num
feature_dim = sample_info.feature_dim feature_dim = sample_info.feature_dim
...@@ -341,12 +356,11 @@ class AsyncDataReader(object): ...@@ -341,12 +356,11 @@ class AsyncDataReader(object):
feature_data = np.array( feature_data = np.array(
feature_array, dtype='float32').reshape(( feature_array, dtype='float32').reshape((
sample_info.feature_frame_num, sample_info.feature_dim)) sample_info.feature_frame_num, sample_info.feature_dim))
sample_data = (feature_data, label_data,
sample_data = (feature_data, label_data) sample_info.sample_name)
for transformer in self._transformers: for transformer in self._transformers:
# @TODO(pkuyym) to make transfomer only accept feature_data # @TODO(pkuyym) to make transfomer only accept feature_data
sample_data = transformer.perform_trans(sample_data) sample_data = transformer.perform_trans(sample_data)
while order_id != out_order[0]: while order_id != out_order[0]:
time.sleep(0.001) time.sleep(0.001)
...@@ -362,74 +376,77 @@ class AsyncDataReader(object): ...@@ -362,74 +376,77 @@ class AsyncDataReader(object):
out_order = self._manager.list([0]) out_order = self._manager.list([0])
args = (sample_info_queue, sample_queue, out_order) args = (sample_info_queue, sample_queue, out_order)
sample_proc = DaemonProcessGroup( workers = [
proc_num=self._sample_proc_num, Process(
target=ordered_processing_task, target=ordered_processing_task, args=args)
args=args) for _ in xrange(self._proc_num)
sample_proc.start_all() ]
return sample_queue for w in workers:
w.daemon = True
w.start()
def batch_iterator(self, batch_size, minimum_batch_size): finished_proc_num = 0
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_queue, batch_queue, pool): while self._force_exit == False:
def conv_to_shared(ndarray): try:
while self._force_exit == False: sample = sample_queue.get_nowait()
try: except Queue.Empty:
(name, shared_ndarray) = pool.popitem() time.sleep(0.001)
except Exception as e: else:
time.sleep(0.001) if isinstance(sample, EpochEndSignal):
finished_proc_num += 1
if finished_proc_num >= self._proc_num:
break
else: else:
shared_ndarray.copy(ndarray) continue
return shared_ndarray
if self._verbose == 0: yield sample
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal) def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod):
assert len(batch_samples)
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
name_lst = []
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
name_lst.append(sample[2])
return (batch_feature, batch_label, name_lst)
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = [] batch_samples = []
lod = [0] lod = [0]
done_num = 0 for sample in sample_generator():
while done_num < self._sample_proc_num: batch_samples.append(sample)
sample = sample_queue.get() lod.append(lod[-1] + sample[0].shape[0])
if isinstance(sample, EpochEndSignal): if len(batch_samples) == batch_size:
done_num += 1 (batch_feature, batch_label, name_lst) = batch_to_ndarray(
else: batch_samples, lod)
batch_samples.append(sample) batch_queue.put((batch_feature, batch_label, lod, name_lst))
lod.append(lod[-1] + sample[0].shape[0]) batch_samples = []
if len(batch_samples) == batch_size: lod = [0]
feature, label = batch_to_ndarray(batch_samples, lod)
feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
batch_queue.put((feature, label, lod))
batch_samples = []
lod = [0]
if len(batch_samples) >= minimum_batch_size: if len(batch_samples) >= minimum_batch_size:
(feature, label) = batch_to_ndarray(batch_samples, lod) (batch_feature, batch_label, name_lst) = batch_to_ndarray(
batch_samples, lod)
feature = conv_to_shared(feature) batch_queue.put((batch_feature, batch_label, lod, name_lst))
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
batch_queue.put((feature, label, lod))
batch_queue.put(EpochEndSignal()) batch_queue.put(EpochEndSignal())
sample_queue = self._start_async_processing() batch_queue = Queue.Queue(self._batch_buffer_size)
batch_queue = self._manager.Queue(self._batch_buffer_size)
self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size * assembling_thread = Thread(
3, self._manager)
assembling_proc = DaemonProcessGroup(
proc_num=1,
target=batch_assembling_task, target=batch_assembling_task,
args=(sample_queue, batch_queue, self._pool_manager.pool)) args=(self._sample_generator, batch_queue))
assembling_proc.start_all() assembling_thread.daemon = True
assembling_thread.start()
while self._force_exit == False: while self._force_exit == False:
try: try:
...@@ -440,6 +457,3 @@ class AsyncDataReader(object): ...@@ -440,6 +457,3 @@ class AsyncDataReader(object):
if isinstance(batch_data, EpochEndSignal): if isinstance(batch_data, EpochEndSignal):
break break
yield batch_data yield batch_data
# clean the shared memory
del self._pool_manager
...@@ -22,7 +22,7 @@ class TestTransMeanVarianceNorm(unittest.TestCase): ...@@ -22,7 +22,7 @@ class TestTransMeanVarianceNorm(unittest.TestCase):
feature = np.zeros((2, 120), dtype="float32") feature = np.zeros((2, 120), dtype="float32")
feature.fill(1) feature.fill(1)
trans = trans_mean_variance_norm.TransMeanVarianceNorm(self._file_path) trans = trans_mean_variance_norm.TransMeanVarianceNorm(self._file_path)
(feature1, label1) = trans.perform_trans((feature, None)) (feature1, label1, name) = trans.perform_trans((feature, None, None))
(mean, var) = trans.get_mean_var() (mean, var) = trans.get_mean_var()
feature_flat1 = feature1.flatten() feature_flat1 = feature1.flatten()
feature_flat = feature.flatten() feature_flat = feature.flatten()
...@@ -70,7 +70,7 @@ class TestTransAddDelta(unittest.TestCase): ...@@ -70,7 +70,7 @@ class TestTransAddDelta(unittest.TestCase):
feature[2, 0:40].fill(3) feature[2, 0:40].fill(3)
feature[3, 0:40].fill(4) feature[3, 0:40].fill(4)
trans = trans_add_delta.TransAddDelta() trans = trans_add_delta.TransAddDelta()
(feature, label) = trans.perform_trans((feature, None)) (feature, label, name) = trans.perform_trans((feature, None, None))
self.assertAlmostEqual(feature.shape[0], 4) self.assertAlmostEqual(feature.shape[0], 4)
self.assertAlmostEqual(feature.shape[1], 120) self.assertAlmostEqual(feature.shape[1], 120)
self.assertAlmostEqual(1.0, feature[0][0]) self.assertAlmostEqual(1.0, feature[0][0])
...@@ -93,7 +93,7 @@ class TestTransSplict(unittest.TestCase): ...@@ -93,7 +93,7 @@ class TestTransSplict(unittest.TestCase):
feature[i, :].fill(i) feature[i, :].fill(i)
trans = trans_splice.TransSplice() trans = trans_splice.TransSplice()
(feature, label) = trans.perform_trans((feature, None)) (feature, label, name) = trans.perform_trans((feature, None, None))
self.assertEqual(feature.shape[1], 110) self.assertEqual(feature.shape[1], 110)
for i in xrange(8): for i in xrange(8):
......
...@@ -32,9 +32,9 @@ class TransAddDelta(object): ...@@ -32,9 +32,9 @@ class TransAddDelta(object):
Args: Args:
sample(object,tuple): contain feature numpy and label numpy sample(object,tuple): contain feature numpy and label numpy
Returns: Returns:
(feature, label) (feature, label, name)
""" """
(feature, label) = sample (feature, label, name) = sample
frame_dim = feature.shape[1] frame_dim = feature.shape[1]
d_frame_dim = frame_dim * 3 d_frame_dim = frame_dim * 3
head_filled = 5 head_filled = 5
...@@ -64,7 +64,7 @@ class TransAddDelta(object): ...@@ -64,7 +64,7 @@ class TransAddDelta(object):
start * d_frame_dim + 2 * frame_dim, frame_dim, nframe, start * d_frame_dim + 2 * frame_dim, frame_dim, nframe,
d_frame_dim) d_frame_dim)
mat.shape = tmp_shape mat.shape = tmp_shape
return (mat[head_filled:mat.shape[0] - tail_filled, :], label) return (mat[head_filled:mat.shape[0] - tail_filled, :], label, name)
def _regress(self, data_in, start_in, data_out, start_out, size, n, step): def _regress(self, data_in, start_in, data_out, start_out, size, n, step):
""" regress """ regress
......
...@@ -53,9 +53,9 @@ class TransMeanVarianceNorm(object): ...@@ -53,9 +53,9 @@ class TransMeanVarianceNorm(object):
Args: Args:
sample(object):input sample, contain feature numpy and label numpy sample(object):input sample, contain feature numpy and label numpy
Returns: Returns:
(feature, label) (feature, label, name)
""" """
(feature, label) = sample (feature, label, name) = sample
shape = feature.shape shape = feature.shape
assert len(shape) == 2 assert len(shape) == 2
nfeature_len = shape[0] * shape[1] nfeature_len = shape[0] * shape[1]
...@@ -68,4 +68,4 @@ class TransMeanVarianceNorm(object): ...@@ -68,4 +68,4 @@ class TransMeanVarianceNorm(object):
feature[ncur_idx:ncur_idx + self._nLen] = block feature[ncur_idx:ncur_idx + self._nLen] = block
ncur_idx += self._nLen ncur_idx += self._nLen
feature = feature.reshape(shape) feature = feature.reshape(shape)
return (feature, label) return (feature, label, name)
...@@ -30,9 +30,9 @@ class TransSplice(object): ...@@ -30,9 +30,9 @@ class TransSplice(object):
Args: Args:
sample(object): input sample(feature, label) sample(object): input sample(feature, label)
Return: Return:
(feature, label) (feature, label, name)
""" """
(feature, label) = sample (feature, label, name) = sample
nframe_num = feature.shape[0] nframe_num = feature.shape[0]
nframe_dim = feature.shape[1] nframe_dim = feature.shape[1]
nnew_frame_dim = nframe_dim * ( nnew_frame_dim = nframe_dim * (
...@@ -61,4 +61,4 @@ class TransSplice(object): ...@@ -61,4 +61,4 @@ class TransSplice(object):
np.copyto(ret[i * nnew_frame_dim:(i + 1) * nnew_frame_dim], np.copyto(ret[i * nnew_frame_dim:(i + 1) * nnew_frame_dim],
mat[i * nframe_dim:i * nframe_dim + nnew_frame_dim]) mat[i * nframe_dim:i * nframe_dim + nnew_frame_dim])
ret = ret.reshape((nframe_num, nnew_frame_dim)) ret = ret.reshape((nframe_num, nnew_frame_dim))
return (ret, label) return (ret, label, name)
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys, time import sys
from six import reraise from six import reraise
from tblib import Traceback from tblib import Traceback
from multiprocessing import Manager, Process
import posix_ipc, mmap
import numpy as np import numpy as np
...@@ -37,19 +35,6 @@ def lodtensor_to_ndarray(lod_tensor): ...@@ -37,19 +35,6 @@ def lodtensor_to_ndarray(lod_tensor):
return ret, lod_tensor.lod() return ret, lod_tensor.lod()
def batch_to_ndarray(batch_samples, lod):
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
return (batch_feature, batch_label)
def split_infer_result(infer_seq, lod): def split_infer_result(infer_seq, lod):
infer_batch = [] infer_batch = []
for i in xrange(0, len(lod[0]) - 1): for i in xrange(0, len(lod[0]) - 1):
...@@ -57,127 +42,10 @@ def split_infer_result(infer_seq, lod): ...@@ -57,127 +42,10 @@ def split_infer_result(infer_seq, lod):
return infer_batch return infer_batch
class DaemonProcessGroup(object):
def __init__(self, proc_num, target, args):
self._proc_num = proc_num
self._workers = [
Process(
target=target, args=args) for _ in xrange(self._proc_num)
]
def start_all(self):
for w in self._workers:
w.daemon = True
w.start()
@property
def proc_num(self):
return self._proc_num
class EpochEndSignal(object):
pass
class CriticalException(Exception): class CriticalException(Exception):
pass pass
class SharedNDArray(object):
"""SharedNDArray utilizes shared memory to avoid data serialization when
data object shared among different processes. We can reconstruct the
`ndarray` when memory address, shape and dtype provided.
Args:
name (str): Address name of shared memory.
whether_verify (bool): Whether to validate the writing operation.
"""
def __init__(self, name, whether_verify=False):
self._name = name
self._shm = None
self._buf = None
self._array = np.zeros(1, dtype=np.float32)
self._inited = False
self._whether_verify = whether_verify
def zeros_like(self, shape, dtype):
size = int(np.prod(shape)) * np.dtype(dtype).itemsize
if self._inited:
self._shm = posix_ipc.SharedMemory(self._name)
else:
self._shm = posix_ipc.SharedMemory(
self._name, posix_ipc.O_CREAT, size=size)
self._buf = mmap.mmap(self._shm.fd, size)
self._array = np.ndarray(shape, dtype, self._buf, order='C')
def copy(self, ndarray):
size = int(np.prod(ndarray.shape)) * np.dtype(ndarray.dtype).itemsize
self.zeros_like(ndarray.shape, ndarray.dtype)
self._array[:] = ndarray
self._buf.flush()
self._inited = True
if self._whether_verify:
shm = posix_ipc.SharedMemory(self._name)
buf = mmap.mmap(shm.fd, size)
array = np.ndarray(ndarray.shape, ndarray.dtype, buf, order='C')
np.testing.assert_array_equal(array, ndarray)
@property
def ndarray(self):
return self._array
def recycle(self, pool):
self._buf.close()
self._shm.close_fd()
self._inited = False
pool[self._name] = self
def __getstate__(self):
return (self._name, self._array.shape, self._array.dtype, self._inited,
self._whether_verify)
def __setstate__(self, state):
self._name = state[0]
self._inited = state[3]
self.zeros_like(state[1], state[2])
self._whether_verify = state[4]
class SharedMemoryPoolManager(object):
"""SharedMemoryPoolManager maintains a multiprocessing.Manager.dict object.
All available addresses are allocated once and will be reused. Though this
class is not process-safe, the pool can be shared between processes. All
shared memory should be unlinked before the main process exited.
Args:
pool_size (int): Size of shared memory pool.
manager (dict): A multiprocessing.Manager object, the pool is
maintained by the proxy process.
name_prefix (str): Address prefix of shared memory.
"""
def __init__(self, pool_size, manager, name_prefix='/deep_asr'):
self._names = []
self._dict = manager.dict()
self._time_prefix = time.strftime('%Y%m%d%H%M%S')
for i in xrange(pool_size):
name = name_prefix + '_' + self._time_prefix + '_' + str(i)
self._dict[name] = SharedNDArray(name)
self._names.append(name)
@property
def pool(self):
return self._dict
def __del__(self):
for name in self._names:
# have to unlink the shared memory
posix_ipc.unlink_shared_memory(name)
def suppress_signal(signo, stack_frame): def suppress_signal(signo, stack_frame):
pass pass
......
...@@ -21,14 +21,15 @@ using fst::StdArc; ...@@ -21,14 +21,15 @@ using fst::StdArc;
Decoder::Decoder(std::string word_syms_filename, Decoder::Decoder(std::string word_syms_filename,
std::string fst_in_filename, std::string fst_in_filename,
std::string logprior_rxfilename) { std::string logprior_rxfilename,
kaldi::BaseFloat acoustic_scale) {
const char* usage = const char* usage =
"Decode, reading log-likelihoods (of transition-ids or whatever symbol " "Decode, reading log-likelihoods (of transition-ids or whatever symbol "
"is on the graph) as matrices."; "is on the graph) as matrices.";
kaldi::ParseOptions po(usage); kaldi::ParseOptions po(usage);
binary = true; binary = true;
acoustic_scale = 1.5; this->acoustic_scale = acoustic_scale;
allow_partial = true; allow_partial = true;
kaldi::FasterDecoderOptions decoder_opts; kaldi::FasterDecoderOptions decoder_opts;
decoder_opts.Register(&po, true); // true == include obscure settings. decoder_opts.Register(&po, true); // true == include obscure settings.
......
...@@ -29,7 +29,8 @@ class Decoder { ...@@ -29,7 +29,8 @@ class Decoder {
public: public:
Decoder(std::string word_syms_filename, Decoder(std::string word_syms_filename,
std::string fst_in_filename, std::string fst_in_filename,
std::string logprior_rxfilename); std::string logprior_rxfilename,
kaldi::BaseFloat acoustic_scale);
~Decoder(); ~Decoder();
// Interface to accept the scores read from specifier and return // Interface to accept the scores read from specifier and return
......
...@@ -23,7 +23,7 @@ PYBIND11_MODULE(post_decode_faster, m) { ...@@ -23,7 +23,7 @@ PYBIND11_MODULE(post_decode_faster, m) {
m.doc() = "Decoder for Deep ASR model"; m.doc() = "Decoder for Deep ASR model";
py::class_<Decoder>(m, "Decoder") py::class_<Decoder>(m, "Decoder")
.def(py::init<std::string, std::string, std::string>()) .def(py::init<std::string, std::string, std::string, kaldi::BaseFloat>())
.def("decode", .def("decode",
(std::vector<std::string> (Decoder::*)(std::string)) & (std::vector<std::string> (Decoder::*)(std::string)) &
Decoder::decode, Decoder::decode,
......
...@@ -8,7 +8,7 @@ import paddle.fluid as fluid ...@@ -8,7 +8,7 @@ import paddle.fluid as fluid
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice import data_utils.augmentor.trans_splice as trans_splice
import data_utils.data_reader as reader import data_utils.async_data_reader as reader
from data_utils.util import lodtensor_to_ndarray from data_utils.util import lodtensor_to_ndarray
from data_utils.util import split_infer_result from data_utils.util import split_infer_result
...@@ -79,12 +79,13 @@ def infer(args): ...@@ -79,12 +79,13 @@ def infer(args):
trans_splice.TransSplice() trans_splice.TransSplice()
] ]
infer_data_reader = reader.DataReader(args.infer_feature_lst, infer_data_reader = reader.AsyncDataReader(args.infer_feature_lst,
args.infer_label_lst) args.infer_label_lst)
infer_data_reader.set_transformers(ltrans) infer_data_reader.set_transformers(ltrans)
feature_t = fluid.LoDTensor() feature_t = fluid.LoDTensor()
one_batch = infer_data_reader.batch_iterator(args.batch_size, 1).next() one_batch = infer_data_reader.batch_iterator(args.batch_size, 1).next()
(features, labels, lod) = one_batch (features, labels, lod) = one_batch
feature_t.set(features, place) feature_t.set(features, place)
feature_t.set_lod([lod]) feature_t.set_lod([lod])
......
...@@ -106,6 +106,11 @@ def parse_args(): ...@@ -106,6 +106,11 @@ def parse_args():
type=str, type=str,
default="./decoder/logprior", default="./decoder/logprior",
help="The log prior probs for training data. (default: %(default)s)") help="The log prior probs for training data. (default: %(default)s)")
parser.add_argument(
'--acoustic_scale',
type=float,
default=0.2,
help="Scaling factor for acoustic likelihoods. (default: %(default)f)")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -143,6 +148,10 @@ def infer_from_ckpt(args): ...@@ -143,6 +148,10 @@ def infer_from_ckpt(args):
# load checkpoint. # load checkpoint.
fluid.io.load_persistables(exe, args.checkpoint) fluid.io.load_persistables(exe, args.checkpoint)
# init decoder
decoder = Decoder(args.vocabulary, args.graphs, args.log_prior,
args.acoustic_scale)
ltrans = [ ltrans = [
trans_add_delta.TransAddDelta(2, 2), trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var), trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
...@@ -162,12 +171,10 @@ def infer_from_ckpt(args): ...@@ -162,12 +171,10 @@ def infer_from_ckpt(args):
args.minimum_batch_size)): args.minimum_batch_size)):
# load_data # load_data
(features, labels, lod) = batch_data (features, labels, lod) = batch_data
feature_t.set(features.ndarray, place) feature_t.set(features, place)
feature_t.set_lod([lod.ndarray]) feature_t.set_lod([lod])
label_t.set(labels.ndarray, place) label_t.set(labels, place)
label_t.set_lod([lod.ndarray]) label_t.set_lod([lod])
infer_data_reader.recycle(features, labels, lod)
results = exe.run(infer_program, results = exe.run(infer_program,
feed={"feature": feature_t, feed={"feature": feature_t,
...@@ -181,7 +188,7 @@ def infer_from_ckpt(args): ...@@ -181,7 +188,7 @@ def infer_from_ckpt(args):
infer_batch = split_infer_result(probs, lod) infer_batch = split_infer_result(probs, lod)
for index, sample in enumerate(infer_batch): for index, sample in enumerate(infer_batch):
key = "utter#%d" % (batch_id * args.batch_size + index) key = "utter#%d" % (batch_id * args.batch_size + index)
print(key, ": ", decoder.decode(key, sample), "\n") print(key, ": ", decoder.decode(key, sample).encode("utf8"), "\n")
print(np.mean(infer_costs), np.mean(infer_accs)) print(np.mean(infer_costs), np.mean(infer_accs))
......
...@@ -168,15 +168,13 @@ def profile(args): ...@@ -168,15 +168,13 @@ def profile(args):
start_time = time.time() start_time = time.time()
frames_seen = 0 frames_seen = 0
# load_data # load_data
(features, labels, lod) = batch_data (features, labels, lod, _) = batch_data
feature_t.set(features.ndarray, place) feature_t.set(features, place)
feature_t.set_lod([lod.ndarray]) feature_t.set_lod([lod])
label_t.set(labels.ndarray, place) label_t.set(labels, place)
label_t.set_lod([lod.ndarray]) label_t.set_lod([lod])
frames_seen += lod.ndarray[-1] frames_seen += lod[-1]
data_reader.recycle(features, labels, lod)
outs = exe.run(fluid.default_main_program(), outs = exe.run(fluid.default_main_program(),
feed={"feature": feature_t, feed={"feature": feature_t,
......
...@@ -192,13 +192,11 @@ def train(args): ...@@ -192,13 +192,11 @@ def train(args):
test_data_reader.batch_iterator(args.batch_size, test_data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)): args.minimum_batch_size)):
# load_data # load_data
(features, labels, lod) = batch_data (features, labels, lod, _) = batch_data
feature_t.set(features.ndarray, place) feature_t.set(features, place)
feature_t.set_lod([lod.ndarray]) feature_t.set_lod([lod])
label_t.set(labels.ndarray, place) label_t.set(labels, place)
label_t.set_lod([lod.ndarray]) label_t.set_lod([lod])
test_data_reader.recycle(features, labels, lod)
cost, acc = exe.run(test_program, cost, acc = exe.run(test_program,
feed={"feature": feature_t, feed={"feature": feature_t,
...@@ -212,6 +210,7 @@ def train(args): ...@@ -212,6 +210,7 @@ def train(args):
# train data reader # train data reader
train_data_reader = reader.AsyncDataReader(args.train_feature_lst, train_data_reader = reader.AsyncDataReader(args.train_feature_lst,
args.train_label_lst, -1) args.train_label_lst, -1)
train_data_reader.set_transformers(ltrans) train_data_reader.set_transformers(ltrans)
# train # train
for pass_id in xrange(args.pass_num): for pass_id in xrange(args.pass_num):
...@@ -220,13 +219,11 @@ def train(args): ...@@ -220,13 +219,11 @@ def train(args):
train_data_reader.batch_iterator(args.batch_size, train_data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)): args.minimum_batch_size)):
# load_data # load_data
(features, labels, lod) = batch_data (features, labels, lod, name_lst) = batch_data
feature_t.set(features.ndarray, place) feature_t.set(features, place)
feature_t.set_lod([lod.ndarray]) feature_t.set_lod([lod])
label_t.set(labels.ndarray, place) label_t.set(labels, place)
label_t.set_lod([lod.ndarray]) label_t.set_lod([lod])
train_data_reader.recycle(features, labels, lod)
to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0) to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0)
outs = exe.run(fluid.default_main_program(), outs = exe.run(fluid.default_main_program(),
......
...@@ -4,10 +4,109 @@ The minimum PaddlePaddle version needed for the code sample in this directory is ...@@ -4,10 +4,109 @@ The minimum PaddlePaddle version needed for the code sample in this directory is
# Advbox # Advbox
Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle. Advbox is a toolbox to generate adversarial examples that fool neural networks and Advbox can benchmark the robustness of machine learning models.
## How to use The Advbox is based on [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) Fluid and is under continual development, always welcoming contributions of the latest method of adversarial attacks and defenses.
1. train a model and save it's parameters. (like fluid_mnist.py)
2. load the parameters which is trained in step1, then reconstruct the model.(like mnist_tutorial_fgsm.py) ## Overview
3. use advbox to generate the adversarial sample. [Szegedy et al.](https://arxiv.org/abs/1312.6199) discovered an intriguing properties of deep neural networks in the context of image classification for the first time. They showed that despite the state-of-the-art deep networks are surprisingly susceptible to adversarial attacks in the form of small perturbations to images that remain (almost) imperceptible to human vision system. These perturbations are found by optimizing the input to maximize the prediction error and the images modified by these perturbations are called as `adversarial examples`. The profound implications of these results triggered a wide interest of researchers in adversarial attacks and their defenses for deep learning in general.
Advbox is similar to [Foolbox](https://github.com/bethgelab/foolbox) and [CleverHans](https://github.com/tensorflow/cleverhans). CleverHans only supports TensorFlow framework while foolbox interfaces with many popular machine learning frameworks such as PyTorch, Keras, TensorFlow, Theano, Lasagne and MXNet. However, these two great libraries don't support PaddlePaddle, an easy-to-use, efficient, flexible and scalable deep learning platform which is originally developed by Baidu scientists and engineers for the purpose of applying deep learning to many products at Baidu.
## Usage
Advbox provides many stable reference implementations of modern methods to generate adversarial examples such as FGSM, DeepFool, JSMA. When you want to benchmark the robustness of your neural networks , you can use the advbox to generate some adversarial examples and benchmark the networks. Some tips of using Advbox:
1. Train a model and save the parameters.
2. Load the parameters which has been trained,then reconstruct the model.
3. Use advbox to generate the adversarial samples.
#### Dependencies
* PaddlePaddle: [the lastest develop branch](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html)
* Python 2.x
#### Structure
Network models, attack method's implements and the criterion that defines adversarial examples are three essential elements to generate adversarial examples. Misclassification is adopted as the adversarial criterion for briefness in Advbox.
The structure of Advbox module are as follows:
.
├── advbox
| ├── __init__.py
| ├── attack
| ├── __init__.py
| ├── base.py
| ├── deepfool.py
| ├── gradient_method.py
| ├── lbfgs.py
| └── saliency.py
| ├── models
| ├── __init__.py
| ├── base.py
| └── paddle.py
| └── adversary.py
├── tutorials
| ├── __init__.py
| ├── mnist_model.py
| ├── mnist_tutorial_lbfgs.py
| ├── mnist_tutorial_fgsm.py
| ├── mnist_tutorial_bim.py
| ├── mnist_tutorial_ilcm.py
| ├── mnist_tutorial_mifgsm.py
| ├── mnist_tutorial_jsma.py
| └── mnist_tutorial_deepfool.py
└── README.md
**advbox.attack**
Advbox implements several popular adversarial attacks which search adversarial examples. Each attack method uses a distance measure(L1, L2, etc.) to quantify the size of adversarial perturbations. Advbox is easy to craft adversarial example as some attack methods could perform internal hyperparameter tuning to find the minimum perturbation.
**advbox.model**
Advbox implements interfaces to PaddlePaddle. Additionally, other deep learning framworks such as TensorFlow can also be defined and employed. The module is use to compute predictions and gradients for given inputs in a specific framework.
**advbox.adversary**
Adversary contains the original object, the target and the adversarial examples. It provides the misclassification as the criterion to accept a adversarial example.
## Tutorials
The `./tutorials/` folder provides some tutorials to generate adversarial examples on the MNIST dataset. You can slightly modify the code to apply to other dataset. These attack methods are supported in Advbox:
* [L-BFGS](https://arxiv.org/abs/1312.6199)
* [FGSM](https://arxiv.org/abs/1412.6572)
* [BIM](https://arxiv.org/abs/1607.02533)
* [ILCM](https://arxiv.org/abs/1607.02533)
* [MI-FGSM](https://arxiv.org/pdf/1710.06081.pdf)
* [JSMA](https://arxiv.org/pdf/1511.07528)
* [DeepFool](https://arxiv.org/abs/1511.04599)
## Testing
Benchmarks on a vanilla CNN model.
> MNIST
| adversarial attacks | fooling rate (non-targeted) | fooling rate (targeted) | max_epsilon | iterations | Strength |
|:-----:| :----: | :---: | :----: | :----: | :----: |
|L-BFGS| --- | 89.2% | --- | One shot | *** |
|FGSM| 57.8% | 26.55% | 0.3 | One shot| *** |
|BIM| 97.4% | --- | 0.1 | 100 | **** |
|ILCM| --- | 100.0% | 0.1 | 100 | **** |
|MI-FGSM| 94.4% | 100.0% | 0.1 | 100 | **** |
|JSMA| 96.8% | 90.4%| 0.1 | 2000 | *** |
|DeepFool| 97.7% | 51.3% | --- | 100 | **** |
* The strength (higher for more asterisks) is based on the impression from the reviewed literature.
---
## References
* [Intriguing properties of neural networks](https://arxiv.org/abs/1312.6199), C. Szegedy et al., arxiv 2014
* [Explaining and Harnessing Adversarial Examples](https://arxiv.org/abs/1412.6572), I. Goodfellow et al., ICLR 2015
* [Adversarial Examples In The Physical World](https://arxiv.org/pdf/1607.02533v3.pdf), A. Kurakin et al., ICLR workshop 2017
* [Boosting Adversarial Attacks with Momentum](https://arxiv.org/abs/1710.06081), Yinpeng Dong et al., arxiv 2018
* [The Limitations of Deep Learning in Adversarial Settings](https://arxiv.org/abs/1511.07528), N. Papernot et al., ESSP 2016
* [DeepFool: a simple and accurate method to fool deep neural networks](https://arxiv.org/abs/1511.04599), S. Moosavi-Dezfooli et al., CVPR 2016
* [Foolbox: A Python toolbox to benchmark the robustness of machine learning models](https://arxiv.org/abs/1707.04131), Jonas Rauber et al., arxiv 2018
* [CleverHans: An adversarial example library for constructing attacks, building defenses, and benchmarking both](https://github.com/tensorflow/cleverhans#setting-up-cleverhans)
* [Threat of Adversarial Attacks on Deep Learning in Computer Vision: A Survey](https://arxiv.org/abs/1801.00553), Naveed Akhtar, Ajmal Mian, arxiv 2018
...@@ -14,7 +14,8 @@ __all__ = [ ...@@ -14,7 +14,8 @@ __all__ = [
'GradientMethodAttack', 'FastGradientSignMethodAttack', 'FGSM', 'GradientMethodAttack', 'FastGradientSignMethodAttack', 'FGSM',
'FastGradientSignMethodTargetedAttack', 'FGSMT', 'FastGradientSignMethodTargetedAttack', 'FGSMT',
'BasicIterativeMethodAttack', 'BIM', 'BasicIterativeMethodAttack', 'BIM',
'IterativeLeastLikelyClassMethodAttack', 'ILCM' 'IterativeLeastLikelyClassMethodAttack', 'ILCM', 'MomentumIteratorAttack',
'MIFGSM'
] ]
...@@ -32,7 +33,12 @@ class GradientMethodAttack(Attack): ...@@ -32,7 +33,12 @@ class GradientMethodAttack(Attack):
super(GradientMethodAttack, self).__init__(model) super(GradientMethodAttack, self).__init__(model)
self.support_targeted = support_targeted self.support_targeted = support_targeted
def _apply(self, adversary, norm_ord=np.inf, epsilons=0.01, steps=100): def _apply(self,
adversary,
norm_ord=np.inf,
epsilons=0.01,
steps=1,
epsilon_steps=100):
""" """
Apply the gradient attack method. Apply the gradient attack method.
:param adversary(Adversary): :param adversary(Adversary):
...@@ -41,8 +47,11 @@ class GradientMethodAttack(Attack): ...@@ -41,8 +47,11 @@ class GradientMethodAttack(Attack):
Order of the norm, such as np.inf, 1, 2, etc. It can't be 0. Order of the norm, such as np.inf, 1, 2, etc. It can't be 0.
:param epsilons(list|tuple|int): :param epsilons(list|tuple|int):
Attack step size (input variation). Attack step size (input variation).
Largest step size if epsilons is not iterable.
:param steps: :param steps:
The number of iterator steps. The number of attack iteration.
:param epsilon_steps:
The number of Epsilons' iteration for each attack iteration.
:return: :return:
adversary(Adversary): The Adversary object. adversary(Adversary): The Adversary object.
""" """
...@@ -55,7 +64,7 @@ class GradientMethodAttack(Attack): ...@@ -55,7 +64,7 @@ class GradientMethodAttack(Attack):
"This attack method doesn't support targeted attack!") "This attack method doesn't support targeted attack!")
if not isinstance(epsilons, Iterable): if not isinstance(epsilons, Iterable):
epsilons = np.linspace(epsilons, epsilons + 1e-10, num=steps) epsilons = np.linspace(0, epsilons, num=epsilon_steps)
pre_label = adversary.original_label pre_label = adversary.original_label
min_, max_ = self.model.bounds() min_, max_ = self.model.bounds()
...@@ -65,30 +74,33 @@ class GradientMethodAttack(Attack): ...@@ -65,30 +74,33 @@ class GradientMethodAttack(Attack):
self.model.channel_axis() == adversary.original.shape[0] or self.model.channel_axis() == adversary.original.shape[0] or
self.model.channel_axis() == adversary.original.shape[-1]) self.model.channel_axis() == adversary.original.shape[-1])
step = 1 for epsilon in epsilons[:]:
adv_img = adversary.original step = 1
for epsilon in epsilons[:steps]: adv_img = adversary.original
if epsilon == 0.0: if epsilon == 0.0:
continue continue
if adversary.is_targeted_attack: for i in range(steps):
gradient = -self.model.gradient(adv_img, adversary.target_label) if adversary.is_targeted_attack:
else: gradient = -self.model.gradient(adv_img,
gradient = self.model.gradient(adv_img, adversary.target_label)
adversary.original_label) else:
if norm_ord == np.inf: gradient = self.model.gradient(adv_img,
gradient_norm = np.sign(gradient) adversary.original_label)
else: if norm_ord == np.inf:
gradient_norm = gradient / self._norm(gradient, ord=norm_ord) gradient_norm = np.sign(gradient)
else:
adv_img = adv_img + epsilon * gradient_norm * (max_ - min_) gradient_norm = gradient / self._norm(
adv_img = np.clip(adv_img, min_, max_) gradient, ord=norm_ord)
adv_label = np.argmax(self.model.predict(adv_img))
logging.info('step={}, epsilon = {:.5f}, pre_label = {}, ' adv_img = adv_img + epsilon * gradient_norm * (max_ - min_)
'adv_label={}'.format(step, epsilon, pre_label, adv_img = np.clip(adv_img, min_, max_)
adv_label)) adv_label = np.argmax(self.model.predict(adv_img))
if adversary.try_accept_the_example(adv_img, adv_label): logging.info('step={}, epsilon = {:.5f}, pre_label = {}, '
return adversary 'adv_label={}'.format(step, epsilon, pre_label,
step += 1 adv_label))
if adversary.try_accept_the_example(adv_img, adv_label):
return adversary
step += 1
return adversary return adversary
@staticmethod @staticmethod
...@@ -113,7 +125,7 @@ class FastGradientSignMethodTargetedAttack(GradientMethodAttack): ...@@ -113,7 +125,7 @@ class FastGradientSignMethodTargetedAttack(GradientMethodAttack):
Paper link: https://arxiv.org/abs/1412.6572 Paper link: https://arxiv.org/abs/1412.6572
""" """
def _apply(self, adversary, epsilons=0.03): def _apply(self, adversary, epsilons=0.01):
return GradientMethodAttack._apply( return GradientMethodAttack._apply(
self, self,
adversary=adversary, adversary=adversary,
...@@ -144,7 +156,7 @@ class IterativeLeastLikelyClassMethodAttack(GradientMethodAttack): ...@@ -144,7 +156,7 @@ class IterativeLeastLikelyClassMethodAttack(GradientMethodAttack):
Paper link: https://arxiv.org/abs/1607.02533 Paper link: https://arxiv.org/abs/1607.02533
""" """
def _apply(self, adversary, epsilons=0.001, steps=1000): def _apply(self, adversary, epsilons=0.01, steps=1000):
return GradientMethodAttack._apply( return GradientMethodAttack._apply(
self, self,
adversary=adversary, adversary=adversary,
...@@ -164,7 +176,103 @@ class BasicIterativeMethodAttack(IterativeLeastLikelyClassMethodAttack): ...@@ -164,7 +176,103 @@ class BasicIterativeMethodAttack(IterativeLeastLikelyClassMethodAttack):
super(BasicIterativeMethodAttack, self).__init__(model, False) super(BasicIterativeMethodAttack, self).__init__(model, False)
class MomentumIteratorAttack(GradientMethodAttack):
"""
The Momentum Iterative Fast Gradient Sign Method (Dong et al. 2017).
This method won the first places in NIPS 2017 Non-targeted Adversarial
Attacks and Targeted Adversarial Attacks. The original paper used
hard labels for this attack; no label smoothing. inf norm.
Paper link: https://arxiv.org/pdf/1710.06081.pdf
"""
def __init__(self, model, support_targeted=True):
"""
:param model(model): The model to be attacked.
:param support_targeted(bool): Does this attack method support targeted.
"""
super(MomentumIteratorAttack, self).__init__(model)
self.support_targeted = support_targeted
def _apply(self,
adversary,
norm_ord=np.inf,
epsilons=0.1,
steps=100,
epsilon_steps=100,
decay_factor=1):
"""
Apply the momentum iterative gradient attack method.
:param adversary(Adversary):
The Adversary object.
:param norm_ord(int):
Order of the norm, such as np.inf, 1, 2, etc. It can't be 0.
:param epsilons(list|tuple|float):
Attack step size (input variation).
Largest step size if epsilons is not iterable.
:param epsilon_steps:
The number of Epsilons' iteration for each attack iteration.
:param steps:
The number of attack iteration.
:param decay_factor:
The decay factor for the momentum term.
:return:
adversary(Adversary): The Adversary object.
"""
if norm_ord == 0:
raise ValueError("L0 norm is not supported!")
if not self.support_targeted:
if adversary.is_targeted_attack:
raise ValueError(
"This attack method doesn't support targeted attack!")
assert self.model.channel_axis() == adversary.original.ndim
assert (self.model.channel_axis() == 1 or
self.model.channel_axis() == adversary.original.shape[0] or
self.model.channel_axis() == adversary.original.shape[-1])
if not isinstance(epsilons, Iterable):
epsilons = np.linspace(0, epsilons, num=epsilon_steps)
min_, max_ = self.model.bounds()
pre_label = adversary.original_label
for epsilon in epsilons[:]:
if epsilon == 0.0:
continue
step = 1
adv_img = adversary.original
momentum = 0
for i in range(steps):
if adversary.is_targeted_attack:
gradient = -self.model.gradient(adv_img,
adversary.target_label)
else:
gradient = self.model.gradient(adv_img, pre_label)
# normalize gradient
velocity = gradient / self._norm(gradient, ord=1)
momentum = decay_factor * momentum + velocity
if norm_ord == np.inf:
normalized_grad = np.sign(momentum)
else:
normalized_grad = self._norm(momentum, ord=norm_ord)
perturbation = epsilon * normalized_grad
adv_img = adv_img + perturbation
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict(adv_img))
logging.info(
'step={}, epsilon = {:.5f}, pre_label = {}, adv_label={}'
.format(step, epsilon, pre_label, adv_label))
if adversary.try_accept_the_example(adv_img, adv_label):
return adversary
step += 1
return adversary
FGSM = FastGradientSignMethodAttack FGSM = FastGradientSignMethodAttack
FGSMT = FastGradientSignMethodTargetedAttack FGSMT = FastGradientSignMethodTargetedAttack
BIM = BasicIterativeMethodAttack BIM = BasicIterativeMethodAttack
ILCM = IterativeLeastLikelyClassMethodAttack ILCM = IterativeLeastLikelyClassMethodAttack
MIFGSM = MomentumIteratorAttack
"""
FGSM demos on mnist using advbox tool.
"""
import matplotlib.pyplot as plt
import paddle.v2 as paddle
import paddle.fluid as fluid
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import FGSM
from advbox.models.paddle import PaddleModel
def cnn_model(img):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
# conv1 = fluid.nets.conv2d()
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
num_filters=20,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
num_filters=50,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return logits
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(
feed_list=[IMG_NAME, LABEL_NAME],
place=place,
program=fluid.default_main_program())
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME,
logits.name, avg_cost.name, (-1, 1))
att = FGSM(m)
for data in train_reader():
# fgsm attack
adversary = att(Adversary(data[0][0], data[0][1]))
if adversary.is_successful():
plt.imshow(adversary.target, cmap='Greys_r')
plt.show()
# np.save('adv_img', adversary.target)
break
if __name__ == '__main__':
main()
"""
FGSM demos on mnist using advbox tool.
"""
import matplotlib.pyplot as plt
import paddle.v2 as paddle
import paddle.fluid as fluid
import numpy as np
from advbox import Adversary
from advbox.attacks.saliency import SaliencyMapAttack
from advbox.models.paddle import PaddleModel
def cnn_model(img):
"""
Mnist cnn model
Args:
img(Varaible): the input image to be recognized
Returns:
Variable: the label prediction
"""
# conv1 = fluid.nets.conv2d()
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
num_filters=20,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
num_filters=50,
filter_size=5,
pool_size=2,
pool_stride=2,
act='relu')
logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return logits
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(
feed_list=[IMG_NAME, LABEL_NAME],
place=place,
program=fluid.default_main_program())
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME,
logits.name, avg_cost.name, (-1, 1))
attack = SaliencyMapAttack(m)
total_num = 0
success_num = 0
for data in train_reader():
total_num += 1
# adversary.set_target(True, target_label=target_label)
jsma_attack = attack(Adversary(data[0][0], data[0][1]))
if jsma_attack is not None and jsma_attack.is_successful():
# plt.imshow(jsma_attack.target, cmap='Greys_r')
# plt.show()
success_num += 1
print('original_label=%d, adversary examples label =%d' %
(data[0][1], jsma_attack.adversarial_label))
# np.save('adv_img', jsma_attack.adversarial_example)
print('total num = %d, success num = %d ' % (total_num, success_num))
if total_num == 100:
break
if __name__ == '__main__':
main()
"""
A set of tutorials for generating adversarial examples with advbox.
"""
\ No newline at end of file
...@@ -30,8 +30,9 @@ def mnist_cnn_model(img): ...@@ -30,8 +30,9 @@ def mnist_cnn_model(img):
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
act='relu') act='relu')
fc = fluid.layers.fc(input=conv_pool_2, size=50, act='relu')
logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') logits = fluid.layers.fc(input=fc, size=10, act='softmax')
return logits return logits
...@@ -60,7 +61,10 @@ def main(): ...@@ -60,7 +61,10 @@ def main():
paddle.dataset.mnist.train(), buf_size=500), paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
# use CPU
place = fluid.CPUPlace() place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place) feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -74,9 +78,11 @@ def main(): ...@@ -74,9 +78,11 @@ def main():
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost, batch_acc, batch_size]) fetch_list=[avg_cost, batch_acc, batch_size])
pass_acc.add(value=acc, weight=b_size) pass_acc.add(value=acc, weight=b_size)
pass_acc_val = pass_acc.eval()[0]
print("pass_id=" + str(pass_id) + " acc=" + str(acc[0]) + print("pass_id=" + str(pass_id) + " acc=" + str(acc[0]) +
" pass_acc=" + str(pass_acc.eval()[0])) " pass_acc=" + str(pass_acc_val))
if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD: if loss < LOSS_THRESHOLD and pass_acc_val > ACC_THRESHOLD:
# early stop
break break
print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc.eval()[ print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc.eval()[
......
"""
BIM tutorial on mnist using advbox tool.
BIM method iteratively take multiple small steps while adjusting the direction after each step.
It only supports non-targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import BIM
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = BIM(m)
attack_config = {"epsilons": 0.1, "steps": 100}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# BIM non-targeted attack
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# BIM non-targeted attack
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("bim attack done")
if __name__ == '__main__':
main()
"""
DeepFool tutorial on mnist using advbox tool.
Deepfool is a simple and accurate adversarial attack method.
It supports both targeted attack and non-targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.deepfool import DeepFoolAttack
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = DeepFoolAttack(m)
attack_config = {"iterations": 100, "overshoot": 9}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# DeepFool non-targeted attack
adversary = attack(adversary, **attack_config)
# DeepFool targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# DeepFool non-targeted attack
adversary = attack(adversary, **attack_config)
# DeepFool targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("deelfool attack done")
if __name__ == '__main__':
main()
"""
FGSM tutorial on mnist using advbox tool.
FGSM method is non-targeted attack while FGSMT is targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import FGSM
from advbox.attacks.gradient_method import FGSMT
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = FGSM(m)
# attack = FGSMT(m)
attack_config = {"epsilons": 0.3}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# FGSM non-targeted attack
adversary = attack(adversary, **attack_config)
# FGSMT targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# FGSM non-targeted attack
adversary = attack(adversary, **attack_config)
# FGSMT targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("fgsm attack done")
if __name__ == '__main__':
main()
"""
ILCM tutorial on mnist using advbox tool.
ILCM method extends "BIM" to support targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import ILCM
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = ILCM(m)
attack_config = {"epsilons": 0.1, "steps": 100}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
tlabel = 0
adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# ILCM targeted attack
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
tlabel = 0
adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# ILCM targeted attack
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("ilcm attack done")
if __name__ == '__main__':
main()
"""
JSMA tutorial on mnist using advbox tool.
JSMA method supports both targeted attack and non-targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.saliency import JSMA
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = JSMA(m)
attack_config = {
"max_iter": 2000,
"theta": 0.1,
"max_perturbations_per_pixel": 7
}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# JSMA non-targeted attack
adversary = attack(adversary, **attack_config)
# JSMA targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
# JSMA may return None
if adversary is not None and adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# JSMA non-targeted attack
adversary = attack(adversary, **attack_config)
# JSMA targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
# JSMA may return None
if adversary is not None and adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("jsma attack done")
if __name__ == '__main__':
main()
"""
LBFGS tutorial on mnist using advbox tool.
LBFGS method only supports targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.lbfgs import LBFGS
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = LBFGS(m)
attack_config = {"epsilon": 0.001, }
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# LBFGS targeted attack
tlabel = 0
adversary.set_target(is_targeted_attack=True, target_label=tlabel)
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# LBFGS targeted attack
tlabel = 0
adversary.set_target(is_targeted_attack=True, target_label=tlabel)
adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("lbfgs attack done")
if __name__ == '__main__':
main()
"""
MIFGSM tutorial on mnist using advbox tool.
MIFGSM is a broad class of momentum iterative gradient-based methods based on FSGM.
It supports non-targeted attack and targeted attack.
"""
import sys
sys.path.append("..")
import matplotlib.pyplot as plt
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import MIFGSM
from advbox.models.paddle import PaddleModel
from tutorials.mnist_model import mnist_cnn_model
def main():
"""
Advbox demo which demonstrate how to use advbox.
"""
TOTAL_NUM = 500
IMG_NAME = 'img'
LABEL_NAME = 'label'
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
# gradient should flow
img.stop_gradient = False
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
logits = mnist_cnn_model(img)
cost = fluid.layers.cross_entropy(input=logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
# use CPU
place = fluid.CPUPlace()
# use GPU
# place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
BATCH_SIZE = 1
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.test(), buf_size=128 * 10),
batch_size=BATCH_SIZE)
fluid.io.load_params(
exe, "./mnist/", main_program=fluid.default_main_program())
# advbox demo
m = PaddleModel(
fluid.default_main_program(),
IMG_NAME,
LABEL_NAME,
logits.name,
avg_cost.name, (-1, 1),
channel_axis=1)
attack = MIFGSM(m)
attack_config = {
"norm_ord": np.inf,
"epsilons": 0.1,
"steps": 100,
"decay_factor": 1
}
# use train data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in train_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# MIFGSM non-targeted attack
adversary = attack(adversary, **attack_config)
# MIFGSM targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TRAIN_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
# use test data to generate adversarial examples
total_count = 0
fooling_count = 0
for data in test_reader():
total_count += 1
adversary = Adversary(data[0][0], data[0][1])
# MIFGSM non-targeted attack
adversary = attack(adversary, **attack_config)
# MIFGSM targeted attack
# tlabel = 0
# adversary.set_target(is_targeted_attack=True, target_label=tlabel)
# adversary = attack(adversary, **attack_config)
if adversary.is_successful():
fooling_count += 1
print(
'attack success, original_label=%d, adversarial_label=%d, count=%d'
% (data[0][1], adversary.adversarial_label, total_count))
# plt.imshow(adversary.target, cmap='Greys_r')
# plt.show()
# np.save('adv_img', adversary.target)
else:
print('attack failed, original_label=%d, count=%d' %
(data[0][1], total_count))
if total_count >= TOTAL_NUM:
print(
"[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f"
% (fooling_count, total_count,
float(fooling_count) / total_count))
break
print("mifgsm attack done")
if __name__ == '__main__':
main()
...@@ -18,19 +18,19 @@ This tool is used to convert a Caffe model to Fluid model ...@@ -18,19 +18,19 @@ This tool is used to convert a Caffe model to Fluid model
### Tested models ### Tested models
- Lenet on mnist dataset - Lenet
- ResNets:(ResNet-50, ResNet-101, ResNet-152) - ResNets:(ResNet-50, ResNet-101, ResNet-152)
model addr: `https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777`_ [model addr](https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)
- GoogleNet: - GoogleNet:
model addr: `https://gist.github.com/jimmie33/7ea9f8ac0da259866b854460f4526034`_ [model addr](https://gist.github.com/jimmie33/7ea9f8ac0da259866b854460f4526034)
- VGG: - VGG:
model addr: `https://gist.github.com/ksimonyan/211839e770f7b538e2d8`_ [model addr](https://gist.github.com/ksimonyan/211839e770f7b538e2d8)
- AlexNet: - AlexNet:
model addr: `https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet`_ [model addr](https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet)
### Notes ### Notes
Some of this code come from here: https://github.com/ethereon/caffe-tensorflow Some of this code come from here: https://github.com/ethereon/caffe-tensorflow
#!/usr/bin/python
#
#a tool to compare tensors in two files or two directories
#
import sys
import os
def walk_dir(rootdir):
for subdir, dirs, files in os.walk(rootdir):
for file in files:
yield file
def calc_diff(f1, f2):
import numpy as np
d1 = np.load(f1).flatten()
d2 = np.load(f2).flatten()
d1_num = reduce(lambda x, y: x * y, d1.shape)
d2_num = reduce(lambda x, y: x * y, d2.shape)
if d1_num != d2_num:
print d1.shape
print d2.shape
assert (d1_num == d2_num), "their shape is not consistent"
try:
df = np.abs(d1 - d2)
max_df = np.max(df)
sq_df = np.mean(df * df)
return max_df, sq_df
except Exception as e:
return -1.0, -1.0
def compare(path1, path2):
def diff(f1, f2):
max_df, sq_df = calc_diff(f1, f2)
print('compare %s <=> %s with result[max_df:%.4e, sq_df:%.4e]' %
(f1, f2, max_df, sq_df))
assert (max_df < 1e-5), \
'max_df is too large with value[%.6e]' % (max_df)
assert (sq_df < 1e-10), \
'sq_df is too large with value[%.6e]' % (sq_df)
if os.path.exists(path1) is False:
print('not found %s' % (path1))
return 1
elif os.path.exists(path2) is False:
print('not found %s' % (path2))
return 1
if path1.find('.npy') > 0 and path2.find('.npy') > 0:
diff(path1, path2)
return
for f in walk_dir(path2):
if f.find('.npy') < 0:
continue
f1 = os.path.join(path1, f)
f2 = os.path.join(path2, f)
diff(f1, f2)
print('all checking succeed to pass')
return 0
if __name__ == "__main__":
if len(sys.argv) == 1:
path1 = 'lenet.tf/results'
path2 = 'lenet.paddle/results'
elif len(sys.argv) == 3:
path1 = sys.argv[1]
path2 = sys.argv[2]
else:
print('usage:')
print(' %s [path1] [path2]' % (sys.argv[0]))
exit(1)
print('compare inner result in %s %s' % (path1, path2))
exit(compare(path1, path2))
#!/bin/bash
#
#function:
# a tool used to check the difference of models' results generated by caffe model and paddle model
#
#howto:
# bash diff.sh resnet50 #when this has been finished, you can get the difference in precision
#
#notes:
# 0, in order to infer using caffe, we need pycaffe installed
# 1, prepare your caffe model in 'models.caffe/', eg: 'model.caffe/resnet101/resnet101.[prototxt|caffemodel]'
# 2, converted paddle model will be in 'models'
# 3, results of layers will be stored in 'results/${model_name}.[paddle|caffe]'
# 4, only the last layer will be checked by default
model_name="resnet50"
results_root="results/"
if [[ -n $1 ]];then
if [ $1 = "-h" ];then
echo "usage:"
echo " bash $0 [model_name]"
echo " eg:bash $0 resnet50"
exit 0
fi
model_name=$1
fi
mkdir -p $results_root
model_prototxt="models.caffe/$model_name/${model_name}.prototxt"
model_caffemodel="models.caffe/${model_name}/${model_name}.caffemodel"
#1, dump layers' results from paddle
paddle_results="$results_root/${model_name}.paddle"
rm -rf $paddle_results
rm -rf "results.paddle"
bash run.sh $model_name ./models.caffe/$model_name ./models/$model_name
if [[ $? -ne 0 ]] || [[ ! -e "results.paddle" ]];then
echo "not found paddle's results, maybe failed to convert"
exit 1
fi
mv results.paddle $paddle_results
#2, dump layers' results from caffe
caffe_results="$results_root/${model_name}.caffe"
rm -rf $caffe_results
rm -rf "results.caffe"
cfpython ./infer.py caffe $model_prototxt $model_caffemodel $paddle_results/data.npy
if [[ $? -ne 0 ]] || [[ ! -e "results.caffe" ]];then
echo "not found caffe's results, maybe failed to do inference with caffe"
exit 1
fi
mv results.caffe $caffe_results
#3, extract layer names
cat $model_prototxt | grep name | perl -ne 'if(/^\s*name:\s+\"([^\"]+)/){ print $1."\n";}' >.layer_names
#4, compare one by one
for i in $(cat ".layer_names" | tail -n1);do
echo "process $i"
python compare.py $caffe_results/${i}.npy $paddle_results/${i}.npy
done
...@@ -10,8 +10,11 @@ import os ...@@ -10,8 +10,11 @@ import os
import sys import sys
import inspect import inspect
import numpy as np import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
def import_fluid():
import paddle.fluid as fluid
return fluid
def load_data(imgfile, shape): def load_data(imgfile, shape):
...@@ -52,8 +55,10 @@ def build_model(net_file, net_name): ...@@ -52,8 +55,10 @@ def build_model(net_file, net_name):
print(e) print(e)
return None return None
input_name = 'data' fluid = import_fluid()
input_shape = MyNet.input_shapes()[input_name] inputs_dict = MyNet.input_shapes()
input_name = inputs_dict.keys()[0]
input_shape = inputs_dict[input_name]
images = fluid.layers.data(name='image', shape=input_shape, dtype='float32') images = fluid.layers.data(name='image', shape=input_shape, dtype='float32')
#label = fluid.layers.data(name='label', shape=[1], dtype='int64') #label = fluid.layers.data(name='label', shape=[1], dtype='int64')
...@@ -64,7 +69,7 @@ def build_model(net_file, net_name): ...@@ -64,7 +69,7 @@ def build_model(net_file, net_name):
def dump_results(results, names, root): def dump_results(results, names, root):
if os.path.exists(root) is False: if os.path.exists(root) is False:
os.path.mkdir(root) os.mkdir(root)
for i in range(len(names)): for i in range(len(names)):
n = names[i] n = names[i]
...@@ -73,9 +78,12 @@ def dump_results(results, names, root): ...@@ -73,9 +78,12 @@ def dump_results(results, names, root):
np.save(filename + '.npy', res) np.save(filename + '.npy', res)
def infer(net_file, net_name, model_file, imgfile, debug=False): def infer(net_file, net_name, model_file, imgfile, debug=True):
""" do inference using a model which consist 'xxx.py' and 'xxx.npy' """ do inference using a model which consist 'xxx.py' and 'xxx.npy'
""" """
fluid = import_fluid()
#1, build model #1, build model
net, input_shape = build_model(net_file, net_name) net, input_shape = build_model(net_file, net_name)
prediction = net.get_output() prediction = net.get_output()
...@@ -109,34 +117,79 @@ def infer(net_file, net_name, model_file, imgfile, debug=False): ...@@ -109,34 +117,79 @@ def infer(net_file, net_name, model_file, imgfile, debug=False):
fetch_list=fetch_list_var) fetch_list=fetch_list_var)
if debug is True: if debug is True:
dump_path = 'results.layers' dump_path = 'results.paddle'
dump_results(results, fetch_list_name, dump_path) dump_results(results, fetch_list_name, dump_path)
print('all results dumped to [%s]' % (dump_path)) print('all result of layers dumped to [%s]' % (dump_path))
else: else:
result = results[0] result = results[0]
print('predicted class:', np.argmax(result)) print('predicted class:', np.argmax(result))
return 0
def caffe_infer(prototxt, caffemodel, datafile):
""" do inference using pycaffe for debug,
all intermediate results will be dumpped to 'results.caffe'
"""
import caffe
net = caffe.Net(prototxt, caffemodel, caffe.TEST)
input_layer = net.blobs.keys()[0]
print('got name of input layer is:%s' % (input_layer))
input_shape = list(net.blobs[input_layer].data.shape[1:])
if '.npy' in datafile:
np_images = np.load(datafile)
else:
np_images = load_data(datafile, input_shape)
inputs = {input_layer: np_images}
net.forward_all(**inputs)
results = []
names = []
for k, v in net.blobs.items():
k = k.rstrip('_output')
k = k.replace('/', '_')
names.append(k)
results.append(v.data.copy())
dump_path = 'results.caffe'
dump_results(results, names, dump_path)
print('all result of layers dumped to [%s]' % (dump_path))
return 0
if __name__ == "__main__": if __name__ == "__main__":
""" maybe more convenient to use 'run.sh' to call this tool """ maybe more convenient to use 'run.sh' to call this tool
""" """
net_file = 'models/resnet50/resnet50.py' net_file = 'models/resnet50/resnet50.py'
weight_file = 'models/resnet50/resnet50.npy' weight_file = 'models/resnet50/resnet50.npy'
imgfile = 'data/65.jpeg' datafile = 'data/65.jpeg'
net_name = 'ResNet50' net_name = 'ResNet50'
argc = len(sys.argv) argc = len(sys.argv)
if argc == 5: if sys.argv[1] == 'caffe':
if len(sys.argv) != 5:
print('usage:')
print('\tpython %s caffe [prototxt] [caffemodel] [datafile]' %
(sys.argv[0]))
sys.exit(1)
prototxt = sys.argv[2]
caffemodel = sys.argv[3]
datafile = sys.argv[4]
sys.exit(caffe_infer(prototxt, caffemodel, datafile))
elif argc == 5:
net_file = sys.argv[1] net_file = sys.argv[1]
weight_file = sys.argv[2] weight_file = sys.argv[2]
imgfile = sys.argv[3] datafile = sys.argv[3]
net_name = sys.argv[4] net_name = sys.argv[4]
elif argc > 1: elif argc > 1:
print('usage:') print('usage:')
print('\tpython %s [net_file] [weight_file] [imgfile] [net_name]' % print('\tpython %s [net_file] [weight_file] [datafile] [net_name]' %
(sys.argv[0])) (sys.argv[0]))
print('\teg:python %s %s %s %s %s' % (sys.argv[0], net_file, print('\teg:python %s %s %s %s %s' % (sys.argv[0], net_file,
weight_file, imgfile, net_name)) weight_file, datafile, net_name))
sys.exit(1) sys.exit(1)
infer(net_file, net_name, weight_file, imgfile) infer(net_file, net_name, weight_file, datafile)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#function: #function:
# a tool used to: # a tool used to:
# 1, convert a caffe model # 1, convert a caffe model
# 2, do inference using this model # 2, do inference(only in fluid) using this model
# #
#usage: #usage:
# bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50 # bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50
...@@ -65,7 +65,12 @@ if [[ -z $only_convert ]];then ...@@ -65,7 +65,12 @@ if [[ -z $only_convert ]];then
PYTHON=`which python` PYTHON=`which python`
fi fi
imgfile="data/65.jpeg" imgfile="data/65.jpeg"
net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/\"([^\"]+)\"/){ print $1."\n";}'` #FIX ME:
# only look the first line in prototxt file for the name of this network, maybe not correct
net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/^\s*name\s*:\s*\"([^\"]+)\"/){ print $1."\n";}'`
if [[ -z $net_name ]];then
net_name="MyNet"
fi
$PYTHON ./infer.py $net_file $weight_file $imgfile $net_name $PYTHON ./infer.py $net_file $weight_file $imgfile $net_name
ret=$? ret=$?
fi fi
......
...@@ -52,7 +52,10 @@ class Graph(object): ...@@ -52,7 +52,10 @@ class Graph(object):
def __init__(self, nodes=None, name=None): def __init__(self, nodes=None, name=None):
self.nodes = nodes or [] self.nodes = nodes or []
self.node_lut = {node.name: node for node in self.nodes} self.node_lut = {node.name: node for node in self.nodes}
self.name = name if name is None or name == '':
self.name = 'MyNet'
else:
self.name = name
def add_node(self, node): def add_node(self, node):
self.nodes.append(node) self.nodes.append(node)
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
def import_fluid(): def import_fluid():
import paddle.v2.fluid as fluid import paddle.fluid as fluid
return fluid return fluid
...@@ -64,7 +64,7 @@ class Network(object): ...@@ -64,7 +64,7 @@ class Network(object):
if os.path.isdir(data_path): if os.path.isdir(data_path):
assert (exe is not None), \ assert (exe is not None), \
'must provide a executor to load fluid model' 'must provide a executor to load fluid model'
fluid.io.load_persistables_if_exist(executor=exe, dirname=data_path) fluid.io.load_persistables(executor=exe, dirname=data_path)
return True return True
#load model from a npy file #load model from a npy file
...@@ -161,56 +161,28 @@ class Network(object): ...@@ -161,56 +161,28 @@ class Network(object):
output = fluid.layers.relu(x=input) output = fluid.layers.relu(x=input)
return output return output
def _adjust_pad_if_needed(self, i_hw, k_hw, s_hw, p_hw):
#adjust the padding if needed
i_h, i_w = i_hw
k_h, k_w = k_hw
s_h, s_w = s_hw
p_h, p_w = p_hw
def is_consistent(i, k, s, p):
o = i + 2 * p - k
if o % s == 0:
return True
else:
return False
real_p_h = 0
real_p_w = 0
if is_consistent(i_h, k_h, s_h, p_h) is False:
real_p_h = int(k_h / 2)
if is_consistent(i_w, k_w, s_w, p_w) is False:
real_p_w = int(k_w / 2)
return [real_p_h, real_p_w]
def pool(self, pool_type, input, k_h, k_w, s_h, s_w, name, padding): def pool(self, pool_type, input, k_h, k_w, s_h, s_w, name, padding):
# Get the number of channels in the input # Get the number of channels in the input
in_hw = input.shape[2:] in_hw = input.shape[2:]
k_hw = [k_h, k_w] k_hw = [k_h, k_w]
s_hw = [s_h, s_w] s_hw = [s_h, s_w]
if padding is None:
#fix bug about the difference between conv and pool
#more info: https://github.com/BVLC/caffe/issues/1318
padding = self._adjust_pad_if_needed(in_hw, k_hw, s_hw, [0, 0])
fluid = import_fluid() fluid = import_fluid()
output = fluid.layers.pool2d( output = fluid.layers.pool2d(
input=input, input=input,
pool_size=k_hw, pool_size=k_hw,
pool_stride=s_hw, pool_stride=s_hw,
pool_padding=padding, pool_padding=padding,
ceil_mode=True,
pool_type=pool_type) pool_type=pool_type)
return output return output
@layer @layer
def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None): def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=[0, 0]):
return self.pool('max', input, k_h, k_w, s_h, s_w, name, padding) return self.pool('max', input, k_h, k_w, s_h, s_w, name, padding)
@layer @layer
def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None): def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=[0, 0]):
return self.pool('avg', input, k_h, k_w, s_h, s_w, name, padding) return self.pool('avg', input, k_h, k_w, s_h, s_w, name, padding)
@layer @layer
...@@ -258,7 +230,12 @@ class Network(object): ...@@ -258,7 +230,12 @@ class Network(object):
return output return output
@layer @layer
def batch_normalization(self, input, name, scale_offset=True, relu=False): def batch_normalization(self,
input,
name,
scale_offset=True,
eps=1e-5,
relu=False):
# NOTE: Currently, only inference is supported # NOTE: Currently, only inference is supported
fluid = import_fluid() fluid = import_fluid()
prefix = name + '_' prefix = name + '_'
...@@ -276,7 +253,7 @@ class Network(object): ...@@ -276,7 +253,7 @@ class Network(object):
bias_attr=bias_attr, bias_attr=bias_attr,
moving_mean_name=mean_name, moving_mean_name=mean_name,
moving_variance_name=variance_name, moving_variance_name=variance_name,
epsilon=1e-5, epsilon=eps,
act='relu' if relu is True else None) act='relu' if relu is True else None)
return output return output
......
...@@ -142,7 +142,13 @@ class TensorFlowMapper(NodeMapper): ...@@ -142,7 +142,13 @@ class TensorFlowMapper(NodeMapper):
def map_batch_norm(self, node): def map_batch_norm(self, node):
scale_offset = len(node.data) == 4 scale_offset = len(node.data) == 4
kwargs = {} if scale_offset else {'scale_offset': False}
#this default value comes from caffe's param in batch_norm
default_eps = 1e-5
kwargs = {'scale_offset': scale_offset}
if node.parameters.eps != default_eps:
kwargs['eps'] = node.parameters.eps
return MaybeActivated( return MaybeActivated(
node, default=False)('batch_normalization', **kwargs) node, default=False)('batch_normalization', **kwargs)
...@@ -236,7 +242,7 @@ class TensorFlowEmitter(object): ...@@ -236,7 +242,7 @@ class TensorFlowEmitter(object):
func_def = self.statement('@classmethod') func_def = self.statement('@classmethod')
func_def += self.statement('def convert(cls, npy_model, fluid_path):') func_def += self.statement('def convert(cls, npy_model, fluid_path):')
self.indent() self.indent()
func_def += self.statement('import paddle.v2.fluid as fluid') func_def += self.statement('fluid = import_fluid()')
for l in codes: for l in codes:
func_def += self.statement(l) func_def += self.statement(l)
return '\n' + func_def return '\n' + func_def
......
import os
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import reader
def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1,
...@@ -65,20 +63,44 @@ def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio): ...@@ -65,20 +63,44 @@ def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio):
return fluid.layers.elementwise_add(x=short, y=scale, act='relu') return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
def SE_ResNeXt(input, class_dim, infer=False): def SE_ResNeXt(input, class_dim, infer=False, layers=50):
cardinality = 64 supported_layers = [50, 152]
reduction_ratio = 16 if layers not in supported_layers:
depth = [3, 8, 36, 3] print("supported layers are", supported_layers, "but input layer is",
num_filters = [128, 256, 512, 1024] layers)
exit()
conv = conv_bn_layer( if layers == 50:
input=input, num_filters=64, filter_size=3, stride=2, act='relu') cardinality = 32
conv = conv_bn_layer( reduction_ratio = 16
input=conv, num_filters=64, filter_size=3, stride=1, act='relu') depth = [3, 4, 6, 3]
conv = conv_bn_layer( num_filters = [128, 256, 512, 1024]
input=conv, num_filters=128, filter_size=3, stride=1, act='relu')
conv = fluid.layers.pool2d( conv = conv_bn_layer(
input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
elif layers == 152:
cardinality = 64
reduction_ratio = 16
depth = [3, 8, 36, 3]
num_filters = [128, 256, 512, 1024]
conv = conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=2, act='relu')
conv = conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu')
conv = conv_bn_layer(
input=conv, num_filters=128, filter_size=3, stride=1, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)): for block in range(len(depth)):
for i in range(depth[block]): for i in range(depth[block]):
...@@ -97,93 +119,3 @@ def SE_ResNeXt(input, class_dim, infer=False): ...@@ -97,93 +119,3 @@ def SE_ResNeXt(input, class_dim, infer=False):
drop = pool drop = pool
out = fluid.layers.fc(input=drop, size=class_dim, act='softmax') out = fluid.layers.fc(input=drop, size=class_dim, act='softmax')
return out return out
def train(learning_rate,
batch_size,
num_passes,
init_model=None,
model_save_dir='model',
parallel=True):
class_dim = 1000
image_shape = [3, 224, 224]
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
if parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
image_ = pd.read_input(image)
label_ = pd.read_input(label)
out = SE_ResNeXt(input=image_, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label_)
avg_cost = fluid.layers.mean(x=cost)
accuracy = fluid.layers.accuracy(input=out, label=label_)
pd.write_output(avg_cost)
pd.write_output(accuracy)
avg_cost, accuracy = pd()
avg_cost = fluid.layers.mean(x=avg_cost)
accuracy = fluid.layers.mean(x=accuracy)
else:
out = SE_ResNeXt(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
accuracy = fluid.layers.accuracy(input=out, label=label)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
opts = optimizer.minimize(avg_cost)
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program([avg_cost, accuracy])
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if init_model is not None:
fluid.io.load_persistables(exe, init_model)
train_reader = paddle.batch(reader.train(), batch_size=batch_size)
test_reader = paddle.batch(reader.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
for pass_id in range(num_passes):
for batch_id, data in enumerate(train_reader()):
loss = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost])
print("Pass {0}, batch {1}, loss {2}".format(pass_id, batch_id,
float(loss[0])))
total_loss = 0.0
total_acc = 0.0
total_batch = 0
for data in test_reader():
loss, acc = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, accuracy])
total_loss += float(loss)
total_acc += float(acc)
total_batch += 1
print("End pass {0}, test_loss {1}, test_acc {2}".format(
pass_id, total_loss / total_batch, total_acc / total_batch))
model_path = os.path.join(model_save_dir, str(pass_id))
fluid.io.save_inference_model(model_path, ['image'], [out], exe)
if __name__ == '__main__':
train(
learning_rate=0.1,
batch_size=8,
num_passes=100,
init_model=None,
parallel=False)
import os
import numpy as np
import time
import sys
import paddle.v2 as paddle
import paddle.fluid as fluid
from se_resnext import SE_ResNeXt
import reader
import argparse
import functools
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 256, "Minibatch size.")
add_arg('num_layers', int, 50, "How many layers for SE-ResNeXt model.")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('parallel_exe', bool, True, "Whether to use ParallelExecutor to train or not.")
def train_paralle_do(args,
learning_rate,
batch_size,
num_passes,
init_model=None,
model_save_dir='model',
parallel=True,
use_nccl=True,
lr_strategy=None,
layers=50):
class_dim = 1000
image_shape = [3, 224, 224]
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
if parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
with pd.do():
image_ = pd.read_input(image)
label_ = pd.read_input(label)
out = SE_ResNeXt(input=image_, class_dim=class_dim, layers=layers)
cost = fluid.layers.cross_entropy(input=out, label=label_)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label_, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label_, k=5)
pd.write_output(avg_cost)
pd.write_output(acc_top1)
pd.write_output(acc_top5)
avg_cost, acc_top1, acc_top5 = pd()
avg_cost = fluid.layers.mean(x=avg_cost)
acc_top1 = fluid.layers.mean(x=acc_top1)
acc_top5 = fluid.layers.mean(x=acc_top5)
else:
out = SE_ResNeXt(input=image, class_dim=class_dim, layers=layers)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
if lr_strategy is None:
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
else:
bd = lr_strategy["bd"]
lr = lr_strategy["lr"]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
inference_program = fluid.default_main_program().clone(for_test=True)
opts = optimizer.minimize(avg_cost)
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(inference_program)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if init_model is not None:
fluid.io.load_persistables(exe, init_model)
train_reader = paddle.batch(reader.train(), batch_size=batch_size)
test_reader = paddle.batch(reader.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
for pass_id in range(num_passes):
train_info = [[], [], []]
test_info = [[], [], []]
for batch_id, data in enumerate(train_reader()):
t1 = time.time()
loss, acc1, acc5 = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
t2 = time.time()
period = t2 - t1
train_info[0].append(loss[0])
train_info[1].append(acc1[0])
train_info[2].append(acc5[0])
if batch_id % 10 == 0:
print("Pass {0}, trainbatch {1}, loss {2}, \
acc1 {3}, acc5 {4} time {5}"
.format(pass_id, \
batch_id, loss[0], acc1[0], acc5[0], \
"%2.2f sec" % period))
sys.stdout.flush()
train_loss = np.array(train_info[0]).mean()
train_acc1 = np.array(train_info[1]).mean()
train_acc5 = np.array(train_info[2]).mean()
for data in test_reader():
t1 = time.time()
loss, acc1, acc5 = exe.run(
inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
t2 = time.time()
period = t2 - t1
test_info[0].append(loss[0])
test_info[1].append(acc1[0])
test_info[2].append(acc5[0])
if batch_id % 10 == 0:
print("Pass {0},testbatch {1},loss {2}, \
acc1 {3},acc5 {4},time {5}"
.format(pass_id, \
batch_id, loss[0], acc1[0], acc5[0], \
"%2.2f sec" % period))
sys.stdout.flush()
test_loss = np.array(test_info[0]).mean()
test_acc1 = np.array(test_info[1]).mean()
test_acc5 = np.array(test_info[2]).mean()
print("End pass {0}, train_loss {1}, train_acc1 {2}, train_acc5 {3}, \
test_loss {4}, test_acc1 {5}, test_acc5 {6}"
.format(pass_id, \
train_loss, train_acc1, train_acc5, test_loss, test_acc1, \
test_acc5))
sys.stdout.flush()
model_path = os.path.join(model_save_dir, str(pass_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
def train_parallel_exe(args,
learning_rate,
batch_size,
num_passes,
init_model=None,
model_save_dir='model',
parallel=True,
use_nccl=True,
lr_strategy=None,
layers=50):
class_dim = 1000
image_shape = [3, 224, 224]
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = SE_ResNeXt(input=image, class_dim=class_dim, layers=layers)
cost = fluid.layers.cross_entropy(input=out, label=label)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
avg_cost = fluid.layers.mean(x=cost)
test_program = fluid.default_main_program().clone(for_test=True)
if lr_strategy is None:
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
else:
bd = lr_strategy["bd"]
lr = lr_strategy["lr"]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
opts = optimizer.minimize(avg_cost)
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(test_program)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if init_model is not None:
fluid.io.load_persistables(exe, init_model)
train_reader = paddle.batch(reader.train(), batch_size=batch_size)
test_reader = paddle.batch(reader.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
test_exe = fluid.ParallelExecutor(
use_cuda=True,
main_program=test_program,
share_vars_from=train_exe)
fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
for pass_id in range(num_passes):
train_info = [[], [], []]
test_info = [[], [], []]
for batch_id, data in enumerate(train_reader()):
t1 = time.time()
loss, acc1, acc5 = train_exe.run(
fetch_list,
feed_dict=feeder.feed(data))
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
acc1 = np.mean(np.array(acc1))
acc5 = np.mean(np.array(acc5))
train_info[0].append(loss)
train_info[1].append(acc1)
train_info[2].append(acc5)
if batch_id % 10 == 0:
print("Pass {0}, trainbatch {1}, loss {2}, \
acc1 {3}, acc5 {4} time {5}"
.format(pass_id, \
batch_id, loss, acc1, acc5, \
"%2.2f sec" % period))
sys.stdout.flush()
train_loss = np.array(train_info[0]).mean()
train_acc1 = np.array(train_info[1]).mean()
train_acc5 = np.array(train_info[2]).mean()
for data in test_reader():
t1 = time.time()
loss, acc1, acc5 = test_exe.run(
fetch_list,
feed_dict=feeder.feed(data))
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
acc1 = np.mean(np.array(acc1))
acc5 = np.mean(np.array(acc5))
test_info[0].append(loss)
test_info[1].append(acc1)
test_info[2].append(acc5)
if batch_id % 10 == 0:
print("Pass {0},testbatch {1},loss {2}, \
acc1 {3},acc5 {4},time {5}"
.format(pass_id, \
batch_id, loss, acc1, acc5, \
"%2.2f sec" % period))
sys.stdout.flush()
test_loss = np.array(test_info[0]).mean()
test_acc1 = np.array(test_info[1]).mean()
test_acc5 = np.array(test_info[2]).mean()
print("End pass {0}, train_loss {1}, train_acc1 {2}, train_acc5 {3}, \
test_loss {4}, test_acc1 {5}, test_acc5 {6}"
.format(pass_id, \
train_loss, train_acc1, train_acc5, test_loss, test_acc1, \
test_acc5))
sys.stdout.flush()
model_path = os.path.join(model_save_dir, str(pass_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
epoch_points = [30, 60, 90]
total_images = 1281167
batch_size = args.batch_size
step = int(total_images / batch_size + 1)
bd = [e * step for e in epoch_points]
lr = [0.1, 0.01, 0.001, 0.0001]
lr_strategy = {"bd": bd, "lr": lr}
use_nccl = True
# layers: 50, 152
layers = args.num_layers
method = train_parallel_exe if args.parallel_exe else train_parallel_do
method(args,
learning_rate=0.1,
batch_size=batch_size,
num_passes=120,
init_model=None,
parallel=True,
use_nccl=True,
lr_strategy=lr_strategy,
layers=layers)
"""A dummy reader for test.""" """Contains common utility functions."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,40 +13,50 @@ ...@@ -13,40 +13,50 @@
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import numpy as np import numpy as np
import paddle.v2 as paddle from paddle.fluid import core
DATA_SHAPE = [1, 512, 512]
NUM_CLASSES = 20
def print_arguments(args):
"""Print argparse's arguments.
def _read_creater(num_sample=1024, min_seq_len=1, max_seq_len=10): Usage:
def reader():
for i in range(num_sample):
sequence_len = np.random.randint(min_seq_len, max_seq_len)
x = np.random.uniform(0.1, 1, DATA_SHAPE).astype("float32")
y = np.random.randint(0, NUM_CLASSES + 1,
[sequence_len]).astype("int32")
yield x, y
return reader .. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
def train(batch_size, num_sample=128): :param args: Input argparse.Namespace for printing.
"""Get train dataset reader.""" :type args: argparse.Namespace
return paddle.batch(_read_creater(num_sample=num_sample), batch_size) """
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).iteritems()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def test(batch_size=1, num_sample=16): def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Get test dataset reader.""" """Add argparse's argument.
return paddle.batch(_read_creater(num_sample=num_sample), batch_size)
Usage:
def data_shape(): .. code-block:: python
"""Get image shape in CHW order."""
return DATA_SHAPE
parser = argparse.ArgumentParser()
def num_classes(): add_argument("name", str, "Jonh", "User name.", parser)
"""Get number of total classes.""" args = parser.parse_args()
return NUM_CLASSES """
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
...@@ -15,6 +15,9 @@ class TrainTaskConfig(object): ...@@ -15,6 +15,9 @@ class TrainTaskConfig(object):
# the parameters for learning rate scheduling. # the parameters for learning rate scheduling.
warmup_steps = 4000 warmup_steps = 4000
# the flag indicating to use average loss or sum loss when training.
use_avg_cost = False
# the directory for saving trained models. # the directory for saving trained models.
model_dir = "trained_models" model_dir = "trained_models"
...@@ -22,8 +25,7 @@ class TrainTaskConfig(object): ...@@ -22,8 +25,7 @@ class TrainTaskConfig(object):
class InferTaskConfig(object): class InferTaskConfig(object):
use_gpu = False use_gpu = False
# the number of examples in one run for sequence generation. # the number of examples in one run for sequence generation.
# currently the batch size can only be set to 1. batch_size = 10
batch_size = 1
# the parameters for beam search. # the parameters for beam search.
beam_size = 5 beam_size = 5
...@@ -31,37 +33,38 @@ class InferTaskConfig(object): ...@@ -31,37 +33,38 @@ class InferTaskConfig(object):
# the number of decoded sentences to output. # the number of decoded sentences to output.
n_best = 1 n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False
# the directory for loading the trained model. # the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model" model_path = "trained_models/pass_1.infer.model"
class ModelHyperParams(object): class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses # This model directly uses paddle.dataset.wmt16 in which <bos>, <eos> and
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has # <unk> token has alreay been added. As for the <pad> token, any token
# alreay been added, but the <pad> token is not added. Transformer requires # included in dict can be used to pad, since the paddings' loss will be
# sequences in a mini-batch are padded to have the same length. A <pad> token is # masked out and make no effect on parameter gradients.
# added into the original dictionary in paddle.dateset.wmt16.
# size of source word dictionary. # size of source word dictionary.
src_vocab_size = 10000 src_vocab_size = 10000
# index for <pad> token in source language.
src_pad_idx = src_vocab_size
# size of target word dictionay # size of target word dictionay
trg_vocab_size = 10000 trg_vocab_size = 10000
# index for <pad> token in target language.
trg_pad_idx = trg_vocab_size
# index for <bos> token # index for <bos> token
bos_idx = 0 bos_idx = 0
# index for <eos> token # index for <eos> token
eos_idx = 1 eos_idx = 1
# index for <unk> token
unk_idx = 2
# position value corresponding to the <pad> token. # max length of sequences.
pos_pad_idx = 0 # The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# max length of sequences. It should plus 1 to include position # token for position encoding.
# padding token for position encoding.
max_length = 50 max_length = 50
# the dimension for word embeddings, which is also the last dimension of # the dimension for word embeddings, which is also the last dimension of
...@@ -92,7 +95,10 @@ pos_enc_param_names = ( ...@@ -92,7 +95,10 @@ pos_enc_param_names = (
encoder_input_data_names = ( encoder_input_data_names = (
"src_word", "src_word",
"src_pos", "src_pos",
"src_slf_attn_bias", ) "src_slf_attn_bias",
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )
# Names of all data layers in decoder listed in order. # Names of all data layers in decoder listed in order.
decoder_input_data_names = ( decoder_input_data_names = (
...@@ -100,6 +106,11 @@ decoder_input_data_names = ( ...@@ -100,6 +106,11 @@ decoder_input_data_names = (
"trg_pos", "trg_pos",
"trg_slf_attn_bias", "trg_slf_attn_bias",
"trg_src_attn_bias", "trg_src_attn_bias",
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
"trg_src_attn_post_softmax_shape",
"enc_output", ) "enc_output", )
# Names of label related data layers listed in order. # Names of label related data layers listed in order.
......
import numpy as np import numpy as np
import paddle.v2 as paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import model import model
...@@ -11,10 +11,26 @@ from config import InferTaskConfig, ModelHyperParams, \ ...@@ -11,10 +11,26 @@ from config import InferTaskConfig, ModelHyperParams, \
from train import pad_batch_data from train import pad_batch_data
def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, def translate_batch(exe,
decoder, dec_in_names, dec_out_names, beam_size, max_length, src_words,
n_best, batch_size, n_head, src_pad_idx, trg_pad_idx, encoder,
bos_idx, eos_idx): enc_in_names,
enc_out_names,
decoder,
dec_in_names,
dec_out_names,
beam_size,
max_length,
n_best,
batch_size,
n_head,
d_model,
src_pad_idx,
trg_pad_idx,
bos_idx,
eos_idx,
unk_idx,
output_unk=True):
""" """
Run the encoder program once and run the decoder program multiple times to Run the encoder program once and run the decoder program multiple times to
implement beam search externally. implement beam search externally.
...@@ -25,9 +41,21 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -25,9 +41,21 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
src_pad_idx, src_pad_idx,
n_head, n_head,
is_target=False, is_target=False,
return_pos=True, is_label=False,
return_attn_bias=True, return_attn_bias=True,
return_max_len=True) return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1], d_model], dtype="int32")
]
# Append the shape inputs to reshape before and after softmax in encoder
# self attention.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1]], dtype="int32"), np.array(
enc_in_data[2].shape, dtype="int32")
]
enc_output = exe.run(encoder, enc_output = exe.run(encoder,
feed=dict(zip(enc_in_names, enc_in_data)), feed=dict(zip(enc_in_names, enc_in_data)),
fetch_list=enc_out_names)[0] fetch_list=enc_out_names)[0]
...@@ -35,13 +63,18 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -35,13 +63,18 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
# Beam Search. # Beam Search.
# To store the beam info. # To store the beam info.
scores = np.zeros((batch_size, beam_size), dtype="float32") scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[]] * batch_size prev_branchs = [[] for i in range(batch_size)]
next_ids = [[]] * batch_size next_ids = [[] for i in range(batch_size)]
# Use beam_map to map the instance idx in batch to beam idx, since the # Use beam_inst_map to map beam idx to the instance idx in batch, since the
# size of feeded batch is changing. # size of feeded batch is changing.
beam_map = range(batch_size) beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(range(batch_size))
}
# Use active_beams to recode the alive.
active_beams = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True): def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
""" """
Decode and select n_best sequences for one instance by backtrace. Decode and select n_best sequences for one instance by backtrace.
""" """
...@@ -53,7 +86,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -53,7 +86,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
seq.append(next_ids[j][k]) seq.append(next_ids[j][k])
k = prev_branchs[j][k] k = prev_branchs[j][k]
seq = seq[::-1] seq = seq[::-1]
seq = [bos_idx] + seq if add_bos else seq # Add the <bos>, since next_ids don't include the <bos>.
seq = [bos_idx] + seq
seqs.append(seq) seqs.append(seq)
return seqs return seqs
...@@ -64,8 +98,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -64,8 +98,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_words = np.array( trg_words = np.array(
[[bos_idx]] * batch_size * beam_size, dtype="int64") [[bos_idx]] * batch_size * beam_size, dtype="int64")
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[2].shape[
-1], enc_in_data[-2], 1 -1], enc_in_data[2], 1
# This is used to remove attention on subsequent words. # This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len, trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len)) trg_max_len))
...@@ -75,22 +109,47 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -75,22 +109,47 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
[-1e9]).astype("float32") [-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences. # This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile( trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :], src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis],
[beam_size, 1, trg_max_len, 1]) [1, beam_size, 1, trg_max_len, 1]).reshape([
enc_output = np.tile(enc_output, [beam_size, 1, 1]) -1, src_slf_attn_bias.shape[1], trg_max_len,
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output src_slf_attn_bias.shape[-1]
])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = np.tile(
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
[-1, enc_output.shape[-2], enc_output.shape[-1]])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams): def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
""" """
Update the input data of decoder mainly by slicing from the previous Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams. input data and dropping the finished instance beams.
""" """
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_cur_len = len(next_ids[0]) + 1 # include the <bos> trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output = dec_in_data
trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array( trg_words = np.array(
[ [
beam_backtrace( beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True)
for beam_idx in active_beams for beam_idx in active_beams
], ],
dtype="int64") dtype="int64")
...@@ -98,6 +157,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -98,6 +157,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_pos = np.array( trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size, [range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1]) dtype="int64").reshape([-1, 1])
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
active_beams_indice = ( active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] + (np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten() np.array(range(beam_size))[np.newaxis, :]).flatten()
...@@ -112,8 +172,27 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -112,8 +172,27 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_src_attn_bias = np.tile(trg_src_attn_bias[ trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1]) [1, 1, trg_cur_len, 1])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[len(active_beams) * beam_size, trg_cur_len, d_model],
dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = enc_output[active_beams_indice, :, :] enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output) enc_output)
...@@ -122,13 +201,18 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -122,13 +201,18 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
feed=dict(zip(dec_in_names, dec_in_data)), feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0] fetch_list=dec_out_names)[0]
predict_all = np.log( predict_all = np.log(
predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:, predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
-1, :]) [:, -1, :])
predict_all = (predict_all + scores[beam_map].reshape( predict_all = (predict_all + scores[active_beams].reshape(
[len(beam_map) * beam_size, -1])).reshape( [len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1]) [len(beam_inst_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = [] active_beams = []
for inst_idx, beam_idx in enumerate(beam_map): for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
continue
inst_idx = beam_inst_map[beam_idx]
predict = (predict_all[inst_idx, :, :] predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten() if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
...@@ -141,13 +225,20 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -141,13 +225,20 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1]) next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx: if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx) active_beams.append(beam_idx)
beam_map = active_beams if len(active_beams) == 0:
if len(beam_map) == 0:
break break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
beam_inst_map)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(active_beams)
}
# Decode beams and select n_best sequences for each instance by backtrace. # Decode beams and select n_best sequences for each instance by backtrace.
seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)] seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]
return seqs, scores[:, :n_best].tolist() return seqs, scores[:, :n_best].tolist()
...@@ -155,29 +246,24 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -155,29 +246,24 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
def main(): def main():
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# The current program desc is coupled with batch_size and the only
# supported batch size is 1 currently.
encoder_program = fluid.Program() encoder_program = fluid.Program()
model.batch_size = InferTaskConfig.batch_size
with fluid.program_guard(main_program=encoder_program): with fluid.program_guard(main_program=encoder_program):
enc_output = encoder( enc_output = encoder(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.dropout)
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size
decoder_program = fluid.Program() decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program): with fluid.program_guard(main_program=decoder_program):
predict = decoder( predict = decoder(
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.dropout)
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
# Load model parameters of encoder and decoder separately from the saved # Load model parameters of encoder and decoder separately from the saved
# transformer model. # transformer model.
...@@ -214,17 +300,51 @@ def main(): ...@@ -214,17 +300,51 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict( trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
for batch_id, data in enumerate(test_data()): for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch( batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], encoder_program, exe,
encoder_input_data_names, [enc_output.name], decoder_program, [item[0] for item in data],
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, encoder_program,
InferTaskConfig.max_length, InferTaskConfig.n_best, encoder_input_data_names,
len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx, [enc_output.name],
ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx, decoder_program,
ModelHyperParams.eos_idx) decoder_input_data_names,
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
output_unk=InferTaskConfig.output_unk)
for i in range(len(batch_seqs)): for i in range(len(batch_seqs)):
seqs = batch_seqs[i] # Post-process the beam-search decoded sequences.
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i] scores = batch_scores[i]
for seq in seqs: for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq])) print(" ".join([trg_idx2word[idx] for idx in seq]))
......
import os import os
import time
import numpy as np import numpy as np
import paddle.v2 as paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from model import transformer, position_encoding_init from model import transformer, position_encoding_init
...@@ -14,7 +15,7 @@ def pad_batch_data(insts, ...@@ -14,7 +15,7 @@ def pad_batch_data(insts,
pad_idx, pad_idx,
n_head, n_head,
is_target=False, is_target=False,
return_pos=True, is_label=False,
return_attn_bias=True, return_attn_bias=True,
return_max_len=True): return_max_len=True):
""" """
...@@ -23,14 +24,20 @@ def pad_batch_data(insts, ...@@ -23,14 +24,20 @@ def pad_batch_data(insts,
""" """
return_list = [] return_list = []
max_len = max(len(inst) for inst in insts) max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array( inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])] return_list += [inst_data.astype("int64").reshape([-1, 1])]
if return_pos: if is_label: # label weight
inst_pos = np.array([[ inst_weight = np.array(
pos_i + 1 if w_i != pad_idx else 0 for pos_i, w_i in enumerate(inst) [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
] for inst in inst_data]) return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
range(1, len(inst) + 1) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])] return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias: if return_attn_bias:
if is_target: if is_target:
...@@ -56,7 +63,7 @@ def pad_batch_data(insts, ...@@ -56,7 +63,7 @@ def pad_batch_data(insts,
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
max_length, n_head): n_head, d_model):
""" """
Put all padded data needed by training into a dict. Put all padded data needed by training into a dict.
""" """
...@@ -66,13 +73,40 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -66,13 +73,40 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head,
False, False, False, False) # These shape tensors are used in reshape_op.
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) src_data_shape = np.array([len(insts), src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([len(insts), trg_max_len, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
src_slf_attn_bias.shape, dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
lbl_word, lbl_weight = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False)
input_dict = dict( input_dict = dict(
zip(input_data_names, [ zip(input_data_names, [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, src_word, src_pos, src_slf_attn_bias, src_data_shape,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape,
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
trg_data_shape, trg_slf_attn_pre_softmax_shape,
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape,
trg_src_attn_post_softmax_shape, lbl_word, lbl_weight
])) ]))
return input_dict return input_dict
...@@ -81,14 +115,12 @@ def main(): ...@@ -81,14 +115,12 @@ def main():
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
cost, predict = transformer( sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
ModelHyperParams.dropout, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, place, TrainTaskConfig.warmup_steps, place,
...@@ -98,7 +130,7 @@ def main(): ...@@ -98,7 +130,7 @@ def main():
beta1=TrainTaskConfig.beta1, beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2, beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps) epsilon=TrainTaskConfig.eps)
optimizer.minimize(cost) optimizer.minimize(avg_cost if TrainTaskConfig.use_avg_cost else sum_cost)
train_data = paddle.batch( train_data = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
...@@ -110,27 +142,31 @@ def main(): ...@@ -110,27 +142,31 @@ def main():
# Program to do validation. # Program to do validation.
test_program = fluid.default_main_program().clone() test_program = fluid.default_main_program().clone()
with fluid.program_guard(test_program): with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([cost]) test_program = fluid.io.get_inference_program([avg_cost])
val_data = paddle.batch( val_data = paddle.batch(
paddle.dataset.wmt16.validation(ModelHyperParams.src_vocab_size, paddle.dataset.wmt16.validation(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size), ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size) batch_size=TrainTaskConfig.batch_size)
def test(exe): def test(exe):
test_costs = [] test_total_cost = 0
test_total_token = 0
for batch_id, data in enumerate(val_data()): for batch_id, data in enumerate(val_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head) ModelHyperParams.d_model)
test_cost = exe.run(test_program, test_sum_cost, test_token_num = exe.run(
feed=data_input, test_program,
fetch_list=[cost])[0] feed=data_input,
test_costs.append(test_cost) fetch_list=[sum_cost, token_num],
return np.mean(test_costs) use_program_cache=True)
test_total_cost += test_sum_cost
test_total_token += test_token_num
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
# Initialize the parameters. # Initialize the parameters.
exe.run(fluid.framework.default_startup_program()) exe.run(fluid.framework.default_startup_program())
...@@ -142,27 +178,30 @@ def main(): ...@@ -142,27 +178,30 @@ def main():
ModelHyperParams.d_model), place) ModelHyperParams.d_model), place)
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
# The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently.
if len(data) != TrainTaskConfig.batch_size: if len(data) != TrainTaskConfig.batch_size:
continue continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head) ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input) lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(), outs = exe.run(fluid.framework.default_main_program(),
feed=data_input, feed=data_input,
fetch_list=[cost], fetch_list=[sum_cost, avg_cost],
use_program_cache=True) use_program_cache=True)
cost_val = np.array(outs[0]) sum_cost_val, avg_cost_val = np.array(outs[0]), np.array(outs[1])
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) + print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
" cost = " + str(cost_val)) (pass_id, batch_id, sum_cost_val, avg_cost_val,
np.exp([min(avg_cost_val[0], 100)])))
# Validate and save the model for inference. # Validate and save the model for inference.
val_cost = test(exe) val_avg_cost, val_ppl = test(exe)
print("pass_id = " + str(pass_id) + " val_cost = " + str(val_cost)) pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
print("epoch: %d, val avg loss: %f, val ppl: %f, "
"consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
fluid.io.save_inference_model( fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir, os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"), "pass_" + str(pass_id) + ".infer.model"),
......
./data/pascalvoc/VOCdevkit/
data/pascalvoc/test.txt
data/pascalvoc/trainval.txt
pretrained/ssd_mobilenet_v1_coco.tar.gz
pretrained/ssd_mobilenet_v1_coco
pretrained/mobilenet_v1_imagenet.tar.gz
pretrained/mobilenet_v1_imagenet
log*
...@@ -2,7 +2,99 @@ The minimum PaddlePaddle version needed for the code sample in this directory is ...@@ -2,7 +2,99 @@ The minimum PaddlePaddle version needed for the code sample in this directory is
--- ---
# MobileNet-SSD ## SSD Object Detection
This model built with paddle fluid is still under active development and is not ### Introduction
the final version. We welcome feedbacks.
[Single Shot MultiBox Detector (SSD)](https://arxiv.org/abs/1512.02325) framework for object detection is based on a feed-forward convolutional network. The early network is a standard convolutional architecture for image classification, such as VGG, ResNet, or MobileNet, which is als called base network. In this tutorial we used [MobileNet](https://arxiv.org/abs/1704.04861).
### Data Preparation
You can use [PASCAL VOC dataset](http://host.robots.ox.ac.uk/pascal/VOC/) or [MS-COCO dataset](http://cocodataset.org/#download).
#### PASCAL VOC Dataset
If you want to train model on PASCAL VOC dataset, please download datset at first, skip this step if you already have one.
```bash
cd data/pascalvoc
./download.sh
```
The command `download.sh` also will create training and testing file lists.
#### MS-COCO Dataset
If you want to train model on MS-COCO dataset, please download datset at first, skip this step if you already have one.
```
cd data/coco
./download.sh
```
### Train
#### Download the Pre-trained Model.
We provide two pre-trained models. The one is MobileNet-v1 SSD trained on COCO dataset, but removed the convolutional predictors for COCO dataset. This model can be used to initialize the models when training other dataset, like PASCAL VOC. Then other pre-trained model is MobileNet v1 trained on ImageNet 2012 dataset, but removed the last weights and bias in Fully-Connected layer.
Declaration: the MobileNet-v1 SSD model is converted by [TensorFlow model](https://github.com/tensorflow/models/blob/f87a58cd96d45de73c9a8330a06b2ab56749a7fa/research/object_detection/g3doc/detection_model_zoo.md). The MobileNet v1 model is converted [Caffe](https://github.com/shicai/MobileNet-Caffe).
- Download MobileNet-v1 SSD:
```
./pretrained/download_coco.sh
```
- Download MobileNet-v1:
```
./pretrained/download_imagenet.sh
```
#### Train on PASCAL VOC
- Train on one device (/GPU).
```python
env CUDA_VISIABLE_DEVICES=0 python -u train.py --parallel=False --data='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
```
- Train on multi devices (/GPUs).
```python
env CUDA_VISIABLE_DEVICES=0,1 python -u train.py --batch_size=64 --data='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
```
#### Train on MS-COCO
- Train on one device (/GPU).
```python
env CUDA_VISIABLE_DEVICES=0 python -u train.py --parallel=False --data='coco' --pretrained_model='pretrained/mobilenet_imagenet/'
```
- Train on multi devices (/GPUs).
```python
env CUDA_VISIABLE_DEVICES=0,1 python -u train.py --batch_size=64 --data='coco' --pretrained_model='pretrained/mobilenet_imagenet/'
```
TBD
### Evaluate
```python
env CUDA_VISIABLE_DEVICES=0 python eval.py --model='model/90' --test_list=''
```
TBD
### Infer and Visualize
```python
env CUDA_VISIABLE_DEVICES=0 python infer.py --batch_size=2 --model='model/90' --test_list=''
```
TBD
### Released Model
| Model | Pre-trained Model | Training data | Test data | mAP |
|:------------------------:|:------------------:|:----------------:|:------------:|:----:|
|MobileNet-v1-SSD 300x300 | COCO MobileNet SSD | VOC07+12 trainval| VOC07 test | xx% |
|MobileNet-v1-SSD 300x300 | ImageNet MobileNet | VOC07+12 trainval| VOC07 test | xx% |
|MobileNet-v1-SSD 300x300 | ImageNet MobileNet | MS-COCO trainval | MS-COCO test | xx% |
TBD
...@@ -60,4 +60,5 @@ def prepare_filelist(devkit_dir, years, output_dir): ...@@ -60,4 +60,5 @@ def prepare_filelist(devkit_dir, years, output_dir):
ftest.write(item[0] + ' ' + item[1] + '\n') ftest.write(item[0] + ' ' + item[1] + '\n')
prepare_filelist(devkit_dir, years, '.') if __name__ == '__main__':
prepare_filelist(devkit_dir, years, '.')
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
# Download the data.
echo "Downloading..."
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
# Extract the data.
echo "Extractint..."
tar -xf VOCtrainval_11-May-2012.tar
tar -xf VOCtrainval_06-Nov-2007.tar
tar -xf VOCtest_06-Nov-2007.tar
echo "Creating data lists..."
python create_list.py
...@@ -85,8 +85,7 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): ...@@ -85,8 +85,7 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
return False return False
def generate_batch_samples(batch_sampler, bbox_labels, image_width, def generate_batch_samples(batch_sampler, bbox_labels):
image_height):
sampled_bbox = [] sampled_bbox = []
index = [] index = []
c = 0 c = 0
...@@ -216,9 +215,9 @@ def distort_image(img, settings): ...@@ -216,9 +215,9 @@ def distort_image(img, settings):
def expand_image(img, bbox_labels, img_width, img_height, settings): def expand_image(img, bbox_labels, img_width, img_height, settings):
prob = random.uniform(0, 1) prob = random.uniform(0, 1)
if prob < settings._hue_prob: if prob < settings._expand_prob:
expand_ratio = random.uniform(1, settings._expand_max_ratio) if _expand_max_ratio - 1 >= 0.01:
if expand_ratio - 1 >= 0.01: expand_ratio = random.uniform(1, settings._expand_max_ratio)
height = int(img_height * expand_ratio) height = int(img_height * expand_ratio)
width = int(img_width * expand_ratio) width = int(img_width * expand_ratio)
h_off = math.floor(random.uniform(0, height - img_height)) h_off = math.floor(random.uniform(0, height - img_height))
...@@ -231,5 +230,5 @@ def expand_image(img, bbox_labels, img_width, img_height, settings): ...@@ -231,5 +230,5 @@ def expand_image(img, bbox_labels, img_width, img_height, settings):
expand_img = Image.fromarray(expand_img) expand_img = Image.fromarray(expand_img)
expand_img.paste(img, (int(w_off), int(h_off))) expand_img.paste(img, (int(w_off), int(h_off)))
bbox_labels = transform_labels(bbox_labels, expand_bbox) bbox_labels = transform_labels(bbox_labels, expand_bbox)
return expand_img, bbox_labels return expand_img, bbox_labels, width, height
return img, bbox_labels return img, bbox_labels, img_width, img_height
import paddle.v2 as paddle
import paddle.fluid as fluid
import numpy as np
# From npy
def load_vars():
vars = {}
name_map = {}
with open('./ssd_mobilenet_v1_coco/names.map', 'r') as map_file:
for param in map_file:
fd_name, tf_name = param.strip().split('\t')
name_map[fd_name] = tf_name
tf_vars = np.load(
'./ssd_mobilenet_v1_coco/ssd_mobilenet_v1_coco_2017_11_17.npy').item()
for fd_name in name_map:
tf_name = name_map[fd_name]
tf_var = tf_vars[tf_name]
if len(tf_var.shape) == 4 and 'depthwise' in tf_name:
vars[fd_name] = np.transpose(tf_var, (2, 3, 0, 1))
elif len(tf_var.shape) == 4:
vars[fd_name] = np.transpose(tf_var, (3, 2, 0, 1))
else:
vars[fd_name] = tf_var
return vars
def load_and_set_vars(place):
vars = load_vars()
for k, v in vars.items():
t = fluid.global_scope().find_var(k).get_tensor()
#print(np.array(t).shape, v.shape, k)
assert np.array(t).shape == v.shape
t.set(v, place)
# From Paddle V1
def load_paddlev1_vars(place):
vars = {}
name_map = {}
with open('./caffe2paddle/names.map', 'r') as map_file:
for param in map_file:
fd_name, tf_name = param.strip().split('\t')
name_map[fd_name] = tf_name
from operator import mul
def load(file_name, shape):
with open(file_name, 'rb') as f:
f.read(16)
arr = np.fromfile(f, dtype=np.float32)
#print(arr.size, reduce(mul, shape), file_name)
assert arr.size == reduce(mul, shape)
return arr.reshape(shape)
for fd_name in name_map:
v1_name = name_map[fd_name]
t = fluid.global_scope().find_var(fd_name).get_tensor()
shape = np.array(t).shape
v1_var = load('./caffe2paddle/' + v1_name, shape)
t.set(v1_var, place)
if __name__ == "__main__":
load_vars()
...@@ -27,12 +27,7 @@ def conv_bn(input, ...@@ -27,12 +27,7 @@ def conv_bn(input,
bias_attr=False) bias_attr=False)
parameter_attr = ParamAttr(learning_rate=0.1, initializer=MSRA()) parameter_attr = ParamAttr(learning_rate=0.1, initializer=MSRA())
bias_attr = ParamAttr(learning_rate=0.2) bias_attr = ParamAttr(learning_rate=0.2)
return fluid.layers.batch_norm( return fluid.layers.batch_norm(input=conv, act=act)
input=conv,
act=act,
epsilon=0.00001,
param_attr=parameter_attr,
bias_attr=bias_attr)
def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride, def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride,
...@@ -76,7 +71,7 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale): ...@@ -76,7 +71,7 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale):
return normal_conv return normal_conv
def mobile_net(img, img_shape, scale=1.0): def mobile_net(num_classes, img, img_shape, scale=1.0):
# 300x300 # 300x300
tmp = conv_bn(img, 3, int(32 * scale), 2, 1, 3) tmp = conv_bn(img, 3, int(32 * scale), 2, 1, 3)
# 150x150 # 150x150
...@@ -104,10 +99,11 @@ def mobile_net(img, img_shape, scale=1.0): ...@@ -104,10 +99,11 @@ def mobile_net(img, img_shape, scale=1.0):
module16 = extra_block(module15, 128, 256, 1, 2, scale) module16 = extra_block(module15, 128, 256, 1, 2, scale)
# 2x2 # 2x2
module17 = extra_block(module16, 64, 128, 1, 2, scale) module17 = extra_block(module16, 64, 128, 1, 2, scale)
mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head( mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=[module11, module13, module14, module15, module16, module17], inputs=[module11, module13, module14, module15, module16, module17],
image=img, image=img,
num_classes=21, num_classes=num_classes,
min_ratio=20, min_ratio=20,
max_ratio=90, max_ratio=90,
min_sizes=[60.0, 105.0, 150.0, 195.0, 240.0, 285.0], min_sizes=[60.0, 105.0, 150.0, 195.0, 240.0, 285.0],
......
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
# Download the data.
echo "Downloading..."
wget http://paddlemodels.bj.bcebos.com/ssd_mobilenet_v1_coco.tar.gz
echo "Extractint..."
tar -xf ssd_mobilenet_v1_coco.tar.gz
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
# Download the data.
echo "Downloading..."
wget http://paddlemodels.bj.bcebos.com/mobilenet_v1_imagenet.tar.gz
echo "Extractint..."
tar -xf mobilenet_v1_imagenet.tar.gz
此差异已折叠。
此差异已折叠。
# OCR Model 
[toc]
This model built with paddle fluid is still under active development and is not 运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照安装文档中的说明更新PaddlePaddle安装版本。
the final version. We welcome feedbacks.
# Optical Character Recognition
这里将介绍如何在PaddlePaddle Fluid下使用CRNN-CTC 和 CRNN-Attention模型对图片中的文字内容进行识别。
## 1. CRNN-CTC
本章的任务是识别含有单行汉语字符图片,首先采用卷积将图片转为特征图, 然后使用`im2sequence op`将特征图转为序列,通过`双向GRU`学习到序列特征。训练过程选用的损失函数为CTC(Connectionist Temporal Classification) loss,最终的评估指标为样本级别的错误率。
本路径下各个文件的作用如下:
- **ctc_reader.py :** 下载、读取、处理数据。提供方法`train()``test()` 分别产生训练集和测试集的数据迭代器。
- **crnn_ctc_model.py :** 在该脚本中定义了训练网络、预测网络和evaluate网络。
- **ctc_train.py :** 用于模型的训练,可通过命令`python train.py --help` 获得使用方法。
- **infer.py :** 加载训练好的模型文件,对新数据进行预测。可通过命令`python infer.py --help` 获得使用方法。
- **eval.py :** 评估模型在指定数据集上的效果。可通过命令`python infer.py --help` 获得使用方法。
- **utility.py :** 实现的一些通用方法,包括参数配置、tensor的构造等。
### 1.1 数据
数据的下载和简单预处理都在`ctc_reader.py`中实现。
#### 1.1.1 数据格式
我们使用的训练和测试数据如`图1`所示,每张图片包含单行不定长的中文字符串,这些图片都是经过检测算法进行预框选处理的。
<p align="center">
<img src="images/demo.jpg" width="620" hspace='10'/> <br/>
<strong>图 1</strong>
</p>
在训练集中,每张图片对应的label是汉字在词典中的索引。 `图1` 对应的label如下所示:
```
3835,8371,7191,2369,6876,4162,1938,168,1517,4590,3793
```
在上边这个label中,`3835` 表示字符‘两’的索引,`4590` 表示中文字符逗号的索引。
#### 1.1.2 数据准备
**A. 训练集**
我们需要把所有参与训练的图片放入同一个文件夹,暂且记为`train_images`。然后用一个list文件存放每张图片的信息,包括图片大小、图片名称和对应的label,这里暂记该list文件为`train_list`,其格式如下所示:
```
185 48 00508_0215.jpg 7740,5332,2369,3201,4162
48 48 00197_1893.jpg 6569
338 48 00007_0219.jpg 4590,4788,3015,1994,3402,999,4553
150 48 00107_4517.jpg 5936,3382,1437,3382
...
157 48 00387_0622.jpg 2397,1707,5919,1278
```
<center>文件train_list</center>
上述文件中的每一行表示一张图片,每行被空格分为四列,前两列分别表示图片的宽和高,第三列表示图片的名称,第四列表示该图片对应的sequence label。
最终我们应有以下类似文件结构:
```
|-train_data
|- train_list
|- train_imags
|- 00508_0215.jpg
|- 00197_1893.jpg
|- 00007_0219.jpg
| ...
```
在训练时,我们通过选项`--train_images``--train_list` 分别设置准备好的`train_images``train_list`
>**注:** 如果`--train_images` 和 `--train_list`都未设置或设置为None, ctc_reader.py会自动下载使用[示例数据](http://cloud.dlnel.org/filepub/?uuid=df937251-3c0b-480d-9a7b-0080dfeee65c),并将其缓存到`$HOME/.cache/paddle/dataset/ctc_data/data/` 路径下。
**B. 测试集和评估集**
测试集、评估集的准备方式与训练集相同。
在训练阶段,测试集的路径通过train.py的选项`--test_images``--test_list` 来设置。
在评估时,评估集的路径通过eval.py的选项`--input_images_dir``--input_images_list` 来设置。
**C. 待预测数据集**
预测支持三种形式的输入:
第一种:设置`--input_images_dir``--input_images_list`, 与训练集类似, 只不过list文件中的最后一列可以放任意占位字符或字符串,如下所示:
```
185 48 00508_0215.jpg s
48 48 00197_1893.jpg s
338 48 00007_0219.jpg s
...
```
第二种:仅设置`--input_images_list`, 其中list文件中只需放图片的完整路径,如下所示:
```
data/test_images/00000.jpg
data/test_images/00001.jpg
data/test_images/00003.jpg
```
第三种:从stdin读入一张图片的path,然后进行一次inference.
#### 1.2 训练
使用默认数据在GPU单卡上训练:
```
env CUDA_VISIABLE_DEVICES=0 python ctc_train.py
```
使用默认数据在GPU多卡上训练:
```
env CUDA_VISIABLE_DEVICES=0,1,2,3 python ctc_train.py --parallel=True
```
执行`python ctc_train.py --help`可查看更多使用方式和参数详细说明。
图2为使用默认参数和默认数据集训练的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为在测试集上的样本级错误率。在45轮迭代训练中,最低错误率为第42轮的21.11%.
<p align="center">
<img src="images/train.jpg" width="620" hspace='10'/> <br/>
<strong>图 2</strong>
</p>
### 1.3 评估
通过以下命令调用评估脚本用指定数据集对模型进行评估:
```
env CUDA_VISIBLE_DEVICE=0 python eval.py \
--model_path="./models/model_0" \
--input_images_dir="./eval_data/images/" \
--input_images_list="./eval_data/eval_list\" \
```
执行`python ctc_train.py --help`可查看参数详细说明。
### 1.4 预测
从标准输入读取一张图片的路径,并对齐进行预测:
```
env CUDA_VISIBLE_DEVICE=0 python infer.py \
--model_path="models/model_00044_15000"
```
执行上述命令进行预测的效果如下:
```
----------- Configuration Arguments -----------
use_gpu: True
input_images_dir: None
input_images_list: None
model_path: /home/work/models/fluid/ocr_recognition/models/model_00052_15000
------------------------------------------------
Init model from: /home/work/models/fluid/ocr_recognition/models/model_00052_15000.
Please input the path of image: /home/work/models/fluid/ocr_recognition/data/test_images/00001_0060.jpg
result: [3298 2371 4233 6514 2378 3298 2363]
Please input the path of image: /home/work/models/fluid/ocr_recognition/data/test_images/00001_0429.jpg
result: [2067 2067 8187 8477 5027 7191 2431 1462]
```
从文件中批量读取图片路径,并对其进行预测:
```
env CUDA_VISIBLE_DEVICE=0 python infer.py \
--model_path="models/model_00044_15000" \
--input_images_list="data/test.list"
```
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册