提交 ce557989 编写于 作者: _青葱's avatar _青葱

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into doc

Merge branch develop
# C++ Data Feeding # C++ Data Feeding
In training with Paddle V2 API, data feeding wholly dependents on Python code. To get rid of the Python environment and achieve the goal of "wrapping the whole training by a while loop op" in Paddle Fluid, a C++ data feeding mechanism is required. While using Paddle V2 API for Training, data feeding completely depends on the Python code. To get rid of the Python environment and achieve the goal of "wrapping the whole training by a while loop op" in Paddle Fluid, a C++ data feeding mechanism is required.
In this document we show the fundamental design of C++ data feeding process, which includes the data reading, shuffling and batching. In this document we show the fundamental design of a C++ data feeding process, which includes data reading, shuffling and batching.
## Reader ## Reader
A new concept named 'Reader' is introduced. `Reader` is a series of inherited classes which can be hold by our `Variable` and they are used to read or process file data. In order to handle the above mentioned problem, a new concept called 'Reader' is introduced. `Reader` is a series of inherited classes which can be held by our `Variable` and they are used to read or process file data.
### `ReaderBase` ### `ReaderBase`
`ReaderBase` is the abstract base class of all readers. It defines the all readers' interfaces. `ReaderBase` is the abstract base class for all readers. It defines the interface for all readers.
```cpp ```cpp
class ReaderBase { class ReaderBase {
...@@ -20,10 +20,10 @@ class ReaderBase { ...@@ -20,10 +20,10 @@ class ReaderBase {
PADDLE_ENFORCE(!shapes_.empty()); PADDLE_ENFORCE(!shapes_.empty());
} }
// Read the next batch of data. (A 'batch' can be only one instance) // Read the next batch of data. (A 'batch' can be only one instance)
// If the next batch doesn't exist, the '*out' will be an empty std::vector. // If the next batch doesn't exist, '*out' will be an empty std::vector.
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
// Reinitialize the reader and read the file from the begin. // Reinitialize the reader and read the file from the beginning.
virtual void ReInit() = 0; virtual void ReInit() = 0;
// Get a certain read in data's shape. // Get a certain read in data's shape.
...@@ -42,36 +42,36 @@ class ReaderBase { ...@@ -42,36 +42,36 @@ class ReaderBase {
### `FileReader` and `DecoratedReader` ### `FileReader` and `DecoratedReader`
These two classes are derived from the `ReaderBase` and will further be derived by respective specific readers. That is to say, in our design, there are two kinds of readers: file readers and decorated readers. A file reader reads from a file of some specific format, and yield only one instance of data at a time. e.g. RecordIO reader, jpg reader, .... A decorated reader takes another reader(both file reader and decorated reader are OK) as its 'underlying reader'. It gets data from its underlying reader, does some process on them(shuffling, or batching), then yields processed data. The output data of a decorated reader can be a single instance or a batch. `ShuffleReader` and `BatchReader` are both decorated readers. These two classes are derived from the `ReaderBase` and will further be derived by more specific readers. Thus, in our design, there are two kinds of readers: file readers and decorated readers. A file reader reads from a file of some specific format, and yield only one instance of data at a time. For example, RecordIO reader, jpg reader, .... A decorated reader takes another reader(both file reader and decorated reader are OK) as its 'underlying reader'. It gets data from its underlying reader, does some processing on them(shuffling, or batching), then yields processed data. The output data of a decorated reader can be a single instance or a batch. `ShuffleReader` and `BatchReader` are both decorated readers.
All the readers share exactly the same interfaces defined in `ReaderBase`. So they can be decorated for more than one time: We can **shuffle** a reader's outputs and then **batch** the shuffle outputs. The interface consistency also allows related ops use readers without knowing what they are exactly. All the readers share exactly the same interface as defined in `ReaderBase`. So they can be decorated for more than one time: We can **shuffle** a reader's outputs and then **batch** the shuffle outputs. The interface consistency also allows related ops use readers without knowing what they are exactly.
### `ReaderHolder` ### `ReaderHolder`
Different readers belong to different class types. It leads to a problem: How can we drop them into `Variable`s and fetch them out by a unified method? For example, if a Variable holds a `BatchReader`, we can not get it by the following code: Different readers belong to different class types. This leads to a problem: How can we drop them into `Variable`s and fetch them out by a unified method? For example, if a Variable holds a `BatchReader`, we can not get it by the following code:
```cpp ```cpp
var->Get<ReaderBase>("batch_reader"); var->Get<ReaderBase>("batch_reader");
``` ```
we have to write: We would have to write:
```cpp ```cpp
var->Get<BatchReader>("batch_reader"); var->Get<BatchReader>("batch_reader");
``` ```
This requires each time getting a reader from a variable we must know the reader's type exactly. It is nearly impossible. This requires that in order to get a reader from a variable, every time, we must know the reader's type exactly. This is nearly impossible.
To solve this problem, we introduce `ReaderHolder` as a wrapper. It acts as an empty decorator of `ReaderBase`, which erases reader's type. With `ReaderHolder` we are able to fetch all types of readers by `var->Get<ReaderHolder>("...")` and regard the obtained object as a reader. To solve this problem, we introduce `ReaderHolder` as a wrapper. It acts as an empty decorator of `ReaderBase`, which hides reader's type. With `ReaderHolder` we are able to fetch all types of readers by `var->Get<ReaderHolder>("...")` and regard the obtained object as a reader.
## Related Operators ## Related Operators
To create and invoke readers, some now ops are introduced: To create and invoke readers, some new ops are introduced:
### `CreateReaderOp` ### `CreateReaderOp`
Each reader has its creating op. File readers' creating ops have no input and yield the created file reader as its output. Decorated readers' creating ops take the underlying readers as inputs and then yield new decorated readers. Each reader has its creation op. File readers' creation ops have no input and yield the created file reader as its output. Decorated readers' creation ops take the underlying readers as inputs and then yield new decorated readers.
### `ReadOp` ### `ReadOp`
......
# Design Doc: Distributed Training Architecture # Design Doc: Fluid Distributed Training Architecture
## Abstract ## Abstract
......
...@@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true; return true;
} }
bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
...@@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { ...@@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s); rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++; req_count_++;
}
return true; void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;
} }
bool RPCClient::Wait() { bool RPCClient::Wait() {
...@@ -154,7 +164,7 @@ bool RPCClient::Proceed() { ...@@ -154,7 +164,7 @@ bool RPCClient::Proceed() {
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
// TODO(gongwb): add more retries. // TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag); BaseProcessor* c = static_cast<BaseProcessor*>(tag);
if (!c->status_.ok()) { if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String() LOG(ERROR) << "proc param error:" << c->var_h_.String()
<< " grpc error:" << c->status_.error_message(); << " grpc error:" << c->status_.error_message();
...@@ -174,6 +184,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { ...@@ -174,6 +184,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
} }
grpc::ChannelArguments args; grpc::ChannelArguments args;
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max()); args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
......
...@@ -52,14 +52,14 @@ struct VarHandle { ...@@ -52,14 +52,14 @@ struct VarHandle {
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& msg); const sendrecv::VariableMessage& msg);
class ClientBase { class BaseProcessor {
public: public:
explicit ClientBase(std::shared_ptr<grpc::Channel> ch) { explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
context_ = NULL; context_ = NULL;
} }
virtual ~ClientBase() {} virtual ~BaseProcessor() {}
virtual void Prepare(const VarHandle& var_info, int64_t time_out) { virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
context_.reset(new grpc::ClientContext()); context_.reset(new grpc::ClientContext());
...@@ -91,9 +91,10 @@ class ClientBase { ...@@ -91,9 +91,10 @@ class ClientBase {
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)> typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
RequestSendCallBack; RequestSendCallBack;
class SendProcessor : public ClientBase { class SendProcessor : public BaseProcessor {
public: public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {} explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
virtual ~SendProcessor() {} virtual ~SendProcessor() {}
...@@ -110,9 +111,10 @@ class SendProcessor : public ClientBase { ...@@ -110,9 +111,10 @@ class SendProcessor : public ClientBase {
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)> typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
RequestGetCallBack; RequestGetCallBack;
class GetProcessor : public ClientBase { class GetProcessor : public BaseProcessor {
public: public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {} explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
virtual ~GetProcessor() {} virtual ~GetProcessor() {}
...@@ -126,10 +128,10 @@ class GetProcessor : public ClientBase { ...@@ -126,10 +128,10 @@ class GetProcessor : public ClientBase {
RequestGetCallBack response_call_back_ = ProcGetResponse; RequestGetCallBack response_call_back_ = ProcGetResponse;
}; };
class BatchBarrierProcessor : public ClientBase { class BatchBarrierProcessor : public BaseProcessor {
public: public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: ClientBase(ch) {} : BaseProcessor(ch) {}
virtual ~BatchBarrierProcessor() {} virtual ~BatchBarrierProcessor() {}
...@@ -137,6 +139,17 @@ class BatchBarrierProcessor : public ClientBase { ...@@ -137,6 +139,17 @@ class BatchBarrierProcessor : public ClientBase {
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
}; };
class FetchBarrierProcessor : public BaseProcessor {
public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {}
virtual ~FetchBarrierProcessor() {}
virtual void Process() {}
sendrecv::VariableMessage reply_;
};
class RPCClient { class RPCClient {
public: public:
bool AsyncSendVariable(const std::string& ep, bool AsyncSendVariable(const std::string& ep,
...@@ -151,7 +164,10 @@ class RPCClient { ...@@ -151,7 +164,10 @@ class RPCClient {
const std::string& var_name, const std::string& var_name,
int64_t time_out = 600 * 1000); int64_t time_out = 600 * 1000);
bool AsyncSendBatchBarrier(const std::string& ep, void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000); int64_t time_out = 600 * 1000);
bool Wait(); bool Wait();
......
...@@ -84,7 +84,7 @@ class RequestGet final : public RequestBase { ...@@ -84,7 +84,7 @@ class RequestGet final : public RequestBase {
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope, grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<char>* queue) SimpleBlockQueue<MessageWithName>* queue)
: RequestBase(service, cq), : RequestBase(service, cq),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
...@@ -101,11 +101,16 @@ class RequestGet final : public RequestBase { ...@@ -101,11 +101,16 @@ class RequestGet final : public RequestBase {
// proc request. // proc request.
std::string var_name = request_.varname(); std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name); auto* var = scope_->FindVar(var_name);
if (var_name != FETCH_BARRIER_MESSAGE) {
SerializeToMessage(var_name, var, *dev_ctx_, &reply_); SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
}
// TODO(gongwb): check var's info. // TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this); responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
queue_->Push('c'); MessageWithName msg_with_name =
// request name reply
std::make_pair(var_name, std::move(reply_));
queue_->Push(msg_with_name);
} }
protected: protected:
...@@ -114,12 +119,16 @@ class RequestGet final : public RequestBase { ...@@ -114,12 +119,16 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_; ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_; framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_; const platform::DeviceContext* dev_ctx_;
SimpleBlockQueue<char>* queue_; SimpleBlockQueue<MessageWithName>* queue_;
}; };
void AsyncGRPCServer::WaitClientGet(int count) { void AsyncGRPCServer::WaitClientGet(int count) {
for (int i = 0; i < count; ++i) { int fetch_barriers = 0;
var_get_queue_.Pop(); while (fetch_barriers < count) {
auto msg = var_get_queue_.Pop();
if (msg.first == FETCH_BARRIER_MESSAGE) {
fetch_barriers++;
}
} }
} }
......
...@@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const platform::DeviceContext *dev_ctx_; const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_; SimpleBlockQueue<MessageWithName> var_recv_queue_;
SimpleBlockQueue<char> var_get_queue_; SimpleBlockQueue<MessageWithName> var_get_queue_;
// condition of the sub program // condition of the sub program
std::mutex barrier_mutex_; std::mutex barrier_mutex_;
......
...@@ -32,6 +32,7 @@ namespace detail { ...@@ -32,6 +32,7 @@ namespace detail {
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
typedef void (*DestroyCallback)(void*); typedef void (*DestroyCallback)(void*);
......
...@@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase {
} }
} }
if (exit_flag) { if (exit_flag) {
rpc_service_->ShutDown();
rpc_service_->SetCond(1); rpc_service_->SetCond(1);
rpc_service_->ShutDown();
break; break;
} }
try { try {
...@@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase {
} }
rpc_service_->SetCond(1); rpc_service_->SetCond(1);
// FIXME(typhoonzero): use another condition to sync wait clients get. // FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(ins.size()); rpc_service_->WaitClientGet(fan_in);
sparse_vars.clear(); sparse_vars.clear();
} // while(true) } // while(true)
} }
......
...@@ -33,8 +33,16 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -33,8 +33,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W"); auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W
// and it must be a column vector with rank = 2 while the 2nd dimension
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
// contains the ids to be looked up in W;
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1); PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out"); ctx->ShareLoD("Ids", /*->*/ "Out");
...@@ -54,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -54,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker) LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W", AddInput("W",
"An input represents embedding tensors, " "(Tensor) The input represents embedding tensors, "
"which is a learnable parameter."); "which is a learnable parameter.");
AddInput("Ids", AddInput(
"An input with type int32 or int64 " "Ids",
"contains the ids to be looked up in W. " "(Tensor or SelectedRows) Ids's type can be Tensor or "
"Ids must be a column vector with rank = 2. " "SelectedRows, when Ids's type is Tensor, this tensor contains "
"The 2nd dimension size must be 1."); "the ids to be looked up in W and it must be a column vector with "
AddOutput("Out", "The lookup results, which have the same type as W."); "rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
"SelectedRows, the rows of Ids contains the ids to be looked up "
"in W.");
AddOutput("Out",
"(Tensor or SelectedRows) The lookup results, which have the "
"same type as W.");
AddAttr<bool>("is_sparse", AddAttr<bool>("is_sparse",
"(boolean, default false) " "(boolean, default false) "
"Sparse update") "Sparse update.")
.SetDefault(false); .SetDefault(false);
AddAttr<int64_t>("padding_idx", AddAttr<int64_t>("padding_idx",
"(int64, default -1) " "(int64, default -1) "
...@@ -76,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -76,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
Lookup Table Operator. Lookup Table Operator.
This operator is used to perform lookups on the parameter W, This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor. then concatenated into a dense or sparse tensor.
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
at this time, Ids can carry the LoD (Level of Details) information, or not, and
the output only shares the LoD information with input Ids.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC"); )DOC");
} }
......
...@@ -74,14 +74,32 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -74,14 +74,32 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W"); auto* table_t = context.Input<LoDTensor>("W");
auto* ids_t = context.Input<LoDTensor>("Ids");
auto* output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx"); int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto* ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out");
int64_t* ids;
int64_t K;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<framework::LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
K = ids_t->numel();
} else if (ids_var->IsType<framework::SelectedRows>()) {
auto* ids_t = context.Input<framework::SelectedRows>("Ids");
ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace()));
K = ids_t->rows().size();
output_t->Resize({K, table_t->dims()[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto* ids = ids_t->data<int64_t>();
auto* table = table_t->data<T>(); auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace()); auto* output = output_t->mutable_data<T>(context.GetPlace());
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
...@@ -29,25 +30,45 @@ template <typename T> ...@@ -29,25 +30,45 @@ template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class LookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W"); // float tensor auto* table_t = context.Input<LoDTensor>("W");
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor auto* ids_var = context.InputVar("Ids");
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor Tensor* output_t = context.Output<Tensor>("Out");
int64_t* ids;
int64_t ids_numel;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
ids_numel = ids_t->numel();
} else if (ids_var->IsType<SelectedRows>()) {
auto* ids_t = context.Input<SelectedRows>("Ids");
ids = const_cast<int64_t*>(ids_t->rows().data());
ids_numel = ids_t->rows().size();
output_t->Resize({ids_numel, table_t->dims()[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
int64_t padding_idx = context.Attr<int64_t>("padding_idx"); int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int N = table_t->dims()[0]; int N = table_t->dims()[0];
int D = table_t->dims()[1]; int D = table_t->dims()[1];
auto* ids = ids_t->data<int64_t>();
auto* table = table_t->data<T>(); auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace()); auto* output = output_t->mutable_data<T>(context.GetPlace());
if (padding_idx == -1) { if (padding_idx == -1) {
for (int64_t i = 0; i < ids_t->numel(); ++i) { for (int64_t i = 0; i < ids_numel; ++i) {
PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0); PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
} }
} else { } else {
for (int64_t i = 0; i < ids_t->numel(); ++i) { for (int64_t i = 0; i < ids_numel; ++i) {
if (ids[i] == padding_idx) { if (ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T)); memset(output + i * D, 0, D * sizeof(T));
} else { } else {
......
...@@ -104,19 +104,38 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { ...@@ -104,19 +104,38 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
" Input(Communicator) of AllReduce op input should not be NULL"); " Input(Communicator) of AllReduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Output(Out) of AllReduce op output should not be NULL"); " Output(Out) of AllReduce op output should not be NULL");
auto x_dims = ctx->GetInputsDim("X");
std::string reduction = ctx->Attrs().Get<std::string>("reduction"); std::string reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"), reduction == "ncclMin" || reduction == "ncclMax"),
"invalid reduction."); "invalid reduction.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims); ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
// AllReduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op");
AddAttr<std::string>("reduction",
"(string, default 'ncclSum') "
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
.SetDefault("ncclSum");
AddComment(R"DOC(
NCCLAllReduce Operator.
AllReduce the input tensors.
)DOC");
}
};
// ReduceOp // ReduceOp
class NCCLReduceOp : public framework::OperatorWithKernel { class NCCLReduceOp : public framework::OperatorWithKernel {
public: public:
...@@ -143,50 +162,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel { ...@@ -143,50 +162,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
} }
}; };
// BcastOp
class NCCLBcastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
" Input(Communicator) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Output(Out) of Bcast op output should not be NULL");
int root = ctx->Attrs().Get<int>("root");
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
// AllreduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op");
AddAttr<std::string>("reduction",
"(string, default 'ncclSum') "
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
.SetDefault("ncclSum");
AddComment(R"DOC(
NCCLAllReduce Operator.
AllReduce the input tensors.
)DOC");
}
};
// ReduceOp // ReduceOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -213,6 +188,29 @@ Reduce the tensors. ...@@ -213,6 +188,29 @@ Reduce the tensors.
} }
}; };
// BcastOp
class NCCLBcastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
" Input(Communicator) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Output(Out) of Bcast op output should not be NULL");
int root = ctx->Attrs().Get<int>("root");
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
// BcastOp // BcastOp
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
......
...@@ -43,13 +43,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -43,13 +43,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
auto* x = ctx.Input<LoDTensor>("X");
auto ins = ctx.MultiInput<LoDTensor>("X"); auto* out = ctx.Output<LoDTensor>("Out");
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto* comm = ctx.Input<Communicator>("Communicator");
std::string reduction = ctx.Attr<std::string>("reduction"); std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") { if (reduction == "ncclMin") {
reduction_op_ = ncclMin; reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") { } else if (reduction == "ncclMax") {
...@@ -61,30 +60,19 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -61,30 +60,19 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_THROW("Invalid reduction. default ncclSum."); PADDLE_THROW("Invalid reduction. default ncclSum.");
} }
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = ctx.cuda_device_context().stream();
// device id // device id
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId(); int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id); int idx = comm->GetCommId(gpu_id);
VLOG(3) << "gpu : "
for (size_t i = 0; i < ins.size(); ++i) { << " invoke allreduce. send " << x->numel() << " recv "
VLOG(1) << "gpu : " << out->numel();
<< " invoke allreduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()), x->data<T>(), out->mutable_data<T>(ctx.GetPlace()), out->numel(),
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_, NCCLTypeWrapper<T>::type, reduction_op_, comm->comms().at(idx),
comm->comms().at(idx), stream)); ctx.cuda_device_context().stream()));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); VLOG(3) << "gpu : "
<< " finished allreduce. send " << x->numel() << " recv "
VLOG(1) << "gpu : " << out->numel();
<< " finished allreduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
}
} }
}; };
...@@ -94,13 +82,13 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -94,13 +82,13 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
auto x = ctx.Input<LoDTensor>("X"); // x0, x1, x2
auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2 auto out = ctx.Output<LoDTensor>("Out");
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto* comm = ctx.Input<Communicator>("Communicator");
int root = ctx.Attr<int>("root");
std::string reduction = ctx.Attr<std::string>("reduction"); std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") { if (reduction == "ncclMin") {
reduction_op_ = ncclMin; reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") { } else if (reduction == "ncclMax") {
...@@ -112,40 +100,21 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -112,40 +100,21 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_THROW("Invalid reduction. default ncclSum."); PADDLE_THROW("Invalid reduction. default ncclSum.");
} }
int root = ctx.Attr<int>("root");
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id // device id
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId(); int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id); int idx = comm->GetCommId(gpu_id);
auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) {
if (root == platform::kInvalidGPUId) {
root = hasher(ins_names[i]) % comm->comms().size();
}
T* recvbuffer = nullptr; T* recvbuffer = nullptr;
if (root == gpu_id) { if (root == gpu_id) {
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace()); recvbuffer = out->mutable_data<T>(ctx.GetPlace());
} }
VLOG(3) << "gpu : " << gpu_id << " invoke reduce. send " << x->numel()
VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send " << " recv " << out->numel();
<< ins[i]->numel() << " recv " << outs[i]->numel();
PADDLE_ENFORCE(platform::dynload::ncclReduce( PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(), x->data<T>(), recvbuffer, x->numel(), NCCLTypeWrapper<T>::type,
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms().at(idx), reduction_op_, root, comm->comms().at(idx),
stream)); ctx.cuda_device_context().stream()));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); VLOG(3) << "gpu : " << gpu_id << " finished reduce. send " << x->numel()
<< " recv " << out->numel();
VLOG(1) << "gpu : " << gpu_id << " finished reduce. send "
<< ins[i]->numel() << " recv " << outs[i]->numel();
}
} }
}; };
...@@ -155,47 +124,27 @@ class NCCLBcastKernel : public framework::OpKernel<T> { ...@@ -155,47 +124,27 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
int root = ctx.Attr<int>("root"); int root = ctx.Attr<int>("root");
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id // device id
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId(); int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id); int idx = comm->GetCommId(gpu_id);
if (idx == root) { if (idx == root) {
auto ins = ctx.MultiInput<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
for (size_t i = 0; i < ins.size(); ++i) { VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. send " << x->numel();
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send "
<< ins[i]->numel();
VLOG(1) << " before ncclBcast";
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type, (void*)x->data<T>(), x->numel(), NCCLTypeWrapper<T>::type, root,
root, comm->comms().at(idx), stream)); comm->comms().at(idx), ctx.cuda_device_context().stream()));
VLOG(1) << " after ncclBcast"; VLOG(3) << "gpu : " << gpu_id << " finished Bcast.";
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished Bcast.";
}
} else { } else {
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto* out = ctx.Output<LoDTensor>("Out");
for (size_t i = 0; i < outs.size(); ++i) { VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. recv buffer "
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer " << framework::product(out->dims());
<< framework::product(outs[i]->dims());
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(), out->mutable_data<T>(ctx.GetPlace()), out->numel(),
NCCLTypeWrapper<T>::type, root, comm->comms().at(idx), stream)); NCCLTypeWrapper<T>::type, root, comm->comms().at(idx),
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); ctx.cuda_device_context().stream()));
VLOG(3) << "gpu : " << gpu_id << " finished Bcast. recv " << out->numel();
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
<< outs[i]->numel();
}
} }
} }
}; };
......
...@@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase { ...@@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase {
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
// tell pservers that current trainer have called fetch
for (auto& ep : endpoints) {
VLOG(3) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
} }
} }
}; };
......
...@@ -250,6 +250,8 @@ class DistributeTranspiler: ...@@ -250,6 +250,8 @@ class DistributeTranspiler:
def get_trainer_program(self): def get_trainer_program(self):
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops) self.program.global_block().delete_ops(self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.program.__str__()
return self.program return self.program
def get_pserver_program(self, endpoint): def get_pserver_program(self, endpoint):
...@@ -309,7 +311,8 @@ class DistributeTranspiler: ...@@ -309,7 +311,8 @@ class DistributeTranspiler:
for _, opt_op in enumerate(opt_op_on_pserver): for _, opt_op in enumerate(opt_op_on_pserver):
if ufind.is_connected(op, opt_op): if ufind.is_connected(op, opt_op):
if self._is_opt_op(op): if self._is_opt_op(op):
self._append_pserver_ops(optimize_block, op, endpoint) self._append_pserver_ops(optimize_block, op, endpoint,
default_main_program())
else: else:
self._append_pserver_non_opt_ops(optimize_block, op) self._append_pserver_non_opt_ops(optimize_block, op)
break break
...@@ -520,7 +523,8 @@ class DistributeTranspiler: ...@@ -520,7 +523,8 @@ class DistributeTranspiler:
orig_var_name = varname[:suff_idx] orig_var_name = varname[:suff_idx]
return orig_var_name return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint): def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
origin_program):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block() pserver_block = program.global_block()
new_inputs = dict() new_inputs = dict()
...@@ -576,7 +580,17 @@ class DistributeTranspiler: ...@@ -576,7 +580,17 @@ class DistributeTranspiler:
elif key == "LearningRate": elif key == "LearningRate":
# leraning rate variable has already be created by non-optimize op, # leraning rate variable has already be created by non-optimize op,
# don't create it once again. # don't create it once again.
lr_varname = opt_op.input(key)[0]
if pserver_block.vars.has_key(lr_varname):
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]] new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
else:
origin_var = origin_program.global_block().vars[lr_varname]
tmpvar = pserver_block.create_var(
name=origin_var.name,
persistable=origin_var.persistable,
dtype=origin_var.dtype,
shape=origin_var.shape)
new_inputs[key] = tmpvar
for key in opt_op.input_names: for key in opt_op.input_names:
new_shape = None new_shape = None
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestLookupTableOp(OpTest): class TestLookupTableOp(OpTest):
...@@ -47,5 +49,52 @@ class TestLookupTableOpWithPadding(TestLookupTableOp): ...@@ -47,5 +49,52 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
pass pass
class TestLookupTableIdsIsSelectedRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Variable
height = 10
rows = [0, 4, 4, 7]
row_numel = 12
# create and initialize W Variable
W = scope.var('W').get_tensor()
W_array = np.full((height, row_numel), 1.0).astype("float32")
for i in range(height):
W_array[i] *= i
W.set(W_array, place)
# create and initialize Ids Variable
ids_selected_rows = scope.var('Ids').get_selected_rows()
ids_selected_rows.set_height(len(rows))
ids_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
ids_tensor = ids_selected_rows.get_tensor()
ids_tensor.set(np_array, place)
# create Out Variable
Out = scope.var('Out').get_selected_rows()
# create and run lookup_table operator
concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
concat_rows_op.run(scope, place)
# get result from Out
Out_tensor = Out.get_tensor()
result_array = np.array(Out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(rows):
assert (row == result_array[idx]).all()
def test_concat_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册