提交 70e77f98 编写于 作者: X Xin Pan

fix lm

上级 3eacba37
......@@ -21,6 +21,11 @@ def parse_args():
action='store_true',
help='If set, run \
the task with continuous evaluation logs.')
parser.add_argument(
'--num_devices',
type=int,
default=1,
help='Number of GPU devices')
args = parser.parse_args()
return args
......@@ -165,13 +170,13 @@ def train(train_reader,
print("finish training")
def get_cards(enable_ce):
if enable_ce:
def get_cards(args):
if args.enable_ce:
cards = os.environ.get('CUDA_VISIBLE_DEVICES')
num = len(cards.split(","))
return num
else:
return fluid.core.get_cuda_device_count()
return args.num_devices
def train_net():
......@@ -179,7 +184,7 @@ def train_net():
batch_size = 20
args = parse_args()
vocab, train_reader, test_reader = utils.prepare_data(
batch_size=batch_size * get_cards(args.enable_ce), buffer_size=1000, \
batch_size=batch_size * get_cards(args), buffer_size=1000, \
word_freq_threshold=0, enable_ce = args.enable_ce)
train(
train_reader=train_reader,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册