提交 ad63722e 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into average_model

...@@ -47,3 +47,10 @@ DecayedAdagrad ...@@ -47,3 +47,10 @@ DecayedAdagrad
:members: :members:
:noindex: :noindex:
Adadelta
--------------
.. autoclass:: paddle.fluid.optimizer.AdadeltaOptimizer
:members:
:noindex:
...@@ -2,6 +2,9 @@ Distributed Training ...@@ -2,6 +2,9 @@ Distributed Training
==================== ====================
The effectiveness of the deep learning model is often directly related to the scale of the data: it can generally achieve better results after increasing the size of the dataset on the same model. However, it can not fit in one single computer when the amount of data increases to a certain extent. At this point, using multiple computers for distributed training is a natural solution. In distributed training, the training data is divided into multiple copies (sharding), and multiple machines participating in the training read their own data for training and collaboratively update the parameters of the overall model. The effectiveness of the deep learning model is often directly related to the scale of the data: it can generally achieve better results after increasing the size of the dataset on the same model. However, it can not fit in one single computer when the amount of data increases to a certain extent. At this point, using multiple computers for distributed training is a natural solution. In distributed training, the training data is divided into multiple copies (sharding), and multiple machines participating in the training read their own data for training and collaboratively update the parameters of the overall model.
Distributed training generally has framwork as shown below:
.. image:: src/ps_en.png .. image:: src/ps_en.png
:width: 500 :width: 500
......
...@@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel {
} }
}; };
template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker) DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
...@@ -73,7 +72,6 @@ are set equal to their corresponding inputs. ...@@ -73,7 +72,6 @@ are set equal to their corresponding inputs.
} }
}; };
template <typename AttrType>
class DropoutOpGrad : public framework::OperatorWithKernel { class DropoutOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad, REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad,
ops::DropoutOpGrad<float>); ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout, dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>);
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout_grad, dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>); ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -18,17 +18,18 @@ limitations under the License. */ ...@@ -18,17 +18,18 @@ limitations under the License. */
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, typename AttrType> template <typename T>
__global__ void RandomGenerator(const size_t n, const int seed, __global__ void RandomGenerator(const size_t n, const int seed,
const AttrType dropout_prob, const T* src, const float dropout_prob, const T* src,
T* mask_data, T* dst) { T* mask_data, T* dst) {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed); rng.seed(seed);
thrust::uniform_real_distribution<AttrType> dist(0, 1); thrust::uniform_real_distribution<float> dist(0, 1);
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < n; idx += blockDim.x * gridDim.x) { for (; idx < n; idx += blockDim.x * gridDim.x) {
...@@ -44,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed, ...@@ -44,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed,
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to // Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random. // implement uniform random.
template <typename Place, typename T, typename AttrType> template <typename Place, typename T>
class GPUDropoutKernel : public framework::OpKernel<T> { class GPUDropoutKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out"); auto* y = context.Output<Tensor>("Out");
y->mutable_data<T>(context.GetPlace()); y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
auto X = EigenMatrix<T>::Reshape(*x, 1); auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
...@@ -70,11 +71,11 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -70,11 +71,11 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int threads = 512; int threads = 512;
int grid = (x->numel() + threads - 1) / threads; int grid = (x->numel() + threads - 1) / threads;
RandomGenerator<T, AttrType><<<grid, threads, 0, RandomGenerator<
context.cuda_device_context().stream()>>>( T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, seed, dropout_prob, x_data, mask_data, y_data); size, seed, dropout_prob, x_data, mask_data, y_data);
} else { } else {
Y.device(place) = X * (1.0f - dropout_prob); Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
} }
} }
}; };
...@@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
dropout, dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<paddle::platform::CUDADeviceContext, float, float>); ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(dropout_grad,
dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>);
ops::DropoutGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T, typename AttrType> template <typename DeviceContext, typename T>
class CPUDropoutKernel : public framework::OpKernel<T> { class CPUDropoutKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h" #include "paddle/fluid/operators/detail/simple_block_queue.h"
...@@ -89,6 +90,10 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -89,6 +90,10 @@ class ListenAndServOp : public framework::OperatorBase {
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program(); auto *program = block->Program();
int num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
...@@ -132,12 +137,36 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -132,12 +137,36 @@ class ListenAndServOp : public framework::OperatorBase {
rpc_service_->ShutDown(); rpc_service_->ShutDown();
break; break;
} }
try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ // put optimize blocks in the thread pool to start run, the last block
false /*create_local_scope*/, false /*create_vars*/); // should be global ops.
} catch (std::exception &e) { // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
LOG(ERROR) << "run sub program error " << e.what(); // and this will still work.
std::vector<std::future<void>> fs;
// block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back(framework::Async([&executor, &program, &recv_scope,
blkid]() {
int run_block = blkid; // thread local
try {
executor.Run(*program, &recv_scope, run_block,
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
}
for (int i = 0; i < num_blocks - 2; ++i) fs[i].wait();
// Run global block at final step, or block1 if there are only 2 blocks
if (num_blocks >= 2) {
try {
executor.Run(*program, &recv_scope, num_blocks - 1,
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
} }
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next // sum the input sparse variables which rows is empty at the next
// mini-batch. // mini-batch.
......
...@@ -371,6 +371,8 @@ template struct RowwiseAdd<platform::CPUDeviceContext, double>; ...@@ -371,6 +371,8 @@ template struct RowwiseAdd<platform::CPUDeviceContext, double>;
template struct ColwiseSum<platform::CPUDeviceContext, float>; template struct ColwiseSum<platform::CPUDeviceContext, float>;
template struct ColwiseSum<platform::CPUDeviceContext, double>; template struct ColwiseSum<platform::CPUDeviceContext, double>;
template struct ColwiseSum<platform::CPUDeviceContext, int>;
template struct ColwiseSum<platform::CPUDeviceContext, int64_t>;
template struct RowwiseSum<platform::CPUDeviceContext, float>; template struct RowwiseSum<platform::CPUDeviceContext, float>;
template struct RowwiseSum<platform::CPUDeviceContext, double>; template struct RowwiseSum<platform::CPUDeviceContext, double>;
......
...@@ -422,6 +422,8 @@ struct RowwiseAdd<platform::CUDADeviceContext, T> { ...@@ -422,6 +422,8 @@ struct RowwiseAdd<platform::CUDADeviceContext, T> {
template struct RowwiseAdd<platform::CUDADeviceContext, float>; template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>; template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>; template struct ColwiseSum<platform::CUDADeviceContext, float>;
template struct ColwiseSum<platform::CUDADeviceContext, int>;
template struct ColwiseSum<platform::CUDADeviceContext, int64_t>;
// template struct ColwiseSum<platform::CUDADeviceContext, double>; // template struct ColwiseSum<platform::CUDADeviceContext, double>;
// The ColwiseSum<platform::CUDADeviceContext, double> failed in debug mode, // The ColwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
// and only failed for this case. So reimplemented it. // and only failed for this case. So reimplemented it.
......
...@@ -48,20 +48,24 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -48,20 +48,24 @@ class DoubleBufferReader : public framework::DecoratedReader {
void start_thread() { void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize); buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
std::thread prefetch([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
prefetch.detach();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override; void ReInit() override;
~DoubleBufferReader() { buffer_->Close(); } ~DoubleBufferReader() {
buffer_->Close();
prefetcher_.join();
delete buffer_;
}
bool HasNext() const override; bool HasNext() const override;
private: private:
void PrefetchThreadFunc(); void PrefetchThreadFunc();
std::thread prefetcher_;
framework::Channel<Item>* buffer_; framework::Channel<Item>* buffer_;
platform::Place place_; platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_; std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
...@@ -134,6 +138,8 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -134,6 +138,8 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void DoubleBufferReader::ReInit() { void DoubleBufferReader::ReInit() {
reader_->ReInit(); reader_->ReInit();
buffer_->Close(); buffer_->Close();
prefetcher_.join();
delete buffer_;
start_thread(); start_thread();
} }
...@@ -159,11 +165,12 @@ void DoubleBufferReader::PrefetchThreadFunc() { ...@@ -159,11 +165,12 @@ void DoubleBufferReader::PrefetchThreadFunc() {
if (!buffer_->Send(&batch)) { if (!buffer_->Send(&batch)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread terminates."; "prefetch thread will terminate.";
break; break;
} }
} }
buffer_->Close(); buffer_->Close();
VLOG(5) << "Prefetch thread terminates.";
} }
bool DoubleBufferReader::HasNext() const { bool DoubleBufferReader::HasNext() const {
......
...@@ -34,6 +34,9 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -34,6 +34,9 @@ class ShuffleReader : public framework::DecoratedReader {
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (iteration_pos_ >= buffer_.size()) { if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer"; VLOG(10) << "Resetting shuffle buffer";
ReadIntoBuffers(); ReadIntoBuffers();
...@@ -50,7 +53,6 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -50,7 +53,6 @@ class ShuffleReader : public framework::DecoratedReader {
buffer_.clear(); buffer_.clear();
buffer_.reserve(buffer_size_); buffer_.reserve(buffer_size_);
iteration_pos_ = 0; iteration_pos_ = 0;
PADDLE_ENFORCE(reader_->HasNext());
for (size_t i = 0; i < buffer_size_; ++i) { for (size_t i = 0; i < buffer_size_; ++i) {
if (!reader_->HasNext()) { if (!reader_->HasNext()) {
break; break;
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::LoDTensor;
class SequenceExpandOp : public framework::OperatorWithKernel { class SequenceExpandOp : public framework::OperatorWithKernel {
public: public:
...@@ -25,15 +25,71 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -25,15 +25,71 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(ctx->HasOutput("Out")); "Input(X) of SequenceExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y")); PADDLE_ENFORCE(ctx->HasInput("Y"),
framework::DDim out_dim; "Input(Y) of SequenceExpandOp should not be null.");
auto y_dim = ctx->GetInputDim("Y"); PADDLE_ENFORCE(ctx->HasOutput("Out"),
out_dim = ctx->GetInputDim("X"); "Output(Out) of SequenceExpandOp should not be null.");
out_dim[0] = y_dim[0];
ctx->ShareLoD("Y", "Out"); auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", out_dim); auto out_dims = x_dims;
int ref_level = ctx->Attrs().Get<int>("ref_level");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Dimension number of Input(X) should be at least 2.");
if (ctx->IsRuntime()) {
framework::Variable* x_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
framework::Variable* y_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Y")[0]);
auto& x_lod = x_var->Get<LoDTensor>().lod();
auto& y_lod = y_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE_LE(x_lod.size(), 1,
"Level number of Input(X)'s lod should not be "
"greater than 1.");
PADDLE_ENFORCE_GT(y_lod.size(), 0,
"Level number of Input(Y)'s lod should be "
"greater than 0.");
PADDLE_ENFORCE(
ref_level == -1 ||
(ref_level >= 0 && ref_level < static_cast<int>(y_lod.size())),
"Invlid `ref_level`, which should be either equal to -1 "
"or in [0, %d)",
y_lod.size());
if (ref_level == -1) ref_level = y_lod.size() - 1;
if (x_lod.size() > 0) {
PADDLE_ENFORCE(x_lod[0].size() == y_lod[ref_level].size(),
"Level number of Input(X)'s lod could be 0. Otherwise "
"size of Input(X)'s first level lod should be equal to "
"size of Input(Y)'s referred level lod.");
}
int64_t out_first_dim = 0;
if (y_lod[ref_level].size() <= 1) {
out_first_dim = x_dims[0];
} else {
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int x_seq_len = 1;
if (x_lod.size() == 1) {
x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
}
out_first_dim +=
(y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len;
}
}
out_dims[0] = out_first_dim;
ctx->SetOutputDim("Out", out_dims);
} else {
out_dims[0] = -1;
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
} }
}; };
...@@ -42,83 +98,81 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -42,83 +98,81 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor or LoDTensor) The input(X) of this operator can be a " "(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
"LoDTensor or a base Tensor."); "level is at most 1.");
AddInput("Y", AddInput("Y",
"(LoDTensor)The reference input(Y) of sequence_expand op." "(LoDTensor, default LoDTensor<float>) Referred LoDTensor whose "
"It must be a LoDTensor with k-level(k>0)." "lod (specified level) is referred by Input(X).");
"The input(X) will be expanded according to LOD of input(Y)."
"The element numbers of last level in input(Y) "
"must be equal to dims[0] of input(X).");
AddOutput("Out", AddOutput("Out",
"(LodTensor)The output of sequence_expand op." "(LodTensor, default LoDTensor<float>) Output LoDTensor which is "
"The lod of output will be as same as input(Y)'s lod."); "generated from Input(X) by referring lod of Input(Y).");
AddAttr<int>("ref_level", "Specify lod level of Input(Y).").SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
Sequence Expand Operator. Sequence Expand Operator.
This operator expands input(X) according to LOD of input(Y). This operator expands `X` according to specified level lod of `Y`. Current
implementation constaints that lod level of `X` should be at most 1. Attribute
`ref_level` is used to specify which level lod of `Y` is referred to expand `X`.
If set `ref_level` to -1, then last level lod of `Y` would be referred.
Please note, rank of `X` should be at least 2, when the rank exceeds 2, `X`
would be viewed as a 2-D tensor.
Following are cases to better explain how this works: Following are cases to better explain how this works:
Case 1: Case 1:
Given a 2-level LoDTensor input(X) Given a 1-level LoDTensor input(X)
X.lod = [[0, 2, 3], X.lod = [[0, 2, 4]]
[0, 1, 3, 4]] X.data = [[a], [b], [c], [d]]
X.data = [a, b, c, d]
X.dims = [4, 1] X.dims = [4, 1]
and input(Y) and input(Y)
Y.lod = [[0, 2, 4], Y.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]] [0, 3, 6, 7, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0] ref_level: 0
then we get 2-level LoDTensor then we get 1-level LoDTensor
Out.lod = [[0, 2, 4], Out.lod = [[0, 2, 4, 6, 8]]
[0, 3, 6, 7, 8]] Out.data = [[a], [b], [a], [b], [c], [d], [c], [d]]
Out.data = [a, a, a, b, b, b, c, d]
Out.dims = [8, 1] Out.dims = [8, 1]
Case 2: Case 2:
Given 1-level LoDTensor input(X)
X.lod = [[0, 1, 4]]
X.data = [[a], [b], [c], [d]]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
ref_level: 0
then we get 1-level LoDTensor
Out.lod = [[0, 1, 2, 5, 8]]
Out.data = [[a], [a], [b], [c], [d], [b], [c], [d]]
Out.dims = [8, 1]
Case 3:
Given a common Tensor input(X) Given a common Tensor input(X)
X.data = [a, b, c] X.data = [[a], [b], [c]]
X.dims = [3, 1] X.dims = [3, 1]
and input(Y) and input(Y)
Y.lod = [[0, 2, 3, 6]] Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0] ref_level: -1
then we get 1-level LoDTensor then we get a common Tensor
Out.lod = [[0, 2, 3, 6]] Out.data = [[a], [a], [b], [c], [c], [c]]
Out.data = [a, a, b, c, c, c]
Out.dims = [6, 1] Out.dims = [6, 1]
Case 3: Case 4:
Given a common Tensor input(X) Given a common Tensor input(X)
X.data = [[a, b], [c, d], [e, f]] X.data = [[a, b], [c, d], [e, f]]
X.dims = [3, 2] X.dims = [3, 2]
and input(Y) and input(Y)
Y.lod = [[0, 2, 3, 6]] Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0] ref_level: 0
then we get 1-level LoDTensor then we get a common LoDTensor
Out.lod = [[0, 2, 3, 6]] Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]]
Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]]
Out.dims = [6, 2] Out.dims = [6, 2]
Case 4:
Given 2-level a LoDTensor input(X)
X.lod = [[0, 2, 3],
[0, 1, 3, 4]]
X.data = [a, b, c, d]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 2-level LoDTensor
Out.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
Out.data = [a, a, a, b, b, b, d, d]
Out.dims = [8, 1]
)DOC"); )DOC");
} }
}; };
...@@ -129,12 +183,14 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { ...@@ -129,12 +183,14 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out")); PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
...@@ -149,7 +205,13 @@ REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker, ...@@ -149,7 +205,13 @@ REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker,
sequence_expand_grad, ops::SequenceExpandOpGrad); sequence_expand_grad, ops::SequenceExpandOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_expand, sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>); ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_expand_grad, sequence_expand_grad,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>); ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -18,7 +18,14 @@ limitations under the License. */ ...@@ -18,7 +18,14 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sequence_expand, sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>); ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sequence_expand_grad, sequence_expand_grad,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>); ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
...@@ -16,45 +16,75 @@ limitations under the License. */ ...@@ -16,45 +16,75 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequenceExpandKernel : public framework::OpKernel<T> { class SequenceExpandKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<LoDTensor>("X"); auto* x = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
const T* x_data = x->data<T>();
auto x_dims = x->dims();
auto* y = context.Input<LoDTensor>("Y"); auto* y = context.Input<LoDTensor>("Y");
PADDLE_ENFORCE(!y->lod().empty(), "y should have lod"); auto* out = context.Output<LoDTensor>("Out");
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims[0]),
y->lod().back().size() - 1, int ref_level = context.Attr<int>("ref_level");
"The size of last lod level in Input(Y)" auto& x_lod = x->lod();
"must be equal to dims[0] of Input(X)."); auto& y_lod = y->lod();
out->set_lod(y->lod());
auto* place = if (ref_level == -1) ref_level = y_lod.size() - 1;
context.template device_context<DeviceContext>().eigen_device();
size_t element_len = framework::product(x_dims) / x_dims[0]; out->mutable_data<T>(context.GetPlace());
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back(); if (y_lod[ref_level].size() <= 1) {
framework::TensorCopy(*x, context.GetPlace(), out);
for (size_t i = 0; i < out_starts.size() - 1; i++) { return;
int scale = out_starts[i + 1] - out_starts[i]; }
Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>> auto& out_lod = *out->mutable_lod();
x_t(x_data, 1, element_len); if (x_lod.size() == 1) {
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>> out_lod.resize(1);
out_t(out_data, scale, element_len); out_lod[0] = {0};
Eigen::array<int, 2> cast({{scale, 1}}); }
out_t.device(*place) = x_t.broadcast(cast);
x_data += element_len; int out_offset = 0;
out_data += element_len * scale; auto& eigen_place =
*context.template device_context<DeviceContext>().eigen_device();
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
int x_start = i - 1;
int x_end = i;
if (x_lod.size() == 1) {
x_start = x_lod[0][i - 1];
x_end = x_lod[0][i];
}
int x_seq_len = x_end - x_start;
if (repeat_num > 0) {
auto x_sub_tensor = x->Slice(x_start, x_end);
x_sub_tensor.Resize({1, x_sub_tensor.numel()});
int out_start = out_offset;
if (x_lod.size() == 1) {
out_start = out_lod[0][out_offset];
}
auto out_sub_tensor =
out->Slice(out_start, out_start + x_seq_len * repeat_num);
out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]});
EigenMatrix<T>::From(out_sub_tensor).device(eigen_place) =
EigenMatrix<T>::From(x_sub_tensor)
.broadcast(Eigen::array<int, 2>({{repeat_num, 1}}));
}
for (int j = 0; j < repeat_num; ++j) {
if (x_lod.size() == 1) {
out_lod[0].push_back(out_lod[0].back() + x_seq_len);
}
out_offset++;
}
} }
} }
}; };
...@@ -75,27 +105,51 @@ template <typename DeviceContext, typename T> ...@@ -75,27 +105,51 @@ template <typename DeviceContext, typename T>
class SequenceExpandGradKernel : public framework::OpKernel<T> { class SequenceExpandGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out")); auto* g_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = context.Input<LoDTensor>("X"); auto* x = context.Input<LoDTensor>("X");
auto* out = context.Input<LoDTensor>("Out"); auto* y = context.Input<LoDTensor>("Y");
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X")); auto* g_x = context.Output<LoDTensor>(framework::GradVarName("X"));
auto out_last_level = out->lod().back(); int ref_level = context.Attr<int>("ref_level");
d_x->set_lod(x->lod());
const T* d_out_data = d_out->data<T>(); g_x->mutable_data<T>(context.GetPlace());
T* d_x_data = d_x->mutable_data<T>(context.GetPlace()); g_x->set_lod(x->lod());
size_t element_len = d_out->numel() / d_out->dims()[0];
for (size_t i = 0; i < out_last_level.size() - 1; ++i) { auto& x_lod = x->lod();
size_t repeat = out_last_level[i + 1] - out_last_level[i]; auto& y_lod = y->lod();
Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>> if (ref_level == -1) ref_level = y_lod.size() - 1;
d_out_t(d_out_data, static_cast<int>(repeat), element_len);
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>> // just copy the gradient
d_x_t(d_x_data, static_cast<int>(element_len)); if (y_lod[ref_level].size() <= 1) {
auto place = framework::TensorCopy(*g_out, context.GetPlace(), g_x);
context.template device_context<DeviceContext>().eigen_device(); return;
d_x_t.device(*place) = d_out_t.sum(Eigen::array<int, 1>({{0}})); }
d_out_data += (repeat * element_len);
d_x_data += element_len; auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, g_x, static_cast<T>(0));
int g_out_offset = 0;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
if (repeat_num > 0) {
int x_start = i - 1;
int x_end = i;
if (x_lod.size() == 1) {
x_start = x_lod[0][i - 1];
x_end = x_lod[0][i];
}
int x_seq_len = x_end - x_start;
auto g_x_sub = g_x->Slice(x_start, x_end);
g_x_sub.Resize(flatten_to_1d(g_x_sub.dims()));
int g_out_end = g_out_offset + repeat_num * x_seq_len;
auto g_out_sub = g_out->Slice(g_out_offset, g_out_end);
g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]});
math::ColwiseSum<DeviceContext, T> col_sum;
col_sum(dev_ctx, g_out_sub, &g_x_sub);
g_out_offset += repeat_num * x_seq_len;
}
} }
} }
}; };
......
...@@ -483,8 +483,123 @@ DEVICE inline bool operator>=(const half& a, const half& b) { ...@@ -483,8 +483,123 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
#endif // PADDLE_CUDA_FP16 #endif // PADDLE_CUDA_FP16
// Arithmetic operators on ARMv8.2-A CPU // Arithmetic operators for float16 on GPU
#if defined(PADDLE_WITH_NATIVE_FP16) #if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hadd(half(a), half(b)));
#else
return float16(float(a) + float(b));
#endif
}
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hsub(half(a), half(b)));
#else
return float16(float(a) - float(b));
#endif
}
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hmul(half(a), half(b)));
#else
return float16(float(a) * float(b));
#endif
}
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
// TODO(kexinzhao): check which cuda version starts to support __hdiv
float num = __half2float(half(a));
float denom = __half2float(half(b));
return float16(num / denom);
#else
return float16(float(a) / float(b));
#endif
}
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hneg(half(a)));
#else
float16 res;
res.x = a.x ^ 0x8000;
return res;
#endif
}
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
a = a + b;
return a;
}
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
a = a - b;
return a;
}
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
a = a * b;
return a;
}
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
a = a / b;
return a;
}
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(half(a), half(b));
#else
return float(a) == float(b);
#endif
}
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(half(a), half(b));
#else
return float(a) != float(b);
#endif
}
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(half(a), half(b));
#else
return float(a) < float(b);
#endif
}
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(half(a), half(b));
#else
return float(a) <= float(b);
#endif
}
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(half(a), half(b));
#else
return float(a) > float(b);
#endif
}
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(half(a), half(b));
#else
return float(a) >= float(b);
#endif
}
// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
HOST inline float16 operator+(const float16& a, const float16& b) { HOST inline float16 operator+(const float16& a, const float16& b) {
float16 res; float16 res;
asm volatile( asm volatile(
...@@ -668,71 +783,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) { ...@@ -668,71 +783,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
return (res & 0xffff) != 0; return (res & 0xffff) != 0;
} }
// Arithmetic operators, software emulated on other CPU // Arithmetic operators for float16, software emulated on other CPU
#else #else
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { HOST inline float16 operator+(const float16& a, const float16& b) {
return float16(float(a) + float(b)); return float16(float(a) + float(b));
} }
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { HOST inline float16 operator-(const float16& a, const float16& b) {
return float16(float(a) - float(b)); return float16(float(a) - float(b));
} }
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { HOST inline float16 operator*(const float16& a, const float16& b) {
return float16(float(a) * float(b)); return float16(float(a) * float(b));
} }
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { HOST inline float16 operator/(const float16& a, const float16& b) {
return float16(float(a) / float(b)); return float16(float(a) / float(b));
} }
HOSTDEVICE inline float16 operator-(const float16& a) { HOST inline float16 operator-(const float16& a) {
float16 res; float16 res;
res.x = a.x ^ 0x8000; res.x = a.x ^ 0x8000;
return res; return res;
} }
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { HOST inline float16& operator+=(float16& a, const float16& b) {
a = float16(float(a) + float(b)); a = float16(float(a) + float(b));
return a; return a;
} }
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { HOST inline float16& operator-=(float16& a, const float16& b) {
a = float16(float(a) - float(b)); a = float16(float(a) - float(b));
return a; return a;
} }
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { HOST inline float16& operator*=(float16& a, const float16& b) {
a = float16(float(a) * float(b)); a = float16(float(a) * float(b));
return a; return a;
} }
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { HOST inline float16& operator/=(float16& a, const float16& b) {
a = float16(float(a) / float(b)); a = float16(float(a) / float(b));
return a; return a;
} }
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { HOST inline bool operator==(const float16& a, const float16& b) {
return float(a) == float(b); return float(a) == float(b);
} }
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { HOST inline bool operator!=(const float16& a, const float16& b) {
return float(a) != float(b); return float(a) != float(b);
} }
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { HOST inline bool operator<(const float16& a, const float16& b) {
return float(a) < float(b); return float(a) < float(b);
} }
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { HOST inline bool operator<=(const float16& a, const float16& b) {
return float(a) <= float(b); return float(a) <= float(b);
} }
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { HOST inline bool operator>(const float16& a, const float16& b) {
return float(a) > float(b); return float(a) > float(b);
} }
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { HOST inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b); return float(a) >= float(b);
} }
#endif #endif
......
...@@ -307,15 +307,57 @@ class DistributeTranspiler: ...@@ -307,15 +307,57 @@ class DistributeTranspiler:
# Iterate through the ops, and if an op and the optimize ops # Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then # which located on current pserver are in one set, then
# append it into the sub program. # append it into the sub program.
for _, op in enumerate(self.optimize_ops):
for _, opt_op in enumerate(opt_op_on_pserver): # We try to put optimization program run parallelly, assume
if ufind.is_connected(op, opt_op): # optimization program always looks like:
if self._is_opt_op(op): #
self._append_pserver_ops(optimize_block, op, endpoint, # prevop -> prevop -> opt op -> following op -> following op; ->
default_main_program()) # prevop -> prevop -> opt op -> following op -> following op; ->
else: # global op -> global op
self._append_pserver_non_opt_ops(optimize_block, op) #
break # we put operators that can run parallelly to many program blocks.
# in above example, we seperate ops by the ";". Global ops must run
# after all the optimize ops finished.
global_ops = []
# HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine.
for op in self.optimize_ops:
if op.type == "scale":
for in_name in op.input_arg_names:
if in_name.startswith("beta1_pow_acc") or\
in_name.startswith("beta2_pow_acc"):
global_ops.append(op)
def __append_optimize_op__(op, block):
if self._is_opt_op(op):
self._append_pserver_ops(block, op, endpoint,
default_main_program())
else:
self._append_pserver_non_opt_ops(block, op)
# append op to the current block
per_opt_block = optimize_block
for _, opt_op in enumerate(opt_op_on_pserver):
for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and \
op not in global_ops:
__append_optimize_op__(op, per_opt_block)
per_opt_block = pserver_program.create_block(0)
# append global ops
for glb_op in global_ops:
__append_optimize_op__(glb_op, per_opt_block)
# NOT USED: single block version:
#
# for _, op in enumerate(self.optimize_ops):
# for _, opt_op in enumerate(opt_op_on_pserver):
# if ufind.is_connected(op, opt_op):
# __append_optimize_op__(glb_op, optimize_block)
# break
# step5 append the listen_and_serv op # step5 append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="listen_and_serv", type="listen_and_serv",
...@@ -660,10 +702,22 @@ class DistributeTranspiler: ...@@ -660,10 +702,22 @@ class DistributeTranspiler:
# If one op's input is another op's output or # If one op's input is another op's output or
# one op's output is another op's input, we say # one op's output is another op's input, we say
# the two operator is connected. # the two operator is connected.
op1_input_names = op1.desc.input_arg_names() def _append_inname_remove_beta(varname_list):
op_input_names = []
for in_name in varname_list:
# HACK: remove beta1 and beta2 to avoid let all
# ops connected.
if in_name.startswith("beta2_pow_acc") or \
in_name.startswith("beta1_pow_acc"):
continue
else:
op_input_names.append(in_name)
return op_input_names
op1_input_names = _append_inname_remove_beta(op1.desc.input_arg_names())
op1_output_names = op1.desc.output_arg_names() op1_output_names = op1.desc.output_arg_names()
op2_input_names = op2.desc.input_arg_names() op2_input_names = _append_inname_remove_beta(op2.desc.input_arg_names())
op2_output_names = op2.desc.output_arg_names() op2_output_names = op2.desc.output_arg_names()
if set(op1_output_names) & set(op2_input_names) or \ if set(op1_output_names) & set(op2_input_names) or \
......
...@@ -1809,52 +1809,52 @@ def conv2d_transpose(input, ...@@ -1809,52 +1809,52 @@ def conv2d_transpose(input,
return out return out
def sequence_expand(x, y, name=None): def sequence_expand(x, y, ref_level=-1, name=None):
"""Sequence Expand Layer. This layer will expand the input variable **x** """Sequence Expand Layer. This layer will expand the input variable **x**
according to LoD information of **y**. And the following examples will according to specified level lod of **y**. Please note that lod level of
explain how sequence_expand works: **x** is at most 1 and rank of **x** is at least 2. When rank of **x**
is greater than 2, then it would be viewed as a 2-D tensor.
Following examples will explain how sequence_expand works:
.. code-block:: text .. code-block:: text
* Case 1 * Case 1
x is a LoDTensor: x is a LoDTensor:
x.lod = [[0, 2, 3], x.lod = [[0, 2, 4]]
[0, 1, 3, 4]] x.data = [[a], [b], [c], [d]]
x.data = [a, b, c, d]
x.dims = [4, 1] x.dims = [4, 1]
y is a LoDTensor: y is a LoDTensor:
y.lod = [[0, 2, 4], y.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]] [0, 3, 6, 7, 8]]
with condition len(y.lod[-1]) - 1 == x.dims[0] ref_level: 0
then output is a 2-level LoDTensor: then output is a 1-level LoDTensor:
out.lod = [[0, 2, 4], out.lod = [[0, 2, 4, 6, 8]]
[0, 3, 6, 7, 8]] out.data = [[a], [b], [a], [b], [c], [d], [c], [d]]
out.data = [a, a, a, b, b, b, c, d]
out.dims = [8, 1] out.dims = [8, 1]
* Case 2 * Case 2
x is a Tensor: x is a Tensor:
x.data = [a, b, c] x.data = [[a], [b], [c]]
x.dims = [3, 1] x.dims = [3, 1]
y is a LoDTensor: y is a LoDTensor:
y.lod = [[0, 2, 3, 6]] y.lod = [[0, 2, 2, 5]]
with condition len(y.lod[-1]) - 1 == x.dims[0] ref_level: -1
then output is a 1-level LoDTensor:
out.lod = [[0, 2, 3, 6]]
out.data = [a, a, b, c, c, c]
out.dims = [6, 1]
then output is a Tensor:
out.data = [[a], [a], [c], [c], [c]]
out.dims = [5, 1]
Args: Args:
x (Variable): The input variable which is a Tensor or LoDTensor. x (Variable): The input variable which is a Tensor or LoDTensor.
y (Variable): The input variable which is a LoDTensor. y (Variable): The input variable which is a LoDTensor.
ref_level (int): Lod level of `y` to be referred by `x`. If set to -1,
refer the last level of lod.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
Returns: Returns:
Variable: The expanded variable which is a LoDTensor. Variable: The expanded variable which is a LoDTensor.
...@@ -1865,14 +1865,17 @@ def sequence_expand(x, y, name=None): ...@@ -1865,14 +1865,17 @@ def sequence_expand(x, y, name=None):
x = fluid.layers.data(name='x', shape=[10], dtype='float32') x = fluid.layers.data(name='x', shape=[10], dtype='float32')
y = fluid.layers.data(name='y', shape=[10, 20], y = fluid.layers.data(name='y', shape=[10, 20],
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
out = layers.sequence_expand(x=x, y=y) out = layers.sequence_expand(x=x, y=y, ref_level=0)
""" """
helper = LayerHelper('sequence_expand', input=x, **locals()) helper = LayerHelper('sequence_expand', input=x, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
tmp = helper.create_tmp_variable(dtype) tmp = helper.create_tmp_variable(dtype)
helper.append_op( helper.append_op(
type='sequence_expand', inputs={'X': x, type='sequence_expand',
'Y': y}, outputs={'Out': tmp}) inputs={'X': x,
'Y': y},
outputs={'Out': tmp},
attrs={'ref_level': ref_level})
return tmp return tmp
......
...@@ -27,7 +27,7 @@ from contextlib import contextmanager ...@@ -27,7 +27,7 @@ from contextlib import contextmanager
__all__ = [ __all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad',
'ModelAverage' 'Adadelta', 'ModelAverage'
] ]
...@@ -591,6 +591,88 @@ class DecayedAdagradOptimizer(Optimizer): ...@@ -591,6 +591,88 @@ class DecayedAdagradOptimizer(Optimizer):
return decayed_adagrad_op return decayed_adagrad_op
class AdadeltaOptimizer(Optimizer):
"""
**Adadelta Optimizer**
Simple Adadelta optimizer with average squared grad state and
average squared update state.
The details of adadelta please refer to this
`ADADELTA: AN ADAPTIVE LEARNING RATE METHOD
<http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf>`_.
.. math::
E(g_t^2) &= \\rho * E(g_{t-1}^2) + (1-\\rho) * g^2 \\\\
learning\\_rate &= sqrt( ( E(dx_{t-1}^2) + \\epsilon ) / ( \\
E(g_t^2) + \\epsilon ) ) \\\\
E(dx_t^2) &= \\rho * E(dx_{t-1}^2) + (1-\\rho) * (-g*learning\\_rate)^2
Args:
learning_rate(float): global leraning rate
rho(float): rho in equation
epsilon(float): epsilon in equation
Examples:
.. code-block:: python
optimizer = fluid.optimizer.Adadelta(
learning_rate=0.0003, epsilon=1.0e-6, rho=0.95)
_, params_grads = optimizer.minimize(cost)
"""
_avg_squared_grad_acc_str = "_avg_squared_grad"
_avg_squared_update_acc_str = "_avg_squared_update"
def __init__(self, learning_rate, epsilon=1.0e-6, rho=0.95, **kwargs):
if learning_rate is None:
raise ValueError("learning_rate is not set.")
if epsilon is None:
raise ValueError("epsilon is not set.")
if rho is None:
raise ValueError("rho is not set.")
super(AdadeltaOptimizer, self).__init__(
learning_rate=learning_rate, **kwargs)
self.type = "adadelta"
self._epsilon = epsilon
self._rho = rho
def _create_accumulators(self, block, parameters):
if not isinstance(block, framework.Block):
raise TypeError("block is not instance of framework.Block.")
for p in parameters:
self._add_accumulator(self._avg_squared_grad_acc_str, p)
self._add_accumulator(self._avg_squared_update_acc_str, p)
def _append_optimize_op(self, block, param_and_grad):
if not isinstance(block, framework.Block):
raise TypeError("block is not instance of framework.Block.")
avg_squared_grad_acc = self._get_accumulator(
self._avg_squared_grad_acc_str, param_and_grad[0])
avg_squared_update_acc = self._get_accumulator(
self._avg_squared_update_acc_str, param_and_grad[0])
# Create the adadelta optimizer op
adadelta_op = block.append_op(
type=self.type,
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"AvgSquaredGrad": avg_squared_grad_acc,
"AvgSquaredUpdate": avg_squared_update_acc
},
outputs={
"ParamOut": param_and_grad[0],
"AvgSquaredGradOut": avg_squared_grad_acc,
"AvgSquaredUpdateOut": avg_squared_update_acc
},
attrs={"epsilon": self._epsilon,
"rho": self._rho})
return adadelta_op
# We short the class name, since users will use the optimizer with the package # We short the class name, since users will use the optimizer with the package
# name. The sample code: # name. The sample code:
# #
...@@ -605,6 +687,7 @@ Adagrad = AdagradOptimizer ...@@ -605,6 +687,7 @@ Adagrad = AdagradOptimizer
Adam = AdamOptimizer Adam = AdamOptimizer
Adamax = AdamaxOptimizer Adamax = AdamaxOptimizer
DecayedAdagrad = DecayedAdagradOptimizer DecayedAdagrad = DecayedAdagradOptimizer
Adadelta = AdadeltaOptimizer
class ModelAverage(Optimizer): class ModelAverage(Optimizer):
......
...@@ -118,12 +118,12 @@ def decoder_decode(context, is_sparse): ...@@ -118,12 +118,12 @@ def decoder_decode(context, is_sparse):
is_sparse=is_sparse) is_sparse=is_sparse)
# use rnn unit to update rnn # use rnn unit to update rnn
current_state = pd.fc(input=[pre_ids_emb, pre_state_expanded], current_state = pd.fc(input=[pre_state_expanded, pre_ids_emb],
size=decoder_size, size=decoder_size,
act='tanh') act='tanh')
current_state_with_lod = pd.lod_reset(x=current_state, y=pre_score)
# use score to do beam search # use score to do beam search
current_score = pd.fc(input=current_state, current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim, size=target_dict_dim,
act='softmax') act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=50) topk_scores, topk_indices = pd.topk(current_score, k=50)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
...@@ -82,5 +83,37 @@ class TestDropoutOp5(OpTest): ...@@ -82,5 +83,37 @@ class TestDropoutOp5(OpTest):
self.check_output() self.check_output()
class TestFP16DropoutOp(OpTest):
def setUp(self):
self.op_type = "dropout"
self.init_test_case()
x = np.random.random(self.input_size).astype("float16")
out = x * (1.0 - self.prob)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {
'dropout_prob': self.prob,
'fix_seed': self.fix_seed,
'is_test': True
}
self.outputs = {'Out': out}
def init_test_case(self):
self.input_size = [32, 64]
self.prob = 0.35
self.fix_seed = True
def test_check_output(self):
if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"):
self.check_output_with_place(core.CUDAPlace(0), atol=1e-3)
class TestFP16DropoutOp2(TestFP16DropoutOp):
def init_test_case(self):
self.input_size = [32, 64, 3]
self.prob = 0.75
self.fix_seed = False
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -181,8 +181,8 @@ class TestBook(unittest.TestCase): ...@@ -181,8 +181,8 @@ class TestBook(unittest.TestCase):
with program_guard(program): with program_guard(program):
x = layers.data(name='x', shape=[10], dtype='float32') x = layers.data(name='x', shape=[10], dtype='float32')
y = layers.data( y = layers.data(
name='y', shape=[10, 20], dtype='float32', lod_level=1) name='y', shape=[10, 20], dtype='float32', lod_level=2)
self.assertIsNotNone(layers.sequence_expand(x=x, y=y)) self.assertIsNotNone(layers.sequence_expand(x=x, y=y, ref_level=1))
print(str(program)) print(str(program))
def test_lstm_unit(self): def test_lstm_unit(self):
......
...@@ -27,12 +27,36 @@ class TestSequenceExpand(OpTest): ...@@ -27,12 +27,36 @@ class TestSequenceExpand(OpTest):
def compute(self): def compute(self):
x = self.inputs['X'] x = self.inputs['X']
x_data, x_lod = x if type(x) == tuple else (x, None) x_data, x_lod = x if type(x) == tuple else (x, None)
n = 1 + x_data.shape[0] if not x_lod else len(x_lod[0])
y_data, y_lod = self.inputs['Y'] y_data, y_lod = self.inputs['Y']
repeats = [((y_lod[-1][i + 1] - y_lod[-1][i]))
for i in range(len(y_lod[-1]) - 1)] if hasattr(self, 'attrs'):
out = x_data.repeat(repeats, axis=0) ref_level = self.attrs['ref_level']
self.outputs = {'Out': out} else:
ref_level = len(y_lod) - 1
out = np.zeros(shape=((0, ) + x_data.shape[1:]), dtype=x_data.dtype)
if x_lod is None:
x_idx = [i for i in xrange(x_data.shape[0] + 1)]
else:
x_idx = x_lod[0]
out_lod = [[0]]
for i in xrange(1, len(y_lod[ref_level])):
repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]
x_len = x_idx[i] - x_idx[i - 1]
if repeat_num > 0:
x_sub = x_data[x_idx[i - 1]:x_idx[i], :]
x_sub = np.repeat(x_sub, repeat_num, axis=0)
out = np.vstack((out, x_sub))
if x_lod is not None:
for j in xrange(repeat_num):
out_lod[0].append(out_lod[0][-1] + x_len)
if x_lod is None:
self.outputs = {'Out': out}
else:
self.outputs = {'Out': (out, out_lod)}
def setUp(self): def setUp(self):
self.op_type = 'sequence_expand' self.op_type = 'sequence_expand'
...@@ -52,7 +76,8 @@ class TestSequenceExpandCase1(TestSequenceExpand): ...@@ -52,7 +76,8 @@ class TestSequenceExpandCase1(TestSequenceExpand):
x_lod = [[0, 2, 5]] x_lod = [[0, 2, 5]]
y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32') y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32')
y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]] y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': x_data, 'Y': (y_data, y_lod)}
self.attrs = {'ref_level': 0}
class TestSequenceExpandCase2(TestSequenceExpand): class TestSequenceExpandCase2(TestSequenceExpand):
...@@ -60,8 +85,9 @@ class TestSequenceExpandCase2(TestSequenceExpand): ...@@ -60,8 +85,9 @@ class TestSequenceExpandCase2(TestSequenceExpand):
x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32') x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32')
x_lod = [[0, 1]] x_lod = [[0, 1]]
y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32') y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32')
y_lod = [[0, 2]] y_lod = [[0, 2], [0, 2]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
self.attrs = {'ref_level': 0}
class TestSequenceExpandCase3(TestSequenceExpand): class TestSequenceExpandCase3(TestSequenceExpand):
...@@ -75,14 +101,9 @@ class TestSequenceExpandCase3(TestSequenceExpand): ...@@ -75,14 +101,9 @@ class TestSequenceExpandCase3(TestSequenceExpand):
class TestSequenceExpandCase4(TestSequenceExpand): class TestSequenceExpandCase4(TestSequenceExpand):
def set_data(self): def set_data(self):
x_data = np.array( data = [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]
[0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape( x_data = np.array(data).reshape([5, 2]).astype('float32')
[2, 5]).astype('float32') x_lod = [[0, 2, 5]]
x_lod = [[
0,
1,
2,
]]
y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32') y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32')
y_lod = [[0, 1, 2], [0, 1, 2]] y_lod = [[0, 1, 2], [0, 1, 2]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册