提交 49dedfad 编写于 作者: Y Yu Yang

Polish code and tests

上级 c888e016
...@@ -42,9 +42,20 @@ struct CUBlas<double> { ...@@ -42,9 +42,20 @@ struct CUBlas<double> {
template <> template <>
struct CUBlas<platform::float16> { struct CUBlas<platform::float16> {
template <typename... ARGS> using float16 = platform::float16;
static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasHgemm(args...)); static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float16 *alpha, const float16 *A, int lda,
const float16 *B, int ldb, const float16 *beta, float16 *C,
int ldc) {
PADDLE_ENFORCE(
platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(A), lda,
reinterpret_cast<const __half *>(B), ldb,
reinterpret_cast<const __half *>(beta),
reinterpret_cast<__half *>(C), ldc));
} }
}; };
......
...@@ -14,6 +14,13 @@ ...@@ -14,6 +14,13 @@
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CPUDeviceContext, T>
GetBlas(const paddle::platform::CPUDeviceContext& context) {
return paddle::operators::math::GetBlas<paddle::platform::CPUDeviceContext,
T>(context);
}
TEST(math_function, gemm_notrans_cblas) { TEST(math_function, gemm_notrans_cblas) {
paddle::framework::Tensor input1; paddle::framework::Tensor input1;
paddle::framework::Tensor input2; paddle::framework::Tensor input2;
...@@ -34,9 +41,8 @@ TEST(math_function, gemm_notrans_cblas) { ...@@ -34,9 +41,8 @@ TEST(math_function, gemm_notrans_cblas) {
memcpy(input3_ptr, arr3, 8 * sizeof(float)); memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::platform::CPUDeviceContext context(*cpu_place); paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemm<paddle::platform::CPUDeviceContext, float>( GetBlas<float>(context).GEMM(false, false, m, n, k, 1, input1_ptr, 3,
context, false, false, m, n, k, 1, input1_ptr, 3, input2_ptr + 1, 4, 1, input2_ptr + 1, 4, 1, input3_ptr + 1, 4);
input3_ptr + 1, 4);
EXPECT_EQ(input3_ptr[0], 0); EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24); EXPECT_EQ(input3_ptr[1], 24);
...@@ -68,9 +74,8 @@ TEST(math_function, gemm_trans_clbas) { ...@@ -68,9 +74,8 @@ TEST(math_function, gemm_trans_clbas) {
memcpy(input3_ptr, arr3, 8 * sizeof(float)); memcpy(input3_ptr, arr3, 8 * sizeof(float));
paddle::platform::CPUDeviceContext context(*cpu_place); paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemm<paddle::platform::CPUDeviceContext, float>( GetBlas<float>(context).GEMM(false, true, m, n, k, 1, input1_ptr, 3,
context, false, true, m, n, k, 1, input1_ptr, 3, input2_ptr + 3, 3, 1, input2_ptr + 3, 3, 1, input3_ptr + 1, 4);
input3_ptr + 1, 4);
EXPECT_EQ(input3_ptr[0], 0); EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24); EXPECT_EQ(input3_ptr[1], 24);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
const std::vector<float>& data) { const std::vector<float>& data) {
...@@ -23,8 +24,8 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, ...@@ -23,8 +24,8 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
} }
TEST(math_function, notrans_mul_trans_fp32) { TEST(math_function, notrans_mul_trans_fp32) {
using namespace paddle::framework; using namespace paddle::framework; // NOLINT
using namespace paddle::platform; using namespace paddle::platform; // NOLINT
Tensor input1; Tensor input1;
Tensor input1_gpu; Tensor input1_gpu;
...@@ -59,8 +60,8 @@ TEST(math_function, notrans_mul_trans_fp32) { ...@@ -59,8 +60,8 @@ TEST(math_function, notrans_mul_trans_fp32) {
} }
TEST(math_function, notrans_mul_trans_fp16) { TEST(math_function, notrans_mul_trans_fp16) {
using namespace paddle::framework; using namespace paddle::framework; // NOLINT
using namespace paddle::platform; using namespace paddle::platform; // NOLINT
Tensor input1; Tensor input1;
Tensor input1_gpu; Tensor input1_gpu;
...@@ -100,8 +101,8 @@ TEST(math_function, notrans_mul_trans_fp16) { ...@@ -100,8 +101,8 @@ TEST(math_function, notrans_mul_trans_fp16) {
} }
TEST(math_function, trans_mul_notrans_fp32) { TEST(math_function, trans_mul_notrans_fp32) {
using namespace paddle::framework; using namespace paddle::framework; // NOLINT
using namespace paddle::platform; using namespace paddle::platform; // NOLINT
Tensor input1; Tensor input1;
Tensor input1_gpu; Tensor input1_gpu;
...@@ -141,8 +142,8 @@ TEST(math_function, trans_mul_notrans_fp32) { ...@@ -141,8 +142,8 @@ TEST(math_function, trans_mul_notrans_fp32) {
} }
TEST(math_function, trans_mul_notrans_fp16) { TEST(math_function, trans_mul_notrans_fp16) {
using namespace paddle::framework; using namespace paddle::framework; // NOLINT
using namespace paddle::platform; using namespace paddle::platform; // NOLINT
Tensor input1; Tensor input1;
Tensor input1_gpu; Tensor input1_gpu;
...@@ -186,9 +187,16 @@ TEST(math_function, trans_mul_notrans_fp16) { ...@@ -186,9 +187,16 @@ TEST(math_function, trans_mul_notrans_fp16) {
EXPECT_EQ(static_cast<float>(out_ptr[8]), 29); EXPECT_EQ(static_cast<float>(out_ptr[8]), 29);
} }
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CUDADeviceContext, T>
GetBlas(const paddle::platform::CUDADeviceContext& context) {
return paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
T>(context);
}
TEST(math_function, gemm_notrans_cublas_fp32) { TEST(math_function, gemm_notrans_cublas_fp32) {
using namespace paddle::framework; using namespace paddle::framework; // NOLINT
using namespace paddle::platform; using namespace paddle::platform; // NOLINT
Tensor input1; Tensor input1;
Tensor input2; Tensor input2;
...@@ -221,8 +229,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) { ...@@ -221,8 +229,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
float* b = input2_gpu.data<float>(); float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(gpu_place); float* c = input3_gpu.mutable_data<float>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>( GetBlas<float>(context).GEMM(false, false, m, n, k, 1, a, 3, b + 1, 4, 1,
context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); c + 1, 4);
TensorCopySync(input3_gpu, cpu_place, &input3); TensorCopySync(input3_gpu, cpu_place, &input3);
...@@ -244,8 +252,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) { ...@@ -244,8 +252,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
} }
TEST(math_function, gemm_notrans_cublas_fp16) { TEST(math_function, gemm_notrans_cublas_fp16) {
using namespace paddle::framework; using namespace paddle::framework; // NOLINT
using namespace paddle::platform; using namespace paddle::platform; // NOLINT
Tensor input1; Tensor input1;
Tensor input2; Tensor input2;
...@@ -281,9 +289,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) { ...@@ -281,9 +289,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
float16* b = input2_gpu.data<float16>(); float16* b = input2_gpu.data<float16>();
float16* c = input3_gpu.mutable_data<float16>(gpu_place); float16* c = input3_gpu.mutable_data<float16>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float16>( GetBlas<float16>(context).GEMM(false, false, m, n, k, float16(1), a, 3, b + 1,
context, false, false, m, n, k, float16(1), a, 3, b + 1, 4, float16(1), 4, float16(1), c + 1, 4);
c + 1, 4);
TensorCopySync(input3_gpu, cpu_place, &input3); TensorCopySync(input3_gpu, cpu_place, &input3);
...@@ -305,8 +312,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) { ...@@ -305,8 +312,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
} }
TEST(math_function, gemm_trans_cublas_fp32) { TEST(math_function, gemm_trans_cublas_fp32) {
using namespace paddle::framework; using namespace paddle::framework; // NOLINT
using namespace paddle::platform; using namespace paddle::platform; // NOLINT
Tensor input1; Tensor input1;
Tensor input2; Tensor input2;
...@@ -339,8 +346,8 @@ TEST(math_function, gemm_trans_cublas_fp32) { ...@@ -339,8 +346,8 @@ TEST(math_function, gemm_trans_cublas_fp32) {
float* b = input2_gpu.data<float>(); float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(gpu_place); float* c = input3_gpu.mutable_data<float>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>( GetBlas<float>(context).GEMM(false, true, m, n, k, 1, a, 3, b + 3, 3, 1,
context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); c + 1, 4);
TensorCopySync(input3_gpu, cpu_place, &input3); TensorCopySync(input3_gpu, cpu_place, &input3);
...@@ -356,8 +363,8 @@ TEST(math_function, gemm_trans_cublas_fp32) { ...@@ -356,8 +363,8 @@ TEST(math_function, gemm_trans_cublas_fp32) {
} }
TEST(math_function, gemm_trans_cublas_fp16) { TEST(math_function, gemm_trans_cublas_fp16) {
using namespace paddle::framework; using namespace paddle::framework; // NOLINT
using namespace paddle::platform; using namespace paddle::platform; // NOLINT
Tensor input1; Tensor input1;
Tensor input2; Tensor input2;
...@@ -393,9 +400,8 @@ TEST(math_function, gemm_trans_cublas_fp16) { ...@@ -393,9 +400,8 @@ TEST(math_function, gemm_trans_cublas_fp16) {
float16* b = input2_gpu.data<float16>(); float16* b = input2_gpu.data<float16>();
float16* c = input3_gpu.mutable_data<float16>(gpu_place); float16* c = input3_gpu.mutable_data<float16>(gpu_place);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float16>( GetBlas<float16>(context).GEMM(false, true, m, n, k, float16(1), a, 3, b + 3,
context, false, true, m, n, k, float16(1), a, 3, b + 3, 3, float16(1), 3, float16(1), c + 1, 4);
c + 1, 4);
TensorCopySync(input3_gpu, cpu_place, &input3); TensorCopySync(input3_gpu, cpu_place, &input3);
...@@ -412,8 +418,8 @@ TEST(math_function, gemm_trans_cublas_fp16) { ...@@ -412,8 +418,8 @@ TEST(math_function, gemm_trans_cublas_fp16) {
template <typename T> template <typename T>
void GemvTest(int m, int n, bool trans) { void GemvTest(int m, int n, bool trans) {
using namespace paddle::framework; using namespace paddle::framework; // NOLINT
using namespace paddle::platform; using namespace paddle::platform; // NOLINT
Tensor mat_a; Tensor mat_a;
Tensor vec_b; Tensor vec_b;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册