Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6975542a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6975542a
编写于
4月 26, 2023
作者:
S
sneaxiy
提交者:
GitHub
4月 26, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Cherry-pick] Optimize c_embedding op in deterministic mode (#53203)
上级
4236351c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
456 addition
and
44 deletion
+456
-44
paddle/fluid/operators/collective/c_embedding_op.cu
paddle/fluid/operators/collective/c_embedding_op.cu
+59
-28
paddle/phi/core/flags.cc
paddle/phi/core/flags.cc
+7
-5
paddle/phi/kernels/funcs/embedding_grad.h
paddle/phi/kernels/funcs/embedding_grad.h
+167
-0
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
+10
-11
python/paddle/fluid/tests/unittests/test_embedding_deterministic.py
...dle/fluid/tests/unittests/test_embedding_deterministic.py
+213
-0
未找到文件。
paddle/fluid/operators/collective/c_embedding_op.cu
浏览文件 @
6975542a
...
...
@@ -18,8 +18,9 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/embedding_grad.h"
DECLARE_
bool
(
cudnn
_deterministic
);
DECLARE_
int64
(
embedding
_deterministic
);
namespace
paddle
{
namespace
operators
{
...
...
@@ -154,7 +155,6 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
int
D
=
d_table_t
->
dims
()[
1
];
int
K
=
ids_t
->
numel
();
const
int64_t
end_idx
=
start_idx
+
N
;
auto
limit
=
K
*
D
;
int
blocks
=
NumBlocks
(
limit
);
int
threads
=
kNumCUDAThreads
;
...
...
@@ -166,33 +166,64 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
t
.
device
(
*
dev_ctx
.
eigen_device
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
const
auto
&
index_type
=
framework
::
TransToProtoVarType
(
ids_t
->
dtype
());
if
(
FLAGS_cudnn_deterministic
)
{
VLOG
(
2
)
<<
"Run grad kernel of embedding with single thread."
;
blocks
=
1
;
}
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CEmbeddingGrad
<
T
,
int32_t
>
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
d_output
,
ids_t
->
data
<
int32_t
>
(),
K
,
D
,
N
,
start_idx
,
end_idx
,
limit
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
CEmbeddingGrad
<
T
,
int64_t
>
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
d_output
,
ids_t
->
data
<
int64_t
>
(),
K
,
D
,
N
,
start_idx
,
end_idx
,
limit
);
if
(
FLAGS_embedding_deterministic
==
1
)
{
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
phi
::
funcs
::
LaunchEmbeddingGradDeterministicKernel
<
T
,
int32_t
>
(
dev_ctx
,
ids_t
->
data
<
int32_t
>
(),
d_output
,
d_table
,
N
,
D
,
K
,
start_idx
);
return
;
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
phi
::
funcs
::
LaunchEmbeddingGradDeterministicKernel
<
T
,
int64_t
>
(
dev_ctx
,
ids_t
->
data
<
int64_t
>
(),
d_output
,
d_table
,
N
,
D
,
K
,
start_idx
);
return
;
}
}
else
{
if
(
FLAGS_embedding_deterministic
>
1
)
{
VLOG
(
2
)
<<
"Run grad kernel of embedding with single thread."
;
blocks
=
1
;
}
const
int64_t
end_idx
=
start_idx
+
N
;
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
CEmbeddingGrad
<
T
,
int32_t
>
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
d_output
,
ids_t
->
data
<
int32_t
>
(),
K
,
D
,
N
,
start_idx
,
end_idx
,
limit
);
return
;
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
CEmbeddingGrad
<
T
,
int64_t
>
<<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
d_output
,
ids_t
->
data
<
int64_t
>
(),
K
,
D
,
N
,
start_idx
,
end_idx
,
limit
);
return
;
}
}
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"The data type of Input(Ids) must be int32 or int64."
));
}
};
...
...
paddle/phi/core/flags.cc
浏览文件 @
6975542a
...
...
@@ -236,17 +236,19 @@ PADDLE_DEFINE_EXPORTED_bool(
* CUDA related FLAG
* Name: FLAGS_embedding_deterministic
* Since Version: 2.5
* Value Range:
bool, default=false
* Value Range:
int64, default=0
* Example:
* Note: whether to use deterministic algorithm in embedding op.
* If true, it will use deterministic CUDA kernel in embedding op.
* If it is 1, it will use the optimized deterministic CUDA kernel in
* embedding op. If it is 2, it will use the legacy deterministic
* CUDA kernel in embedding op.
*/
PADDLE_DEFINE_EXPORTED_
bool
(
PADDLE_DEFINE_EXPORTED_
int64
(
embedding_deterministic
,
false
,
0
,
"Whether allow using an deterministic algorithm for embedding "
"operator. The deterministic algorithm may be slower. If "
"
true
, the algorithm is deterministic."
);
"
it is larger than 0
, the algorithm is deterministic."
);
/**
* CUDNN related FLAG
...
...
paddle/phi/kernels/funcs/embedding_grad.h
0 → 100644
浏览文件 @
6975542a
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
namespace
phi
{
namespace
funcs
{
template
<
typename
T
,
typename
IdT
,
int
WarpSize
,
int
BlockDimY
,
bool
UseLimit
>
__global__
void
EmbeddingGradDeterministicKernel
(
T
*
table
,
const
T
*
output
,
const
IdT
*
ids
,
const
int64_t
K
,
const
int64_t
D
,
const
int64_t
start_idx
,
const
int64_t
end_idx
)
{
using
MT
=
typename
dtype
::
MPTypeTrait
<
T
>::
Type
;
constexpr
int64_t
kInvalidId
=
-
1
;
extern
__shared__
char
buf
[];
MT
*
smem
=
reinterpret_cast
<
MT
*>
(
buf
);
MT
*
my_s
=
smem
+
WarpSize
*
threadIdx
.
y
;
IdT
*
indices_batch
=
reinterpret_cast
<
IdT
*>
(
buf
+
sizeof
(
MT
)
*
WarpSize
*
BlockDimY
);
const
int
stride
=
static_cast
<
int
>
(
D
);
const
int
feature
=
threadIdx
.
x
+
blockIdx
.
x
*
WarpSize
;
// To ensure determinism. If any other warps pulled grad data targeting
// dst_row, we elect the first warp in each matching group as the leader.
// Each leader warp serializes the accumulates targeting dst_row in shared
// memory, then adding the accumulated buffer to dst_row in table.
for
(
int
batch_start
=
0
;
batch_start
<
K
;
batch_start
+=
WarpSize
*
BlockDimY
)
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
WarpSize
;
if
(
batch_start
+
tid
<
K
)
{
int64_t
cur_id
=
static_cast
<
int64_t
>
(
ids
[
batch_start
+
tid
]);
if
(
UseLimit
)
{
if
(
cur_id
>=
start_idx
&&
cur_id
<
end_idx
)
{
cur_id
-=
start_idx
;
}
else
{
cur_id
=
kInvalidId
;
}
}
indices_batch
[
tid
]
=
cur_id
;
}
int
batch_end
=
min
(
static_cast
<
int64_t
>
(
batch_start
+
WarpSize
*
BlockDimY
),
K
);
// Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY
for
(
int
chunk_start
=
batch_start
;
chunk_start
<
batch_end
;
chunk_start
+=
BlockDimY
)
{
// This sync makes sure that indices_batch is ready and match-group
// leaders are done with their accumulates before other warps start
// loading again.
__syncthreads
();
int
n_this_chunk
=
min
(
batch_end
-
chunk_start
,
BlockDimY
);
int64_t
src_row
=
static_cast
<
int64_t
>
(
chunk_start
+
threadIdx
.
y
);
int64_t
dst_row
=
indices_batch
[
src_row
-
batch_start
];
if
(
src_row
<
K
&&
feature
<
stride
)
{
if
(
UseLimit
&&
dst_row
==
kInvalidId
)
{
my_s
[
threadIdx
.
x
]
=
static_cast
<
MT
>
(
0
);
}
else
{
my_s
[
threadIdx
.
x
]
=
static_cast
<
MT
>
(
output
[
src_row
*
D
+
feature
]);
}
}
__syncthreads
();
if
(
src_row
<
K
)
{
int
match_found_this_thread
=
0
;
if
(
threadIdx
.
x
<
n_this_chunk
&&
(
!
UseLimit
||
dst_row
!=
kInvalidId
))
{
match_found_this_thread
=
(
dst_row
==
indices_batch
[
chunk_start
-
batch_start
+
threadIdx
.
x
]);
}
#ifdef PADDLE_WITH_HIP
unsigned
long
long
int
matchmask
=
// NOLINT
__ballot
(
match_found_this_thread
);
// NOLINT
int
first_remaining_peer
=
__ffsll
(
matchmask
)
-
1
;
#else
// If and only if match_found_this_thread of the Nth thread is non-zero,
// set the Nth bit of matchmask to 1.
unsigned
int
matchmask
=
__ballot_sync
(
0xffffffff
,
match_found_this_thread
);
// Find the position of the first bit set to 1 in matchmask.
int
first_remaining_peer
=
__ffs
(
matchmask
)
-
1
;
#endif
// select lowest-indexed warp as the leader
if
(
threadIdx
.
y
==
first_remaining_peer
)
{
// Set the first bit 1 in matchmask to 0.
matchmask
^=
(
1
<<
first_remaining_peer
);
while
(
matchmask
)
{
#ifdef PADDLE_WITH_HIP
first_remaining_peer
=
__ffsll
(
matchmask
)
-
1
;
#else
first_remaining_peer
=
__ffs
(
matchmask
)
-
1
;
#endif
my_s
[
threadIdx
.
x
]
+=
smem
[
threadIdx
.
x
+
WarpSize
*
first_remaining_peer
];
matchmask
^=
(
1
<<
first_remaining_peer
);
}
if
(
feature
<
stride
&&
(
!
UseLimit
||
dst_row
!=
kInvalidId
))
{
auto
table_idx
=
dst_row
*
D
+
feature
;
table
[
table_idx
]
=
static_cast
<
T
>
(
static_cast
<
MT
>
(
table
[
table_idx
])
+
my_s
[
threadIdx
.
x
]);
}
}
}
}
}
}
template
<
typename
T
,
typename
IdT
>
void
LaunchEmbeddingGradDeterministicKernel
(
const
GPUContext
&
ctx
,
const
IdT
*
ids
,
const
T
*
d_out
,
T
*
d_table
,
int64_t
N
,
int64_t
D
,
int64_t
K
,
int64_t
start_idx
=
-
1
)
{
#ifdef PADDLE_WITH_HIP
constexpr
int
kWarpSize
=
64
;
constexpr
int
kBlockDimY
=
16
;
#else
constexpr
int
kWarpSize
=
32
;
constexpr
int
kBlockDimY
=
32
;
#endif
dim3
threads
(
kWarpSize
,
kBlockDimY
);
dim3
grids
(
static_cast
<
int
>
((
D
+
kWarpSize
-
1
)
/
kWarpSize
));
using
MT
=
typename
dtype
::
MPTypeTrait
<
T
>::
Type
;
constexpr
auto
kSharedMemSize
=
sizeof
(
MT
)
*
kWarpSize
*
kBlockDimY
+
sizeof
(
IdT
)
*
kWarpSize
*
kBlockDimY
;
if
(
start_idx
<
0
)
{
EmbeddingGradDeterministicKernel
<
T
,
IdT
,
kWarpSize
,
kBlockDimY
,
false
>
<<<
grids
,
threads
,
kSharedMemSize
,
ctx
.
stream
()
>>>
(
d_table
,
d_out
,
ids
,
K
,
D
,
-
1
,
-
1
);
}
else
{
int64_t
end_idx
=
start_idx
+
N
;
EmbeddingGradDeterministicKernel
<
T
,
IdT
,
kWarpSize
,
kBlockDimY
,
true
>
<<<
grids
,
threads
,
kSharedMemSize
,
ctx
.
stream
()
>>>
(
d_table
,
d_out
,
ids
,
K
,
D
,
start_idx
,
end_idx
);
}
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
浏览文件 @
6975542a
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_grad.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
...
...
@@ -26,7 +27,7 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
DECLARE_
bool
(
embedding_deterministic
);
DECLARE_
int64
(
embedding_deterministic
);
namespace
phi
{
...
...
@@ -198,20 +199,18 @@ struct EmbeddingGradCUDAFunctor {
cudaMemsetAsync
(
d_table
,
0
,
N
*
D
*
sizeof
(
T
),
dev_ctx_
.
stream
()));
#endif
if
(
FLAGS_embedding_deterministic
)
{
dim3
threads
(
WARP_SIZE
,
BLOCKDIMY
);
dim3
grids
(
static_cast
<
int
>
((
D
+
WARP_SIZE
-
1
)
/
WARP_SIZE
));
using
MT
=
typename
dtype
::
MPTypeTrait
<
T
>::
Type
;
EmbeddingGradDeterministic
<
T
,
IdT
>
<<<
grids
,
threads
,
sizeof
(
MT
)
*
WARP_SIZE
*
BLOCKDIMY
+
sizeof
(
IdT
)
*
WARP_SIZE
*
BLOCKDIMY
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
K
,
D
);
if
(
FLAGS_embedding_deterministic
==
1
)
{
phi
::
funcs
::
LaunchEmbeddingGradDeterministicKernel
<
T
,
IdT
>
(
dev_ctx_
,
ids
,
d_output
,
d_table
,
N
,
D
,
K
);
}
else
{
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
dim3
threads
(
128
,
8
);
dim3
grids
(
gridx
,
1
);
if
(
FLAGS_embedding_deterministic
>
1
)
{
VLOG
(
2
)
<<
"Run grad kernel of embedding with single thread."
;
grids
.
x
=
1
;
threads
.
y
=
1
;
}
EmbeddingGrad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
...
...
python/paddle/fluid/tests/unittests/test_embedding_deterministic.py
0 → 100644
浏览文件 @
6975542a
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextlib
import
random
import
sys
import
unittest
import
numpy
as
np
import
paddle
from
paddle.distributed.fleet.layers.mpu.mp_ops
import
_c_lookup_table
@
contextlib
.
contextmanager
def
deterministic_guard
(
value
):
flag_name
=
'FLAGS_embedding_deterministic'
old_value
=
paddle
.
get_flags
(
flag_name
)[
flag_name
]
paddle
.
set_flags
({
flag_name
:
value
})
assert
paddle
.
get_flags
(
flag_name
)[
flag_name
]
==
value
yield
paddle
.
set_flags
({
flag_name
:
old_value
})
assert
paddle
.
get_flags
(
flag_name
)[
flag_name
]
==
old_value
def
to_numpy
(
tensor
):
if
tensor
.
dtype
in
[
paddle
.
float16
,
paddle
.
bfloat16
]:
tensor
=
tensor
.
astype
(
paddle
.
float32
)
return
tensor
.
numpy
()
def
clone_weight
(
weight
):
if
weight
.
dtype
==
paddle
.
bfloat16
:
weight
=
weight
.
astype
(
paddle
.
float32
).
numpy
()
weight
=
paddle
.
to_tensor
(
weight
,
dtype
=
paddle
.
float32
).
astype
(
paddle
.
bfloat16
)
else
:
weight
=
paddle
.
to_tensor
(
weight
.
numpy
())
weight
.
stop_gradient
=
False
return
weight
def
embedding
(
ids
,
weight
,
out_grad
,
deterministic_level
=
0
,
rank
=
None
):
weight
=
clone_weight
(
weight
)
with
deterministic_guard
(
deterministic_level
):
if
rank
is
not
None
:
vocab_size
,
_
=
weight
.
shape
start_idx
=
vocab_size
*
rank
out
=
_c_lookup_table
(
weight
,
ids
,
start_index
=
start_idx
)
else
:
out
=
paddle
.
nn
.
functional
.
embedding
(
ids
,
weight
)
out
.
backward
(
out_grad
.
clone
())
return
to_numpy
(
out
),
to_numpy
(
weight
.
grad
)
def
embedding_ground_truth
(
ids
,
weight
,
out_grad
,
rank
=
None
):
weight
=
clone_weight
(
weight
.
astype
(
paddle
.
float32
))
out_grad
=
out_grad
.
astype
(
paddle
.
float32
)
return
embedding
(
ids
,
weight
,
out_grad
,
deterministic_level
=
2
,
rank
=
rank
)
def
generate_input_data
(
ids_shape
,
vocab_size
,
hidden_size
,
weight_dtype
,
ids_dtype
,
allow_duplicate_id
=
True
,
rank
=
None
,
nranks
=
None
,
allow_pure_random
=
False
,
):
max_id
=
vocab_size
if
rank
is
None
else
vocab_size
*
nranks
if
allow_duplicate_id
:
ids
=
np
.
random
.
randint
(
low
=
0
,
high
=
max_id
,
size
=
ids_shape
)
else
:
sequence
=
list
(
range
(
max_id
))
numel
=
int
(
np
.
prod
(
ids_shape
))
if
len
(
sequence
)
<
numel
:
return
None
,
None
,
None
ids
=
np
.
array
(
random
.
sample
(
sequence
,
numel
)).
reshape
(
ids_shape
)
ids
=
paddle
.
to_tensor
(
ids
).
astype
(
ids_dtype
)
ids
.
stop_gradient
=
True
weight
=
paddle
.
randn
([
vocab_size
,
hidden_size
]).
astype
(
weight_dtype
)
weight
.
stop_gradient
=
False
out_grad_shape
=
list
(
ids_shape
)
+
[
hidden_size
]
if
allow_duplicate_id
and
not
allow_pure_random
:
out_grad
=
paddle
.
randint
(
low
=-
10
,
high
=
10
,
shape
=
out_grad_shape
)
else
:
out_grad
=
paddle
.
randn
(
out_grad_shape
)
out_grad
=
out_grad
.
astype
(
weight
.
dtype
)
return
ids
,
weight
,
out_grad
def
get_all_dtypes
():
if
not
paddle
.
is_compiled_with_cuda
()
or
paddle
.
is_compiled_with_rocm
():
return
[]
dtypes
=
[
paddle
.
float32
,
paddle
.
float16
]
if
'A100'
in
paddle
.
device
.
cuda
.
get_device_properties
().
name
:
dtypes
.
append
(
paddle
.
bfloat16
)
return
dtypes
class
TestEmbeddingBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
ids_shape
=
[
32
,
3
]
self
.
vocab_size
=
128
self
.
hidden_size
=
1024
self
.
nranks
=
8
def
check_main
(
self
,
weight_dtype
,
ids_dtype
,
deterministic_level
=
0
,
rank
=
None
,
allow_duplicate_id
=
True
,
allow_pure_random
=
False
,
):
if
sys
.
platform
==
'win32'
and
rank
is
not
None
:
return
ids
,
weight
,
out_grad
=
generate_input_data
(
ids_shape
=
self
.
ids_shape
,
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
weight_dtype
=
weight_dtype
,
ids_dtype
=
ids_dtype
,
allow_duplicate_id
=
allow_duplicate_id
,
rank
=
rank
,
nranks
=
self
.
nranks
,
allow_pure_random
=
allow_pure_random
,
)
if
ids
is
None
:
return
if
allow_pure_random
:
out_1
,
weight_grad_1
=
embedding_ground_truth
(
ids
,
weight
,
out_grad
,
rank
)
out_2
,
weight_grad_2
=
embedding_ground_truth
(
ids
,
weight
,
out_grad
,
rank
)
else
:
out_1
,
weight_grad_1
=
embedding_ground_truth
(
ids
,
weight
,
out_grad
,
rank
)
out_2
,
weight_grad_2
=
embedding
(
ids
,
weight
,
out_grad
,
deterministic_level
=
deterministic_level
,
rank
=
rank
,
)
np
.
testing
.
assert_equal
(
out_1
,
out_2
)
np
.
testing
.
assert_equal
(
weight_grad_1
,
weight_grad_2
)
def
test_main
(
self
):
weight_dtypes
=
get_all_dtypes
()
ids_dtypes
=
[
paddle
.
int64
,
paddle
.
int32
]
deterministic_levels
=
[
0
,
1
]
ranks
=
[
None
,
0
,
2
,
4
,
8
]
allow_duplicate_ids
=
[
False
,
True
]
allow_pure_randoms
=
[
False
,
True
]
for
weight_dtype
in
weight_dtypes
:
for
ids_dtype
in
ids_dtypes
:
for
deterministic_level
in
deterministic_levels
:
for
rank
in
ranks
:
for
allow_duplicate_id
in
allow_duplicate_ids
:
for
allow_pure_random
in
allow_pure_randoms
:
self
.
check_main
(
weight_dtype
,
ids_dtype
,
deterministic_level
,
rank
,
allow_duplicate_id
,
allow_pure_random
,
)
class
TestEmbedding2
(
TestEmbeddingBase
):
def
setUp
(
self
):
self
.
ids_shape
=
[
32
,
16
]
self
.
vocab_size
=
128
self
.
hidden_size
=
1024
self
.
nranks
=
8
class
TestEmbeddingDeterministic
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
ids_shape
=
[
32
,
16
]
self
.
vocab_size
=
128
self
.
hidden_size
=
1024
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录