提交 d6c8d267 编写于 作者: Q qiaolongfei

optimize code and comment

上级 f031555c
...@@ -21,15 +21,17 @@ class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -21,15 +21,17 @@ class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
AddInput("X", AddInput(
"(LoDTensor) the input tensor with shape{batch_num, N}, N is the " "X",
"size of embedding table") "(LoDTensors) multi input tensor with shape{batch_num, N}, N is the "
"size of embedding table")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors."); AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.");
AddComment(R"DOC( AddComment(R"DOC(
Merge multi LoDTensor's into one according to Ids's shard num. Merge multi LoDTensor's into one according to Ids's shard num.
The values in the input LoDTensor are lookuped from the output of splite_ids_op The values in the input LoDTensor are lookuped from the output of split_ids_op
Example: Example:
Input: Input:
Ids = [1,2,3,4,5,6] Ids = [1,2,3,4,5,6]
......
...@@ -47,10 +47,10 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -47,10 +47,10 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
int batch_size = 0; int batch_size = 0;
int embedding_size = 0; int embedding_size = 0;
for (auto &input : x_tensors) { for (auto &input : x_tensors) {
if (embedding_size == 0) {
embedding_size = input->dims()[1];
}
if (framework::product(input->dims()) != 0) { if (framework::product(input->dims()) != 0) {
if (embedding_size == 0) {
embedding_size = input->dims()[1];
}
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
"embedding size of all input should be the same"); "embedding size of all input should be the same");
batch_size += input->dims()[0]; batch_size += input->dims()[0];
...@@ -58,7 +58,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -58,7 +58,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
batch_size, ids_dims[0], batch_size, ids_dims[0],
"the batch size of ids and embedding value should be the same"); "the batch size of ids and merged embedding value should be the same");
const size_t shard_num = x_tensors.size(); const size_t shard_num = x_tensors.size();
...@@ -80,7 +80,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -80,7 +80,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
in_indexs[shard_id] += 1; in_indexs[shard_id] += 1;
} }
for (int i = 0; i < shard_num; ++i) { for (size_t i = 0; i < shard_num; ++i) {
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0], PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
"after merge, all data in x_tensor should be used"); "after merge, all data in x_tensor should be used");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册