提交 91b6d600 编写于 作者: F fengjiayi

Merge branch 'fix_bug_in_recordio' into dev_MultiEpochReader

......@@ -26,7 +26,7 @@ lookup of rows.
The following figure illustrates the multiplication of x with two
non-zero elements, or say, two symbols, and a lookup table W:
![lookup table](./lookup_table.png)
![lookup table](./src/lookup_table.png)
### The Backward Algorithm
......@@ -42,7 +42,7 @@ or some more sophisticated algorithms that rely on both W' and W:
$$W = f(W, W')$$
The following figure illustrates the backward pass of the lookup
operator: ![lookup table training](./lookup_table_training.png)
operator: ![lookup table training](./src/lookup_table_training.png)
## Distributed Storage Service
......
......@@ -103,7 +103,7 @@ In computability theory, a system of data-manipulation rules, such as a programm
There are two ways to execute a Fluid program. When a program is executed, it creates a protobuf message [`ProgramDesc`](https://github.com/PaddlePaddle/Paddle/blob/a91efdde6910ce92a78e3aa7157412c4c88d9ee8/paddle/framework/framework.proto#L145) that describes the process and is conceptually like an [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree).
There is a C++ class [`Executor`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/executor.h), which runs a `ProgramDesc`, similar to how an interpreter runs a Python program.
There is a C++ class [`Executor`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/executor.h), which runs a `ProgramDesc`, similar to how an interpreter runs a Python program.
Fluid is moving towards the direction of a compiler, which is explain in [fluid_compiler.md](fluid_compiler.md).
......
......@@ -47,3 +47,10 @@ DecayedAdagrad
:members:
:noindex:
Adadelta
--------------
.. autoclass:: paddle.fluid.optimizer.AdadeltaOptimizer
:members:
:noindex:
......@@ -871,3 +871,67 @@ TEST(ChannelHolder, ChannelHolderDestroyUnblocksSendersTest) {
ch->Reset<int>(0);
ChannelHolderDestroyUnblockSenders(ch, false);
}
// This tests that closing a channelholder many times.
void ChannelHolderManyTimesClose(ChannelHolder *ch) {
const int num_threads = 15;
std::thread t[num_threads];
bool thread_ended[num_threads];
// Launches threads that try to send data to channel.
for (size_t i = 0; i < num_threads / 3; i++) {
thread_ended[i] = false;
t[i] = std::thread(
[&](bool *ended) {
int data = 10;
ch->Send(&data);
*ended = true;
},
&thread_ended[i]);
}
// Launches threads that try to receive data to channel.
for (size_t i = num_threads / 3; i < 2 * num_threads / 3; i++) {
thread_ended[i] = false;
t[i] = std::thread(
[&](bool *p) {
int data;
if (ch->Receive(&data)) {
EXPECT_EQ(data, 10);
}
*p = true;
},
&thread_ended[i]);
}
// Launches threads that try to close the channel.
for (size_t i = 2 * num_threads / 3; i < num_threads; i++) {
thread_ended[i] = false;
t[i] = std::thread(
[&](bool *p) {
if (!ch->IsClosed()) {
ch->close();
}
*p = true;
},
&thread_ended[i]);
}
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
// Verify that all threads are unblocked
for (size_t i = 0; i < num_threads; i++) {
EXPECT_EQ(thread_ended[i], true);
}
EXPECT_TRUE(ch->IsClosed());
// delete the channel
delete ch;
for (size_t i = 0; i < num_threads; i++) t[i].join();
}
TEST(ChannelHolder, ChannelHolderManyTimesCloseTest) {
// Check for Buffered Channel
ChannelHolder *ch = new ChannelHolder();
ch->Reset<int>(10);
ChannelHolderManyTimesClose(ch);
}
......@@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -73,7 +72,6 @@ are set equal to their corresponding inputs.
}
};
template <typename AttrType>
class DropoutOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad,
ops::DropoutOpGrad<float>);
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad,
ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL(
dropout,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float, float>);
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -18,17 +18,18 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T, typename AttrType>
template <typename T>
__global__ void RandomGenerator(const size_t n, const int seed,
const AttrType dropout_prob, const T* src,
const float dropout_prob, const T* src,
T* mask_data, T* dst) {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<AttrType> dist(0, 1);
thrust::uniform_real_distribution<float> dist(0, 1);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < n; idx += blockDim.x * gridDim.x) {
......@@ -44,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed,
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename Place, typename T, typename AttrType>
template <typename Place, typename T>
class GPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
float dropout_prob = context.Attr<float>("dropout_prob");
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
......@@ -70,11 +71,11 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int threads = 512;
int grid = (x->numel() + threads - 1) / threads;
RandomGenerator<T, AttrType><<<grid, threads, 0,
context.cuda_device_context().stream()>>>(
RandomGenerator<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, seed, dropout_prob, x_data, mask_data, y_data);
} else {
Y.device(place) = X * (1.0f - dropout_prob);
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
};
......@@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
dropout,
ops::GPUDropoutKernel<paddle::platform::CUDADeviceContext, float, float>);
REGISTER_OP_CUDA_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CUDADeviceContext, float>);
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(dropout_grad,
ops::DropoutGradKernel<plat::CUDADeviceContext, float>);
......@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T, typename AttrType>
template <typename DeviceContext, typename T>
class CPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -22,17 +22,16 @@ class LoDResetOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
// input check
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LoDResetOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LoDResetOp should not be null.");
// If target LoD is not set form Input(), then it must be set from Attr().
if (!ctx->HasInput("TargetLoD")) {
if (!ctx->HasInput("Y")) {
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
PADDLE_ENFORCE(level0.size() > 1,
"Target LoD is not found, should be set to be a valid one "
"through Input() or Attr().");
PADDLE_ENFORCE_GT(level0.size(), 1,
"If Input(Y) not provided, the target lod should be "
"specified by attribute `target_lod`.");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
......@@ -50,36 +49,77 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The input tensor of lod_reset operator.");
AddInput("TargetLoD",
"(Tensor, optional) The target level 0 LoD from Input().")
AddInput("X",
"(Tensor, LoDTensor) Input variable of LoDResetOp which "
"could be a Tensor or LoDTensor, where the data of output "
"variable inherits from.");
AddInput("Y",
"(Tensor, LoDTensor, optional) If provided and Y is LoDTensor, "
"lod of Input(Y) would be considered as the target lod first, "
"otherwise data of Input(Y) would be considered as the "
"target lod.")
.AsDispensable();
AddOutput("Out", "(LoDTensor) The output tensor of lod_reset operator.");
AddOutput("Out",
"(LoDTensor) Output variable of LoDResetOp which should be a "
"LoDTensor.");
AddAttr<std::vector<int>>("target_lod",
"The target level 0 LoD from Attr().")
.SetDefault(std::vector<int>{});
AddComment(R"DOC(LoDReset operator
Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or
Attr(target_lod), or set LoD for Input(X) if it doesn't have one.
Currently the lod_reset operator only supports the reset of level 0 LoD.
At least one of Input(TargetLoD) and Attr(target_lod) must be set,
and if both of them are set, Input(TargetLoD) will be chosen as the
target LoD.
Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y`
provided and `Y` is a LoDTensor, `Y.lod` would be considered as target LoD
first, otherwise `Y.data` would be considered as target LoD. If `Y` is not
provided, target LoD should be specified by attribute `target_lod`.
If target LoD is specified by `Y.data` or `target_lod`, only one level LoD
is supported.
Example 1:
Given a 1-level LoDTensor input(X):
X.lod = [[ 0, 2, 5 6 ]]
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
X.dims = [6, 1]
attr(target_lod): [0, 4, 6]
then we get a 1-level LoDTensor:
Out.lod = [[ 0, 4, 6 ]]
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
Out.dims = [6, 1]
Example 2:
An example:
Given a float LoDTensor X with shape (6, 1), its transpose form represents
Given a 1-level LoDTensor input(X):
X.lod = [[ 0, 2, 5 6 ]]
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
X.dims = [6, 1]
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
input(Y) is a Tensor:
Y.data = [[0, 2, 6]]
Y.dims = [1, 3]
with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like
then we get a 1-level LoDTensor:
Out.lod = [[ 0, 2, 6 ]]
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
Out.dims = [6, 1]
[1.0, 2.0], [3.0, 4.0, 5.0], [6.0].
Example 3:
If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and
the sequences that the LoDTensor Output(Out) contains becomes:
Given a 1-level LoDTensor input(X):
X.lod = [[ 0, 2, 5 6 ]]
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
X.dims = [6, 1]
[1.0, 2.0, 3.0, 4.0], [5.0, 6.0].
input(Y) is a 2-level LoDTensor:
Y.lod = [[0, 2, 4], [0, 2, 5, 6]]
Y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]]
Y.dims = [6, 1]
then we get a 2-level LoDTensor:
Out.lod = [[0, 2, 4], [0, 2, 5, 6]]
Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
Out.dims = [6, 1]
)DOC");
}
......@@ -90,10 +130,16 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LoDResetGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
"Input(Out@Grad) of LoDResetGradOp should not be null.");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
}
protected:
......@@ -111,9 +157,13 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad,
ops::LoDResetGradOp);
REGISTER_OP_CPU_KERNEL(lod_reset,
ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetKernel<paddle::platform::CPUPlace, int>,
ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t>);
REGISTER_OP_CPU_KERNEL(
lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>);
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int64_t>);
......@@ -18,8 +18,12 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lod_reset, ops::LoDResetKernel<paddle::platform::CUDADeviceContext, float>,
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, double>);
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, double>,
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, int>,
ops::LoDResetKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
lod_reset_grad,
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::LoDResetGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -26,35 +26,46 @@ class LoDResetKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto* in = ctx.Input<framework::LoDTensor>("X");
auto* lod_t = ctx.Input<framework::Tensor>("TargetLoD");
auto* lod_t = ctx.Input<framework::LoDTensor>("Y");
out->ShareDataWith(*in);
std::vector<int> level0;
if (lod_t) {
auto* lod = lod_t->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor lod_cpu;
framework::TensorCopy(*lod_t, platform::CPUPlace(),
ctx.device_context(), &lod_cpu);
lod = lod_cpu.data<int>();
if (lod_t->lod().size() > 0) {
auto y_lod = lod_t->lod();
auto last_level = y_lod[y_lod.size() - 1];
PADDLE_ENFORCE_EQ(last_level.back(), in->dims()[0],
"Last value of `Y`'s last level LoD should be equal "
"to the first dimension of `X`");
out->set_lod(y_lod);
return; // early return, since lod already set
} else {
auto* lod = lod_t->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor lod_cpu;
framework::TensorCopy(*lod_t, platform::CPUPlace(),
ctx.device_context(), &lod_cpu);
lod = lod_cpu.data<int>();
}
level0 = std::vector<int>(lod, lod + lod_t->numel());
}
level0 = std::vector<int>(lod, lod + lod_t->numel());
} else {
level0 = ctx.Attr<std::vector<int>>("target_lod");
}
PADDLE_ENFORCE(level0.size() > 1UL,
"The size of target LoD should be greater than 1.");
PADDLE_ENFORCE(level0[0] == 0,
"Target LoD should be a vector starting from 0.");
PADDLE_ENFORCE(level0.back() == in->dims()[0],
"Target LoD should be a vector end with the "
"first dimension of Input(X).");
PADDLE_ENFORCE_GT(level0.size(), 1UL,
"Size of target LoD should be greater than 1.");
PADDLE_ENFORCE_EQ(level0[0], 0,
"Target LoD should be a vector starting from 0.");
PADDLE_ENFORCE_EQ(level0.back(), in->dims()[0],
"Target LoD should be a vector end with the "
"first dimension of Input(X).");
for (size_t i = 0; i < level0.size() - 1; ++i) {
PADDLE_ENFORCE(level0[i + 1] > level0[i],
"Target LoD should be an ascending vector.");
}
out->ShareDataWith(*in);
// cast level0 to size_t
std::vector<size_t> ulevel0(level0.size(), 0);
std::transform(level0.begin(), level0.end(), ulevel0.begin(),
......
......@@ -43,7 +43,7 @@ math_library(sequence2batch)
math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale)
math_library(softmax)
math_library(softmax DEPS math_function)
math_library(unpooling)
math_library(vol2col)
......
......@@ -44,7 +44,7 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
out_cols += t_cols;
input_cols[i] = t_cols;
}
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation
for (int k = 0; k < out_rows; ++k) {
......@@ -87,7 +87,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
input_cols += t_cols;
output_cols[i] = t_cols;
}
auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation
for (int k = 0; k < input_rows; ++k) {
......
......@@ -371,6 +371,8 @@ template struct RowwiseAdd<platform::CPUDeviceContext, double>;
template struct ColwiseSum<platform::CPUDeviceContext, float>;
template struct ColwiseSum<platform::CPUDeviceContext, double>;
template struct ColwiseSum<platform::CPUDeviceContext, int>;
template struct ColwiseSum<platform::CPUDeviceContext, int64_t>;
template struct RowwiseSum<platform::CPUDeviceContext, float>;
template struct RowwiseSum<platform::CPUDeviceContext, double>;
......
......@@ -422,6 +422,8 @@ struct RowwiseAdd<platform::CUDADeviceContext, T> {
template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>;
template struct ColwiseSum<platform::CUDADeviceContext, int>;
template struct ColwiseSum<platform::CUDADeviceContext, int64_t>;
// template struct ColwiseSum<platform::CUDADeviceContext, double>;
// The ColwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
......
......@@ -48,20 +48,24 @@ class DoubleBufferReader : public framework::DecoratedReader {
void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
std::thread prefetch([this] { PrefetchThreadFunc(); });
prefetch.detach();
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~DoubleBufferReader() { buffer_->Close(); }
~DoubleBufferReader() {
buffer_->Close();
prefetcher_.join();
delete buffer_;
}
bool HasNext() const override;
private:
void PrefetchThreadFunc();
std::thread prefetcher_;
framework::Channel<Item>* buffer_;
platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
......@@ -134,6 +138,8 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void DoubleBufferReader::ReInit() {
reader_->ReInit();
buffer_->Close();
prefetcher_.join();
delete buffer_;
start_thread();
}
......@@ -159,11 +165,12 @@ void DoubleBufferReader::PrefetchThreadFunc() {
if (!buffer_->Send(&batch)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread terminates.";
"prefetch thread will terminate.";
break;
}
}
buffer_->Close();
VLOG(5) << "Prefetch thread terminates.";
}
bool DoubleBufferReader::HasNext() const {
......
......@@ -34,6 +34,9 @@ class ShuffleReader : public framework::DecoratedReader {
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer";
ReadIntoBuffers();
......@@ -50,7 +53,6 @@ class ShuffleReader : public framework::DecoratedReader {
buffer_.clear();
buffer_.reserve(buffer_size_);
iteration_pos_ = 0;
PADDLE_ENFORCE(reader_->HasNext());
for (size_t i = 0; i < buffer_size_; ++i) {
if (!reader_->HasNext()) {
break;
......
......@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::Tensor;
using framework::LoDTensor;
class SequenceExpandOp : public framework::OperatorWithKernel {
public:
......@@ -25,15 +25,71 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasOutput("Out"));
PADDLE_ENFORCE(ctx->HasInput("Y"));
framework::DDim out_dim;
auto y_dim = ctx->GetInputDim("Y");
out_dim = ctx->GetInputDim("X");
out_dim[0] = y_dim[0];
ctx->ShareLoD("Y", "Out");
ctx->SetOutputDim("Out", out_dim);
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of SequenceExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceExpandOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto out_dims = x_dims;
int ref_level = ctx->Attrs().Get<int>("ref_level");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Dimension number of Input(X) should be at least 2.");
if (ctx->IsRuntime()) {
framework::Variable* x_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
framework::Variable* y_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Y")[0]);
auto& x_lod = x_var->Get<LoDTensor>().lod();
auto& y_lod = y_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE_LE(x_lod.size(), 1,
"Level number of Input(X)'s lod should not be "
"greater than 1.");
PADDLE_ENFORCE_GT(y_lod.size(), 0,
"Level number of Input(Y)'s lod should be "
"greater than 0.");
PADDLE_ENFORCE(
ref_level == -1 ||
(ref_level >= 0 && ref_level < static_cast<int>(y_lod.size())),
"Invlid `ref_level`, which should be either equal to -1 "
"or in [0, %d)",
y_lod.size());
if (ref_level == -1) ref_level = y_lod.size() - 1;
if (x_lod.size() > 0) {
PADDLE_ENFORCE(x_lod[0].size() == y_lod[ref_level].size(),
"Level number of Input(X)'s lod could be 0. Otherwise "
"size of Input(X)'s first level lod should be equal to "
"size of Input(Y)'s referred level lod.");
}
int64_t out_first_dim = 0;
if (y_lod[ref_level].size() <= 1) {
out_first_dim = x_dims[0];
} else {
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int x_seq_len = 1;
if (x_lod.size() == 1) {
x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
}
out_first_dim +=
(y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len;
}
}
out_dims[0] = out_first_dim;
ctx->SetOutputDim("Out", out_dims);
} else {
out_dims[0] = -1;
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}
};
......@@ -42,83 +98,81 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor or LoDTensor) The input(X) of this operator can be a "
"LoDTensor or a base Tensor.");
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
"level is at most 1.");
AddInput("Y",
"(LoDTensor)The reference input(Y) of sequence_expand op."
"It must be a LoDTensor with k-level(k>0)."
"The input(X) will be expanded according to LOD of input(Y)."
"The element numbers of last level in input(Y) "
"must be equal to dims[0] of input(X).");
"(LoDTensor, default LoDTensor<float>) Referred LoDTensor whose "
"lod (specified level) is referred by Input(X).");
AddOutput("Out",
"(LodTensor)The output of sequence_expand op."
"The lod of output will be as same as input(Y)'s lod.");
"(LodTensor, default LoDTensor<float>) Output LoDTensor which is "
"generated from Input(X) by referring lod of Input(Y).");
AddAttr<int>("ref_level", "Specify lod level of Input(Y).").SetDefault(-1);
AddComment(R"DOC(
Sequence Expand Operator.
This operator expands input(X) according to LOD of input(Y).
This operator expands `X` according to specified level lod of `Y`. Current
implementation constaints that lod level of `X` should be at most 1. Attribute
`ref_level` is used to specify which level lod of `Y` is referred to expand `X`.
If set `ref_level` to -1, then last level lod of `Y` would be referred.
Please note, rank of `X` should be at least 2, when the rank exceeds 2, `X`
would be viewed as a 2-D tensor.
Following are cases to better explain how this works:
Case 1:
Given a 2-level LoDTensor input(X)
X.lod = [[0, 2, 3],
[0, 1, 3, 4]]
X.data = [a, b, c, d]
Given a 1-level LoDTensor input(X)
X.lod = [[0, 2, 4]]
X.data = [[a], [b], [c], [d]]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 2-level LoDTensor
Out.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
Out.data = [a, a, a, b, b, b, c, d]
ref_level: 0
then we get 1-level LoDTensor
Out.lod = [[0, 2, 4, 6, 8]]
Out.data = [[a], [b], [a], [b], [c], [d], [c], [d]]
Out.dims = [8, 1]
Case 2:
Given 1-level LoDTensor input(X)
X.lod = [[0, 1, 4]]
X.data = [[a], [b], [c], [d]]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
ref_level: 0
then we get 1-level LoDTensor
Out.lod = [[0, 1, 2, 5, 8]]
Out.data = [[a], [a], [b], [c], [d], [b], [c], [d]]
Out.dims = [8, 1]
Case 3:
Given a common Tensor input(X)
X.data = [a, b, c]
X.data = [[a], [b], [c]]
X.dims = [3, 1]
and input(Y)
Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 1-level LoDTensor
Out.lod = [[0, 2, 3, 6]]
Out.data = [a, a, b, c, c, c]
ref_level: -1
then we get a common Tensor
Out.data = [[a], [a], [b], [c], [c], [c]]
Out.dims = [6, 1]
Case 3:
Case 4:
Given a common Tensor input(X)
X.data = [[a, b], [c, d], [e, f]]
X.dims = [3, 2]
and input(Y)
Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 1-level LoDTensor
Out.lod = [[0, 2, 3, 6]]
Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]]
ref_level: 0
then we get a common LoDTensor
Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]]
Out.dims = [6, 2]
Case 4:
Given 2-level a LoDTensor input(X)
X.lod = [[0, 2, 3],
[0, 1, 3, 4]]
X.data = [a, b, c, d]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 2-level LoDTensor
Out.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
Out.data = [a, a, a, b, b, b, d, d]
Out.dims = [8, 1]
)DOC");
}
};
......@@ -129,12 +183,14 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasInput("Out"));
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The input(Out@GRAD) should not be null");
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
......@@ -149,7 +205,13 @@ REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker,
sequence_expand_grad, ops::SequenceExpandOpGrad);
REGISTER_OP_CPU_KERNEL(
sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
sequence_expand_grad,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -18,7 +18,14 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>);
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
sequence_expand_grad,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -16,45 +16,75 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class SequenceExpandKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
const T* x_data = x->data<T>();
auto x_dims = x->dims();
auto* y = context.Input<LoDTensor>("Y");
PADDLE_ENFORCE(!y->lod().empty(), "y should have lod");
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims[0]),
y->lod().back().size() - 1,
"The size of last lod level in Input(Y)"
"must be equal to dims[0] of Input(X).");
out->set_lod(y->lod());
auto* place =
context.template device_context<DeviceContext>().eigen_device();
size_t element_len = framework::product(x_dims) / x_dims[0];
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back();
for (size_t i = 0; i < out_starts.size() - 1; i++) {
int scale = out_starts[i + 1] - out_starts[i];
Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
x_t(x_data, 1, element_len);
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
out_t(out_data, scale, element_len);
Eigen::array<int, 2> cast({{scale, 1}});
out_t.device(*place) = x_t.broadcast(cast);
x_data += element_len;
out_data += element_len * scale;
auto* out = context.Output<LoDTensor>("Out");
int ref_level = context.Attr<int>("ref_level");
auto& x_lod = x->lod();
auto& y_lod = y->lod();
if (ref_level == -1) ref_level = y_lod.size() - 1;
out->mutable_data<T>(context.GetPlace());
if (y_lod[ref_level].size() <= 1) {
framework::TensorCopy(*x, context.GetPlace(), out);
return;
}
auto& out_lod = *out->mutable_lod();
if (x_lod.size() == 1) {
out_lod.resize(1);
out_lod[0] = {0};
}
int out_offset = 0;
auto& eigen_place =
*context.template device_context<DeviceContext>().eigen_device();
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
int x_start = i - 1;
int x_end = i;
if (x_lod.size() == 1) {
x_start = x_lod[0][i - 1];
x_end = x_lod[0][i];
}
int x_seq_len = x_end - x_start;
if (repeat_num > 0) {
auto x_sub_tensor = x->Slice(x_start, x_end);
x_sub_tensor.Resize({1, x_sub_tensor.numel()});
int out_start = out_offset;
if (x_lod.size() == 1) {
out_start = out_lod[0][out_offset];
}
auto out_sub_tensor =
out->Slice(out_start, out_start + x_seq_len * repeat_num);
out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]});
EigenMatrix<T>::From(out_sub_tensor).device(eigen_place) =
EigenMatrix<T>::From(x_sub_tensor)
.broadcast(Eigen::array<int, 2>({{repeat_num, 1}}));
}
for (int j = 0; j < repeat_num; ++j) {
if (x_lod.size() == 1) {
out_lod[0].push_back(out_lod[0].back() + x_seq_len);
}
out_offset++;
}
}
}
};
......@@ -75,27 +105,51 @@ template <typename DeviceContext, typename T>
class SequenceExpandGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* g_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Input<LoDTensor>("Out");
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
auto out_last_level = out->lod().back();
d_x->set_lod(x->lod());
const T* d_out_data = d_out->data<T>();
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
size_t element_len = d_out->numel() / d_out->dims()[0];
for (size_t i = 0; i < out_last_level.size() - 1; ++i) {
size_t repeat = out_last_level[i + 1] - out_last_level[i];
Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
d_out_t(d_out_data, static_cast<int>(repeat), element_len);
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>>
d_x_t(d_x_data, static_cast<int>(element_len));
auto place =
context.template device_context<DeviceContext>().eigen_device();
d_x_t.device(*place) = d_out_t.sum(Eigen::array<int, 1>({{0}}));
d_out_data += (repeat * element_len);
d_x_data += element_len;
auto* y = context.Input<LoDTensor>("Y");
auto* g_x = context.Output<LoDTensor>(framework::GradVarName("X"));
int ref_level = context.Attr<int>("ref_level");
g_x->mutable_data<T>(context.GetPlace());
g_x->set_lod(x->lod());
auto& x_lod = x->lod();
auto& y_lod = y->lod();
if (ref_level == -1) ref_level = y_lod.size() - 1;
// just copy the gradient
if (y_lod[ref_level].size() <= 1) {
framework::TensorCopy(*g_out, context.GetPlace(), g_x);
return;
}
auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, g_x, static_cast<T>(0));
int g_out_offset = 0;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
if (repeat_num > 0) {
int x_start = i - 1;
int x_end = i;
if (x_lod.size() == 1) {
x_start = x_lod[0][i - 1];
x_end = x_lod[0][i];
}
int x_seq_len = x_end - x_start;
auto g_x_sub = g_x->Slice(x_start, x_end);
g_x_sub.Resize(flatten_to_1d(g_x_sub.dims()));
int g_out_end = g_out_offset + repeat_num * x_seq_len;
auto g_out_sub = g_out->Slice(g_out_offset, g_out_end);
g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]});
math::ColwiseSum<DeviceContext, T> col_sum;
col_sum(dev_ctx, g_out_sub, &g_x_sub);
g_out_offset += repeat_num * x_seq_len;
}
}
}
};
......
......@@ -483,8 +483,123 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
#endif // PADDLE_CUDA_FP16
// Arithmetic operators on ARMv8.2-A CPU
#if defined(PADDLE_WITH_NATIVE_FP16)
// Arithmetic operators for float16 on GPU
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hadd(half(a), half(b)));
#else
return float16(float(a) + float(b));
#endif
}
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hsub(half(a), half(b)));
#else
return float16(float(a) - float(b));
#endif
}
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hmul(half(a), half(b)));
#else
return float16(float(a) * float(b));
#endif
}
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
// TODO(kexinzhao): check which cuda version starts to support __hdiv
float num = __half2float(half(a));
float denom = __half2float(half(b));
return float16(num / denom);
#else
return float16(float(a) / float(b));
#endif
}
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hneg(half(a)));
#else
float16 res;
res.x = a.x ^ 0x8000;
return res;
#endif
}
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
a = a + b;
return a;
}
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
a = a - b;
return a;
}
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
a = a * b;
return a;
}
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
a = a / b;
return a;
}
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(half(a), half(b));
#else
return float(a) == float(b);
#endif
}
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(half(a), half(b));
#else
return float(a) != float(b);
#endif
}
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(half(a), half(b));
#else
return float(a) < float(b);
#endif
}
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(half(a), half(b));
#else
return float(a) <= float(b);
#endif
}
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(half(a), half(b));
#else
return float(a) > float(b);
#endif
}
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(half(a), half(b));
#else
return float(a) >= float(b);
#endif
}
// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
HOST inline float16 operator+(const float16& a, const float16& b) {
float16 res;
asm volatile(
......@@ -668,71 +783,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
return (res & 0xffff) != 0;
}
// Arithmetic operators, software emulated on other CPU
// Arithmetic operators for float16, software emulated on other CPU
#else
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
HOST inline float16 operator+(const float16& a, const float16& b) {
return float16(float(a) + float(b));
}
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
HOST inline float16 operator-(const float16& a, const float16& b) {
return float16(float(a) - float(b));
}
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
HOST inline float16 operator*(const float16& a, const float16& b) {
return float16(float(a) * float(b));
}
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
HOST inline float16 operator/(const float16& a, const float16& b) {
return float16(float(a) / float(b));
}
HOSTDEVICE inline float16 operator-(const float16& a) {
HOST inline float16 operator-(const float16& a) {
float16 res;
res.x = a.x ^ 0x8000;
return res;
}
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
HOST inline float16& operator+=(float16& a, const float16& b) {
a = float16(float(a) + float(b));
return a;
}
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
HOST inline float16& operator-=(float16& a, const float16& b) {
a = float16(float(a) - float(b));
return a;
}
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
HOST inline float16& operator*=(float16& a, const float16& b) {
a = float16(float(a) * float(b));
return a;
}
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
HOST inline float16& operator/=(float16& a, const float16& b) {
a = float16(float(a) / float(b));
return a;
}
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
HOST inline bool operator==(const float16& a, const float16& b) {
return float(a) == float(b);
}
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
HOST inline bool operator!=(const float16& a, const float16& b) {
return float(a) != float(b);
}
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
HOST inline bool operator<(const float16& a, const float16& b) {
return float(a) < float(b);
}
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
HOST inline bool operator<=(const float16& a, const float16& b) {
return float(a) <= float(b);
}
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
HOST inline bool operator>(const float16& a, const float16& b) {
return float(a) > float(b);
}
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
HOST inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b);
}
#endif
......
......@@ -29,8 +29,8 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs)
bool Header::Parse(std::istream& is) {
uint32_t magic;
size_t read_size =
is.readsome(reinterpret_cast<char*>(&magic), sizeof(uint32_t));
is.read(reinterpret_cast<char*>(&magic), sizeof(uint32_t));
size_t read_size = is.gcount();
if (read_size < sizeof(uint32_t)) {
return false;
}
......
......@@ -131,7 +131,7 @@ def make_channel(dtype, capacity=0):
return channel
def channel_send(channel, value):
def channel_send(channel, value, copy=False):
"""
Sends a value through a channel variable. Used by an unbuffered or buffered
channel to pass data from within or to a concurrent Go block, where
......@@ -141,6 +141,8 @@ def channel_send(channel, value):
channel (Variable|Channel): Channel variable created using
`make_channel`.
value (Variable): Value to send to channel
copy (bool): Copy data while channel send. If False, then data
is moved. The input cannot be used after move.
Returns:
Variable: The boolean status on whether or not the channel
successfully sent the passed value.
......@@ -162,11 +164,26 @@ def channel_send(channel, value):
type=core.VarDesc.VarType.LOD_TENSOR,
dtype=core.VarDesc.VarType.BOOL)
X = value
if copy is True:
copied_X = helper.create_variable(
name=unique_name.generate(value.name + '_copy'),
type=value.type,
dtype=value.dtype,
shape=value.shape,
lod_level=value.lod_level,
capacity=value.capacity)
assign_op = channel_send_block.append_op(
type="assign_op", inputs={"X": value}, outputs={"Out": copied_X})
X = copied_X
channel_send_op = channel_send_block.append_op(
type="channel_send",
inputs={
"Channel": channel,
"X": value,
"X": X,
},
outputs={"Status": status})
......
......@@ -73,6 +73,7 @@ __all__ = [
'smooth_l1',
'one_hot',
'autoincreased_step_counter',
'lod_reset',
]
......@@ -1808,52 +1809,52 @@ def conv2d_transpose(input,
return out
def sequence_expand(x, y, name=None):
def sequence_expand(x, y, ref_level=-1, name=None):
"""Sequence Expand Layer. This layer will expand the input variable **x**
according to LoD information of **y**. And the following examples will
explain how sequence_expand works:
according to specified level lod of **y**. Please note that lod level of
**x** is at most 1 and rank of **x** is at least 2. When rank of **x**
is greater than 2, then it would be viewed as a 2-D tensor.
Following examples will explain how sequence_expand works:
.. code-block:: text
* Case 1
x is a LoDTensor:
x.lod = [[0, 2, 3],
[0, 1, 3, 4]]
x.data = [a, b, c, d]
x.lod = [[0, 2, 4]]
x.data = [[a], [b], [c], [d]]
x.dims = [4, 1]
y is a LoDTensor:
y.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
with condition len(y.lod[-1]) - 1 == x.dims[0]
ref_level: 0
then output is a 2-level LoDTensor:
out.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
out.data = [a, a, a, b, b, b, c, d]
then output is a 1-level LoDTensor:
out.lod = [[0, 2, 4, 6, 8]]
out.data = [[a], [b], [a], [b], [c], [d], [c], [d]]
out.dims = [8, 1]
* Case 2
x is a Tensor:
x.data = [a, b, c]
x.data = [[a], [b], [c]]
x.dims = [3, 1]
y is a LoDTensor:
y.lod = [[0, 2, 3, 6]]
with condition len(y.lod[-1]) - 1 == x.dims[0]
y.lod = [[0, 2, 2, 5]]
then output is a 1-level LoDTensor:
out.lod = [[0, 2, 3, 6]]
out.data = [a, a, b, c, c, c]
out.dims = [6, 1]
ref_level: -1
then output is a Tensor:
out.data = [[a], [a], [c], [c], [c]]
out.dims = [5, 1]
Args:
x (Variable): The input variable which is a Tensor or LoDTensor.
y (Variable): The input variable which is a LoDTensor.
ref_level (int): Lod level of `y` to be referred by `x`. If set to -1,
refer the last level of lod.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically.
Returns:
Variable: The expanded variable which is a LoDTensor.
......@@ -1864,14 +1865,17 @@ def sequence_expand(x, y, name=None):
x = fluid.layers.data(name='x', shape=[10], dtype='float32')
y = fluid.layers.data(name='y', shape=[10, 20],
dtype='float32', lod_level=1)
out = layers.sequence_expand(x=x, y=y)
out = layers.sequence_expand(x=x, y=y, ref_level=0)
"""
helper = LayerHelper('sequence_expand', input=x, **locals())
dtype = helper.input_dtype()
tmp = helper.create_tmp_variable(dtype)
helper.append_op(
type='sequence_expand', inputs={'X': x,
'Y': y}, outputs={'Out': tmp})
type='sequence_expand',
inputs={'X': x,
'Y': y},
outputs={'Out': tmp},
attrs={'ref_level': ref_level})
return tmp
......@@ -2225,7 +2229,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None):
keep_dim (bool|False): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true.
name(str|None): A name for this layer(optional). If set None, the
name(str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
Returns:
......@@ -2241,7 +2245,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_prod(x) # [0.0002268]
fluid.layers.reduce_prod(x, dim=0) # [0.02, 0.06, 0.3, 0.63]
fluid.layers.reduce_prod(x, dim=-1) # [0.027, 0.0084]
fluid.layers.reduce_prod(x, dim=1,
fluid.layers.reduce_prod(x, dim=1,
keep_dim=True) # [[0.027], [0.0084]]
"""
helper = LayerHelper('reduce_prod', **locals())
......@@ -3292,3 +3296,98 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
counter.stop_gradient = True
return counter
def lod_reset(x, y=None, target_lod=None):
"""
LoD Reset Operator. Set LoD of **x** to a new one specified by **y** or
**target_lod**. When **y** provided, **y.lod** would be considered as target
LoD first, otherwise **y.data** would be considered as target LoD. If **y**
is not provided, target LoD should be specified by **target_lod**.
If target LoD is specified by **Y.data** or **target_lod**, only one level
LoD is supported.
.. code-block:: text
* Example 1:
Given a 1-level LoDTensor x:
x.lod = [[ 0, 2, 5 6 ]]
x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
x.dims = [6, 1]
target_lod: [0, 4, 6]
then we get a 1-level LoDTensor:
out.lod = [[ 0, 4, 6 ]]
out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
out.dims = [6, 1]
* Example 2:
Given a 1-level LoDTensor x:
x.lod = [[ 0, 2, 5 6 ]]
x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
x.dims = [6, 1]
y is a Tensor:
y.data = [[0, 2, 6]]
y.dims = [1, 3]
then we get a 1-level LoDTensor:
out.lod = [[ 0, 2, 6 ]]
out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
out.dims = [6, 1]
* Example 3:
Given a 1-level LoDTensor x:
x.lod = [[ 0, 2, 5 6 ]]
x.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
x.dims = [6, 1]
y is a 2-level LoDTensor:
y.lod = [[0, 2, 4], [0, 2, 5, 6]]
y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]]
y.dims = [6, 1]
then we get a 2-level LoDTensor:
out.lod = [[0, 2, 4], [0, 2, 5, 6]]
out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
out.dims = [6, 1]
Args:
x (Variable): Input variable which could be a Tensor or LodTensor.
y (Variable|None): If provided, output's LoD would be derived from y.
target_lod (list|tuple|None): One level LoD which should be considered
as target LoD when y not provided.
Returns:
Variable: Output variable with LoD specified by this operator.
Raises:
ValueError: If y and target_lod are both None.
Examples:
.. code-block:: python
x = layers.data(name='x', shape=[10])
y = layers.data(name='y', shape=[10, 20], lod_level=2)
out = layers.lod_reset(x=x, y=y)
"""
helper = LayerHelper("lod_reset", **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
if y is not None:
helper.append_op(
type="lod_reset", inputs={'X': x,
'Y': y}, outputs={'Out': out})
elif target_lod is not None:
helper.append_op(
type="lod_reset",
inputs={'X': x},
attrs={'target_lod': target_lod},
outputs={'Out': out})
else:
raise ValueError("y and target_lod should not be both None.")
return out
......@@ -24,7 +24,9 @@ from layer_helper import LayerHelper
from regularizer import append_regularization_ops
from clip import append_gradient_clip_ops, error_clip_callback
__all__ = ['SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad']
__all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Adadelta'
]
class Optimizer(object):
......@@ -580,6 +582,88 @@ class DecayedAdagradOptimizer(Optimizer):
return decayed_adagrad_op
class AdadeltaOptimizer(Optimizer):
"""
**Adadelta Optimizer**
Simple Adadelta optimizer with average squared grad state and
average squared update state.
The details of adadelta please refer to this
`ADADELTA: AN ADAPTIVE LEARNING RATE METHOD
<http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf>`_.
.. math::
E(g_t^2) &= \\rho * E(g_{t-1}^2) + (1-\\rho) * g^2 \\\\
learning\\_rate &= sqrt( ( E(dx_{t-1}^2) + \\epsilon ) / ( \\
E(g_t^2) + \\epsilon ) ) \\\\
E(dx_t^2) &= \\rho * E(dx_{t-1}^2) + (1-\\rho) * (-g*learning\\_rate)^2
Args:
learning_rate(float): global leraning rate
rho(float): rho in equation
epsilon(float): epsilon in equation
Examples:
.. code-block:: python
optimizer = fluid.optimizer.Adadelta(
learning_rate=0.0003, epsilon=1.0e-6, rho=0.95)
_, params_grads = optimizer.minimize(cost)
"""
_avg_squared_grad_acc_str = "_avg_squared_grad"
_avg_squared_update_acc_str = "_avg_squared_update"
def __init__(self, learning_rate, epsilon=1.0e-6, rho=0.95, **kwargs):
if learning_rate is None:
raise ValueError("learning_rate is not set.")
if epsilon is None:
raise ValueError("epsilon is not set.")
if rho is None:
raise ValueError("rho is not set.")
super(AdadeltaOptimizer, self).__init__(
learning_rate=learning_rate, **kwargs)
self.type = "adadelta"
self._epsilon = epsilon
self._rho = rho
def _create_accumulators(self, block, parameters):
if not isinstance(block, framework.Block):
raise TypeError("block is not instance of framework.Block.")
for p in parameters:
self._add_accumulator(self._avg_squared_grad_acc_str, p)
self._add_accumulator(self._avg_squared_update_acc_str, p)
def _append_optimize_op(self, block, param_and_grad):
if not isinstance(block, framework.Block):
raise TypeError("block is not instance of framework.Block.")
avg_squared_grad_acc = self._get_accumulator(
self._avg_squared_grad_acc_str, param_and_grad[0])
avg_squared_update_acc = self._get_accumulator(
self._avg_squared_update_acc_str, param_and_grad[0])
# Create the adadelta optimizer op
adadelta_op = block.append_op(
type=self.type,
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"AvgSquaredGrad": avg_squared_grad_acc,
"AvgSquaredUpdate": avg_squared_update_acc
},
outputs={
"ParamOut": param_and_grad[0],
"AvgSquaredGradOut": avg_squared_grad_acc,
"AvgSquaredUpdateOut": avg_squared_update_acc
},
attrs={"epsilon": self._epsilon,
"rho": self._rho})
return adadelta_op
# We short the class name, since users will use the optimizer with the package
# name. The sample code:
#
......@@ -594,3 +678,4 @@ Adagrad = AdagradOptimizer
Adam = AdamOptimizer
Adamax = AdamaxOptimizer
DecayedAdagrad = DecayedAdagradOptimizer
Adadelta = AdadeltaOptimizer
......@@ -118,12 +118,12 @@ def decoder_decode(context, is_sparse):
is_sparse=is_sparse)
# use rnn unit to update rnn
current_state = pd.fc(input=[pre_ids_emb, pre_state_expanded],
current_state = pd.fc(input=[pre_state_expanded, pre_ids_emb],
size=decoder_size,
act='tanh')
current_state_with_lod = pd.lod_reset(x=current_state, y=pre_score)
# use score to do beam search
current_score = pd.fc(input=current_state,
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=50)
......
......@@ -14,6 +14,7 @@
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
......@@ -82,5 +83,37 @@ class TestDropoutOp5(OpTest):
self.check_output()
class TestFP16DropoutOp(OpTest):
def setUp(self):
self.op_type = "dropout"
self.init_test_case()
x = np.random.random(self.input_size).astype("float16")
out = x * (1.0 - self.prob)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {
'dropout_prob': self.prob,
'fix_seed': self.fix_seed,
'is_test': True
}
self.outputs = {'Out': out}
def init_test_case(self):
self.input_size = [32, 64]
self.prob = 0.35
self.fix_seed = True
def test_check_output(self):
if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"):
self.check_output_with_place(core.CUDAPlace(0), atol=1e-3)
class TestFP16DropoutOp2(TestFP16DropoutOp):
def init_test_case(self):
self.input_size = [32, 64, 3]
self.prob = 0.75
self.fix_seed = False
if __name__ == '__main__':
unittest.main()
......@@ -181,8 +181,8 @@ class TestBook(unittest.TestCase):
with program_guard(program):
x = layers.data(name='x', shape=[10], dtype='float32')
y = layers.data(
name='y', shape=[10, 20], dtype='float32', lod_level=1)
self.assertIsNotNone(layers.sequence_expand(x=x, y=y))
name='y', shape=[10, 20], dtype='float32', lod_level=2)
self.assertIsNotNone(layers.sequence_expand(x=x, y=y, ref_level=1))
print(str(program))
def test_lstm_unit(self):
......@@ -327,6 +327,15 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(loss)
print(str(program))
def test_lod_reset(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[10], dtype='float32')
y = layers.data(
name='y', shape=[10, 20], dtype='float32', lod_level=2)
print(layers.lod_reset(x=x, y=y))
print(str(program))
if __name__ == '__main__':
unittest.main()
......@@ -42,7 +42,7 @@ class TestLodResetOpByInput(OpTest):
target_lod_0 = [0, 4, 7, 10]
self.inputs = {
'X': (x, lod),
'TargetLoD': np.array([target_lod_0]).astype('int32')
'Y': np.array([target_lod_0]).astype('int32')
}
self.outputs = {'Out': (x, [target_lod_0])}
......@@ -50,7 +50,7 @@ class TestLodResetOpByInput(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD"))
self.check_grad(["X"], "Out", no_grad_set=set("Y"))
class TestLodResetOpBoth(OpTest):
......@@ -62,7 +62,7 @@ class TestLodResetOpBoth(OpTest):
target_lod_0_in = [0, 4, 7, 10]
self.inputs = {
'X': (x, lod),
'TargetLoD': np.array(target_lod_0_in).astype('int32')
'Y': np.array(target_lod_0_in).astype('int32')
}
self.attrs = {'target_lod': target_lod_0_attr}
self.outputs = {'Out': (x, [target_lod_0_in])}
......@@ -71,7 +71,24 @@ class TestLodResetOpBoth(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD"))
self.check_grad(["X"], "Out", no_grad_set=set("Y"))
class TestLodResetOpYIsLoDTensor(OpTest):
def setUp(self):
self.op_type = "lod_reset"
x = np.random.random((10, 20)).astype("float32")
lod = [[0, 3, 5, 10]]
y = np.random.random((10, 10)).astype("float32")
target_lod_0 = [[0, 4, 7, 10]]
self.inputs = {'X': (x, lod), 'Y': (y, target_lod_0)}
self.outputs = {'Out': (x, target_lod_0)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", no_grad_set=set("Y"))
if __name__ == '__main__':
......
......@@ -27,12 +27,36 @@ class TestSequenceExpand(OpTest):
def compute(self):
x = self.inputs['X']
x_data, x_lod = x if type(x) == tuple else (x, None)
n = 1 + x_data.shape[0] if not x_lod else len(x_lod[0])
y_data, y_lod = self.inputs['Y']
repeats = [((y_lod[-1][i + 1] - y_lod[-1][i]))
for i in range(len(y_lod[-1]) - 1)]
out = x_data.repeat(repeats, axis=0)
self.outputs = {'Out': out}
if hasattr(self, 'attrs'):
ref_level = self.attrs['ref_level']
else:
ref_level = len(y_lod) - 1
out = np.zeros(shape=((0, ) + x_data.shape[1:]), dtype=x_data.dtype)
if x_lod is None:
x_idx = [i for i in xrange(x_data.shape[0] + 1)]
else:
x_idx = x_lod[0]
out_lod = [[0]]
for i in xrange(1, len(y_lod[ref_level])):
repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]
x_len = x_idx[i] - x_idx[i - 1]
if repeat_num > 0:
x_sub = x_data[x_idx[i - 1]:x_idx[i], :]
x_sub = np.repeat(x_sub, repeat_num, axis=0)
out = np.vstack((out, x_sub))
if x_lod is not None:
for j in xrange(repeat_num):
out_lod[0].append(out_lod[0][-1] + x_len)
if x_lod is None:
self.outputs = {'Out': out}
else:
self.outputs = {'Out': (out, out_lod)}
def setUp(self):
self.op_type = 'sequence_expand'
......@@ -52,7 +76,8 @@ class TestSequenceExpandCase1(TestSequenceExpand):
x_lod = [[0, 2, 5]]
y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32')
y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
self.inputs = {'X': x_data, 'Y': (y_data, y_lod)}
self.attrs = {'ref_level': 0}
class TestSequenceExpandCase2(TestSequenceExpand):
......@@ -60,8 +85,9 @@ class TestSequenceExpandCase2(TestSequenceExpand):
x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32')
x_lod = [[0, 1]]
y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32')
y_lod = [[0, 2]]
y_lod = [[0, 2], [0, 2]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
self.attrs = {'ref_level': 0}
class TestSequenceExpandCase3(TestSequenceExpand):
......@@ -75,14 +101,9 @@ class TestSequenceExpandCase3(TestSequenceExpand):
class TestSequenceExpandCase4(TestSequenceExpand):
def set_data(self):
x_data = np.array(
[0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape(
[2, 5]).astype('float32')
x_lod = [[
0,
1,
2,
]]
data = [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]
x_data = np.array(data).reshape([5, 2]).astype('float32')
x_lod = [[0, 2, 5]]
y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32')
y_lod = [[0, 1, 2], [0, 1, 2]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册