Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f873d3a1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f873d3a1
编写于
11月 23, 2021
作者:
L
lilong12
提交者:
GitHub
11月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix shard_index (#37042) (#37421)
上级
4dc426f4
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
21 addition
and
12 deletion
+21
-12
paddle/fluid/operators/shard_index_op.cc
paddle/fluid/operators/shard_index_op.cc
+1
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+20
-11
未找到文件。
paddle/fluid/operators/shard_index_op.cc
浏览文件 @
f873d3a1
...
@@ -31,7 +31,7 @@ class ShardIndexOp : public framework::OperatorWithKernel {
...
@@ -31,7 +31,7 @@ class ShardIndexOp : public framework::OperatorWithKernel {
"but the value given is %d."
,
"but the value given is %d."
,
x_dims
.
size
()));
x_dims
.
size
()));
if
(
ctx
->
IsRuntime
()
||
x_dims
[
x_dims
.
size
()
-
1
]
>
0
)
{
if
(
ctx
->
IsRuntime
()
||
x_dims
[
x_dims
.
size
()
-
1
]
>
0
)
{
PADDLE_ENFORCE_
GE
(
x_dims
[
x_dims
.
size
()
-
1
],
1U
,
PADDLE_ENFORCE_
EQ
(
x_dims
[
x_dims
.
size
()
-
1
],
1U
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The last dimension of Input(X) should be 1, "
"The last dimension of Input(X) should be 1, "
"but the value given is %d."
,
"but the value given is %d."
,
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
f873d3a1
...
@@ -14914,28 +14914,37 @@ def deformable_roi_pooling(input,
...
@@ -14914,28 +14914,37 @@ def deformable_roi_pooling(input,
@deprecated(since="2.0.0", update_to="paddle.shard_index")
@deprecated(since="2.0.0", update_to="paddle.shard_index")
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
"""
"""
Recompute the `input` indices according to the offset of the
Reset the values of `input` according to the shard it beloning to.
shard. The length of the indices is evenly divided into N shards, and if
Every value in `input` must be a non-negative integer, and
the `shard_id` matches the shard with the input index inside, the index is
the parameter `index_num` represents the integer above the maximum
recomputed on the basis of the shard offset, elsewise it is set to
value of `input`. Thus, all values in `input` must be in the range
`ignore_value`. The detail is as follows:
[0, index_num) and each value can be regarded as the offset to the beginning
of the range. The range is further split into multiple shards. Specifically,
we first compute the `shard_size` according to the following formula,
which represents the number of integers each shard can hold. So for the
i'th shard, it can hold values in the range [i*shard_size, (i+1)*shard_size).
::
::
shard_size = (index_num + nshards - 1) // nshards
shard_size = (index_num + nshards - 1) // nshards
y = x % shard_size if x // shard_size == shard_id else ignore_value
NOTE: If the length of indices cannot be evely divided by the shard number,
For each value `v` in `input`, we reset it to a new value according to the
the size of the last shard will be less than the calculated `shard_size`
following formula:
::
v = v - shard_id * shard_size if shard_id * shard_size <= v < (shard_id+1) * shard_size else ignore_value
That is, the value `v` is set to the new offset within the range represented by the shard `shard_id`
if it in the range. Otherwise, we reset it to be `ignore_value`.
Args:
Args:
input (Tensor): Input
indices
with data type int64 or int32. It's last dimension must be 1.
input (Tensor): Input
tensor
with data type int64 or int32. It's last dimension must be 1.
index_num (int): An integer
defining the range of the index
.
index_num (int): An integer
represents the integer above the maximum value of `input`
.
nshards (int): The number of shards.
nshards (int): The number of shards.
shard_id (int): The index of the current shard.
shard_id (int): The index of the current shard.
ignore_value (int): An integer value out of sharded index range.
ignore_value (int): An integer value out of sharded index range.
Returns:
Returns:
Tensor
: The sharded index of input
.
Tensor.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录