提交 85d5f8e2 编写于 作者: T typhoonzero

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into rename_rpc_ops

......@@ -18,6 +18,11 @@ dynamic_lstm
.. autofunction:: paddle.v2.fluid.layers.dynamic_lstm
:noindex:
dynamic_lstmp
-------------
.. autofunction:: paddle.v2.fluid.layers.dynamic_lstmp
:noindex:
dynamic_gru
-----------
.. autofunction:: paddle.v2.fluid.layers.dynamic_gru
......
......@@ -174,7 +174,7 @@ class MaxOutFunctor {
};
```
CPU implemention is in .cc file
CPU implementation is in .cc file
```
template <typename T>
......@@ -188,7 +188,7 @@ class MaxOutFunctor<platform::CPUDeviceContext, T> {
};
```
CUDA implemention is in .cu file
CUDA implementation is in .cu file
```
template <typename T>
......@@ -203,9 +203,9 @@ class MaxOutFunctor<platform::CUDADeviceContext, T> {
```
We first obtain the computing handle from a concrete DeviceContext, and then compute on tensors.
We first obtain the computing handle from a concrete DeviceContext and then compute on tensors.
The implemention of `OpKernel` is similar to math functors, the extra thing we need to do is to register the OpKernel in a global map.
The implementation of `OpKernel` is similar to math functors, the extra thing we need to do is to register the OpKernel in a global map.
Fluid provides different register interfaces in op_registry.h
......
......@@ -98,3 +98,5 @@ if(NOT WITH_C_API AND WITH_FLUID)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/framework.pb.h DESTINATION include/paddle/framework)
install(FILES details/cow_ptr.h details/op_registry.h DESTINATION include/paddle/framework/details)
endif()
cc_test(channel_test SRCS channel_test.cc)
......@@ -13,75 +13,52 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable>
#include <mutex>
#include <queue>
#include <stddef.h> // for size_t
namespace paddle {
namespace framework {
// Channel is the abstract class of buffered and un-buffered channels.
template <typename T>
class Channel {
public:
explicit Channel(std::size_t capacity) : capacity_(capacity) {}
void Send(T* channel_element) {
std::unique_lock<std::mutex> lock(mu_);
if (IsBounded()) {
full_cond_var_.wait(lock, [this]() {
bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true;
return capacity_valid;
});
}
channel_.push_back(std::move(*channel_element));
lock.unlock();
empty_cond_var_.notify_one();
}
virtual void Send(T*) = 0;
virtual void Receive(T*) = 0;
virtual size_t Cap() = 0;
T* Receive() {
std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); });
T* channel_element = std::move(channel_.front());
channel_.pop_front();
NotifyAllSenders(&lock);
return channel_element;
}
size_t Size() {
std::unique_lock<std::mutex> lock(mu_);
return channel_.size();
}
// Don't delete channels; instead, call Channel::Close.
protected:
virtual ~Channel() {}
};
void Clear() {
std::unique_lock<std::mutex> lock(mu_);
channel_.clear();
// Forward declaration of channel implementations.
namespace details {
template <typename T>
class Buffered;
template <typename T>
class UnBuffered;
} // namespace details
NotifyAllSenders(&lock);
template <typename T>
Channel<T>* MakeChannel(size_t buffer_size) {
if (buffer_size > 0) {
return new details::Buffered<T>(buffer_size);
}
return new details::UnBuffered<T>();
}
private:
std::size_t capacity_;
std::mutex mu_;
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
std::deque<T> channel_;
private:
void NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
if (IsBounded()) {
lock->unlock();
full_cond_var_.notify_one();
}
template <typename T>
void CloseChannel(Channel<T>* ch) {
if (ch->Cap() > 0) {
delete dynamic_cast<details::Buffered<T>*>(ch);
} else {
delete dynamic_cast<details::UnBuffered<T>*>(ch);
}
}
bool IsBounded() const { return capacity_ > 0; }
bool IsCapacityFull() const { return channel_.size() >= capacity_; }
};
} // namespace operator
} // namespace framework
} // namespace paddle
#include "paddle/framework/details/buffered_channel.h"
#include "paddle/framework/details/unbuffered_channel.h"
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/channel.h"
#include "gtest/gtest.h"
TEST(Channel, MakeAndClose) {
using paddle::framework::Channel;
using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel;
Channel<int>* ch = MakeChannel<int>(10);
CloseChannel(ch);
}
......@@ -79,5 +79,33 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) {
}
}
inline std::string DataTypeToString(const proto::DataType type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP16:
return "float16";
case DataType::FP32:
return "float32";
case DataType::FP64:
return "float64";
case DataType::INT16:
return "int16";
case DataType::INT32:
return "int32";
case DataType::INT64:
return "int64";
case DataType::BOOL:
return "bool";
default:
PADDLE_THROW("Not support type %d", type);
}
}
inline std::ostream& operator<<(std::ostream& out,
const proto::DataType& type) {
out << DataTypeToString(type);
return out;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable>
#include <deque>
#include <mutex>
#include "paddle/framework/channel.h"
namespace paddle {
namespace framework {
namespace details {
template <typename T>
class Buffered : public paddle::framework::Channel<T> {
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual void Send(T*);
virtual void Receive(T*);
virtual size_t Cap() { return cap_; }
private:
size_t cap_;
std::mutex mu_;
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
std::deque<T> channel_;
Buffered(size_t cap) : cap_(cap) {}
virtual ~Buffered();
void NotifyAllSenders(std::unique_lock<std::mutex>*);
};
template <typename T>
void Buffered<T>::Send(T* item) {
std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; });
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
}
template <typename T>
void Buffered<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); });
*item = std::move(channel_.front());
channel_.pop_front();
NotifyAllSenders(&lock);
}
template <typename T>
Buffered<T>::~Buffered() {
std::unique_lock<std::mutex> lock(mu_);
channel_.clear();
NotifyAllSenders(&lock);
}
template <typename T>
void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
lock->unlock();
full_cond_var_.notify_one();
}
} // namespace details
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable>
#include <deque>
#include <mutex>
#include "paddle/framework/channel.h"
namespace paddle {
namespace framework {
namespace details {
template <typename T>
class UnBuffered : public paddle::framework::Channel<T> {
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual void Send(T*);
virtual void Receive(T*);
virtual size_t Cap() { return 0; }
private:
UnBuffered() {}
virtual ~UnBuffered();
};
template <typename T>
void UnBuffered<T>::Send(T* channel_element) {}
template <typename T>
void UnBuffered<T>::Receive(T*) {}
template <typename T>
UnBuffered<T>::~UnBuffered() {}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -26,9 +26,9 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN);
ASSERT_EQ(
paddle::framework::KernelTypeToString(op_kernel_type),
"data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]");
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type),
"data_type[float32]:data_layout[NCHW]:place[CPUPlace]:library_type["
"CUDNN]");
}
TEST(OpKernelType, Hash) {
......
......@@ -122,9 +122,11 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(recv_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS send_op recv_op sum_op executor)
op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS send_op listen_and_serv_op sum_op executor)
else()
set(DEPS_OPS ${DEPS_OPS} send_op recv_op)
set(DEPS_OPS ${DEPS_OPS} send_op recv_op listen_and_serv_op)
endif()
op_library(cond_op DEPS framework_proto tensor net_op)
......@@ -147,6 +149,7 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function)
......
......@@ -97,6 +97,21 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}
bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;
return true;
}
bool RPCClient::Wait() {
if (req_count_ <= 0) {
return true;
......
......@@ -71,6 +71,15 @@ class ClientBase {
context_->set_deadline(deadline);
}
virtual void Prepare(int64_t time_out) {
context_.reset(new grpc::ClientContext());
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
context_->set_deadline(deadline);
}
virtual void Process() = 0;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
......@@ -117,6 +126,17 @@ class GetProcessor : public ClientBase {
RequestGetCallBack response_call_back_ = ProcGetResponse;
};
class BatchBarrierProcessor : public ClientBase {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: ClientBase(ch) {}
virtual ~BatchBarrierProcessor() {}
virtual void Process() {}
sendrecv::VoidMessage reply_;
};
class RPCClient {
public:
bool AsyncSendVariable(const std::string& ep,
......@@ -130,6 +150,10 @@ class RPCClient {
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
bool Wait();
private:
......
......@@ -132,6 +132,7 @@ void AsyncGRPCServer::RunSyncUpdate() {
cq_send_ = builder.AddCompletionQueue();
cq_get_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ << std::endl;
......@@ -141,11 +142,11 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false,
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));
t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true,
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));
// wait server
......@@ -174,7 +175,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
}
RequestSend* send =
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
VLOG(4) << "create RequestSend status:" << send->Status();
VLOG(4) << "Create RequestSend status:" << send->Status();
}
void AsyncGRPCServer::TryToRegisterNewGetOne() {
......@@ -184,11 +185,11 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status();
VLOG(4) << "Create RequestGet status:" << get->Status();
}
// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
// FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
TryToRegisterNewOne();
......
......@@ -57,8 +57,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void ShutDown();
protected:
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name,
void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name,
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
......
......@@ -30,6 +30,9 @@ namespace paddle {
namespace operators {
namespace detail {
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg);
......
......@@ -25,9 +25,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
class SendOp : public framework::OperatorBase {
class RecvOp : public framework::OperatorBase {
public:
SendOp(const std::string& type, const framework::VariableNameMap& inputs,
RecvOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
......@@ -44,7 +44,6 @@ class SendOp : public framework::OperatorBase {
VLOG(3) << "getting " << outs[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
PADDLE_ENFORCE(client_.Wait());
}
......@@ -52,21 +51,22 @@ class SendOp : public framework::OperatorBase {
mutable detail::RPCClient client_;
};
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendOpMaker(OpProto* proto, OpAttrChecker* op_checker)
RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable();
AddComment(R"DOC(
Send operator
Recv operator
This operator will send tensor to recv_op at the parameter server.
This operator can get variables from server side.
)DOC");
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
"Server endpoints to recv variables"
"from.")
.SetDefault({});
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
......@@ -81,4 +81,4 @@ This operator will send tensor to recv_op at the parameter server.
namespace ops = paddle::operators;
REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker);
REGISTER_OPERATOR(recv, ops::RecvOp, ops::RecvOpMaker);
......@@ -38,26 +38,34 @@ class SendOp : public framework::OperatorBase {
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
bool do_get = Attr<bool>("DoGet");
std::vector<std::string> endpoints =
Attr<std::vector<std::string>>("endpoints");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i];
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
}
PADDLE_ENFORCE(client_.Wait());
for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
client_.AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(client_.Wait());
if (do_get) {
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
PADDLE_ENFORCE(client_.Wait());
}
PADDLE_ENFORCE(client_.Wait());
}
private:
// TODO(typhoonzero): put RPCClient in a Variable.
mutable detail::RPCClient client_;
};
......
......@@ -132,7 +132,7 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", block});
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"RX", {"x1"}}}, {}, attrs);
f::OpRegistry::CreateOp("listen_and_serv", {}, {}, attrs);
listen_and_serv_op->Run(scope, place);
}
......
......@@ -32,7 +32,7 @@ function cmake_gen() {
cat <<EOF
========================================
Configuring cmake in /paddle/build ...
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_BUILD_TYPE=${BUILD_TYPE:Release}
${PYTHON_FLAGS}
-DWITH_DOC=OFF
-DWITH_GPU=${WITH_GPU:-OFF}
......@@ -54,7 +54,7 @@ EOF
# docker environment is fully controlled by this script.
# See /Paddle/CMakeLists.txt, UNITTEST_USE_VIRTUALENV option.
cmake .. \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE:Release} \
${PYTHON_FLAGS} \
-DWITH_DOC=OFF \
-DWITH_GPU=${WITH_GPU:-OFF} \
......
......@@ -471,11 +471,10 @@ class DistributeTranspiler:
else:
self._append_pserver_non_opt_ops(optimize_sub_program,
pserver_program, opt_op)
# Append the recv op
# Append the listen_and_serv op
pserver_program.global_block().append_op(
type="recv",
inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"]
}, # grads to recv
type="listen_and_serv",
inputs={},
outputs={},
attrs={
"OptimizeBlock": optimize_sub_program.global_block(),
......
......@@ -478,7 +478,7 @@ class Operator(object):
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'parallel_do'
'recv', 'listen_and_serv', 'parallel_do'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
......
......@@ -14,8 +14,10 @@
from .. import core
from ..layer_helper import LayerHelper
from control_flow import BlockGuard
from ..layer_helper import LayerHelper
__all__ = ['data']
__all__ = ['data', 'BlockGuardServ', 'ListenAndServ', 'Send']
def data(name,
......@@ -74,3 +76,151 @@ def data(name,
type=type,
stop_gradient=stop_gradient,
lod_level=lod_level)
class BlockGuardServ(BlockGuard):
"""
BlockGuardServ class.
BlockGuardServ class is used to create an op with a block in a program.
"""
def __init__(self, server):
if not (isinstance(server, ListenAndServ)):
raise TypeError("BlockGuardServ takes a ListenAndServ")
super(BlockGuardServ, self).__init__(server.helper.main_program)
self.server = server
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
self.server.complete_op()
return super(BlockGuardServ, self).__exit__(exc_type, exc_val, exc_tb)
class ListenAndServ(object):
"""
ListenAndServ class.
ListenAndServ class is used to wrap listen_and_serv op to create a server
which can receive variables from clients and run a block.
"""
def __init__(self, endpoint, fan_in=1, optimizer_mode=True):
self.helper = LayerHelper("listen_and_serv")
self.inputs = []
self.outputs = []
self.endpoint = endpoint
self.fan_in = fan_in
# FIXME(typhoonzero): add optimizer_mode is stupid, should make it more
# general.
self.optimizer_mode = optimizer_mode
def do(self):
return BlockGuardServ(self)
def get_params_and_grads(self):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
# params and grads in the same order.
params = list()
grads = list()
for op in current_block.ops:
# FIXME(typhoonzero): op.inputs is None if it's cloned.
if self.optimizer_mode:
if "Grad" in op.inputs and "Param" in op.inputs:
params.append(op.inputs["Param"].name)
grads.append(op.inputs["Grad"].name)
else:
# simple recv mode, recv operators inputs.
for iname in op.input_names:
for in_var_name in op.input(iname):
params.append(parent_block.var(in_var_name))
grads.append(parent_block.var(in_var_name))
return params, grads
def parent_block(self):
prog = self.helper.main_program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0
parent_block = prog.block(parent_idx)
return parent_block
def complete_op(self):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
params, grads = self.get_params_and_grads()
param_names = [p.name for p in params]
grad_names = [g.name for g in grads]
parent_block.append_op(
type='listen_and_serv',
inputs={},
outputs={},
attrs={
'endpoint': self.endpoint,
'Fanin': self.fan_in,
'ParamList': param_names,
'GradList': grad_names,
'OptimizeBlock': current_block
})
def Send(endpoints, send_vars, get_vars):
"""
Send layer
Args:
endpoints: comma seperated IP:PORT pairs in the order
of send_vars to send
send_vars: vars to send
get_vars: vars to get from server after send completes.
Send variables to the server side, and get vars from server
side when server have finished running server side program.
"""
assert (type(send_vars) == list)
assert (type(get_vars) == list)
epmap = endpoints.split(",")
endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals())
helper.append_op(
type="send",
inputs={"X": send_vars},
outputs={"Out": get_vars},
attrs={"endpoints": endpoints,
"epmap": epmap})
def Recv(endpoints, get_vars):
"""
Recv layer
Args:
endpoints: comma seperated IP:PORT pairs in the order
of send_vars to send
send_vars: vars to send
get_vars: vars to get from server after send completes.
Send variables to the server side, and get vars from server
side when server have finished running server side program.
"""
assert (type(send_vars) == list)
assert (type(get_vars) == list)
epmap = endpoints.split(",")
endpoints = list(set(epmap))
helper = LayerHelper("Recv", **locals())
helper.append_op(
type="recv",
inputs={"X": get_vars},
outputs={"Out": get_vars},
attrs={"endpoints": endpoints,
"epmap": epmap})
......@@ -26,6 +26,7 @@ __all__ = [
'fc',
'embedding',
'dynamic_lstm',
'dynamic_lstmp',
'dynamic_gru',
'gru_unit',
'linear_chain_crf',
......@@ -256,7 +257,8 @@ def dynamic_lstm(input,
gate_activation='sigmoid',
cell_activation='tanh',
candidate_activation='tanh',
dtype='float32'):
dtype='float32',
name=None):
"""
**Dynamic LSTM Layer**
......@@ -282,7 +284,7 @@ def dynamic_lstm(input,
W_{fc}, W_{oc}` are diagonal weight matrices for peephole connections. In
our implementation, we use vectors to reprenset these diagonal weight
matrices. The :math:`b` terms denote bias vectors (:math:`b_i` is the input
gate bias vector), :math:`\sigma` is the non-line activations, such as
gate bias vector), :math:`\sigma` is the non-linear activations, such as
logistic sigmoid function, and :math:`i, f, o` and :math:`c` are the input
gate, forget gate, output gate, and cell activation vectors, respectively,
all of which have the same size as the cell output activation vector :math:`h`.
......@@ -308,25 +310,25 @@ def dynamic_lstm(input,
(T X 4D), where T is the total time steps in this
mini-batch, D is the hidden size.
size(int): 4 * hidden size.
param_attr(ParamAttr): The parameter attribute for the learnable
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weights.
- The shape is (D x 4D), where D is the hidden
size.
- Weights = {:math:`W_{ch}, W_{ih}, \
W_{fh}, W_{oh}`}
bias_attr(ParamAttr): The bias attribute for the learnable bias
- The shape is (D x 4D), where D is the hidden
size.
bias_attr(ParamAttr|None): The bias attribute for the learnable bias
weights, which contains two parts, input-hidden
bias weights and peephole connections weights if
setting `use_peepholes` to `True`.
1. `use_peepholes = False`
- The shape is (1 x 4D).
- Biases = {:math:`b_c, b_i, b_f, b_o`}.
- The shape is (1 x 4D).
2. `use_peepholes = True`
- The shape is (1 x 7D).
- Biases = { :math:`b_c, b_i, b_f, b_o, W_{ic}, \
W_{fc}, W_{oc}`}.
- The shape is (1 x 7D).
use_peepholes(bool): Whether to enable diagonal/peephole connections,
default `True`.
is_reverse(bool): Whether to compute reversed LSTM, default `False`.
......@@ -339,6 +341,8 @@ def dynamic_lstm(input,
Choices = ["sigmoid", "tanh", "relu", "identity"],
default "tanh".
dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
tuple: The hidden state, and cell state of LSTM. The shape of both \
......@@ -353,6 +357,7 @@ def dynamic_lstm(input,
forward, _ = fluid.layers.dynamic_lstm(
input=forward_proj, size=hidden_dim * 4, use_peepholes=False)
"""
helper = LayerHelper('lstm', **locals())
size = size / 4
weight = helper.create_parameter(
......@@ -389,6 +394,192 @@ def dynamic_lstm(input,
return hidden, cell
def dynamic_lstmp(input,
size,
proj_size,
param_attr=None,
bias_attr=None,
use_peepholes=True,
is_reverse=False,
gate_activation='sigmoid',
cell_activation='tanh',
candidate_activation='tanh',
proj_activation='tanh',
dtype='float32',
name=None):
"""
**Dynamic LSTMP Layer**
LSTMP (LSTM with recurrent projection) layer has a separate projection
layer after the LSTM layer, projecting the original hidden state to a
lower-dimensional one, which is proposed to reduce the number of total
parameters and furthermore computational complexity for the LSTM,
espeacially for the case that the size of output units is relative
large (https://research.google.com/pubs/archive/43905.pdf).
The formula is as follows:
.. math::
i_t & = \sigma(W_{ix}x_{t} + W_{ir}r_{t-1} + W_{ic}c_{t-1} + b_i)
f_t & = \sigma(W_{fx}x_{t} + W_{fr}r_{t-1} + W_{fc}c_{t-1} + b_f)
\\tilde{c_t} & = act_g(W_{cx}x_t + W_{cr}r_{t-1} + b_c)
o_t & = \sigma(W_{ox}x_{t} + W_{or}r_{t-1} + W_{oc}c_t + b_o)
c_t & = f_t \odot c_{t-1} + i_t \odot \\tilde{c_t}
h_t & = o_t \odot act_h(c_t)
r_t & = \overline{act_h}(W_{rh}h_t)
In the above formula:
* :math:`W`: Denotes weight matrices (e.g. :math:`W_{xi}` is \
the matrix of weights from the input gate to the input).
* :math:`W_{ic}`, :math:`W_{fc}`, :math:`W_{oc}`: Diagonal weight \
matrices for peephole connections. In our implementation, \
we use vectors to reprenset these diagonal weight matrices.
* :math:`b`: Denotes bias vectors (e.g. :math:`b_i` is the input gate \
bias vector).
* :math:`\sigma`: The activation, such as logistic sigmoid function.
* :math:`i, f, o` and :math:`c`: The input gate, forget gate, output \
gate, and cell activation vectors, respectively, all of which have \
the same size as the cell output activation vector :math:`h`.
* :math:`h`: The hidden state.
* :math:`r`: The recurrent projection of the hidden state.
* :math:`\\tilde{c_t}`: The candidate hidden state, whose \
computation is based on the current input and previous hidden state.
* :math:`\odot`: The element-wise product of the vectors.
* :math:`act_g` and :math:`act_h`: The cell input and cell output \
activation functions and `tanh` is usually used for them.
* :math:`\overline{act_h}`: The activation function for the projection \
output, usually using `identity` or same as :math:`act_h`.
Set `use_peepholes` to `False` to disable peephole connection. The formula
is omitted here, please refer to the paper
http://www.bioinf.jku.at/publications/older/2604.pdf for details.
Note that these :math:`W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}`
operations on the input :math:`x_{t}` are NOT included in this operator.
Users can choose to use fully-connected layer before LSTMP layer.
Args:
input(Variable): The input of dynamic_lstmp layer, which supports
variable-time length input sequence. The underlying
tensor in this Variable is a matrix with shape
(T X 4D), where T is the total time steps in this
mini-batch, D is the hidden size.
size(int): 4 * hidden size.
proj_size(int): The size of projection output.
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weight and projection weight.
- Hidden-hidden weight = {:math:`W_{ch}, W_{ih}, \
W_{fh}, W_{oh}`}.
- The shape of hidden-hidden weight is (P x 4D),
where P is the projection size and D the hidden
size.
- Projection weight = {:math:`W_{rh}`}.
- The shape of projection weight is (D x P).
bias_attr(ParamAttr|None): The bias attribute for the learnable bias
weights, which contains two parts, input-hidden
bias weights and peephole connections weights if
setting `use_peepholes` to `True`.
1. `use_peepholes = False`
- Biases = {:math:`b_c, b_i, b_f, b_o`}.
- The shape is (1 x 4D).
2. `use_peepholes = True`
- Biases = { :math:`b_c, b_i, b_f, b_o, W_{ic}, \
W_{fc}, W_{oc}`}.
- The shape is (1 x 7D).
use_peepholes(bool): Whether to enable diagonal/peephole connections,
default `True`.
is_reverse(bool): Whether to compute reversed LSTM, default `False`.
gate_activation(str): The activation for input gate, forget gate and
output gate. Choices = ["sigmoid", "tanh", "relu",
"identity"], default "sigmoid".
cell_activation(str): The activation for cell output. Choices = ["sigmoid",
"tanh", "relu", "identity"], default "tanh".
candidate_activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh", "relu", "identity"],
default "tanh".
proj_activation(str): The activation for projection output.
Choices = ["sigmoid", "tanh", "relu", "identity"],
default "tanh".
dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
tuple: The projection of hidden state, and cell state of LSTMP. The \
shape of projection is (T x P), for the cell state which is \
(T x D), and both LoD is the same with the `input`.
Examples:
.. code-block:: python
hidden_dim, proj_dim = 512, 256
fc_out = fluid.layers.fc(input=input_seq, size=hidden_dim * 4,
act=None, bias_attr=None)
proj_out, _ = fluid.layers.dynamic_lstmp(input=fc_out,
size=hidden_dim * 4,
proj_size=proj_dim,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh")
"""
helper = LayerHelper('lstmp', **locals())
size = size / 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[proj_size, 4 * size], dtype=dtype)
proj_weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, proj_size], dtype=dtype)
bias_size = [1, 7 * size]
if not use_peepholes:
bias_size[1] = 4 * size
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
projection = helper.create_tmp_variable(dtype)
cell = helper.create_tmp_variable(dtype)
ordered_proj0 = helper.create_tmp_variable(dtype)
batch_hidden = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype)
batch_cell_pre_act = helper.create_tmp_variable(dtype)
helper.append_op(
type='lstmp',
inputs={
'Input': input,
'Weight': weight,
'ProjWeight': proj_weight,
'Bias': bias
},
outputs={
'Projection': projection,
'Cell': cell,
'OrderedP0': ordered_proj0,
'BatchHidden': batch_hidden,
'BatchGate': batch_gate,
'BatchCellPreAct': batch_cell_pre_act
},
attrs={
'use_peepholes': use_peepholes,
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'cell_activation': cell_activation,
'candidate_activation': candidate_activation,
'proj_activation': proj_activation
})
return projection, cell
def dynamic_gru(input,
size,
param_attr=None,
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_OPS test_recv_op)
endif(NOT WITH_DISTRIBUTE)
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
......
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
......
......@@ -202,6 +202,18 @@ class TestBook(unittest.TestCase):
x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell))
print(str(program))
def test_dynamic_lstmp(self):
program = Program()
with program_guard(program):
hidden_dim, proj_dim = 16, 8
seq_data = layers.data(
name='seq_data', shape=[10, 10], dtype='float32', lod_level=1)
fc_out = layers.fc(input=seq_data, size=4 * hidden_dim)
self.assertIsNotNone(
layers.dynamic_lstmp(
input=fc_out, size=4 * hidden_dim, proj_size=proj_dim))
print(str(program))
def test_sequence_softmax(self):
program = Program()
with program_guard(program):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers as layers
import numpy
from multiprocessing import Process
import os, sys
import time
class TestRecvOp(unittest.TestCase):
def test_send(self):
# Run init_serv in a thread
place = fluid.CPUPlace()
p = Process(target=self.init_serv, args=(place, ))
p.daemon = True
p.start()
time.sleep(5)
self.init_client(place)
# FIXME(typhoonzero): find a way to gracefully shutdown the server.
os.system("kill -9 %d" % p.pid)
p.join()
def init_serv(self, place):
main = fluid.Program()
with fluid.program_guard(main):
x = layers.data(
shape=[32, 32],
dtype='float32',
name="X",
append_batch_size=False)
fluid.initializer.Constant(value=1.0)(x, main.global_block())
serv = layers.ListenAndServ("127.0.0.1:6174", optimizer_mode=False)
with serv.do():
o = layers.scale(x=x, scale=10.0)
main.global_block().create_var(
name=o.name, psersistable=False, dtype=o.dtype, shape=o.shape)
exe = fluid.Executor(place)
exe.run(main)
def init_client(self, place):
main = fluid.Program()
with fluid.program_guard(main):
x = layers.data(
shape=[32, 32],
dtype='float32',
name='X',
append_batch_size=False)
fluid.initializer.Constant(value=1.0)(x, main.global_block())
layers.Send("127.0.0.1:6174", [x], [x])
exe = fluid.Executor(place)
exe.run(main)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册