未验证 提交 8fd56a45 编写于 作者: L littletomatodonkey 提交者: GitHub

fix static train (#478)

上级 29b305d2
......@@ -86,16 +86,21 @@ def create_model(architecture, image, classes_num, config, is_train):
use_pure_fp16 = config.get("use_pure_fp16", False)
name = architecture["name"]
params = architecture.get("params", {})
data_format = config.get("data_format", "NCHW")
data_format = "NCHW"
if "data_format" in config:
params["data_format"] = config["data_format"]
data_format = config["data_format"]
input_image_channel = config.get('image_shape', [3, 224, 224])[0]
if input_image_channel != 3:
logger.warning(
"Input image channel is changed to {}, maybe for better speed-up".
format(input_image_channel))
params["input_image_channel"] = input_image_channel
if "is_test" in params:
params['is_test'] = not is_train
model = architectures.__dict__[name](
class_dim=classes_num,
input_image_channel=input_image_channel,
data_format=data_format,
**params)
model = architectures.__dict__[name](class_dim=classes_num, **params)
if use_pure_fp16 and not config.get("use_dali", False):
image = image.astype('float16')
if data_format == "NHWC":
......@@ -352,7 +357,10 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
and config.get("use_dali", False):
image_dtype = "float16"
feeds = create_feeds(
config.image_shape, use_mix=use_mix, use_dali=use_dali, dtype = image_dtype)
config.image_shape,
use_mix=use_mix,
use_dali=use_dali,
dtype=image_dtype)
if use_dali and use_mix:
import dali
feeds = dali.mix(feeds, config, is_train)
......@@ -395,9 +403,11 @@ def compile(config, program, loss_name=None, share_prog=None):
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get('use_pure_fp16', False) else 10
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get(
'use_pure_fp16', False) else 10
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16', False)
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
False)
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
......
......@@ -65,10 +65,7 @@ def main(args):
if config.get("is_distributed", True):
fleet.init(is_collective=True)
# assign the place
use_gpu = config.get("use_gpu", False)
assert use_gpu is True, "gpu must be true in static mode!"
place = paddle.set_device("gpu")
use_gpu = config.get("use_gpu", True)
# amp related config
use_amp = config.get('use_amp', False)
use_pure_fp16 = config.get('use_pure_fp16', False)
......@@ -122,7 +119,7 @@ def main(args):
exe = paddle.static.Executor(place)
# Parameter initialization
exe.run(startup_prog)
if config.get("use_pure_fp16", False):
if config.get("use_pure_fp16", False):
cast_parameters_to_fp16(place, train_prog, fluid.global_scope())
# load pretrained models or checkpoints
init_model(config, train_prog, exe)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册