提交 75b8b0e8 编写于 作者: Z zhangwenhui03

Merge branch 'develop' into 'develop'

add gpu config

See merge request !25
......@@ -28,6 +28,10 @@ from paddlerec.core.utils import dataloader_instance
class TranspileTrainer(Trainer):
def __init__(self, config=None):
Trainer.__init__(self, config)
device = envs.get_global_env("train.device")
if device == 'gpu':
self._place = fluid.CUDAPlace(0)
self._exe = fluid.Executor(self._place)
self.processor_register()
self.model = None
self.inference_models = []
......
......@@ -19,6 +19,7 @@ train:
epochs: 3
workspace: "paddlerec.models.multitask.esmm"
device: cpu
reader:
batch_size: 2
......
......@@ -19,6 +19,7 @@ train:
epochs: 3
workspace: "paddlerec.models.multitask.mmoe"
device: cpu
reader:
batch_size: 2
......
......@@ -19,6 +19,7 @@ train:
epochs: 3
workspace: "paddlerec.models.multitask.share-bottom"
device: cpu
reader:
batch_size: 2
......
......@@ -20,6 +20,7 @@ train:
epochs: 3
workspace: "paddlerec.models.recall.gru4rec"
device: cpu
reader:
batch_size: 5
class: "{workspace}/rsc15_reader.py"
......
......@@ -19,6 +19,7 @@ train:
epochs: 3
workspace: "paddlerec.models.recall.ssr"
device: cpu
reader:
batch_size: 5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册