未验证 提交 2409d0f7 编写于 作者: C chengduo 提交者: GitHub

Refine regularization for selected_rows (#12369)

* refine regularization for selected_rows

* clean lookup_table

* refine rpc_server_test

* temporally disable rpc_server_test

* fix rpc_server_test

* add unit test
上级 85c49127
......@@ -270,6 +270,7 @@ op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
......@@ -17,9 +17,9 @@ if(WITH_GRPC)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
return()
endif()
......
......@@ -30,7 +30,7 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed;
USE_OP(lookup_table);
USE_NO_KERNEL_OP(lookup_sparse_table);
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
......@@ -42,13 +42,13 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp();
op->SetType("lookup_table");
op->SetType("lookup_sparse_table");
op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"});
auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::SELECTED_ROWS);
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({10, 10});
return block;
......@@ -59,20 +59,19 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
w_var->GetMutable<framework::SelectedRows>();
auto out_var = scope->Var("out");
out_var->GetMutable<framework::SelectedRows>();
out_var->GetMutable<framework::LoDTensor>();
auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::SelectedRows>();
ids_var->GetMutable<framework::LoDTensor>();
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto ids_var = scope->Var("ids")->GetMutable<framework::SelectedRows>();
auto rows = ids_var->mutable_rows();
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2);
ids_var->mutable_value()->Resize({rows_numel, 1});
ids_var->mutable_value()->mutable_data<float>(*place);
auto ids_var = scope->Var("ids")->GetMutable<framework::LoDTensor>();
int64_t* ids_ptr =
ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
}
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
......@@ -148,11 +147,11 @@ TEST(PREFETCH, CPU) {
client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place);
auto value = var->GetMutable<framework::LoDTensor>();
auto ptr = value->mutable_data<float>(place);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast<float>(i * 2));
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class ExtractRowsOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ExtractRowsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ExtractRowsOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X")[0],
framework::proto::VarType::SELECTED_ROWS,
"The type of input(X) must be SelectedRows.");
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(
"Out", framework::make_ddim(std::vector<int64_t>{in_dims[0], 1}));
}
};
class ExtractRowsOp : public framework::OperatorBase {
public:
ExtractRowsOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &in = scope.FindVar(Input("X"))->Get<framework::SelectedRows>();
auto out = scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto in_rows = in.rows();
auto out_dim = framework::make_ddim(
std::vector<int64_t>{static_cast<int64_t>(in_rows.size()), 1});
auto dst_ptr = out->mutable_data<int64_t>(out_dim, in.place());
if (paddle::platform::is_gpu_place(in.place())) {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(in.place());
auto src_ptr = in_rows.Data(in.place());
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(*dev_ctx)
.stream();
memory::Copy(boost::get<platform::CUDAPlace>(out->place()), dst_ptr,
boost::get<platform::CUDAPlace>(in.place()), src_ptr,
in_rows.size() * sizeof(int64_t), stream);
#else
PADDLE_THROW("Not compiled with CUDA.");
#endif
} else {
memory::Copy(platform::CPUPlace(), dst_ptr, platform::CPUPlace(),
in_rows.data(), in_rows.size() * sizeof(int64_t));
}
}
};
class ExtractRowsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(SelectedRows). The input tensor of extract_rows operator,"
" and its type is SelectedRows.");
AddOutput("Out", "(Tensor). The the rows of input(X).");
AddComment(R"DOC(
ExtractRows Operator.
The function of extract_rows_op is extracting the rows from the input(X)
whose type is SelectedRows.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(extract_rows, ops::ExtractRowsOp, ops::ExtractRowsOpMaker,
ops::ExtractRowsOpInferShape);
......@@ -33,19 +33,15 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W
// and it must be a column vector with rank = 2 while the 2nd dimension
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
// contains the ids to be looked up in W;
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out");
if (ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("Ids", /*->*/ "Out");
}
}
protected:
......@@ -62,17 +58,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("W",
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter.");
AddInput(
"Ids",
"(Tensor or SelectedRows) Ids's type can be Tensor or "
"SelectedRows, when Ids's type is Tensor, this tensor contains "
"the ids to be looked up in W and it must be a column vector with "
"rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
"SelectedRows, the rows of Ids contains the ids to be looked up "
"in W.");
AddOutput("Out",
"(Tensor or SelectedRows) The lookup results, which have the "
"same type as W.");
AddInput("Ids",
"An input with type int32 or int64 "
"contains the ids to be looked up in W. "
"Ids must be a column vector with rank = 2. "
"The 2nd dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr<bool>("is_sparse",
"(boolean, default false) "
"Sparse update.")
......@@ -90,15 +81,10 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
Lookup Table Operator.
This operator is used to perform lookups on the parameter W,
then concatenated into a dense or sparse tensor.
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
at this time, Ids can carry the LoD (Level of Details) information, or not, and
the output only shares the LoD information with input Ids.
then concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC");
}
......
......@@ -23,7 +23,7 @@ namespace operators {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
bool PaddingFlag>
__global__ void LookupTable(T* output, const T* table, const int64_t* ids,
__global__ void LookupTable(T *output, const T *table, const int64_t *ids,
const int64_t N, const int64_t K, const int64_t D,
const int64_t padding_idx) {
int idx = threadIdx.x;
......@@ -33,8 +33,8 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
int64_t id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
T* out = output + idy * D;
const T* tab = table + id * D;
T *out = output + idy * D;
const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
if (PaddingFlag) {
if (id == padding_idx)
......@@ -50,7 +50,7 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
}
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
__global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
const int64_t N, const int64_t K,
const int64_t D) {
int idx = threadIdx.x;
......@@ -60,8 +60,8 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
int id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
const T* out = output + idy * D;
T* tab = table + id * D;
const T *out = output + idy * D;
T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
}
......@@ -72,36 +72,19 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W");
void Compute(const framework::ExecutionContext &context) const override {
auto *table_t = context.Input<LoDTensor>("W");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto* ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out");
int64_t* ids;
int64_t K;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<framework::LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
K = ids_t->numel();
} else if (ids_var->IsType<framework::SelectedRows>()) {
auto* ids_t = context.Input<framework::SelectedRows>("Ids");
ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace()));
K = ids_t->rows().size();
output_t->Resize({K, table_t->dims()[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
size_t K = ids_t->numel();
auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
......@@ -122,19 +105,19 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
template <typename T>
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx =
void Compute(const framework::ExecutionContext &context) const override {
auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
bool is_sparse = context.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W");
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto *ids = context.Input<LoDTensor>("Ids");
auto *table = context.Input<LoDTensor>("W");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>();
auto *ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims();
auto stream = dev_ctx.stream();
......@@ -150,12 +133,12 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
d_table->set_rows(new_rows);
auto* d_table_value = d_table->mutable_value();
auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_dim[0], table->dims()[1]});
d_table_value->mutable_data<T>(context.GetPlace());
auto* d_table_data = d_table_value->data<T>();
auto* d_output_data = d_output->data<T>();
auto *d_table_data = d_table_value->data<T>();
auto *d_output_data = d_output->data<T>();
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
d_output->numel() * sizeof(T), stream);
......@@ -168,9 +151,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
const int64_t* ids = ids_t->data<int64_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
const int64_t *ids = ids_t->data<int64_t>();
const T *d_output = d_output_t->data<T>();
T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
......
......@@ -36,43 +36,13 @@ template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
auto *ids_var = context.InputVar("Ids");
Tensor *output_t = context.Output<Tensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
DDim table_dim;
if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims();
} else if (table_var->IsType<SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows");
}
int64_t *ids;
int64_t ids_numel;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<LoDTensor>()) {
auto *ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t *>(ids_t->data<int64_t>());
ids_numel = ids_t->numel();
} else if (ids_var->IsType<SelectedRows>()) {
auto *ids_t = context.Input<SelectedRows>("Ids");
ids = const_cast<int64_t *>(ids_t->rows().data());
ids_numel = ids_t->rows().size();
output_t->Resize({ids_numel, table_dim[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
......
......@@ -142,14 +142,20 @@ class L2DecayRegularizer(WeightDecayRegularizer):
dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
idx = block.create_var(
dtype="int64",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS)
block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': grad},
'Ids': idx},
outputs={'Out': decay},
attrs={'is_sparse': True})
param = decay
......@@ -216,14 +222,20 @@ class L1DecayRegularizer(WeightDecayRegularizer):
dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
idx = block.create_var(
dtype="int64",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS)
block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': grad},
'Ids': idx},
outputs={'Out': decay},
attrs={'is_sparse': True})
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
class TestExtractRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Variable
feature_len = 12
rows = [0, 4, 4, 7]
np_array = np.ones((len(rows), feature_len)).astype("float32")
in_x = scope.var('X').get_selected_rows()
in_x.set_height(len(rows))
in_x.set_rows(rows)
in_x_tensor = in_x.get_tensor()
in_x_tensor.set(np_array, place)
# create Out Variable
out_tensor = scope.var('Out').get_tensor()
# create and run lookup_table operator
extract_rows_op = Operator("extract_rows", X='X', Out='Out')
extract_rows_op.run(scope, place)
# get result from Out
result_array = np.array(out_tensor)
result_array = [ele[0] for ele in result_array]
assert result_array == rows
def test_concat_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == '__main__':
unittest.main()
......@@ -49,53 +49,6 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
pass
class TestLookupTableIdsIsSelectedRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Variable
height = 10
rows = [0, 4, 4, 7]
row_numel = 12
# create and initialize W Variable
W = scope.var('W').get_tensor()
W_array = np.full((height, row_numel), 1.0).astype("float32")
for i in range(height):
W_array[i] *= i
W.set(W_array, place)
# create and initialize Ids Variable
ids_selected_rows = scope.var('Ids').get_selected_rows()
ids_selected_rows.set_height(len(rows))
ids_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
ids_tensor = ids_selected_rows.get_tensor()
ids_tensor.set(np_array, place)
# create Out Variable
Out = scope.var('Out').get_selected_rows()
# create and run lookup_table operator
concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
concat_rows_op.run(scope, place)
# get result from Out
Out_tensor = Out.get_tensor()
result_array = np.array(Out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(rows):
assert (row == result_array[idx]).all()
def test_concat_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
class TestLookupTableWIsSelectedRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册