未验证 提交 f873d3a1 编写于 作者: L lilong12 提交者: GitHub

bug fix shard_index (#37042) (#37421)

上级 4dc426f4
...@@ -31,7 +31,7 @@ class ShardIndexOp : public framework::OperatorWithKernel { ...@@ -31,7 +31,7 @@ class ShardIndexOp : public framework::OperatorWithKernel {
"but the value given is %d.", "but the value given is %d.",
x_dims.size())); x_dims.size()));
if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) { if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) {
PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U, PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1], 1U,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The last dimension of Input(X) should be 1, " "The last dimension of Input(X) should be 1, "
"but the value given is %d.", "but the value given is %d.",
......
...@@ -14914,28 +14914,37 @@ def deformable_roi_pooling(input, ...@@ -14914,28 +14914,37 @@ def deformable_roi_pooling(input,
@deprecated(since="2.0.0", update_to="paddle.shard_index") @deprecated(since="2.0.0", update_to="paddle.shard_index")
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
""" """
Recompute the `input` indices according to the offset of the Reset the values of `input` according to the shard it beloning to.
shard. The length of the indices is evenly divided into N shards, and if Every value in `input` must be a non-negative integer, and
the `shard_id` matches the shard with the input index inside, the index is the parameter `index_num` represents the integer above the maximum
recomputed on the basis of the shard offset, elsewise it is set to value of `input`. Thus, all values in `input` must be in the range
`ignore_value`. The detail is as follows: [0, index_num) and each value can be regarded as the offset to the beginning
of the range. The range is further split into multiple shards. Specifically,
we first compute the `shard_size` according to the following formula,
which represents the number of integers each shard can hold. So for the
i'th shard, it can hold values in the range [i*shard_size, (i+1)*shard_size).
:: ::
shard_size = (index_num + nshards - 1) // nshards shard_size = (index_num + nshards - 1) // nshards
y = x % shard_size if x // shard_size == shard_id else ignore_value
NOTE: If the length of indices cannot be evely divided by the shard number, For each value `v` in `input`, we reset it to a new value according to the
the size of the last shard will be less than the calculated `shard_size` following formula:
::
v = v - shard_id * shard_size if shard_id * shard_size <= v < (shard_id+1) * shard_size else ignore_value
That is, the value `v` is set to the new offset within the range represented by the shard `shard_id`
if it in the range. Otherwise, we reset it to be `ignore_value`.
Args: Args:
input (Tensor): Input indices with data type int64 or int32. It's last dimension must be 1. input (Tensor): Input tensor with data type int64 or int32. It's last dimension must be 1.
index_num (int): An integer defining the range of the index. index_num (int): An integer represents the integer above the maximum value of `input`.
nshards (int): The number of shards. nshards (int): The number of shards.
shard_id (int): The index of the current shard. shard_id (int): The index of the current shard.
ignore_value (int): An integer value out of sharded index range. ignore_value (int): An integer value out of sharded index range.
Returns: Returns:
Tensor: The sharded index of input. Tensor.
Examples: Examples:
.. code-block:: python .. code-block:: python
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册