diff --git a/deep_speech_2/cloud/pcloud_submit.sh b/deep_speech_2/cloud/pcloud_submit.sh index 2fb80d667722c31fefb8c0b552d967e3af5dc64f..3a64f32e245d90eb88af1c36cc2726551e03a9f7 100644 --- a/deep_speech_2/cloud/pcloud_submit.sh +++ b/deep_speech_2/cloud/pcloud_submit.sh @@ -7,7 +7,7 @@ MEAN_STD_FILE="../mean_std.npz" CLOUD_DATA_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/data" CLOUD_MODEL_DIR="/pfs/dlnel/home/sunxinghai@baidu.com/deepspeech2/model" # Configure cloud resources -NUM_CPU=12 +NUM_CPU=8 NUM_GPU=8 NUM_NODE=1 MEMORY="10Gi" diff --git a/deep_speech_2/model.py b/deep_speech_2/model.py index e2f2903b6ecff653c1dd032308c1cdd7eb4a175d..99412e595af43fa3af16cb7d09274bf19c473eca 100644 --- a/deep_speech_2/model.py +++ b/deep_speech_2/model.py @@ -46,6 +46,7 @@ class DeepSpeech2Model(object): gradient_clipping, num_passes, output_model_dir, + is_local=True, num_iterations_print=100): """Train the model. @@ -65,6 +66,8 @@ class DeepSpeech2Model(object): :param num_iterations_print: Number of training iterations for printing a training loss. :type rnn_iteratons_print: int + :param is_local: Set to False if running with pserver with multi-nodes. + :type is_local: bool :param output_model_dir: Directory for saving the model (every pass). :type output_model_dir: basestring """ @@ -79,7 +82,8 @@ class DeepSpeech2Model(object): trainer = paddle.trainer.SGD( cost=self._loss, parameters=self._parameters, - update_equation=optimizer) + update_equation=optimizer, + is_local=is_local) # create event handler def event_handler(event): diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index 379e364c9266dd3fd6aab72a1cf49b1b491c23e9..262d8bf0125bec3b225337c3f5de299be3f44ba1 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -179,15 +179,13 @@ def train(): gradient_clipping=400, num_passes=args.num_passes, num_iterations_print=args.num_iterations_print, - output_model_dir=args.output_model_dir) + output_model_dir=args.output_model_dir, + is_local=args.is_local) def main(): utils.print_arguments(args) - paddle.init( - use_gpu=args.use_gpu, - trainer_count=args.trainer_count, - is_local=args.is_local) + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) train()