diff --git a/paddle/fluid/operators/shard_index_op.cu b/paddle/fluid/operators/shard_index_op.cu index 08503e3e1a8fe66b20f1e23012c584f9e32b4a01..db29b73f9eae463c977e96293d870bdf77addce9 100644 --- a/paddle/fluid/operators/shard_index_op.cu +++ b/paddle/fluid/operators/shard_index_op.cu @@ -26,7 +26,7 @@ __global__ void ShardIndexInner(const T* in_data, T* out_data, const int64_t numel, const int index_num, const int nshards, const int shard_id, const int ignore_value) { - int shard_size = index_num / nshards; + int shard_size = (index_num + nshards - 1) / nshards; int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < numel) { assert(in_data[idx] >= 0 && in_data[idx] < index_num); diff --git a/paddle/fluid/operators/shard_index_op.h b/paddle/fluid/operators/shard_index_op.h index f060b3fdf182a2bf7fe03b1d86db41c4d1cfb340..f943de586bc242324596853455cf94cd80837953 100644 --- a/paddle/fluid/operators/shard_index_op.h +++ b/paddle/fluid/operators/shard_index_op.h @@ -34,7 +34,7 @@ class ShardIndexCPUKernel : public framework::OpKernel { PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards, "shard_id(%d) is not in range [0, %d)", shard_id, nshards); - int shard_size = index_num / nshards; + int shard_size = (index_num + nshards - 1) / nshards; out->Resize(in->dims()); out->set_lod(in->lod()); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5a68fe449ce54ff506c8833642dfe0599b168823..6d36016c5b1b6a5c2cdd21b4a6870d2a535ec280 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1250,6 +1250,35 @@ class Variable(object): """ self.error_clip = error_clip + def _set_info(self, key, value): + """ + Set key-value information for this variable. + + Args: + key(str): Key for this information. + value(object): The value associated to the key. + + Returns: + None + """ + if not hasattr(self, "_info"): + self._info = {} + self._info[key] = value + + def _get_info(self, key): + """ + Get the information of this variable corresponding to key. + + Args: + key(str): Key for this information. + + Returns: + object + """ + if hasattr(self, "_info") and key in self._info: + return self._info[key] + return None + def _slice_indices(self, slice, length): """ Reference implementation for the slice.indices method. diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4c62a4764ac21e57a94a96ea372cc1a2d1096c9a..b0e3e71cdb6550177872482b327d76d1ce541b4e 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -17655,10 +17655,6 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): """ op_type = 'shard_index' helper = LayerHelper(op_type, **locals()) - if index_num % nshards != 0: - raise ValueError( - 'The index_num(%d) cannot be evenly divided by nshards(%d)' % - (index_num, nshards)) if shard_id < 0 or shard_id >= nshards: raise ValueError('The shard_id(%d) should be in [0, %d)' % (shard_id, nshards)) diff --git a/python/paddle/fluid/tests/unittests/test_shard_index_op.py b/python/paddle/fluid/tests/unittests/test_shard_index_op.py index fd3c0a5458ab8cc675b4de43516164b6386a4882..9ccf1f254a5566bfebce1d18873b76f5961ff65b 100644 --- a/python/paddle/fluid/tests/unittests/test_shard_index_op.py +++ b/python/paddle/fluid/tests/unittests/test_shard_index_op.py @@ -31,7 +31,7 @@ def common_setup(self, index_num, nshards, shard_id, ignore_value): x = [np.random.randint(0, index_num - 1) for i in range(N)] x = np.array(x).astype('int32').reshape([N, 1]) - shard_size = index_num // nshards + shard_size = (index_num + nshards - 1) // nshards out = np.zeros(shape=x.shape).astype('int32') for i in range(N): if x[i] // shard_size == shard_id: diff --git a/python/paddle/fluid/tests/unittests/test_var_info.py b/python/paddle/fluid/tests/unittests/test_var_info.py new file mode 100644 index 0000000000000000000000000000000000000000..0683eeb64145dc22f331274b6d70f7ccf2856f9a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_var_info.py @@ -0,0 +1,42 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TestCases for Dataset, +including create, config, run, etc. +""" + +from __future__ import print_function +import paddle.fluid as fluid +import numpy as np +import os +import shutil +import unittest + + +class TestVarInfo(unittest.TestCase): + """ TestCases for Dataset. """ + + def test_var_info(self): + """ Testcase for get and set info for variable. """ + value = np.random.randn(1) + var = fluid.layers.create_global_var([1], value, "float32") + var._set_info("name", "test") + ret = var._get_info("name") + assert ret == "test" + ret = var._get_info("not_exist") + assert ret == None + + +if __name__ == '__main__': + unittest.main()