未验证 提交 d0219002 编写于 作者: J Jiabin Yang 提交者: GitHub

Cherry pick install check for multi gpu (#18245)

* test=develop, add add_multi_gpu_install_check (#18157)

* test=develop, add add_multi_gpu_install_check

* test=develop, refine warning doc

* test=develop, refine warning doc

* test=develop, refine warning doc

* test=develop, support multi cpu

* test=release/1.5, cherry-picked from develop
上级 0648376c
...@@ -12,15 +12,50 @@ ...@@ -12,15 +12,50 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .framework import Program, program_guard, unique_name, default_startup_program import os
from . import core
def process_env():
env = os.environ
device_list = []
if env.get('CUDA_VISIBLE_DEVICES') is not None:
cuda_devices = env['CUDA_VISIBLE_DEVICES']
if cuda_devices == "" or len(cuda_devices) == 0:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"
device_list = [0, 1]
elif len(cuda_devices) == 1:
device_list.append(0)
elif len(cuda_devices) > 1:
for i in range(len(cuda_devices.split(","))):
device_list.append(i)
return device_list
else:
if core.get_cuda_device_count() > 1:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"
return [0, 1]
else:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
return [0]
device_list = []
if core.is_compiled_with_cuda():
device_list = process_env()
else:
device_list = [0, 1] # for CPU 0,1
from .framework import Program, program_guard, unique_name
from .param_attr import ParamAttr from .param_attr import ParamAttr
from .initializer import Constant from .initializer import Constant
from . import layers from . import layers
from . import backward from . import backward
from .dygraph import Layer, nn from .dygraph import Layer, nn
from . import executor from . import executor
from . import optimizer
from . import core from . import core
from . import compiler
import logging
import numpy as np import numpy as np
__all__ = ['run_check'] __all__ = ['run_check']
...@@ -45,25 +80,94 @@ def run_check(): ...@@ -45,25 +80,94 @@ def run_check():
This func should not be called only if you need to verify installation This func should not be called only if you need to verify installation
''' '''
print("Running Verify Fluid Program ... ") print("Running Verify Fluid Program ... ")
prog = Program() use_cuda = False if not core.is_compiled_with_cuda() else True
place = core.CPUPlace() if not core.is_compiled_with_cuda(
) else core.CUDAPlace(0)
np_inp_single = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
inp = []
for i in range(len(device_list)):
inp.append(np_inp_single)
np_inp_muti = np.array(inp)
np_inp_muti = np_inp_muti.reshape(len(device_list), 2, 2)
def test_parallerl_exe():
train_prog = Program()
startup_prog = Program() startup_prog = Program()
scope = core.Scope() scope = core.Scope()
if not use_cuda:
os.environ['CPU_NUM'] = "2"
with executor.scope_guard(scope): with executor.scope_guard(scope):
with program_guard(prog, startup_prog): with program_guard(train_prog, startup_prog):
with unique_name.guard(): with unique_name.guard():
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) places = []
inp = layers.data( build_strategy = compiler.BuildStrategy()
name="inp", shape=[2, 2], append_batch_size=False) build_strategy.enable_inplace = True
build_strategy.memory_optimize = True
inp = layers.data(name="inp", shape=[2, 2])
simple_layer = SimpleLayer("simple_layer") simple_layer = SimpleLayer("simple_layer")
out = simple_layer(inp) out = simple_layer(inp)
exe = executor.Executor(place)
if use_cuda:
for i in device_list:
places.append(core.CUDAPlace(i))
else:
places = [core.CPUPlace(), core.CPUPlace()]
loss = layers.mean(out)
loss.persistable = True
optimizer.SGD(learning_rate=0.01).minimize(loss)
startup_prog.random_seed = 1
compiled_prog = compiler.CompiledProgram(
train_prog).with_data_parallel(
build_strategy=build_strategy,
loss_name=loss.name,
places=places)
exe.run(startup_prog)
exe.run(compiled_prog,
feed={inp.name: np_inp_muti},
fetch_list=[loss.name])
def test_simple_exe():
train_prog = Program()
startup_prog = Program()
scope = core.Scope()
if not use_cuda:
os.environ['CPU_NUM'] = "1"
with executor.scope_guard(scope):
with program_guard(train_prog, startup_prog):
with unique_name.guard():
inp0 = layers.data(
name="inp", shape=[2, 2], append_batch_size=False)
simple_layer0 = SimpleLayer("simple_layer")
out0 = simple_layer0(inp0)
param_grads = backward.append_backward( param_grads = backward.append_backward(
out, parameter_list=[simple_layer._fc1._w.name])[0] out0, parameter_list=[simple_layer0._fc1._w.name])[0]
exe = executor.Executor(core.CPUPlace( exe0 = executor.Executor(core.CPUPlace()
) if not core.is_compiled_with_cuda() else core.CUDAPlace(0)) if not core.is_compiled_with_cuda()
exe.run(default_startup_program()) else core.CUDAPlace(0))
exe.run(feed={inp.name: np_inp}, exe0.run(startup_prog)
fetch_list=[out.name, param_grads[1].name]) exe0.run(feed={inp0.name: np_inp_single},
fetch_list=[out0.name, param_grads[1].name])
test_simple_exe()
print("Your Paddle Fluid works well on SINGLE GPU or CPU.")
try:
test_parallerl_exe()
print("Your Paddle Fluid works well on MUTIPLE GPU or CPU.")
print( print(
"Your Paddle Fluid is installed successfully! Let's start deep Learning with Paddle Fluid now" "Your Paddle Fluid is installed successfully! Let's start deep Learning with Paddle Fluid now"
) )
except Exception as e:
logging.warning(
"Your Paddle Fluid has some problem with multiple GPU. This may be caused by:"
"\n 1. There is only 1 GPU visible on your Device;"
"\n 2. No.1 or No.2 GPU or both of them are occupied now"
"\n 3. Wrong installation of NVIDIA-NCCL2, please follow instruction on https://github.com/NVIDIA/nccl-tests "
"\n to test your NCCL, or reinstall it following https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html"
)
print("\n Original Error is: {}".format(e))
print(
"Your Paddle Fluid is installed successfully ONLY for SINGLE GPU or CPU! "
"\n Let's start deep Learning with Paddle Fluid now")
...@@ -116,6 +116,7 @@ list(REMOVE_ITEM TEST_OPS test_imperative_mnist) ...@@ -116,6 +116,7 @@ list(REMOVE_ITEM TEST_OPS test_imperative_mnist)
list(REMOVE_ITEM TEST_OPS test_ir_memory_optimize_transformer) list(REMOVE_ITEM TEST_OPS test_ir_memory_optimize_transformer)
list(REMOVE_ITEM TEST_OPS test_layers) list(REMOVE_ITEM TEST_OPS test_layers)
list(REMOVE_ITEM TEST_OPS test_imperative_ocr_attention_model) list(REMOVE_ITEM TEST_OPS test_imperative_ocr_attention_model)
list(REMOVE_ITEM TEST_OPS test_install_check)
# Some ops need to check results when gc is enabled # Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test # Currently, only ops that register NoNeedBufferVarsInference need to do this test
...@@ -172,6 +173,9 @@ py_test_modules(test_imperative_mnist_sorted_gradient MODULES test_imperative_mn ...@@ -172,6 +173,9 @@ py_test_modules(test_imperative_mnist_sorted_gradient MODULES test_imperative_mn
py_test_modules(test_imperative_se_resnext MODULES test_imperative_se_resnext ENVS py_test_modules(test_imperative_se_resnext MODULES test_imperative_se_resnext ENVS
FLAGS_cudnn_deterministic=1 SERIAL) FLAGS_cudnn_deterministic=1 SERIAL)
set_tests_properties(test_imperative_se_resnext PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_imperative_se_resnext PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
py_test_modules(test_install_check MODULES test_install_check ENVS
FLAGS_cudnn_deterministic=1 SERIAL)
set_tests_properties(test_install_check PROPERTIES LABELS "RUN_TYPE=DIST")
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
py_test_modules(test_dist_train MODULES test_dist_train) py_test_modules(test_dist_train MODULES test_dist_train)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册