提交 49dedfad 编写于 作者: Y Yu Yang

Polish code and tests

上级 c888e016
......@@ -42,9 +42,20 @@ struct CUBlas<double> {
template <>
struct CUBlas<platform::float16> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasHgemm(args...));
using float16 = platform::float16;
static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float16 *alpha, const float16 *A, int lda,
const float16 *B, int ldb, const float16 *beta, float16 *C,
int ldc) {
PADDLE_ENFORCE(
platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(A), lda,
reinterpret_cast<const __half *>(B), ldb,
reinterpret_cast<const __half *>(beta),
reinterpret_cast<__half *>(C), ldc));
}
};
......
......@@ -14,6 +14,13 @@
#include "paddle/fluid/operators/math/math_function.h"
#include "gtest/gtest.h"
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CPUDeviceContext, T>
GetBlas(const paddle::platform::CPUDeviceContext& context) {
return paddle::operators::math::GetBlas<paddle::platform::CPUDeviceContext,
T>(context);
}
TEST(math_function, gemm_notrans_cblas) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input2;
......@@ -34,9 +41,8 @@ TEST(math_function, gemm_notrans_cblas) {
memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemm<paddle::platform::CPUDeviceContext, float>(
context, false, false, m, n, k, 1, input1_ptr, 3, input2_ptr + 1, 4, 1,
input3_ptr + 1, 4);
GetBlas<float>(context).GEMM(false, false, m, n, k, 1, input1_ptr, 3,
input2_ptr + 1, 4, 1, input3_ptr + 1, 4);
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
......@@ -68,9 +74,8 @@ TEST(math_function, gemm_trans_clbas) {
memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemm<paddle::platform::CPUDeviceContext, float>(
context, false, true, m, n, k, 1, input1_ptr, 3, input2_ptr + 3, 3, 1,
input3_ptr + 1, 4);
GetBlas<float>(context).GEMM(false, true, m, n, k, 1, input1_ptr, 3,
input2_ptr + 3, 3, 1, input3_ptr + 1, 4);
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
const std::vector<float>& data) {
......@@ -23,8 +24,8 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
}
TEST(math_function, notrans_mul_trans_fp32) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::framework; // NOLINT
using namespace paddle::platform; // NOLINT
Tensor input1;
Tensor input1_gpu;
......@@ -59,8 +60,8 @@ TEST(math_function, notrans_mul_trans_fp32) {
}
TEST(math_function, notrans_mul_trans_fp16) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::framework; // NOLINT
using namespace paddle::platform; // NOLINT
Tensor input1;
Tensor input1_gpu;
......@@ -100,8 +101,8 @@ TEST(math_function, notrans_mul_trans_fp16) {
}
TEST(math_function, trans_mul_notrans_fp32) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::framework; // NOLINT
using namespace paddle::platform; // NOLINT
Tensor input1;
Tensor input1_gpu;
......@@ -141,8 +142,8 @@ TEST(math_function, trans_mul_notrans_fp32) {
}
TEST(math_function, trans_mul_notrans_fp16) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::framework; // NOLINT
using namespace paddle::platform; // NOLINT
Tensor input1;
Tensor input1_gpu;
......@@ -186,9 +187,16 @@ TEST(math_function, trans_mul_notrans_fp16) {
EXPECT_EQ(static_cast<float>(out_ptr[8]), 29);
}
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CUDADeviceContext, T>
GetBlas(const paddle::platform::CUDADeviceContext& context) {
return paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
T>(context);
}
TEST(math_function, gemm_notrans_cublas_fp32) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::framework; // NOLINT
using namespace paddle::platform; // NOLINT
Tensor input1;
Tensor input2;
......@@ -221,8 +229,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>(
context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4);
GetBlas<float>(context).GEMM(false, false, m, n, k, 1, a, 3, b + 1, 4, 1,
c + 1, 4);
TensorCopySync(input3_gpu, cpu_place, &input3);
......@@ -244,8 +252,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
}
TEST(math_function, gemm_notrans_cublas_fp16) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::framework; // NOLINT
using namespace paddle::platform; // NOLINT
Tensor input1;
Tensor input2;
......@@ -281,9 +289,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
float16* b = input2_gpu.data<float16>();
float16* c = input3_gpu.mutable_data<float16>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float16>(
context, false, false, m, n, k, float16(1), a, 3, b + 1, 4, float16(1),
c + 1, 4);
GetBlas<float16>(context).GEMM(false, false, m, n, k, float16(1), a, 3, b + 1,
4, float16(1), c + 1, 4);
TensorCopySync(input3_gpu, cpu_place, &input3);
......@@ -305,8 +312,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
}
TEST(math_function, gemm_trans_cublas_fp32) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::framework; // NOLINT
using namespace paddle::platform; // NOLINT
Tensor input1;
Tensor input2;
......@@ -339,8 +346,8 @@ TEST(math_function, gemm_trans_cublas_fp32) {
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>(
context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4);
GetBlas<float>(context).GEMM(false, true, m, n, k, 1, a, 3, b + 3, 3, 1,
c + 1, 4);
TensorCopySync(input3_gpu, cpu_place, &input3);
......@@ -356,8 +363,8 @@ TEST(math_function, gemm_trans_cublas_fp32) {
}
TEST(math_function, gemm_trans_cublas_fp16) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::framework; // NOLINT
using namespace paddle::platform; // NOLINT
Tensor input1;
Tensor input2;
......@@ -393,9 +400,8 @@ TEST(math_function, gemm_trans_cublas_fp16) {
float16* b = input2_gpu.data<float16>();
float16* c = input3_gpu.mutable_data<float16>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float16>(
context, false, true, m, n, k, float16(1), a, 3, b + 3, 3, float16(1),
c + 1, 4);
GetBlas<float16>(context).GEMM(false, true, m, n, k, float16(1), a, 3, b + 3,
3, float16(1), c + 1, 4);
TensorCopySync(input3_gpu, cpu_place, &input3);
......@@ -412,8 +418,8 @@ TEST(math_function, gemm_trans_cublas_fp16) {
template <typename T>
void GemvTest(int m, int n, bool trans) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::framework; // NOLINT
using namespace paddle::platform; // NOLINT
Tensor mat_a;
Tensor vec_b;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册