未验证 提交 e4dd68e8 编写于 作者: Y Yi Liu 提交者: GitHub

update shard_index_cn.rst (#1300)

上级 f31732a6
=====
API
=====
.. toctree::
:maxdepth: 1
../api_guides/index_cn.rst
api_tree_cn.rst
...@@ -5,50 +5,47 @@ shard_index ...@@ -5,50 +5,47 @@ shard_index
.. py:function:: paddle.fluid.layers.shard_index(input, index_num, nshards, shard_id, ignore_value=-1) .. py:function:: paddle.fluid.layers.shard_index(input, index_num, nshards, shard_id, ignore_value=-1)
该层为输入创建碎片化索引,通常在模型和数据并行混合训练时使用,索引数据(通常是标签)应该在每一个trainer里面被计算,通过 该函数对输入的索引根据分片(shard)的偏移量重新计算。
索引长度被均分为N个分片,如果输入索引所在的分片跟分片ID对应,则该索引以分片的偏移量为界重新计算,否则更新为默认值(ignore_value)。具体计算为:
:: ::
assert index_num % nshards == 0 每个分片的长度为
shard_size = (index_num + nshards - 1) // nshards
shard_size = index_num / nshards 如果 shard_id == input // shard_size
则 output = input % shard_size
否则 output = ignore_value
如果 x / shard_size == shard_id 注意:若索引长度不能被分片数整除,则最后一个分片长度不足shard_size。
y = x % shard_size 示例:
::
否则
y = ignore_value
我们使用分布式 ``one-hot`` 表示来展示该层如何使用, 分布式的 ``one-hot`` 表示被分割为多个碎片, 碎片索引里不为1的都使用0来填充。为了在每一个trainer里面创建碎片化的表示,原始的索引应该先进行计算(i.e. sharded)。我们来看个例子:
.. code-block:: text
X 是一个整形张量
X.shape = [4, 1]
X.data = [[1], [6], [12], [19]]
假设 index_num = 20 并且 nshards = 2, 我们可以得到 shard_size = 10 输入:
input.shape = [4, 1]
input.data = [[1], [6], [12], [19]]
index_num = 20
nshards = 2
ignore_value=-1
如果 shard_id == 0, 我们得到输出: 如果 shard_id == 0, 输出:
Out.shape = [4, 1] output.shape = [4, 1]
Out.data = [[1], [6], [-1], [-1]] output.data = [[1], [6], [-1], [-1]]
如果 shard_id == 1, 我们得到输出:
Out.shape = [4, 1]
Out.data = [[-1], [-1], [2], [9]]
上面的例子中默认 ignore_value = -1 如果 shard_id == 1, 输出:
output.shape = [4, 1]
output.data = [[-1], [-1], [2], [9]]
参数: 参数:
- **input** (Variable)- 输入的索引,最后的维度应该为1 - **input** (Variable)- 输入的索引
- **index_num** (scalar) - 定义索引长度的整形参数 - **index_num** (scalar) - 索引长度
- **nshards** (scalar) - shards数量 - **nshards** (scalar) - 分片数量
- **shard_id** (scalar) - 当前碎片的索引 - **shard_id** (scalar) - 当前分片ID
- **ignore_value** (scalar) - 超出碎片索引范围的整型 - **ignore_value** (scalar) - 超出分片索引范围的默认
返回: 输入的碎片索引 返回:更新后的索引值
返回类型: Variable 返回类型:Variable
**代码示例:** **代码示例:**
...@@ -60,8 +57,3 @@ shard_index ...@@ -60,8 +57,3 @@ shard_index
index_num=20, index_num=20,
nshards=2, nshards=2,
shard_id=0) shard_id=0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册