diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 9b56ad4c55e35d497aa7abe4e1da3867a2084b88..2da52dbf48c870353a06efe29675f3b225aefa1d 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -270,6 +270,7 @@ op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) op_library(unsqueeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op) +op_library(extract_rows_op DEPS memory) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 1612927055dd4ec5ee2220bc2b285e8d9b640ea8..da5d20505e9b06c0717af8d79d5456a9ade1e89c 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -17,9 +17,9 @@ if(WITH_GRPC) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(grpc_serde_test SRCS grpc_serde_test.cc - DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) - cc_test(rpc_server_test SRCS rpc_server_test.cc - DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL) + DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) + cc_test(rpc_server_test SRCS rpc_server_test.cc + DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) return() endif() diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index 9f2360ec70d2ce5d4e16435595e109c1bf04fd13..b50830c362d3f6ecf38affbfa6a1ffe2ed77e125 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -30,7 +30,7 @@ namespace framework = paddle::framework; namespace platform = paddle::platform; namespace distributed = paddle::operators::distributed; -USE_OP(lookup_table); +USE_NO_KERNEL_OP(lookup_sparse_table); std::unique_ptr g_rpc_service; std::unique_ptr g_req_handler; @@ -42,13 +42,13 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}}); framework::VariableNameMap output({{"Output", {"out"}}}); auto op = block->AppendOp(); - op->SetType("lookup_table"); + op->SetType("lookup_sparse_table"); op->SetInput("W", {"w"}); op->SetInput("Ids", {"ids"}); op->SetOutput("Out", {"out"}); auto& out = *root_block->Var("out"); - out.SetType(framework::proto::VarType::SELECTED_ROWS); + out.SetType(framework::proto::VarType::LOD_TENSOR); out.SetShape({10, 10}); return block; @@ -59,20 +59,19 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { w_var->GetMutable(); auto out_var = scope->Var("out"); - out_var->GetMutable(); + out_var->GetMutable(); auto ids_var = scope->Var("ids"); - ids_var->GetMutable(); + ids_var->GetMutable(); } void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, int64_t rows_numel) { CreateVarsOnScope(scope, place); - auto ids_var = scope->Var("ids")->GetMutable(); - auto rows = ids_var->mutable_rows(); - for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2); - ids_var->mutable_value()->Resize({rows_numel, 1}); - ids_var->mutable_value()->mutable_data(*place); + auto ids_var = scope->Var("ids")->GetMutable(); + int64_t* ids_ptr = + ids_var->mutable_data(framework::DDim({rows_numel, 1}), *place); + for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2; } void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, @@ -148,11 +147,11 @@ TEST(PREFETCH, CPU) { client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name); client->Wait(); auto var = scope.Var(out_var_name); - auto value = var->GetMutable()->value(); - auto ptr = value.mutable_data(place); + auto value = var->GetMutable(); + auto ptr = value->mutable_data(place); for (int64_t i = 0; i < rows_numel; ++i) { - EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast(i * 2)); + EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast(i * 2)); } } diff --git a/paddle/fluid/operators/extract_rows_op.cc b/paddle/fluid/operators/extract_rows_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a297d03cfb041e584159a5fc5ba214f8ac404b4 --- /dev/null +++ b/paddle/fluid/operators/extract_rows_op.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class ExtractRowsOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ExtractRowsOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ExtractRowsOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X")[0], + framework::proto::VarType::SELECTED_ROWS, + "The type of input(X) must be SelectedRows."); + auto in_dims = ctx->GetInputDim("X"); + + ctx->SetOutputDim( + "Out", framework::make_ddim(std::vector{in_dims[0], 1})); + } +}; + +class ExtractRowsOp : public framework::OperatorBase { + public: + ExtractRowsOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &in = scope.FindVar(Input("X"))->Get(); + auto out = scope.FindVar(Output("Out"))->GetMutable(); + + auto in_rows = in.rows(); + auto out_dim = framework::make_ddim( + std::vector{static_cast(in_rows.size()), 1}); + auto dst_ptr = out->mutable_data(out_dim, in.place()); + + if (paddle::platform::is_gpu_place(in.place())) { +#ifdef PADDLE_WITH_CUDA + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto *dev_ctx = pool.Get(in.place()); + auto src_ptr = in_rows.Data(in.place()); + auto stream = + reinterpret_cast(*dev_ctx) + .stream(); + memory::Copy(boost::get(out->place()), dst_ptr, + boost::get(in.place()), src_ptr, + in_rows.size() * sizeof(int64_t), stream); +#else + PADDLE_THROW("Not compiled with CUDA."); +#endif + } else { + memory::Copy(platform::CPUPlace(), dst_ptr, platform::CPUPlace(), + in_rows.data(), in_rows.size() * sizeof(int64_t)); + } + } +}; + +class ExtractRowsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(SelectedRows). The input tensor of extract_rows operator," + " and its type is SelectedRows."); + AddOutput("Out", "(Tensor). The the rows of input(X)."); + + AddComment(R"DOC( + ExtractRows Operator. + +The function of extract_rows_op is extracting the rows from the input(X) +whose type is SelectedRows. + + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(extract_rows, ops::ExtractRowsOp, ops::ExtractRowsOpMaker, + ops::ExtractRowsOpInferShape); diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index bda499432214b8841c8dfc406ee45ca0367920e7..3e8f3ec5c5cd683343bcbdfc2388bd37c25e00f9 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -33,19 +33,15 @@ class LookupTableOp : public framework::OperatorWithKernel { auto table_dims = ctx->GetInputDim("W"); auto ids_dims = ctx->GetInputDim("Ids"); - auto ids_var_type = ctx->GetInputsVarType("Ids").front(); - // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type - // is LoDTensor, this tensor contains the ids to be looked up in W - // and it must be a column vector with rank = 2 while the 2nd dimension - // size must be 1, when Ids's type is SelectedRows, the rows of Ids - // contains the ids to be looked up in W; - if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ(ids_dims.size(), 2); - PADDLE_ENFORCE_EQ(ids_dims[1], 1); - } + PADDLE_ENFORCE_EQ(ids_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[1], 1); ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); - ctx->ShareLoD("Ids", /*->*/ "Out"); + + if (ctx->GetOutputsVarType("Out")[0] == + framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("Ids", /*->*/ "Out"); + } } protected: @@ -62,17 +58,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("W", "(Tensor) The input represents embedding tensors, " "which is a learnable parameter."); - AddInput( - "Ids", - "(Tensor or SelectedRows) Ids's type can be Tensor or " - "SelectedRows, when Ids's type is Tensor, this tensor contains " - "the ids to be looked up in W and it must be a column vector with " - "rank = 2 while the 2nd dimension size must be 1; when Ids's type is " - "SelectedRows, the rows of Ids contains the ids to be looked up " - "in W."); - AddOutput("Out", - "(Tensor or SelectedRows) The lookup results, which have the " - "same type as W."); + AddInput("Ids", + "An input with type int32 or int64 " + "contains the ids to be looked up in W. " + "Ids must be a column vector with rank = 2. " + "The 2nd dimension size must be 1."); + AddOutput("Out", "The lookup results, which have the same type as W."); AddAttr("is_sparse", "(boolean, default false) " "Sparse update.") @@ -90,15 +81,10 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { Lookup Table Operator. This operator is used to perform lookups on the parameter W, -then concatenated into a dense or sparse tensor. - -The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's -type is SelectedRows, the rows of Ids contains the ids to be looked up in W; -when Ids's type is Tensor, this tensor contains the ids to be looked up in W -and it must be a column vector with rank = 2 while the 2nd dimension size must be 1, -at this time, Ids can carry the LoD (Level of Details) information, or not, and -the output only shares the LoD information with input Ids. +then concatenated into a dense tensor. +The input Ids can carry the LoD (Level of Details) information, +or not. And the output only shares the LoD information with input Ids. )DOC"); } diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 77722c50d39003d9342afb04a61ae3aaf6b21100..27483372b93a850d313445386c7973838c4a0710 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -23,7 +23,7 @@ namespace operators { template -__global__ void LookupTable(T* output, const T* table, const int64_t* ids, +__global__ void LookupTable(T *output, const T *table, const int64_t *ids, const int64_t N, const int64_t K, const int64_t D, const int64_t padding_idx) { int idx = threadIdx.x; @@ -33,8 +33,8 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids, int64_t id = ids[idy]; PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id < N); - T* out = output + idy * D; - const T* tab = table + id * D; + T *out = output + idy * D; + const T *tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { if (PaddingFlag) { if (id == padding_idx) @@ -50,7 +50,7 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids, } template -__global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids, +__global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids, const int64_t N, const int64_t K, const int64_t D) { int idx = threadIdx.x; @@ -60,8 +60,8 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids, int id = ids[idy]; PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id < N); - const T* out = output + idy * D; - T* tab = table + id * D; + const T *out = output + idy * D; + T *tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { paddle::platform::CudaAtomicAdd(&tab[i], out[i]); } @@ -72,36 +72,19 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids, template class LookupTableCUDAKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* table_t = context.Input("W"); + void Compute(const framework::ExecutionContext &context) const override { + auto *table_t = context.Input("W"); + auto *ids_t = context.Input("Ids"); + auto *output_t = context.Output("Out"); int64_t padding_idx = context.Attr("padding_idx"); - auto* ids_var = context.InputVar("Ids"); - Tensor* output_t = context.Output("Out"); - - int64_t* ids; - int64_t K; - - // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type - // is LoDTensor, this tensor contains the ids to be looked up in W; - // when Ids's type is SelectedRows, the rows of Ids contains the - // ids to be looked up in W. - if (ids_var->IsType()) { - auto* ids_t = context.Input("Ids"); - ids = const_cast(ids_t->data()); - K = ids_t->numel(); - } else if (ids_var->IsType()) { - auto* ids_t = context.Input("Ids"); - ids = const_cast(ids_t->rows().CUDAData(context.GetPlace())); - K = ids_t->rows().size(); - output_t->Resize({K, table_t->dims()[1]}); - } else { - PADDLE_THROW("Unsupported Variable Type of Ids"); - } size_t N = table_t->dims()[0]; size_t D = table_t->dims()[1]; - auto* table = table_t->data(); - auto* output = output_t->mutable_data(context.GetPlace()); + size_t K = ids_t->numel(); + + auto *ids = ids_t->data(); + auto *table = table_t->data(); + auto *output = output_t->mutable_data(context.GetPlace()); dim3 threads(128, 8); dim3 grids(8, 1); @@ -122,19 +105,19 @@ class LookupTableCUDAKernel : public framework::OpKernel { template class LookupTableGradCUDAKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto& dev_ctx = + void Compute(const framework::ExecutionContext &context) const override { + auto &dev_ctx = context.template device_context(); bool is_sparse = context.Attr("is_sparse"); // Since paddings are not trainable and fixed in forward, the gradient of // paddings makes no sense and we don't deal with it in backward. if (is_sparse) { - auto* ids = context.Input("Ids"); - auto* table = context.Input("W"); - auto* d_output = context.Input(framework::GradVarName("Out")); - auto* d_table = context.Output(framework::GradVarName("W")); + auto *ids = context.Input("Ids"); + auto *table = context.Input("W"); + auto *d_output = context.Input(framework::GradVarName("Out")); + auto *d_table = context.Output(framework::GradVarName("W")); - auto* ids_data = ids->data(); + auto *ids_data = ids->data(); auto ids_dim = ids->dims(); auto stream = dev_ctx.stream(); @@ -150,12 +133,12 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { d_table->set_rows(new_rows); - auto* d_table_value = d_table->mutable_value(); + auto *d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_dim[0], table->dims()[1]}); d_table_value->mutable_data(context.GetPlace()); - auto* d_table_data = d_table_value->data(); - auto* d_output_data = d_output->data(); + auto *d_table_data = d_table_value->data(); + auto *d_output_data = d_output->data(); PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims()); memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data, d_output->numel() * sizeof(T), stream); @@ -168,9 +151,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { int N = d_table_t->dims()[0]; int D = d_table_t->dims()[1]; int K = ids_t->numel(); - const int64_t* ids = ids_t->data(); - const T* d_output = d_output_t->data(); - T* d_table = d_table_t->mutable_data(context.GetPlace()); + const int64_t *ids = ids_t->data(); + const T *d_output = d_output_t->data(); + T *d_table = d_table_t->mutable_data(context.GetPlace()); auto t = framework::EigenVector::Flatten(*d_table_t); t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index d482506bf0361c11a019e32efbf348a64aaf5164..c9f074ca0e8dafb374dc9368165df5af5053a6b8 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -36,43 +36,13 @@ template class LookupTableKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { + auto *ids_t = context.Input("Ids"); // int tensor + auto *output_t = context.Output("Out"); // float tensor auto *table_var = context.InputVar("W"); - auto *ids_var = context.InputVar("Ids"); - Tensor *output_t = context.Output("Out"); - int64_t padding_idx = context.Attr("padding_idx"); - - DDim table_dim; - if (table_var->IsType()) { - table_dim = context.Input("W")->dims(); - } else if (table_var->IsType()) { - auto *table_t = context.Input("W"); - table_dim = table_t->value().dims(); - } else { - PADDLE_THROW( - "The parameter W of a LookupTable " - "must be either LoDTensor or SelectedRows"); - } - - int64_t *ids; - int64_t ids_numel; - - // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type - // is LoDTensor, this tensor contains the ids to be looked up in W; - // when Ids's type is SelectedRows, the rows of Ids contains the - // ids to be looked up in W. - if (ids_var->IsType()) { - auto *ids_t = context.Input("Ids"); - ids = const_cast(ids_t->data()); - ids_numel = ids_t->numel(); - } else if (ids_var->IsType()) { - auto *ids_t = context.Input("Ids"); - ids = const_cast(ids_t->rows().data()); - ids_numel = ids_t->rows().size(); - output_t->Resize({ids_numel, table_dim[1]}); - } else { - PADDLE_THROW("Unsupported Variable Type of Ids"); - } + int64_t padding_idx = context.Attr("padding_idx"); + int64_t *ids = const_cast(ids_t->data()); + int64_t ids_numel = ids_t->numel(); if (table_var->IsType()) { auto *table_t = context.Input("W"); diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py index 080c185420bdc79d6da1d5a52fdd11fa4105d59a..3712955b3b32de457a0d47120a00ab7d4ecd5a66 100644 --- a/python/paddle/fluid/regularizer.py +++ b/python/paddle/fluid/regularizer.py @@ -142,14 +142,20 @@ class L2DecayRegularizer(WeightDecayRegularizer): dtype="float32", shape=param.shape, lod_level=param.lod_level) if grad.type == core.VarDesc.VarType.SELECTED_ROWS: + idx = block.create_var( + dtype="int64", + shape=param.shape, + type=core.VarDesc.VarType.LOD_TENSOR) decay = block.create_var( dtype="float32", shape=param.shape, type=core.VarDesc.VarType.SELECTED_ROWS) + block.append_op( + type='extract_rows', inputs={'X': grad}, outputs={'Out': idx}) block.append_op( type='lookup_table', inputs={'W': param, - 'Ids': grad}, + 'Ids': idx}, outputs={'Out': decay}, attrs={'is_sparse': True}) param = decay @@ -216,14 +222,20 @@ class L1DecayRegularizer(WeightDecayRegularizer): dtype="float32", shape=param.shape, lod_level=param.lod_level) if grad.type == core.VarDesc.VarType.SELECTED_ROWS: + idx = block.create_var( + dtype="int64", + shape=param.shape, + type=core.VarDesc.VarType.LOD_TENSOR) decay = block.create_var( dtype="float32", shape=param.shape, type=core.VarDesc.VarType.SELECTED_ROWS) + block.append_op( + type='extract_rows', inputs={'X': grad}, outputs={'Out': idx}) block.append_op( type='lookup_table', inputs={'W': param, - 'Ids': grad}, + 'Ids': idx}, outputs={'Out': decay}, attrs={'is_sparse': True}) diff --git a/python/paddle/fluid/tests/unittests/test_extract_rows_op.py b/python/paddle/fluid/tests/unittests/test_extract_rows_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6a41c44fe655b18626bdb727745dae032babe8ad --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_extract_rows_op.py @@ -0,0 +1,58 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from op_test import OpTest + + +class TestExtractRows(OpTest): + def check_with_place(self, place): + scope = core.Scope() + + # create and initialize Variable + feature_len = 12 + rows = [0, 4, 4, 7] + np_array = np.ones((len(rows), feature_len)).astype("float32") + + in_x = scope.var('X').get_selected_rows() + in_x.set_height(len(rows)) + in_x.set_rows(rows) + in_x_tensor = in_x.get_tensor() + in_x_tensor.set(np_array, place) + + # create Out Variable + out_tensor = scope.var('Out').get_tensor() + + # create and run lookup_table operator + extract_rows_op = Operator("extract_rows", X='X', Out='Out') + extract_rows_op.run(scope, place) + + # get result from Out + result_array = np.array(out_tensor) + result_array = [ele[0] for ele in result_array] + assert result_array == rows + + def test_concat_rows(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index f8d5785fbfe64843f4aa3b96b24809df60980c74..e16ab1d15f165bd0efa1b7d51add36c3020a1910 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -49,53 +49,6 @@ class TestLookupTableOpWithPadding(TestLookupTableOp): pass -class TestLookupTableIdsIsSelectedRows(OpTest): - def check_with_place(self, place): - scope = core.Scope() - - # create and initialize Variable - height = 10 - rows = [0, 4, 4, 7] - row_numel = 12 - - # create and initialize W Variable - W = scope.var('W').get_tensor() - W_array = np.full((height, row_numel), 1.0).astype("float32") - for i in range(height): - W_array[i] *= i - W.set(W_array, place) - - # create and initialize Ids Variable - ids_selected_rows = scope.var('Ids').get_selected_rows() - ids_selected_rows.set_height(len(rows)) - ids_selected_rows.set_rows(rows) - np_array = np.ones((len(rows), row_numel)).astype("float32") - ids_tensor = ids_selected_rows.get_tensor() - ids_tensor.set(np_array, place) - - # create Out Variable - Out = scope.var('Out').get_selected_rows() - - # create and run lookup_table operator - concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out') - concat_rows_op.run(scope, place) - - # get result from Out - Out_tensor = Out.get_tensor() - result_array = np.array(Out_tensor) - - # all(): return True if all elements of the iterable are true (or if the iterable is empty) - for idx, row in enumerate(rows): - assert (row == result_array[idx]).all() - - def test_concat_rows(self): - places = [core.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(core.CUDAPlace(0)) - for place in places: - self.check_with_place(place) - - class TestLookupTableWIsSelectedRows(OpTest): def check_with_place(self, place): scope = core.Scope()