diff --git a/AUTHORS.md b/AUTHORS.md index 46e2aef30ae87656343e36224f0ab6c0277a619b..5be71c9b2d598c4c9141ec23628fc8be898bf5e8 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -28,6 +28,7 @@ | lcy-seso | Ying Cao | | cjld | Dun Liang | | lipeng-unisound | Peng Li | +| gavin1332 | Yi Liu | | liuyuan | Yuan Liu | | livc | Zhao Li | | llxxxll | Yong-Feng Liu | diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index fc20278ef676f30a54500c15b203b970f7f9737e..7c31b5b4e1fa34b5cf4ca8dcffe3e4fc2c1fc84e 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -158,7 +158,7 @@ paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'par paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '9461e67095a6fc5d568fb2ce8fef66ff')) paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax', 'axis'], varargs=None, keywords=None, defaults=(False, -100, True, False, -1)), ('document', '54e1675aa0364f4a78fa72804ec0f413')) paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'ecb75c1b00c4c76c98b482f633b7a10c')) -paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None), ('document', '52db6229214fc6ab167d7009df29170d')) +paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth', 'allow_out_of_range'], varargs=None, keywords=None, defaults=(False,)), ('document', 'ec4115591be842868c86b2e5334245c6')) paddle.fluid.layers.autoincreased_step_counter (ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)), ('document', '98e7927f09ee2270535b29f048e481ec')) paddle.fluid.layers.reshape (ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', '6196c9ec3075ca5a9c058ea1f8492256')) paddle.fluid.layers.squeeze (ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ebbac07662a6e22e8e299ced880c7775')) @@ -264,6 +264,7 @@ paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defau paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', '4d83ba6b971cfd590493b0925b3e081e')) paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6')) paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35')) +paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cccb6eb5410c822e5307c947aca2c899')) paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6')) diff --git a/paddle/fluid/operators/shard_index_op.cc b/paddle/fluid/operators/shard_index_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..578dcd37bb42bdc4c69020c2cf500d4a6c203a55 --- /dev/null +++ b/paddle/fluid/operators/shard_index_op.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/shard_index_op.h" + +namespace paddle { +namespace operators { + +class ShardIndexOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ShardIndexOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ShardIndexOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "Rank of Input(X) should be at least 2."); + if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) { + PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U, + "Last dimension of Input(X) should be 1."); + } + + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /* --> */ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class ShardIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LoDTensor, LoDTensor) Input variable. Each value " + "of X is an index."); + AddOutput( + "Out", + "(Tensor, Tensor) Output tensor with same shape as X. " + "The tensor consists of sharding representations of values in X."); + AddAttr("index_num", + "A positive integer to specify the range of the input X."); + + AddAttr("nshards", + "A positive integer to specify the number of shards."); + AddAttr("shard_id", "The current shard id"); + AddAttr("ignore_value", "An ingeter value out of sharded range") + .SetDefault(-1); + AddComment(R"DOC( +This layer creates the sharded index for input. This layers is used in +model- and data- parallel mixed training generally, in which the index +data (usually the label) should be recaculated in each trainer according +to + +.. math:: + + assert index_num % nshards == 0 + + shard_size = index_num / nshards + + y = x % shard_size if x / shard_size == shard_id else ignore_value + +We take the distributed one-hot representation to show what this layer is +used for. The distributed one-hot representation is seperated into multiple +shards, and each shard is filling zeros except the one with the index +inside. In order to create these sharded representation in each trainer, +the original index should be recalculated (i.e. sharded) before. + +Examples: + + X is a Tensor of integer values: + X.shape = [4, 1] + X.data = [[1], [6], [12], [19]] + + suppose index_num = 20 and nshards = 2, then we get shard_size = 10 + + if shard_id == 0, we get the Out: + Out.shape = [4, 1] + Out.data = [[1], [6], [-1], [-1]] + + if shard_id == 1, we get the Out: + Out.shape = [4, 1] + Out.data = [[-1], [-1], [2], [9]] + + the default `ignore_value` -1 is used in this example. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(shard_index, ops::ShardIndexOp, + ops::ShardIndexOpMaker); +REGISTER_OP_CPU_KERNEL(shard_index, ops::ShardIndexCPUKernel, + ops::ShardIndexCPUKernel); diff --git a/paddle/fluid/operators/shard_index_op.cu b/paddle/fluid/operators/shard_index_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..08503e3e1a8fe66b20f1e23012c584f9e32b4a01 --- /dev/null +++ b/paddle/fluid/operators/shard_index_op.cu @@ -0,0 +1,77 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/shard_index_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void ShardIndexInner(const T* in_data, T* out_data, + const int64_t numel, const int index_num, + const int nshards, const int shard_id, + const int ignore_value) { + int shard_size = index_num / nshards; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) { + assert(in_data[idx] >= 0 && in_data[idx] < index_num); + if (in_data[idx] / shard_size == shard_id) { + out_data[idx] = in_data[idx] % shard_size; + } else { + out_data[idx] = ignore_value; + } + } +} + +using LoDTensor = framework::LoDTensor; + +template +class ShardIndexCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int index_num = context.Attr("index_num"); + int nshards = context.Attr("nshards"); + int shard_id = context.Attr("shard_id"); + int ignore_value = context.Attr("ignore_value"); + PADDLE_ENFORCE_GT(index_num, 0); + PADDLE_ENFORCE_GT(nshards, 0); + PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards, + "shard_id(%d) is not in range [0, %d)", shard_id, nshards); + + out->Resize(in->dims()); + out->set_lod(in->lod()); + auto* in_data = in->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + int64_t numel = in->numel(); + auto stream = + context.template device_context().stream(); + ShardIndexInner<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_data, out_data, numel, index_num, nshards, shard_id, ignore_value); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(shard_index, ops::ShardIndexCUDAKernel, + ops::ShardIndexCUDAKernel); diff --git a/paddle/fluid/operators/shard_index_op.h b/paddle/fluid/operators/shard_index_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f060b3fdf182a2bf7fe03b1d86db41c4d1cfb340 --- /dev/null +++ b/paddle/fluid/operators/shard_index_op.h @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +template +class ShardIndexCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int index_num = context.Attr("index_num"); + int nshards = context.Attr("nshards"); + int shard_id = context.Attr("shard_id"); + int ignore_value = context.Attr("ignore_value"); + PADDLE_ENFORCE_GT(index_num, 0); + PADDLE_ENFORCE_GT(nshards, 0); + PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards, + "shard_id(%d) is not in range [0, %d)", shard_id, nshards); + + int shard_size = index_num / nshards; + + out->Resize(in->dims()); + out->set_lod(in->lod()); + auto* in_data = in->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + int64_t numel = in->numel(); + for (int64_t i = 0; i < numel; ++i) { + PADDLE_ENFORCE(in_data[i] >= 0 && in_data[i] < index_num, + "Input index(%d) is out of range [0,%d)", in_data[i], + index_num); + if (in_data[i] / shard_size == shard_id) { + out_data[i] = in_data[i] % shard_size; + } else { + out_data[i] = ignore_value; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 22d30cbcf574ebacf06a3fc3d725f39a3c8777ee..9371406ddb68f391d1aaa720fda1030eb6af2a1c 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3632,6 +3632,8 @@ class Parameter(Variable): self.do_model_average = kwargs.get('do_model_average', None) + self.is_distributed = False + def __str__(self): return self.to_string(True) diff --git a/python/paddle/fluid/layers/collective.py b/python/paddle/fluid/layers/collective.py index 6beddac7aace007f2c37b154a1b941083144da8b..9e96624cf7c70f24a7f65a91c7ee41af45ddeb6c 100644 --- a/python/paddle/fluid/layers/collective.py +++ b/python/paddle/fluid/layers/collective.py @@ -14,6 +14,7 @@ from __future__ import print_function from ..layer_helper import LayerHelper, unique_name +from ..framework import Variable def _allreduce(x, out=None, reduce_type="sum", sync_mode=False): @@ -58,3 +59,122 @@ def _broadcast(x, root, sync_mode=False): attrs={"sync_mode": sync_mode, "root": root}) return x + + +def _c_allreduce(x, + out=None, + reduce_type='sum', + ring_id=0, + use_calc_stream=False): + helper = LayerHelper('c_allreduce', **locals()) + + if reduce_type not in ['sum', 'prob', 'max', 'min']: + raise TypeError('reduce type can only be "sum|prod|max|min]"') + + op_type = 'c_allreduce_' + reduce_type + if out is None: + out = helper.create_variable( + name=unique_name.generate_with_ignorable_key('.'.join( + [x.name, op_type])), + shape=x.shape, + dtype=x.dtype, + type=x.type, + persistable=x.persistable) + + helper.append_op( + type=op_type, + inputs={'X': [x]}, + outputs={'Out': [out]}, + attrs={'ring_id': ring_id, + 'use_calc_stream': use_calc_stream}) + return out + + +def _c_broadcast(x, root=0, ring_id=0, use_calc_stream=False): + op_type = 'c_broadcast' + helper = LayerHelper(op_type, **locals()) + helper.append_op( + type=op_type, + inputs={'X': [x]}, + outputs={'Out': [x]}, + attrs={ + 'root': root, + 'ring_id': ring_id, + 'use_calc_stream': use_calc_stream + }) + return x + + +def _c_allgather(x, nranks, ring_id=0, use_calc_stream=False): + op_type = 'c_allgather' + helper = LayerHelper(op_type, **locals()) + out_shape = list(x.shape[:]) + if out_shape[0] > 0: + out_shape[0] *= nranks + out = helper.create_variable( + name=unique_name.generate_with_ignorable_key('.'.join( + [x.name, op_type])), + shape=out_shape, + dtype=x.dtype, + type=x.type, + persistable=x.persistable) + helper.append_op( + type=op_type, + inputs={'X': [x]}, + outputs={'Out': [out]}, + attrs={ + 'nranks': nranks, + 'ring_id': ring_id, + 'use_calc_stream': use_calc_stream + }) + return out + + +def _c_reducescatter(x, nranks, ring_id=0, use_calc_stream=False): + if not isinstance(x, Variable): + raise TypeError('x must be a Variable') + + if x.shape[0] % nranks != 0: + raise ValueError('x.shape[0](%d) cannot be evenly divided by nranks(%d)' + % (x.shape[0], nranks)) + + op_type = 'c_reducescatter' + helper = LayerHelper(op_type, **locals()) + out_shape = list(x.shape[:]) + if out_shape[0] > 0: + out_shape[0] //= nranks + out = helper.create_variable( + name=unique_name.generate_with_ignorable_key('.'.join( + [x.name, op_type])), + shape=out_shape, + dtype=x.dtype, + type=x.type, + persistable=x.persistable) + helper.append_op( + type=op_type, + inputs={'X': [x]}, + outputs={'Out': [out]}, + attrs={ + 'nranks': nranks, + 'ring_id': ring_id, + 'use_calc_stream': use_calc_stream + }) + return out + + +def _c_sync_calc_stream(x): + op_type = 'c_sync_calc_stream' + helper = LayerHelper(op_type, **locals()) + helper.append_op(type=op_type, inputs={'X': [x]}, outputs={'Out': [x]}) + return x + + +def _c_sync_comm_stream(x, ring_id): + op_type = 'c_sync_comm_stream' + helper = LayerHelper(op_type, **locals()) + helper.append_op( + type=op_type, + inputs={'X': [x]}, + outputs={'Out': [x]}, + attrs={'ring_id': ring_id}) + return x diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 363a05dd50fc32cd2d14cde6209d53cf2c2aa840..f859a19a180522923a2863ad587d9c376e624b5e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -207,6 +207,7 @@ __all__ = [ 'deformable_conv', 'unfold', 'deformable_roi_pooling', + 'shard_index', ] kIgnoreIndex = -100 @@ -6643,13 +6644,17 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): return loss -def one_hot(input, depth): +def one_hot(input, depth, allow_out_of_range=False): """ This layer creates the one-hot representations for input indices. Args: input(Variable): Input indices, last dimension must be 1. depth(scalar): An interger defining the depth of the one-hot dimension. + allow_out_of_range(bool): A bool value indicating whether the input + indices could be out of range [0, depth). When input indices are + out of range, exceptions is raised if allow_out_of_range is False, + or zero-filling representations is created if it is set True Returns: Variable: The one-hot representations of input. @@ -12516,3 +12521,87 @@ def deformable_roi_pooling(input, "trans_std": trans_std }) return output + + +def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): + """ + This layer creates the sharded index for input. This layers is used in + model- and data- parallel mixed training generally, in which the index + data (usually the label) should be recaculated in each trainer according + to + + .. math:: + + assert index_num % nshards == 0 + + shard_size = index_num / nshards + + y = x % shard_size if x / shard_size == shard_id else ignore_value + + We take the distributed one-hot representation to show what this layer is + used for. The distributed one-hot representation is seperated into multiple + shards, and each shard is filling zeros except the one with the index + inside. In order to create these sharded representation in each trainer, + the original index should be recalculated (i.e. sharded) before. + + Examples: + + X is a Tensor of integer values: + X.shape = [4, 1] + X.data = [[1], [6], [12], [19]] + + suppose index_num = 20 and nshards = 2, then we get shard_size = 10 + + if shard_id == 0, we get the Out: + Out.shape = [4, 1] + Out.data = [[1], [6], [-1], [-1]] + + if shard_id == 1, we get the Out: + Out.shape = [4, 1] + Out.data = [[-1], [-1], [2], [9]] + + the default `ignore_value` -1 is used in this example. + + Args: + input(Variable): Input indices, last dimension must be 1. + index_num(scalar): An interger defining the range of the index. + nshards(scalar): The number of shards + shard_id(scalar): The index of the current shard + ignore_value(scalar): An ingeter value out of sharded index range + + Returns: + Variable: The shard index of input. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + shard_label = fluid.layers.shard_index(input=label, + index_num=20, + nshards=2, + shard_id=0) + """ + op_type = 'shard_index' + helper = LayerHelper(op_type, **locals()) + if index_num % nshards != 0: + raise ValueError( + 'The index_num(%d) cannot be evenly divided by nshards(%d)' % + (index_num, nshards)) + if shard_id < 0 or shard_id >= nshards: + raise ValueError('The shard_id(%d) should be in [0, %d)' % + (shard_id, nshards)) + + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type=op_type, + inputs={'X': [input]}, + outputs={'Out': out}, + attrs={ + 'index_num': index_num, + 'nshards': nshards, + 'shard_id': shard_id, + 'ignore_value': ignore_value + }, + stop_gradient=True) + return out diff --git a/python/paddle/fluid/tests/unittests/test_one_hot_op.py b/python/paddle/fluid/tests/unittests/test_one_hot_op.py index f213a0c77f4babdb46626c6e7d9b631a4e79a631..62184f771942b2f94b65ffd2f2253e1121d15f9d 100644 --- a/python/paddle/fluid/tests/unittests/test_one_hot_op.py +++ b/python/paddle/fluid/tests/unittests/test_one_hot_op.py @@ -118,6 +118,25 @@ class TestOneHotOp_default_dtype_attr(OpTest): self.check_output() +class TestOneHotOp_out_of_range(OpTest): + def setUp(self): + self.op_type = 'one_hot' + depth = 10 + x_lod = [[4, 1, 3, 3]] + x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))] + x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1]) + + out = np.zeros(shape=(np.product(x.shape[:-1]), + depth)).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'depth': depth, 'allow_out_of_range': True} + self.outputs = {'Out': (out, x_lod)} + + def test_check_output(self): + self.check_output() + + class TestOneHotOp_exception(OpTest): def setUp(self): self.op_type = 'one_hot' diff --git a/python/paddle/fluid/tests/unittests/test_shard_index_op.py b/python/paddle/fluid/tests/unittests/test_shard_index_op.py new file mode 100644 index 0000000000000000000000000000000000000000..fd3c0a5458ab8cc675b4de43516164b6386a4882 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_shard_index_op.py @@ -0,0 +1,85 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +from op_test import OpTest +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.framework as framework +from paddle.fluid.framework import Program, program_guard + + +def common_setup(self, index_num, nshards, shard_id, ignore_value): + self.op_type = 'shard_index' + x_lod = [[i for i in range(10)]] + N = sum(x_lod[0]) + x = [np.random.randint(0, index_num - 1) for i in range(N)] + x = np.array(x).astype('int32').reshape([N, 1]) + + shard_size = index_num // nshards + out = np.zeros(shape=x.shape).astype('int32') + for i in range(N): + if x[i] // shard_size == shard_id: + out[i] = x[i] % shard_size + else: + out[i] = ignore_value + + self.inputs = {'X': (x, x_lod)} + self.attrs = { + 'index_num': index_num, + 'nshards': nshards, + 'shard_id': shard_id, + 'ignore_value': ignore_value + } + self.outputs = {'Out': (out, x_lod)} + + +class TestShardIndexShardId0Op(OpTest): + def setUp(self): + common_setup(self, 20, 2, 0, -1) + + def test_check_output(self): + self.check_output() + + +class TestShardIndexShardId1Op(OpTest): + def setUp(self): + common_setup(self, 20, 2, 1, -1) + + def test_check_output(self): + self.check_output() + + +class TestShardIndexIgnoreValueOp(OpTest): + def setUp(self): + common_setup(self, 20, 2, 0, -2) + + def test_check_output(self): + self.check_output() + + +class TestShardIndexNotEvenlyDividedOp(OpTest): + def setUp(self): + common_setup(self, 15, 2, 1, -1) + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index 18cf1fec417598a402127585c40ddd53bd49e9f5..12edb56d0b80b5b5e9b262ed3406c9ce740f1630 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -134,6 +134,9 @@ class Collective(object): block = self.startup_program.global_block() ring_id = -1 for param in block.iter_parameters(): + if param.is_distributed: + continue + ring_id = (ring_id + 1) % self.nrings block.append_op( type='c_broadcast', @@ -219,6 +222,9 @@ class GradAllReduce(Collective): for i in range(0, len(op_role_var), 2): param = block.vars[op_role_var[i]] grad = block.vars[op_role_var[i + 1]] + if param.is_distributed: + continue + if offset == idx: offset += 1 block._insert_op( @@ -273,6 +279,9 @@ class LocalSGD(Collective): block = self.startup_program.global_block() for param in block.iter_parameters(): + if param.is_distributed: + continue + snapshot = block.create_var( name=self.snapshot_name(param.name), shape=param.shape, @@ -294,6 +303,9 @@ class LocalSGD(Collective): for idx, op in reversed(list(enumerate(block.ops))): if self._is_update_op(op): param = block.vars[op.input('Param')[0]] + if param.is_distributed: + continue + snapshot = block.create_var( name=self.snapshot_name(param.name), shape=param.shape,