提交 7d38d810 编写于 作者: S ShawnXuan

bert add xla option

上级 6673c2dc
......@@ -48,14 +48,16 @@ def get_parser(parser=None):
help='node/machine number for training')
parser.add_argument('--node_ips', type=str_list, default=['192.168.1.13', '192.168.1.14'],
help='nodes ip list for training, devided by ",", length >= num_nodes')
# train
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--weight_decay_rate", type=float, default=0.01, help="weight decay rate")
parser.add_argument("--warmup_proportion", type=float, default=0.1)
parser.add_argument('--use_fp16', type=str2bool, nargs='?', default='False', const=True,
parser.add_argument('--use_fp16', type=str2bool, nargs='?', default='False', const=True,
help='use use fp16 or not')
parser.add_argument('--use_xla', type=str2bool, nargs='?', const=True,
help='Whether to use use xla')
# log and resore/save
parser.add_argument("--loss_print_every_n_iter", type=int, default=10, required=False,
help="print loss every n iteration")
......@@ -68,7 +70,7 @@ def get_parser(parser=None):
help="save model snapshot for last iteration")
parser.add_argument("--model_load_dir", type=str, default=None, help="model load directory")
parser.add_argument("--log_dir", type=str, default="./output", help="log info save directory")
# bert backbone
parser.add_argument('--do_lower_case', type=str2bool, nargs='?', const=True, default='True')
parser.add_argument("--seq_length", type=int, default=512)
......@@ -81,7 +83,7 @@ def get_parser(parser=None):
parser.add_argument("--attention_probs_dropout_prob", type=float, default=0.1)
parser.add_argument("--hidden_dropout_prob", type=float, default=0.1)
parser.add_argument("--hidden_size_per_head", type=int, default=64)
return parser
......
......@@ -131,7 +131,7 @@ class Metric(object):
self.metric_dict[key] = 0.0
self.metric_dict['throughput'] = 0.0
self.num_samples = 0.0
def update_and_save(self, key, value, step, **kwargs):
self.metric_dict[key] = value
if self.save_summary:
......@@ -164,14 +164,16 @@ class Metric(object):
def CreateOptimizer(args):
warmup_batches = int(args.iter_num * args.warmup_proportion)
lr_warmup = flow.optimizer.warmup.linear(warmup_batches, 0)
lr_scheduler = flow.optimizer.PolynomialSchduler(args.learning_rate, args.iter_num, 0.0,
lr_scheduler = flow.optimizer.PolynomialSchduler(args.learning_rate, args.iter_num, 0.0,
warmup=lr_warmup)
return flow.optimizer.AdamW(lr_scheduler, epsilon=1e-6, weight_decay=args.weight_decay_rate,
return flow.optimizer.AdamW(lr_scheduler, epsilon=1e-6, weight_decay=args.weight_decay_rate,
weight_decay_excludes=["bias", "LayerNorm", "layer_norm"],
grad_clipping=flow.optimizer.grad_clipping.by_global_norm(1.0))
def GetFunctionConfig(args):
config = flow.function_config()
config.enable_auto_mixed_precision(args.use_fp16)
if args.use_xla:
config.use_xla_jit(True)
return config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册