提交 5a7c92df 编写于 作者: X Xinghai Sun

Fix an incorrect usage of is_local argument.

上级 dd92a02f
......@@ -7,7 +7,7 @@ MEAN_STD_FILE="../mean_std.npz"
CLOUD_DATA_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/data"
CLOUD_MODEL_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/model"
# Configure cloud resources
NUM_CPU=12
NUM_CPU=8
NUM_GPU=8
NUM_NODE=1
MEMORY="10Gi"
......
......@@ -46,6 +46,7 @@ class DeepSpeech2Model(object):
gradient_clipping,
num_passes,
output_model_dir,
is_local=True,
num_iterations_print=100):
"""Train the model.
......@@ -65,6 +66,8 @@ class DeepSpeech2Model(object):
:param num_iterations_print: Number of training iterations for printing
a training loss.
:type rnn_iteratons_print: int
:param is_local: Set to False if running with pserver with multi-nodes.
:type is_local: bool
:param output_model_dir: Directory for saving the model (every pass).
:type output_model_dir: basestring
"""
......@@ -79,7 +82,8 @@ class DeepSpeech2Model(object):
trainer = paddle.trainer.SGD(
cost=self._loss,
parameters=self._parameters,
update_equation=optimizer)
update_equation=optimizer,
is_local=is_local)
# create event handler
def event_handler(event):
......
......@@ -179,15 +179,13 @@ def train():
gradient_clipping=400,
num_passes=args.num_passes,
num_iterations_print=args.num_iterations_print,
output_model_dir=args.output_model_dir)
output_model_dir=args.output_model_dir,
is_local=args.is_local)
def main():
utils.print_arguments(args)
paddle.init(
use_gpu=args.use_gpu,
trainer_count=args.trainer_count,
is_local=args.is_local)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
train()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册