未验证 提交 3cdb8562 编写于 作者: H hutuxian 提交者: GitHub

DIN: fix bug in using place (#4386)

上级 e041a960
...@@ -80,7 +80,7 @@ def infer(): ...@@ -80,7 +80,7 @@ def infer():
loader = fluid.io.DataLoader.from_generator( loader = fluid.io.DataLoader.from_generator(
feed_list=[inference_program.block(0).var(e) for e in feed_target_names], capacity=10000, iterable=True) feed_list=[inference_program.block(0).var(e) for e in feed_target_names], capacity=10000, iterable=True)
loader.set_sample_list_generator(data_reader, places=fluid.cuda_places()) loader.set_sample_list_generator(data_reader, places=place)
loss_sum = 0.0 loss_sum = 0.0
score = [] score = []
......
...@@ -96,7 +96,7 @@ def train(): ...@@ -96,7 +96,7 @@ def train():
loader = fluid.io.DataLoader.from_generator( loader = fluid.io.DataLoader.from_generator(
feed_list=feed_list, capacity=10000, iterable=True) feed_list=feed_list, capacity=10000, iterable=True)
loader.set_sample_list_generator(data_reader, places=fluid.cuda_places()) loader.set_sample_list_generator(data_reader, places=place)
if use_parallel: if use_parallel:
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=avg_cost.name) use_cuda=use_cuda, loss_name=avg_cost.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册