未验证 提交 f85a6c20 编写于 作者: L LielinJiang 提交者: GitHub

fix distributed bug (#160)

上级 438deb1d
......@@ -64,7 +64,10 @@ dataset:
preprocess:
- name: LoadImageFromFile
key: pair
- name: Transforms
- name: SplitPairedImage
key: pair
paired_keys: [A, B]
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
......
......@@ -144,9 +144,9 @@ class Trainer:
self.best_metric = {}
def distributed_data_parallel(self):
strategy = paddle.distributed.prepare_context()
paddle.distributed.init_parallel_env()
for net_name, net in self.model.nets.items():
self.model.nets[net_name] = paddle.DataParallel(net, strategy)
self.model.nets[net_name] = paddle.DataParallel(net)
def learning_rate_scheduler_step(self):
if isinstance(self.model.lr_scheduler, dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册