Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9d439ae6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9d439ae6
编写于
5月 12, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
docs(misc): correct docstring format broadly
GitOrigin-RevId: 45234ca07ed66408d77741de307cc44c59b9f3da
上级
651c4e9a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
63 addition
and
48 deletion
+63
-48
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+2
-1
imperative/python/megengine/data/collator.py
imperative/python/megengine/data/collator.py
+1
-1
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+23
-32
imperative/python/megengine/data/sampler.py
imperative/python/megengine/data/sampler.py
+32
-9
imperative/python/megengine/functional/vision.py
imperative/python/megengine/functional/vision.py
+1
-1
imperative/python/megengine/module/conv.py
imperative/python/megengine/module/conv.py
+1
-1
imperative/python/megengine/module/module.py
imperative/python/megengine/module/module.py
+2
-2
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+1
-1
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
9d439ae6
...
...
@@ -27,7 +27,8 @@ from .core import TensorBase
def
set_priority_to_id
(
dest_vars
):
"""
For all oprs in the subgraph constructed by dest_vars,
sets its priority to id if its original priority is zero.
sets its priority to id if its original priority is zero.
:param dest_vars: target vars representing the graph.
"""
dest_vec
=
[]
...
...
imperative/python/megengine/data/collator.py
浏览文件 @
9d439ae6
...
...
@@ -40,7 +40,7 @@ class Collator:
def
apply
(
self
,
inputs
):
"""
:param input: sequence_N(tuple(CHW, C, CK)).
:param input
s
: sequence_N(tuple(CHW, C, CK)).
:return: tuple(NCHW, NC, NCK).
"""
elem
=
inputs
[
0
]
...
...
imperative/python/megengine/data/dataloader.py
浏览文件 @
9d439ae6
...
...
@@ -43,8 +43,29 @@ def raise_timeout_error():
class
DataLoader
:
r
"""
Provides a convenient way to iterate on a given dataset.
r
"""Provides a convenient way to iterate on a given dataset.
DataLoader combines a dataset with
:class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
make it flexible to get minibatch continually from a dataset.
:param dataset: dataset from which to load the minibatch.
:param sampler: defines the strategy to sample data from the dataset.
:param transform: defined the transforming strategy for a sampled batch.
Default: None
:param collator: defined the merging strategy for a transformed batch.
Default: None
:param num_workers: the number of sub-process to load, transform and collate
the batch. ``0`` means using single-process. Default: 0
:param timeout: if positive, means the timeout value(second) for collecting a
batch from workers. Default: 0
:param timeout_event: callback function triggered by timeout, default to raise
runtime error.
:param divide: define the paralleling strategy in multi-processing mode.
``True`` means one batch is divided into :attr:`num_workers` pieces, and
the workers will process these pieces parallelly. ``False`` means
different sub-process will process different batch. Default: False
"""
__initialized
=
False
...
...
@@ -59,36 +80,6 @@ class DataLoader:
timeout_event
:
Callable
=
raise_timeout_error
,
divide
:
bool
=
False
,
):
r
"""
`DataLoader` combines a dataset with `sampler`, `transform` and `collator`,
make it flexible to get minibatch continually from a dataset.
:type dataset: Dataset
:param dataset: dataset from which to load the minibatch.
:type sampler: Sampler
:param sampler: defines the strategy to sample data from the dataset.
:type transform: Transform
:param transform: defined the transforming strategy for a sampled batch.
Default: None
:type collator: Collator
:param collator: defined the merging strategy for a transformed batch.
Default: None
:type num_workers: int
:param num_workers: the number of sub-process to load, transform and collate
the batch. ``0`` means using single-process. Default: 0
:type timeout: int
:param timeout: if positive, means the timeout value(second) for collecting a
batch from workers. Default: 0
:type timeout_event: Callable
:param timeout_event: callback function triggered by timeout, default to raise
runtime error.
:type divide: bool
:param divide: define the paralleling strategy in multi-processing mode.
``True`` means one batch is divided into :attr:`num_workers` pieces, and
the workers will process these pieces parallelly. ``False`` means
different sub-process will process different batch. Default: False
"""
if
num_workers
<
0
:
raise
ValueError
(
"num_workers should not be negative"
)
...
...
imperative/python/megengine/data/sampler.py
浏览文件 @
9d439ae6
...
...
@@ -30,22 +30,15 @@ class MapSampler(Sampler):
r
"""
Sampler for map dataset.
:type dataset: `dataset`
:param dataset: dataset to sample from.
:type batch_size: positive integer
:param batch_size: batch size for batch method.
:type drop_last: bool
:param drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False
:type num_samples: positive integer
:param num_samples: number of samples assigned to one rank.
:type world_size: positive integer
:param world_size: number of ranks.
:type rank: non-negative integer within 0 and world_size
:param rank: rank id, non-negative interger within 0 and ``world_size``.
:type seed: non-negative integer
:param seed: seed for random operators.
"""
...
...
@@ -166,7 +159,7 @@ class StreamSampler(Sampler):
different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal.
Usually,
meth:
:`~.StreamDataset.__iter__` can return different iterator by
Usually,
:meth
:`~.StreamDataset.__iter__` can return different iterator by
``rank = dist.get_rank()``. So that they will get different data.
"""
...
...
@@ -184,6 +177,16 @@ class StreamSampler(Sampler):
class
SequentialSampler
(
MapSampler
):
r
"""
Sample elements sequentially.
:param dataset: dataset to sample from.
:param batch_size: batch size for batch method.
:param drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False
:param indices: indice of samples.
:param world_size: number of ranks.
:param rank: rank id, non-negative interger within 0 and ``world_size``.
"""
def
__init__
(
...
...
@@ -216,6 +219,17 @@ class SequentialSampler(MapSampler):
class
RandomSampler
(
MapSampler
):
r
"""
Sample elements randomly without replacement.
:param dataset: dataset to sample from.
:param batch_size: batch size for batch method.
:param drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False
:param indices: indice of samples.
:param world_size: number of ranks.
:param rank: rank id, non-negative interger within 0 and ``world_size``.
:param seed: seed for random operators.
"""
def
__init__
(
...
...
@@ -247,8 +261,17 @@ class ReplacementSampler(MapSampler):
r
"""
Sample elements randomly with replacement.
:type weights: List
:param dataset: dataset to sample from.
:param batch_size: batch size for batch method.
:param drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False
:param num_samples: number of samples assigned to one rank.
:param weights: weights for sampling indices, it could be unnormalized weights.
:param world_size: number of ranks.
:param rank: rank id, non-negative interger within 0 and ``world_size``.
:param seed: seed for random operators.
"""
def
__init__
(
...
...
imperative/python/megengine/functional/vision.py
浏览文件 @
9d439ae6
...
...
@@ -224,7 +224,7 @@ def nms(
:param scores: tensor of shape `(N,)`, the score of boxes.
:param max_output: the maximum number of boxes to keep; it is optional if this operator is not traced
otherwise it required to be specified; if it is not specified, all boxes are kept.
:return: indices of the elements that have been kept by NMS.
:return: indices of the elements that have been kept by NMS
, sorted by scores
.
Examples:
...
...
imperative/python/megengine/module/conv.py
浏览文件 @
9d439ae6
...
...
@@ -409,7 +409,7 @@ class Conv3d(_ConvNd):
For instance, given an input of the size :math:`(N, C_{\text{in}}, T, H, W)`,
this layer generates an output of the size
:math:`(N, C_{\text{out}}, T_{\text{out}}
}, H_{\text{out}}}, W_{\text{out}
}})` through the
:math:`(N, C_{\text{out}}, T_{\text{out}}
, H_{\text{out}}, W_{\text{out
}})` through the
process described as below:
.. math::
...
...
imperative/python/megengine/module/module.py
浏览文件 @
9d439ae6
...
...
@@ -91,7 +91,7 @@ class Module(metaclass=ABCMeta):
def
__init__
(
self
,
name
=
None
):
"""
:param name: module's name, can be initialized by the ``kwargs`` parameter
of child class.
of child class.
"""
self
.
_modules
=
[]
...
...
@@ -122,7 +122,7 @@ class Module(metaclass=ABCMeta):
Registers a hook to handle forward inputs. `hook` should be a function.
:param hook: a function that receive `module` and `inputs`, then return
a modified `inputs` or `None`.
a modified `inputs` or `None`.
:return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
"""
return
HookHandler
(
self
.
_forward_pre_hooks
,
hook
)
...
...
imperative/python/megengine/tensor.py
浏览文件 @
9d439ae6
...
...
@@ -174,7 +174,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
def
set_value
(
self
,
value
):
self
.
_reset
(
value
)
@
deprecated
(
version
=
"1.0"
,
reason
=
"use
*= 0
instead"
)
@
deprecated
(
version
=
"1.0"
,
reason
=
"use
``*= 0``
instead"
)
def
reset_zero
(
self
):
self
*=
0
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录