diff --git a/core/trainers/transpiler_trainer.py b/core/trainers/transpiler_trainer.py index 3dc8bfd6e04fb3da2b217f84126fac8101102945..fa0d454067fecdfb94d24f30594fe61548150cf1 100755 --- a/core/trainers/transpiler_trainer.py +++ b/core/trainers/transpiler_trainer.py @@ -28,6 +28,10 @@ from paddlerec.core.utils import dataloader_instance class TranspileTrainer(Trainer): def __init__(self, config=None): Trainer.__init__(self, config) + device = envs.get_global_env("train.device") + if device == 'gpu': + self._place = fluid.CUDAPlace(0) + self._exe = fluid.Executor(self._place) self.processor_register() self.model = None self.inference_models = [] diff --git a/models/multitask/esmm/config.yaml b/models/multitask/esmm/config.yaml index a83c145c9779661a6bb35fb177ec5a53aa4ad56e..18b47f893089badf28841814d5ef367121b1a46e 100644 --- a/models/multitask/esmm/config.yaml +++ b/models/multitask/esmm/config.yaml @@ -19,6 +19,7 @@ train: epochs: 3 workspace: "paddlerec.models.multitask.esmm" + device: cpu reader: batch_size: 2 diff --git a/models/multitask/mmoe/config.yaml b/models/multitask/mmoe/config.yaml index 4f79fbd8981346e59c1904d80fcc9aa4c67bc3ac..e537b81e01174af36a449e36b5bac2412f06d9d5 100644 --- a/models/multitask/mmoe/config.yaml +++ b/models/multitask/mmoe/config.yaml @@ -19,6 +19,7 @@ train: epochs: 3 workspace: "paddlerec.models.multitask.mmoe" + device: cpu reader: batch_size: 2 diff --git a/models/multitask/share-bottom/config.yaml b/models/multitask/share-bottom/config.yaml index f0ace882c3531ce05bee8af954220b84070f918b..64d61ed4e3b003f47ad78330362ffac0707ff50d 100644 --- a/models/multitask/share-bottom/config.yaml +++ b/models/multitask/share-bottom/config.yaml @@ -19,6 +19,7 @@ train: epochs: 3 workspace: "paddlerec.models.multitask.share-bottom" + device: cpu reader: batch_size: 2 diff --git a/models/recall/gru4rec/config.yaml b/models/recall/gru4rec/config.yaml index 71c212adbb6c200d74d121a517af17d287e2c17d..2668fb9a55efa3ec411c92c770acc1b8158e7e88 100644 --- a/models/recall/gru4rec/config.yaml +++ b/models/recall/gru4rec/config.yaml @@ -20,6 +20,7 @@ train: epochs: 3 workspace: "paddlerec.models.recall.gru4rec" + device: cpu reader: batch_size: 5 class: "{workspace}/rsc15_reader.py" diff --git a/models/recall/ssr/config.yaml b/models/recall/ssr/config.yaml index b6bcbffce6e144c14c22b5a72064887be1a7e025..0682c652a0e8c3cd6912afa081474fa9fa0bb8dd 100644 --- a/models/recall/ssr/config.yaml +++ b/models/recall/ssr/config.yaml @@ -19,6 +19,7 @@ train: epochs: 3 workspace: "paddlerec.models.recall.ssr" + device: cpu reader: batch_size: 5