Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2a06e307
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2a06e307
编写于
4月 25, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix batch_gemm bugs
stride should be int64_t, not int
上级
bfbbe19f
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
21 addition
and
12 deletion
+21
-12
paddle/fluid/operators/math/math_function.cc
paddle/fluid/operators/math/math_function.cc
+7
-3
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+10
-6
paddle/fluid/operators/math/math_function.h
paddle/fluid/operators/math/math_function.h
+4
-3
未找到文件。
paddle/fluid/operators/math/math_function.cc
浏览文件 @
2a06e307
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -161,7 +162,8 @@ void batched_gemm<platform::CPUDeviceContext, float16>(
const
platform
::
CPUDeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float16
alpha
,
const
float16
*
A
,
const
float16
*
B
,
const
float16
beta
,
float16
*
C
,
const
int
batchCount
,
const
int
strideA
,
const
int
strideB
)
{
float16
*
C
,
const
int
batchCount
,
const
int64_t
strideA
,
const
int64_t
strideB
)
{
PADDLE_THROW
(
"float16 batched_gemm not supported on CPU"
);
}
...
...
@@ -172,7 +174,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
const
platform
::
CPUDeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
,
const
int
batchCount
,
const
int
strideA
,
const
int
strideB
)
{
float
*
C
,
const
int
batchCount
,
const
int64_t
strideA
,
const
int64_t
strideB
)
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
...
...
@@ -194,7 +197,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
const
platform
::
CPUDeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
,
const
int
batchCount
,
const
int
strideA
,
const
int
strideB
)
{
double
*
C
,
const
int
batchCount
,
const
int64_t
strideA
,
const
int64_t
strideB
)
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
...
...
paddle/fluid/operators/math/math_function.cu
浏览文件 @
2a06e307
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
...
...
@@ -267,7 +268,8 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
const
platform
::
CUDADeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float16
alpha
,
const
float16
*
A
,
const
float16
*
B
,
const
float16
beta
,
float16
*
C
,
const
int
batchCount
,
const
int
strideA
,
const
int
strideB
)
{
float16
*
C
,
const
int
batchCount
,
const
int64_t
strideA
,
const
int64_t
strideB
)
{
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
...
...
@@ -278,7 +280,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
const
int
strideC
=
M
*
N
;
const
int
64_t
strideC
=
M
*
N
;
const
half
h_alpha
=
static_cast
<
const
half
>
(
alpha
);
const
half
h_beta
=
static_cast
<
const
half
>
(
beta
);
...
...
@@ -303,7 +305,8 @@ void batched_gemm<platform::CUDADeviceContext, float>(
const
platform
::
CUDADeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
,
const
int
batchCount
,
const
int
strideA
,
const
int
strideB
)
{
float
*
C
,
const
int
batchCount
,
const
int64_t
strideA
,
const
int64_t
strideB
)
{
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
...
...
@@ -314,7 +317,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
const
int
strideC
=
M
*
N
;
const
int
64_t
strideC
=
M
*
N
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmStridedBatched
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
...
...
@@ -329,7 +332,8 @@ void batched_gemm<platform::CUDADeviceContext, double>(
const
platform
::
CUDADeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
,
const
int
batchCount
,
const
int
strideA
,
const
int
strideB
)
{
double
*
C
,
const
int
batchCount
,
const
int64_t
strideA
,
const
int64_t
strideB
)
{
#if CUDA_VERSION >= 8000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
...
...
@@ -340,7 +344,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
const
int
strideC
=
M
*
N
;
const
int
64_t
strideC
=
M
*
N
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemmStridedBatched
(
context
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
...
...
paddle/fluid/operators/math/math_function.h
浏览文件 @
2a06e307
...
...
@@ -26,7 +26,7 @@ limitations under the License. */
#ifndef LAPACK_FOUND
extern
"C"
{
#include <cblas.h>
#include <cblas.h>
// NOLINT
int
LAPACKE_sgetrf
(
int
matrix_layout
,
int
m
,
int
n
,
float
*
a
,
int
lda
,
int
*
ipiv
);
int
LAPACKE_dgetrf
(
int
matrix_layout
,
int
m
,
int
n
,
double
*
a
,
int
lda
,
...
...
@@ -39,6 +39,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#endif
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
...
...
@@ -78,8 +79,8 @@ template <typename DeviceContext, typename T>
void
batched_gemm
(
const
DeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
,
const
int
batchCount
,
const
int
strideA
,
const
int
strideB
);
const
T
beta
,
T
*
C
,
const
int
batchCount
,
const
int
64_t
strideA
,
const
int64_t
strideB
);
template
<
typename
DeviceContext
,
typename
T
>
void
gemv
(
const
DeviceContext
&
context
,
const
bool
trans_a
,
const
int
M
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录