Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4a863160
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4a863160
编写于
12月 01, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/data): fix weighted random sampler
GitOrigin-RevId: d09cbbfffd2434dea3f56b3db2f642b5c236141a
上级
68c5e766
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
32 addition
and
6 deletion
+32
-6
imperative/python/megengine/data/sampler.py
imperative/python/megengine/data/sampler.py
+4
-4
imperative/python/test/unit/data/test_sampler.py
imperative/python/test/unit/data/test_sampler.py
+28
-2
未找到文件。
imperative/python/megengine/data/sampler.py
浏览文件 @
4a863160
...
...
@@ -297,10 +297,10 @@ class ReplacementSampler(MapSampler):
def
sample
(
self
)
->
List
:
n
=
len
(
self
.
dataset
)
i
f
self
.
weights
is
None
:
return
self
.
rng
.
randint
(
n
,
size
=
self
.
num_samples
).
tolist
()
else
:
return
self
.
rng
.
multinomial
(
n
,
self
.
weights
,
self
.
num_samples
)
.
tolist
()
i
ndices
=
self
.
rng
.
choice
(
n
,
size
=
self
.
num_samples
,
replace
=
True
,
p
=
self
.
weights
)
return
indices
.
tolist
()
class
Infinite
(
MapSampler
):
...
...
imperative/python/test/unit/data/test_sampler.py
浏览文件 @
4a863160
...
...
@@ -58,8 +58,34 @@ def test_random_sampler_seed():
def
test_ReplacementSampler
():
num_samples
=
30
indices
=
list
(
range
(
20
))
weights
=
list
(
range
(
20
))
num_data
=
20
indices
=
list
(
range
(
num_data
))
sampler
=
ReplacementSampler
(
ArrayDataset
(
indices
),
num_samples
=
num_samples
,
weights
=
None
)
assert
len
(
list
(
each
[
0
]
for
each
in
sampler
))
==
num_samples
num_data
=
8
weights
=
list
(
range
(
num_data
))
indices
=
list
(
range
(
num_data
))
sampler
=
ReplacementSampler
(
ArrayDataset
(
indices
),
num_samples
=
num_samples
,
weights
=
weights
)
assert
len
(
list
(
each
[
0
]
for
each
in
sampler
))
==
num_samples
iter
=
1000
hist
=
[
0
for
_
in
range
(
num_data
)]
for
_
in
range
(
iter
):
for
index
in
sampler
:
index
=
index
[
0
]
hist
[
index
]
+=
1
actual_weights
=
np
.
array
(
hist
)
/
sum
(
hist
)
desired_weights
=
np
.
array
(
weights
)
/
sum
(
weights
)
np
.
testing
.
assert_allclose
(
actual_weights
,
desired_weights
,
rtol
=
8e-2
)
num_data
=
50000
num_samples
=
50000
*
30
weights
=
list
(
range
(
num_data
))
indices
=
list
(
range
(
num_data
))
sampler
=
ReplacementSampler
(
ArrayDataset
(
indices
),
num_samples
=
num_samples
,
weights
=
weights
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录