提交 57dd5c7b 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/models into ctc_doc

......@@ -15,9 +15,7 @@ from multiprocessing import Manager, Process
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal
from data_utils.util import SharedNDArray, SharedMemoryPoolManager
from data_utils.util import DaemonProcessGroup, batch_to_ndarray
from data_utils.util import CriticalException, ForceExitWrapper, EpochEndSignal
from data_utils.util import CriticalException, ForceExitWrapper
class SampleInfo(object):
......@@ -32,11 +30,12 @@ class SampleInfo(object):
label_bin_path (str): File containing the label data.
label_size (int): Byte count of the sample's label data.
label_frame_num (int): Label number of the sample.
sample_name (str): Key of the sample
"""
def __init__(self, feature_bin_path, feature_start, feature_size,
feature_frame_num, feature_dim, label_bin_path, label_start,
label_size, label_frame_num):
label_size, label_frame_num, sample_name):
self.feature_bin_path = feature_bin_path
self.feature_start = feature_start
self.feature_size = feature_size
......@@ -47,6 +46,7 @@ class SampleInfo(object):
self.label_start = label_start
self.label_size = label_size
self.label_frame_num = label_frame_num
self.sample_name = sample_name
class SampleInfoBucket(object):
......@@ -69,8 +69,8 @@ class SampleInfoBucket(object):
split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation.
split_sub_sentence_len(int): sub-sentence length is equal to
(split_sub_sentence_len + \
rand() % split_perturb).
(split_sub_sentence_len
+ rand() % split_perturb).
"""
def __init__(self,
......@@ -104,24 +104,33 @@ class SampleInfoBucket(object):
feature_bin_path = self._feature_bin_paths[block_idx]
feature_desc_path = self._feature_desc_paths[block_idx]
label_desc_lines = open(label_desc_path).readlines()
feature_desc_lines = open(feature_desc_path).readlines()
sample_num = int(label_desc_lines[0].split()[1])
assert sample_num == int(feature_desc_lines[0].split()[1])
label_desc_lines = []
if label_desc_path != "":
label_desc_lines = open(label_desc_path).readlines()
sample_num = int(feature_desc_lines[0].split()[1])
if label_desc_path != "":
assert sample_num == int(label_desc_lines[0].split()[1])
for i in xrange(sample_num):
feature_desc_split = feature_desc_lines[i + 1].split()
sample_name = feature_desc_split[0]
feature_start = int(feature_desc_split[2])
feature_size = int(feature_desc_split[3])
feature_frame_num = int(feature_desc_split[4])
feature_dim = int(feature_desc_split[5])
label_desc_split = label_desc_lines[i + 1].split()
label_start = int(label_desc_split[2])
label_size = int(label_desc_split[3])
label_frame_num = int(label_desc_split[4])
assert feature_frame_num == label_frame_num
label_start = -1
label_size = -1
label_frame_num = feature_frame_num
if label_desc_path != "":
label_desc_split = label_desc_lines[i + 1].split()
label_start = int(label_desc_split[2])
label_size = int(label_desc_split[3])
label_frame_num = int(label_desc_split[4])
assert feature_frame_num == label_frame_num
if self._split_sentence_threshold == -1 or \
self._split_perturb == -1 or \
......@@ -131,7 +140,7 @@ class SampleInfoBucket(object):
SampleInfo(feature_bin_path, feature_start,
feature_size, feature_frame_num, feature_dim,
label_bin_path, label_start, label_size,
label_frame_num))
label_frame_num, sample_name))
#split sentence
else:
cur_frame_pos = 0
......@@ -152,16 +161,19 @@ class SampleInfoBucket(object):
* feature_dim * 4, cur_frame_len * feature_dim *
4, cur_frame_len, feature_dim, label_bin_path,
label_start + cur_frame_pos * 4, cur_frame_len *
4, cur_frame_len))
4, cur_frame_len, sample_name))
remain_frame_num -= cur_frame_len
cur_frame_pos += cur_frame_len
if remain_frame_num <= 0:
break
return sample_info_list
class EpochEndSignal():
pass
class AsyncDataReader(object):
"""DataReader provides basic audio sample preprocessing pipeline including
data loading and data augmentation.
......@@ -190,7 +202,7 @@ class AsyncDataReader(object):
def __init__(self,
feature_file_list,
label_file_list,
label_file_list="",
drop_frame_len=512,
proc_num=10,
sample_buffer_size=1024,
......@@ -213,25 +225,30 @@ class AsyncDataReader(object):
self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size
self._proc_num = proc_num
if self._proc_num <= 2:
raise ValueError("Value of `proc_num` should be greater than 2.")
self._sample_proc_num = self._proc_num - 2
self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None:
block_feature_info_lines = open(self._feature_file_list).readlines()
block_label_info_lines = open(self._label_file_list).readlines()
assert len(block_feature_info_lines) == len(block_label_info_lines)
self._block_info_list = []
for i in xrange(0, len(block_feature_info_lines), 2):
block_info = (block_feature_info_lines[i],
block_feature_info_lines[i + 1],
block_label_info_lines[i],
block_label_info_lines[i + 1])
self._block_info_list.append(
map(lambda line: line.strip(), block_info))
if self._label_file_list != "":
block_label_info_lines = open(self._label_file_list).readlines()
assert len(block_feature_info_lines) == len(
block_label_info_lines)
for i in xrange(0, len(block_feature_info_lines), 2):
block_info = (block_feature_info_lines[i],
block_feature_info_lines[i + 1],
block_label_info_lines[i],
block_label_info_lines[i + 1])
self._block_info_list.append(
map(lambda line: line.strip(), block_info))
else:
for i in xrange(0, len(block_feature_info_lines), 2):
block_info = (block_feature_info_lines[i],
block_feature_info_lines[i + 1], "", "")
self._block_info_list.append(
map(lambda line: line.strip(), block_info))
if is_shuffle:
self._rng.shuffle(self._block_info_list)
......@@ -251,23 +268,13 @@ class AsyncDataReader(object):
def set_transformers(self, transformers):
self._transformers = transformers
def recycle(self, *args):
for shared_ndarray in args:
if not isinstance(shared_ndarray, SharedNDArray):
raise Value("Only support recycle SharedNDArray object.")
shared_ndarray.recycle(self._pool_manager.pool)
def _start_async_processing(self):
def _sample_generator(self):
sample_info_queue = self._manager.Queue(self._sample_info_buffer_size)
sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_feeding_task(sample_info_queue):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
for sample_info_bucket in self._bucket_list:
try:
sample_info_list = \
......@@ -280,12 +287,13 @@ class AsyncDataReader(object):
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1
for i in xrange(self._sample_proc_num):
for i in xrange(self._proc_num):
sample_info_queue.put(EpochEndSignal())
feeding_proc = DaemonProcessGroup(
proc_num=1, target=ordered_feeding_task, args=(sample_info_queue, ))
feeding_proc.start_all()
feeding_thread = Thread(
target=ordered_feeding_task, args=(sample_info_queue, ))
feeding_thread.daemon = True
feeding_thread.start()
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
......@@ -313,25 +321,32 @@ class AsyncDataReader(object):
sample_info.feature_size)
assert sample_info.feature_frame_num \
* sample_info.feature_dim * 4 == len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))
label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start,
sample_info.label_size)
assert sample_info.label_frame_num * 4 == len(label_bytes), (
sample_info.label_bin_path, sample_info.label_array,
len(label_bytes))
label_array = struct.unpack('I' * sample_info.label_frame_num,
label_bytes)
label_data = np.array(
label_array, dtype='int64').reshape(
(sample_info.label_frame_num, 1))
* sample_info.feature_dim * 4 \
== len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))
label_data = None
if sample_info.label_bin_path != "":
label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start,
sample_info.label_size)
assert sample_info.label_frame_num * 4 == len(
label_bytes), (sample_info.label_bin_path,
sample_info.label_array,
len(label_bytes))
label_array = struct.unpack(
'I' * sample_info.label_frame_num, label_bytes)
label_data = np.array(
label_array, dtype='int64').reshape(
(sample_info.label_frame_num, 1))
else:
label_data = np.zeros(
(sample_info.label_frame_num, 1), dtype='int64')
feature_frame_num = sample_info.feature_frame_num
feature_dim = sample_info.feature_dim
......@@ -341,12 +356,11 @@ class AsyncDataReader(object):
feature_data = np.array(
feature_array, dtype='float32').reshape((
sample_info.feature_frame_num, sample_info.feature_dim))
sample_data = (feature_data, label_data)
sample_data = (feature_data, label_data,
sample_info.sample_name)
for transformer in self._transformers:
# @TODO(pkuyym) to make transfomer only accept feature_data
sample_data = transformer.perform_trans(sample_data)
while order_id != out_order[0]:
time.sleep(0.001)
......@@ -362,74 +376,77 @@ class AsyncDataReader(object):
out_order = self._manager.list([0])
args = (sample_info_queue, sample_queue, out_order)
sample_proc = DaemonProcessGroup(
proc_num=self._sample_proc_num,
target=ordered_processing_task,
args=args)
sample_proc.start_all()
workers = [
Process(
target=ordered_processing_task, args=args)
for _ in xrange(self._proc_num)
]
return sample_queue
for w in workers:
w.daemon = True
w.start()
def batch_iterator(self, batch_size, minimum_batch_size):
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_queue, batch_queue, pool):
def conv_to_shared(ndarray):
while self._force_exit == False:
try:
(name, shared_ndarray) = pool.popitem()
except Exception as e:
time.sleep(0.001)
finished_proc_num = 0
while self._force_exit == False:
try:
sample = sample_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(sample, EpochEndSignal):
finished_proc_num += 1
if finished_proc_num >= self._proc_num:
break
else:
shared_ndarray.copy(ndarray)
return shared_ndarray
continue
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
yield sample
def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod):
assert len(batch_samples)
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
name_lst = []
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
name_lst.append(sample[2])
return (batch_feature, batch_label, name_lst)
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = []
lod = [0]
done_num = 0
while done_num < self._sample_proc_num:
sample = sample_queue.get()
if isinstance(sample, EpochEndSignal):
done_num += 1
else:
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
feature, label = batch_to_ndarray(batch_samples, lod)
feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
batch_queue.put((feature, label, lod))
batch_samples = []
lod = [0]
for sample in sample_generator():
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
(batch_feature, batch_label, name_lst) = batch_to_ndarray(
batch_samples, lod)
batch_queue.put((batch_feature, batch_label, lod, name_lst))
batch_samples = []
lod = [0]
if len(batch_samples) >= minimum_batch_size:
(feature, label) = batch_to_ndarray(batch_samples, lod)
feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
batch_queue.put((feature, label, lod))
(batch_feature, batch_label, name_lst) = batch_to_ndarray(
batch_samples, lod)
batch_queue.put((batch_feature, batch_label, lod, name_lst))
batch_queue.put(EpochEndSignal())
sample_queue = self._start_async_processing()
batch_queue = self._manager.Queue(self._batch_buffer_size)
batch_queue = Queue.Queue(self._batch_buffer_size)
self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size *
3, self._manager)
assembling_proc = DaemonProcessGroup(
proc_num=1,
assembling_thread = Thread(
target=batch_assembling_task,
args=(sample_queue, batch_queue, self._pool_manager.pool))
assembling_proc.start_all()
args=(self._sample_generator, batch_queue))
assembling_thread.daemon = True
assembling_thread.start()
while self._force_exit == False:
try:
......@@ -440,6 +457,3 @@ class AsyncDataReader(object):
if isinstance(batch_data, EpochEndSignal):
break
yield batch_data
# clean the shared memory
del self._pool_manager
......@@ -22,7 +22,7 @@ class TestTransMeanVarianceNorm(unittest.TestCase):
feature = np.zeros((2, 120), dtype="float32")
feature.fill(1)
trans = trans_mean_variance_norm.TransMeanVarianceNorm(self._file_path)
(feature1, label1) = trans.perform_trans((feature, None))
(feature1, label1, name) = trans.perform_trans((feature, None, None))
(mean, var) = trans.get_mean_var()
feature_flat1 = feature1.flatten()
feature_flat = feature.flatten()
......@@ -70,7 +70,7 @@ class TestTransAddDelta(unittest.TestCase):
feature[2, 0:40].fill(3)
feature[3, 0:40].fill(4)
trans = trans_add_delta.TransAddDelta()
(feature, label) = trans.perform_trans((feature, None))
(feature, label, name) = trans.perform_trans((feature, None, None))
self.assertAlmostEqual(feature.shape[0], 4)
self.assertAlmostEqual(feature.shape[1], 120)
self.assertAlmostEqual(1.0, feature[0][0])
......@@ -93,7 +93,7 @@ class TestTransSplict(unittest.TestCase):
feature[i, :].fill(i)
trans = trans_splice.TransSplice()
(feature, label) = trans.perform_trans((feature, None))
(feature, label, name) = trans.perform_trans((feature, None, None))
self.assertEqual(feature.shape[1], 110)
for i in xrange(8):
......
......@@ -32,9 +32,9 @@ class TransAddDelta(object):
Args:
sample(object,tuple): contain feature numpy and label numpy
Returns:
(feature, label)
(feature, label, name)
"""
(feature, label) = sample
(feature, label, name) = sample
frame_dim = feature.shape[1]
d_frame_dim = frame_dim * 3
head_filled = 5
......@@ -64,7 +64,7 @@ class TransAddDelta(object):
start * d_frame_dim + 2 * frame_dim, frame_dim, nframe,
d_frame_dim)
mat.shape = tmp_shape
return (mat[head_filled:mat.shape[0] - tail_filled, :], label)
return (mat[head_filled:mat.shape[0] - tail_filled, :], label, name)
def _regress(self, data_in, start_in, data_out, start_out, size, n, step):
""" regress
......
......@@ -53,9 +53,9 @@ class TransMeanVarianceNorm(object):
Args:
sample(object):input sample, contain feature numpy and label numpy
Returns:
(feature, label)
(feature, label, name)
"""
(feature, label) = sample
(feature, label, name) = sample
shape = feature.shape
assert len(shape) == 2
nfeature_len = shape[0] * shape[1]
......@@ -68,4 +68,4 @@ class TransMeanVarianceNorm(object):
feature[ncur_idx:ncur_idx + self._nLen] = block
ncur_idx += self._nLen
feature = feature.reshape(shape)
return (feature, label)
return (feature, label, name)
......@@ -30,9 +30,9 @@ class TransSplice(object):
Args:
sample(object): input sample(feature, label)
Return:
(feature, label)
(feature, label, name)
"""
(feature, label) = sample
(feature, label, name) = sample
nframe_num = feature.shape[0]
nframe_dim = feature.shape[1]
nnew_frame_dim = nframe_dim * (
......@@ -61,4 +61,4 @@ class TransSplice(object):
np.copyto(ret[i * nnew_frame_dim:(i + 1) * nnew_frame_dim],
mat[i * nframe_dim:i * nframe_dim + nnew_frame_dim])
ret = ret.reshape((nframe_num, nnew_frame_dim))
return (ret, label)
return (ret, label, name)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys, time
import sys
from six import reraise
from tblib import Traceback
from multiprocessing import Manager, Process
import posix_ipc, mmap
import numpy as np
......@@ -37,19 +35,6 @@ def lodtensor_to_ndarray(lod_tensor):
return ret, lod_tensor.lod()
def batch_to_ndarray(batch_samples, lod):
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
return (batch_feature, batch_label)
def split_infer_result(infer_seq, lod):
infer_batch = []
for i in xrange(0, len(lod[0]) - 1):
......@@ -57,127 +42,10 @@ def split_infer_result(infer_seq, lod):
return infer_batch
class DaemonProcessGroup(object):
def __init__(self, proc_num, target, args):
self._proc_num = proc_num
self._workers = [
Process(
target=target, args=args) for _ in xrange(self._proc_num)
]
def start_all(self):
for w in self._workers:
w.daemon = True
w.start()
@property
def proc_num(self):
return self._proc_num
class EpochEndSignal(object):
pass
class CriticalException(Exception):
pass
class SharedNDArray(object):
"""SharedNDArray utilizes shared memory to avoid data serialization when
data object shared among different processes. We can reconstruct the
`ndarray` when memory address, shape and dtype provided.
Args:
name (str): Address name of shared memory.
whether_verify (bool): Whether to validate the writing operation.
"""
def __init__(self, name, whether_verify=False):
self._name = name
self._shm = None
self._buf = None
self._array = np.zeros(1, dtype=np.float32)
self._inited = False
self._whether_verify = whether_verify
def zeros_like(self, shape, dtype):
size = int(np.prod(shape)) * np.dtype(dtype).itemsize
if self._inited:
self._shm = posix_ipc.SharedMemory(self._name)
else:
self._shm = posix_ipc.SharedMemory(
self._name, posix_ipc.O_CREAT, size=size)
self._buf = mmap.mmap(self._shm.fd, size)
self._array = np.ndarray(shape, dtype, self._buf, order='C')
def copy(self, ndarray):
size = int(np.prod(ndarray.shape)) * np.dtype(ndarray.dtype).itemsize
self.zeros_like(ndarray.shape, ndarray.dtype)
self._array[:] = ndarray
self._buf.flush()
self._inited = True
if self._whether_verify:
shm = posix_ipc.SharedMemory(self._name)
buf = mmap.mmap(shm.fd, size)
array = np.ndarray(ndarray.shape, ndarray.dtype, buf, order='C')
np.testing.assert_array_equal(array, ndarray)
@property
def ndarray(self):
return self._array
def recycle(self, pool):
self._buf.close()
self._shm.close_fd()
self._inited = False
pool[self._name] = self
def __getstate__(self):
return (self._name, self._array.shape, self._array.dtype, self._inited,
self._whether_verify)
def __setstate__(self, state):
self._name = state[0]
self._inited = state[3]
self.zeros_like(state[1], state[2])
self._whether_verify = state[4]
class SharedMemoryPoolManager(object):
"""SharedMemoryPoolManager maintains a multiprocessing.Manager.dict object.
All available addresses are allocated once and will be reused. Though this
class is not process-safe, the pool can be shared between processes. All
shared memory should be unlinked before the main process exited.
Args:
pool_size (int): Size of shared memory pool.
manager (dict): A multiprocessing.Manager object, the pool is
maintained by the proxy process.
name_prefix (str): Address prefix of shared memory.
"""
def __init__(self, pool_size, manager, name_prefix='/deep_asr'):
self._names = []
self._dict = manager.dict()
self._time_prefix = time.strftime('%Y%m%d%H%M%S')
for i in xrange(pool_size):
name = name_prefix + '_' + self._time_prefix + '_' + str(i)
self._dict[name] = SharedNDArray(name)
self._names.append(name)
@property
def pool(self):
return self._dict
def __del__(self):
for name in self._names:
# have to unlink the shared memory
posix_ipc.unlink_shared_memory(name)
def suppress_signal(signo, stack_frame):
pass
......
......@@ -21,14 +21,15 @@ using fst::StdArc;
Decoder::Decoder(std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_rxfilename) {
std::string logprior_rxfilename,
kaldi::BaseFloat acoustic_scale) {
const char* usage =
"Decode, reading log-likelihoods (of transition-ids or whatever symbol "
"is on the graph) as matrices.";
kaldi::ParseOptions po(usage);
binary = true;
acoustic_scale = 1.5;
this->acoustic_scale = acoustic_scale;
allow_partial = true;
kaldi::FasterDecoderOptions decoder_opts;
decoder_opts.Register(&po, true); // true == include obscure settings.
......
......@@ -29,7 +29,8 @@ class Decoder {
public:
Decoder(std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_rxfilename);
std::string logprior_rxfilename,
kaldi::BaseFloat acoustic_scale);
~Decoder();
// Interface to accept the scores read from specifier and return
......
......@@ -23,7 +23,7 @@ PYBIND11_MODULE(post_decode_faster, m) {
m.doc() = "Decoder for Deep ASR model";
py::class_<Decoder>(m, "Decoder")
.def(py::init<std::string, std::string, std::string>())
.def(py::init<std::string, std::string, std::string, kaldi::BaseFloat>())
.def("decode",
(std::vector<std::string> (Decoder::*)(std::string)) &
Decoder::decode,
......
......@@ -8,7 +8,7 @@ import paddle.fluid as fluid
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.data_reader as reader
import data_utils.async_data_reader as reader
from data_utils.util import lodtensor_to_ndarray
from data_utils.util import split_infer_result
......@@ -79,12 +79,13 @@ def infer(args):
trans_splice.TransSplice()
]
infer_data_reader = reader.DataReader(args.infer_feature_lst,
args.infer_label_lst)
infer_data_reader = reader.AsyncDataReader(args.infer_feature_lst,
args.infer_label_lst)
infer_data_reader.set_transformers(ltrans)
feature_t = fluid.LoDTensor()
one_batch = infer_data_reader.batch_iterator(args.batch_size, 1).next()
(features, labels, lod) = one_batch
feature_t.set(features, place)
feature_t.set_lod([lod])
......
......@@ -106,6 +106,11 @@ def parse_args():
type=str,
default="./decoder/logprior",
help="The log prior probs for training data. (default: %(default)s)")
parser.add_argument(
'--acoustic_scale',
type=float,
default=0.2,
help="Scaling factor for acoustic likelihoods. (default: %(default)f)")
args = parser.parse_args()
return args
......@@ -143,6 +148,10 @@ def infer_from_ckpt(args):
# load checkpoint.
fluid.io.load_persistables(exe, args.checkpoint)
# init decoder
decoder = Decoder(args.vocabulary, args.graphs, args.log_prior,
args.acoustic_scale)
ltrans = [
trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
......@@ -162,12 +171,10 @@ def infer_from_ckpt(args):
args.minimum_batch_size)):
# load_data
(features, labels, lod) = batch_data
feature_t.set(features.ndarray, place)
feature_t.set_lod([lod.ndarray])
label_t.set(labels.ndarray, place)
label_t.set_lod([lod.ndarray])
infer_data_reader.recycle(features, labels, lod)
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
label_t.set_lod([lod])
results = exe.run(infer_program,
feed={"feature": feature_t,
......@@ -181,7 +188,7 @@ def infer_from_ckpt(args):
infer_batch = split_infer_result(probs, lod)
for index, sample in enumerate(infer_batch):
key = "utter#%d" % (batch_id * args.batch_size + index)
print(key, ": ", decoder.decode(key, sample), "\n")
print(key, ": ", decoder.decode(key, sample).encode("utf8"), "\n")
print(np.mean(infer_costs), np.mean(infer_accs))
......
......@@ -169,14 +169,12 @@ def profile(args):
frames_seen = 0
# load_data
(features, labels, lod) = batch_data
feature_t.set(features.ndarray, place)
feature_t.set_lod([lod.ndarray])
label_t.set(labels.ndarray, place)
label_t.set_lod([lod.ndarray])
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
label_t.set_lod([lod])
frames_seen += lod.ndarray[-1]
data_reader.recycle(features, labels, lod)
frames_seen += lod[-1]
outs = exe.run(fluid.default_main_program(),
feed={"feature": feature_t,
......
......@@ -193,12 +193,10 @@ def train(args):
args.minimum_batch_size)):
# load_data
(features, labels, lod) = batch_data
feature_t.set(features.ndarray, place)
feature_t.set_lod([lod.ndarray])
label_t.set(labels.ndarray, place)
label_t.set_lod([lod.ndarray])
test_data_reader.recycle(features, labels, lod)
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
label_t.set_lod([lod])
cost, acc = exe.run(test_program,
feed={"feature": feature_t,
......@@ -212,6 +210,7 @@ def train(args):
# train data reader
train_data_reader = reader.AsyncDataReader(args.train_feature_lst,
args.train_label_lst, -1)
train_data_reader.set_transformers(ltrans)
# train
for pass_id in xrange(args.pass_num):
......@@ -220,13 +219,11 @@ def train(args):
train_data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)):
# load_data
(features, labels, lod) = batch_data
feature_t.set(features.ndarray, place)
feature_t.set_lod([lod.ndarray])
label_t.set(labels.ndarray, place)
label_t.set_lod([lod.ndarray])
train_data_reader.recycle(features, labels, lod)
(features, labels, lod, name_lst) = batch_data
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
label_t.set_lod([lod])
to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0)
outs = exe.run(fluid.default_main_program(),
......
import os
import numpy as np
import time
import sys
import paddle.v2 as paddle
import paddle.fluid as fluid
import reader
......@@ -65,20 +68,44 @@ def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio):
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
def SE_ResNeXt(input, class_dim, infer=False):
cardinality = 64
reduction_ratio = 16
depth = [3, 8, 36, 3]
num_filters = [128, 256, 512, 1024]
def SE_ResNeXt(input, class_dim, infer=False, layers=50):
supported_layers = [50, 152]
if layers not in supported_layers:
print("supported layers are", supported_layers, "but input layer is",
layers)
exit()
if layers == 50:
cardinality = 32
reduction_ratio = 16
depth = [3, 4, 6, 3]
num_filters = [128, 256, 512, 1024]
conv = conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=2, act='relu')
conv = conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu')
conv = conv_bn_layer(
input=conv, num_filters=128, filter_size=3, stride=1, act='relu')
conv = fluid.layers.pool2d(
input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
conv = conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
elif layers == 152:
cardinality = 64
reduction_ratio = 16
depth = [3, 8, 36, 3]
num_filters = [128, 256, 512, 1024]
conv = conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=2, act='relu')
conv = conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu')
conv = conv_bn_layer(
input=conv, num_filters=128, filter_size=3, stride=1, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
......@@ -104,7 +131,10 @@ def train(learning_rate,
num_passes,
init_model=None,
model_save_dir='model',
parallel=True):
parallel=True,
use_nccl=True,
lr_strategy=None,
layers=50):
class_dim = 1000
image_shape = [3, 224, 224]
......@@ -113,36 +143,52 @@ def train(learning_rate,
if parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
with pd.do():
image_ = pd.read_input(image)
label_ = pd.read_input(label)
out = SE_ResNeXt(input=image_, class_dim=class_dim)
out = SE_ResNeXt(input=image_, class_dim=class_dim, layers=layers)
cost = fluid.layers.cross_entropy(input=out, label=label_)
avg_cost = fluid.layers.mean(x=cost)
accuracy = fluid.layers.accuracy(input=out, label=label_)
acc_top1 = fluid.layers.accuracy(input=out, label=label_, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label_, k=5)
pd.write_output(avg_cost)
pd.write_output(accuracy)
pd.write_output(acc_top1)
pd.write_output(acc_top5)
avg_cost, accuracy = pd()
avg_cost, acc_top1, acc_top5 = pd()
avg_cost = fluid.layers.mean(x=avg_cost)
accuracy = fluid.layers.mean(x=accuracy)
acc_top1 = fluid.layers.mean(x=acc_top1)
acc_top5 = fluid.layers.mean(x=acc_top5)
else:
out = SE_ResNeXt(input=image, class_dim=class_dim)
out = SE_ResNeXt(input=image, class_dim=class_dim, layers=layers)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
accuracy = fluid.layers.accuracy(input=out, label=label)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
if lr_strategy is None:
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
else:
bd = lr_strategy["bd"]
lr = lr_strategy["lr"]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
opts = optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program())
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program([avg_cost, accuracy])
inference_program = fluid.io.get_inference_program(
[avg_cost, acc_top1, acc_top5])
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
......@@ -156,34 +202,86 @@ def train(learning_rate,
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
for pass_id in range(num_passes):
train_info = [[], [], []]
test_info = [[], [], []]
for batch_id, data in enumerate(train_reader()):
loss = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost])
print("Pass {0}, batch {1}, loss {2}".format(pass_id, batch_id,
float(loss[0])))
total_loss = 0.0
total_acc = 0.0
total_batch = 0
t1 = time.time()
loss, acc1, acc5 = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
t2 = time.time()
period = t2 - t1
train_info[0].append(loss[0])
train_info[1].append(acc1[0])
train_info[2].append(acc5[0])
if batch_id % 10 == 0:
print("Pass {0}, trainbatch {1}, loss {2}, \
acc1 {3}, acc5 {4} time {5}"
.format(pass_id, \
batch_id, loss[0], acc1[0], acc5[0], \
"%2.2f sec" % period))
sys.stdout.flush()
train_loss = np.array(train_info[0]).mean()
train_acc1 = np.array(train_info[1]).mean()
train_acc5 = np.array(train_info[2]).mean()
for data in test_reader():
loss, acc = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, accuracy])
total_loss += float(loss)
total_acc += float(acc)
total_batch += 1
print("End pass {0}, test_loss {1}, test_acc {2}".format(
pass_id, total_loss / total_batch, total_acc / total_batch))
t1 = time.time()
loss, acc1, acc5 = exe.run(
inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, acc_top1, acc_top5])
t2 = time.time()
period = t2 - t1
test_info[0].append(loss[0])
test_info[1].append(acc1[0])
test_info[2].append(acc5[0])
if batch_id % 10 == 0:
print("Pass {0},testbatch {1},loss {2}, \
acc1 {3},acc5 {4},time {5}"
.format(pass_id, \
batch_id, loss[0], acc1[0], acc5[0], \
"%2.2f sec" % period))
sys.stdout.flush()
test_loss = np.array(test_info[0]).mean()
test_acc1 = np.array(test_info[1]).mean()
test_acc5 = np.array(test_info[2]).mean()
print("End pass {0}, train_loss {1}, train_acc1 {2}, train_acc5 {3}, \
test_loss {4}, test_acc1 {5}, test_acc5 {6}"
.format(pass_id, \
train_loss, train_acc1, train_acc5, test_loss, test_acc1, \
test_acc5))
sys.stdout.flush()
model_path = os.path.join(model_save_dir, str(pass_id))
fluid.io.save_inference_model(model_path, ['image'], [out], exe)
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
if __name__ == '__main__':
epoch_points = [30, 60, 90]
total_images = 1281167
batch_size = 256
step = int(total_images / batch_size + 1)
bd = [e * step for e in epoch_points]
lr = [0.1, 0.01, 0.001, 0.0001]
lr_strategy = {"bd": bd, "lr": lr}
use_nccl = True
# layers: 50, 152
layers = 50
train(
learning_rate=0.1,
batch_size=8,
num_passes=100,
batch_size=batch_size,
num_passes=120,
init_model=None,
parallel=False)
parallel=True,
use_nccl=True,
lr_strategy=lr_strategy,
layers=layers)
......@@ -92,7 +92,9 @@ pos_enc_param_names = (
encoder_input_data_names = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
"src_slf_attn_bias",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )
# Names of all data layers in decoder listed in order.
decoder_input_data_names = (
......@@ -100,6 +102,10 @@ decoder_input_data_names = (
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
"trg_src_attn_post_softmax_shape",
"enc_output", )
# Names of label related data layers listed in order.
......
......@@ -27,7 +27,14 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
is_target=False,
return_pos=True,
return_attn_bias=True,
return_max_len=True)
return_max_len=False)
# Append the shape inputs to reshape before and after softmax in encoder
# self attention.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1]], dtype="int32"), np.array(
enc_in_data[2].shape, dtype="int32")
]
enc_output = exe.run(encoder,
feed=dict(zip(enc_in_names, enc_in_data)),
fetch_list=enc_out_names)[0]
......@@ -35,8 +42,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
# Beam Search.
# To store the beam info.
scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[]] * batch_size
next_ids = [[]] * batch_size
prev_branchs = [[] for i in range(batch_size)]
next_ids = [[] for i in range(batch_size)]
# Use beam_map to map the instance idx in batch to beam idx, since the
# size of feeded batch is changing.
beam_map = range(batch_size)
......@@ -64,8 +71,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_words = np.array(
[[bos_idx]] * batch_size * beam_size, dtype="int64")
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[
-1], enc_in_data[-2], 1
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[2].shape[
-1], enc_in_data[2], 1
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len))
......@@ -77,15 +84,33 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :],
[beam_size, 1, trg_max_len, 1])
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = np.tile(enc_output, [beam_size, 1, 1])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams):
"""
Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
enc_output = dec_in_data
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_words = np.array(
[
......@@ -112,8 +137,23 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1])
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output)
......
......@@ -32,7 +32,9 @@ def multi_head_attention(queries,
d_value,
d_model,
n_head=1,
dropout_rate=0.):
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
......@@ -111,26 +113,16 @@ def multi_head_attention(queries,
"""
Scaled Dot-Product Attention
"""
# FIXME(guosheng): Optimize the shape in reshape_op or softmax_op.
# The current implementation of softmax_op only supports 2D tensor,
# consequently it cannot be directly used here.
# If to use the reshape_op, Besides, the shape of product inferred in
# compile-time is not the actual shape in run-time. It cann't be used
# to set the attribute of reshape_op.
# So, here define the softmax for temporary solution.
def __softmax(x, eps=1e-9):
exp_out = layers.exp(x=x)
sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
return layers.elementwise_div(x=exp_out, y=sum_out, axis=0)
scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = __softmax(
layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product)
weights = layers.reshape(
x=layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product,
shape=[-1, product.shape[-1]],
actual_shape=pre_softmax_shape,
act="softmax")
weights = layers.reshape(
x=weights, shape=product.shape, actual_shape=post_softmax_shape)
if dropout_rate:
weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False)
......@@ -177,7 +169,7 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid):
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
......@@ -195,8 +187,9 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout
if dropout:
out = layers.dropout(out, dropout_prob=dropout, is_test=False)
if dropout_rate:
out = layers.dropout(
out, dropout_prob=dropout_rate, is_test=False)
return out
......@@ -210,7 +203,7 @@ def prepare_encoder(src_word,
src_emb_dim,
src_pad_idx,
src_max_len,
dropout=0.,
dropout_rate=0.,
pos_pad_idx=0,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
......@@ -235,8 +228,8 @@ def prepare_encoder(src_word,
# FIXME(guosheng): Decouple the program desc with batch_size.
enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim])
return layers.dropout(
enc_input, dropout_prob=dropout,
is_test=False) if dropout else enc_input
enc_input, dropout_prob=dropout_rate,
is_test=False) if dropout_rate else enc_input
prepare_encoder = partial(
......@@ -252,7 +245,9 @@ def encoder_layer(enc_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
......@@ -260,9 +255,9 @@ def encoder_layer(enc_input,
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(enc_input, enc_input, enc_input,
attn_bias, d_key, d_value, d_model,
n_head, dropout_rate)
attn_output = multi_head_attention(
enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
attn_output = post_process_layer(enc_input, attn_output, "dan",
dropout_rate)
ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
......@@ -277,7 +272,9 @@ def encoder(enc_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
......@@ -291,7 +288,9 @@ def encoder(enc_input,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
dropout_rate,
pre_softmax_shape,
post_softmax_shape, )
enc_input = enc_output
return enc_output
......@@ -305,7 +304,11 @@ def decoder_layer(dec_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
......@@ -320,7 +323,9 @@ def decoder_layer(dec_input,
d_value,
d_model,
n_head,
dropout_rate, )
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
......@@ -335,7 +340,9 @@ def decoder_layer(dec_input,
d_value,
d_model,
n_head,
dropout_rate, )
dropout_rate,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
......@@ -363,7 +370,11 @@ def decoder(dec_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
......@@ -378,7 +389,11 @@ def decoder(dec_input,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
dec_input = dec_output
return dec_output
......@@ -391,7 +406,9 @@ def make_inputs(input_data_names,
is_pos,
slf_attn_bias_flag,
src_attn_bias_flag,
enc_output_flag=False):
enc_output_flag=False,
slf_attn_shape_flag=True,
src_attn_shape_flag=True):
"""
Define the input data layers for the transformer model.
"""
......@@ -429,6 +446,32 @@ def make_inputs(input_data_names,
dtype="float32",
append_batch_size=False)
input_layers += [src_attn_bias]
if slf_attn_shape_flag:
slf_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
dtype="int32",
append_batch_size=False)
input_layers += [slf_attn_pre_softmax_shape]
slf_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
dtype="int32",
append_batch_size=False)
input_layers += [slf_attn_post_softmax_shape]
if src_attn_shape_flag:
src_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_pre_softmax_shape]
src_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_post_softmax_shape]
if enc_output_flag:
enc_output = layers.data(
name=input_data_names[len(input_layers)],
......@@ -436,6 +479,7 @@ def make_inputs(input_data_names,
dtype="float32",
append_batch_size=False)
input_layers += [enc_output]
return input_layers
......@@ -453,8 +497,18 @@ def transformer(
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model,
batch_size, max_length, True, True, False)
enc_input_layers = make_inputs(
encoder_input_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=False,
enc_output_flag=False,
slf_attn_shape_flag=True,
src_attn_shape_flag=False)
enc_output = wrap_encoder(
src_vocab_size,
......@@ -470,8 +524,18 @@ def transformer(
pos_pad_idx,
enc_input_layers, )
dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model,
batch_size, max_length, True, True, True)
dec_input_layers = make_inputs(
decoder_input_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=True,
enc_output_flag=False,
slf_attn_shape_flag=True,
src_attn_shape_flag=True)
predict = wrap_decoder(
trg_vocab_size,
......@@ -490,9 +554,19 @@ def transformer(
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size,
max_length, False, False, False)
cost = layers.cross_entropy(input=predict, label=gold)
gold, weights = make_inputs(
label_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos=False,
slf_attn_bias_flag=False,
src_attn_bias_flag=False,
enc_output_flag=False,
slf_attn_shape_flag=False,
src_attn_shape_flag=False)
cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
weighted_cost = cost * weights
return layers.reduce_sum(weighted_cost), predict
......@@ -514,11 +588,22 @@ def wrap_encoder(src_vocab_size,
"""
if enc_input_layers is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_inputs(
encoder_input_data_names, n_head, d_model, batch_size, max_length,
True, True, False)
src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape = make_inputs(
encoder_input_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=False,
enc_output_flag=False,
slf_attn_shape_flag=True,
src_attn_shape_flag=False)
else:
src_word, src_pos, src_slf_attn_bias = enc_input_layers
src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape = enc_input_layers
enc_input = prepare_encoder(
src_word,
src_pos,
......@@ -536,7 +621,9 @@ def wrap_encoder(src_vocab_size,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
return enc_output
......@@ -558,11 +645,26 @@ def wrap_decoder(trg_vocab_size,
"""
if dec_input_layers is None:
# This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs(
decoder_input_data_names, n_head, d_model, batch_size, max_length,
True, True, True, True)
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
enc_output = make_inputs(
decoder_input_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=True,
enc_output_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=True)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape = \
dec_input_layers
dec_input = prepare_decoder(
trg_word,
......@@ -583,13 +685,17 @@ def wrap_decoder(trg_vocab_size,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
# Return logits for training and probs for inference.
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax")
act="softmax" if dec_input_layers is None else None)
return predict
......@@ -66,13 +66,29 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
src_slf_attn_bias.shape, dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head,
False, False, False, False)
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
input_dict = dict(
zip(input_data_names, [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
src_word, src_pos, src_slf_attn_bias,
src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape,
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape,
lbl_word, lbl_weight
]))
return input_dict
......
......@@ -216,7 +216,7 @@ def distort_image(img, settings):
def expand_image(img, bbox_labels, img_width, img_height, settings):
prob = random.uniform(0, 1)
if prob < settings._hue_prob:
if prob < settings._expand_prob:
expand_ratio = random.uniform(1, settings._expand_max_ratio)
if expand_ratio - 1 >= 0.01:
height = int(img_height * expand_ratio)
......
......@@ -12,8 +12,9 @@ import functools
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
# yapf: disable
......@@ -47,26 +48,24 @@ def train(args,
locs, confs, box, box_var = mobile_net(image_, image_shape)
loss = fluid.layers.ssd_loss(locs, confs, gt_box_, gt_label_,
box, box_var)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.reduce_sum(loss)
pd.write_output(loss)
pd.write_output(locs)
pd.write_output(confs)
pd.write_output(box)
pd.write_output(box_var)
pd.write_output(nmsed_out)
loss, locs, confs, box, box_var = pd()
loss = fluid.layers.reduce_sum(loss)
loss, nmsed_out = pd()
loss = fluid.layers.mean(loss)
else:
locs, confs, box, box_var = mobile_net(image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, mbox_confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.ssd_loss(locs, mbox_confs, gt_box, gt_label,
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label,
box, box_var)
loss = fluid.layers.reduce_sum(loss)
test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
......@@ -98,7 +97,6 @@ def train(args,
feeder = fluid.DataFeeder(
place=place, feed_list=[image, gt_box, gt_label, difficult])
#print 'test_program ', test_program
def test(pass_id):
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
......@@ -109,7 +107,6 @@ def train(args,
fetch_list=[accum_map])
print("Test {0}, map {1}".format(pass_id, test_map[0]))
#print 'main_program ', fluid.default_main_program()
for pass_id in range(num_passes):
for batch_id, data in enumerate(train_reader()):
loss_v = exe.run(fluid.default_main_program(),
......@@ -143,5 +140,5 @@ if __name__ == '__main__':
val_file_list='./data/test.txt',
data_args=data_args,
learning_rate=0.001,
batch_size=32,
batch_size=args.batch_size,
num_passes=300)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册