提交 2a5b5313 编写于 作者: Z zhangwenhui03

add gpu config

上级 07dbb36d
...@@ -28,6 +28,10 @@ from paddlerec.core.utils import dataloader_instance ...@@ -28,6 +28,10 @@ from paddlerec.core.utils import dataloader_instance
class TranspileTrainer(Trainer): class TranspileTrainer(Trainer):
def __init__(self, config=None): def __init__(self, config=None):
Trainer.__init__(self, config) Trainer.__init__(self, config)
device = envs.get_global_env("train.device")
if device == 'gpu':
self._place = fluid.CUDAPlace(0)
self._exe = fluid.Executor(self._place)
self.processor_register() self.processor_register()
self.model = None self.model = None
self.inference_models = [] self.inference_models = []
......
...@@ -19,6 +19,7 @@ train: ...@@ -19,6 +19,7 @@ train:
epochs: 3 epochs: 3
workspace: "paddlerec.models.multitask.esmm" workspace: "paddlerec.models.multitask.esmm"
device: cpu
reader: reader:
batch_size: 2 batch_size: 2
......
...@@ -19,6 +19,7 @@ train: ...@@ -19,6 +19,7 @@ train:
epochs: 3 epochs: 3
workspace: "paddlerec.models.multitask.mmoe" workspace: "paddlerec.models.multitask.mmoe"
device: cpu
reader: reader:
batch_size: 2 batch_size: 2
......
...@@ -19,6 +19,7 @@ train: ...@@ -19,6 +19,7 @@ train:
epochs: 3 epochs: 3
workspace: "paddlerec.models.multitask.share-bottom" workspace: "paddlerec.models.multitask.share-bottom"
device: cpu
reader: reader:
batch_size: 2 batch_size: 2
......
...@@ -20,6 +20,7 @@ train: ...@@ -20,6 +20,7 @@ train:
epochs: 3 epochs: 3
workspace: "paddlerec.models.recall.gru4rec" workspace: "paddlerec.models.recall.gru4rec"
device: cpu
reader: reader:
batch_size: 5 batch_size: 5
class: "{workspace}/rsc15_reader.py" class: "{workspace}/rsc15_reader.py"
......
...@@ -19,6 +19,7 @@ train: ...@@ -19,6 +19,7 @@ train:
epochs: 3 epochs: 3
workspace: "paddlerec.models.recall.ssr" workspace: "paddlerec.models.recall.ssr"
device: cpu
reader: reader:
batch_size: 5 batch_size: 5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册