提交 bc816035 编写于 作者: Y Yu Yang

Fix compile

上级 a6edeb39
......@@ -279,8 +279,9 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
paddle::platform::float16* c =
input3_gpu.mutable_data<paddle::platform::float16>(gpu_place);
GetBlas<float16>(context).GEMM(false, false, m, n, k, float16(1), a, 3, b + 1,
4, float16(1), c + 1, 4);
GetBlas<paddle::platform::float16>(context).GEMM(
false, false, m, n, k, static_cast<paddle::platform::float16>(1), a, 3,
b + 1, 4, static_cast<paddle::platform::float16>(1), c + 1, 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3);
......@@ -388,12 +389,9 @@ TEST(math_function, gemm_trans_cublas_fp16) {
paddle::platform::float16* c =
input3_gpu.mutable_data<paddle::platform::float16>(gpu_place);
GetBlas<float16>(context).GEMM(false, true, m, n, k, float16(1), a, 3, b + 3,
3, float16(1), c + 1, 4);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext,
paddle::platform::float16>(
context, false, true, m, n, k, paddle::platform::float16(1), a, 3, b + 3,
3, paddle::platform::float16(1), c + 1, 4);
GetBlas<paddle::platform::float16>(context).GEMM(
false, true, m, n, k, static_cast<paddle::platform::float16>(1), a, 3,
b + 3, 3, static_cast<paddle::platform::float16>(1), c + 1, 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册