提交 37aa4b98 编写于 作者: Q qijun

refine unittest

上级 c2631ebf
...@@ -60,16 +60,6 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T, ...@@ -60,16 +60,6 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T,
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
"The input and output of matmul be matrix"); "The input and output of matmul be matrix");
if (!in1_T && !in2_T) {
PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]);
} else if (in1_T && !in2_T) {
PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]);
} else if (!in1_T && in2_T) {
PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]);
} else {
PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]);
}
PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) &&
platform::is_cpu_place(in2.place()) && platform::is_cpu_place(in2.place()) &&
platform::is_cpu_place(out->place()), platform::is_cpu_place(out->place()),
...@@ -77,7 +67,7 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T, ...@@ -77,7 +67,7 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T,
int M = out_dim[0]; int M = out_dim[0];
int N = out_dim[1]; int N = out_dim[1];
int K = in1_dim[1]; int K = (in1_T == false) ? in1_dim[1] : in1_dim[0];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans;
...@@ -100,16 +90,6 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& in1, ...@@ -100,16 +90,6 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& in1,
PADDLE_ENFORCE( PADDLE_ENFORCE(
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
"The input and output of matmul be matrix"); "The input and output of matmul be matrix");
if (!in1_T && !in2_T) {
PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]);
} else if (in1_T && !in2_T) {
PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]);
} else if (!in1_T && in2_T) {
PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]);
} else {
PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]);
}
PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) &&
platform::is_cpu_place(in2.place()) && platform::is_cpu_place(in2.place()) &&
platform::is_cpu_place(out->place()), platform::is_cpu_place(out->place()),
...@@ -117,7 +97,7 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& in1, ...@@ -117,7 +97,7 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& in1,
int M = out_dim[0]; int M = out_dim[0];
int N = out_dim[1]; int N = out_dim[1];
int K = in1_dim[1]; int K = (in1_T == false) ? in1_dim[1] : in1_dim[0];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans;
......
...@@ -71,15 +71,6 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T, ...@@ -71,15 +71,6 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T,
PADDLE_ENFORCE( PADDLE_ENFORCE(
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
"The input and output of matmul be matrix"); "The input and output of matmul be matrix");
if (!in1_T && !in2_T) {
PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]);
} else if (in1_T && !in2_T) {
PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]);
} else if (!in1_T && in2_T) {
PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]);
} else {
PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]);
}
PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) &&
platform::is_gpu_place(in2.place()) && platform::is_gpu_place(in2.place()) &&
...@@ -88,7 +79,7 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T, ...@@ -88,7 +79,7 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T,
int M = out_dim[0]; int M = out_dim[0];
int N = out_dim[1]; int N = out_dim[1];
int K = in1_dim[1]; int K = (in1_T == false) ? in1_dim[1] : in1_dim[0];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans;
...@@ -111,16 +102,6 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1, ...@@ -111,16 +102,6 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
PADDLE_ENFORCE( PADDLE_ENFORCE(
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
"The input and output of matmul be matrix"); "The input and output of matmul be matrix");
if (!in1_T && !in2_T) {
PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]);
} else if (in1_T && !in2_T) {
PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]);
} else if (!in1_T && in2_T) {
PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]);
} else {
PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]);
}
PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) &&
platform::is_gpu_place(in2.place()) && platform::is_gpu_place(in2.place()) &&
platform::is_gpu_place(out->place()), platform::is_gpu_place(out->place()),
...@@ -128,7 +109,7 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1, ...@@ -128,7 +109,7 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
int M = out_dim[0]; int M = out_dim[0];
int N = out_dim[1]; int N = out_dim[1];
int K = in1_dim[1]; int K = (in1_T == false) ? in1_dim[1] : in1_dim[0];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册