From 2a5b5313dca2465e87a7f71e4fef475c48a18ab6 Mon Sep 17 00:00:00 2001 From: zhangwenhui03 Date: Thu, 14 May 2020 14:10:53 +0800 Subject: [PATCH] add gpu config --- core/trainers/transpiler_trainer.py | 4 ++++ models/multitask/esmm/config.yaml | 1 + models/multitask/mmoe/config.yaml | 1 + models/multitask/share-bottom/config.yaml | 1 + models/recall/gru4rec/config.yaml | 1 + models/recall/ssr/config.yaml | 1 + 6 files changed, 9 insertions(+) diff --git a/core/trainers/transpiler_trainer.py b/core/trainers/transpiler_trainer.py index 3dc8bfd6..fa0d4540 100755 --- a/core/trainers/transpiler_trainer.py +++ b/core/trainers/transpiler_trainer.py @@ -28,6 +28,10 @@ from paddlerec.core.utils import dataloader_instance class TranspileTrainer(Trainer): def __init__(self, config=None): Trainer.__init__(self, config) + device = envs.get_global_env("train.device") + if device == 'gpu': + self._place = fluid.CUDAPlace(0) + self._exe = fluid.Executor(self._place) self.processor_register() self.model = None self.inference_models = [] diff --git a/models/multitask/esmm/config.yaml b/models/multitask/esmm/config.yaml index a83c145c..18b47f89 100644 --- a/models/multitask/esmm/config.yaml +++ b/models/multitask/esmm/config.yaml @@ -19,6 +19,7 @@ train: epochs: 3 workspace: "paddlerec.models.multitask.esmm" + device: cpu reader: batch_size: 2 diff --git a/models/multitask/mmoe/config.yaml b/models/multitask/mmoe/config.yaml index 4f79fbd8..e537b81e 100644 --- a/models/multitask/mmoe/config.yaml +++ b/models/multitask/mmoe/config.yaml @@ -19,6 +19,7 @@ train: epochs: 3 workspace: "paddlerec.models.multitask.mmoe" + device: cpu reader: batch_size: 2 diff --git a/models/multitask/share-bottom/config.yaml b/models/multitask/share-bottom/config.yaml index f0ace882..64d61ed4 100644 --- a/models/multitask/share-bottom/config.yaml +++ b/models/multitask/share-bottom/config.yaml @@ -19,6 +19,7 @@ train: epochs: 3 workspace: "paddlerec.models.multitask.share-bottom" + device: cpu reader: batch_size: 2 diff --git a/models/recall/gru4rec/config.yaml b/models/recall/gru4rec/config.yaml index 71c212ad..2668fb9a 100644 --- a/models/recall/gru4rec/config.yaml +++ b/models/recall/gru4rec/config.yaml @@ -20,6 +20,7 @@ train: epochs: 3 workspace: "paddlerec.models.recall.gru4rec" + device: cpu reader: batch_size: 5 class: "{workspace}/rsc15_reader.py" diff --git a/models/recall/ssr/config.yaml b/models/recall/ssr/config.yaml index b6bcbffc..0682c652 100644 --- a/models/recall/ssr/config.yaml +++ b/models/recall/ssr/config.yaml @@ -19,6 +19,7 @@ train: epochs: 3 workspace: "paddlerec.models.recall.ssr" + device: cpu reader: batch_size: 5 -- GitLab