From d6c8d2675cd07ae679f922fef83a0a089d04c2ee Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 12 Jun 2018 23:26:58 +0800 Subject: [PATCH] optimize code and comment --- paddle/fluid/operators/merge_ids_op.cc | 10 ++++++---- paddle/fluid/operators/merge_ids_op.h | 10 +++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/merge_ids_op.cc b/paddle/fluid/operators/merge_ids_op.cc index bae649adec..f3940231d7 100644 --- a/paddle/fluid/operators/merge_ids_op.cc +++ b/paddle/fluid/operators/merge_ids_op.cc @@ -21,15 +21,17 @@ class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); - AddInput("X", - "(LoDTensor) the input tensor with shape{batch_num, N}, N is the " - "size of embedding table") + AddInput( + "X", + "(LoDTensors) multi input tensor with shape{batch_num, N}, N is the " + "size of embedding table") .AsDuplicable(); AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors."); AddComment(R"DOC( Merge multi LoDTensor's into one according to Ids's shard num. -The values in the input LoDTensor are lookuped from the output of splite_ids_op +The values in the input LoDTensor are lookuped from the output of split_ids_op + Example: Input: Ids = [1,2,3,4,5,6] diff --git a/paddle/fluid/operators/merge_ids_op.h b/paddle/fluid/operators/merge_ids_op.h index 065368f8dd..83712a8519 100644 --- a/paddle/fluid/operators/merge_ids_op.h +++ b/paddle/fluid/operators/merge_ids_op.h @@ -47,10 +47,10 @@ class MergeIdsOpKernel : public framework::OpKernel { int batch_size = 0; int embedding_size = 0; for (auto &input : x_tensors) { - if (embedding_size == 0) { - embedding_size = input->dims()[1]; - } if (framework::product(input->dims()) != 0) { + if (embedding_size == 0) { + embedding_size = input->dims()[1]; + } PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], "embedding size of all input should be the same"); batch_size += input->dims()[0]; @@ -58,7 +58,7 @@ class MergeIdsOpKernel : public framework::OpKernel { } PADDLE_ENFORCE_EQ( batch_size, ids_dims[0], - "the batch size of ids and embedding value should be the same"); + "the batch size of ids and merged embedding value should be the same"); const size_t shard_num = x_tensors.size(); @@ -80,7 +80,7 @@ class MergeIdsOpKernel : public framework::OpKernel { in_indexs[shard_id] += 1; } - for (int i = 0; i < shard_num; ++i) { + for (size_t i = 0; i < shard_num; ++i) { PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0], "after merge, all data in x_tensor should be used"); } -- GitLab