未验证 提交 f85a6c20 编写于 作者: L LielinJiang 提交者: GitHub

fix distributed bug (#160)

上级 438deb1d
...@@ -64,7 +64,10 @@ dataset: ...@@ -64,7 +64,10 @@ dataset:
preprocess: preprocess:
- name: LoadImageFromFile - name: LoadImageFromFile
key: pair key: pair
- name: Transforms - name: SplitPairedImage
key: pair
paired_keys: [A, B]
- name: Transforms
input_keys: [A, B] input_keys: [A, B]
pipeline: pipeline:
- name: Resize - name: Resize
......
...@@ -144,9 +144,9 @@ class Trainer: ...@@ -144,9 +144,9 @@ class Trainer:
self.best_metric = {} self.best_metric = {}
def distributed_data_parallel(self): def distributed_data_parallel(self):
strategy = paddle.distributed.prepare_context() paddle.distributed.init_parallel_env()
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
self.model.nets[net_name] = paddle.DataParallel(net, strategy) self.model.nets[net_name] = paddle.DataParallel(net)
def learning_rate_scheduler_step(self): def learning_rate_scheduler_step(self):
if isinstance(self.model.lr_scheduler, dict): if isinstance(self.model.lr_scheduler, dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册