提交 5f1081d8 编写于 作者: Q qijun

fix bug in dynload

上级 c5a7471e
...@@ -13,4 +13,4 @@ else() ...@@ -13,4 +13,4 @@ else()
endif() endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
...@@ -12,15 +12,18 @@ TEST(math_function, GPU) { ...@@ -12,15 +12,18 @@ TEST(math_function, GPU) {
auto* cpu_place = new paddle::platform::CPUPlace(); auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 2}, *cpu_place); float* input1_ptr = input1.mutable_data<float>({2, 2}, *cpu_place);
float arr[4] = {0, 1, 2, 3}; float arr[4] = {0, 1, 2, 3};
memcpy(input1_ptr, arr, 4 * sizeof(int));
auto* gpu_place = new paddle::platform::GPUPlace(0); auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context = new CUDADeviceContext(gpu_place); paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place); input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place); input2_gpu.CopyFrom<float>(input1, *gpu_place);
out_gpu.CopyFrom<float>(input1, *gpu_place); out_gpu.CopyFrom<float>(input1, *gpu_place);
matmul<paddle::platform::GPUPlace, float>(input1_gpu, false, input2_gpu, paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, false, input2_gpu,
false, 1, &out_gpu, 0, context); false, 1, &out_gpu, 0, context);
out.CopyFrom<float>(out_gpu, *cpu_place); out.CopyFrom<float>(out_gpu, *cpu_place);
......
...@@ -62,12 +62,12 @@ extern void *cublas_dso_handle; ...@@ -62,12 +62,12 @@ extern void *cublas_dso_handle;
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name)
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSgemv); \ __macro(cublasSgemv_v2); \
__macro(cublasDgemv); \ __macro(cublasDgemv_v2); \
__macro(cublasSgemm); \ __macro(cublasSgemm_v2); \
__macro(cublasDgemm); \ __macro(cublasDgemm_v2); \
__macro(cublasSgeam); \ __macro(cublasSgeam_v2); \
__macro(cublasDgeam); \ __macro(cublasDgeam_v2); \
__macro(cublasCreate_v2); \ __macro(cublasCreate_v2); \
__macro(cublasDestroy_v2); \ __macro(cublasDestroy_v2); \
__macro(cublasSetStream_v2); \ __macro(cublasSetStream_v2); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册