Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
b6a4349d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b6a4349d
编写于
9月 18, 2020
作者:
W
wawltor
提交者:
GitHub
9月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the error message for the math dir
https://github.com/PaddlePaddle/Paddle/pull/27332
上级
ac82baa8
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
146 addition
and
44 deletion
+146
-44
paddle/fluid/operators/math/beam_search.cc
paddle/fluid/operators/math/beam_search.cc
+4
-1
paddle/fluid/operators/math/beam_search.cu
paddle/fluid/operators/math/beam_search.cu
+4
-1
paddle/fluid/operators/math/blas.cc
paddle/fluid/operators/math/blas.cc
+5
-1
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+18
-8
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+115
-33
未找到文件。
paddle/fluid/operators/math/beam_search.cc
浏览文件 @
b6a4349d
...
@@ -87,7 +87,10 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
...
@@ -87,7 +87,10 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
lod
[
0
].
assign
(
high_level
.
begin
(),
high_level
.
end
());
lod
[
0
].
assign
(
high_level
.
begin
(),
high_level
.
end
());
lod
[
1
].
assign
(
low_level
.
begin
(),
low_level
.
end
());
lod
[
1
].
assign
(
low_level
.
begin
(),
low_level
.
end
());
if
(
!
framework
::
CheckLoD
(
lod
))
{
if
(
!
framework
::
CheckLoD
(
lod
))
{
PADDLE_THROW
(
"lod %s is not right"
,
framework
::
LoDToString
(
lod
));
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"lod %s is not right in"
" beam_search, please check your code."
,
framework
::
LoDToString
(
lod
)));
}
}
selected_ids
->
set_lod
(
lod
);
selected_ids
->
set_lod
(
lod
);
selected_scores
->
set_lod
(
lod
);
selected_scores
->
set_lod
(
lod
);
...
...
paddle/fluid/operators/math/beam_search.cu
浏览文件 @
b6a4349d
...
@@ -400,7 +400,10 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
...
@@ -400,7 +400,10 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
context
.
Wait
();
context
.
Wait
();
if
(
!
framework
::
CheckLoD
(
selected_lod
))
{
if
(
!
framework
::
CheckLoD
(
selected_lod
))
{
PADDLE_THROW
(
"lod %s is not right"
,
framework
::
LoDToString
(
selected_lod
));
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"lod %s is not right in"
" beam_search, please check your code."
,
framework
::
LoDToString
(
selected_lod
)));
}
}
selected_ids
->
set_lod
(
selected_lod
);
selected_ids
->
set_lod
(
selected_lod
);
...
...
paddle/fluid/operators/math/blas.cc
浏览文件 @
b6a4349d
...
@@ -20,7 +20,11 @@ namespace operators {
...
@@ -20,7 +20,11 @@ namespace operators {
namespace
math
{
namespace
math
{
MatDescriptor
CreateMatrixDescriptor
(
const
framework
::
DDim
&
tensor_dim
,
MatDescriptor
CreateMatrixDescriptor
(
const
framework
::
DDim
&
tensor_dim
,
int
num_flatten_cols
,
bool
trans
)
{
int
num_flatten_cols
,
bool
trans
)
{
PADDLE_ENFORCE_GT
(
tensor_dim
.
size
(),
1
);
PADDLE_ENFORCE_GT
(
tensor_dim
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The tensor dim size should be greater "
"than 1, but reveived dim size is %d"
,
tensor_dim
.
size
()));
MatDescriptor
retv
;
MatDescriptor
retv
;
if
(
num_flatten_cols
>
1
)
{
if
(
num_flatten_cols
>
1
)
{
auto
flatten_dim
=
framework
::
flatten_to_2d
(
tensor_dim
,
num_flatten_cols
);
auto
flatten_dim
=
framework
::
flatten_to_2d
(
tensor_dim
,
num_flatten_cols
);
...
...
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
b6a4349d
...
@@ -60,7 +60,8 @@ struct CUBlas<float> {
...
@@ -60,7 +60,8 @@ struct CUBlas<float> {
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasSgemmStridedBatched
(
args
...));
platform
::
dynload
::
cublasSgemmStridedBatched
(
args
...));
#else
#else
PADDLE_THROW
(
"SgemmStridedBatched is not supported on cuda <= 7.5"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"SgemmStridedBatched is not supported on cuda <= 7.5"
));
#endif
#endif
}
}
...
@@ -85,7 +86,8 @@ struct CUBlas<float> {
...
@@ -85,7 +86,8 @@ struct CUBlas<float> {
beta
,
C
,
Ctype
,
ldc
));
beta
,
C
,
Ctype
,
ldc
));
});
});
#else
#else
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"cublasSgemmEx is not supported on cuda <= 7.5"
));
#endif
#endif
}
}
...
@@ -146,13 +148,15 @@ struct CUBlas<double> {
...
@@ -146,13 +148,15 @@ struct CUBlas<double> {
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasDgemmStridedBatched
(
args
...));
platform
::
dynload
::
cublasDgemmStridedBatched
(
args
...));
#else
#else
PADDLE_THROW
(
"DgemmStridedBatched is not supported on cuda <= 7.5"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"DgemmStridedBatched is not supported on cuda <= 7.5"
));
#endif
#endif
}
}
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
static
void
GEMM_EX
(
ARGS
...
args
)
{
static
void
GEMM_EX
(
ARGS
...
args
)
{
PADDLE_THROW
(
"Currently there are not cublasDgemmEx."
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Currently there are not cublasDgemmEx."
));
}
}
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
...
@@ -216,7 +220,8 @@ struct CUBlas<platform::float16> {
...
@@ -216,7 +220,8 @@ struct CUBlas<platform::float16> {
reinterpret_cast
<
const
__half
*>
(
beta
),
reinterpret_cast
<
__half
*>
(
C
),
reinterpret_cast
<
const
__half
*>
(
beta
),
reinterpret_cast
<
__half
*>
(
C
),
ldc
,
strideC
,
batchCount
));
ldc
,
strideC
,
batchCount
));
#else
#else
PADDLE_THROW
(
"HgemmStridedBatched is not supported on cuda <= 7.5"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"HgemmStridedBatched is not supported on cuda <= 7.5"
));
#endif
#endif
}
}
...
@@ -247,7 +252,8 @@ struct CUBlas<platform::float16> {
...
@@ -247,7 +252,8 @@ struct CUBlas<platform::float16> {
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
});
});
#else
#else
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"cublasGemmEx is not supported on cuda <= 7.5"
));
#endif
#endif
}
}
};
};
...
@@ -302,8 +308,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
...
@@ -302,8 +308,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
// TODO(kexinzhao): add processing code for compute capability < 53 case
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context_
.
GetComputeCapability
(),
53
,
PADDLE_ENFORCE_GE
(
"cublas fp16 gemm requires GPU compute capability >= 53"
);
context_
.
GetComputeCapability
(),
53
,
platform
::
errors
::
InvalidArgument
(
"cublas fp16 gemm requires GPU compute capability >= 53,"
"but received %d"
,
context_
.
GetComputeCapability
()));
float
h_alpha
=
static_cast
<
float
>
(
alpha
);
float
h_alpha
=
static_cast
<
float
>
(
alpha
);
float
h_beta
=
static_cast
<
float
>
(
beta
);
float
h_beta
=
static_cast
<
float
>
(
beta
);
...
...
paddle/fluid/operators/math/blas_impl.h
浏览文件 @
b6a4349d
...
@@ -29,7 +29,8 @@ template <>
...
@@ -29,7 +29,8 @@ template <>
struct
CBlas
<
int8_t
>
{
struct
CBlas
<
int8_t
>
{
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
static
void
VCOPY
(
ARGS
...
args
)
{
static
void
VCOPY
(
ARGS
...
args
)
{
PADDLE_THROW
(
"Blas VCOPY don't support int8_t"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Blas VCOPY do not supported on CPU, please check your code"
));
}
}
};
};
...
@@ -347,22 +348,47 @@ struct CBlas<double> {
...
@@ -347,22 +348,47 @@ struct CBlas<double> {
template
<
>
template
<
>
struct
CBlas
<
platform
::
float16
>
{
struct
CBlas
<
platform
::
float16
>
{
static
void
GEMM
(...)
{
PADDLE_THROW
(
"float16 GEMM not supported on CPU"
);
}
static
void
GEMM
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 GEMM not supported on CPU, please check your code"
));
}
static
void
SMM_GEMM
(...)
{
static
void
SMM_GEMM
(...)
{
PADDLE_THROW
(
"float16 SMM_GEMM not supported on CPU"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 SMM_GEMM not supported on CPU, please check your code"
));
}
}
static
void
VMUL
(...)
{
PADDLE_THROW
(
"float16 VMUL not supported on CPU"
);
}
static
void
VMUL
(...)
{
static
void
VEXP
(...)
{
PADDLE_THROW
(
"float16 VEXP not supported on CPU"
);
}
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
static
void
VSQUARE
(...)
{
"float16 VMUL not supported on CPU, please check your code"
));
PADDLE_THROW
(
"float16 VSQUARE not supported on CPU"
);
}
}
static
void
VPOW
(...)
{
PADDLE_THROW
(
"float16 VPOW not supported on CPU"
);
}
static
void
VEXP
(...)
{
static
void
DOT
(...)
{
PADDLE_THROW
(
"float16 DOT not supported on CPU"
);
};
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
static
void
SCAL
(...)
{
PADDLE_THROW
(
"float16 SCAL not supported on CPU"
);
};
"float16 VEXP not supported on CPU, please check your code"
));
static
void
ASUM
(...)
{
PADDLE_THROW
(
"float16 ASUM not supported on CPU"
);
};
}
static
void
VSQUARE
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 VSQUARE not supported on CPU, please check your code"
));
}
static
void
VPOW
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 VPOW not supported on CPU, please check your code"
));
}
static
void
DOT
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 DOT not supported on CPU, please check your code"
));
};
static
void
SCAL
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 SCAL not supported on CPU, please check your code"
));
};
static
void
ASUM
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 ASUM not supported on CPU, please check your code"
));
};
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
static
void
GEMM_BATCH
(...)
{
static
void
GEMM_BATCH
(...)
{
PADDLE_THROW
(
"float16 GEMM_BATCH not supported on CPU"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 GEMM_BATCH not supported on CPU, please check your code"
));
}
}
#endif
#endif
};
};
...
@@ -446,11 +472,18 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, bool trans_a,
...
@@ -446,11 +472,18 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, bool trans_a,
auto
dim_a
=
mat_a
.
dims
();
auto
dim_a
=
mat_a
.
dims
();
auto
dim_b
=
mat_b
.
dims
();
auto
dim_b
=
mat_b
.
dims
();
auto
dim_out
=
mat_out
->
dims
();
auto
dim_out
=
mat_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
PADDLE_ENFORCE_EQ
(
"The input and output of matmul be matrix"
);
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
true
,
PADDLE_ENFORCE
(
platform
::
errors
::
InvalidArgument
(
mat_a
.
place
()
==
mat_b
.
place
()
&&
mat_a
.
place
()
==
mat_out
->
place
(),
"The input and output of matmul should be matrix, the dim size must "
"The places of matrices must be same"
);
"be 2,"
"but received dim size input_a:%d, input_b:%d, output:%d"
,
dim_a
.
size
(),
dim_b
.
size
(),
dim_out
.
size
()));
PADDLE_ENFORCE_EQ
(
mat_a
.
place
()
==
mat_b
.
place
()
&&
mat_a
.
place
()
==
mat_out
->
place
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The places of matrices in the matmul "
"should be same, please check your "
"code."
));
int
M
=
dim_out
[
0
];
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
N
=
dim_out
[
1
];
...
@@ -715,7 +748,13 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
...
@@ -715,7 +748,13 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
}
}
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
W1
,
H2
);
PADDLE_ENFORCE_EQ
(
W1
,
H2
,
platform
::
errors
::
InvalidArgument
(
"The fisrt matrix width should be same as second matrix height,"
"but received fisrt matrix width %d"
", second matrix height %d"
,
W1
,
H2
));
int
ldc
=
W2
*
head_number
;
int
ldc
=
W2
*
head_number
;
int
sub_width
=
W1
/
head_number
;
int
sub_width
=
W1
/
head_number
;
...
@@ -785,7 +824,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
...
@@ -785,7 +824,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
const
framework
::
Tensor
&
mat_b
,
const
framework
::
Tensor
&
mat_b
,
const
MatDescriptor
&
dim_b
,
T
alpha
,
const
MatDescriptor
&
dim_b
,
T
alpha
,
framework
::
Tensor
*
mat_out
,
T
beta
)
const
{
framework
::
Tensor
*
mat_out
,
T
beta
)
const
{
PADDLE_ENFORCE_EQ
(
dim_a
.
width_
,
dim_b
.
height_
);
PADDLE_ENFORCE_EQ
(
dim_a
.
width_
,
dim_b
.
height_
,
platform
::
errors
::
InvalidArgument
(
"The fisrt matrix width should be same as second matrix height,"
"but received fisrt matrix width %d"
", second matrix height %d"
,
dim_a
.
width_
,
dim_b
.
height_
));
CBLAS_TRANSPOSE
transA
=
!
dim_a
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transA
=
!
dim_a
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
!
dim_b
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
!
dim_b
.
trans_
?
CblasNoTrans
:
CblasTrans
;
if
(
dim_a
.
batch_size_
==
0
&&
dim_b
.
batch_size_
==
0
)
{
if
(
dim_a
.
batch_size_
==
0
&&
dim_b
.
batch_size_
==
0
)
{
...
@@ -793,12 +839,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
...
@@ -793,12 +839,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
dim_a
.
width_
,
alpha
,
mat_a
.
data
<
T
>
(),
dim_a
.
width_
,
alpha
,
mat_a
.
data
<
T
>
(),
mat_b
.
data
<
T
>
(),
beta
,
mat_out
->
data
<
T
>
());
mat_b
.
data
<
T
>
(),
beta
,
mat_out
->
data
<
T
>
());
}
else
{
}
else
{
PADDLE_ENFORCE
(
dim_a
.
batch_size_
==
dim_b
.
batch_size_
||
PADDLE_ENFORCE_EQ
(
dim_a
.
batch_size_
==
0
||
dim_b
.
batch_size_
==
0
,
dim_a
.
batch_size_
==
dim_b
.
batch_size_
||
dim_a
.
batch_size_
==
0
||
"dim_a.batch_size should be equal to dim_b.batch_size, or "
dim_b
.
batch_size_
==
0
,
"one of dim_a.batch_size and dim_b.batch_size should be 0. "
true
,
platform
::
errors
::
InvalidArgument
(
"But got dim_a.batch_size = %d, dim_b.batch_size = %d."
,
"dim_a.batch_size should be equal to dim_b.batch_size, or "
dim_a
.
batch_size_
,
dim_b
.
batch_size_
);
"one of dim_a.batch_size and dim_b.batch_size should be 0. "
"But got dim_a.batch_size = %d, dim_b.batch_size = %d."
,
dim_a
.
batch_size_
,
dim_b
.
batch_size_
));
this
->
template
BatchedGEMM
<
T
>(
this
->
template
BatchedGEMM
<
T
>(
transA
,
transB
,
dim_a
.
height_
,
dim_b
.
width_
,
dim_a
.
width_
,
alpha
,
transA
,
transB
,
dim_a
.
height_
,
dim_b
.
width_
,
dim_a
.
width_
,
alpha
,
mat_a
.
data
<
T
>
(),
mat_b
.
data
<
T
>
(),
beta
,
mat_out
->
data
<
T
>
(),
mat_a
.
data
<
T
>
(),
mat_b
.
data
<
T
>
(),
beta
,
mat_out
->
data
<
T
>
(),
...
@@ -834,15 +882,42 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
...
@@ -834,15 +882,42 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
int
head_number
,
int
head_number
,
framework
::
Tensor
*
mat_out
,
T
beta
,
framework
::
Tensor
*
mat_out
,
T
beta
,
bool
mat_b_split_vertical
)
const
{
bool
mat_b_split_vertical
)
const
{
PADDLE_ENFORCE_EQ
(
dim_a
.
width_
%
head_number
,
0
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_GE
(
head_number
,
1
);
dim_a
.
width_
%
head_number
,
0
,
PADDLE_ENFORCE_LE
(
head_number
,
dim_a
.
width_
);
platform
::
errors
::
InvalidArgument
(
"The first input width must be some times the head number"
"but received first input width %d"
", head_number %d"
,
dim_a
.
width_
,
head_number
));
PADDLE_ENFORCE_GE
(
head_number
,
1
,
platform
::
errors
::
InvalidArgument
(
"The head number should be greater equal 1,"
"but received head number %d"
,
head_number
));
PADDLE_ENFORCE_LE
(
head_number
,
dim_a
.
width_
,
platform
::
errors
::
InvalidArgument
(
"The head number should be less equal first input width,"
"but received first input width %d"
", head_number %d"
,
dim_a
.
width_
,
head_number
));
CBLAS_TRANSPOSE
transA
=
!
dim_a
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transA
=
!
dim_a
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
!
dim_b
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
!
dim_b
.
trans_
?
CblasNoTrans
:
CblasTrans
;
if
(
mat_b_split_vertical
)
{
if
(
mat_b_split_vertical
)
{
PADDLE_ENFORCE_EQ
(
dim_b
.
height_
,
dim_a
.
width_
/
head_number
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
dim_b
.
width_
%
head_number
,
0
);
dim_b
.
height_
,
dim_a
.
width_
/
head_number
,
platform
::
errors
::
InvalidArgument
(
"The second input height should be equal than first input width,"
"but received second input height %d, first input width %d"
,
dim_b
.
height_
,
dim_a
.
width_
/
head_number
));
PADDLE_ENFORCE_EQ
(
dim_a
.
width_
%
head_number
,
0
,
platform
::
errors
::
InvalidArgument
(
"The second input width should be some times the head number"
"but received second input width %d"
", head_number %d"
,
dim_b
.
width_
,
head_number
));
}
}
if
(
dim_a
.
batch_size_
==
0
&&
dim_b
.
batch_size_
==
0
)
{
if
(
dim_a
.
batch_size_
==
0
&&
dim_b
.
batch_size_
==
0
)
{
...
@@ -888,9 +963,16 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
...
@@ -888,9 +963,16 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
mat_out
->
data
<
T
>
()
+
sub_matC_offset
,
ldc
);
mat_out
->
data
<
T
>
()
+
sub_matC_offset
,
ldc
);
}
}
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
((
dim_a
.
batch_size_
==
dim_b
.
batch_size_
||
PADDLE_ENFORCE_EQ
(
dim_a
.
batch_size_
==
0
||
dim_b
.
batch_size_
==
0
),
(
dim_a
.
batch_size_
==
dim_b
.
batch_size_
||
dim_a
.
batch_size_
==
0
||
true
);
dim_b
.
batch_size_
==
0
),
true
,
platform
::
errors
::
InvalidArgument
(
"The first input batch size should be equal than second input,"
"either two input batch size is 0, but received first input batch "
"size"
" %d, second input batch size %d"
,
dim_a
.
batch_size_
,
dim_b
.
batch_size_
));
this
->
template
BatchedGEMMWithHead
<
T
>(
this
->
template
BatchedGEMMWithHead
<
T
>(
transA
,
transB
,
dim_a
.
width_
,
dim_a
.
height_
,
dim_b
.
width_
,
transA
,
transB
,
dim_a
.
width_
,
dim_a
.
height_
,
dim_b
.
width_
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录