提交 c765794e 编写于 作者: M Michał Gallus 提交者: Tao Luo

Add MKL-DNN Benchmarking to CRNN-CTC (#1046)

* Add MKL-DNN Benchmarking to CRNN-CTC

* Make crnn-ctc scripts more portable

* Add documentation for cycle to crnn-ctc-reader

* Update crnn-ctc readme for CPU execution

* Merge CRNN-CTC train & inference scripts

* Fix mnist model & ce, kaffe graph yapf issues

* Remove LD_LIBRARY_PATH from crnn-ctc scripts

* CRNN-CTC scripts: set parallel to true

Abort script if batch_size is lower than num of cores

* CRNN-CTC scripts: limit mode options in infer

* CRNN-CTC scripts: set mkldnn parallel to False
上级 53937db0
......@@ -122,8 +122,8 @@ class Graph(object):
def compute_output_shapes(self):
sorted_nodes = self.topologically_sorted()
for node in sorted_nodes:
node.output_shape = make_tensor(*NodeKind.compute_output_shape(
node))
node.output_shape = make_tensor(
*NodeKind.compute_output_shape(node))
def replaced(self, new_nodes):
return Graph(nodes=new_nodes, name=self.name, trace=self.output_trace)
......
......@@ -19,6 +19,7 @@ tracking_kpis = [
train_duration_kpi,
]
def parse_log(log):
'''
This method should be implemented by model developers.
......@@ -37,7 +38,7 @@ def parse_log(log):
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print (fs)
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
kpi_name = fs[1]
kpi_value = float(fs[2])
......@@ -50,12 +51,11 @@ def log_to_ce(log):
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print (kpi_name, kpi_value)
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
log_to_ce(log)
log_to_ce(log)
......@@ -183,10 +183,10 @@ def run_benchmark(model, args):
(pass_end - pass_start)))
#Note: The following logs are special for CE monitoring.
#Other situations do not need to care about these logs.
print ("kpis train_acc %f" % train_avg_acc)
print ("kpis train_cost %f" % train_avg_loss)
print ("kpis test_acc %f" % test_avg_acc)
print ("kpis train_duration %f" % (pass_end - pass_start))
print("kpis train_acc %f" % train_avg_acc)
print("kpis train_cost %f" % train_avg_loss)
print("kpis test_acc %f" % test_avg_acc)
print("kpis train_duration %f" % (pass_end - pass_start))
if __name__ == '__main__':
......
......@@ -113,6 +113,10 @@ data/test_images/00003.jpg
```
env CUDA_VISIABLE_DEVICES=0 python ctc_train.py
```
使用默认数据在CPU上训练:
```
env OMP_NUM_THREADS=<num_of_physical_cores> python ctc_train.py --use_gpu False --parallel=False
```
使用默认数据在GPU多卡上训练:
......
......@@ -12,7 +12,8 @@ def conv_bn_pool(input,
bias=None,
param_0=None,
is_test=False,
pooling=True):
pooling=True,
use_cudnn=False):
tmp = input
for i in xrange(group):
tmp = fluid.layers.conv2d(
......@@ -22,7 +23,7 @@ def conv_bn_pool(input,
padding=1,
param_attr=param if param_0 is None else param_0,
act=None, # LinearActivation
use_cudnn=True)
use_cudnn=use_cudnn)
tmp = fluid.layers.batch_norm(
input=tmp,
act=act,
......@@ -35,13 +36,17 @@ def conv_bn_pool(input,
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=True,
use_cudnn=use_cudnn,
ceil_mode=True)
return tmp
def ocr_convs(input, regularizer=None, gradient_clip=None, is_test=False):
def ocr_convs(input,
regularizer=None,
gradient_clip=None,
is_test=False,
use_cudnn=False):
b = fluid.ParamAttr(
regularizer=regularizer,
gradient_clip=gradient_clip,
......@@ -56,12 +61,36 @@ def ocr_convs(input, regularizer=None, gradient_clip=None, is_test=False):
initializer=fluid.initializer.Normal(0.0, 0.01))
tmp = input
tmp = conv_bn_pool(
tmp, 2, [16, 16], param=w1, bias=b, param_0=w0, is_test=is_test)
tmp,
2, [16, 16],
param=w1,
bias=b,
param_0=w0,
is_test=is_test,
use_cudnn=use_cudnn)
tmp = conv_bn_pool(tmp, 2, [32, 32], param=w1, bias=b, is_test=is_test)
tmp = conv_bn_pool(tmp, 2, [64, 64], param=w1, bias=b, is_test=is_test)
tmp = conv_bn_pool(
tmp, 2, [128, 128], param=w1, bias=b, is_test=is_test, pooling=False)
tmp,
2, [32, 32],
param=w1,
bias=b,
is_test=is_test,
use_cudnn=use_cudnn)
tmp = conv_bn_pool(
tmp,
2, [64, 64],
param=w1,
bias=b,
is_test=is_test,
use_cudnn=use_cudnn)
tmp = conv_bn_pool(
tmp,
2, [128, 128],
param=w1,
bias=b,
is_test=is_test,
pooling=False,
use_cudnn=use_cudnn)
return tmp
......@@ -70,12 +99,14 @@ def encoder_net(images,
rnn_hidden_size=200,
regularizer=None,
gradient_clip=None,
is_test=False):
is_test=False,
use_cudnn=False):
conv_features = ocr_convs(
images,
regularizer=regularizer,
gradient_clip=gradient_clip,
is_test=is_test)
is_test=is_test,
use_cudnn=use_cudnn)
sliced_feature = fluid.layers.im2sequence(
input=conv_features,
stride=[1, 1],
......@@ -142,7 +173,11 @@ def ctc_train_net(images, label, args, num_classes):
learning_rate_decay = None
regularizer = fluid.regularizer.L2Decay(L2_RATE)
fc_out = encoder_net(images, num_classes, regularizer=regularizer)
fc_out = encoder_net(
images,
num_classes,
regularizer=regularizer,
use_cudnn=True if args.use_gpu else False)
cost = fluid.layers.warpctc(
input=fc_out, label=label, blank=num_classes, norm_by_times=True)
sum_cost = fluid.layers.reduce_sum(cost)
......@@ -166,19 +201,18 @@ def ctc_train_net(images, label, args, num_classes):
if args.average_window > 0:
model_average = fluid.optimizer.ModelAverage(
args.average_window,
params_grads,
min_average_window=args.min_average_window,
max_average_window=args.max_average_window)
return sum_cost, error_evaluator, inference_program, model_average
def ctc_infer(images, num_classes):
fc_out = encoder_net(images, num_classes, is_test=True)
def ctc_infer(images, num_classes, use_cudnn):
fc_out = encoder_net(images, num_classes, is_test=True, use_cudnn=use_cudnn)
return fluid.layers.ctc_greedy_decoder(input=fc_out, blank=num_classes)
def ctc_eval(images, label, num_classes):
fc_out = encoder_net(images, num_classes, is_test=True)
def ctc_eval(images, label, num_classes, use_cudnn):
fc_out = encoder_net(images, num_classes, is_test=True, use_cudnn=use_cudnn)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
......
......@@ -25,7 +25,7 @@ class DataGenerator(object):
def __init__(self):
pass
def train_reader(self, img_root_dir, img_label_list, batchsize):
def train_reader(self, img_root_dir, img_label_list, batchsize, cycle):
'''
Reader interface for training.
......@@ -35,6 +35,10 @@ class DataGenerator(object):
:param img_label_list: The path of the <image_name, label> file for training.
:type img_label_list: str
:param cycle: If number of iterations is greater than dataset_size / batch_size
it reiterates dataset over as many times as necessary.
:type cycle: bool
'''
img_label_lines = []
......@@ -65,24 +69,29 @@ class DataGenerator(object):
def reader():
sizes = len(img_label_lines) / batchsize
for i in range(sizes):
result = []
sz = [0, 0]
for j in range(batchsize):
line = img_label_lines[i * batchsize + j]
# h, w, img_name, labels
items = line.split(' ')
label = [int(c) for c in items[-1].split(',')]
img = Image.open(os.path.join(img_root_dir, items[
2])).convert('L') #zhuanhuidu
if j == 0:
sz = img.size
img = img.resize((sz[0], sz[1]))
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
result.append([img, label])
yield result
if sizes == 0:
raise ValueError('Batch size is bigger than the dataset size.')
while True:
for i in range(sizes):
result = []
sz = [0, 0]
for j in range(batchsize):
line = img_label_lines[i * batchsize + j]
# h, w, img_name, labels
items = line.split(' ')
label = [int(c) for c in items[-1].split(',')]
img = Image.open(os.path.join(img_root_dir, items[
2])).convert('L') #zhuanhuidu
if j == 0:
sz = img.size
img = img.resize((sz[0], sz[1]))
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
result.append([img, label])
yield result
if not cycle:
break
return reader
......@@ -111,7 +120,7 @@ class DataGenerator(object):
return reader
def infer_reader(self, img_root_dir=None, img_label_list=None):
def infer_reader(self, img_root_dir=None, img_label_list=None, cycle=False):
'''A reader interface for inference.
:param img_root_dir: The root path of the images for training.
......@@ -122,11 +131,15 @@ class DataGenerator(object):
was None. If img_label_list was set to None, it will read image path
from stdin.
:type img_root_dir: str
:param cycle: If number of iterations is greater than dataset_size /
batch_size it reiterates dataset over as many times as necessary.
:type cycle: bool
'''
def reader():
if img_label_list is not None:
for line in open(img_label_list):
def yield_img_and_label(lines):
for line in lines:
if img_root_dir is not None:
# h, w, img_name, labels
img_name = line.split(' ')[2]
......@@ -138,6 +151,16 @@ class DataGenerator(object):
img = img[np.newaxis, ...]
label = [int(c) for c in line.split(' ')[3].split(',')]
yield img, label
if img_label_list is not None:
lines = []
with open(img_label_list) as f:
lines = f.readlines()
for img, label in yield_img_and_label(lines):
yield img, label
while cycle:
for img, label in yield_img_and_label(lines):
yield img, label
else:
while True:
img_path = raw_input("Please input the path of image: ")
......@@ -161,14 +184,15 @@ def data_shape():
return DATA_SHAPE
def train(batch_size, train_images_dir=None, train_list_file=None):
def train(batch_size, train_images_dir=None, train_list_file=None, cycle=False):
generator = DataGenerator()
if train_images_dir is None:
data_dir = download_data()
train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
if train_list_file is None:
train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
return generator.train_reader(train_images_dir, train_list_file, batch_size)
return generator.train_reader(train_images_dir, train_list_file, batch_size,
cycle)
def test(batch_size=1, test_images_dir=None, test_list_file=None):
......@@ -182,10 +206,14 @@ def test(batch_size=1, test_images_dir=None, test_list_file=None):
generator.test_reader(test_images_dir, test_list_file), batch_size)
def inference(infer_images_dir=None, infer_list_file=None):
def inference(batch_size=1,
infer_images_dir=None,
infer_list_file=None,
cycle=False):
generator = DataGenerator()
return paddle.batch(
generator.infer_reader(infer_images_dir, infer_list_file), 1)
generator.infer_reader(infer_images_dir, infer_list_file, cycle),
batch_size)
def download_data():
......
"""Trainer for OCR CTC model."""
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
import ctc_reader
......@@ -14,7 +15,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('total_step', int, 720000, "Number of training iterations.")
add_arg('total_step', int, 720000, "The number of iterations. Zero or less means whole training set. More than 0 means the training set might be looped until # of iterations is reached.")
add_arg('log_period', int, 1000, "Log period.")
add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.")
......@@ -25,6 +26,9 @@ add_arg('min_average_window',int, 10000, "Min average window.")
add_arg('max_average_window',int, 12500, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether use parallel training.")
add_arg('profile', bool, False, "Whether to use profiling.")
add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('skip_test', bool, False, "Whether to skip test phase.")
# yapf: enable
......@@ -49,7 +53,8 @@ def train(args, data_reader=ctc_reader):
train_reader = data_reader.train(
args.batch_size,
train_images_dir=train_images,
train_list_file=train_list)
train_list_file=train_list,
cycle=args.total_step > 0)
test_reader = data_reader.test(
test_images_dir=test_images, test_list_file=test_list)
......@@ -74,7 +79,7 @@ def train(args, data_reader=ctc_reader):
error_evaluator.reset(exe)
if args.parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=sum_cost.name)
use_cuda=True if args.use_gpu else False, loss_name=sum_cost.name)
fetch_vars = [sum_cost] + error_evaluator.metrics
......@@ -85,8 +90,8 @@ def train(args, data_reader=ctc_reader):
feed=get_feeder_data(data, place))
results = [np.array(result).sum() for result in results]
else:
results = exe.run(feed=get_feeder_data(data, place),
fetch_list=fetch_vars)
results = train_exe.run(feed=get_feeder_data(data, place),
fetch_list=fetch_vars)
results = [result[0] for result in results]
return results
......@@ -105,17 +110,29 @@ def train(args, data_reader=ctc_reader):
print "Saved model to: %s/%s." % (args.save_model_dir, filename)
iter_num = 0
while True:
stop = False
while not stop:
total_loss = 0.0
total_seq_error = 0.0
batch_times = []
# train a pass
for data in train_reader():
iter_num += 1
if iter_num > args.total_step:
return
if args.total_step > 0 and iter_num == args.total_step + args.skip_batch_num:
stop = True
break
if iter_num < args.skip_batch_num:
print("Warm-up iteration")
if iter_num == args.skip_batch_num:
profiler.reset_profiler()
start = time.time()
results = train_one_batch(data)
batch_time = time.time() - start
fps = args.batch_size / batch_time
batch_times.append(batch_time)
total_loss += results[0]
total_seq_error += results[2]
iter_num += 1
# training log
if iter_num % args.log_period == 0:
print "\nTime: %s; Iter[%d]; Avg Warp-CTC loss: %.3f; Avg seq err: %.3f" % (
......@@ -127,7 +144,7 @@ def train(args, data_reader=ctc_reader):
total_seq_error = 0.0
# evaluate
if iter_num % args.eval_period == 0:
if not args.skip_test and iter_num % args.eval_period == 0:
if model_average:
with model_average.apply(exe):
test(iter_num)
......@@ -141,12 +158,35 @@ def train(args, data_reader=ctc_reader):
save_model(args, exe, iter_num)
else:
save_model(args, exe, iter_num)
# Postprocess benchmark data
latencies = batch_times[args.skip_batch_num:]
latency_avg = np.average(latencies)
latency_pc99 = np.percentile(latencies, 99)
fpses = np.divide(args.batch_size, latencies)
fps_avg = np.average(fpses)
fps_pc99 = np.percentile(fpses, 1)
# Benchmark output
print('\nTotal examples (incl. warm-up): %d' %
(iter_num * args.batch_size))
print('average latency: %.5f s, 99pc latency: %.5f s' % (latency_avg,
latency_pc99))
print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg,
fps_pc99))
def main():
args = parser.parse_args()
print_arguments(args)
train(args, data_reader=ctc_reader)
if args.profile:
if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
train(args, data_reader=ctc_reader)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
train(args, data_reader=ctc_reader)
else:
train(args, data_reader=ctc_reader)
if __name__ == "__main__":
......
import paddle.v2 as paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_infer
import numpy as np
......@@ -7,6 +8,7 @@ import ctc_reader
import argparse
import functools
import os
import time
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
......@@ -16,6 +18,10 @@ add_arg('input_images_dir', str, None, "The directory of images.")
add_arg('input_images_list', str, None, "The list file of images.")
add_arg('dict', str, None, "The dictionary. The result of inference will be index sequence if the dictionary was None.")
add_arg('use_gpu', bool, True, "Whether use GPU to infer.")
add_arg('iterations', int, 0, "The number of iterations. Zero or less means whole test set. More than 0 means the test set might be looped until # of iterations is reached.")
add_arg('profile', bool, False, "Whether to use profiling.")
add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('batch_size', int, 1, "The minibatch size.")
# yapf: enable
......@@ -25,11 +31,14 @@ def inference(args, infer=ctc_infer, data_reader=ctc_reader):
data_shape = data_reader.data_shape()
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
sequence = infer(images, num_classes)
sequence = infer(
images, num_classes, use_cudnn=True if args.use_gpu else False)
# data reader
infer_reader = data_reader.inference(
batch_size=args.batch_size,
infer_images_dir=args.input_images_dir,
infer_list_file=args.input_images_list)
infer_list_file=args.input_images_list,
cycle=True if args.iterations > 0 else False)
# prepare environment
place = fluid.CPUPlace()
if args.use_gpu:
......@@ -56,23 +65,67 @@ def inference(args, infer=ctc_infer, data_reader=ctc_reader):
fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
print "Init model from: %s." % args.model_path
batch_times = []
iters = 0
for data in infer_reader():
if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
break
if iters < args.skip_batch_num:
print("Warm-up itaration")
if iters == args.skip_batch_num:
profiler.reset_profiler()
start = time.time()
result = exe.run(fluid.default_main_program(),
feed=get_feeder_data(
data, place, need_label=False),
fetch_list=[sequence],
return_numpy=False)
batch_time = time.time() - start
fps = args.batch_size / batch_time
batch_times.append(batch_time)
indexes = np.array(result[0]).flatten()
if dict_map is not None:
print "result: %s" % ([dict_map[index] for index in indexes], )
print "Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
iters,
batch_time,
fps,
[dict_map[index] for index in indexes], )
else:
print "result: %s" % (indexes, )
print "Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
iters,
batch_time,
fps,
indexes, )
iters += 1
latencies = batch_times[args.skip_batch_num:]
latency_avg = np.average(latencies)
latency_pc99 = np.percentile(latencies, 99)
fpses = np.divide(args.batch_size, latencies)
fps_avg = np.average(fpses)
fps_pc99 = np.percentile(fpses, 1)
# Benchmark output
print('\nTotal examples (incl. warm-up): %d' % (iters * args.batch_size))
print('average latency: %.5f s, 99pc latency: %.5f s' % (latency_avg,
latency_pc99))
print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg, fps_pc99))
def main():
args = parser.parse_args()
print_arguments(args)
inference(args, data_reader=ctc_reader)
if args.profile:
if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
inference(args, data_reader=ctc_reader)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
inference(args, data_reader=ctc_reader)
else:
inference(args, data_reader=ctc_reader)
if __name__ == "__main__":
......
## Introduction
Scripts enclosed in the folder serve as examples of commands that start training
and inference of a model, and are subject to further customisation.
# Running with MKL-DNN
In order to run training or inference using MKL-DNN library, please use
`FLAGS_use_mkldnn=1` environmental variable.
## Prerequisites
In order to run the training and inference, no special requirements are posed.
## Training
To run training on *CPU*, please execute:
```sh
source train.sh CPU
```
To run training on *CPU* with MKL-DNN, please execute:
```sh
source train.sh MKLDNN
```
To run training on *GPU*, please execute:
```sh
source train.sh GPU
```
## Inference
To perform inference on the trained model using *CPU*, please run:
```sh
source infer.sh CPU
```
To perform inference on the trained model using *CPU* with MKL-DNN, please run:
```sh
source infer.sh MKLDNN
```
To perform inference on the trained model using *GPU*, please run:
```sh
source infer.sh GPU
```
#!/bin/bash
export MKL_NUM_THREADS=1
export OMP_NUM_THREADS=1
mode=$1 # gpu, cpu, mkldnn
if [ "$mode" = "CPU" ]; then
use_gpu="False"
model_path="cpu_model"
elif [ "$mode" = "GPU" ]; then
use_gpu="True"
model_path="gpu_model"
elif [ "$mode" = "MKLDNN" ]; then
use_gpu="False"
model_path="mkldnn_model"
export FLAGS_use_mkldnn=1
else
echo "Invalid mode provided. Please use one of {GPU, CPU, MKLDNN}"
exit 1
fi
ht=`lscpu |grep "per core"|awk -F':' '{print $2}'|xargs`
if [ $ht -eq 1 ]; then # HT is OFF
if [ -z "$KMP_AFFINITY" ]; then
export KMP_AFFINITY="granularity=fine,compact,0,0"
fi
if [ -z "$OMP_DYNAMIC" ]; then
export OMP_DYNAMIC="FALSE"
fi
else # HT is ON
if [ -z "$KMP_AFFINITY" ]; then
export KMP_AFFINITY="granularity=fine,compact,1,0"
fi
fi
python ../infer.py \
--model_path $model_path/model_00001 \
--input_images_list ~/.cache/paddle/dataset/ctc_data/data/test.list \
--input_images_dir ~/.cache/paddle/dataset/ctc_data/data/test_images \
--use_gpu $use_gpu \
--batch_size 32 \
--iterations 5 \
--skip_batch_num 2
#!/bin/bash
export MKL_NUM_THREADS=1
export OMP_NUM_THREADS=1
batch_size=32
core_num=`lscpu |grep -m1 "CPU(s)"|awk -F':' '{print $2}'|xargs`
mode=$1 # gpu, cpu, mkldnn
if [ "$mode" = "CPU" ]; then
if [ $core_num -gt $batch_size ]; then
echo "Batch size should be greater or equal to the number of
available cores, when parallel mode is set to True."
fi
use_gpu="False"
save_model_dir="cpu_model"
parallel="True"
elif [ "$mode" = "GPU" ]; then
use_gpu="True"
save_model_dir="gpu_model"
parallel="True"
elif [ "$mode" = "MKLDNN" ]; then
if [ $core_num -gt $batch_size ]; then
echo "Batch size should be greater or equal to the number of
available cores, when parallel mode is set to True."
fi
use_gpu="False"
save_model_dir="mkldnn_model"
parallel="False"
export FLAGS_use_mkldnn=1
else
echo "Invalid mode provided. Please use one of {GPU, CPU, MKLDNN}"
exit 1
fi
ht=`lscpu |grep "per core"|awk -F':' '{print $2}'|xargs`
if [ $ht -eq 1 ]; then # HT is OFF
if [ -z "$KMP_AFFINITY" ]; then
export KMP_AFFINITY="granularity=fine,compact,0,0"
fi
if [ -z "$OMP_DYNAMIC" ]; then
export OMP_DYNAMIC="FALSE"
fi
else # HT is ON
if [ -z "$KMP_AFFINITY" ]; then
export KMP_AFFINITY="granularity=fine,compact,1,0"
fi
fi
python ../ctc_train.py \
--use_gpu $use_gpu \
--parallel $parallel \
--batch_size $batch_size \
--save_model_period 1 \
--total_step 1 \
--save_model_dir $save_model_dir
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册