提交 af230d9b 编写于 作者: Y Yang Yu

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into cpp_parallel_executor

......@@ -587,6 +587,9 @@ function(grpc_library TARGET_NAME)
get_filename_component(PROTO_WE ${grpc_library_PROTO} NAME_WE)
get_filename_component(PROTO_PATH ${ABS_PROTO} PATH)
#FIXME(putcn): the follwoing line is supposed to generate *.pb.h and cc, but
# somehow it didn't. line 602 to 604 is to patching this. Leaving this here
# for now to enable dist CI.
protobuf_generate_cpp(grpc_proto_srcs grpc_proto_hdrs "${ABS_PROTO}")
set(grpc_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.cc")
set(grpc_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.h")
......@@ -597,6 +600,9 @@ function(grpc_library TARGET_NAME)
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" -I "${PROTO_PATH}"
--plugin=protoc-gen-grpc="${GRPC_CPP_PLUGIN}" "${ABS_PROTO}"
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
ARGS --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" -I "${PROTO_PATH}"
"${ABS_PROTO}"
DEPENDS "${ABS_PROTO}" ${PROTOBUF_PROTOC_EXECUTABLE} extern_grpc)
# FIXME(typhoonzero): grpc generated code do not generate virtual-dtor, mark it
......
......@@ -113,7 +113,7 @@ To solve this problem, we introduce `ReaderHolder` as a wrapper. It acts as an e
To create and invoke readers, some new ops are introduced:
### CreateReaderOp
### Operators That Create Readers
Each reader has its creation op. File readers' creation ops have no input and yield the created file reader as its output. Decorated readers' creation ops take the underlying readers as inputs and then yield new decorated readers.
......@@ -153,19 +153,52 @@ double_buffer_reader = create_double_buffer_op(batch_reader)
The forwarding ops of the corresponding `main_program` would be like this:
```
while_op {
not_completed = true
pass_count = 0
while_op(not_completed) {
has_next = has_next_op(double_buffer_reader)
if_else_op(has_next) {
batch_data = read_op(double_buffer_reader)
... (subsequent training ops)
} else {
reset_op(double_buffer_reader)
increase_op(pass_count)
not_completed = less_than_op(pass_count, reqiured_pass_num)
}
}
```
Two important considerations for these programs are as follows:
A few important considerations for these programs are as follows:
1. The multiple\_reader is the batch\_reader's underlying reader, and the batch\_reader is the double\_buffer\_reader's underlying reader. `read_op`, `has_next_op` and other reader related ops will only invoke the top-most reader. In this case, it's the double\_buffer\_reader.
1. `not_completed`, `pass_count` and other variables shown above are all Fluid Variables.
2. All readers exist in both `startup_program` and `main_program`. And they are persistable.
2. The multiple\_reader is the batch\_reader's underlying reader, and the batch\_reader is the double\_buffer\_reader's underlying reader. `read_op`, `has_next_op` and other reader related ops will only invoke the top-most reader. In this case, it's the double\_buffer\_reader.
3. All readers exist in both `startup_program` and `main_program`. And they are persistable.
### Simplify Configuration by MultiPassReader
The Program configuration mentioned above is complicated. Users need to be very familiar to concepts of Program and Block to prevent making mistakes in their code. To make the usage of C++ readers more friendly to new users, we introduce `MultiPassReader`.
`MultiPassReader` is a decorated reader. A multi-pass reader is used to continuously yield data for several training passes. It takes the number of passes to run as one of its attributes('pass_num') and maintains a counter to record how many passes it has completed. Each time its underlying reader reaches the EOF, the multi-pass reader checks whether it has completed the training of given number of pass. If not, the underlying reader will be re-initialized and starts a new pass automatically. Before completing the whole training, the return of MultiPassReader's `HasNext()` will always be `true`.
With `MultiPassReader`, the startup program would be like this:
```
multiple_reader = open_files_op(...)
batch_reader = create_batch_reader_op(multiple_reader)
multi_pass_reader = create_multi_pass_reader_op(batch_reader)
double_buffer_reader = create_double_buffer_op(multi_pass_reader)
... (other initializers)
```
The forwarding part of the corresponding `main_program` would be like this:
```
not_completed = true
while_op(not_completed) {
batch_data = read_op(double_buffer_reader)
... (subsequent training ops)
not_completed = has_next_op(double_buffer_reader)
}
```
# Channel Design
## Introduction
A Channel is a data structure that allows for synchronous interprocess
communication via message passing. It is a fundemental component of CSP
(communicating sequential processes), and allows for users to pass data
between threads without having to worry about synchronization.
## How to use it
Paddle offers python APIs to open and close channels, along with sending
and receiving data to/from a channel.
### Create a channel
Creates a new channel that takes in variables of a specific dtype.
- **fluid.make_channel(dtype, capacity=0)**
- **dtype**: The data type of variables being sent/received through channel
- **capacity**: The capacity of the channel. A capacity of 0 represents
an unbuffered channel. Capacity > 0 represents a buffered channel
```
ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR, 10)
```
### Close a channel
Closes a channel. Any pending senders and receivers will be awoken during
this time. Receivers can still receive from a closed channel, but senders
are not allowed to send any additional data to the channel (Paddle will
raise an exception if users try to send to a closed channel.)
- **fluid.channel_close(channel)**
```
fluid.channel_close(ch)
```
### Send data to a channel
Sends a variable to a channel. Currently, variables of dtype `LoDTensor`,
`LoDRankTable`, `LoDTensorArray`, `SelectedRows`, `ReaderHolder`, and
`ChannelHolder` are supported.
By default, the data of the Variable is moved from the sender to the receiver,
however the user can optionally copy the data before performing the send.
- **channel_send(channel, variable, is_copy=False)**
- **channel**: The channel to send the variable to
- **variable**: The variable to send to the channel
- **is_copy**: If set to True, channel_send will perform a variable assign
to copy the source variable to a new variable to be sent.
```
ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
var = fill_constant(shape=[1],dtype=core.VarDesc.VarType.INT32, value=100)
fluid.channel_send(ch, var, True)
```
### Receive data from a channel
Receives a variable from a channel. The data of the variable is moved to the
receiving variable.
- **channel_recv(channel, return_variable)**
- **channel**: The channel to receive the variable from
- **return_variable**: The destination variable used to store the data of the
variable received from the channel
```
ch = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR)
var = fill_constant(shape=[1],dtype=core.VarDesc.VarType.INT32, value=-1)
fluid.channel_recv(ch, var)
```
## How it Works
Channels provides a simple interface for different threads to share data.
To support the synchronization requirements, channels utilizes a series of
internal queues, locks, and conditional variables.
### QueueMessage
QueueMessage encapsulates the state of the channel send/receive operation to be
put in the **sendq/recvq**. It contains a condition variable used to lock the
thread (when there are no available sends/receives). In addition, it contains
a callback function to notify a thread when the QueueMessage is being
processed by the channel.
### Queues
- **buff_**: This queue holds the data buffer in a buffered channel. The
capacity is set to the capacity of the channel. This data buffer is not
used in an unbuffered channel.
- **sendq**: This queue holds the QueueMessage of any pending senders of a
channel. When a thread performs a channel_send operation on the channel, the
channel_send operation will put a new QueueMessage on the sendq and block the
current thread under two conditions:
1. The channel is buffered and is full
2. The channel is unbuffered and does not have a receiver
- **recvq**: This queue holds the QueueMessage of any pending receivers of a
channel. When a thread performs a channel_recv operation on the channel, the
channel_recv operation will put a new QueueMessage on the recvq and block the
current thread under two conditions:
1. The channel is buffered and there is no data on the buff_
2. The channel is unbuffered and does not have a sender
### State diagram
#### Channel Send
<p align="center">
<img src="./images/channel_send.png"/><br/>
</p>
#### Channel Receive
<p align="center">
<img src="./images/channel_recv.png"/><br/>
</p>
## Limitations and Considerations
### Variable Copy
In golang, variables in channels are copied from the sender to the receiver.
In Paddle, the data from our variables are **moved** from sender to receiver.
As a result, these variables should not be used after they are sent. We
provide a flag in channel_send method to allow users to copy the variable to
be sent before it is sent.
Please note that this is acheived by adding an **assign** operator and creating
a temporary variable that is sent in place of the original variable. Please
note that **assign** operator has limited support for only certain variables
datatypes.
......@@ -104,7 +104,7 @@ cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(channel_test SRCS channel_test.cc)
# cc_test(channel_test SRCS channel_test.cc)
cc_test(tuple_test SRCS tuple_test.cc )
cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op
......
......@@ -147,15 +147,52 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return;
}
auto get_vars = [](std::deque<std::unique_ptr<OpDesc>>::iterator &op,
std::vector<std::string> &v) {
auto in_names = (*op)->InputArgumentNames();
v.insert(v.end(), in_names.begin(), in_names.end());
auto out_names = (*op)->OutputArgumentNames();
v.insert(v.end(), out_names.begin(), out_names.end());
std::sort(v.begin(), v.end());
auto last = std::unique(v.begin(), v.end());
v.erase(last, v.end());
};
need_update_ = true;
for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) {
auto names = (*it)->InputArgumentNames();
for (auto n : names) {
// TODO(typhoonzero): delete vars if no other op use it.
VLOG(3) << "deleting var " << n;
for (size_t i = s; i < e; i++) {
// since remove op one by one, every time remove the first op.
auto op = ops_.begin() + s;
// collect input and output variables from current delete op
std::vector<std::string> cur_vars;
get_vars(op, cur_vars);
// remove current op
ops_.erase(ops_.begin() + s);
// collect input and output variables from other ops
std::vector<std::string> other_vars;
for (auto it = ops_.begin(); it != ops_.end(); it++) {
get_vars(it, other_vars);
}
// variables should be deleted
std::vector<std::string> delete_vars;
// delete_vars = cur_vars - cur_vars ^ other_input_vars
std::set_difference(cur_vars.begin(), cur_vars.end(), other_vars.begin(),
other_vars.end(),
std::inserter(delete_vars, delete_vars.end()));
// remove variables
for (size_t i = 0; i < delete_vars.size(); i++) {
auto name = delete_vars[i];
auto it = vars_.find(name);
PADDLE_ENFORCE(it != vars_.end(),
"%s is not in variable list, it should not be deleted",
name);
vars_.erase(it);
VLOG(3) << "deleting variable " << name;
}
}
ops_.erase(ops_.begin() + s, ops_.begin() + e);
}
std::vector<OpDesc *> BlockDesc::AllOps() const {
......
......@@ -89,6 +89,11 @@ class BlockDesc {
OpDesc *InsertOp(size_t index);
/*
* Remove Op and its input/output variables.
* Note that for either input or ouput variable, if it is also an input or
* output variable of other ops, we should remain it.
*/
void RemoveOp(size_t s, size_t e);
std::vector<OpDesc *> AllOps() const;
......
......@@ -260,6 +260,36 @@ $out = floor(x)$
}
};
class CosOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CosOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Cosine operator");
AddOutput("Out", "Output of Cosine operator");
AddComment(R"DOC(
Cosine Activation Operator.
$out = cos(x)$
)DOC");
}
};
class SinOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SinOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sine operator");
AddOutput("Out", "Output of Sine operator");
AddComment(R"DOC(
Sine Activation Operator.
$out = sin(x)$
)DOC");
}
};
class RoundOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker)
......@@ -561,6 +591,12 @@ REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad,
REGISTER_OP(floor, ops::ActivationOp, ops::FloorOpMaker, floor_grad,
ops::ActivationOpGrad);
REGISTER_OP(cos, ops::ActivationOp, ops::CosOpMaker, cos_grad,
ops::ActivationOpGrad);
REGISTER_OP(sin, ops::ActivationOp, ops::SinOpMaker, sin_grad,
ops::ActivationOpGrad);
REGISTER_OP(round, ops::ActivationOp, ops::RoundOpMaker, round_grad,
ops::ActivationOpGrad);
......
......@@ -331,6 +331,54 @@ struct FloorFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct Sine {
HOSTDEVICE T operator()(const T& val) const { return sin(val); }
};
template <typename T>
struct Cosine {
HOSTDEVICE T operator()(const T& val) const { return cos(val); }
};
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = -dout * x.unaryExpr(Sine<T>());
}
};
// cosine(x) = cos(x)
template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosine<T>());
}
};
// sine'(x) = cos(x)
template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Cosine<T>());
}
};
// sine(x) = sin(x)
template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sine<T>());
}
};
// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
......@@ -782,6 +830,8 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, FloorFunctor, ZeroGradFunctor); \
__macro(cos, CosFunctor, CosGradFunctor); \
__macro(sin, SinFunctor, SinGradFunctor); \
__macro(round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
......
......@@ -204,7 +204,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
}
grpc::ChannelArguments args;
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
......
......@@ -59,12 +59,12 @@ message VariableMessage {
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 7;
bytes serialized = 8;
// selected_rows data
bytes rows = 8;
bytes rows = 9;
}
message VoidMessage {}
message TestMessage { int64 test_1 = 1; }
......@@ -108,6 +108,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
......@@ -154,7 +155,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
ProtoEncodeHelper e2((char*)buf, 128);
// NOTE: rows is of type int64_t
size_t rows_memory_size =
slr->rows().capacity() * framework::SizeOfType(typeid(int64_t));
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
......@@ -35,6 +36,12 @@ namespace detail {
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
......
......@@ -40,14 +40,14 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10}));
tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place);
int tensor_numel = 2 * 10;
int tensor_numel = 564 * 128;
math::set_constant(ctx, tensor, 32.7);
rows->push_back(3);
rows->push_back(10);
for (int i = 0; i < 564; ++i) rows->push_back(i);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
......@@ -64,6 +64,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
// deserialize bytebuffer
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1);
......@@ -74,8 +75,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
}
EXPECT_EQ(rows_data[0], 3);
EXPECT_EQ(rows_data[1], 10);
for (int i = 0; i < 564; ++i) {
EXPECT_EQ(rows_data[i], i);
}
// deserialize zero-copy
// framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
......@@ -104,8 +107,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
EXPECT_EQ(rows_data2[0], 3);
EXPECT_EQ(rows_data2[1], 10);
for (int i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], i);
}
EXPECT_EQ(slr2->height(), 1000);
}
void RunTestLodTensor(platform::Place place, int from_type = 0) {
......
......@@ -147,8 +147,13 @@ bool VariableResponse::CopySelectRowsTensorData(
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
PADDLE_ENFORCE_EQ(
tensor->numel(),
length / framework::SizeOfType(
paddle::operators::detail::ToTypeIndex(meta_.data_type())));
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type()));
......@@ -165,7 +170,8 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->resize(length / 8); // int64
slr->mutable_rows()->resize(length /
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily.
......@@ -348,6 +354,14 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
......
......@@ -141,6 +141,7 @@ class ListenAndServOp : public framework::OperatorBase {
// and this will still work.
std::vector<std::future<void>> fs;
double ts = detail::GetTimestamp();
// block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back(
......@@ -162,6 +163,7 @@ class ListenAndServOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what();
}
}
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
......
......@@ -19,8 +19,17 @@ namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class MaxSeqPoolFunctor<platform::CPUDeviceContext, T> {
class MaxSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
......@@ -60,7 +69,7 @@ class MaxSeqPoolFunctor<platform::CPUDeviceContext, T> {
};
template <typename T>
class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, T> {
class MaxSeqPoolGradFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& out_grad,
......@@ -93,10 +102,101 @@ class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, T> {
}
};
template class MaxSeqPoolFunctor<platform::CPUDeviceContext, float>;
template class MaxSeqPoolFunctor<platform::CPUDeviceContext, double>;
template class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, float>;
template class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, double>;
template <typename T>
class SequencePoolFunctor<platform::CPUDeviceContext, T> {
public:
/* max pool has index output */
void operator()(const platform::CPUDeviceContext& context,
const std::string pooltype, const framework::LoDTensor& input,
framework::Tensor* output,
framework::Tensor* index = nullptr) {
if (pooltype == "MAX") {
math::MaxSeqPoolFunctor<T> max_pool;
max_pool(context, input, output, index);
return;
}
auto lod = input.lod()[0];
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
Tensor in_t =
input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
Tensor out_t = output->Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t w = input.numel() / input.dims()[0];
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t);
if (pooltype == "AVERAGE") {
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SUM") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
out_e.device(place) = in_e.chip(0, 0);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
}
};
template <typename T>
class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const std::string pooltype, const framework::Tensor& out_grad,
framework::LoDTensor* in_grad,
/* max pool has index */
const framework::Tensor* index = nullptr) {
if (pooltype == "MAX") {
math::MaxSeqPoolGradFunctor<T> max_pool_grad;
max_pool_grad(context, out_grad, *index, in_grad);
return;
}
if (pooltype == "LAST" || pooltype == "FIRST") {
// set X@Grad be zero at first when pooltype is LAST/FIRST
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(context, in_grad, 0);
}
auto lod = in_grad->lod()[0];
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
static_cast<int>(lod[i + 1]));
auto out_g_t = out_grad.Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t w = in_grad->numel() / in_grad->dims()[0];
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
auto out_g_e_v = EigenVector<T>::Flatten(out_g_t);
Eigen::DSizes<int, 2> bcast(h, 1);
if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
in_g_e.device(place) = (out_g_e).broadcast(bcast);
} else if (pooltype == "SQRT") {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
} else if (pooltype == "LAST") {
in_g_e.chip(h - 1, 0).device(place) = out_g_e_v;
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e_v;
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
}
};
template class SequencePoolFunctor<platform::CPUDeviceContext, float>;
template class SequencePoolFunctor<platform::CPUDeviceContext, double>;
template class SequencePoolGradFunctor<platform::CPUDeviceContext, float>;
template class SequencePoolGradFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace paddle {
namespace operators {
......@@ -22,113 +23,331 @@ namespace math {
#define FLT_MAX __FLT_MAX__
template <typename T>
__global__ void KeMaxSequencePool(const T* input, const size_t* starts,
T* output, int* index, int64_t num_seq,
int64_t dim) {
int dim_idx = threadIdx.x;
int seq_id = blockIdx.x;
if (seq_id >= num_seq) return;
size_t start = starts[seq_id];
size_t end = starts[seq_id + 1];
for (int64_t i = dim_idx; i < dim; i += blockDim.x) {
T max_val = static_cast<T>(-FLT_MAX);
int max_id = -1;
for (size_t step_id = start; step_id < end; step_id++) {
if (max_val < input[step_id * dim + i]) {
max_val = input[step_id * dim + i];
max_id = step_id;
struct MaxPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T max_val = static_cast<T>(-FLT_MAX);
int max_index = -1;
for (int i = start; i < end; ++i) {
if (max_val < input[item_dim * i + tid]) {
max_val = input[item_dim * i + tid];
max_index = i;
}
}
output[tid] = max_val;
index[tid] = max_index;
}
output[seq_id * dim + i] = max_val;
index[seq_id * dim + i] = max_id;
}
}
};
template <typename T>
class MaxSeqPoolFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index) {
auto in_dims = input.dims();
auto out_dims = output->dims();
auto idx_dims = index->dims();
PADDLE_ENFORCE_GT(in_dims.size(), static_cast<int64_t>(1));
PADDLE_ENFORCE_GT(out_dims.size(), 1);
for (int64_t i = 1; i < in_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
struct AvgPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
}
// end, start is lod, so end - start != 0
output[tid] = val / static_cast<T>(end - start);
}
PADDLE_ENFORCE_EQ(idx_dims, out_dims);
}
};
auto starts = input.lod()[0];
const T* in_data = input.data<T>();
T* out_data = output->data<T>();
int* max_index = index->data<int>();
template <typename T>
struct SumPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
}
output[tid] = val;
}
}
};
int64_t num_seq = out_dims[0];
int64_t dim = output->numel() / num_seq;
template <typename T>
struct SqrtPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
}
// end, start is lod, so end - start != 0
output[tid] = val / sqrt(end - start);
}
}
};
dim3 threads(256, 1);
dim3 grid(num_seq, 1);
auto stream = context.stream();
KeMaxSequencePool<T><<<grid, threads, 0, stream>>>(
in_data, starts.CUDAData(context.GetPlace()), out_data, max_index,
num_seq, dim);
template <typename T>
struct LastPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
output[tid] = input[item_dim * (end - 1) + tid];
}
}
};
template <typename T>
__global__ void KeMaxSequencePoolGrad(const T* out_grad, const int* max_index,
T* in_grad, int64_t num_seq,
int64_t dim) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int col_idx = idx % dim;
if (idx < num_seq * dim) {
int step_id = max_index[idx];
in_grad[step_id * dim + col_idx] = out_grad[idx];
struct FirstPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
output[tid] = input[item_dim * start + tid];
}
}
};
template <typename T, typename Range_OP>
__global__ void sequence_pool_kernel(Range_OP op, const T* input,
const size_t* lod, const size_t lod_size,
const size_t item_dim, T* output,
int* index) {
int bid = blockIdx.x;
if (bid >= lod_size - 1) return;
size_t start = lod[bid];
size_t end = lod[bid + 1];
int* index_offset = nullptr;
if (index != nullptr) {
index_offset = &index[bid * item_dim];
}
op(input, start, end, item_dim, &output[bid * item_dim], index_offset);
}
template <typename T>
class MaxSeqPoolGradFunctor<platform::CUDADeviceContext, T> {
class SequencePoolFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& out_grad,
const framework::Tensor& index,
framework::LoDTensor* in_grad) {
auto og_dims = out_grad.dims();
auto idx_dims = index.dims();
auto ig_dims = in_grad->dims();
PADDLE_ENFORCE_GT(og_dims.size(), static_cast<int64_t>(1));
PADDLE_ENFORCE_GT(ig_dims.size(), static_cast<int64_t>(1));
for (int64_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]);
const std::string pooltype, const framework::LoDTensor& input,
framework::Tensor* output,
framework::Tensor* index = nullptr) {
auto lod = input.lod()[0];
const size_t item_dim = output->numel() / output->dims()[0];
dim3 threads(1024, 1);
dim3 grid(lod.size(), 1);
if (pooltype == "MAX") {
sequence_pool_kernel<
T, MaxPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
MaxPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), index->data<int>());
} else if (pooltype == "AVERAGE") {
sequence_pool_kernel<
T, AvgPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
AvgPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "SUM") {
sequence_pool_kernel<
T, SumPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SumPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "SQRT") {
sequence_pool_kernel<
T, SqrtPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SqrtPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "LAST") {
sequence_pool_kernel<
T, LastPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
LastPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "FIRST") {
sequence_pool_kernel<
T, FirstPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
FirstPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
PADDLE_ENFORCE_EQ(idx_dims, og_dims);
}
};
const T* og_data = out_grad.data<T>();
const int* max_index = index.data<int>();
T* ig_data = in_grad->data<T>();
template <typename T>
struct MaxPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
if (i == index[tid]) {
in_grad[item_dim * i + tid] = out_grad[tid];
} else {
in_grad[item_dim * i + tid] = static_cast<T>(0);
}
}
}
}
};
SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(context, in_grad, static_cast<T>(0.0));
int64_t num_seq = og_dims[0];
int64_t dim = out_grad.numel() / num_seq;
template <typename T>
struct AvgPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
in_grad[item_dim * i + tid] = out_grad[tid] / (end - start);
}
}
}
};
unsigned int blocks = (num_seq * dim + 128 - 1) / 128;
dim3 threads(128, 1);
dim3 grid(blocks, 1);
auto stream = context.stream();
KeMaxSequencePoolGrad<T><<<grid, threads, 0, stream>>>(
og_data, max_index, ig_data, num_seq, dim);
template <typename T>
struct SumPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
in_grad[item_dim * i + tid] = out_grad[tid];
}
}
}
};
template <typename T>
struct SqrtPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
in_grad[item_dim * i + tid] =
out_grad[tid] / (sqrt(static_cast<T>(end - start)));
}
}
}
};
template <typename T>
struct LastPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
if (i == end - 1) {
in_grad[item_dim * i + tid] = out_grad[tid];
} else {
in_grad[item_dim * i + tid] = static_cast<T>(0);
}
}
}
}
};
template <typename T>
struct FirstPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
if (i == start) {
in_grad[item_dim * i + tid] = out_grad[tid];
} else {
in_grad[item_dim * i + tid] = static_cast<T>(0);
}
}
}
}
};
template <typename T, typename Range_OP>
__global__ void sequence_pool_grad_kernel(Range_OP op, const T* out_grad,
const size_t* lod,
const size_t lod_size,
const size_t item_dim, T* in_grad,
const int* index) {
int bid = blockIdx.x;
if (bid >= lod_size - 1) return;
size_t start = lod[bid];
size_t end = lod[bid + 1];
const int* index_offset = nullptr;
if (index != nullptr) {
index_offset = &index[bid * item_dim];
}
op(&out_grad[bid * item_dim], start, end, item_dim, in_grad, index_offset);
}
template <typename T>
class SequencePoolGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const std::string pooltype, const framework::Tensor& out_grad,
framework::LoDTensor* in_grad,
/* max pool has index */
const framework::Tensor* index = nullptr) {
auto lod = in_grad->lod()[0];
const size_t item_dim = in_grad->numel() / in_grad->dims()[0];
dim3 threads(1024, 1);
dim3 grid(lod.size(), 1);
if (pooltype == "MAX") {
sequence_pool_grad_kernel<
T, MaxPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
MaxPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()), index->data<int>());
} else if (pooltype == "AVERAGE") {
sequence_pool_grad_kernel<
T, AvgPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
AvgPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "SUM") {
sequence_pool_grad_kernel<
T, SumPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SumPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "SQRT") {
sequence_pool_grad_kernel<
T, SqrtPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SqrtPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "LAST") {
sequence_pool_grad_kernel<
T, LastPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
LastPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "FIRST") {
sequence_pool_grad_kernel<
T, FirstPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
FirstPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()), nullptr);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
};
template class MaxSeqPoolFunctor<platform::CUDADeviceContext, float>;
template class MaxSeqPoolFunctor<platform::CUDADeviceContext, double>;
template class MaxSeqPoolGradFunctor<platform::CUDADeviceContext, float>;
template class MaxSeqPoolGradFunctor<platform::CUDADeviceContext, double>;
// sequence pooling
template class SequencePoolFunctor<platform::CUDADeviceContext, float>;
template class SequencePoolFunctor<platform::CUDADeviceContext, double>;
template class SequencePoolGradFunctor<platform::CUDADeviceContext, float>;
template class SequencePoolGradFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
......
......@@ -21,23 +21,23 @@ namespace paddle {
namespace operators {
namespace math {
#define FLT_MAX __FLT_MAX__
template <typename DeviceContext, typename T>
class MaxSeqPoolFunctor {
class SequencePoolFunctor {
public:
void operator()(const DeviceContext& context,
/* max pool has index output */
void operator()(const DeviceContext& context, const std::string pooltype,
const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index);
framework::Tensor* index = nullptr);
};
template <typename DeviceContext, class T>
class MaxSeqPoolGradFunctor {
template <typename DeviceContext, typename T>
class SequencePoolGradFunctor {
public:
void operator()(const DeviceContext& context,
void operator()(const DeviceContext& context, const std::string pooltype,
const framework::Tensor& out_grad,
const framework::Tensor& index,
framework::LoDTensor* in_grad);
framework::LoDTensor* in_grad,
/* max pool has index */
const framework::Tensor* index = nullptr);
};
} // namespace math
......
......@@ -81,10 +81,10 @@ class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
This operator creates a multi-pass reader. A multi-pass reader
is used to yield data for several pass training continuously.
It takes the the number of pass to run as one of its attributes
It takes the number of passes to run as one of its attributes
('pass_num'), and maintains a pass counter to record how many
passes it has completed. When the underlying reader reach the EOF,
the multi-pass reader checks whether it has completed training
passes it has completed. When the underlying reader reaches the
EOF, the multi-pass reader checks whether it has completed training
of the given number of pass. If not, the underlying reader will
be re-initialized and starts a new pass automatically.
)DOC");
......
......@@ -72,7 +72,7 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(2) << "sending " << ins[i] << " to " << epmap[i];
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
......@@ -81,7 +81,7 @@ class SendOp : public framework::OperatorBase {
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) {
VLOG(2) << "batch barrier, ep: " << ep;
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
......
......@@ -23,12 +23,6 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class SequencePoolKernel : public framework::OpKernel<T> {
......@@ -37,11 +31,13 @@ class SequencePoolKernel : public framework::OpKernel<T> {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<Tensor>("Out");
std::string pooltype = context.Attr<std::string>("pooltype");
Tensor* index = nullptr;
if (pooltype == "MAX") {
index = context.Output<Tensor>("MaxIndex");
}
auto dims = in->dims();
auto lod = in->lod();
int64_t w = in->numel() / dims[0];
// InferShape by lod
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_GE(
......@@ -50,45 +46,14 @@ class SequencePoolKernel : public framework::OpKernel<T> {
"The first dimension of Input(X) must be large than batch size.");
dims[0] = lod[0].size() - 1;
out->Resize({dims});
auto lod_level_0 = lod[0];
out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
if (pooltype == "MAX") {
math::MaxSeqPoolFunctor<DeviceContext, T> max_pool;
auto* index = context.Output<Tensor>("MaxIndex");
index->Resize({dims});
index->mutable_data<int>(context.GetPlace());
max_pool(dev_ctx, *in, out, index);
return;
}
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor in_t = in->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
Tensor out_t = out->Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod_level_0[i + 1] - lod_level_0[i]);
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t);
if (pooltype == "AVERAGE") {
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SUM") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
out_e.device(place) = in_e.chip(0, 0);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
math::SequencePoolFunctor<DeviceContext, T> pool;
pool(context.template device_context<DeviceContext>(), pooltype, *in, out,
index);
}
};
......@@ -96,58 +61,17 @@ template <typename DeviceContext, typename T>
class SequencePoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out_g = context.Input<Tensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
std::string pooltype = context.Attr<std::string>("pooltype");
auto dims = in->dims();
auto lod = in->lod()[0];
int64_t w = in->numel() / dims[0];
in_g->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
const Tensor* index = nullptr;
if (pooltype == "MAX") {
math::MaxSeqPoolGradFunctor<DeviceContext, T> max_pool_grad;
auto* index = context.Input<Tensor>("MaxIndex");
max_pool_grad(dev_ctx, *out_g, *index, in_g);
return;
}
if (pooltype == "LAST" || pooltype == "FIRST") {
// set X@Grad be zero at first when pooltype is LAST/FIRST
math::SetConstant<DeviceContext, T> functor;
functor(dev_ctx, in_g, 0);
}
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t =
in_g->Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
auto out_g_t = out_g->Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
auto out_g_e_v = EigenVector<T>::Flatten(out_g_t);
Eigen::DSizes<int, 2> bcast(h, 1);
if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
in_g_e.device(place) = (out_g_e).broadcast(bcast);
} else if (pooltype == "SQRT") {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
} else if (pooltype == "LAST") {
in_g_e.chip(h - 1, 0).device(place) = out_g_e_v;
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e_v;
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
index = context.Input<Tensor>("MaxIndex");
}
in_g->mutable_data<T>(context.GetPlace());
math::SequencePoolGradFunctor<DeviceContext, T> pool;
pool(context.template device_context<DeviceContext>(), pooltype, *out_g,
in_g, index);
}
};
......
......@@ -25,6 +25,8 @@ __activations__ = [
'abs',
'ceil',
'floor',
'cos',
'sin',
'round',
'reciprocal',
'log',
......
......@@ -196,6 +196,34 @@ class TestFloor(OpTest):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestCos(OpTest):
def setUp(self):
self.op_type = "cos"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x}
self.outputs = {'Out': np.cos(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSin(OpTest):
def setUp(self):
self.op_type = "sin"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x}
self.outputs = {'Out': np.sin(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestRound(OpTest):
def setUp(self):
self.op_type = "round"
......
......@@ -186,6 +186,34 @@ class TestBlockDesc(unittest.TestCase):
all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op0, op1, op2])
def test_remove_op(self):
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
op1 = block.append_op()
op2 = block.append_op()
var1 = block.var("var1")
var2 = block.var("var2")
var3 = block.var("var3")
var4 = block.var("var4")
var5 = block.var("var5")
op1.set_input("X", ["var1", "var2"])
op1.set_output("Y", ["var3", "var4"])
op2.set_input("X", ["var1"])
op2.set_output("Y", ["var4", "var5"])
# remove op1, its input var2 and output var3 will be removed at the same time,
# but its input var1 and output var4 will not be removed since they are used for op2.
block.remove_op(0, 1)
all_ops = []
for idx in xrange(0, block.op_size()):
all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op2])
all_vars = block.all_vars()
self.assertEqual(set(all_vars), {var1, var4, var5})
if __name__ == '__main__':
unittest.main()
......@@ -49,6 +49,61 @@ class TestSeqAvgPool(OpTest):
self.check_grad(["X"], "Out")
class TestSeqSumPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SUM"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.sum(axis=0)
class TestSeqMaxPool(TestSeqAvgPool):
def set_data(self):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 23]).astype('float32')
lod = [[0, 4, 5, 8, 13]]
for i in range(4):
l = lod[0][i + 1] - lod[0][i]
x[lod[0][i] + np.random.randint(l), :] += 2.0
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 23)).astype('float32')
self.outputs = {'Out': out}
return x, lod, out
def compute(self, x, lod, out):
self.attrs = {'pooltype': "MAX"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = np.amax(sub_x, axis=0)
class TestSeqSqrtPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SQRT"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
len = lod[0][i + 1] - lod[0][i]
out[i] = sub_x.sum(axis=0) / np.sqrt(len)
class TestSeqLastPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "LAST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[-1, :]
class TestSeqFirstPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "FIRST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[0, :]
class TestSeqAvgPool2D(TestSeqAvgPool):
def set_data(self):
self.op_type = 'sequence_pool'
......@@ -68,14 +123,6 @@ class TestSeqAvgPool2D(TestSeqAvgPool):
out[i] = np.reshape(sub_x.mean(axis=0), (3, 17))
class TestSeqSumPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SUM"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.sum(axis=0)
class TestSeqSumPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SUM"}
......@@ -84,15 +131,6 @@ class TestSeqSumPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(sub_x.sum(axis=0), (3, 17))
class TestSeqSqrtPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SQRT"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
len = lod[0][i + 1] - lod[0][i]
out[i] = sub_x.sum(axis=0) / np.sqrt(len)
class TestSeqSqrtPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SQRT"}
......@@ -108,28 +146,6 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D):
self.check_grad(["X"], "Out", max_relative_error=0.06)
class TestSeqMaxPool(TestSeqAvgPool):
def set_data(self):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 23]).astype('float32')
lod = [[0, 4, 5, 8, 13]]
for i in range(4):
l = lod[0][i + 1] - lod[0][i]
x[lod[0][i] + np.random.randint(l), :] += 2.0
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 23)).astype('float32')
self.outputs = {'Out': out}
return x, lod, out
def compute(self, x, lod, out):
self.attrs = {'pooltype': "MAX"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = np.amax(sub_x, axis=0)
class TestSeqMaxPool2D(TestSeqAvgPool2D):
def set_data(self):
self.op_type = 'sequence_pool'
......@@ -151,14 +167,6 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11))
class TestSeqLastPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "LAST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[-1, :]
class TestSeqLastPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "LAST"}
......@@ -167,14 +175,6 @@ class TestSeqLastPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(sub_x[-1, :], (3, 17))
class TestSeqFirstPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "FIRST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[0, :]
class TestSeqFirstPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "FIRST"}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册