diff --git a/doc/design/distributed_lookup_table_design.md b/doc/fluid/design/dist_train/distributed_lookup_table_design.md similarity index 97% rename from doc/design/distributed_lookup_table_design.md rename to doc/fluid/design/dist_train/distributed_lookup_table_design.md index a09f2818c888397b07fc7d09ecd20056f4176982..e543adf0f97cc6b47415b807d7a1ed1effec9b22 100644 --- a/doc/design/distributed_lookup_table_design.md +++ b/doc/fluid/design/dist_train/distributed_lookup_table_design.md @@ -26,7 +26,7 @@ lookup of rows. The following figure illustrates the multiplication of x with two non-zero elements, or say, two symbols, and a lookup table W: -![lookup table](./lookup_table.png) +![lookup table](./src/lookup_table.png) ### The Backward Algorithm @@ -42,7 +42,7 @@ or some more sophisticated algorithms that rely on both W' and W: $$W = f(W, W')$$ The following figure illustrates the backward pass of the lookup -operator: ![lookup table training](./lookup_table_training.png) +operator: ![lookup table training](./src/lookup_table_training.png) ## Distributed Storage Service diff --git a/doc/design/lookup_table.png b/doc/fluid/design/dist_train/src/lookup_table.png similarity index 100% rename from doc/design/lookup_table.png rename to doc/fluid/design/dist_train/src/lookup_table.png diff --git a/doc/design/lookup_table_training.png b/doc/fluid/design/dist_train/src/lookup_table_training.png similarity index 100% rename from doc/design/lookup_table_training.png rename to doc/fluid/design/dist_train/src/lookup_table_training.png diff --git a/doc/fluid/design/motivation/fluid.md b/doc/fluid/design/motivation/fluid.md index f78fa8c1914124f33b9730f918c8887ced4f8d9d..110b7d78bf12ac8328fb3a913e4386e75d63c995 100644 --- a/doc/fluid/design/motivation/fluid.md +++ b/doc/fluid/design/motivation/fluid.md @@ -103,7 +103,7 @@ In computability theory, a system of data-manipulation rules, such as a programm There are two ways to execute a Fluid program. When a program is executed, it creates a protobuf message [`ProgramDesc`](https://github.com/PaddlePaddle/Paddle/blob/a91efdde6910ce92a78e3aa7157412c4c88d9ee8/paddle/framework/framework.proto#L145) that describes the process and is conceptually like an [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree). -There is a C++ class [`Executor`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/executor.h), which runs a `ProgramDesc`, similar to how an interpreter runs a Python program. +There is a C++ class [`Executor`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/executor.h), which runs a `ProgramDesc`, similar to how an interpreter runs a Python program. Fluid is moving towards the direction of a compiler, which is explain in [fluid_compiler.md](fluid_compiler.md). diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 1074ed6acc22a81f46c466d917ef973945a12898..e4436549f6185ba04a5f270893596a6dcb11e89b 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel { } }; -template class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { public: DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -73,7 +72,6 @@ are set equal to their corresponding inputs. } }; -template class DropoutOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, - ops::DropoutOpGrad); +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, + ops::DropoutOpGrad); REGISTER_OP_CPU_KERNEL( - dropout, - ops::CPUDropoutKernel); + dropout, ops::CPUDropoutKernel); REGISTER_OP_CPU_KERNEL( dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index d6f9c04359d733cb4f3f0586e9239ee67deb7078..f6c85a2a537b37feb20e6d62729dc5075af2a5d9 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -18,17 +18,18 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -template +template __global__ void RandomGenerator(const size_t n, const int seed, - const AttrType dropout_prob, const T* src, + const float dropout_prob, const T* src, T* mask_data, T* dst) { thrust::minstd_rand rng; rng.seed(seed); - thrust::uniform_real_distribution dist(0, 1); + thrust::uniform_real_distribution dist(0, 1); int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < n; idx += blockDim.x * gridDim.x) { @@ -44,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed, // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. -template +template class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Output("Out"); y->mutable_data(context.GetPlace()); - AttrType dropout_prob = context.Attr("dropout_prob"); + float dropout_prob = context.Attr("dropout_prob"); auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); @@ -70,11 +71,11 @@ class GPUDropoutKernel : public framework::OpKernel { int threads = 512; int grid = (x->numel() + threads - 1) / threads; - RandomGenerator<<>>( + RandomGenerator< + T><<>>( size, seed, dropout_prob, x_data, mask_data, y_data); } else { - Y.device(place) = X * (1.0f - dropout_prob); + Y.device(place) = X * static_cast(1.0f - dropout_prob); } } }; @@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - dropout, - ops::GPUDropoutKernel); -REGISTER_OP_CUDA_KERNEL( - dropout_grad, - ops::DropoutGradKernel); + dropout, ops::GPUDropoutKernel, + ops::GPUDropoutKernel); +REGISTER_OP_CUDA_KERNEL(dropout_grad, + ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 209e4dec1756dc65fbf147c4dbbf0913d3c6ef7e..b5ee86ae2d11dfc835e1a3a6826ce016baf38a29 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -25,7 +25,7 @@ template using EigenMatrix = framework::EigenMatrix; -template +template class CPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index fba1612d10f0494f4ab06fabdd0e799a74dafd53..547d081006f1c28ba73cb02d38e36bb612cea494 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -43,7 +43,7 @@ math_library(sequence2batch) math_library(sequence_padding) math_library(sequence_pooling DEPS math_function) math_library(sequence_scale) -math_library(softmax) +math_library(softmax DEPS math_function) math_library(unpooling) math_library(vol2col) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index d0de092947eb04a1b7d06dedea919f6b1094dd06..bd0bb2ee3b0252f47318c59d9940d8dd478723de 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -48,20 +48,24 @@ class DoubleBufferReader : public framework::DecoratedReader { void start_thread() { buffer_ = framework::MakeChannel(kDoubleBufferSize); - std::thread prefetch([this] { PrefetchThreadFunc(); }); - prefetch.detach(); + prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); } void ReadNext(std::vector* out) override; void ReInit() override; - ~DoubleBufferReader() { buffer_->Close(); } + ~DoubleBufferReader() { + buffer_->Close(); + prefetcher_.join(); + delete buffer_; + } bool HasNext() const override; private: void PrefetchThreadFunc(); + std::thread prefetcher_; framework::Channel* buffer_; platform::Place place_; std::vector> ctxs_; @@ -134,6 +138,8 @@ void DoubleBufferReader::ReadNext(std::vector* out) { void DoubleBufferReader::ReInit() { reader_->ReInit(); buffer_->Close(); + prefetcher_.join(); + delete buffer_; start_thread(); } @@ -159,11 +165,12 @@ void DoubleBufferReader::PrefetchThreadFunc() { if (!buffer_->Send(&batch)) { VLOG(5) << "WARNING: The double buffer channel has been closed. The " - "prefetch thread terminates."; + "prefetch thread will terminate."; break; } } buffer_->Close(); + VLOG(5) << "Prefetch thread terminates."; } bool DoubleBufferReader::HasNext() const { diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 70e2f587dc414a850ddc341b98f26ae54636755c..3a1f3805a0483c2f5eabdc7432556051d8308964 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -34,6 +34,9 @@ class ShuffleReader : public framework::DecoratedReader { } void ReadNext(std::vector* out) override { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } if (iteration_pos_ >= buffer_.size()) { VLOG(10) << "Resetting shuffle buffer"; ReadIntoBuffers(); @@ -50,7 +53,6 @@ class ShuffleReader : public framework::DecoratedReader { buffer_.clear(); buffer_.reserve(buffer_size_); iteration_pos_ = 0; - PADDLE_ENFORCE(reader_->HasNext()); for (size_t i = 0; i < buffer_size_; ++i) { if (!reader_->HasNext()) { break; diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 52fb8c2531357ad7a2b2f8613e5c7fbcef52c6bb..2cf311c7e56a9bbb0bdb0078d5cfefb4bb50018b 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -483,9 +483,124 @@ DEVICE inline bool operator>=(const half& a, const half& b) { #endif // PADDLE_CUDA_FP16 -// Arithmetic operators on ARMv8.2-A CPU -#if defined(PADDLE_WITH_NATIVE_FP16) -HOST inline float16 operator+(const float16& a, const float16& b) { +// Arithmetic operators for float16 on GPU +#if defined(PADDLE_CUDA_FP16) +HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hadd(half(a), half(b))); +#else + return float16(float(a) + float(b)); +#endif +} + +HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hsub(half(a), half(b))); +#else + return float16(float(a) - float(b)); +#endif +} + +HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hmul(half(a), half(b))); +#else + return float16(float(a) * float(b)); +#endif +} + +HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 + // TODO(kexinzhao): check which cuda version starts to support __hdiv + float num = __half2float(half(a)); + float denom = __half2float(half(b)); + return float16(num / denom); +#else + return float16(float(a) / float(b)); +#endif +} + +HOSTDEVICE inline float16 operator-(const float16& a) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hneg(half(a))); +#else + float16 res; + res.x = a.x ^ 0x8000; + return res; +#endif +} + +HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { + a = a + b; + return a; +} + +HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { + a = a - b; + return a; +} + +HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { + a = a * b; + return a; +} + +HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { + a = a / b; + return a; +} + +HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __heq(half(a), half(b)); +#else + return float(a) == float(b); +#endif +} + +HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hne(half(a), half(b)); +#else + return float(a) != float(b); +#endif +} + +HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hlt(half(a), half(b)); +#else + return float(a) < float(b); +#endif +} + +HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hle(half(a), half(b)); +#else + return float(a) <= float(b); +#endif +} + +HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hgt(half(a), half(b)); +#else + return float(a) > float(b); +#endif +} + +HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hge(half(a), half(b)); +#else + return float(a) >= float(b); +#endif +} + +// Arithmetic operators for float16 on ARMv8.2-A CPU +#elif defined(PADDLE_WITH_NATIVE_FP16) +inline float16 operator+(const float16& a, const float16& b) { float16 res; asm volatile( "ld1 {v0.h}[0], [%[a_ptr]]\n" @@ -501,7 +616,7 @@ HOST inline float16 operator+(const float16& a, const float16& b) { return res; } -HOST inline float16 operator-(const float16& a, const float16& b) { +inline float16 operator-(const float16& a, const float16& b) { float16 res; asm volatile( "ld1 {v0.h}[0], [%[a_ptr]]\n" @@ -517,7 +632,7 @@ HOST inline float16 operator-(const float16& a, const float16& b) { return res; } -HOST inline float16 operator*(const float16& a, const float16& b) { +inline float16 operator*(const float16& a, const float16& b) { float16 res; asm volatile( "ld1 {v0.h}[0], [%[a_ptr]]\n" @@ -533,7 +648,7 @@ HOST inline float16 operator*(const float16& a, const float16& b) { return res; } -HOST inline float16 operator/(const float16& a, const float16& b) { +inline float16 operator/(const float16& a, const float16& b) { float16 res; asm volatile( "ld1 {v0.h}[0], [%[a_ptr]]\n" @@ -549,7 +664,7 @@ HOST inline float16 operator/(const float16& a, const float16& b) { return res; } -HOST inline float16 operator-(const float16& a) { +inline float16 operator-(const float16& a) { float16 res; asm volatile( "ld1 {v0.h}[0], [%[a_ptr]]\n" @@ -564,27 +679,27 @@ HOST inline float16 operator-(const float16& a) { return res; } -HOST inline float16& operator+=(float16& a, const float16& b) { +inline float16& operator+=(float16& a, const float16& b) { a = a + b; return a; } -HOST inline float16& operator-=(float16& a, const float16& b) { +inline float16& operator-=(float16& a, const float16& b) { a = a - b; return a; } -HOST inline float16& operator*=(float16& a, const float16& b) { +inline float16& operator*=(float16& a, const float16& b) { a = a * b; return a; } -HOST inline float16& operator/=(float16& a, const float16& b) { +inline float16& operator/=(float16& a, const float16& b) { a = a / b; return a; } -HOST inline bool operator==(const float16& a, const float16& b) { +inline bool operator==(const float16& a, const float16& b) { uint16_t res; asm volatile( "ld1 {v0.h}[0], [%[a_ptr]]\n" @@ -600,11 +715,9 @@ HOST inline bool operator==(const float16& a, const float16& b) { return (res & 0xffff) != 0; } -HOST inline bool operator!=(const float16& a, const float16& b) { - return !(a == b); -} +inline bool operator!=(const float16& a, const float16& b) { return !(a == b); } -HOST inline bool operator<(const float16& a, const float16& b) { +inline bool operator<(const float16& a, const float16& b) { uint16_t res; asm volatile( "ld1 {v1.h}[0], [%[a_ptr]]\n" @@ -620,7 +733,7 @@ HOST inline bool operator<(const float16& a, const float16& b) { return (res & 0xffff) != 0; } -HOST inline bool operator<=(const float16& a, const float16& b) { +inline bool operator<=(const float16& a, const float16& b) { uint16_t res; asm volatile( "ld1 {v1.h}[0], [%[a_ptr]]\n" @@ -636,7 +749,7 @@ HOST inline bool operator<=(const float16& a, const float16& b) { return (res & 0xffff) != 0; } -HOST inline bool operator>(const float16& a, const float16& b) { +inline bool operator>(const float16& a, const float16& b) { uint16_t res; asm volatile( "ld1 {v0.h}[0], [%[a_ptr]]\n" @@ -652,7 +765,7 @@ HOST inline bool operator>(const float16& a, const float16& b) { return (res & 0xffff) != 0; } -HOST inline bool operator>=(const float16& a, const float16& b) { +inline bool operator>=(const float16& a, const float16& b) { uint16_t res; asm volatile( "ld1 {v0.h}[0], [%[a_ptr]]\n" @@ -668,71 +781,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) { return (res & 0xffff) != 0; } -// Arithmetic operators, software emulated on other CPU +// Arithmetic operators for float16, software emulated on other CPU #else -HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { +inline float16 operator+(const float16& a, const float16& b) { return float16(float(a) + float(b)); } -HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { +inline float16 operator-(const float16& a, const float16& b) { return float16(float(a) - float(b)); } -HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { +inline float16 operator*(const float16& a, const float16& b) { return float16(float(a) * float(b)); } -HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { +inline float16 operator/(const float16& a, const float16& b) { return float16(float(a) / float(b)); } -HOSTDEVICE inline float16 operator-(const float16& a) { +inline float16 operator-(const float16& a) { float16 res; res.x = a.x ^ 0x8000; return res; } -HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { +inline float16& operator+=(float16& a, const float16& b) { a = float16(float(a) + float(b)); return a; } -HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { +inline float16& operator-=(float16& a, const float16& b) { a = float16(float(a) - float(b)); return a; } -HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { +inline float16& operator*=(float16& a, const float16& b) { a = float16(float(a) * float(b)); return a; } -HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { +inline float16& operator/=(float16& a, const float16& b) { a = float16(float(a) / float(b)); return a; } -HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { +inline bool operator==(const float16& a, const float16& b) { return float(a) == float(b); } -HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { +inline bool operator!=(const float16& a, const float16& b) { return float(a) != float(b); } -HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { +inline bool operator<(const float16& a, const float16& b) { return float(a) < float(b); } -HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { +inline bool operator<=(const float16& a, const float16& b) { return float(a) <= float(b); } -HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { +inline bool operator>(const float16& a, const float16& b) { return float(a) > float(b); } -HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { +inline bool operator>=(const float16& a, const float16& b) { return float(a) >= float(b); } #endif diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 60930a612c128cbf18e89711b9246d148e41ec58..eaa3435a86462236a99489749abe877648677053 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest @@ -82,5 +83,37 @@ class TestDropoutOp5(OpTest): self.check_output() +class TestFP16DropoutOp(OpTest): + def setUp(self): + self.op_type = "dropout" + self.init_test_case() + + x = np.random.random(self.input_size).astype("float16") + out = x * (1.0 - self.prob) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = { + 'dropout_prob': self.prob, + 'fix_seed': self.fix_seed, + 'is_test': True + } + self.outputs = {'Out': out} + + def init_test_case(self): + self.input_size = [32, 64] + self.prob = 0.35 + self.fix_seed = True + + def test_check_output(self): + if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"): + self.check_output_with_place(core.CUDAPlace(0), atol=1e-3) + + +class TestFP16DropoutOp2(TestFP16DropoutOp): + def init_test_case(self): + self.input_size = [32, 64, 3] + self.prob = 0.75 + self.fix_seed = False + + if __name__ == '__main__': unittest.main()