From e7a033d5bea525a530c19426b0bf6f4e6593b9f0 Mon Sep 17 00:00:00 2001 From: ranqiu Date: Tue, 31 Oct 2017 19:11:48 +0800 Subject: [PATCH] Fix bugs of DSSM --- dssm/infer.py | 5 +++-- dssm/train.py | 11 ++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/dssm/infer.py b/dssm/infer.py index dc5595ab..f0c65e44 100644 --- a/dssm/infer.py +++ b/dssm/infer.py @@ -1,5 +1,6 @@ import argparse import itertools +import distutils.util import reader import paddle.v2 as paddle @@ -56,12 +57,12 @@ parser.add_argument( (ModelArch.CNN_MODE, ModelArch.FC_MODE, ModelArch.RNN_MODE)) parser.add_argument( '--share_network_between_source_target', - type=bool, + type=distutils.util.strtobool, default=False, help="whether to share network parameters between source and target") parser.add_argument( '--share_embed', - type=bool, + type=distutils.util.strtobool, default=False, help="whether to share word embedding between source and target") parser.add_argument( diff --git a/dssm/train.py b/dssm/train.py index bc7685ab..a7694877 100644 --- a/dssm/train.py +++ b/dssm/train.py @@ -1,4 +1,5 @@ import argparse +import distutils.util import paddle.v2 as paddle from network_conf import DSSM @@ -35,8 +36,8 @@ parser.add_argument( '-b', '--batch_size', type=int, - default=10, - help="size of mini-batch (default:10)") + default=32, + help="size of mini-batch (default:32)") parser.add_argument( '-p', '--num_passes', @@ -62,12 +63,12 @@ parser.add_argument( (ModelArch.CNN_MODE, ModelArch.FC_MODE, ModelArch.RNN_MODE)) parser.add_argument( '--share_network_between_source_target', - type=bool, + type=distutils.util.strtobool, default=False, help="whether to share network parameters between source and target") parser.add_argument( '--share_embed', - type=bool, + type=distutils.util.strtobool, default=False, help="whether to share word embedding between source and target") parser.add_argument( @@ -80,7 +81,7 @@ parser.add_argument( '--num_workers', type=int, default=1, help="num worker threads, default 1") parser.add_argument( '--use_gpu', - type=bool, + type=distutils.util.strtobool, default=False, help="whether to use GPU devices (default: False)") parser.add_argument( -- GitLab