提交 f632706c 编写于 作者: D Dong Zhihong

fix based on comment

上级 52200523
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/operators/nccl/nccl_gpu_common.h" #include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/gpu_info.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -482,6 +483,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -482,6 +483,7 @@ All parameter, weight, gradient are variables in Paddle.
BindOpDesc(m); BindOpDesc(m);
m.def("op_support_gpu", OpSupportGPU); m.def("op_support_gpu", OpSupportGPU);
m.def("get_cuda_device_count", platform::GetCUDADeviceCount);
return m.ptr(); return m.ptr();
} }
......
...@@ -5,11 +5,10 @@ from paddle.v2.framework.op import Operator ...@@ -5,11 +5,10 @@ from paddle.v2.framework.op import Operator
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from op_test import OpTest, create_op, set_input from op_test import OpTest, create_op, set_input
gpu_list = "0,1,2,3" if not core.is_compile_gpu():
if not core.is_compile_gpu() or not gpu_list:
exit(0) exit(0)
gpu_count = core.get_cuda_device_count
g_scope = core.Scope() g_scope = core.Scope()
g_ctx = core.DeviceContext.create(core.CPUPlace()) g_ctx = core.DeviceContext.create(core.CPUPlace())
...@@ -17,7 +16,7 @@ g_ctx = core.DeviceContext.create(core.CPUPlace()) ...@@ -17,7 +16,7 @@ g_ctx = core.DeviceContext.create(core.CPUPlace())
class TestNCCLInit(unittest.TestCase): class TestNCCLInit(unittest.TestCase):
def test_init(self): def test_init(self):
self.op_type = "ncclInit" self.op_type = "ncclInit"
self.gpus = [int(g) for g in gpu_list.split(",")] self.gpus = [int(g) for g in range(gpu_count)]
self.inputs = {} self.inputs = {}
self.attrs = {"gpus": self.gpus} self.attrs = {"gpus": self.gpus}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册