job_function_util.py 955 字节
Newer Older
S
ShawnXuan 已提交
1 2 3 4 5
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import oneflow as flow
6
from optimizer_util import gen_model_update_conf
S
ShawnXuan 已提交
7

M
format  
mir-of 已提交
8

S
ShawnXuan 已提交
9 10
def _default_config(args):
    config = flow.function_config()
F
Flowingsun007 已提交
11
    config.default_distribute_strategy(flow.scope.consistent_view())
S
ShawnXuan 已提交
12
    config.default_data_type(flow.float)
S
ShawnXuan 已提交
13 14
    if args.use_fp16:
        config.enable_auto_mixed_precision(True)
S
ShawnXuan 已提交
15 16
    return config

M
format  
mir-of 已提交
17

S
ShawnXuan 已提交
18 19 20
def get_train_config(args):
    train_config = _default_config(args)
    train_config.train.primary_lr(args.learning_rate)
21
    train_config.disable_all_reduce_sequence(False)
S
ShawnXuan 已提交
22 23
    train_config.all_reduce_group_min_mbyte(8)
    train_config.all_reduce_group_num(128)
S
ShawnXuan 已提交
24

M
format  
mir-of 已提交
25

S
ShawnXuan 已提交
26
    train_config.prune_parallel_cast_ops(True)
27
    train_config.train.model_update_conf(gen_model_update_conf(args))
S
ShawnXuan 已提交
28 29 30
    train_config.enable_inplace(True)
    return train_config

M
format  
mir-of 已提交
31

S
ShawnXuan 已提交
32
def get_val_config(args):
33
    return _default_config(args)