提交 606cb22e 编写于 作者: Z zhxfl

Merge remote-tracking branch 'upstream/develop' into fix-661

...@@ -15,13 +15,12 @@ from multiprocessing import Manager, Process ...@@ -15,13 +15,12 @@ from multiprocessing import Manager, Process
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal from data_utils.util import suppress_complaints, suppress_signal
from data_utils.util import SharedNDArray, SharedMemoryPoolManager from data_utils.util import CriticalException, ForceExitWrapper
from data_utils.util import DaemonProcessGroup, batch_to_ndarray
from data_utils.util import CriticalException, ForceExitWrapper, EpochEndSignal
class SampleInfo(object): class SampleInfo(object):
"""SampleInfo holds the necessary information to load a sample from disk. """SampleInfo holds the necessary information to load a sample from disk.
Args: Args:
feature_bin_path (str): File containing the feature data. feature_bin_path (str): File containing the feature data.
feature_start (int): Start position of the sample's feature data. feature_start (int): Start position of the sample's feature data.
...@@ -54,6 +53,7 @@ class SampleInfoBucket(object): ...@@ -54,6 +53,7 @@ class SampleInfoBucket(object):
data, sample start position, sample byte number etc.) to access samples' data, sample start position, sample byte number etc.) to access samples'
feature data and the same with the label description file. SampleInfoBucket feature data and the same with the label description file. SampleInfoBucket
is the minimum unit to do shuffle. is the minimum unit to do shuffle.
Args: Args:
feature_bin_paths (list|tuple): Files containing the binary feature feature_bin_paths (list|tuple): Files containing the binary feature
data. data.
...@@ -67,8 +67,8 @@ class SampleInfoBucket(object): ...@@ -67,8 +67,8 @@ class SampleInfoBucket(object):
split_sentence_threshold(int): Sentence whose length larger than split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation. the value will trigger split operation.
split_sub_sentence_len(int): sub-sentence length is equal to split_sub_sentence_len(int): sub-sentence length is equal to
(split_sub_sentence_len + \ (split_sub_sentence_len
rand() % split_perturb). + rand() % split_perturb).
""" """
def __init__(self, def __init__(self,
...@@ -160,9 +160,14 @@ class SampleInfoBucket(object): ...@@ -160,9 +160,14 @@ class SampleInfoBucket(object):
return sample_info_list return sample_info_list
class EpochEndSignal():
pass
class AsyncDataReader(object): class AsyncDataReader(object):
"""DataReader provides basic audio sample preprocessing pipeline including """DataReader provides basic audio sample preprocessing pipeline including
data loading and data augmentation. data loading and data augmentation.
Args: Args:
feature_file_list (str): File containing paths of feature data file and feature_file_list (str): File containing paths of feature data file and
corresponding description file. corresponding description file.
...@@ -206,17 +211,12 @@ class AsyncDataReader(object): ...@@ -206,17 +211,12 @@ class AsyncDataReader(object):
self.generate_bucket_list(True) self.generate_bucket_list(True)
self._order_id = 0 self._order_id = 0
self._manager = Manager() self._manager = Manager()
self._sample_buffer_size = sample_buffer_size
self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size self._batch_buffer_size = batch_buffer_size
self._proc_num = proc_num self._proc_num = proc_num
if self._proc_num <= 2:
raise ValueError("Value of `proc_num` should be greater than 2.")
self._sample_proc_num = self._proc_num - 2
self._verbose = verbose self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False)) self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
# buffer queue
self._sample_info_queue = self._manager.Queue(sample_info_buffer_size)
self._sample_queue = self._manager.Queue(sample_buffer_size)
self._batch_queue = self._manager.Queue(batch_buffer_size)
def generate_bucket_list(self, is_shuffle): def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None: if self._block_info_list is None:
...@@ -250,21 +250,13 @@ class AsyncDataReader(object): ...@@ -250,21 +250,13 @@ class AsyncDataReader(object):
def set_transformers(self, transformers): def set_transformers(self, transformers):
self._transformers = transformers self._transformers = transformers
def recycle(self, *args): def _sample_generator(self):
for shared_ndarray in args: sample_info_queue = self._manager.Queue(self._sample_info_buffer_size)
if not isinstance(shared_ndarray, SharedNDArray): sample_queue = self._manager.Queue(self._sample_buffer_size)
raise Value("Only support recycle SharedNDArray object.")
shared_ndarray.recycle(self._pool_manager.pool)
def _start_async_processing(self):
self._order_id = 0 self._order_id = 0
@suppress_complaints(verbose=self._verbose, notify=self._force_exit) @suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_feeding_task(sample_info_queue): def ordered_feeding_task(sample_info_queue):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
for sample_info_bucket in self._bucket_list: for sample_info_bucket in self._bucket_list:
try: try:
sample_info_list = \ sample_info_list = \
...@@ -277,14 +269,13 @@ class AsyncDataReader(object): ...@@ -277,14 +269,13 @@ class AsyncDataReader(object):
sample_info_queue.put((sample_info, self._order_id)) sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1 self._order_id += 1
for i in xrange(self._sample_proc_num): for i in xrange(self._proc_num):
sample_info_queue.put(EpochEndSignal()) sample_info_queue.put(EpochEndSignal())
feeding_proc = DaemonProcessGroup( feeding_thread = Thread(
proc_num=1, target=ordered_feeding_task, args=(sample_info_queue, ))
target=ordered_feeding_task, feeding_thread.daemon = True
args=(self._sample_info_queue, )) feeding_thread.start()
feeding_proc.start_all()
@suppress_complaints(verbose=self._verbose, notify=self._force_exit) @suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_processing_task(sample_info_queue, sample_queue, out_order): def ordered_processing_task(sample_info_queue, sample_queue, out_order):
...@@ -312,11 +303,12 @@ class AsyncDataReader(object): ...@@ -312,11 +303,12 @@ class AsyncDataReader(object):
sample_info.feature_size) sample_info.feature_size)
assert sample_info.feature_frame_num \ assert sample_info.feature_frame_num \
* sample_info.feature_dim * 4 == len(feature_bytes), \ * sample_info.feature_dim * 4 \
(sample_info.feature_bin_path, == len(feature_bytes), \
sample_info.feature_frame_num, (sample_info.feature_bin_path,
sample_info.feature_dim, sample_info.feature_frame_num,
len(feature_bytes)) sample_info.feature_dim,
len(feature_bytes))
label_bytes = read_bytes(sample_info.label_bin_path, label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start, sample_info.label_start,
...@@ -360,83 +352,83 @@ class AsyncDataReader(object): ...@@ -360,83 +352,83 @@ class AsyncDataReader(object):
sample_queue.put(EpochEndSignal()) sample_queue.put(EpochEndSignal())
out_order = self._manager.list([0]) out_order = self._manager.list([0])
args = (self._sample_info_queue, self._sample_queue, out_order) args = (sample_info_queue, sample_queue, out_order)
sample_proc = DaemonProcessGroup( workers = [
proc_num=self._sample_proc_num, Process(
target=ordered_processing_task, target=ordered_processing_task, args=args)
args=args) for _ in xrange(self._proc_num)
sample_proc.start_all() ]
def batch_iterator(self, batch_size, minimum_batch_size): for w in workers:
@suppress_complaints(verbose=self._verbose, notify=self._force_exit) w.daemon = True
def batch_assembling_task(sample_queue, batch_queue, pool): w.start()
def conv_to_shared(ndarray):
while self._force_exit == False:
try:
(name, shared_ndarray) = pool.popitem()
except Exception as e:
time.sleep(0.001)
else:
shared_ndarray.copy(ndarray)
return shared_ndarray
if self._verbose == 0: finished_proc_num = 0
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
batch_samples = [] while self._force_exit == False:
lod = [0] try:
done_num = 0 sample = sample_queue.get_nowait()
while done_num < self._sample_proc_num: except Queue.Empty:
sample = sample_queue.get() time.sleep(0.001)
else:
if isinstance(sample, EpochEndSignal): if isinstance(sample, EpochEndSignal):
done_num += 1 finished_proc_num += 1
else: if finished_proc_num >= self._proc_num:
batch_samples.append(sample) break
lod.append(lod[-1] + sample[0].shape[0]) else:
if len(batch_samples) == batch_size: continue
feature, label = batch_to_ndarray(batch_samples, lod)
feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
batch_queue.put((feature, label, lod)) yield sample
batch_samples = []
lod = [0]
if len(batch_samples) >= minimum_batch_size: def batch_iterator(self, batch_size, minimum_batch_size):
(feature, label) = batch_to_ndarray(batch_samples, lod) def batch_to_ndarray(batch_samples, lod):
assert len(batch_samples)
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
return (batch_feature, batch_label)
feature = conv_to_shared(feature) @suppress_complaints(verbose=self._verbose, notify=self._force_exit)
label = conv_to_shared(label) def batch_assembling_task(sample_generator, batch_queue):
lod = conv_to_shared(np.array(lod).astype('int64')) batch_samples = []
lod = [0]
for sample in sample_generator():
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
(batch_feature, batch_label) = batch_to_ndarray(
batch_samples, lod)
batch_queue.put((batch_feature, batch_label, lod))
batch_samples = []
lod = [0]
batch_queue.put((feature, label, lod)) if len(batch_samples) >= minimum_batch_size:
(batch_feature, batch_label) = batch_to_ndarray(batch_samples,
lod)
batch_queue.put((batch_feature, batch_label, lod))
batch_queue.put(EpochEndSignal()) batch_queue.put(EpochEndSignal())
self._start_async_processing() batch_queue = Queue.Queue(self._batch_buffer_size)
self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size * assembling_thread = Thread(
3, self._manager)
assembling_proc = DaemonProcessGroup(
proc_num=1,
target=batch_assembling_task, target=batch_assembling_task,
args=(self._sample_queue, self._batch_queue, args=(self._sample_generator, batch_queue))
self._pool_manager.pool)) assembling_thread.daemon = True
assembling_proc.start_all() assembling_thread.start()
while self._force_exit == False: while self._force_exit == False:
try: try:
batch_data = self._batch_queue.get_nowait() batch_data = batch_queue.get_nowait()
except Queue.Empty: except Queue.Empty:
time.sleep(0.001) time.sleep(0.001)
else: else:
if isinstance(batch_data, EpochEndSignal): if isinstance(batch_data, EpochEndSignal):
break break
yield batch_data yield batch_data
# clean the shared memory
del self._pool_manager
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys, time import sys
from six import reraise from six import reraise
from tblib import Traceback from tblib import Traceback
from multiprocessing import Manager, Process
import posix_ipc, mmap
import numpy as np import numpy as np
...@@ -37,19 +35,6 @@ def lodtensor_to_ndarray(lod_tensor): ...@@ -37,19 +35,6 @@ def lodtensor_to_ndarray(lod_tensor):
return ret, lod_tensor.lod() return ret, lod_tensor.lod()
def batch_to_ndarray(batch_samples, lod):
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
return (batch_feature, batch_label)
def split_infer_result(infer_seq, lod): def split_infer_result(infer_seq, lod):
infer_batch = [] infer_batch = []
for i in xrange(0, len(lod[0]) - 1): for i in xrange(0, len(lod[0]) - 1):
...@@ -57,127 +42,10 @@ def split_infer_result(infer_seq, lod): ...@@ -57,127 +42,10 @@ def split_infer_result(infer_seq, lod):
return infer_batch return infer_batch
class DaemonProcessGroup(object):
def __init__(self, proc_num, target, args):
self._proc_num = proc_num
self._workers = [
Process(
target=target, args=args) for _ in xrange(self._proc_num)
]
def start_all(self):
for w in self._workers:
w.daemon = True
w.start()
@property
def proc_num(self):
return self._proc_num
class EpochEndSignal(object):
pass
class CriticalException(Exception): class CriticalException(Exception):
pass pass
class SharedNDArray(object):
"""SharedNDArray utilizes shared memory to avoid data serialization when
data object shared among different processes. We can reconstruct the
`ndarray` when memory address, shape and dtype provided.
Args:
name (str): Address name of shared memory.
whether_verify (bool): Whether to validate the writing operation.
"""
def __init__(self, name, whether_verify=False):
self._name = name
self._shm = None
self._buf = None
self._array = np.zeros(1, dtype=np.float32)
self._inited = False
self._whether_verify = whether_verify
def zeros_like(self, shape, dtype):
size = int(np.prod(shape)) * np.dtype(dtype).itemsize
if self._inited:
self._shm = posix_ipc.SharedMemory(self._name)
else:
self._shm = posix_ipc.SharedMemory(
self._name, posix_ipc.O_CREAT, size=size)
self._buf = mmap.mmap(self._shm.fd, size)
self._array = np.ndarray(shape, dtype, self._buf, order='C')
def copy(self, ndarray):
size = int(np.prod(ndarray.shape)) * np.dtype(ndarray.dtype).itemsize
self.zeros_like(ndarray.shape, ndarray.dtype)
self._array[:] = ndarray
self._buf.flush()
self._inited = True
if self._whether_verify:
shm = posix_ipc.SharedMemory(self._name)
buf = mmap.mmap(shm.fd, size)
array = np.ndarray(ndarray.shape, ndarray.dtype, buf, order='C')
np.testing.assert_array_equal(array, ndarray)
@property
def ndarray(self):
return self._array
def recycle(self, pool):
self._buf.close()
self._shm.close_fd()
self._inited = False
pool[self._name] = self
def __getstate__(self):
return (self._name, self._array.shape, self._array.dtype, self._inited,
self._whether_verify)
def __setstate__(self, state):
self._name = state[0]
self._inited = state[3]
self.zeros_like(state[1], state[2])
self._whether_verify = state[4]
class SharedMemoryPoolManager(object):
"""SharedMemoryPoolManager maintains a multiprocessing.Manager.dict object.
All available addresses are allocated once and will be reused. Though this
class is not process-safe, the pool can be shared between processes. All
shared memory should be unlinked before the main process exited.
Args:
pool_size (int): Size of shared memory pool.
manager (dict): A multiprocessing.Manager object, the pool is
maintained by the proxy process.
name_prefix (str): Address prefix of shared memory.
"""
def __init__(self, pool_size, manager, name_prefix='/deep_asr'):
self._names = []
self._dict = manager.dict()
self._time_prefix = time.strftime('%Y%m%d%H%M%S')
for i in xrange(pool_size):
name = name_prefix + '_' + self._time_prefix + '_' + str(i)
self._dict[name] = SharedNDArray(name)
self._names.append(name)
@property
def pool(self):
return self._dict
def __del__(self):
for name in self._names:
# have to unlink the shared memory
posix_ipc.unlink_shared_memory(name)
def suppress_signal(signo, stack_frame): def suppress_signal(signo, stack_frame):
pass pass
......
...@@ -21,14 +21,15 @@ using fst::StdArc; ...@@ -21,14 +21,15 @@ using fst::StdArc;
Decoder::Decoder(std::string word_syms_filename, Decoder::Decoder(std::string word_syms_filename,
std::string fst_in_filename, std::string fst_in_filename,
std::string logprior_rxfilename) { std::string logprior_rxfilename,
kaldi::BaseFloat acoustic_scale) {
const char* usage = const char* usage =
"Decode, reading log-likelihoods (of transition-ids or whatever symbol " "Decode, reading log-likelihoods (of transition-ids or whatever symbol "
"is on the graph) as matrices."; "is on the graph) as matrices.";
kaldi::ParseOptions po(usage); kaldi::ParseOptions po(usage);
binary = true; binary = true;
acoustic_scale = 1.5; this->acoustic_scale = acoustic_scale;
allow_partial = true; allow_partial = true;
kaldi::FasterDecoderOptions decoder_opts; kaldi::FasterDecoderOptions decoder_opts;
decoder_opts.Register(&po, true); // true == include obscure settings. decoder_opts.Register(&po, true); // true == include obscure settings.
......
...@@ -29,7 +29,8 @@ class Decoder { ...@@ -29,7 +29,8 @@ class Decoder {
public: public:
Decoder(std::string word_syms_filename, Decoder(std::string word_syms_filename,
std::string fst_in_filename, std::string fst_in_filename,
std::string logprior_rxfilename); std::string logprior_rxfilename,
kaldi::BaseFloat acoustic_scale);
~Decoder(); ~Decoder();
// Interface to accept the scores read from specifier and return // Interface to accept the scores read from specifier and return
......
...@@ -23,7 +23,7 @@ PYBIND11_MODULE(post_decode_faster, m) { ...@@ -23,7 +23,7 @@ PYBIND11_MODULE(post_decode_faster, m) {
m.doc() = "Decoder for Deep ASR model"; m.doc() = "Decoder for Deep ASR model";
py::class_<Decoder>(m, "Decoder") py::class_<Decoder>(m, "Decoder")
.def(py::init<std::string, std::string, std::string>()) .def(py::init<std::string, std::string, std::string, kaldi::BaseFloat>())
.def("decode", .def("decode",
(std::vector<std::string> (Decoder::*)(std::string)) & (std::vector<std::string> (Decoder::*)(std::string)) &
Decoder::decode, Decoder::decode,
......
...@@ -8,7 +8,7 @@ import paddle.fluid as fluid ...@@ -8,7 +8,7 @@ import paddle.fluid as fluid
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice import data_utils.augmentor.trans_splice as trans_splice
import data_utils.data_reader as reader import data_utils.async_data_reader as reader
from data_utils.util import lodtensor_to_ndarray from data_utils.util import lodtensor_to_ndarray
from data_utils.util import split_infer_result from data_utils.util import split_infer_result
...@@ -79,12 +79,13 @@ def infer(args): ...@@ -79,12 +79,13 @@ def infer(args):
trans_splice.TransSplice() trans_splice.TransSplice()
] ]
infer_data_reader = reader.DataReader(args.infer_feature_lst, infer_data_reader = reader.AsyncDataReader(args.infer_feature_lst,
args.infer_label_lst) args.infer_label_lst)
infer_data_reader.set_transformers(ltrans) infer_data_reader.set_transformers(ltrans)
feature_t = fluid.LoDTensor() feature_t = fluid.LoDTensor()
one_batch = infer_data_reader.batch_iterator(args.batch_size, 1).next() one_batch = infer_data_reader.batch_iterator(args.batch_size, 1).next()
(features, labels, lod) = one_batch (features, labels, lod) = one_batch
feature_t.set(features, place) feature_t.set(features, place)
feature_t.set_lod([lod]) feature_t.set_lod([lod])
......
...@@ -106,6 +106,11 @@ def parse_args(): ...@@ -106,6 +106,11 @@ def parse_args():
type=str, type=str,
default="./decoder/logprior", default="./decoder/logprior",
help="The log prior probs for training data. (default: %(default)s)") help="The log prior probs for training data. (default: %(default)s)")
parser.add_argument(
'--acoustic_scale',
type=float,
default=0.2,
help="Scaling factor for acoustic likelihoods. (default: %(default)f)")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -143,6 +148,10 @@ def infer_from_ckpt(args): ...@@ -143,6 +148,10 @@ def infer_from_ckpt(args):
# load checkpoint. # load checkpoint.
fluid.io.load_persistables(exe, args.checkpoint) fluid.io.load_persistables(exe, args.checkpoint)
# init decoder
decoder = Decoder(args.vocabulary, args.graphs, args.log_prior,
args.acoustic_scale)
ltrans = [ ltrans = [
trans_add_delta.TransAddDelta(2, 2), trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var), trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
...@@ -162,12 +171,10 @@ def infer_from_ckpt(args): ...@@ -162,12 +171,10 @@ def infer_from_ckpt(args):
args.minimum_batch_size)): args.minimum_batch_size)):
# load_data # load_data
(features, labels, lod) = batch_data (features, labels, lod) = batch_data
feature_t.set(features.ndarray, place) feature_t.set(features, place)
feature_t.set_lod([lod.ndarray]) feature_t.set_lod([lod])
label_t.set(labels.ndarray, place) label_t.set(labels, place)
label_t.set_lod([lod.ndarray]) label_t.set_lod([lod])
infer_data_reader.recycle(features, labels, lod)
results = exe.run(infer_program, results = exe.run(infer_program,
feed={"feature": feature_t, feed={"feature": feature_t,
...@@ -181,7 +188,7 @@ def infer_from_ckpt(args): ...@@ -181,7 +188,7 @@ def infer_from_ckpt(args):
infer_batch = split_infer_result(probs, lod) infer_batch = split_infer_result(probs, lod)
for index, sample in enumerate(infer_batch): for index, sample in enumerate(infer_batch):
key = "utter#%d" % (batch_id * args.batch_size + index) key = "utter#%d" % (batch_id * args.batch_size + index)
print(key, ": ", decoder.decode(key, sample), "\n") print(key, ": ", decoder.decode(key, sample).encode("utf8"), "\n")
print(np.mean(infer_costs), np.mean(infer_accs)) print(np.mean(infer_costs), np.mean(infer_accs))
......
...@@ -169,14 +169,12 @@ def profile(args): ...@@ -169,14 +169,12 @@ def profile(args):
frames_seen = 0 frames_seen = 0
# load_data # load_data
(features, labels, lod) = batch_data (features, labels, lod) = batch_data
feature_t.set(features.ndarray, place) feature_t.set(features, place)
feature_t.set_lod([lod.ndarray]) feature_t.set_lod([lod])
label_t.set(labels.ndarray, place) label_t.set(labels, place)
label_t.set_lod([lod.ndarray]) label_t.set_lod([lod])
frames_seen += lod.ndarray[-1] frames_seen += lod[-1]
data_reader.recycle(features, labels, lod)
outs = exe.run(fluid.default_main_program(), outs = exe.run(fluid.default_main_program(),
feed={"feature": feature_t, feed={"feature": feature_t,
......
...@@ -193,12 +193,10 @@ def train(args): ...@@ -193,12 +193,10 @@ def train(args):
args.minimum_batch_size)): args.minimum_batch_size)):
# load_data # load_data
(features, labels, lod) = batch_data (features, labels, lod) = batch_data
feature_t.set(features.ndarray, place) feature_t.set(features, place)
feature_t.set_lod([lod.ndarray]) feature_t.set_lod([lod])
label_t.set(labels.ndarray, place) label_t.set(labels, place)
label_t.set_lod([lod.ndarray]) label_t.set_lod([lod])
test_data_reader.recycle(features, labels, lod)
cost, acc = exe.run(test_program, cost, acc = exe.run(test_program,
feed={"feature": feature_t, feed={"feature": feature_t,
...@@ -221,12 +219,10 @@ def train(args): ...@@ -221,12 +219,10 @@ def train(args):
args.minimum_batch_size)): args.minimum_batch_size)):
# load_data # load_data
(features, labels, lod) = batch_data (features, labels, lod) = batch_data
feature_t.set(features.ndarray, place) feature_t.set(features, place)
feature_t.set_lod([lod.ndarray]) feature_t.set_lod([lod])
label_t.set(labels.ndarray, place) label_t.set(labels, place)
label_t.set_lod([lod.ndarray]) label_t.set_lod([lod])
train_data_reader.recycle(features, labels, lod)
to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0) to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0)
outs = exe.run(fluid.default_main_program(), outs = exe.run(fluid.default_main_program(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册