From 509cb0bc76ea3b22423e293f608d4956a63deda7 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 10 Jun 2018 23:31:41 +0800 Subject: [PATCH] add unit test, pass the unit test --- paddle/fluid/operators/merge_ids_op.cc | 12 +++++- paddle/fluid/operators/merge_ids_op.h | 23 +++++++---- .../tests/unittests/test_merge_ids_op.py | 38 +++++++++++++++++++ 3 files changed, 64 insertions(+), 9 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_merge_ids_op.py diff --git a/paddle/fluid/operators/merge_ids_op.cc b/paddle/fluid/operators/merge_ids_op.cc index 939561509c2..bae649adecb 100644 --- a/paddle/fluid/operators/merge_ids_op.cc +++ b/paddle/fluid/operators/merge_ids_op.cc @@ -73,6 +73,15 @@ class MergeIdsOp : public framework::OperatorWithKernel { } ctx->ShareLoD("Ids", "Out"); } + + private: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.MultiInput("X").front()->type()), + ctx.GetPlace()); + } }; class MergeIdsOpInferVarType : public framework::VarTypeInference { @@ -93,5 +102,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker, ops::MergeIdsOpInferVarType); REGISTER_OP_CPU_KERNEL( - merge_ids, ops::MergeIdsOpKernel, - ops::MergeIdsOpKernel); + merge_ids, ops::MergeIdsOpKernel); diff --git a/paddle/fluid/operators/merge_ids_op.h b/paddle/fluid/operators/merge_ids_op.h index fd5b542ceb7..065368f8dd5 100644 --- a/paddle/fluid/operators/merge_ids_op.h +++ b/paddle/fluid/operators/merge_ids_op.h @@ -30,6 +30,7 @@ class MergeIdsOpKernel : public framework::OpKernel { if (!platform::is_cpu_place(place)) { PADDLE_THROW("MergeIds do not support GPU kernel"); } + VLOG(3) << "run in MergeIdsOpKernel"; const auto *ids_var = ctx.InputVar("Ids"); PADDLE_ENFORCE(ids_var->IsType(), @@ -37,7 +38,7 @@ class MergeIdsOpKernel : public framework::OpKernel { const auto &ids_tensor = ids_var->Get(); const auto &ids_dims = ids_tensor.dims(); - const T *ids = ids_tensor.data(); + const int64_t *ids = ids_tensor.data(); auto x_tensors = ctx.MultiInput("X"); @@ -49,9 +50,11 @@ class MergeIdsOpKernel : public framework::OpKernel { if (embedding_size == 0) { embedding_size = input->dims()[1]; } - PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], - "embedding size of all input should be the same"); - batch_size += input->dims()[0]; + if (framework::product(input->dims()) != 0) { + PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], + "embedding size of all input should be the same"); + batch_size += input->dims()[0]; + } } PADDLE_ENFORCE_EQ( batch_size, ids_dims[0], @@ -61,13 +64,14 @@ class MergeIdsOpKernel : public framework::OpKernel { if (shard_num == 1) { VLOG(3) << "only one shard, we can copy the data directly"; - TensorCopy(ids_tensor, place, out); + TensorCopy(*x_tensors[0], place, out); } else { std::vector in_indexs(shard_num, 0); - auto *out_data = out->mutable_data(ids_dims, place); + auto *out_data = out->mutable_data( + framework::make_ddim({batch_size, embedding_size}), place); // copy data from ins[shard_num] to out. for (int i = 0; i < ids_dims[0]; ++i) { - T id = ids[i]; + int64_t id = ids[i]; size_t shard_id = static_cast(id) % shard_num; int index = in_indexs[shard_id]; memcpy(out_data + embedding_size * i, @@ -75,6 +79,11 @@ class MergeIdsOpKernel : public framework::OpKernel { sizeof(T) * embedding_size); in_indexs[shard_id] += 1; } + + for (int i = 0; i < shard_num; ++i) { + PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0], + "after merge, all data in x_tensor should be used"); + } } } }; diff --git a/python/paddle/fluid/tests/unittests/test_merge_ids_op.py b/python/paddle/fluid/tests/unittests/test_merge_ids_op.py new file mode 100644 index 00000000000..f209bdf30fa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_merge_ids_op.py @@ -0,0 +1,38 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestMergeIdsOp(OpTest): + def setUp(self): + self.op_type = "merge_ids" + ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') + x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32') + x1 = np.array([]).astype('float32') + x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6], + [0.5, 0.6]]).astype('float32') + out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3], + [0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32') + self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() -- GitLab