From b505ff96e41bf9b9356b2f52afd01ad6935cf947 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Fri, 19 Nov 2021 20:34:26 +0800 Subject: [PATCH] bug fix shard_index (#37042) --- paddle/fluid/operators/shard_index_op.cc | 2 +- python/paddle/fluid/layers/nn.py | 31 +++++++++++++++--------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/shard_index_op.cc b/paddle/fluid/operators/shard_index_op.cc index 3e5e2ad3d8c..54555e494ff 100644 --- a/paddle/fluid/operators/shard_index_op.cc +++ b/paddle/fluid/operators/shard_index_op.cc @@ -31,7 +31,7 @@ class ShardIndexOp : public framework::OperatorWithKernel { "but the value given is %d.", x_dims.size())); if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) { - PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U, + PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1], 1U, platform::errors::InvalidArgument( "The last dimension of Input(X) should be 1, " "but the value given is %d.", diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index dd0abd212e8..663c394e803 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14904,28 +14904,37 @@ def deformable_roi_pooling(input, @deprecated(since="2.0.0", update_to="paddle.shard_index") def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): """ - Recompute the `input` indices according to the offset of the - shard. The length of the indices is evenly divided into N shards, and if - the `shard_id` matches the shard with the input index inside, the index is - recomputed on the basis of the shard offset, elsewise it is set to - `ignore_value`. The detail is as follows: + Reset the values of `input` according to the shard it beloning to. + Every value in `input` must be a non-negative integer, and + the parameter `index_num` represents the integer above the maximum + value of `input`. Thus, all values in `input` must be in the range + [0, index_num) and each value can be regarded as the offset to the beginning + of the range. The range is further split into multiple shards. Specifically, + we first compute the `shard_size` according to the following formula, + which represents the number of integers each shard can hold. So for the + i'th shard, it can hold values in the range [i*shard_size, (i+1)*shard_size). :: shard_size = (index_num + nshards - 1) // nshards - y = x % shard_size if x // shard_size == shard_id else ignore_value - NOTE: If the length of indices cannot be evely divided by the shard number, - the size of the last shard will be less than the calculated `shard_size` + For each value `v` in `input`, we reset it to a new value according to the + following formula: + :: + + v = v - shard_id * shard_size if shard_id * shard_size <= v < (shard_id+1) * shard_size else ignore_value + + That is, the value `v` is set to the new offset within the range represented by the shard `shard_id` + if it in the range. Otherwise, we reset it to be `ignore_value`. Args: - input (Tensor): Input indices with data type int64 or int32. It's last dimension must be 1. - index_num (int): An integer defining the range of the index. + input (Tensor): Input tensor with data type int64 or int32. It's last dimension must be 1. + index_num (int): An integer represents the integer above the maximum value of `input`. nshards (int): The number of shards. shard_id (int): The index of the current shard. ignore_value (int): An integer value out of sharded index range. Returns: - Tensor: The sharded index of input. + Tensor. Examples: .. code-block:: python -- GitLab