diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 18770fe2861380ea1320aef5cb7ec3432147d7ce..33ef6860e1d38f4e87c4431addf43f9f8a655fc2 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -186,6 +186,11 @@ function(cc_library TARGET_NAME) add_library(${TARGET_NAME} STATIC ${cc_library_SRCS}) endif() if (cc_library_DEPS) + # Don't need link libwarpctc.so + if ("${cc_library_DEPS};" MATCHES "warpctc;") + list(REMOVE_ITEM cc_library_DEPS warpctc) + add_dependencies(${TARGET_NAME} warpctc) + endif() add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) endif() @@ -465,10 +470,10 @@ function(py_test TARGET_NAME) if(WITH_TESTING) set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS DEPS ARGS) + set(multiValueArgs SRCS DEPS ARGS ENVS) cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_test(NAME ${TARGET_NAME} - COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python + COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python ${py_test_ENVS} ${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) endif() diff --git a/doc/design/switch.md b/doc/design/switch.md new file mode 100644 index 0000000000000000000000000000000000000000..9db1b2782a521c2ff4b28b8f9efcdf1492242ed4 --- /dev/null +++ b/doc/design/switch.md @@ -0,0 +1,32 @@ +### Design Doc: Switch + +### Background + +Many programming languages provide `switch` as a generalization of `if-elif-else`. We want to add it to Fluid. + +The following example shows the usage of `fluid.switch`. + +```python +a = fluid.Var(10) +b = fluid.Var(0) + +switch = fluid.switch() +with switch.block(): + with switch.case(fluid.less_equal(a, 10)): + fluid.print("Case 1") + with switch.case(fluid.larger(a, 0)): + fluid.print("Case 2") + with switch.default(): + fluid.print("Case 3") +``` + +### The Semantics + +1. A `switch` control-flow checks cases one-by-one. +1. The condition of each case is a boolean value, which is a scalar, and differs from the `fluid.if_else` control-flow, which condition could be a vector of boolean values. +1. It runs the first matched case, or the default case if there is one. +1. Once it matches a case, it runs the corresponding branch and only that branch. It's like there is a C's `break` keyword at the end of each case. + +The above program should print and print only "Case 1". + +The implementation of the backward pass of the `switch` control-flow is easier than the backward of the `if_else`, because `switch` runs at most one branch, whereas `if-else` could run more than one branches. diff --git a/doc/howto/usage/cluster/cluster_train_cn.md b/doc/howto/usage/cluster/cluster_train_cn.md index c2fc86687d7106aac7c74d6dd16bc229353cb7c1..0f3db59607fb6b43da01f5fdb46949087517ed6c 100644 --- a/doc/howto/usage/cluster/cluster_train_cn.md +++ b/doc/howto/usage/cluster/cluster_train_cn.md @@ -92,11 +92,11 @@ paddle.init( 参数说明 - use_gpu: **可选,默认False**,是否启用GPU训练 -- trainer_count:**必选,默认1**,当前训练任务trainer总个数 +- trainer_count:**必选,默认1**,当前trainer的线程数目 - port:**必选,默认7164**,连接到pserver的端口 - ports_num:**必选,默认1**,连接到pserver的端口个数 - ports_num_for_sparse:**必选,默认0**,和pserver之间用于稀疏类型参数通信的端口个数 -- num_gradient_servers:**必选,默认1**,当前训练任务pserver总数 +- num_gradient_servers:**必选,默认1**,当前训练任务trainer总数 - trainer_id:**必选,默认0**,每个trainer的唯一ID,从0开始的整数 - pservers:**必选,默认127.0.0.1**,当前训练任务启动的pserver的IP列表,多个IP使用“,”隔开 diff --git a/doc/howto/usage/cluster/cluster_train_en.md b/doc/howto/usage/cluster/cluster_train_en.md index 28cd1fa7903e559e33a7fc2f00172fdfbe2fdc97..f9424f8f1a29fcf001c4e7976086512b22f6e858 100644 --- a/doc/howto/usage/cluster/cluster_train_en.md +++ b/doc/howto/usage/cluster/cluster_train_en.md @@ -95,11 +95,11 @@ paddle.init( Parameter Description - use_gpu: **optional, default False**, set to "True" to enable GPU training. -- trainer_count: **required, default 1**, total count of trainers in the training job. +- trainer_count: **required, default 1**, number of threads in current trainer. - port: **required, default 7164**, port to connect to parameter server. - ports_num: **required, default 1**, number of ports for communication. - ports_num_for_sparse: **required, default 0**, number of ports for sparse type caculation. -- num_gradient_servers: **required, default 1**, total number of gradient server. +- num_gradient_servers: **required, default 1**, number of trainers in current job. - trainer_id: **required, default 0**, ID for every trainer, start from 0. - pservers: **required, default 127.0.0.1**, list of IPs of parameter servers, separated by ",". diff --git a/paddle/framework/channel_test.cc b/paddle/framework/channel_test.cc index 1510fb8abf54f05804bd404d9bd00ecc42fbef63..31ac72eda98859327f9857c18287398d0f459c7b 100644 --- a/paddle/framework/channel_test.cc +++ b/paddle/framework/channel_test.cc @@ -29,16 +29,16 @@ TEST(Channel, MakeAndClose) { { // MakeChannel should return a buffered channel is buffer_size > 0. auto ch = MakeChannel(10); - EXPECT_NE(dynamic_cast*>(ch), nullptr); - EXPECT_EQ(dynamic_cast*>(ch), nullptr); + EXPECT_NE(dynamic_cast *>(ch), nullptr); + EXPECT_EQ(dynamic_cast *>(ch), nullptr); CloseChannel(ch); delete ch; } { // MakeChannel should return an un-buffered channel is buffer_size = 0. auto ch = MakeChannel(0); - EXPECT_EQ(dynamic_cast*>(ch), nullptr); - EXPECT_NE(dynamic_cast*>(ch), nullptr); + EXPECT_EQ(dynamic_cast *>(ch), nullptr); + EXPECT_NE(dynamic_cast *>(ch), nullptr); CloseChannel(ch); delete ch; } @@ -78,3 +78,132 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { t.join(); delete ch; } + +TEST(Channel, SimpleUnbufferedChannelTest) { + auto ch = MakeChannel(0); + unsigned sum_send = 0; + std::thread t([&]() { + for (int i = 0; i < 5; i++) { + ch->Send(&i); + sum_send += i; + } + }); + for (int i = 0; i < 5; i++) { + int recv; + ch->Receive(&recv); + EXPECT_EQ(recv, i); + } + + CloseChannel(ch); + t.join(); + EXPECT_EQ(sum_send, 10U); + delete ch; +} + +// This tests that closing an unbuffered channel also unblocks +// unblocks any receivers waiting for senders +TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) { + auto ch = MakeChannel(0); + size_t num_threads = 5; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + + // Launches threads that try to read and are blocked becausew of no writers + for (size_t i = 0; i < num_threads; i++) { + thread_ended[i] = false; + t[i] = std::thread( + [&](bool *p) { + int data; + ch->Receive(&data); + *p = true; + }, + &thread_ended[i]); + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + + // Verify that all the threads are blocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], false); + } + + // Explicitly close the thread + // This should unblock all receivers + CloseChannel(ch); + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + + // Verify that all threads got unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + + for (size_t i = 0; i < num_threads; i++) t[i].join(); + delete ch; +} + +// This tests that closing an unbuffered channel also unblocks +// unblocks any senders waiting for senders +TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) { + auto ch = MakeChannel(0); + size_t num_threads = 5; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + + // Launches threads that try to read and are blocked becausew of no writers + for (size_t i = 0; i < num_threads; i++) { + thread_ended[i] = false; + t[i] = std::thread( + [&](bool *p) { + int data = 10; + ch->Send(&data); + *p = true; + }, + &thread_ended[i]); + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + + // Verify that all the threads are blocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], false); + } + + // Explicitly close the thread + // This should unblock all receivers + CloseChannel(ch); + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + + // Verify that all threads got unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + + for (size_t i = 0; i < num_threads; i++) t[i].join(); + delete ch; +} + +TEST(Channel, UnbufferedLessReceiveMoreSendTest) { + auto ch = MakeChannel(0); + unsigned sum_send = 0; + // Send should block after three iterations + // since we only have three receivers. + std::thread t([&]() { + // Try to send more number of times + // than receivers + for (int i = 0; i < 4; i++) { + ch->Send(&i); + sum_send += i; + } + }); + for (int i = 0; i < 3; i++) { + int recv; + ch->Receive(&recv); + EXPECT_EQ(recv, i); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec + EXPECT_EQ(sum_send, 3U); + + CloseChannel(ch); + t.join(); + delete ch; +} diff --git a/paddle/framework/details/unbuffered_channel.h b/paddle/framework/details/unbuffered_channel.h index cc2d2e587eca981307d4e522bd569fbffa450207..0dc5afd7e57c1f59dfc1b86093eea231d46966f1 100644 --- a/paddle/framework/details/unbuffered_channel.h +++ b/paddle/framework/details/unbuffered_channel.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include -#include #include #include "paddle/framework/channel.h" @@ -36,20 +36,104 @@ class UnBuffered : public paddle::framework::Channel { virtual ~UnBuffered(); private: - UnBuffered() {} + std::mutex mu_ch_; + // Mutex for readers and writers who are waiting for other reader + // and writer to complete execution + std::recursive_mutex mu_read_, mu_write_; + // reader_found_ is set true when a reader is ready to accept data + // writer_found_ is set true when a writer is ready to send data + // A transaction occurs only when both are true + std::atomic reader_found_{false}, writer_found_{false}; + std::condition_variable cv_channel_; + std::condition_variable_any cv_reader_, cv_writer_; + T* item{nullptr}; + std::atomic closed_{false}; + + UnBuffered() : closed_(false) {} + + void NotifyAllParticipants(std::unique_lock*); }; +// This function implements the concept of how data should +// be sent from a writer to a reader. +template +void UnBuffered::Send(T* data) { + // Prevent other writers from entering + std::unique_lock writer_lock(mu_write_); + writer_found_ = true; + std::unique_lock cv_lock(mu_write_); + // If writer comes first, it should wait till a reader arrives + cv_writer_.wait(cv_lock, + [this]() { return reader_found_ == true || closed_; }); + cv_reader_.notify_one(); + if (!closed_) { + std::unique_lock channel_lock(mu_ch_); + item = data; + channel_lock.unlock(); + cv_channel_.notify_one(); + channel_lock.lock(); + cv_channel_.wait(channel_lock, + [this]() { return item == nullptr || closed_; }); + } + writer_found_ = false; +} + +// This function implements the concept of how +// data that was sent by a writer is read from a reader. template -void UnBuffered::Send(T* channel_element) {} +void UnBuffered::Receive(T* data) { + // Prevent other readers from entering + std::unique_lock read_lock{mu_read_}; + reader_found_ = true; + std::unique_lock cv_lock{mu_read_}; + // If reader comes first, it should wait till a writer arrives + cv_reader_.wait(cv_lock, + [this]() { return writer_found_ == true || closed_; }); + cv_writer_.notify_one(); + if (!closed_) { + std::unique_lock lock_ch{mu_ch_}; + // Reader should wait for the writer to first write its data + cv_channel_.wait(lock_ch, [this]() { return item != nullptr || closed_; }); + if (!closed_) { + *data = std::move(*item); + item = nullptr; + lock_ch.unlock(); + } + cv_channel_.notify_one(); + } + reader_found_ = false; +} +// This function implements the sequence of events +// that take place once the channel is closed. template -void UnBuffered::Receive(T*) {} +void UnBuffered::Close() { + std::unique_lock lock(mu_ch_); + item = nullptr; + closed_ = true; + NotifyAllParticipants(&lock); +} +// This function implements the sequence of events +// that are executed once the object of an UnBuffered +// channel is destroyed. template -void UnBuffered::Close() {} +UnBuffered::~UnBuffered() { + std::unique_lock lock(mu_ch_); + item = nullptr; + closed_ = true; + NotifyAllParticipants(&lock); +} +// This function notifies all the readers, writers and +// the channel condition variables. template -UnBuffered::~UnBuffered() {} +void UnBuffered::NotifyAllParticipants(std::unique_lock* lock) { + lock->unlock(); + cv_writer_.notify_all(); + cv_channel_.notify_all(); + cv_reader_.notify_all(); +} } // namespace details } // namespace framework diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index 0e0e23958602343f8e0106e3a88eaac9c6d71066..85caac8dcd9ede4fe997e2fd246d1421aa73c80a 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -34,18 +34,6 @@ namespace framework { template class Vector : public std::vector { - public: - /* NOTE(dzhwinter): - * Data always store and modified on Host. - * If the data is modified when use cuda_data interface, - * You need to call the CopyFromCUDA explicitly to synchronize data. - * - */ - enum class kDataPosition { - kDataOnHost = 0, - kDataOnDevice = 1, - }; - public: using std::vector::vector; @@ -55,11 +43,12 @@ class Vector : public std::vector { virtual ~Vector() { #ifdef PADDLE_WITH_CUDA if (cuda_ptr_ != nullptr) { - memory::Free(place_, static_cast(cuda_ptr_)); + memory::Free(place_, cuda_ptr_); } #endif } + /* Get device vector */ T *cuda_data() { CopyToCUDA(); PADDLE_ENFORCE_NOT_NULL( @@ -67,81 +56,73 @@ class Vector : public std::vector { return static_cast(cuda_ptr_); } + /* Get host vector */ T *data() { return std::vector::data(); } - const T *data() const { return std::vector::data(); } + /* Synchronize host vector to device vector */ void CopyToCUDA(); - + /* Synchronize device vector to host vector */ void CopyFromCUDA(); - + /* Switch device vector location */ void CopyToPeer(platform::Place); private: void *cuda_ptr_ = nullptr; - size_t cuda_size_ = 0; - /*The DataPosition is unused now, - if we want support random access from cpu and cuda, - we need to overload all the vector method */ - - kDataPosition position_ = kDataPosition::kDataOnHost; + size_t cuda_size_ = 0; // device vector numel platform::CUDAPlace place_; }; template void Vector::CopyToCUDA() { #ifdef PADDLE_WITH_CUDA - if (cuda_ptr_ == nullptr) { + if (cuda_size_ < this->size()) { + if (cuda_ptr_ != nullptr) { + memory::Free(place_, cuda_ptr_); + } cuda_ptr_ = memory::Alloc(place_, this->size() * sizeof(T)); } + cuda_size_ = this->size(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto *cuda_ctx = pool.GetByPlace(place_); - - memory::Copy(place_, static_cast(cuda_ptr_), platform::CPUPlace(), + auto *ctx = pool.GetByPlace(place_); + memory::Copy(place_, cuda_ptr_, platform::CPUPlace(), static_cast(this->data()), - this->size() * sizeof(T), cuda_ctx->stream()); - cuda_ctx->Wait(); - - cuda_size_ = this->size(); + this->size() * sizeof(T), ctx->stream()); + ctx->Wait(); #endif } template void Vector::CopyFromCUDA() { #ifdef PADDLE_WITH_CUDA - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto *cuda_ctx = pool.GetByPlace(place_); if (cuda_ptr_ == nullptr) { - LOG(WARNING) << "No uncommited cuda data."; + LOG(WARNING) << "No uncommitted cuda data."; return; } this->resize(cuda_size_); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *ctx = pool.GetByPlace(place_); memory::Copy(platform::CPUPlace(), static_cast(this->data()), place_, static_cast(cuda_ptr_), this->size() * sizeof(T), - cuda_ctx->stream()); - cuda_ctx->Wait(); - + ctx->stream()); + ctx->Wait(); #endif } template void Vector::CopyToPeer(platform::Place peer_place) { - if (platform::is_cpu_place(peer_place)) { - return; - } #ifdef PADDLE_WITH_CUDA - auto *cuda_ctx = platform::DeviceContextPool::Instance().GetByPlace(place_); - void *peer_cuda_ptr_ = memory::Alloc( + auto *ctx = platform::DeviceContextPool::Instance().GetByPlace(place_); + void *peer_cuda_ptr = memory::Alloc( boost::get(peer_place), this->size() * sizeof(T)); - memory::Copy(boost::get(peer_place), - static_cast(peer_cuda_ptr_), place_, - static_cast(cuda_ptr_), this->size() * sizeof(T), - cuda_ctx->stream()); - cuda_ctx->Wait(); - memory::Free(place_, static_cast(cuda_ptr_)); + memory::Copy(boost::get(peer_place), peer_cuda_ptr, + place_, cuda_ptr_, this->size() * sizeof(T), ctx->stream()); + ctx->Wait(); + + memory::Free(place_, cuda_ptr_); place_ = boost::get(peer_place); - cuda_ptr_ = peer_cuda_ptr_; + cuda_ptr_ = peer_cuda_ptr; #endif } diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index cbdbf5335d32d55a0221728758025c9d2cb3e7d1..a9876cec2aabf7d116443b685391ee9d20bc1370 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -178,19 +178,22 @@ public: real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); + real* colData = NULL; bool needIm2col = isNeedIm2col(filter); TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape; - real* colData = NULL; - size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth; - size_t colWidth = outputHeight * outputWidth; - // Max col matrix height 256, Max col matrix width 1024 - size_t stepColHeight = std::min(colHeight, static_cast(256)); - size_t stepColWidth = std::min(colWidth, static_cast(2048)); + // Max col matrix width 4096, Max col matrix size 4M. + size_t outputHeightSteps = + std::min(std::max(4096 / outputWidth, (size_t)1), outputHeight); + size_t maxColWidth = outputHeightSteps * outputWidth; + size_t channelSteps = + std::min(std::max((1048576 / maxColWidth) / filterHeight * filterWidth, + (size_t)1), + inputChannels / groups_); + size_t maxColHeight = channelSteps * filterHeight * filterWidth; if (needIm2col) { colShape = TensorShape({inputChannels / groups_, @@ -199,7 +202,7 @@ public: outputHeight, outputWidth}); - resizeBuffer(stepColHeight * stepColWidth * sizeof(real)); + resizeBuffer(maxColHeight * maxColWidth * sizeof(real)); colData = reinterpret_cast(memory_->getBuf()); } @@ -209,20 +212,24 @@ public: (outputChannels / groups_) * outputHeight * outputWidth; size_t filterOffset = filter.getElements() / groups_; - int nStride = colWidth; - int kStride = colHeight; + int nStride = outputHeight * outputWidth; + int kStride = inputChannels / groups_ * filterHeight * filterWidth; for (size_t i = 0; i < batchSize; i++) { + filterData = inputs[1].data(); for (size_t g = 0; g < groups_; g++) { if (needIm2col) { real beta_ = beta; - for (size_t colHeightStart = 0; colHeightStart < colHeight; - colHeightStart += stepColHeight) { - for (size_t colWidthStart = 0; colWidthStart < colWidth; - colWidthStart += stepColWidth) { - int N = std::min(colWidth - colWidthStart, stepColWidth); - int K = std::min(colHeight - colHeightStart, stepColHeight); + for (size_t ic = 0; ic < inputChannels / groups_; + ic += channelSteps) { + int channels = std::min(inputChannels / groups_ - ic, channelSteps); + for (size_t oh = 0; oh < outputHeight; oh += outputHeightSteps) { + int height = std::min(outputHeight - oh, outputHeightSteps); + + int M = outputChannels / groups_; + int N = height * outputWidth; + int K = channels * filterHeight * filterWidth; // im2col - im2col(inputData + g * inputOffset, + im2col(inputData, imShape, colData, colShape, @@ -232,13 +239,12 @@ public: paddingW(), dilationH(), dilationW(), - colHeightStart, - K, - colWidthStart, + channels, + oh, + height, N); // gemm - int M = outputChannels / groups_; BlasGemm::compute( false, false, @@ -246,12 +252,12 @@ public: N, K, 1.0f, - filterData + g * filterOffset + colHeightStart, + filterData + ic * filterHeight * filterWidth, kStride, colData, N, beta_, - outputData + g * outputOffset + colWidthStart, + outputData + oh * outputWidth, nStride); } beta_ = 1.0; @@ -266,17 +272,18 @@ public: N, K, 1.0f, - filterData + g * filterOffset, + filterData, K, - inputData + g * inputOffset, + inputData, N, beta, - outputData + g * outputOffset, + outputData, N); } + inputData += inputOffset; + outputData += outputOffset; + filterData += filterOffset; } - inputData += inputChannels * inputHeight * inputWidth; - outputData += outputChannels * outputHeight * outputWidth; } memory_.reset(); diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 36a9bcf84e4b14965c83627821b71d1c7c0da1b2..915119e291caaa223249cf8e37078723621517b0 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -111,39 +111,42 @@ public: int paddingWidth, int dilationHeight, int dilationWidth, - int colHeightStart, - int colHeightSize, - int colWidthStart, - int colWidthSize) { + int inputChannels, + int colOffset, + int colOutputHeight, + int colWidth) { int inputHeight = imShape[1]; int inputWidth = imShape[2]; int filterHeight = colShape[1]; int filterWidth = colShape[2]; int outputWidth = colShape[4]; - for (int colh = 0; colh < colHeightSize; colh++) { - int wOffset = (colHeightStart + colh) % filterWidth; - int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight; - int c_im = (colHeightStart + colh) / filterWidth / filterHeight; - - for (int colw = 0; colw < colWidthSize; colw++) { - int h = (colWidthStart + colw) / outputWidth; - int w = (colWidthStart + colw) % outputWidth; - - int imRowIdx = h * strideHeight + hOffset * dilationHeight; - int imColIdx = w * strideWidth + wOffset * dilationWidth; - if ((imRowIdx - paddingHeight) < 0 || - (imRowIdx - paddingHeight) >= inputHeight || - (imColIdx - paddingWidth) < 0 || - (imColIdx - paddingWidth) >= inputWidth) { - colData[colh * colWidthSize + colw] = static_cast(0); - } else { - imRowIdx += c_im * inputHeight - paddingHeight; - imColIdx -= paddingWidth; - colData[colh * colWidthSize + colw] = - imData[imRowIdx * inputWidth + imColIdx]; + for (int ic = 0; ic < inputChannels; ic++) { + for (int oh = 0; oh < colOutputHeight; oh++) { + T* dstData = colData + oh * outputWidth; + for (int fh = 0; fh < filterHeight; fh++) { + for (int fw = 0; fw < filterWidth; fw++) { + int imRowIdx = (oh + colOffset) * strideHeight + + fh * dilationHeight - paddingHeight; + if (imRowIdx < 0 || imRowIdx >= inputHeight) { + memset(dstData, 0, outputWidth * sizeof(T)); + } else { + for (int ow = 0; ow < outputWidth; ow++) { + int imColIdx = + ow * strideWidth + fw * dilationWidth - paddingWidth; + if (imColIdx < 0 || imColIdx >= inputWidth) { + dstData[ow] = T(0); + } else { + dstData[ow] = imData[imRowIdx * inputWidth + imColIdx]; + } + } + } + dstData += colWidth; + } } } + colData += filterHeight * filterWidth * colWidth; + imData += inputHeight * inputWidth; } } }; diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index 3ba866dcdd845403d52f7a85adfef08cbb11c305..fe44a8bf79005efb87c56f6a79f46421129bab22 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -202,10 +202,10 @@ void TestIm2ColMobileFunctor() { padding, dilation, dilation, + channels, 0, - height, - 0, - width); + outputHeight, + outputHeight * outputWidth); autotest::TensorCheckEqual(*output1, *output2); } diff --git a/paddle/inference/tests/book/CMakeLists.txt b/paddle/inference/tests/book/CMakeLists.txt index d3798fb8fd8769aef5940d4ce724cb0cc8686422..0e987eb0240301c58cfb74c9e995d3b525130125 100644 --- a/paddle/inference/tests/book/CMakeLists.txt +++ b/paddle/inference/tests/book/CMakeLists.txt @@ -4,4 +4,4 @@ cc_test(test_inference_recognize_digits_mlp DEPS ARCHIVE_START paddle_fluid ARCHIVE_END ARGS --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model) set_tests_properties(test_inference_recognize_digits_mlp - PROPERTIES DEPENDS test_recognize_digits_mlp_cpu) + PROPERTIES DEPENDS test_recognize_digits) diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 1ec4336cabbc7d3073b7638b7484bf61e83a2dc5..cc86b12be08ba987f9682ebf3fda56c2f07fb576 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -2015,13 +2015,6 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, CHECK_EQ(channels * outLength, maskMatP->getWidth()); } - /* initialize the data_ */ - for (size_t i = 0; i < height_; i++) { - for (size_t j = 0; j < width_; j++) { - outData[i * outStride + j] = -(real)FLT_MAX; - } - } - /* pool max one by one */ for (size_t n = 0; n < num; ++n) { // frame by frame if (!isContiguous()) { @@ -2030,19 +2023,24 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, for (size_t c = 0; c < channels; ++c) { // channel by channel for (size_t ph = 0; ph < outputH; ++ph) { int hstart = ph * strideH - paddingH; - int hend = std::min(hstart + sizeY, imgSizeH); - hstart = std::max(hstart, 0); + int hend = hstart + sizeY; + hstart = hstart < 0 ? 0 : hstart; + hend = hend < (int)imgSizeH ? hend : (int)imgSizeH; for (size_t pw = 0; pw < outputW; ++pw) { int wstart = pw * strideW - paddingW; - int wend = std::min(wstart + sizeX, imgSizeW); - wstart = std::max(wstart, 0); + int wend = wstart + sizeX; + wstart = wstart < 0 ? 0 : wstart; + wend = wend < (int)imgSizeW ? wend : (int)imgSizeW; if (maskData == NULL) { + real tmp = -(real)FLT_MAX; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - outData[ph * outputW + pw] = std::max( - outData[ph * outputW + pw], inputData[h * imgSizeW + w]); + tmp = tmp < inputData[h * imgSizeW + w] + ? inputData[h * imgSizeW + w] + : tmp; } } + outData[ph * outputW + pw] = tmp; } else { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index b2e73b6f23bd36e29be4e97237a269d12b92bd90..e903f43ba69ee9e28b3a03e8921a41ffa81a2542 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -122,9 +122,11 @@ if(WITH_DISTRIBUTE) set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(recv_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS send_op recv_op sum_op executor) + op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS}) + set_source_files_properties(listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS send_op listen_and_serv_op sum_op executor) else() - set(DEPS_OPS ${DEPS_OPS} send_op recv_op) + set(DEPS_OPS ${DEPS_OPS} send_op recv_op listen_and_serv_op) endif() op_library(cond_op DEPS framework_proto tensor net_op) diff --git a/paddle/operators/listen_and_serv_op.cc b/paddle/operators/listen_and_serv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..099f6b23736adcc2a6e9c27dca297178687ae785 --- /dev/null +++ b/paddle/operators/listen_and_serv_op.cc @@ -0,0 +1,207 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include + +#include "paddle/framework/executor.h" +#include "paddle/framework/framework.pb.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/proto_desc.h" +#include "paddle/operators/detail/grpc_server.h" +#include "paddle/operators/detail/sendrecvop_utils.h" +#include "paddle/operators/detail/simple_block_queue.h" +#include "paddle/string/printf.h" + +namespace paddle { +namespace operators { + +constexpr char kOptimizeBlock[] = "OptimizeBlock"; + +void RunServer(std::shared_ptr service) { + service->RunSyncUpdate(); + VLOG(4) << "RunServer thread end"; +} + +static void CreateTensorFromMessageType(framework::Variable *var, + sendrecv::VarType var_type) { + if (var_type == sendrecv::VarType::LOD_TENSOR) { + var->GetMutable(); + } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { + var->GetMutable(); + } else { + PADDLE_THROW( + "VariableMessage type %d is not in " + "[LoDTensor, SelectedRows]", + var_type); + } +} + +class ListenAndServOp : public framework::OperatorBase { + public: + ListenAndServOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) { + if (!rpc_service_) { + std::string endpoint = Attr("endpoint"); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + server_thread_.reset(new std::thread(RunServer, rpc_service_)); + } + } + + void Stop() override { + detail::MessageWithName term_msg; + term_msg.first = LISTEN_TERMINATE_MESSAGE; + rpc_service_->Push(term_msg); + rpc_service_->ShutDown(); + server_thread_->join(); + } + + std::string GetGradVarNameForTrainer(const std::string &varname) const { + if (grads_counter_.find(varname) == grads_counter_.end()) { + grads_counter_[varname] = 0; + } + return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++); + } + + void Run(const framework::Scope &scope, + const platform::Place &dev_place) const override { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + framework::Scope &recv_scope = scope.NewScope(); + + // FIXME(Yancey1989): initialize rpc server with lazy mode. + rpc_service_->SetScope(&recv_scope); + rpc_service_->SetDevCtx(&dev_ctx); + auto param_list = Attr>("ParamList"); + auto grad_list = Attr>("GradList"); + auto fan_in = Attr("Fanin"); + + auto *block = Attr(kOptimizeBlock); + auto *program = block->Program(); + framework::Executor executor(dev_place); + + // TODO(typhoonzero): change this to a while_op for every cluster-batch. + bool exit_flag = false; + while (!exit_flag) { + // Get from multiple trainers, we don't care about the order in which + // the gradients arrives, just add suffix 0~n and merge the gradient. + rpc_service_->SetCond(0); + size_t recv_var_cnt = 0; + int batch_barrier = 0; + while (batch_barrier != fan_in) { + const detail::MessageWithName &v = rpc_service_->Get(); + auto grad_var_name = v.first; + if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { + LOG(INFO) << "received terminate message and exit"; + exit_flag = true; + break; + } else if (grad_var_name == BATCH_BARRIER_MESSAGE) { + VLOG(3) << "recv batch barrier message"; + batch_barrier++; + continue; + } else { + // receive a variable + recv_var_cnt++; + auto it = + std::find(grad_list.begin(), grad_list.end(), grad_var_name); + std::string param_var_name; + if (it != grad_list.end()) { + param_var_name = param_list[it - grad_list.begin()]; + } else { + LOG(ERROR) << "grad has no paired param:" << grad_var_name; + } + VLOG(3) << "received grad: " << grad_var_name + << " updating param: " << param_var_name; + + if (fan_in > 1) { + grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); + } + auto *var = recv_scope.FindVar(grad_var_name); + if (var == nullptr) { + LOG(ERROR) << "Can not find server side var: " << grad_var_name; + PADDLE_THROW("Can not find server side var"); + } + detail::DeserializeFromMessage(v.second, dev_ctx, var); + } + } + VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier."; + // TODO(Yancey1989): merge SelectedRows variables here + if (exit_flag) { + rpc_service_->ShutDown(); + } + + try { + executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ + false /*create_local_scope*/, false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + rpc_service_->SetCond(1); + rpc_service_->WaitClientGet(recv_var_cnt); + grads_counter_.clear(); + } // while(true) + } + + protected: + std::shared_ptr rpc_service_; + std::shared_ptr server_thread_; + mutable std::unordered_map grads_counter_; +}; + +class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddComment(R"DOC( +ListenAndServ operator + +This operator will start a RPC server which can receive variables +from send_op and send back variables to recv_op. +)DOC"); + AddAttr("endpoint", + "(string, default 127.0.0.1:6164)" + "IP address to listen on.") + .SetDefault("127.0.0.1:6164") + .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); + AddAttr(kOptimizeBlock, + "BlockID to run on server side."); + AddAttr>( + "ParamList", "type list of string", + "grad->param name mapping to find which parameters to optimize.") + .SetDefault({}); + AddAttr>( + "GradList", "type list of string", + "grad->param name mapping to find which parameters to optimize.") + .SetDefault({}); + AddAttr("Fanin", "type int", + "Number of trainers in the current cluster job") + .SetDefault(1); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(listen_and_serv, ops::ListenAndServOp, + ops::ListenAndServOpMaker); diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 49e1eb3402482e7ff12d9b2b640f7271a80cf6d9..ba71094219f37eb7a3c2df68be986cec7afbf7ab 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -12,187 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include #include -#include -#include - -#include "paddle/framework/executor.h" +#include "paddle/framework/data_type.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" -#include "paddle/framework/proto_desc.h" -#include "paddle/operators/detail/grpc_server.h" -#include "paddle/operators/detail/sendrecvop_utils.h" -#include "paddle/operators/detail/simple_block_queue.h" -#include "paddle/string/printf.h" + +#include +#include "paddle/operators/detail/grpc_client.h" namespace paddle { namespace operators { -constexpr char kOptimizeBlock[] = "OptimizeBlock"; - -void RunServer(std::shared_ptr service) { - service->RunSyncUpdate(); - VLOG(4) << "RunServer thread end"; -} - -static void CreateTensorFromMessageType(framework::Variable *var, - sendrecv::VarType var_type) { - if (var_type == sendrecv::VarType::LOD_TENSOR) { - var->GetMutable(); - } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { - var->GetMutable(); - } else { - PADDLE_THROW( - "VariableMessage type %d is not in " - "[LoDTensor, SelectedRows]", - var_type); - } -} - class RecvOp : public framework::OperatorBase { public: - RecvOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) { - if (!rpc_service_) { - std::string endpoint = Attr("endpoint"); - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); - server_thread_.reset(new std::thread(RunServer, rpc_service_)); - } - } - - void Stop() override { - detail::MessageWithName term_msg; - term_msg.first = LISTEN_TERMINATE_MESSAGE; - rpc_service_->Push(term_msg); - rpc_service_->ShutDown(); - server_thread_->join(); - } - - std::string GetGradVarNameForTrainer(const std::string &varname) const { - if (grads_counter_.find(varname) == grads_counter_.end()) { - grads_counter_[varname] = 0; + RecvOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope& scope, + const platform::Place& place) const override { + auto outs = Outputs("Out"); + std::vector epmap = Attr>("epmap"); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + + for (size_t i = 0; i < outs.size(); i++) { + VLOG(3) << "getting " << outs[i]; + client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++); + PADDLE_ENFORCE(client_.Wait()); } - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - framework::Scope &recv_scope = scope.NewScope(); - - // FIXME(Yancey1989): initialize rpc server with laze mode. - rpc_service_->SetScope(&recv_scope); - rpc_service_->SetDevCtx(&dev_ctx); - auto param_list = Attr>("ParamList"); - auto grad_list = Attr>("GradList"); - auto fan_in = Attr("Fanin"); - - auto *block = Attr(kOptimizeBlock); - auto *program = block->Program(); - framework::Executor executor(dev_place); - - // TODO(typhoonzero): change this to a while_op for every cluster-batch. - bool exit_flag = false; - while (!exit_flag) { - // Get from multiple trainers, we don't care about the order in which - // the gradients arrives, just add suffix 0~n and merge the gradient. - rpc_service_->SetCond(0); - size_t recv_var_cnt = 0; - int batch_barrier = 0; - while (batch_barrier != fan_in) { - const detail::MessageWithName &v = rpc_service_->Get(); - auto grad_var_name = v.first; - if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { - LOG(INFO) << "received terminate message and exit"; - exit_flag = true; - break; - } else if (grad_var_name == BATCH_BARRIER_MESSAGE) { - VLOG(3) << "recv batch barrier message"; - batch_barrier++; - continue; - } else { - // receive a variable - recv_var_cnt++; - auto it = - std::find(grad_list.begin(), grad_list.end(), grad_var_name); - std::string param_var_name; - if (it != grad_list.end()) { - param_var_name = param_list[it - grad_list.begin()]; - } else { - LOG(ERROR) << "grad has no paired param:" << grad_var_name; - } - VLOG(3) << "received grad: " << grad_var_name - << " updating param: " << param_var_name; - - if (fan_in > 1) { - grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); - } - auto *var = recv_scope.FindVar(grad_var_name); - if (var == nullptr) { - LOG(ERROR) << "Can not find server side var: " << grad_var_name; - PADDLE_THROW("Can not find server side var"); - } - detail::DeserializeFromMessage(v.second, dev_ctx, var); - } - } - VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier."; - // TODO(Yancey1989): merge SelectedRows variables here - if (exit_flag) { - break; - } - - try { - executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ - false /*create_local_scope*/, false /*create_vars*/); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - rpc_service_->SetCond(1); - rpc_service_->WaitClientGet(recv_var_cnt); - grads_counter_.clear(); - } // while(true) - } - - protected: - std::shared_ptr rpc_service_; - std::shared_ptr server_thread_; - mutable std::unordered_map grads_counter_; + private: + mutable detail::RPCClient client_; }; class RecvOpMaker : public framework::OpProtoAndCheckerMaker { public: - RecvOpMaker(OpProto *proto, OpAttrChecker *op_checker) + RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable(); AddComment(R"DOC( Recv operator -This operator will recieve tensor from send_op +This operator can get variables from server side. )DOC"); - AddAttr("endpoint", - "(string, default 127.0.0.1:6164)" - "IP address to listen on.") - .SetDefault("127.0.0.1:6164") - .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); - AddAttr( - kOptimizeBlock, "Serialized ProgramDesc string for recv to run."); - AddAttr>( - "ParamList", "type list of string", - "grad->param name mapping to find which parameters to optimize.") - .SetDefault({}); - AddAttr>( - "GradList", "type list of string", - "grad->param name mapping to find which parameters to optimize.") + AddAttr>("epmap", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints in the order of input " + "variables for mapping") .SetDefault({}); - AddAttr("Fanin", "type int", - "Number of trainers in the current cluster job") - .SetDefault(1); } }; diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index be41b527f2289d5d657a58f3cb6d7be725323cd0..ee0f268b0e4dfa23bf878d71404d47553183a977 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -62,11 +62,13 @@ class SendOp : public framework::OperatorBase { } PADDLE_ENFORCE(rpc_client->Wait()); - for (size_t i = 0; i < outs.size(); i++) { - VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; - rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + if (outs.size() > 0) { + for (size_t i = 0; i < outs.size(); i++) { + VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; + rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + } + PADDLE_ENFORCE(rpc_client->Wait()); } - PADDLE_ENFORCE(rpc_client->Wait()); } }; @@ -85,6 +87,8 @@ Send operator This operator will send tensor to recv_op at the parameter server. )DOC"); + // TODO(typhoonzero): remove this attr generate de-duplicated vector from + // epmap when initializing. AddAttr>("endpoints", "(string vector, default 127.0.0.1:6164)" "Server endpoints to send variables to.") diff --git a/paddle/operators/send_recv_op_test.cc b/paddle/operators/send_recv_op_test.cc index 045a0f5434f339bab345d14881ed05450ce6588d..31527a906d56da54d2571910de627757d708a996 100644 --- a/paddle/operators/send_recv_op_test.cc +++ b/paddle/operators/send_recv_op_test.cc @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/string/printf.h" USE_NO_KERNEL_OP(send); -USE_NO_KERNEL_OP(recv); +USE_NO_KERNEL_OP(listen_and_serv); USE_OP(sum); namespace f = paddle::framework; @@ -33,7 +33,7 @@ namespace p = paddle::platform; namespace m = paddle::operators::math; // global for simplicity. -std::unique_ptr recv_op; +std::unique_ptr listen_and_serv_op; void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) { p::CPUDeviceContext ctx(place); @@ -120,7 +120,7 @@ void StartServerNet(bool is_sparse) { InitTensorsInScope(scope, place); } - // sub program run in recv_op, for simple test we use sum + // sub program run in listen_and_serv_op, for simple test we use sum f::ProgramDesc program; f::BlockDesc *block = program.MutableBlock(0); // X for server side tensors, RX for received tensers, must be of same shape. @@ -131,8 +131,9 @@ void StartServerNet(bool is_sparse) { attrs.insert({"ParamList", std::vector({"Out"})}); attrs.insert({"GradList", std::vector({"x1"})}); attrs.insert({"OptimizeBlock", block}); - recv_op = f::OpRegistry::CreateOp("recv", {{"RX", {"x1"}}}, {}, attrs); - recv_op->Run(scope, place); + listen_and_serv_op = + f::OpRegistry::CreateOp("listen_and_serv", {}, {}, attrs); + listen_and_serv_op->Run(scope, place); } TEST(SendRecvOp, CPUDense) { @@ -161,9 +162,9 @@ TEST(SendRecvOp, CPUDense) { for (int64_t i = 0; i < target->numel(); ++i) { EXPECT_EQ(expected[i] * 2, actual[i]); } - recv_op->Stop(); + listen_and_serv_op->Stop(); server_thread.join(); - recv_op.reset(nullptr); + listen_and_serv_op.reset(nullptr); } TEST(SendRecvOp, CPUSparse) { @@ -200,7 +201,7 @@ TEST(SendRecvOp, CPUSparse) { EXPECT_EQ(expect_value->mutable_data(place)[i], actual->mutable_data(place)[i]); } - recv_op->Stop(); + listen_and_serv_op->Stop(); server_thread.join(); - recv_op.reset(); + listen_and_serv_op.reset(); } diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc index 2fdd25dbbe68659f8a0a9da13a87148ed259127a..a744ebd61595403ee495a2e2c9e84181422e92ff 100644 --- a/paddle/operators/while_op.cc +++ b/paddle/operators/while_op.cc @@ -53,6 +53,8 @@ class WhileOp : public framework::OperatorBase { auto step_scopes = scope.FindVar(Output(kStepScopes))->GetMutable(); + PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), + "Condition of while op must in CPU memory."); while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); step_scopes->push_back(¤t_scope); @@ -99,6 +101,9 @@ class WhileGradOp : public framework::OperatorBase { void Run(const framework::Scope &scope, const platform::Place &dev_place) const override { + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); framework::Executor executor(dev_place); auto *block = Attr(kStepBlock); auto *program = block->Program(); @@ -205,6 +210,8 @@ class WhileGradOp : public framework::OperatorBase { sum_op->Run(cur_scope, dev_place); cur_scope.Rename(new_inside_name, inside_grad_name); } + dev_ctx.Wait(); + const_cast(scope).DeleteScope(&cur_scope); } } }; diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 8369ded8cb9ab3dc43fc5690e67af518fc5d011a..df7310d6b70ac95953177024a7c2981d1c81a901 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -32,7 +32,7 @@ function cmake_gen() { cat <(new_argv.size()); char** new_argv_address = new_argv.data(); diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index f52346c3b59264370f46844d0e6b1e2d489299c7..3ee58393c72c0b6f9bec96be51ad3946752a35dd 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -76,7 +76,9 @@ def __bootstrap__(): os.environ['OMP_NUM_THREADS'] = str(num_threads) - read_env_flags = ['use_pinned_memory', 'check_nan_inf', 'benchmark'] + read_env_flags = [ + 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir' + ] if core.is_compiled_with_cuda(): read_env_flags += ['fraction_of_gpu_memory_to_use'] core.init_gflags([sys.argv[0]] + diff --git a/python/paddle/v2/fluid/debuger.py b/python/paddle/v2/fluid/debuger.py new file mode 100644 index 0000000000000000000000000000000000000000..d37935244291d15da8199bd3d3979118e4944b48 --- /dev/null +++ b/python/paddle/v2/fluid/debuger.py @@ -0,0 +1,73 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from graphviz import GraphPreviewGenerator +import proto.framework_pb2 as framework_pb2 + + +def draw_block_graphviz(block, highlights=None, path="./temp.dot"): + ''' + Generate a debug graph for block. + Args: + block(Block): a block. + ''' + graph = GraphPreviewGenerator("some graph") + # collect parameters and args + protostr = block.desc.serialize_to_string() + desc = framework_pb2.BlockDesc.FromString(str(protostr)) + + def need_highlight(name): + if highlights is None: return False + for pattern in highlights: + assert type(pattern) is str + if re.match(pattern, name): + return True + return False + + # draw parameters and args + vars = {} + for var in desc.vars: + shape = [str(i) for i in var.lod_tensor.tensor.dims] + if not shape: + shape = ['null'] + # create var + if var.persistable: + varn = graph.add_param( + var.name, var.type, shape, highlight=need_highlight(var.name)) + else: + varn = graph.add_arg(var.name, highlight=need_highlight(var.name)) + vars[var.name] = varn + + def add_op_link_var(op, var, op2var=False): + for arg in var.arguments: + if arg not in vars: + # add missing variables as argument + vars[arg] = graph.add_arg(arg, highlight=need_highlight(arg)) + varn = vars[arg] + highlight = need_highlight(op.description) or need_highlight( + varn.description) + if op2var: + graph.add_edge(op, varn, highlight=highlight) + else: + graph.add_edge(varn, op, highlight=highlight) + + for op in desc.ops: + opn = graph.add_op(op.type, highlight=need_highlight(op.type)) + for var in op.inputs: + add_op_link_var(opn, var, False) + for var in op.outputs: + add_op_link_var(opn, var, True) + + graph(path, show=True) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index a4464a281aae714d79a531ec8a2cf793d6330a12..121b407cae41fa477843b7252ebacc9053d5f7aa 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -478,9 +478,9 @@ class DistributeTranspiler: else: self._append_pserver_non_opt_ops(optimize_sub_program, pserver_program, opt_op) - # Append the recv op + # Append the listen_and_serv op pserver_program.global_block().append_op( - type="recv", + type="listen_and_serv", inputs={}, outputs={}, attrs={ diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 8bf545e2ecc3939b00ba25d003a6b3887a54f860..69cbebe41e57309bb0993148833836b715e417ce 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -451,9 +451,8 @@ class Operator(object): if not given == need: raise ValueError(("Incorrect setting for output(s) of " "operator \"%s\". Need: [%s] Given: [%s]") % - (type, ", ".join(str(e) - for e in need), ", ".join( - str(e) for e in given))) + (type, ", ".join(str(e) for e in need), + ", ".join(str(e) for e in given))) for out_proto in proto.outputs: out_args = outputs[out_proto.name] @@ -489,7 +488,8 @@ class Operator(object): no_kernel_op_set = { 'feed', 'fetch', 'save', 'load', 'recurrent', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', - 'recv', 'parallel_do' + 'recv', 'listen_and_serv', 'parallel_do', 'save_combine', + 'load_combine' } if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) diff --git a/python/paddle/v2/fluid/graphviz.py b/python/paddle/v2/fluid/graphviz.py new file mode 100644 index 0000000000000000000000000000000000000000..5881119c39231282b5654cd60720a1d8a7877896 --- /dev/null +++ b/python/paddle/v2/fluid/graphviz.py @@ -0,0 +1,272 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import subprocess +import logging + + +def crepr(v): + if type(v) is str or type(v) is unicode: + return '"%s"' % v + return str(v) + + +class Rank(object): + def __init__(self, kind, name, priority): + ''' + kind: str + name: str + priority: int + ''' + self.kind = kind + self.name = name + self.priority = priority + self.nodes = [] + + def __str__(self): + if not self.nodes: + return '' + + return '{' + 'rank={};'.format(self.kind) + \ + ','.join([node.name for node in self.nodes]) + '}' + + +class Graph(object): + rank_counter = 0 + + def __init__(self, title, **attrs): + self.title = title + self.attrs = attrs + self.nodes = [] + self.edges = [] + self.rank_groups = {} + + def code(self): + return self.__str__() + + def rank_group(self, kind, priority): + name = "rankgroup-%d" % Graph.rank_counter + Graph.rank_counter += 1 + rank = Rank(kind, name, priority) + self.rank_groups[name] = rank + return name + + def node(self, label, prefix, description="", **attrs): + node = Node(label, prefix, description, **attrs) + + if 'rank' in attrs: + rank = self.rank_groups[attrs['rank']] + del attrs['rank'] + rank.nodes.append(node) + self.nodes.append(node) + return node + + def edge(self, source, target, **attrs): + edge = Edge(source, target, **attrs) + self.edges.append(edge) + return edge + + def compile(self, dot_path): + file = open(dot_path, 'w') + file.write(self.__str__()) + image_path = os.path.join( + os.path.dirname(__file__), dot_path[:-3] + "pdf") + cmd = ["dot", "-Tpdf", dot_path, "-o", image_path] + subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + logging.warning("write block debug graph to {}".format(image_path)) + return image_path + + def show(self, dot_path): + image = self.compile(dot_path) + cmd = ["open", image] + subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + def _rank_repr(self): + ranks = sorted( + self.rank_groups.items(), + cmp=lambda a, b: a[1].priority > b[1].priority) + repr = [] + for x in ranks: + repr.append(str(x[1])) + return '\n'.join(repr) + '\n' + + def __str__(self): + reprs = [ + 'digraph G {', + 'title = {}'.format(crepr(self.title)), + ] + + for attr in self.attrs: + reprs.append("{key}={value};".format( + key=attr, value=crepr(self.attrs[attr]))) + + reprs.append(self._rank_repr()) + + random.shuffle(self.nodes) + reprs += [str(node) for node in self.nodes] + + for x in self.edges: + reprs.append(str(x)) + + reprs.append('}') + return '\n'.join(reprs) + + +class Node(object): + counter = 1 + + def __init__(self, label, prefix, description="", **attrs): + self.label = label + self.name = "%s_%d" % (prefix, Node.counter) + self.description = description + self.attrs = attrs + Node.counter += 1 + + def __str__(self): + reprs = '{name} [label={label} {extra} ];'.format( + name=self.name, + label=self.label, + extra=',' + ','.join("%s=%s" % (key, crepr(value)) + for key, value in self.attrs.items()) + if self.attrs else "") + return reprs + + +class Edge(object): + def __init__(self, source, target, **attrs): + ''' + Link source to target. + :param source: Node + :param target: Node + :param graph: Graph + :param attrs: dic + ''' + self.source = source + self.target = target + self.attrs = attrs + + def __str__(self): + repr = "{source} -> {target} {extra}".format( + source=self.source.name, + target=self.target.name, + extra="" if not self.attrs else + "[" + ','.join("{}={}".format(attr[0], crepr(attr[1])) + for attr in self.attrs.items()) + "]") + return repr + + +class GraphPreviewGenerator(object): + ''' + Generate a graph image for ONNX proto. + ''' + + def __init__(self, title): + # init graphviz graph + self.graph = Graph( + title, + layout="dot", + concentrate="true", + rankdir="TB", ) + + self.op_rank = self.graph.rank_group('same', 2) + self.param_rank = self.graph.rank_group('same', 1) + self.arg_rank = self.graph.rank_group('same', 0) + + def __call__(self, path='temp.dot', show=False): + if not show: + self.graph.compile(path) + else: + self.graph.show(path) + + def add_param(self, name, data_type, shape, highlight=False): + label = '\n'.join([ + '<', + ' ', + ' ', + ' ', + ' ', + ' ' + ' ', + ' ', + ' ' + ' ', + '
', + ' ', + name, + ' ', + '
', + str(data_type), + '
', + '[%s]' % 'x'.join(shape), + '
>', + ]) + return self.graph.node( + label, + prefix="param", + description=name, + shape="none", + style="rounded,filled,bold", + width="1.3", + color="#148b97" if not highlight else "orange", + fontcolor="#ffffff", + fontname="Arial") + + def add_op(self, opType, **kwargs): + highlight = False + if 'highlight' in kwargs: + highlight = kwargs['highlight'] + del kwargs['highlight'] + return self.graph.node( + "<%s>" % opType, + prefix="op", + description=opType, + shape="box", + style="rounded, filled, bold", + color="#303A3A" if not highlight else "orange", + fontname="Arial", + fontcolor="#ffffff", + width="1.3", + height="0.84", ) + + def add_arg(self, name, highlight=False): + return self.graph.node( + crepr(name), + prefix="arg", + description=name, + shape="box", + style="rounded,filled,bold", + fontname="Arial", + fontcolor="#999999", + color="#dddddd" if not highlight else "orange") + + def add_edge(self, source, target, **kwargs): + highlight = False + if 'highlight' in kwargs: + highlight = kwargs['highlight'] + del kwargs['highlight'] + return self.graph.edge( + source, + target, + color="#00000" if not highlight else "orange", + **kwargs) diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index d56ec45c538b580f5520bc060b4b339bb1be0539..613dc20b6ea5533d126a73b7ec47796b3f812db5 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -46,6 +46,9 @@ def is_parameter(var): def is_persistable(var): + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST: + return False return var.persistable @@ -60,7 +63,12 @@ def _clone_var_in_block_(block, var): persistable=True) -def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): +def save_vars(executor, + dirname, + main_program=None, + vars=None, + predicate=None, + save_file_name=None): """ Save variables to directory by executor. @@ -69,9 +77,12 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): :param main_program: program. If vars is None, then filter all variables in this program which fit `predicate`. Default default_main_program. :param predicate: The Predicate describes a callable that returns a variable - as a bool. If it returns true, the variables will be saved. - :param vars: variables need to be saved. If specify vars, program & predicate + as a bool. If it returns true, the corresponding input variable will be saved. + :param vars: variables need to be saved. If vars is specified, program & predicate will be ignored + :param save_file_name: The name of a single file that all vars are saved to. + If it is None, save variables to separate files. + :return: None """ if vars is None: @@ -83,21 +94,39 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): save_vars( executor, dirname=dirname, - vars=filter(predicate, main_program.list_vars())) + vars=filter(predicate, main_program.list_vars()), + save_file_name=save_file_name) else: save_program = Program() save_block = save_program.global_block() + + save_var_map = {} for each_var in vars: new_var = _clone_var_in_block_(save_block, each_var) + if save_file_name is None: + save_block.append_op( + type='save', + inputs={'X': [new_var]}, + outputs={}, + attrs={'file_path': os.path.join(dirname, new_var.name)}) + else: + save_var_map[new_var.name] = new_var + + if save_file_name is not None: + save_var_list = [] + for name in sorted(save_var_map.keys()): + save_var_list.append(save_var_map[name]) + save_block.append_op( - type='save', - inputs={'X': [new_var]}, + type='save_combine', + inputs={'X': save_var_list}, outputs={}, - attrs={'file_path': os.path.join(dirname, new_var.name)}) + attrs={'file_path': os.path.join(dirname, save_file_name)}) + executor.run(save_program) -def save_params(executor, dirname, main_program=None): +def save_params(executor, dirname, main_program=None, save_file_name=None): """ Save all parameters to directory with executor. """ @@ -106,10 +135,12 @@ def save_params(executor, dirname, main_program=None): dirname=dirname, main_program=main_program, vars=None, - predicate=is_parameter) + predicate=is_parameter, + save_file_name=save_file_name) -def save_persistables(executor, dirname, main_program=None): +def save_persistables(executor, dirname, main_program=None, + save_file_name=None): """ Save all persistables to directory with executor. """ @@ -118,21 +149,30 @@ def save_persistables(executor, dirname, main_program=None): dirname=dirname, main_program=main_program, vars=None, - predicate=is_persistable) + predicate=is_persistable, + save_file_name=save_file_name) -def load_vars(executor, dirname, main_program=None, vars=None, predicate=None): +def load_vars(executor, + dirname, + main_program=None, + vars=None, + predicate=None, + load_file_name=None): """ Load variables from directory by executor. - :param executor: executor that save variable + :param executor: executor that load variable :param dirname: directory path :param main_program: program. If vars is None, then filter all variables in this program which fit `predicate`. Default default_main_program(). :param predicate: The Predicate describes a callable that returns a variable - as a bool. If it returns true, the variables will be loaded. - :param vars: variables need to be loaded. If specify vars, program & + as a bool. If it returns true, the corresponding input variable will be loaded. + :param vars: variables need to be loaded. If vars is specified, program & predicate will be ignored + :param load_file_name: The name of the single file that all vars are loaded from. + If it is None, load variables from separate files. + :return: None """ if vars is None: @@ -144,23 +184,40 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None): load_vars( executor, dirname=dirname, - vars=filter(predicate, main_program.list_vars())) + vars=filter(predicate, main_program.list_vars()), + load_file_name=load_file_name) else: load_prog = Program() load_block = load_prog.global_block() + + load_var_map = {} for each_var in vars: assert isinstance(each_var, Variable) new_var = _clone_var_in_block_(load_block, each_var) + if load_file_name is None: + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [new_var]}, + attrs={'file_path': os.path.join(dirname, new_var.name)}) + else: + load_var_map[new_var.name] = new_var + + if load_file_name is not None: + load_var_list = [] + for name in sorted(load_var_map.keys()): + load_var_list.append(load_var_map[name]) + load_block.append_op( - type='load', + type='load_combine', inputs={}, - outputs={"Out": [new_var]}, - attrs={'file_path': os.path.join(dirname, new_var.name)}) + outputs={"Out": load_var_list}, + attrs={'file_path': os.path.join(dirname, load_file_name)}) executor.run(load_prog) -def load_params(executor, dirname, main_program=None): +def load_params(executor, dirname, main_program=None, load_file_name=None): """ load all parameters from directory by executor. """ @@ -168,10 +225,12 @@ def load_params(executor, dirname, main_program=None): executor, dirname=dirname, main_program=main_program, - predicate=is_parameter) + predicate=is_parameter, + load_file_name=load_file_name) -def load_persistables(executor, dirname, main_program=None): +def load_persistables(executor, dirname, main_program=None, + load_file_name=None): """ load all persistables from directory by executor. """ @@ -179,7 +238,8 @@ def load_persistables(executor, dirname, main_program=None): executor, dirname=dirname, main_program=main_program, - predicate=is_persistable) + predicate=is_persistable, + load_file_name=load_file_name) def get_inference_program(target_vars, main_program=None): @@ -238,7 +298,8 @@ def save_inference_model(dirname, feeded_var_names, target_vars, executor, - main_program=None): + main_program=None, + save_file_name=None): """ Build a model especially for inference, and save it to directory by the executor. @@ -249,6 +310,8 @@ def save_inference_model(dirname, :param executor: executor that save inference model :param main_program: original program, which will be pruned to build the inference model. Default default_main_program(). + :param save_file_name: The name of a single file that all parameters are saved to. + If it is None, save parameters to separate files. :return: None """ @@ -283,25 +346,7 @@ def save_inference_model(dirname, with open(model_file_name, "wb") as f: f.write(inference_program.desc.serialize_to_string()) - save_params(executor, dirname, main_program) - - -def load_persistables_if_exist(executor, dirname, main_program=None): - filenames = next(os.walk(dirname))[2] - filenames = set(filenames) - - def _is_presistable_and_exist_(var): - if not is_persistable(var): - return False - else: - return var.name in filenames - - load_vars( - executor, - dirname, - main_program=main_program, - vars=None, - predicate=_is_presistable_and_exist_) + save_persistables(executor, dirname, inference_program, save_file_name) def get_feed_targets_names(program): @@ -322,13 +367,15 @@ def get_fetch_targets_names(program): return fetch_targets_names -def load_inference_model(dirname, executor): +def load_inference_model(dirname, executor, load_file_name=None): """ Load inference model from a directory :param dirname: directory path :param executor: executor that load inference model - + :param load_file_name: The name of the single file that all parameters are loaded from. + If it is None, load parameters from separate files. + :return: [program, feed_target_names, fetch_targets] program: program especially for inference. feed_target_names: Names of variables that need to feed data @@ -342,7 +389,7 @@ def load_inference_model(dirname, executor): program_desc_str = f.read() program = Program.parse_from_string(program_desc_str) - load_persistables_if_exist(executor, dirname, program) + load_persistables(executor, dirname, program, load_file_name) feed_target_names = get_feed_targets_names(program) fetch_target_names = get_fetch_targets_names(program) @@ -359,6 +406,7 @@ def get_parameter_value(para, executor): :param executor: executor for retrieving the value :param para: the given parameter + :return: the LoDTensor for the parameter """ assert is_parameter(para) @@ -377,6 +425,7 @@ def get_parameter_value_by_name(name, executor, program=None): :param name: the name of the parameter :param program: the program where the variable is found Default default_main_program(). + :return: the LoDTensor for the variable """ if program is None: diff --git a/python/paddle/v2/fluid/layers/io.py b/python/paddle/v2/fluid/layers/io.py index b7b2cf2296cc8868dd0b5eb6cd6d58b9ae795d5d..85e44a0e5149bd36f2787d9f2d516dbe4abdbb2e 100644 --- a/python/paddle/v2/fluid/layers/io.py +++ b/python/paddle/v2/fluid/layers/io.py @@ -108,7 +108,7 @@ class ListenAndServ(object): """ def __init__(self, endpoint, fan_in=1, optimizer_mode=True): - self.helper = LayerHelper("recv") + self.helper = LayerHelper("listen_and_serv") self.inputs = [] self.outputs = [] self.endpoint = endpoint @@ -158,7 +158,7 @@ class ListenAndServ(object): param_names = [p.name for p in params] grad_names = [g.name for g in grads] parent_block.append_op( - type='recv', + type='listen_and_serv', inputs={}, outputs={}, attrs={ @@ -196,3 +196,31 @@ def Send(endpoints, send_vars, get_vars): outputs={"Out": get_vars}, attrs={"endpoints": endpoints, "epmap": epmap}) + + +def Recv(endpoints, get_vars): + """ + Recv layer + + Args: + endpoints: comma seperated IP:PORT pairs in the order + of send_vars to send + send_vars: vars to send + get_vars: vars to get from server after send completes. + + Send variables to the server side, and get vars from server + side when server have finished running server side program. + """ + assert (type(send_vars) == list) + assert (type(get_vars) == list) + + epmap = endpoints.split(",") + endpoints = list(set(epmap)) + + helper = LayerHelper("Recv", **locals()) + helper.append_op( + type="recv", + inputs={"X": get_vars}, + outputs={"Out": get_vars}, + attrs={"endpoints": endpoints, + "epmap": epmap}) diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index ee3172c7b8dfd65c693e5aee9b55179e654ce7be..c701e79ad266d996038c6868718106664e1009b5 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -59,6 +59,7 @@ __all__ = [ 'elementwise_pow', 'clip', 'clip_by_norm', + 'softmax', 'sequence_softmax', ] + __activations__ diff --git a/python/paddle/v2/fluid/layers/tensor.py b/python/paddle/v2/fluid/layers/tensor.py index c435c5206d1ef1ef57683a1a47bf089be6526f38..8460af2a08a54c5f241553757296ed2bf02f0167 100644 --- a/python/paddle/v2/fluid/layers/tensor.py +++ b/python/paddle/v2/fluid/layers/tensor.py @@ -295,7 +295,7 @@ def fill_constant_batch_size_like(input, return out -def ones(shape, dtype): +def ones(shape, dtype, force_cpu=False): """ **ones** @@ -319,7 +319,7 @@ def ones(shape, dtype): return fill_constant(value=1.0, **locals()) -def zeros(shape, dtype): +def zeros(shape, dtype, force_cpu=False): """ **zeros** diff --git a/python/paddle/v2/fluid/memory_optimization_transpiler.py b/python/paddle/v2/fluid/memory_optimization_transpiler.py index 956c5b66da28fd8e74d4fd12f249688daa72d8ac..2b00923f5e85e6ba8fcdedebf5bbbc29403472c6 100644 --- a/python/paddle/v2/fluid/memory_optimization_transpiler.py +++ b/python/paddle/v2/fluid/memory_optimization_transpiler.py @@ -31,7 +31,7 @@ dtype_to_size = { class ControlFlowGraph(object): - def __init__(self, Program, ops, forward_num): + def __init__(self, Program, ops, forward_num, skip_opt): self._program = Program self._ops = ops self._forward_num = forward_num @@ -41,6 +41,7 @@ class ControlFlowGraph(object): self._defs = defaultdict(set) self._live_in = defaultdict(set) self._live_out = defaultdict(set) + self._skip_opt = skip_opt def _add_connections(self, connections): for node1, node2 in connections: @@ -130,6 +131,10 @@ class ControlFlowGraph(object): block_desc, x, is_forward).type() != core.VarDesc.VarType.LOD_TENSOR: return False + if x in self._skip_opt: + return False + if not self._find_var(block_desc, x, is_forward).shape(): + return False return True self._build_graph() @@ -140,6 +145,7 @@ class ControlFlowGraph(object): if op.type() == "while" or op.type() == "while_grad": continue block_desc = op.block() + self.current_block_desc = block_desc is_forward = i < self._forward_num if self.pool: defs_can_optimize = filter( @@ -197,28 +203,32 @@ def get_cfgs(input_program): block_desc = pdesc.block(0) op_size = block_desc.op_size() # Get global block ops - ops_list.append(([block_desc.op(i) for i in range(op_size)], op_size)) + ops_list.append( + ([block_desc.op(i) for i in range(op_size)], op_size, set())) while_sub_block_ids = [] while_grad_sub_block_ids = [] - while_pair = [] + while_op_output = set() + while_block_id_pair = [] for i in range(op_size): op = block_desc.op(i) if op.type() == "while": while_sub_block_ids.append(op.attr("sub_block").id) + while_op_output.update(op.output_arg_names()) elif op.type() == "while_grad": while_grad_sub_block_ids.append(op.attr("sub_block").id) + while_op_output.update(op.output_arg_names()) # Find while/while_grad block pair for grad_id in while_grad_sub_block_ids: parent_id = pdesc.block(grad_id).parent if parent_id in while_sub_block_ids: - while_pair.append((parent_id, grad_id)) + while_block_id_pair.append((parent_id, grad_id)) while_sub_block_ids.remove(parent_id) # Get while/while_grad block ops - for parent_id, grad_id in while_pair: + for parent_id, grad_id in while_block_id_pair: while_block_ops = [] while_block = pdesc.block(parent_id) while_block_op_size = while_block.op_size() @@ -230,7 +240,7 @@ def get_cfgs(input_program): for i in range(while_grad_block_op_size): while_block_ops.append(while_grad_block.op(i)) - ops_list.append((while_block_ops, while_block_op_size)) + ops_list.append((while_block_ops, while_block_op_size, while_op_output)) # Process rest while block ops for parent_id in while_sub_block_ids: @@ -242,7 +252,7 @@ def get_cfgs(input_program): ops_list.append((while_block_ops, while_block_op_size)) - cfgs = [ControlFlowGraph(input_program, i, j) for i, j in ops_list] + cfgs = [ControlFlowGraph(input_program, i, j, k) for i, j, k in ops_list] return cfgs diff --git a/python/paddle/v2/fluid/tests/CMakeLists.txt b/python/paddle/v2/fluid/tests/CMakeLists.txt index 628ce60b406d880d961d705a6abd2b5236fb1c8c..26a80abcb5839e80b5a22f9415315519ce3042e8 100644 --- a/python/paddle/v2/fluid/tests/CMakeLists.txt +++ b/python/paddle/v2/fluid/tests/CMakeLists.txt @@ -5,9 +5,11 @@ if(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_OPS test_recv_op) endif(NOT WITH_DISTRIBUTE) +list(REMOVE_ITEM TEST_OPS test_warpctc_op) foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() +py_test(test_warpctc_op SRCS test_warpctc_op.py ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR}) add_subdirectory(book) add_subdirectory(book_distribute) diff --git a/python/paddle/v2/fluid/tests/book/.gitignore b/python/paddle/v2/fluid/tests/book/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f0b574b9396706a1d68393482296360362dca750 --- /dev/null +++ b/python/paddle/v2/fluid/tests/book/.gitignore @@ -0,0 +1 @@ +recognize_digits_*.inference.model diff --git a/python/paddle/v2/fluid/tests/book/CMakeLists.txt b/python/paddle/v2/fluid/tests/book/CMakeLists.txt index dda02c03fd531445c1b33b39a6ded10921991d9c..673c965b662a022739f8d489c331f4de9455a926 100644 --- a/python/paddle/v2/fluid/tests/book/CMakeLists.txt +++ b/python/paddle/v2/fluid/tests/book/CMakeLists.txt @@ -1,34 +1,6 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -list(REMOVE_ITEM TEST_OPS test_image_classification_train test_recognize_digits) -py_test(test_image_classification_train_resnet SRCS test_image_classification_train.py ARGS resnet) -py_test(test_image_classification_train_vgg SRCS test_image_classification_train.py ARGS vgg) -py_test(test_recognize_digits_mlp_cpu - SRCS test_recognize_digits.py - ARGS mlp) -py_test(test_recognize_digits_mlp_cuda - SRCS test_recognize_digits.py - ARGS mlp --use_cuda) -py_test(test_recognize_digits_conv_cpu - SRCS test_recognize_digits.py - ARGS conv) -py_test(test_recognize_digits_conv_cuda - SRCS test_recognize_digits.py - ARGS conv --use_cuda) -py_test(test_recognize_digits_mlp_cpu_parallel - SRCS test_recognize_digits.py - ARGS mlp --parallel) -py_test(test_recognize_digits_mlp_cuda_parallel - SRCS test_recognize_digits.py - ARGS mlp --use_cuda --parallel) -py_test(test_recognize_digits_conv_cpu_parallel - SRCS test_recognize_digits.py - ARGS conv --parallel) -py_test(test_recognize_digits_conv_cuda_parallel - SRCS test_recognize_digits.py - ARGS conv --use_cuda --parallel) - # default test foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) diff --git a/python/paddle/v2/fluid/tests/book/test_fit_a_line.py b/python/paddle/v2/fluid/tests/book/test_fit_a_line.py index 0b954c60b6bc2d721c0373243e747056f8f572cf..27f34b17339db31ef3c07555db946fa76d6f1922 100644 --- a/python/paddle/v2/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/v2/fluid/tests/book/test_fit_a_line.py @@ -12,44 +12,74 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import paddle.v2 as paddle import paddle.v2.fluid as fluid +import contextlib +import unittest -x = fluid.layers.data(name='x', shape=[13], dtype='float32') -y_predict = fluid.layers.fc(input=x, size=1, act=None) +def main(use_cuda): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return -y = fluid.layers.data(name='y', shape=[1], dtype='float32') + x = fluid.layers.data(name='x', shape=[13], dtype='float32') -cost = fluid.layers.square_error_cost(input=y_predict, label=y) -avg_cost = fluid.layers.mean(x=cost) + y_predict = fluid.layers.fc(input=x, size=1, act=None) -sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) -sgd_optimizer.minimize(avg_cost) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') -BATCH_SIZE = 20 + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(x=cost) -train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.uci_housing.train(), buf_size=500), - batch_size=BATCH_SIZE) + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer.minimize(avg_cost) -place = fluid.CPUPlace() -feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) -exe = fluid.Executor(place) + BATCH_SIZE = 20 -exe.run(fluid.default_startup_program()) + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.uci_housing.train(), buf_size=500), + batch_size=BATCH_SIZE) -PASS_NUM = 100 -for pass_id in range(PASS_NUM): - fluid.io.save_persistables(exe, "./fit_a_line.model/") - fluid.io.load_persistables(exe, "./fit_a_line.model/") - for data in train_reader(): - avg_loss_value, = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost]) - print(avg_loss_value) - if avg_loss_value[0] < 10.0: - exit(0) # if avg cost less than 10.0, we think our code is good. -exit(1) + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + exe = fluid.Executor(place) + + exe.run(fluid.default_startup_program()) + + PASS_NUM = 100 + for pass_id in range(PASS_NUM): + fluid.io.save_persistables(exe, "./fit_a_line.model/") + fluid.io.load_persistables(exe, "./fit_a_line.model/") + for data in train_reader(): + avg_loss_value, = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost]) + print(avg_loss_value) + if avg_loss_value[0] < 10.0: + return + raise AssertionError("Fit a line cost is too large, {0:2.2}".format( + avg_loss_value[0])) + + +class TestFitALine(unittest.TestCase): + def test_cpu(self): + with self.program_scope_guard(): + main(use_cuda=False) + + def test_cuda(self): + with self.program_scope_guard(): + main(use_cuda=True) + + @contextlib.contextmanager + def program_scope_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py index 30582a21d0a5eeab125f3a2764b45b51aa4f94b6..a4168d16db06f904faed811fdda3f0fe52f0b27b 100644 --- a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py +++ b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py @@ -14,10 +14,10 @@ from __future__ import print_function -import sys - import paddle.v2 as paddle import paddle.v2.fluid as fluid +import unittest +import contextlib def resnet_cifar10(input, depth=32): @@ -89,56 +89,89 @@ def vgg16_bn_drop(input): return fc2 -classdim = 10 -data_shape = [3, 32, 32] - -images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') -label = fluid.layers.data(name='label', shape=[1], dtype='int64') - -net_type = "vgg" -if len(sys.argv) >= 2: - net_type = sys.argv[1] - -if net_type == "vgg": - print("train vgg net") - net = vgg16_bn_drop(images) -elif net_type == "resnet": - print("train resnet") - net = resnet_cifar10(images, 32) -else: - raise ValueError("%s network is not supported" % net_type) - -predict = fluid.layers.fc(input=net, size=classdim, act='softmax') -cost = fluid.layers.cross_entropy(input=predict, label=label) -avg_cost = fluid.layers.mean(x=cost) - -optimizer = fluid.optimizer.Adam(learning_rate=0.001) -opts = optimizer.minimize(avg_cost) - -accuracy = fluid.evaluator.Accuracy(input=predict, label=label) - -BATCH_SIZE = 128 -PASS_NUM = 1 - -train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.cifar.train10(), buf_size=128 * 10), - batch_size=BATCH_SIZE) - -place = fluid.CPUPlace() -exe = fluid.Executor(place) -feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) -exe.run(fluid.default_startup_program()) - -for pass_id in range(PASS_NUM): - accuracy.reset(exe) - for data in train_reader(): - loss, acc = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost] + accuracy.metrics) - pass_acc = accuracy.eval(exe) - print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( - pass_acc)) - # this model is slow, so if we can train two mini batch, we think it works properly. - exit(0) -exit(1) +def main(net_type, use_cuda): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + + classdim = 10 + data_shape = [3, 32, 32] + + images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + if net_type == "vgg": + print("train vgg net") + net = vgg16_bn_drop(images) + elif net_type == "resnet": + print("train resnet") + net = resnet_cifar10(images, 32) + else: + raise ValueError("%s network is not supported" % net_type) + + predict = fluid.layers.fc(input=net, size=classdim, act='softmax') + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + optimizer = fluid.optimizer.Adam(learning_rate=0.001) + optimizer.minimize(avg_cost) + + accuracy = fluid.evaluator.Accuracy(input=predict, label=label) + + BATCH_SIZE = 128 + PASS_NUM = 1 + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.cifar.train10(), buf_size=128 * 10), + batch_size=BATCH_SIZE) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) + exe.run(fluid.default_startup_program()) + + loss = 0.0 + for pass_id in range(PASS_NUM): + accuracy.reset(exe) + for data in train_reader(): + loss, acc = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost] + accuracy.metrics) + pass_acc = accuracy.eval(exe) + print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( + pass_acc)) + return + + raise AssertionError( + "Image classification loss is too large, {0:2.2}".format(loss)) + + +class TestImageClassification(unittest.TestCase): + def test_vgg_cuda(self): + with self.scope_prog_guard(): + main('vgg', use_cuda=True) + + def test_resnet_cuda(self): + with self.scope_prog_guard(): + main('resnet', use_cuda=True) + + def test_vgg_cpu(self): + with self.scope_prog_guard(): + main('vgg', use_cuda=False) + + def test_resnet_cpu(self): + with self.scope_prog_guard(): + main('resnet', use_cuda=False) + + @contextlib.contextmanager + def scope_prog_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/book/test_machine_translation.py b/python/paddle/v2/fluid/tests/book/test_machine_translation.py index 82b760d693560dae1ab1fa39afdc186f60423e65..5716ddd3dda90958ad1008679e018542c4fb73d7 100644 --- a/python/paddle/v2/fluid/tests/book/test_machine_translation.py +++ b/python/paddle/v2/fluid/tests/book/test_machine_translation.py @@ -11,21 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import numpy as np import paddle.v2 as paddle import paddle.v2.fluid as fluid -import paddle.v2.fluid.core as core import paddle.v2.fluid.framework as framework import paddle.v2.fluid.layers as pd from paddle.v2.fluid.executor import Executor +import unittest dict_size = 30000 source_dict_dim = target_dict_dim = dict_size -src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size) hidden_dim = 32 word_dim = 16 -IS_SPARSE = True batch_size = 2 max_length = 8 topk_size = 50 @@ -34,10 +33,8 @@ beam_size = 2 decoder_size = hidden_dim -place = core.CPUPlace() - -def encoder(): +def encoder(is_sparse): # encoder src_word_id = pd.data( name="src_word_id", shape=[1], dtype='int64', lod_level=1) @@ -45,7 +42,7 @@ def encoder(): input=src_word_id, size=[dict_size, word_dim], dtype='float32', - is_sparse=IS_SPARSE, + is_sparse=is_sparse, param_attr=fluid.ParamAttr(name='vemb')) fc1 = pd.fc(input=src_embedding, size=hidden_dim * 4, act='tanh') @@ -54,7 +51,7 @@ def encoder(): return encoder_out -def decoder_train(context): +def decoder_train(context, is_sparse): # decoder trg_language_word = pd.data( name="target_language_word", shape=[1], dtype='int64', lod_level=1) @@ -62,7 +59,7 @@ def decoder_train(context): input=trg_language_word, size=[dict_size, word_dim], dtype='float32', - is_sparse=IS_SPARSE, + is_sparse=is_sparse, param_attr=fluid.ParamAttr(name='vemb')) rnn = pd.DynamicRNN() @@ -82,10 +79,10 @@ def decoder_train(context): return rnn() -def decoder_decode(context): +def decoder_decode(context, is_sparse): init_state = context array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length) - counter = pd.zeros(shape=[1], dtype='int64') + counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True) # fill the first element with init_state state_array = pd.create_array('float32') @@ -117,7 +114,7 @@ def decoder_decode(context): input=pre_ids, size=[dict_size, word_dim], dtype='float32', - is_sparse=IS_SPARSE) + is_sparse=is_sparse) # use rnn unit to update rnn current_state = pd.fc(input=[pre_ids_emb, pre_state_expanded], @@ -150,7 +147,7 @@ def decoder_decode(context): def set_init_lod(data, lod, place): - res = core.LoDTensor() + res = fluid.LoDTensor() res.set(data, place) res.set_lod(lod) return res @@ -165,15 +162,19 @@ def to_lodtensor(data, place): lod.append(cur_len) flattened_data = np.concatenate(data, axis=0).astype("int64") flattened_data = flattened_data.reshape([len(flattened_data), 1]) - res = core.LoDTensor() + res = fluid.LoDTensor() res.set(flattened_data, place) res.set_lod([lod]) return res -def train_main(): - context = encoder() - rnn_out = decoder_train(context) +def train_main(use_cuda, is_sparse): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + + context = encoder(is_sparse) + rnn_out = decoder_train(context, is_sparse) label = pd.data( name="target_language_next_word", shape=[1], dtype='int64', lod_level=1) cost = pd.cross_entropy(input=rnn_out, label=label) @@ -212,9 +213,13 @@ def train_main(): batch_id += 1 -def decode_main(): - context = encoder() - translation_ids, translation_scores = decoder_decode(context) +def decode_main(use_cuda, is_sparse): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + + context = encoder(is_sparse) + translation_ids, translation_scores = decoder_decode(context, is_sparse) exe = Executor(place) exe.run(framework.default_startup_program()) @@ -250,6 +255,60 @@ def decode_main(): break +class TestMachineTranslation(unittest.TestCase): + pass + + +@contextlib.contextmanager +def scope_prog_guard(): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + +def inject_test_train(use_cuda, is_sparse): + f_name = 'test_{0}_{1}_train'.format('cuda' if use_cuda else 'cpu', 'sparse' + if is_sparse else 'dense') + + def f(*args): + with scope_prog_guard(): + train_main(use_cuda, is_sparse) + + setattr(TestMachineTranslation, f_name, f) + + +def inject_test_decode(use_cuda, is_sparse, decorator=None): + f_name = 'test_{0}_{1}_decode'.format('cuda' + if use_cuda else 'cpu', 'sparse' + if is_sparse else 'dense') + + def f(*args): + with scope_prog_guard(): + decode_main(use_cuda, is_sparse) + + if decorator is not None: + f = decorator(f) + + setattr(TestMachineTranslation, f_name, f) + + +for _use_cuda_ in (False, True): + for _is_sparse_ in (False, True): + inject_test_train(_use_cuda_, _is_sparse_) + +for _use_cuda_ in (False, True): + for _is_sparse_ in (False, True): + + _decorator_ = None + if _use_cuda_: + _decorator_ = unittest.skip( + reason='Beam Search does not support CUDA!') + + inject_test_decode( + is_sparse=_is_sparse_, use_cuda=_use_cuda_, decorator=_decorator_) + if __name__ == '__main__': - # train_main() - decode_main() + unittest.main() diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py index b4b6020f58e7538dfe0f98c17d61f3614c3c6fc4..b8f55c813b6984a5a1d266acc1c159c45c23b665 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py @@ -17,6 +17,7 @@ import paddle.v2.fluid as fluid import paddle.v2 as paddle import sys import numpy +import unittest def parse_arg(): @@ -74,18 +75,18 @@ def conv_net(img, label): return loss_net(conv_pool_2, label) -def train(args, save_dirname=None): - print("recognize digits with args: {0}".format(" ".join(sys.argv[1:]))) - +def train(nn_type, use_cuda, parallel, save_dirname): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') - if args.nn_type == 'mlp': + if nn_type == 'mlp': net_conf = mlp else: net_conf = conv_net - if args.parallel: + if parallel: places = fluid.layers.get_places() pd = fluid.layers.ParallelDo(places) with pd.do(): @@ -107,7 +108,7 @@ def train(args, save_dirname=None): optimizer = fluid.optimizer.Adam(learning_rate=0.001) optimizer.minimize(avg_loss) - place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) @@ -147,13 +148,14 @@ def train(args, save_dirname=None): 'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'. format(pass_id, batch_id + 1, float(avg_loss_val), float(acc_val))) + raise AssertionError("Loss of recognize digits is too large") -def infer(args, save_dirname=None): +def infer(use_cuda, save_dirname=None): if save_dirname is None: return - place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) # Use fluid.io.load_inference_model to obtain the inference program desc, @@ -174,11 +176,48 @@ def infer(args, save_dirname=None): print("infer results: ", results[0]) -if __name__ == '__main__': - args = parse_arg() - if not args.use_cuda and not args.parallel: - save_dirname = "recognize_digits_" + args.nn_type + ".inference.model" +def main(use_cuda, parallel, nn_type): + if not use_cuda and not parallel: + save_dirname = "recognize_digits_" + nn_type + ".inference.model" else: save_dirname = None - train(args, save_dirname) - infer(args, save_dirname) + + train( + nn_type=nn_type, + use_cuda=use_cuda, + parallel=parallel, + save_dirname=save_dirname) + infer(use_cuda=use_cuda, save_dirname=save_dirname) + + +class TestRecognizeDigits(unittest.TestCase): + pass + + +def inject_test_method(use_cuda, parallel, nn_type): + def __impl__(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + main(use_cuda, parallel, nn_type) + + fn = 'test_{0}_{1}_{2}'.format(nn_type, 'cuda' + if use_cuda else 'cpu', 'parallel' + if parallel else 'normal') + + setattr(TestRecognizeDigits, fn, __impl__) + + +def inject_all_tests(): + for use_cuda in (False, True): + for parallel in (False, True): + for nn_type in ('mlp', 'conv'): + inject_test_method(use_cuda, parallel, nn_type) + + +inject_all_tests() + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 3f54e28defb76d3430a82e791578e20b84833f16..aea43c2517a02c72c1ee3307afdd3b21910f0064 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -223,6 +223,14 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(layers.sequence_softmax(x=seq)) print(str(program)) + def test_softmax(self): + program = Program() + with program_guard(program): + data = layers.data(name='data', shape=[10], dtype='float32') + hid = layers.fc(input=data, size=20) + self.assertIsNotNone(layers.softmax(x=hid)) + print(str(program)) + def test_get_places(self): program = Program() with program_guard(program): diff --git a/python/paddle/v2/fluid/tests/test_recv_op.py b/python/paddle/v2/fluid/tests/test_recv_op.py index 5c4cec028d354b99d6203281ec4c727d7e3eceac..3a02b882410fe896cd2add03060127a01cbdaa38 100644 --- a/python/paddle/v2/fluid/tests/test_recv_op.py +++ b/python/paddle/v2/fluid/tests/test_recv_op.py @@ -19,6 +19,7 @@ import paddle.v2.fluid.layers as layers import numpy from multiprocessing import Process import os, sys +import time class TestRecvOp(unittest.TestCase): @@ -28,6 +29,7 @@ class TestRecvOp(unittest.TestCase): p = Process(target=self.init_serv, args=(place, )) p.daemon = True p.start() + time.sleep(1) self.init_client(place) # FIXME(typhoonzero): find a way to gracefully shutdown the server. os.system("kill -9 %d" % p.pid)