提交 7fa0e3de 编写于 作者: G Guo Sheng 提交者: pkpk

Add cuda check in Transformer. (#2728)

上级 9a5809fa
...@@ -4,12 +4,14 @@ import multiprocessing ...@@ -4,12 +4,14 @@ import multiprocessing
import numpy as np import numpy as np
import os import os
import sys import sys
sys.path.append("../../")
sys.path.append("../../models/neural_machine_translation/transformer/") sys.path.append("../../models/neural_machine_translation/transformer/")
from functools import partial from functools import partial
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from models.model_check import check_cuda
import reader import reader
from config import * from config import *
from desc import * from desc import *
...@@ -217,6 +219,7 @@ def fast_infer(args): ...@@ -217,6 +219,7 @@ def fast_infer(args):
fluid.memory_optimize(infer_program) fluid.memory_optimize(infer_program)
if InferTaskConfig.use_gpu: if InferTaskConfig.use_gpu:
check_cuda(InferTaskConfig.use_gpu)
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count() dev_count = fluid.core.get_cuda_device_count()
else: else:
......
...@@ -6,12 +6,14 @@ import multiprocessing ...@@ -6,12 +6,14 @@ import multiprocessing
import os import os
import six import six
import sys import sys
sys.path.append("../../")
sys.path.append("../../models/neural_machine_translation/transformer/") sys.path.append("../../models/neural_machine_translation/transformer/")
import time import time
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from models.model_check import check_cuda
import reader import reader
from config import * from config import *
from desc import * from desc import *
...@@ -663,6 +665,7 @@ def train(args): ...@@ -663,6 +665,7 @@ def train(args):
place = fluid.CPUPlace() place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else: else:
check_cuda(TrainTaskConfig.use_gpu)
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) place = fluid.CUDAPlace(gpu_id)
dev_count = get_device_num() dev_count = get_device_num()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册