From 7fa0e3de1f121aa7a1d03ecc059188720c269b47 Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Fri, 5 Jul 2019 17:13:20 +0800 Subject: [PATCH] Add cuda check in Transformer. (#2728) --- PaddleNLP/neural_machine_translation/transformer/infer.py | 3 +++ PaddleNLP/neural_machine_translation/transformer/train.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/PaddleNLP/neural_machine_translation/transformer/infer.py b/PaddleNLP/neural_machine_translation/transformer/infer.py index aaf813a5..cb40e685 100644 --- a/PaddleNLP/neural_machine_translation/transformer/infer.py +++ b/PaddleNLP/neural_machine_translation/transformer/infer.py @@ -4,12 +4,14 @@ import multiprocessing import numpy as np import os import sys +sys.path.append("../../") sys.path.append("../../models/neural_machine_translation/transformer/") from functools import partial import paddle import paddle.fluid as fluid +from models.model_check import check_cuda import reader from config import * from desc import * @@ -217,6 +219,7 @@ def fast_infer(args): fluid.memory_optimize(infer_program) if InferTaskConfig.use_gpu: + check_cuda(InferTaskConfig.use_gpu) place = fluid.CUDAPlace(0) dev_count = fluid.core.get_cuda_device_count() else: diff --git a/PaddleNLP/neural_machine_translation/transformer/train.py b/PaddleNLP/neural_machine_translation/transformer/train.py index b8e6c95f..f284c9c6 100644 --- a/PaddleNLP/neural_machine_translation/transformer/train.py +++ b/PaddleNLP/neural_machine_translation/transformer/train.py @@ -6,12 +6,14 @@ import multiprocessing import os import six import sys +sys.path.append("../../") sys.path.append("../../models/neural_machine_translation/transformer/") import time import numpy as np import paddle.fluid as fluid +from models.model_check import check_cuda import reader from config import * from desc import * @@ -663,6 +665,7 @@ def train(args): place = fluid.CPUPlace() dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) else: + check_cuda(TrainTaskConfig.use_gpu) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) place = fluid.CUDAPlace(gpu_id) dev_count = get_device_num() -- GitLab