提交 48afe86f 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix dygpaph training speed

上级 6cae5aaf
......@@ -21,6 +21,7 @@ import time
from collections import OrderedDict
import paddle
import paddle.fluid as fluid
from ppcls.optimizer import LearningRateBuilder
......@@ -280,7 +281,7 @@ def mixed_precision_optimizer(config, optimizer):
def create_feeds(batch, use_mix):
image = to_variable(batch[0].numpy().astype("float32"))
image = batch[0]
if use_mix:
y_a = to_variable(batch[1].numpy().astype("int64").reshape(-1, 1))
y_b = to_variable(batch[2].numpy().astype("int64").reshape(-1, 1))
......
......@@ -57,13 +57,14 @@ def main(args):
with fluid.dygraph.guard(place):
net = program.create_model(config.ARCHITECTURE, config.classes_num)
if config["use_data_parallel"]:
strategy = fluid.dygraph.parallel.prepare_context()
net = fluid.dygraph.parallel.DataParallel(net, strategy)
optimizer = program.create_optimizer(
config, parameter_list=net.parameters())
if config["use_data_parallel"]:
strategy = fluid.dygraph.parallel.prepare_context()
net = fluid.dygraph.parallel.DataParallel(net, strategy)
# load model from checkpoint or pretrained model
init_model(config, net, optimizer)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册