提交 f632706c 编写于 作者: D Dong Zhihong

fix based on comment

上级 52200523
......@@ -34,6 +34,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/gpu_info.h"
#endif
namespace paddle {
......@@ -482,6 +483,7 @@ All parameter, weight, gradient are variables in Paddle.
BindOpDesc(m);
m.def("op_support_gpu", OpSupportGPU);
m.def("get_cuda_device_count", platform::GetCUDADeviceCount);
return m.ptr();
}
......
......@@ -5,11 +5,10 @@ from paddle.v2.framework.op import Operator
import paddle.v2.framework.core as core
from op_test import OpTest, create_op, set_input
gpu_list = "0,1,2,3"
if not core.is_compile_gpu() or not gpu_list:
if not core.is_compile_gpu():
exit(0)
gpu_count = core.get_cuda_device_count
g_scope = core.Scope()
g_ctx = core.DeviceContext.create(core.CPUPlace())
......@@ -17,7 +16,7 @@ g_ctx = core.DeviceContext.create(core.CPUPlace())
class TestNCCLInit(unittest.TestCase):
def test_init(self):
self.op_type = "ncclInit"
self.gpus = [int(g) for g in gpu_list.split(",")]
self.gpus = [int(g) for g in range(gpu_count)]
self.inputs = {}
self.attrs = {"gpus": self.gpus}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册