未验证 提交 e06f4439 编写于 作者: L lilong12 提交者: GitHub

add the framework support for distfc (#21197) (#21463)

* add the framework support for distfc and ut, test=develop
* fix the implementation of shard_index_op, test=develop
上级 9c63b7c1
......@@ -26,7 +26,7 @@ __global__ void ShardIndexInner(const T* in_data, T* out_data,
const int64_t numel, const int index_num,
const int nshards, const int shard_id,
const int ignore_value) {
int shard_size = index_num / nshards;
int shard_size = (index_num + nshards - 1) / nshards;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
assert(in_data[idx] >= 0 && in_data[idx] < index_num);
......
......@@ -34,7 +34,7 @@ class ShardIndexCPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards,
"shard_id(%d) is not in range [0, %d)", shard_id, nshards);
int shard_size = index_num / nshards;
int shard_size = (index_num + nshards - 1) / nshards;
out->Resize(in->dims());
out->set_lod(in->lod());
......
......@@ -1250,6 +1250,35 @@ class Variable(object):
"""
self.error_clip = error_clip
def _set_info(self, key, value):
"""
Set key-value information for this variable.
Args:
key(str): Key for this information.
value(object): The value associated to the key.
Returns:
None
"""
if not hasattr(self, "_info"):
self._info = {}
self._info[key] = value
def _get_info(self, key):
"""
Get the information of this variable corresponding to key.
Args:
key(str): Key for this information.
Returns:
object
"""
if hasattr(self, "_info") and key in self._info:
return self._info[key]
return None
def _slice_indices(self, slice, length):
"""
Reference implementation for the slice.indices method.
......
......@@ -17655,10 +17655,6 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
"""
op_type = 'shard_index'
helper = LayerHelper(op_type, **locals())
if index_num % nshards != 0:
raise ValueError(
'The index_num(%d) cannot be evenly divided by nshards(%d)' %
(index_num, nshards))
if shard_id < 0 or shard_id >= nshards:
raise ValueError('The shard_id(%d) should be in [0, %d)' %
(shard_id, nshards))
......
......@@ -31,7 +31,7 @@ def common_setup(self, index_num, nshards, shard_id, ignore_value):
x = [np.random.randint(0, index_num - 1) for i in range(N)]
x = np.array(x).astype('int32').reshape([N, 1])
shard_size = index_num // nshards
shard_size = (index_num + nshards - 1) // nshards
out = np.zeros(shape=x.shape).astype('int32')
for i in range(N):
if x[i] // shard_size == shard_id:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TestCases for Dataset,
including create, config, run, etc.
"""
from __future__ import print_function
import paddle.fluid as fluid
import numpy as np
import os
import shutil
import unittest
class TestVarInfo(unittest.TestCase):
""" TestCases for Dataset. """
def test_var_info(self):
""" Testcase for get and set info for variable. """
value = np.random.randn(1)
var = fluid.layers.create_global_var([1], value, "float32")
var._set_info("name", "test")
ret = var._get_info("name")
assert ret == "test"
ret = var._get_info("not_exist")
assert ret == None
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册