提交 509cb0bc 编写于 作者: Q qiaolongfei

add unit test, pass the unit test

上级 7cebec4b
...@@ -73,6 +73,15 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -73,6 +73,15 @@ class MergeIdsOp : public framework::OperatorWithKernel {
} }
ctx->ShareLoD("Ids", "Out"); ctx->ShareLoD("Ids", "Out");
} }
private:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("X").front()->type()),
ctx.GetPlace());
}
}; };
class MergeIdsOpInferVarType : public framework::VarTypeInference { class MergeIdsOpInferVarType : public framework::VarTypeInference {
...@@ -93,5 +102,4 @@ namespace ops = paddle::operators; ...@@ -93,5 +102,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker, REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker,
ops::MergeIdsOpInferVarType); ops::MergeIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
merge_ids, ops::MergeIdsOpKernel<paddle::platform::CPUPlace, int64_t>, merge_ids, ops::MergeIdsOpKernel<paddle::platform::CPUPlace, float>);
ops::MergeIdsOpKernel<paddle::platform::CPUPlace, float>);
...@@ -30,6 +30,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -30,6 +30,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
if (!platform::is_cpu_place(place)) { if (!platform::is_cpu_place(place)) {
PADDLE_THROW("MergeIds do not support GPU kernel"); PADDLE_THROW("MergeIds do not support GPU kernel");
} }
VLOG(3) << "run in MergeIdsOpKernel";
const auto *ids_var = ctx.InputVar("Ids"); const auto *ids_var = ctx.InputVar("Ids");
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
...@@ -37,7 +38,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -37,7 +38,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
const auto &ids_tensor = ids_var->Get<framework::LoDTensor>(); const auto &ids_tensor = ids_var->Get<framework::LoDTensor>();
const auto &ids_dims = ids_tensor.dims(); const auto &ids_dims = ids_tensor.dims();
const T *ids = ids_tensor.data<T>(); const int64_t *ids = ids_tensor.data<int64_t>();
auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X"); auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
...@@ -49,10 +50,12 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -49,10 +50,12 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
if (embedding_size == 0) { if (embedding_size == 0) {
embedding_size = input->dims()[1]; embedding_size = input->dims()[1];
} }
if (framework::product(input->dims()) != 0) {
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
"embedding size of all input should be the same"); "embedding size of all input should be the same");
batch_size += input->dims()[0]; batch_size += input->dims()[0];
} }
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
batch_size, ids_dims[0], batch_size, ids_dims[0],
"the batch size of ids and embedding value should be the same"); "the batch size of ids and embedding value should be the same");
...@@ -61,13 +64,14 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -61,13 +64,14 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
if (shard_num == 1) { if (shard_num == 1) {
VLOG(3) << "only one shard, we can copy the data directly"; VLOG(3) << "only one shard, we can copy the data directly";
TensorCopy(ids_tensor, place, out); TensorCopy(*x_tensors[0], place, out);
} else { } else {
std::vector<int> in_indexs(shard_num, 0); std::vector<int> in_indexs(shard_num, 0);
auto *out_data = out->mutable_data<T>(ids_dims, place); auto *out_data = out->mutable_data<T>(
framework::make_ddim({batch_size, embedding_size}), place);
// copy data from ins[shard_num] to out. // copy data from ins[shard_num] to out.
for (int i = 0; i < ids_dims[0]; ++i) { for (int i = 0; i < ids_dims[0]; ++i) {
T id = ids[i]; int64_t id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num; size_t shard_id = static_cast<size_t>(id) % shard_num;
int index = in_indexs[shard_id]; int index = in_indexs[shard_id];
memcpy(out_data + embedding_size * i, memcpy(out_data + embedding_size * i,
...@@ -75,6 +79,11 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -75,6 +79,11 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
sizeof(T) * embedding_size); sizeof(T) * embedding_size);
in_indexs[shard_id] += 1; in_indexs[shard_id] += 1;
} }
for (int i = 0; i < shard_num; ++i) {
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
"after merge, all data in x_tensor should be used");
}
} }
} }
}; };
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestMergeIdsOp(OpTest):
def setUp(self):
self.op_type = "merge_ids"
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32')
x1 = np.array([]).astype('float32')
x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6],
[0.5, 0.6]]).astype('float32')
out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3],
[0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32')
self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册