未验证 提交 831a3e62 编写于 作者: J Jiabin Yang 提交者: GitHub

Add install check for multigpu (#18323)

* test=develop, add_install_check_for_multigpu

* test=develop, refine code to use cuda_devices
上级 f88e07a0
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .framework import Program, program_guard, unique_name, default_startup_program import os
from .framework import Program, program_guard, unique_name, cuda_places, cpu_places
from .param_attr import ParamAttr from .param_attr import ParamAttr
from .initializer import Constant from .initializer import Constant
from . import layers from . import layers
...@@ -24,7 +25,6 @@ from . import core ...@@ -24,7 +25,6 @@ from . import core
from . import compiler from . import compiler
import logging import logging
import numpy as np import numpy as np
import os
__all__ = ['run_check'] __all__ = ['run_check']
...@@ -48,39 +48,43 @@ def run_check(): ...@@ -48,39 +48,43 @@ def run_check():
This func should not be called only if you need to verify installation This func should not be called only if you need to verify installation
''' '''
print("Running Verify Fluid Program ... ") print("Running Verify Fluid Program ... ")
use_cuda = False if not core.is_compiled_with_cuda() else True
place = core.CPUPlace() if not core.is_compiled_with_cuda( device_list = []
) else core.CUDAPlace(0) if core.is_compiled_with_cuda():
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) try:
core.get_cuda_device_count()
if use_cuda: except Exception as e:
if core.get_cuda_device_count() > 1: logging.warning(
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1" "You are using GPU version Paddle Fluid, But Your CUDA Device is not set properly"
"\n Original Error is {}".format(e))
return 0
device_list = cuda_places()
else: else:
os.environ['CUDA_VISIBLE_DEVICES'] = "0" device_list = [core.CPUPlace(), core.CPUPlace()]
np_inp_single = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
inp = []
for i in range(len(device_list)):
inp.append(np_inp_single)
np_inp_muti = np.array(inp)
np_inp_muti = np_inp_muti.reshape(len(device_list), 2, 2)
def test_parallerl_exe(): def test_parallerl_exe():
train_prog = Program() train_prog = Program()
startup_prog = Program() startup_prog = Program()
scope = core.Scope() scope = core.Scope()
if not use_cuda:
os.environ['CPU_NUM'] = "2"
with executor.scope_guard(scope): with executor.scope_guard(scope):
with program_guard(train_prog, startup_prog): with program_guard(train_prog, startup_prog):
with unique_name.guard(): with unique_name.guard():
places = []
build_strategy = compiler.BuildStrategy() build_strategy = compiler.BuildStrategy()
build_strategy.enable_inplace = True build_strategy.enable_inplace = True
build_strategy.memory_optimize = True build_strategy.memory_optimize = True
inp = layers.data( inp = layers.data(name="inp", shape=[2, 2])
name="inp", shape=[2, 2], append_batch_size=False)
simple_layer = SimpleLayer("simple_layer") simple_layer = SimpleLayer("simple_layer")
out = simple_layer(inp) out = simple_layer(inp)
exe = executor.Executor(place) exe = executor.Executor(
if use_cuda: core.CUDAPlace(0) if core.is_compiled_with_cuda() and
places = [core.CUDAPlace(0), core.CUDAPlace(1)] (core.get_cuda_device_count() > 0) else core.CPUPlace())
else:
places = [core.CPUPlace(), core.CPUPlace()]
loss = layers.mean(out) loss = layers.mean(out)
loss.persistable = True loss.persistable = True
optimizer.SGD(learning_rate=0.01).minimize(loss) optimizer.SGD(learning_rate=0.01).minimize(loss)
...@@ -89,19 +93,17 @@ def run_check(): ...@@ -89,19 +93,17 @@ def run_check():
train_prog).with_data_parallel( train_prog).with_data_parallel(
build_strategy=build_strategy, build_strategy=build_strategy,
loss_name=loss.name, loss_name=loss.name,
places=places) places=device_list)
exe.run(startup_prog) exe.run(startup_prog)
exe.run(compiled_prog, exe.run(compiled_prog,
feed={inp.name: np_inp}, feed={inp.name: np_inp_muti},
fetch_list=[loss.name]) fetch_list=[loss.name])
def test_simple_exe(): def test_simple_exe():
train_prog = Program() train_prog = Program()
startup_prog = Program() startup_prog = Program()
scope = core.Scope() scope = core.Scope()
if not use_cuda:
os.environ['CPU_NUM'] = "1"
with executor.scope_guard(scope): with executor.scope_guard(scope):
with program_guard(train_prog, startup_prog): with program_guard(train_prog, startup_prog):
with unique_name.guard(): with unique_name.guard():
...@@ -111,11 +113,11 @@ def run_check(): ...@@ -111,11 +113,11 @@ def run_check():
out0 = simple_layer0(inp0) out0 = simple_layer0(inp0)
param_grads = backward.append_backward( param_grads = backward.append_backward(
out0, parameter_list=[simple_layer0._fc1._w.name])[0] out0, parameter_list=[simple_layer0._fc1._w.name])[0]
exe0 = executor.Executor(core.CPUPlace() exe0 = executor.Executor(
if not core.is_compiled_with_cuda() core.CUDAPlace(0) if core.is_compiled_with_cuda() and
else core.CUDAPlace(0)) (core.get_cuda_device_count() > 0) else core.CPUPlace())
exe0.run(startup_prog) exe0.run(startup_prog)
exe0.run(feed={inp0.name: np_inp}, exe0.run(feed={inp0.name: np_inp_single},
fetch_list=[out0.name, param_grads[1].name]) fetch_list=[out0.name, param_grads[1].name])
test_simple_exe() test_simple_exe()
...@@ -130,7 +132,7 @@ def run_check(): ...@@ -130,7 +132,7 @@ def run_check():
except Exception as e: except Exception as e:
logging.warning( logging.warning(
"Your Paddle Fluid has some problem with multiple GPU. This may be caused by:" "Your Paddle Fluid has some problem with multiple GPU. This may be caused by:"
"\n 1. There is only 1 GPU visible on your Device;" "\n 1. There is only 1 or 0 GPU visible on your Device;"
"\n 2. No.1 or No.2 GPU or both of them are occupied now" "\n 2. No.1 or No.2 GPU or both of them are occupied now"
"\n 3. Wrong installation of NVIDIA-NCCL2, please follow instruction on https://github.com/NVIDIA/nccl-tests " "\n 3. Wrong installation of NVIDIA-NCCL2, please follow instruction on https://github.com/NVIDIA/nccl-tests "
"\n to test your NCCL, or reinstall it following https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html" "\n to test your NCCL, or reinstall it following https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html"
...@@ -139,4 +141,4 @@ def run_check(): ...@@ -139,4 +141,4 @@ def run_check():
print("\n Original Error is: {}".format(e)) print("\n Original Error is: {}".format(e))
print( print(
"Your Paddle Fluid is installed successfully ONLY for SINGLE GPU or CPU! " "Your Paddle Fluid is installed successfully ONLY for SINGLE GPU or CPU! "
"\n Let's start deep Learning with Paddle Fluid now!") "\n Let's start deep Learning with Paddle Fluid now")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册