未验证 提交 e06f4439 编写于 作者: L lilong12 提交者: GitHub

add the framework support for distfc (#21197) (#21463)

* add the framework support for distfc and ut, test=develop
* fix the implementation of shard_index_op, test=develop
上级 9c63b7c1
...@@ -26,7 +26,7 @@ __global__ void ShardIndexInner(const T* in_data, T* out_data, ...@@ -26,7 +26,7 @@ __global__ void ShardIndexInner(const T* in_data, T* out_data,
const int64_t numel, const int index_num, const int64_t numel, const int index_num,
const int nshards, const int shard_id, const int nshards, const int shard_id,
const int ignore_value) { const int ignore_value) {
int shard_size = index_num / nshards; int shard_size = (index_num + nshards - 1) / nshards;
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) { if (idx < numel) {
assert(in_data[idx] >= 0 && in_data[idx] < index_num); assert(in_data[idx] >= 0 && in_data[idx] < index_num);
......
...@@ -34,7 +34,7 @@ class ShardIndexCPUKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,7 @@ class ShardIndexCPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards, PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards,
"shard_id(%d) is not in range [0, %d)", shard_id, nshards); "shard_id(%d) is not in range [0, %d)", shard_id, nshards);
int shard_size = index_num / nshards; int shard_size = (index_num + nshards - 1) / nshards;
out->Resize(in->dims()); out->Resize(in->dims());
out->set_lod(in->lod()); out->set_lod(in->lod());
......
...@@ -1250,6 +1250,35 @@ class Variable(object): ...@@ -1250,6 +1250,35 @@ class Variable(object):
""" """
self.error_clip = error_clip self.error_clip = error_clip
def _set_info(self, key, value):
"""
Set key-value information for this variable.
Args:
key(str): Key for this information.
value(object): The value associated to the key.
Returns:
None
"""
if not hasattr(self, "_info"):
self._info = {}
self._info[key] = value
def _get_info(self, key):
"""
Get the information of this variable corresponding to key.
Args:
key(str): Key for this information.
Returns:
object
"""
if hasattr(self, "_info") and key in self._info:
return self._info[key]
return None
def _slice_indices(self, slice, length): def _slice_indices(self, slice, length):
""" """
Reference implementation for the slice.indices method. Reference implementation for the slice.indices method.
......
...@@ -17655,10 +17655,6 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): ...@@ -17655,10 +17655,6 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
""" """
op_type = 'shard_index' op_type = 'shard_index'
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
if index_num % nshards != 0:
raise ValueError(
'The index_num(%d) cannot be evenly divided by nshards(%d)' %
(index_num, nshards))
if shard_id < 0 or shard_id >= nshards: if shard_id < 0 or shard_id >= nshards:
raise ValueError('The shard_id(%d) should be in [0, %d)' % raise ValueError('The shard_id(%d) should be in [0, %d)' %
(shard_id, nshards)) (shard_id, nshards))
......
...@@ -31,7 +31,7 @@ def common_setup(self, index_num, nshards, shard_id, ignore_value): ...@@ -31,7 +31,7 @@ def common_setup(self, index_num, nshards, shard_id, ignore_value):
x = [np.random.randint(0, index_num - 1) for i in range(N)] x = [np.random.randint(0, index_num - 1) for i in range(N)]
x = np.array(x).astype('int32').reshape([N, 1]) x = np.array(x).astype('int32').reshape([N, 1])
shard_size = index_num // nshards shard_size = (index_num + nshards - 1) // nshards
out = np.zeros(shape=x.shape).astype('int32') out = np.zeros(shape=x.shape).astype('int32')
for i in range(N): for i in range(N):
if x[i] // shard_size == shard_id: if x[i] // shard_size == shard_id:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TestCases for Dataset,
including create, config, run, etc.
"""
from __future__ import print_function
import paddle.fluid as fluid
import numpy as np
import os
import shutil
import unittest
class TestVarInfo(unittest.TestCase):
""" TestCases for Dataset. """
def test_var_info(self):
""" Testcase for get and set info for variable. """
value = np.random.randn(1)
var = fluid.layers.create_global_var([1], value, "float32")
var._set_info("name", "test")
ret = var._get_info("name")
assert ret == "test"
ret = var._get_info("not_exist")
assert ret == None
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册