提交 b48b902a 编写于 作者: G guosheng

Refine train.py in Transformer

上级 05403680
...@@ -32,7 +32,6 @@ class TrainTaskConfig(object): ...@@ -32,7 +32,6 @@ class TrainTaskConfig(object):
start_step = 0 start_step = 0
# the frequency to save trained models. # the frequency to save trained models.
save_freq = 10000 save_freq = 10000
profile=True
class InferTaskConfig(object): class InferTaskConfig(object):
......
import argparse import argparse
import ast import ast
import copy
import logging
import multiprocessing import multiprocessing
import os import os
import six import six
import sys
import time import time
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.profiler as profiler from paddle.fluid.transpiler.details import program_to_code
import reader import reader
from config import * from config import *
from model import transformer, position_encoding_init from model import transformer, position_encoding_init
from paddle.fluid.transpiler.details import program_to_code
import logging
import sys
import copy
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.") parser = argparse.ArgumentParser("Training for Transformer.")
...@@ -120,7 +117,7 @@ def parse_args(): ...@@ -120,7 +117,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--use_mem_opt", "--use_mem_opt",
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=True,
help="The flag indicating whether to use memory optimization.") help="The flag indicating whether to use memory optimization.")
parser.add_argument( parser.add_argument(
"--use_py_reader", "--use_py_reader",
...@@ -128,10 +125,10 @@ def parse_args(): ...@@ -128,10 +125,10 @@ def parse_args():
default=True, default=True,
help="The flag indicating whether to use py_reader.") help="The flag indicating whether to use py_reader.")
parser.add_argument( parser.add_argument(
"--fetch_steps", type=int, default=100, help="Fetch outputs steps.") "--fetch_steps",
type=int,
#parser.add_argument( default=100,
# '--profile', action='store_true', help='If set, profile a few steps.') help="The frequency to fetch and print output.")
args = parser.parse_args() args = parser.parse_args()
# Append args related to dict # Append args related to dict
...@@ -476,12 +473,7 @@ def train_loop(exe, ...@@ -476,12 +473,7 @@ def train_loop(exe,
# Since the token number differs among devices, customize gradient scale to # Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is # use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost. # `1 / token_number` for average cost.
# build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
#build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
exec_strategy = fluid.ExecutionStrategy()
#if args.update_method == "nccl2":
exec_strategy.num_threads = 1
logging.info("begin executor") logging.info("begin executor")
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
...@@ -517,32 +509,16 @@ def train_loop(exe, ...@@ -517,32 +509,16 @@ def train_loop(exe,
data_generator = train_data() data_generator = train_data()
batch_id = 0 batch_id = 0
avg_batch_time = time.time()
while True: while True:
try: try:
feed_dict_list = prepare_feed_dict_list(data_generator, feed_dict_list = prepare_feed_dict_list(data_generator,
init_flag, dev_count) init_flag, dev_count)
if TrainTaskConfig.profile and batch_id == 5:
logging.info("begin profiler")
profiler.start_profiler("All")
profiler.reset_profiler()
elif TrainTaskConfig.profile and batch_id == 10:
logging.info("end profiler")
#logging.info("profiling total time: ", time.time() - start_time)
profiler.stop_profiler(
"total", "./transformer_local_profile_{}_pass{}".format(
batch_id, pass_id))
sys.exit(0)
logging.info("batch_id:{}".format(batch_id))
outs = train_exe.run( outs = train_exe.run(
fetch_list=[sum_cost.name, token_num.name] fetch_list=[sum_cost.name, token_num.name]
if (batch_id % args.fetch_steps == 0 or if step_idx % args.fetch_steps == 0 else [],
TrainTaskConfig.profile) else [],
feed=feed_dict_list) feed=feed_dict_list)
if (batch_id % args.fetch_steps == 0 and batch_id > 0): if step_idx % args.fetch_steps == 0:
sum_cost_val, token_num_val = np.array(outs[0]), np.array( sum_cost_val, token_num_val = np.array(outs[0]), np.array(
outs[1]) outs[1])
# sum the cost from multi-devices # sum the cost from multi-devices
...@@ -550,16 +526,25 @@ def train_loop(exe, ...@@ -550,16 +526,25 @@ def train_loop(exe,
total_token_num = token_num_val.sum() total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num total_avg_cost = total_sum_cost / total_token_num
logging.info( if step_idx == 0:
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " logging.info(
"normalized loss: %f, ppl: %f, speed: %.2f step/s" % "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
(step_idx, pass_id, batch_id, total_avg_cost, "normalized loss: %f, ppl: %f" %
total_avg_cost - loss_normalizer, (step_idx, pass_id, batch_id, total_avg_cost,
np.exp([min(total_avg_cost, 100)]), total_avg_cost - loss_normalizer,
args.fetch_steps / (time.time() - avg_batch_time))) np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
if step_idx % int(TrainTaskConfig. else:
save_freq) == TrainTaskConfig.save_freq - 1: logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.fetch_steps / (time.time() - avg_batch_time)))
avg_batch_time = time.time()
if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0:
fluid.io.save_persistables( fluid.io.save_persistables(
exe, exe,
os.path.join(TrainTaskConfig.ckpt_dir, os.path.join(TrainTaskConfig.ckpt_dir,
...@@ -569,8 +554,7 @@ def train_loop(exe, ...@@ -569,8 +554,7 @@ def train_loop(exe,
os.path.join(TrainTaskConfig.model_dir, os.path.join(TrainTaskConfig.model_dir,
"iter_" + str(step_idx) + ".infer.model"), "iter_" + str(step_idx) + ".infer.model"),
train_prog) train_prog)
if batch_id % args.fetch_steps == 0 and batch_id > 0:
avg_batch_time = time.time()
init_flag = False init_flag = False
batch_id += 1 batch_id += 1
step_idx += 1 step_idx += 1
...@@ -591,10 +575,12 @@ def train_loop(exe, ...@@ -591,10 +575,12 @@ def train_loop(exe,
time_consumed)) time_consumed))
else: else:
logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed)) logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
fluid.io.save_persistables( if not args.enable_ce:
exe, fluid.io.save_persistables(
os.path.join(TrainTaskConfig.ckpt_dir, exe,
"pass_" + str(pass_id) + ".checkpoint"), train_prog) os.path.join(TrainTaskConfig.ckpt_dir,
"pass_" + str(pass_id) + ".checkpoint"),
train_prog)
if args.enable_ce: # For CE if args.enable_ce: # For CE
print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost)) print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
...@@ -697,8 +683,7 @@ def train(args): ...@@ -697,8 +683,7 @@ def train(args):
append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint) append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint)
train_loop(exe, train_loop(exe,
fluid.default_main_program(), dev_count, sum_cost, fluid.default_main_program(), dev_count, sum_cost,
avg_cost, lr_scheduler, token_num, predict, trainers_num, avg_cost, token_num, predict, trainers_num, trainer_id)
trainer_id)
return return
port = os.getenv("PADDLE_PORT", "6174") port = os.getenv("PADDLE_PORT", "6174")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册