diff --git a/paddle/fluid/operators/split_ids_op.cc b/paddle/fluid/operators/split_ids_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a54f8a2878c8606e6b487552324d1e7dfa94b9b8 --- /dev/null +++ b/paddle/fluid/operators/split_ids_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/split_ids_op.h" + +namespace paddle { +namespace operators { + +class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SplitIdsOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); + AddOutput("Out", "(LoDTensor) The outputs of the input Ids.") + .AsDuplicable(); + + AddComment(R"DOC( +Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number +Example: + Input: + X = [1,2,3,4,5,6] + + Out(3 output): + out0 = [3, 6] + out1 = [1, 4] + out2 = [2, 5] +)DOC"); + } +}; + +class SplitIdsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Ids"), "SplitIdsOp must has input Ids."); + PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out."); + + auto ids_var_type = ctx->GetInputsVarType("Ids").front(); + PADDLE_ENFORCE_EQ(ids_var_type, framework::proto::VarType::LOD_TENSOR); + + auto ids_dims = ctx->GetInputDim("Ids"); + PADDLE_ENFORCE_EQ(ids_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[1], 1); + } +}; + +class SplitIdsOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + for (auto &out_var : op_desc.Output("Out")) { + block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker, + ops::SplitIdsOpInferVarType); +REGISTER_OP_CPU_KERNEL( + split_ids, ops::SplitIdsOpKernel); diff --git a/paddle/fluid/operators/split_ids_op.h b/paddle/fluid/operators/split_ids_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3e750ed2d171876ce2d3c232f5d34234217b3c3e --- /dev/null +++ b/paddle/fluid/operators/split_ids_op.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" + +namespace paddle { +namespace operators { + +template +class SplitIdsOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto place = ctx.GetPlace(); + if (!platform::is_cpu_place(place)) { + PADDLE_THROW("SplitIds do not support GPU kernel"); + } + + const auto* ids_t = ctx.Input("Ids"); + auto& ids_dims = ids_t->dims(); + auto outs = ctx.MultiOutput("Out"); + + const T* ids = ids_t->data(); + + const size_t shard_num = outs.size(); + + std::vector> out_ids; + out_ids.resize(outs.size()); + + // split id by their shard_num. + for (size_t i = 0; i < ids_dims[0]; ++i) { + T id = ids[i]; + size_t shard_id = static_cast(id) % shard_num; + out_ids[shard_id].push_back(id); + } + + // create tensor for each shard and send to parameter server + for (size_t i = 0; i < out_ids.size(); ++i) { + auto* shard_t = outs[i]; + std::vector ids = out_ids[i]; + auto* shard_data = shard_t->mutable_data( + framework::make_ddim({static_cast(ids.size()), 1}), place); + for (size_t i = 0; i < ids.size(); ++i) { + shard_data[i] = ids[i]; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_split_ids_op.py b/python/paddle/fluid/tests/unittests/test_split_ids_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f0a06a56b42952800411d548bb3fc1732e031e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_split_ids_op.py @@ -0,0 +1,35 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestSplitIdsOp(OpTest): + def setUp(self): + self.op_type = "split_ids" + ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') + out0 = np.array([[0], [3], [6]]).astype('int64') + out1 = np.array([[]]).astype('int64') + out2 = np.array([[2], [2], [5], [5]]).astype('int64') + self.inputs = {'Ids': ids} + self.outputs = {'Out': [('out0', out0), ('out1', out1), ('out2', out2)]} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main()