diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake index 739a910c7c670b7b9f89e543582a32a80546fb11..df3f0c7f0c31efaa127515bb98e5668b8f9df199 100644 --- a/cmake/external/mklml.cmake +++ b/cmake/external/mklml.cmake @@ -34,7 +34,7 @@ SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}") SET(MKLML_DST_DIR "mklml") SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR}) -SET(MKLML_ROOT ${MKLML_INSTALL_DIR}/${MKLML_VER}) +SET(MKLML_ROOT ${MKLML_INSTALL_DIR}) SET(MKLML_INC_DIR ${MKLML_ROOT}/include) SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib) SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) @@ -46,7 +46,7 @@ INCLUDE_DIRECTORIES(${MKLML_INC_DIR}) FILE(WRITE ${MKLML_DOWNLOAD_DIR}/CMakeLists.txt "PROJECT(MKLML)\n" "cmake_minimum_required(VERSION 3.0)\n" - "install(DIRECTORY ${MKLML_VER}\n" + "install(DIRECTORY ${MKLML_VER}/include ${MKLML_VER}/lib \n" " DESTINATION ${MKLML_DST_DIR})\n") ExternalProject_Add( diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 6b2237b858380f384be0aa3c6ae24a4c83ad646d..0323cd9698cba916d2aa04403be97c0a6a463830 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -69,6 +69,12 @@ if(NOT CBLAS_FOUND) SRCS ${CBLAS_INSTALL_DIR}/lib ${CBLAS_INSTALL_DIR}/include DSTS ${dst_dir} ${dst_dir} ) +elseif (WITH_MKLML) + set(dst_dir "${CMAKE_INSTALL_PREFIX}/third_party/install/mklml") + copy(mklml_lib + SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_INC_DIR} + DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir} + ) endif() # paddle fluid module diff --git a/doc/design/parallel_executor.md b/doc/design/parallel_executor.md index 076c55d281f3d747d4c9e7dd2795af50a93d9ced..9aed3b059a1595ba3971d7d5acfc0d16a731584b 100644 --- a/doc/design/parallel_executor.md +++ b/doc/design/parallel_executor.md @@ -8,7 +8,7 @@ The executor is a very naive interpreter. It runs operators one by one. We can u We want a `ProgramDesc` can be run on different nodes. It is better not to contain device information in `ProgramDesc`. However, we can write a high-performance interpreter, which can hold an alternative intermediate representation of `ProgramDesc`, to take full usage of Multi-GPUs. -ParallelExecutor is an interpreter of `ProgramDesc` which will [out-of-order execute](Out-of-order execution) `Program` in data parallelism mode and maximise the utility of Multi-GPUs. +ParallelExecutor is an interpreter of `ProgramDesc` which will [out-of-order execute](https://en.wikipedia.org/wiki/Out-of-order_execution) `Program` in data parallelism mode and maximise the utility of Multi-GPUs. ## Overview of MultiGPUs logic diff --git a/doc/fluid/design/concurrent/go_op.md b/doc/fluid/design/concurrent/go_op.md new file mode 100644 index 0000000000000000000000000000000000000000..c18b788e80f432ebb2f14b15229e7823c112001e --- /dev/null +++ b/doc/fluid/design/concurrent/go_op.md @@ -0,0 +1,231 @@ +# go_op Design + +## Introduction + +The **go_op** allows user's of PaddlePaddle to run program blocks on a detached +thread. It works in conjuction with CSP operators (channel_send, +channel_receive, channel_open, channel_close, and select) to allow users to +concurrently process data and communicate easily between different threads. + +## How to use it + +``` +channel = fluid.make_channel(dtype=core.VarDesc.VarType.LOD_TENSOR) + +with fluid.Go(): + # Send a tensor of value 99 to "channel" on a detached thread + tensor = fill_constant(shape=[1], dtype='int', value=99) + tensor.stop_gradient = True + fluid.channel_send(channel, tensor) + +# Receive sent tensor from "channel" on the main thread +result = fill_constant(shape=[1], dtype='int', value=-1) +fluid.channel_recv(ch, result) +``` + +The go operator can be accessed by using the fluid.Go() control flow. This +will create a new sub block, where the user can add additional operators +to be ran on the thread. + +**Note:** Since back propegation is currently not support in the go_op, users +should ensure that operators in the go block does not require gradient +calculations. + +## How it Works + +Similar to other control blocks, go_op will create a sub block and add it +as a child to the current block. Operators and variables defined in this +block will be added to the go sub_block. + +In addition, the go operator will create a new child scope whose parent is +the global scope. Please refer to [block captures](#block-captures) for more +information. + +When Paddle executor runs go_op, go_op will take the sub_block and pass it to +the executor.run method (along with a newly created local scope) on a detached +thread. + +An example of the generated program description is shown below. Take note of +the **go_op** in particular. It is added as an operator in the current +block (in this example, block0). The **go_op** contains a `sub_block` +attribute, which points to the id of the block that will be executed in a +detached thread. + +``` +blocks { + idx: 0 + parent_idx: -1 + vars { + name: "return_value" + type { + type: LOD_TENSOR + lod_tensor { + tensor { + data_type: INT64 + } + } + } + } + vars { + name: "status_recv" + type { + type: LOD_TENSOR + lod_tensor { + tensor { + data_type: BOOL + } + } + } + } + ... + ops { + outputs { + parameter: "Out" + arguments: "channel" + } + type: "channel_create" + attrs { + name: "data_type" + type: INT + i: 7 + } + attrs { + name: "capacity" + type: INT + i: 0 + } + } + ops { + inputs { + parameter: "X" + arguments: "channel" + } + type: "go" + attrs { + name: "sub_block" + type: BLOCK + block_idx: 1 + } + } + ops { + inputs { + parameter: "Channel" + arguments: "channel" + } + outputs { + parameter: "Out" + arguments: "return_value" + } + outputs { + parameter: "Status" + arguments: "status_recv" + } + type: "channel_recv" + } + ... +} + +blocks { + idx: 1 + parent_idx: 0 + vars { + name: "status" + type { + type: LOD_TENSOR + lod_tensor { + tensor { + data_type: BOOL + } + } + } + } + ... + + ops { + outputs { + parameter: "Out" + arguments: "fill_constant_1.tmp_0" + } + type: "fill_constant" + attrs { + name: "force_cpu" + type: BOOLEAN + b: false + } + attrs { + name: "value" + type: FLOAT + f: 99.0 + } + attrs { + name: "shape" + type: INTS + ints: 1 + } + attrs { + name: "dtype" + type: INT + i: 3 + } + } + ops { + inputs { + parameter: "Channel" + arguments: "channel" + } + inputs { + parameter: "X" + arguments: "fill_constant_1.tmp_0" + } + outputs { + parameter: "Status" + arguments: "status" + } + type: "channel_send" + attrs { + name: "copy" + type: BOOLEAN + b: false + } + } +``` + +## Current Limitations + +#### Scopes and block captures: + +Paddle utilizes [scopes](./../concepts/scope.md) to store variables used in a +block. When a block is executed, a new local scope is created from the parent +scope (ie: scope derived from the parent block) and associated with the new +child block. After the block finishes executing, then the local scope and +all associated variables in the scope is deleted. + +This works well in a single threaded scenario, however with introduction of +go_op, a child block may continue to execute even after the parent block has +exited. If the go_op tries to access variables located in the parent block's +scope, it may receive a segmentation fault because the parent scope may have +been deleted. + +We need to implement block closures in order to prevent access to parent +scope variables from causing a segmentation fault. As a temporary workaround, +please ensure that all variables accessed in the go block is not destructed +before it is being accessed. Currently, the go_op will explicitly enforce +this requirement and raise an exception if a variable could not be found in +the scope. + +Please refer to [Closure issue](https://github.com/PaddlePaddle/Paddle/issues/8502) +for more details. + +#### Green Threads + +Golang utilizes `green threads`, which is a mechnism for the runtime library to +manage multiple threads (instead of natively by the OS). Green threads usually +allows for faster thread creation and switching, as there is less overhead +when spawning these threads. For the first version of CSP, we only support +OS threads. + + +#### Backward Propegation: + +go_op currently does not support backwards propagation. Please use go_op with +non training operators. diff --git a/doc/v2/dev/index_en.rst b/doc/v2/dev/index_en.rst index 549f5fa9aace7eb699d229e5f61fe10ae4ed4d66..36516b7953224e799e1065fd7930509eec0aa650 100644 --- a/doc/v2/dev/index_en.rst +++ b/doc/v2/dev/index_en.rst @@ -1,9 +1,27 @@ Development ------------ + +PaddlePaddle adheres to the following three sections of code and document specifications. + + +PaddlePaddle uses git for version control and Docker is used for building and testing environment. The code includes Cuda, C++, Python, Shell and other programming languages,which comply with Google C++ Style, Pep-8, and the code base includes style checking by an automatic inspection tool. Code comments need to follow the Doxygen specification. The code that does not meet the style requirements will fail to compile. We provide the following guidelines for the use of Git, build tests and code development. .. toctree:: :maxdepth: 1 contribute_to_paddle_en.md + + +PaddlePaddle is well documented in English and Chinese. We recommend using the English version of the documents and problem description. The design documents focus on problem descriptions, backgrounds, and are followed by solutions. As documents are generated by Sphinx, code comments should comply with the Sphinx documentation standard. We recommend to use the paddlepaddle.org tool to compile and generate and preview documents locally. Please refer to: + +.. toctree:: + :maxdepth: 1 + write_docs_en.rst + +PaddlePaddle V2 defines new operations by adding new Layers. You can implement various complex layers by combining basic APIs to satisfy most applications. If you want to customize layer, please refer to the following, and welcome to propose patch. + +.. toctree:: + :maxdepth: 1 + new_layer_en.rst diff --git a/doc/v2/faq/cluster/index_en.rst b/doc/v2/faq/cluster/index_en.rst index 855b7e8e53307b82a72c156be4ef509e27edf822..fa942a09625bef78b28456beeb735272b686e061 100644 --- a/doc/v2/faq/cluster/index_en.rst +++ b/doc/v2/faq/cluster/index_en.rst @@ -2,4 +2,15 @@ Cluster Training and Prediction ############################### -TBD +.. contents:: + +1. Network connection errors in the log during multi-node cluster training +------------------------------------------------ +There are maybe some errors in the log belonging to network connection problem during multi-node cluster training, for example, :code:`Connection reset by peer`. +This kind of error is usually caused by the abnormal exit of a training process in some node, and the other nodes cannot connect with this node any longer. Steps to troubleshoot the problem are as follows: + +* Find the first error in the :code:`train.log`, :code:`server.log`, check whether other fault casued the problem, such as FPE, lacking of memory or disk. + +* If the first error in server.log says "Address already used", this may be caused by the port conflict of the non-exclusive execution. Connect the sys-admin to check if the current MPI cluster supports jobs submitted with parameter :code:`resource=full`. If the current MPI cluster does not support this parameter, change the server port and try agian. + +* If the current MPI cluster does not support exclusive pattern which allows a process to occupy the whole node, ask the administrator to replace or update the this cluster. diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h index adfaba26ace78f547161ad4029a741f3ca8a6764..019bea600f496a6b58579ad0aa8af836cd6134a9 100644 --- a/paddle/fluid/framework/channel.h +++ b/paddle/fluid/framework/channel.h @@ -34,7 +34,7 @@ class Channel { public: virtual bool CanSend() = 0; virtual bool CanReceive() = 0; - virtual bool Send(T*) = 0; + virtual void Send(T*) = 0; virtual bool Receive(T*) = 0; virtual size_t Cap() = 0; virtual void Lock() = 0; @@ -84,69 +84,81 @@ class ChannelHolder { } template - bool Send(T* data) { - if (!IsInitialized()) return false; - PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T))); + void Send(T* data) { + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + PADDLE_ENFORCE_EQ( + holder_->Type(), std::type_index(typeid(T)), + "Channel type is not same as the type of the data being sent"); // Static cast should be safe because we have ensured that types are same Channel* channel = static_cast*>(holder_->Ptr()); - return channel != nullptr ? channel->Send(data) : false; + PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null."); + channel->Send(data); } template bool Receive(T* data) { - if (!IsInitialized()) return false; - PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T))); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + PADDLE_ENFORCE_EQ( + holder_->Type(), std::type_index(typeid(T)), + "Channel type is not same as the type of the data being sent"); Channel* channel = static_cast*>(holder_->Ptr()); - return channel != nullptr ? channel->Receive(data) : false; + PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null."); + return channel->Receive(data); } bool IsClosed() { - if (IsInitialized()) { - return holder_->IsClosed(); - } - return false; + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + return holder_->IsClosed(); } bool CanSend() { - if (IsInitialized()) { - return holder_->CanSend(); - } - return false; + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + return holder_->CanSend(); } bool CanReceive() { - if (IsInitialized()) { - return holder_->CanReceive(); - } - return false; + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + return holder_->CanReceive(); } void close() { - if (IsInitialized()) holder_->Close(); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + holder_->Close(); } size_t Cap() { - if (IsInitialized()) return holder_->Cap(); - return -1; + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + return holder_->Cap(); } void Lock() { - if (IsInitialized()) holder_->Lock(); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + holder_->Lock(); } void Unlock() { - if (IsInitialized()) holder_->Unlock(); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + holder_->Unlock(); } template void AddToSendQ(const void* referrer, T* data, std::shared_ptr cond, std::function cb) { - if (IsInitialized()) { - Channel* channel = static_cast*>(holder_->Ptr()); - if (channel != nullptr) { - channel->AddToSendQ(referrer, data, cond, cb); - } + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + Channel* channel = static_cast*>(holder_->Ptr()); + if (channel != nullptr) { + channel->AddToSendQ(referrer, data, cond, cb); } } @@ -154,26 +166,31 @@ class ChannelHolder { void AddToReceiveQ(const void* referrer, T* data, std::shared_ptr cond, std::function cb) { - if (IsInitialized()) { - Channel* channel = static_cast*>(holder_->Ptr()); - if (channel != nullptr) { - channel->AddToReceiveQ(referrer, data, cond, cb); - } + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + Channel* channel = static_cast*>(holder_->Ptr()); + if (channel != nullptr) { + channel->AddToReceiveQ(referrer, data, cond, cb); } } void RemoveFromSendQ(const void* referrer) { - if (IsInitialized()) holder_->RemoveFromSendQ(referrer); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + holder_->RemoveFromSendQ(referrer); } void RemoveFromReceiveQ(const void* referrer) { - if (IsInitialized()) holder_->RemoveFromReceiveQ(referrer); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + holder_->RemoveFromReceiveQ(referrer); } inline bool IsInitialized() const { return holder_ != nullptr; } inline const std::type_index Type() { - PADDLE_ENFORCE_EQ(IsInitialized(), true); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); return holder_->Type(); } diff --git a/paddle/fluid/framework/channel_impl.h b/paddle/fluid/framework/channel_impl.h index 457abbf373d4549229e8fd8bd6b2087cc6b8f5c8..378a0bab1cc7408266fa45a0b3dc19619dd4fb4c 100644 --- a/paddle/fluid/framework/channel_impl.h +++ b/paddle/fluid/framework/channel_impl.h @@ -31,7 +31,7 @@ class ChannelImpl : public paddle::framework::Channel { public: virtual bool CanSend(); virtual bool CanReceive(); - virtual bool Send(T *); + virtual void Send(T *); virtual bool Receive(T *); virtual size_t Cap() { return cap_; } virtual void Lock(); @@ -76,10 +76,9 @@ class ChannelImpl : public paddle::framework::Channel { } }; - bool send_return(bool value) { + void send_return() { send_ctr--; destructor_cond_.notify_all(); - return value; } bool recv_return(bool value) { @@ -118,15 +117,15 @@ bool ChannelImpl::CanReceive() { } template -bool ChannelImpl::Send(T *item) { +void ChannelImpl::Send(T *item) { send_ctr++; std::unique_lock lock{mu_}; - // If channel is closed, do nothing + // If channel is closed, throw exception if (closed_) { lock.unlock(); - // TODO(abhinavarora) Should panic on closed channel - return send_return(false); + send_return(); + PADDLE_THROW("Cannot send on closed channel"); } // If there is a receiver, directly pass the value we want @@ -143,7 +142,7 @@ bool ChannelImpl::Send(T *item) { if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND); if (do_send) *(m->data) = std::move(*item); - else + else { // We cannot do the data transfer because // this QueueMessage was added by Select // and some other case was executed. @@ -151,12 +150,17 @@ bool ChannelImpl::Send(T *item) { // We do not care about notifying other // because they would have been notified // by the executed select case. - return send_return(Send(item)); + lock.unlock(); + Send(item); + send_return(); + return; + } // Wake up the blocked process and unlock m->Notify(); lock.unlock(); - return send_return(true); + send_return(); + return; } // Unbuffered channel will always bypass this @@ -167,7 +171,8 @@ bool ChannelImpl::Send(T *item) { buf_.push_back(std::move(*item)); // Release lock and return true lock.unlock(); - return send_return(true); + send_return(); + return; } // Block on channel, because some receiver will complete @@ -175,8 +180,12 @@ bool ChannelImpl::Send(T *item) { auto m = std::make_shared(item); sendq.push_back(m); m->Wait(lock); - // TODO(abhinavarora) Should panic on closed channel - return send_return(!m->chan_closed); + if (m->chan_closed) { + lock.unlock(); + send_return(); + PADDLE_THROW("Cannot send on closed channel"); + } + send_return(); } template diff --git a/paddle/fluid/framework/channel_test.cc b/paddle/fluid/framework/channel_test.cc index 73be5cdbe2a1f5994ecee4c415e83962f50532fe..e2380bb54bd25c4f30f79cad30f95f7cb056eef0 100644 --- a/paddle/fluid/framework/channel_test.cc +++ b/paddle/fluid/framework/channel_test.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include #include - #include "gtest/gtest.h" using paddle::framework::Channel; @@ -41,7 +40,7 @@ void RecevingOrderEqualToSendingOrder(Channel *ch) { unsigned sum_send = 0; std::thread t([&]() { for (int i = 0; i < 5; i++) { - EXPECT_EQ(ch->Send(&i), true); + ch->Send(&i); sum_send += i; } }); @@ -61,7 +60,7 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) { const size_t buffer_size = 10; auto ch = MakeChannel(buffer_size); for (size_t i = 0; i < buffer_size; ++i) { - EXPECT_EQ(ch->Send(&i), true); // should not block + ch->Send(&i); } size_t out; @@ -82,7 +81,7 @@ void SendReceiveWithACloseChannelShouldPanic(Channel *ch) { const size_t data = 5; std::thread send_thread{[&]() { size_t i = data; - EXPECT_EQ(ch->Send(&i), true); // should not block + ch->Send(&i); // should not block }}; std::thread recv_thread{[&]() { @@ -94,12 +93,18 @@ void SendReceiveWithACloseChannelShouldPanic(Channel *ch) { send_thread.join(); recv_thread.join(); - // After closing send should return false. Receive should - // also return false as there is no data in queue. + // After closing send should panic. Receive should + // also false as there is no data in queue. CloseChannel(ch); send_thread = std::thread{[&]() { size_t i = data; - EXPECT_EQ(ch->Send(&i), false); // should return false + bool is_exception = false; + try { + ch->Send(&i); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + EXPECT_EQ(is_exception, true); }}; recv_thread = std::thread{[&]() { size_t i; @@ -129,7 +134,7 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) { auto ch = MakeChannel(buffer_size); for (size_t i = 0; i < buffer_size; ++i) { - EXPECT_EQ(ch->Send(&i), true); // sending should not block + ch->Send(&i); // sending should not block } size_t out; @@ -160,9 +165,16 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { // Try to write more than buffer size. for (size_t i = 0; i < 2 * buffer_size; ++i) { if (i < buffer_size) - EXPECT_EQ(ch->Send(&i), true); // should block after 10 iterations - else - EXPECT_EQ(ch->Send(&i), false); + ch->Send(&i); // should block after 10 iterations + else { + bool is_exception = false; + try { + ch->Send(&i); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + EXPECT_EQ(is_exception, true); + } } }); std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec @@ -231,7 +243,13 @@ void ChannelCloseUnblocksSendersTest(Channel *ch, bool isBuffered) { t[i] = std::thread( [&](bool *ended, bool *success) { int data = 10; - *success = ch->Send(&data); + bool is_exception = false; + try { + ch->Send(&data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + *success = !is_exception; *ended = true; }, &thread_ended[i], &send_success[i]); @@ -316,8 +334,11 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) { // Try to send more number of times // than receivers for (int i = 0; i < 4; i++) { - ch->Send(&i); - sum_send += i; + try { + ch->Send(&i); + sum_send += i; + } catch (paddle::platform::EnforceNotMet e) { + } } }); for (int i = 0; i < 3; i++) { @@ -382,7 +403,13 @@ void ChannelDestroyUnblockSenders(Channel *ch, bool isBuffered) { t[i] = std::thread( [&](bool *ended, bool *success) { int data = 10; - *success = ch->Send(&data); + bool is_exception = false; + try { + ch->Send(&data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + *success = !is_exception; *ended = true; }, &thread_ended[i], &send_success[i]); @@ -508,7 +535,7 @@ void ChannelHolderSendReceive(ChannelHolder *ch) { unsigned sum_send = 0; std::thread t([&]() { for (int i = 0; i < 5; i++) { - EXPECT_EQ(ch->Send(&i), true); + ch->Send(&i); sum_send += i; } }); @@ -541,8 +568,22 @@ TEST(ChannelHolder, ChannelUninitializedTest) { ChannelHolder *ch = new ChannelHolder(); EXPECT_EQ(ch->IsInitialized(), false); int i = 10; - EXPECT_EQ(ch->Send(&i), false); - EXPECT_EQ(ch->Receive(&i), false); + bool send_exception = false; + try { + ch->Send(&i); + } catch (paddle::platform::EnforceNotMet e) { + send_exception = true; + } + EXPECT_EQ(send_exception, true); + + bool recv_exception = false; + try { + ch->Receive(&i); + } catch (paddle::platform::EnforceNotMet e) { + recv_exception = true; + } + EXPECT_EQ(recv_exception, true); + bool is_exception = false; try { ch->Type(); @@ -669,7 +710,13 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { t[i] = std::thread( [&](bool *ended, bool *success) { int data = 10; - *success = ch->Send(&data); + bool is_exception = false; + try { + ch->Send(&data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + *success = !is_exception; *ended = true; }, &thread_ended[i], &send_success[i]); @@ -760,7 +807,13 @@ void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) { t[i] = std::thread( [&](bool *ended, bool *success) { int data = 10; - *success = ch->Send(&data); + bool is_exception = false; + try { + ch->Send(&data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + *success = !is_exception; *ended = true; }, &thread_ended[i], &send_success[i]); diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 6f878541e6de1deec1829145b1b325ecd176a034..f7a6b5ba84ca1762bd903790aa3c0346b22ed035 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -45,10 +45,11 @@ class Tensor { friend struct EigenVector; public: - Tensor() : offset_(0) {} + Tensor() : offset_(0), is_pinned_(false) {} /*! Constructor with place should only be used in pybind. */ - explicit Tensor(const platform::Place& place) : offset_(0) { + explicit Tensor(const platform::Place& place) + : offset_(0), is_pinned_(false) { holder_->set_place(place); } @@ -69,11 +70,12 @@ class Tensor { * @note If not exist, then allocation. */ template - inline T* mutable_data(platform::Place place); + inline T* mutable_data(platform::Place place, bool is_pinned = false); - inline void* mutable_data(platform::Place place, std::type_index type); + inline void* mutable_data(platform::Place place, std::type_index type, + bool is_pinned = false); - inline void* mutable_data(platform::Place place); + inline void* mutable_data(platform::Place place, bool is_pinned = false); /** * @brief Return a pointer to mutable memory block. @@ -84,7 +86,8 @@ class Tensor { * @note If not exist, then allocation. */ template - inline T* mutable_data(DDim dims, platform::Place place); + inline T* mutable_data(DDim dims, platform::Place place, + bool is_pinned = false); /*! Return the dimensions of the memory block. */ inline const DDim& dims() const; @@ -92,6 +95,9 @@ class Tensor { /*! Return the numel of the memory block. */ inline int64_t numel() const; + /*! Return the numel of the memory block. */ + inline bool isPinned() const; + /*! Resize the dimensions of the memory block. */ inline Tensor& Resize(const DDim& dims); @@ -146,12 +152,14 @@ class Tensor { template struct PlaceholderImpl : public Placeholder { - PlaceholderImpl(Place place, size_t size, std::type_index type) - : ptr_(static_cast(memory::Alloc(place, size)), - memory::PODDeleter(place)), + PlaceholderImpl(Place place, size_t size, std::type_index type, + bool is_pinned = false) + : ptr_(static_cast(memory::Alloc(place, size, is_pinned)), + memory::PODDeleter(place, is_pinned)), place_(place), size_(size), - type_(type) { + type_(type), + is_pinned_(is_pinned) { PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.", (is_cpu_place(place_) ? "CPU" : "GPU")); } @@ -174,6 +182,9 @@ class Tensor { /* the current type of memory */ std::type_index type_; + + /*! use pinned memory or not. */ + bool is_pinned_; }; /*! holds the memory block if allocated. */ @@ -208,6 +219,7 @@ class Tensor { * PlaceHolder::ptr_ and where the tensor data really begins. */ size_t offset_; + bool is_pinned_; }; inline void Tensor::switch_place(platform::Place new_place) { diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 7a4839044008338dda43f75b5ee6def500b78270..113814971e115fa88bd0ded34017fa26a9dd5803 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -101,19 +101,21 @@ inline T* Tensor::data() { } template -inline T* Tensor::mutable_data(DDim dims, platform::Place place) { +inline T* Tensor::mutable_data(DDim dims, platform::Place place, + bool is_pinned) { static_assert(std::is_pod::value, "T must be POD"); Resize(dims); - return mutable_data(place); + return mutable_data(place, is_pinned); } template -inline T* Tensor::mutable_data(platform::Place place) { +inline T* Tensor::mutable_data(platform::Place place, bool is_pinned) { static_assert(std::is_pod::value, "T must be POD"); - return reinterpret_cast(mutable_data(place, typeid(T))); + return reinterpret_cast(mutable_data(place, typeid(T), is_pinned)); } -inline void* Tensor::mutable_data(platform::Place place, std::type_index type) { +inline void* Tensor::mutable_data(platform::Place place, std::type_index type, + bool is_pinned) { if (holder_ != nullptr) { holder_->set_type(type); } @@ -127,26 +129,27 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) { holder_->size() < size + offset_) { if (platform::is_cpu_place(place)) { holder_.reset(new PlaceholderImpl( - boost::get(place), size, type)); + boost::get(place), size, type, is_pinned)); } else if (platform::is_gpu_place(place)) { #ifndef PADDLE_WITH_CUDA PADDLE_THROW("'CUDAPlace' is not supported in CPU only device."); } #else holder_.reset(new PlaceholderImpl( - boost::get(place), size, type)); + boost::get(place), size, type, is_pinned)); } #endif offset_ = 0; + is_pinned_ = is_pinned; } return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } -inline void* Tensor::mutable_data(platform::Place place) { +inline void* Tensor::mutable_data(platform::Place place, bool is_pinned) { PADDLE_ENFORCE(this->holder_ != nullptr, "Cannot invoke mutable data if current hold nothing"); - return mutable_data(place, holder_->type()); + return mutable_data(place, holder_->type(), is_pinned); } inline Tensor& Tensor::ShareDataWith(const Tensor& src) { @@ -188,6 +191,8 @@ inline const DDim& Tensor::dims() const { return dims_; } inline int64_t Tensor::numel() const { return product(dims_); } +inline bool Tensor::isPinned() const { return is_pinned_; } + inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { Tensor res; res.ShareDataWith(src); diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 9949d80434c43ce846895c8d4c84221008a7fd8a..22f6f506748735d1a0fe75375aeea22bd92b8b7e 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -130,6 +130,50 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) { bool GPUAllocator::UseGpu() const { return true; } +// PINNED memory allows direct DMA transfers by the GPU to and from system +// memory. It’s locked to a physical address. +void* CUDAPinnedAllocator::Alloc(size_t& index, size_t size) { + if (size <= 0) return nullptr; + void* p; + // NOTE: here, we use GpuMaxAllocSize() as the maximum memory size + // of host pinned allocation. Allocates too much would reduce + // the amount of memory available to the underlying system for paging. + + size_t usable = paddle::platform::GpuMaxAllocSize() - fallback_alloc_size_; + + if (size > usable) return nullptr; + + // PINNED memory is visible to all CUDA contexts. + cudaError_t result = cudaMallocHost(&p, size); + if (result == cudaSuccess) { + index = 1; + fallback_alloc_size_ += size; + return p; + } + + return nullptr; +} + +void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) { + cudaError_t err; + PADDLE_ASSERT(index == 1); + + PADDLE_ASSERT(fallback_alloc_size_ >= size); + fallback_alloc_size_ -= size; + err = cudaFreeHost(p); + + // Purposefully allow cudaErrorCudartUnloading, because + // that is returned if you ever call cudaFreeHost after the + // driver has already shutdown. This happens only if the + // process is terminating, in which case we don't care if + // cudaFreeHost succeeds. + if (err != cudaErrorCudartUnloading) { + PADDLE_ENFORCE(err, "cudaFreeHost failed in GPUPinnedAllocator::Free."); + } +} + +bool CUDAPinnedAllocator::UseGpu() const { return true; } + #endif } // namespace detail diff --git a/paddle/fluid/memory/detail/system_allocator.h b/paddle/fluid/memory/detail/system_allocator.h index c103d0864012d23d0390076840ee1a61b12ad048..e8479e73f433f1d741b2933da4843c0ba80276d5 100644 --- a/paddle/fluid/memory/detail/system_allocator.h +++ b/paddle/fluid/memory/detail/system_allocator.h @@ -54,6 +54,18 @@ class GPUAllocator : public SystemAllocator { size_t fallback_alloc_size_ = 0; int gpu_id_; }; + +class CUDAPinnedAllocator : public SystemAllocator { + public: + virtual void* Alloc(size_t& index, size_t size); + virtual void Free(void* p, size_t size, size_t index); + virtual bool UseGpu() const; + + private: + size_t gpu_alloc_size_ = + 0; // TODO(zcd): how to define the upper limit of CUDAPinnedMemory? + size_t fallback_alloc_size_ = 0; +}; #endif } // namespace detail diff --git a/paddle/fluid/memory/memory.cc b/paddle/fluid/memory/memory.cc index 1985f1f4e68db1e62ee7cfd3649312581840d02c..56593653a622bce323306d86156d140c46f58d18 100644 --- a/paddle/fluid/memory/memory.cc +++ b/paddle/fluid/memory/memory.cc @@ -38,7 +38,8 @@ BuddyAllocator* GetCPUBuddyAllocator() { } template <> -void* Alloc(platform::CPUPlace place, size_t size) { +void* Alloc(platform::CPUPlace place, size_t size, + bool is_pinned) { VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place); void* p = GetCPUBuddyAllocator()->Alloc(size); VLOG(10) << " pointer=" << p; @@ -46,7 +47,8 @@ void* Alloc(platform::CPUPlace place, size_t size) { } template <> -void Free(platform::CPUPlace place, void* p) { +void Free(platform::CPUPlace place, void* p, + bool is_pinned) { VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place); GetCPUBuddyAllocator()->Free(p); } @@ -82,15 +84,47 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { return as[gpu_id]; } +BuddyAllocator* GetCUDAPinnedBuddyAllocator(int gpu_id) { + static BuddyAllocator** as = NULL; + if (as == NULL) { + int gpu_num = platform::GetCUDADeviceCount(); + as = new BuddyAllocator*[gpu_num]; + for (int gpu = 0; gpu < gpu_num; gpu++) { + as[gpu] = nullptr; + } + } + platform::SetDeviceId(gpu_id); + if (!as[gpu_id]) { + as[gpu_id] = new BuddyAllocator(new detail::CUDAPinnedAllocator, + platform::GpuMinChunkSize(), + platform::GpuMaxChunkSize()); + VLOG(10) << "\n\nNOTE: each GPU device use " + << FLAGS_fraction_of_gpu_memory_to_use * 100 + << "% of GPU memory.\n" + << "You can set GFlags environment variable '" + << "FLAGS_fraction_of_gpu_memory_to_use" + << "' to change the fraction of GPU usage.\n\n"; + } + return as[gpu_id]; +} + template <> size_t Used(platform::CUDAPlace place) { return GetGPUBuddyAllocator(place.device)->Used(); } template <> -void* Alloc(platform::CUDAPlace place, size_t size) { - auto* buddy_allocator = GetGPUBuddyAllocator(place.device); - auto* ptr = buddy_allocator->Alloc(size); +void* Alloc(platform::CUDAPlace place, size_t size, + bool is_pinned) { + void* ptr; + if (is_pinned) { + auto* buddy_allocator = GetCUDAPinnedBuddyAllocator(place.device); + ptr = buddy_allocator->Alloc(size); + } else { + auto* buddy_allocator = GetGPUBuddyAllocator(place.device); + ptr = buddy_allocator->Alloc(size); + } + if (ptr == nullptr) { int cur_dev = platform::GetCurrentDeviceId(); platform::SetDeviceId(place.device); @@ -108,8 +142,13 @@ void* Alloc(platform::CUDAPlace place, size_t size) { } template <> -void Free(platform::CUDAPlace place, void* p) { - GetGPUBuddyAllocator(place.device)->Free(p); +void Free(platform::CUDAPlace place, void* p, + bool is_pinned) { + if (is_pinned) { + GetCUDAPinnedBuddyAllocator(place.device)->Free(p); + } else { + GetGPUBuddyAllocator(place.device)->Free(p); + } } #endif diff --git a/paddle/fluid/memory/memory.h b/paddle/fluid/memory/memory.h index 7c5db815d6543f026ab99f7cf895a87db4e5a3d8..062bfc880e78dc5d90c567ffe5c4e521704c9ca6 100644 --- a/paddle/fluid/memory/memory.h +++ b/paddle/fluid/memory/memory.h @@ -33,7 +33,7 @@ namespace memory { * address is valid or not. */ template -void* Alloc(Place place, size_t size); +void* Alloc(Place place, size_t size, bool is_pinned = false); /** * \brief Free memory block in one place. @@ -43,7 +43,7 @@ void* Alloc(Place place, size_t size); * */ template -void Free(Place place, void* ptr); +void Free(Place place, void* ptr, bool is_pinned = false); /** * \brief Total size of used memory in one place. @@ -74,11 +74,13 @@ class PODDeleter { static_assert(std::is_pod::value, "T must be POD"); public: - explicit PODDeleter(Place place) : place_(place) {} - void operator()(T* ptr) { Free(place_, static_cast(ptr)); } + explicit PODDeleter(Place place, bool is_pinned = false) + : place_(place), is_pinned_(is_pinned) {} + void operator()(T* ptr) { Free(place_, static_cast(ptr), is_pinned_); } private: Place place_; + bool is_pinned_; }; /** diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 9a11e1be7050adb1803b1fd835ffb811d9cae4cd..8341170d6897d71ddf95d4de95f521f5d31ab7cd 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -264,3 +264,4 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memor cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) +nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) diff --git a/paddle/fluid/operators/channel_send_op.cc b/paddle/fluid/operators/channel_send_op.cc index 47cf7d7efc9996e8a8db11b79c0310f77c2435a4..66d33617ede5bef8a95de14f5b447c0910fe3eb4 100644 --- a/paddle/fluid/operators/channel_send_op.cc +++ b/paddle/fluid/operators/channel_send_op.cc @@ -23,21 +23,10 @@ limitations under the License. */ static constexpr char Channel[] = "Channel"; static constexpr char X[] = "X"; -static constexpr char Status[] = "Status"; -static constexpr char copy[] = "copy"; namespace paddle { namespace operators { -void SetSendStatus(const platform::Place &dev_place, - framework::Variable &status_var, bool status) { - auto cpu = platform::CPUPlace(); - auto status_tensor = - status_var.GetMutable()->mutable_data({1}, - cpu); - status_tensor[0] = status; -} - class ChannelSendOp : public framework::OperatorBase { public: ChannelSendOp(const std::string &type, @@ -51,9 +40,6 @@ class ChannelSendOp : public framework::OperatorBase { "Input(Channel) of ChannelSendOp should not be null."); PADDLE_ENFORCE(ctx->HasInput(X), "Input(X) of ChannelSendOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput(Status), - "Output(Status) of ChannelSendOp should not be null."); - ctx->SetOutputDim("Status", {1}); } private: @@ -65,10 +51,7 @@ class ChannelSendOp : public framework::OperatorBase { auto input_var = scope.FindVar(Input(X)); // Send the input data through the channel. - bool ok = concurrency::ChannelSend(ch, input_var); - - // Set the status output of the `ChannelSend` call. - SetSendStatus(dev_place, *scope.FindVar(Output(Status)), ok); + concurrency::ChannelSend(ch, input_var); } }; @@ -82,12 +65,6 @@ class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker { .AsDuplicable(); AddInput(X, "(Variable) The value which gets sent by the channel.") .AsDuplicable(); - AddOutput(Status, - "(Tensor) An LoD Tensor that returns a boolean status of the" - "result of the send operation.") - .AsDuplicable(); - AddAttr(copy, "(bool, default false) Should copy before send") - .SetDefault(false); AddComment(R"DOC( )DOC"); } diff --git a/paddle/fluid/operators/concurrency/channel_util.cc b/paddle/fluid/operators/concurrency/channel_util.cc index a483af7affd824da7d18676d934dc959167ef71f..246c99489c45efec16babb1d3980606318236605 100644 --- a/paddle/fluid/operators/concurrency/channel_util.cc +++ b/paddle/fluid/operators/concurrency/channel_util.cc @@ -17,20 +17,20 @@ limitations under the License. */ namespace poc = paddle::operators::concurrency; -bool poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) { +void poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) { auto type = framework::ToVarType(var->Type()); if (type == framework::proto::VarType_Type_LOD_TENSOR) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else if (type == framework::proto::VarType_Type_SELECTED_ROWS) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else if (type == framework::proto::VarType_Type_READER) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else if (type == framework::proto::VarType_Type_CHANNEL) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else PADDLE_THROW("ChannelSend:Unsupported type"); } diff --git a/paddle/fluid/operators/concurrency/channel_util.h b/paddle/fluid/operators/concurrency/channel_util.h index c3674bd9815df451751707bfa84d18dbb5fa0f6b..cd18ca78c6fdecdc6c72748611ccdd9c2690ef46 100644 --- a/paddle/fluid/operators/concurrency/channel_util.h +++ b/paddle/fluid/operators/concurrency/channel_util.h @@ -21,7 +21,7 @@ namespace paddle { namespace operators { namespace concurrency { -bool ChannelSend(framework::ChannelHolder *ch, framework::Variable *var); +void ChannelSend(framework::ChannelHolder *ch, framework::Variable *var); bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var); void ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer, diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 12e8eb0b4da2252b104415aef4156bf100c3e565..bdda5703436765480f353ee964624364f45dbefb 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -48,6 +48,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, void* dest, int size) { const void* data = NULL; int size_to_write = 0; + int length = size; + int total_written = 0; if (platform::is_gpu_place(place)) { #ifdef PADDLE_WITH_CUDA @@ -56,16 +58,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, platform::CPUPlace cpu; char* p = reinterpret_cast(dest); - while (size > 0) { + while (total_written < length) { if (!input->GetDirectBufferPointer(&data, &size_to_write)) { return false; } - + // NOTE: if raw buffer is large and have two neighbor fields of raw + // buffers GetDirectBufferPointer can get all of them, use length to + // truncate it. + if (total_written + size_to_write > length) { + size_to_write = length - total_written; + } memory::Copy(boost::get(place), reinterpret_cast(p), cpu, data, size_to_write, gpu_dev_ctx.stream()); p += size_to_write; - size -= size_to_write; + total_written += size_to_write; input->Skip(size_to_write); } @@ -77,16 +84,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, } char* p = reinterpret_cast(dest); - while (size > 0) { + while (total_written < length) { if (!input->GetDirectBufferPointer(&data, &size_to_write)) { return false; } + // NOTE: if raw buffer is large and have two neighbor fields of raw buffers + // GetDirectBufferPointer can get all of them, use length to truncate it. + if (total_written + size_to_write > length) { + size_to_write = length - total_written; + } // TODO(gongwb): can we avoid copy? platform::CPUPlace cpu; memory::Copy(cpu, reinterpret_cast(p), cpu, data, size_to_write); p += size_to_write; - size -= size_to_write; + total_written += size_to_write; input->Skip(size_to_write); } @@ -153,6 +165,7 @@ bool VariableResponse::CopySelectRowsData( const platform::DeviceContext& ctx, int length) { auto var = scope_->FindVar(meta_.varname()); auto* slr = var->GetMutable(); + slr->mutable_rows()->resize(length / 8); // int64 int64_t* rows_data = slr->mutable_rows()->data(); // copy rows CPU data, GPU data will be copied lazily. @@ -233,7 +246,6 @@ int VariableResponse::Parse(Source* source) { if (tag != 0) { return -1; } - return 0; } diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 94382739b5077b1449a8fd5be7952f35737ca340..184c095e487a302ebc4d251dd6f332333c415c6d 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -55,9 +55,6 @@ class GPUDropoutKernel : public framework::OpKernel { y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); - auto X = EigenMatrix::Reshape(*x, 1); - auto Y = EigenMatrix::Reshape(*y, 1); - auto& place = *context.template device_context().eigen_device(); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); @@ -76,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel { T><<>>( size, seed, dropout_prob, x_data, mask_data, y_data); } else { + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); Y.device(place) = X * static_cast(1.0f - dropout_prob); } } diff --git a/paddle/fluid/operators/dropout_op_test.cc b/paddle/fluid/operators/dropout_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..db97ba4f64105c37c49cafbc3fbc4829c5077467 --- /dev/null +++ b/paddle/fluid/operators/dropout_op_test.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(dropout); + +void Compare(f::Scope& scope, p::DeviceContext& ctx) { + // init + auto var = scope.Var("X"); + auto tensor = var->GetMutable(); + tensor->Resize({10, 10}); + + std::vector init; + for (int64_t i = 0; i < 10 * 10; ++i) { + init.push_back(1.0); + } + + TensorFromVector(init, ctx, tensor); + + auto place = ctx.GetPlace(); + auto out_var = scope.Var("Out"); + auto out_tensor = out_var->GetMutable(); + out_tensor->Resize({10, 10}); + out_tensor->mutable_data(place); // allocate + + auto mask_var = scope.Var("Mask"); + auto mask_tensor = mask_var->GetMutable(); + mask_tensor->Resize({10, 10}); + mask_tensor->mutable_data(place); // allocate + + // run + f::AttributeMap attrs; + float dropout_prob = 0.5; + attrs.insert({"fix_seed", 1}); + attrs.insert({"seed", 3}); + attrs.insert({"dropout_prob", dropout_prob}); + auto dropout_op = f::OpRegistry::CreateOp( + "dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs); + + dropout_op->Run(scope, place); + + std::vector out_vec; + TensorToVector(*out_tensor, ctx, &out_vec); + + std::vector std_out = { + 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, + 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, + 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, + 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, + 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1}; + + EXPECT_EQ(out_vec.size(), std_out.size()); + for (uint32_t i = 0; i < out_vec.size(); i++) { + EXPECT_EQ(out_vec[i], std_out[i]); + } +} + +TEST(Dropout, CPUDense) { + f::Scope scope; + p::CPUPlace place; + p::CPUDeviceContext ctx(place); + Compare(scope, ctx); +} + +TEST(Dropout, GPUDense) { + f::Scope scope; + p::CUDAPlace place; + p::CUDADeviceContext ctx(place); + Compare(scope, ctx); +} diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 605b5c258ca57b1a63c9b741a1a30dcb9fca2248..7b84ba0a7daf10e9e636f62eea6bd759ebec9541 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -22,6 +22,103 @@ limitations under the License. */ namespace paddle { namespace operators { +// Wrap RowwiseMean and ColwiseMean. +// Reuse the cpu codes and replace the gpu codes with cublas_gemv, which is +// significantly faster. Unlike the RowwiseMean and ColwiseMean, the +// implementation only considers 2D. +template +struct RowwiseMean2D { + RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx); + + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor* vec); +}; + +#ifdef PADDLE_WITH_CUDA +template +class RowwiseMean2D { + public: + RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx) + : left_(left), right_(right) { + framework::DDim ones_dim({right_}); + divisor_.mutable_data(ones_dim, dev_ctx.GetPlace()); + math::set_constant(dev_ctx, &divisor_, 1.0 / right); + } + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + math::gemv( + context, false, left_, right_, 1., input.data(), divisor_.data(), + 0., out->data()); + } + + private: + int left_; + int right_; + framework::Tensor divisor_; +}; +#endif + +template +class RowwiseMean2D { + public: + RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx) {} + + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + row_mean_(context, input, out); + } + + private: + math::RowwiseMean row_mean_; +}; + +template +struct ColwiseSum2D { + ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx); + + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor* vec); +}; + +#ifdef PADDLE_WITH_CUDA +template +class ColwiseSum2D { + public: + ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx) + : left_(left), right_(right) { + framework::DDim ones_dim({left_}); + divisor_.mutable_data(ones_dim, dev_ctx.GetPlace()); + math::set_constant(dev_ctx, &divisor_, 1.0); + } + + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + math::gemv( + context, true, left_, right_, 1., input.data(), divisor_.data(), + 0., out->data()); + } + + private: + int left_; + int right_; + framework::Tensor divisor_; +}; +#endif + +template +class ColwiseSum2D { + public: + ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx) {} + + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + col_wise_(context, input, out); + } + + private: + math::ColwiseSum col_wise_; +}; + template struct SubAndSquareFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); } @@ -67,15 +164,15 @@ using DataLayout = framework::DataLayout; template class LayerNormKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { const float epsilon = ctx.Attr("epsilon"); - auto *scale = ctx.Input("Scale"); - auto *bias = ctx.Input("Bias"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); auto x = *ctx.Input("X"); - auto *y = ctx.Output("Y"); - auto *mean = ctx.Output("Mean"); - auto *var = ctx.Output("Variance"); + auto* y = ctx.Output("Y"); + auto* mean = ctx.Output("Mean"); + auto* var = ctx.Output("Variance"); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); const auto x_dims = x.dims(); @@ -94,8 +191,8 @@ class LayerNormKernel : public framework::OpKernel { out.ShareDataWith(*y); out.Resize(matrix_shape); - auto &dev_ctx = ctx.template device_context(); - math::RowwiseMean row_mean; + auto& dev_ctx = ctx.template device_context(); + RowwiseMean2D row_mean(left, right, ctx.device_context()); // get mean row_mean(dev_ctx, x, mean); @@ -126,31 +223,32 @@ class LayerNormKernel : public framework::OpKernel { template class LayerNormGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { const float epsilon = ctx.Attr("epsilon"); auto x = *ctx.Input("X"); - auto *y = ctx.Input("Y"); - auto *mean = ctx.Input("Mean"); - auto *var = ctx.Input("Variance"); - auto *scale = ctx.Input("Scale"); - auto *bias = ctx.Input("Bias"); + auto* y = ctx.Input("Y"); + auto* mean = ctx.Input("Mean"); + auto* var = ctx.Input("Variance"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); auto d_y = *ctx.Input(framework::GradVarName("Y")); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); // init output - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_scale = ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + auto* d_x = ctx.Output(framework::GradVarName("X")); + auto* d_scale = ctx.Output(framework::GradVarName("Scale")); + auto* d_bias = ctx.Output(framework::GradVarName("Bias")); - const auto &x_dims = x.dims(); + const auto& x_dims = x.dims(); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); framework::DDim matrix_shape({left, right}); d_y.Resize(matrix_shape); - auto &dev_ctx = ctx.template device_context(); - math::ColwiseSum colwise_sum; + auto& dev_ctx = ctx.template device_context(); + ColwiseSum2D colwise_sum(left, right, + ctx.device_context()); Tensor temp; Tensor temp_norm; @@ -190,7 +288,8 @@ class LayerNormGradKernel : public framework::OpKernel { Tensor temp_vec; temp_vec.mutable_data(vec_shape, ctx.GetPlace()); - math::RowwiseMean row_mean; + RowwiseMean2D row_mean(left, right, + ctx.device_context()); if (d_scale) { // dy_dx diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/lrn_mkldnn_op.cc index 3bead16ce44c26b9d7a6f2a5c6b471612494d595..0a18882e8199c2a375a230a693b8b01d12aabfa0 100644 --- a/paddle/fluid/operators/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/lrn_mkldnn_op.cc @@ -36,6 +36,14 @@ std::shared_ptr insert_to_context(const std::string& key, return p; } + +template +void run_primitive(Args&&... args) { + auto forward_op = mkldnn::lrn_forward{args...}; + + std::vector pipeline = {forward_op}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); +} } // namespace template @@ -87,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine}, static_cast(output_data)}; - std::unique_ptr forward_op = nullptr; - if (!is_test) { const std::string key = ctx.op().Output("Out"); const std::string key_src_memory = key + "@lrn_src_memory"; @@ -108,9 +114,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { key_workspace_memory, dev_ctx, forward_pd->workspace_primitive_desc()); - forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory, - *workspace_memory, dst_memory}); - + run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory); } else { auto forward_pd = mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; @@ -119,12 +123,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto workspace_memory = mkldnn::memory{forward_pd.workspace_primitive_desc()}; - forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory, - workspace_memory, dst_memory}); + run_primitive(forward_pd, src_memory, workspace_memory, dst_memory); } - - std::vector pipeline = {*forward_op}; - mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } }; @@ -136,6 +136,9 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { "MKLDNN LRN must use float data."); PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "MKLDNN LRN must use CPUPlace."); + PADDLE_ENFORCE( + !ctx.Attr("is_test"), + "is_test attribute should be set to False in training phase."); auto x = ctx.Input("X"); diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 2b1947a187bbd17871107553127647032ac7d7f9..b36b5c3a339bd7e534bcc3eb7a2efef313cb2a5d 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -155,8 +155,8 @@ class LRNOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4."); ctx->SetOutputDim("Out", x_dim); - ctx->SetOutputDim("MidOut", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("MidOut", x_dim); } framework::OpKernelType GetExpectedKernelType( diff --git a/paddle/fluid/operators/math/math_function.h b/paddle/fluid/operators/math/math_function.h index 47e2386d0578265330088eeac6c57fe2518f951a..cdbc7bfb37e83c6c2b696ba010277c9eec49f2a8 100644 --- a/paddle/fluid/operators/math/math_function.h +++ b/paddle/fluid/operators/math/math_function.h @@ -19,13 +19,6 @@ limitations under the License. */ #include #endif -#ifdef PADDLE_USE_ATLAS -extern "C" { -#include -#include -} -#endif - #ifdef PADDLE_USE_OPENBLAS #include #include diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc index 4001b9a130348b4e3ea99f3017eae6d85e41fc6e..b28c16b13fce30c6e9be9953009b53e722cf4885 100644 --- a/paddle/fluid/operators/parallel_do_op.cc +++ b/paddle/fluid/operators/parallel_do_op.cc @@ -144,7 +144,12 @@ class ParallelDoOp : public framework::OperatorBase { PADDLE_ENFORCE(scope.FindVar(param)->IsType(), "Only support parameter type as LoDTensor"); auto &src = scope.FindVar(param)->Get(); - for (size_t i = 0; i < sub_scopes.size(); ++i) { + + auto *sub_scope0 = sub_scopes[0]; + auto *dst0 = sub_scope0->Var(param)->GetMutable(); + dst0->ShareDataWith(src); + + for (size_t i = 1; i < sub_scopes.size(); ++i) { auto &place = places[i]; auto *sub_scope = sub_scopes[i]; auto *dst = sub_scope->Var(param)->GetMutable(); diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 76cdb794ccdb4a015ae8630940a5c26845e7a7b3..141a3eb93555c32efabc2465dc6daadf41c9d659 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -166,7 +166,9 @@ void DoubleBufferReader::PrefetchThreadFunc() { std::swap(gpu_batch, batch.payloads_); } - if (!buffer_->Send(&batch)) { + try { + buffer_->Send(&batch); + } catch (paddle::platform::EnforceNotMet e) { VLOG(5) << "WARNING: The double buffer channel has been closed. The " "prefetch thread will terminate."; break; diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 414c76fea0bb916dfeafe38c0448a7a800889e03..b6ac7b21d56f7760b3f4814581c90b0ff2cc4a6a 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -146,14 +146,19 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name, while (reader->HasNext()) { std::vector ins; reader->ReadNext(&ins); - if (!buffer_->Send(&ins)) { + try { + buffer_->Send(&ins); + } catch (paddle::platform::EnforceNotMet e) { VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " "thread of file '" << file_name << "' will terminate."; break; } } - if (!available_thread_idx_->Send(&thread_idx)) { + + try { + available_thread_idx_->Send(&thread_idx); + } catch (paddle::platform::EnforceNotMet e) { VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " "Fail to send thread_idx."; } diff --git a/paddle/fluid/operators/split_ids_op.cc b/paddle/fluid/operators/split_ids_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a54f8a2878c8606e6b487552324d1e7dfa94b9b8 --- /dev/null +++ b/paddle/fluid/operators/split_ids_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/split_ids_op.h" + +namespace paddle { +namespace operators { + +class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SplitIdsOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); + AddOutput("Out", "(LoDTensor) The outputs of the input Ids.") + .AsDuplicable(); + + AddComment(R"DOC( +Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number +Example: + Input: + X = [1,2,3,4,5,6] + + Out(3 output): + out0 = [3, 6] + out1 = [1, 4] + out2 = [2, 5] +)DOC"); + } +}; + +class SplitIdsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Ids"), "SplitIdsOp must has input Ids."); + PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out."); + + auto ids_var_type = ctx->GetInputsVarType("Ids").front(); + PADDLE_ENFORCE_EQ(ids_var_type, framework::proto::VarType::LOD_TENSOR); + + auto ids_dims = ctx->GetInputDim("Ids"); + PADDLE_ENFORCE_EQ(ids_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[1], 1); + } +}; + +class SplitIdsOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + for (auto &out_var : op_desc.Output("Out")) { + block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker, + ops::SplitIdsOpInferVarType); +REGISTER_OP_CPU_KERNEL( + split_ids, ops::SplitIdsOpKernel); diff --git a/paddle/fluid/operators/split_ids_op.h b/paddle/fluid/operators/split_ids_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3e750ed2d171876ce2d3c232f5d34234217b3c3e --- /dev/null +++ b/paddle/fluid/operators/split_ids_op.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" + +namespace paddle { +namespace operators { + +template +class SplitIdsOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto place = ctx.GetPlace(); + if (!platform::is_cpu_place(place)) { + PADDLE_THROW("SplitIds do not support GPU kernel"); + } + + const auto* ids_t = ctx.Input("Ids"); + auto& ids_dims = ids_t->dims(); + auto outs = ctx.MultiOutput("Out"); + + const T* ids = ids_t->data(); + + const size_t shard_num = outs.size(); + + std::vector> out_ids; + out_ids.resize(outs.size()); + + // split id by their shard_num. + for (size_t i = 0; i < ids_dims[0]; ++i) { + T id = ids[i]; + size_t shard_id = static_cast(id) % shard_num; + out_ids[shard_id].push_back(id); + } + + // create tensor for each shard and send to parameter server + for (size_t i = 0; i < out_ids.size(); ++i) { + auto* shard_t = outs[i]; + std::vector ids = out_ids[i]; + auto* shard_data = shard_t->mutable_data( + framework::make_ddim({static_cast(ids.size()), 1}), place); + for (size_t i = 0; i < ids.size(); ++i) { + shard_data[i] = ids[i]; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/math/MathFunctions.cpp b/paddle/math/MathFunctions.cpp index b2ff4bc3232a8e5d5d7b49bf49c62fe756d303f4..de404cad89fba8021b8645a40e25c1f5b7e86596 100644 --- a/paddle/math/MathFunctions.cpp +++ b/paddle/math/MathFunctions.cpp @@ -59,17 +59,10 @@ void* lapack_dso_handle = nullptr; } __name; // struct DynLoad__##__name #endif -#ifdef PADDLE_USE_ATLAS - #define PADDLE_SGETRF clapack_sgetrf - #define PADDLE_DGETRF clapack_dgetrf - #define PADDLE_SGETRI clapack_sgetri - #define PADDLE_DGETRI clapack_dgetri -#else - #define PADDLE_SGETRF LAPACKE_sgetrf - #define PADDLE_DGETRF LAPACKE_dgetrf - #define PADDLE_SGETRI LAPACKE_sgetri - #define PADDLE_DGETRI LAPACKE_dgetri -#endif +#define PADDLE_SGETRF LAPACKE_sgetrf +#define PADDLE_DGETRF LAPACKE_dgetrf +#define PADDLE_SGETRI LAPACKE_sgetri +#define PADDLE_DGETRI LAPACKE_dgetri #define LAPACK_ROUTINE_EACH(__macro) \ __macro(PADDLE_SGETRF) \ diff --git a/paddle/math/MathFunctions.h b/paddle/math/MathFunctions.h index f4cf6bd6c2c06f95cda098af389b37b7ff2983eb..f3d8b1a39e849d5f5a9e79cf33252b60170ced81 100644 --- a/paddle/math/MathFunctions.h +++ b/paddle/math/MathFunctions.h @@ -21,7 +21,7 @@ limitations under the License. */ #include #endif -#if defined(PADDLE_USE_ATLAS) || defined(PADDLE_USE_VECLIB) +#if defined(PADDLE_USE_VECLIB) extern "C" { #include #include diff --git a/paddle/scripts/submit_local.sh.in b/paddle/scripts/submit_local.sh.in index 80fa0c72af65cbdc21ba955389318a233e02657c..1283de9d957a46b848c7bb6caf9c5f49398468e2 100755 --- a/paddle/scripts/submit_local.sh.in +++ b/paddle/scripts/submit_local.sh.in @@ -153,9 +153,15 @@ if [ $? -ne 0 ]; then exit 1 fi -INSTALLED_VERSION=`pip freeze 2>/dev/null | grep '^paddle' | sed 's/.*==//g'` +if [ "@WITH_GPU@" == "ON" ]; then + PADDLE_NAME="paddlepaddle-gpu" +else + PADDLE_NAME="paddlepaddle" +fi + +INSTALLED_VERSION=`pip freeze 2>/dev/null | grep "^${PADDLE_NAME}==" | sed 's/.*==//g'` -if [ -z ${INSTALLED_VERSION} ]; then +if [ -z "${INSTALLED_VERSION}" ]; then INSTALLED_VERSION="0.0.0" # not installed fi cat <`_ + + Args: + input (Variable): The input tensor of this layer, and the dimension of input tensor must be 4. + n (int, default 5): The number of channels to sum over. + k (float, default 1.0): An offset (usually positive to avoid dividing by 0). + alpha (float, default 1e-4): The scaling parameter. + beta (float, default 0.75): The exponent. + name (str, default None): A name for this operation. + + Raises: + ValueError: If rank of the input tensor is not 4. + + Returns: + A tensor variable storing the transformation result. + + Examples: + .. code-block:: python + + data = fluid.layers.data(name="data", shape=[3, 112, 112], dtype="float32") + lrn = fluid.layers.lrn(input=data) + """ + helper = LayerHelper('lrn', **locals()) + dtype = helper.input_dtype() + input_shape = input.shape + dims = len(input_shape) + + if dims != 4: + raise ValueError( + "dims of input must be 4(not %d), and it's order must be NCHW" % + (dims)) + + mid_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) + lrn_out = helper.create_tmp_variable(dtype) + helper.append_op( + type="lrn", + inputs={"X": input}, + outputs={ + "Out": lrn_out, + "MidOut": mid_out, + }, + attrs={"n": n, + "k": k, + "alpha": alpha, + "beta": beta}) + + return lrn_out diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index 3b2e1a3073251a6d6460450dc957e1b5c7a873c5..bbedf6fde0872fd32d81c103bf5fe61449b7f57b 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -98,7 +98,7 @@ def img_conv_group(input, use_mkldnn=use_mkldnn) if conv_with_batchnorm[i]: - tmp = layers.batch_norm(input=tmp, act=conv_act) + tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True) drop_rate = conv_batchnorm_drop_rate[i] if abs(drop_rate) > 1e-5: tmp = layers.dropout(x=tmp, dropout_prob=drop_rate) diff --git a/python/paddle/fluid/tests/test_concurrency.py b/python/paddle/fluid/tests/test_concurrency.py index 924895a9afac610059bac5f617c49712441339cc..e8f6cfb4a907b2c01e9662e7e9bf2cb0fbd6cb1b 100644 --- a/python/paddle/fluid/tests/test_concurrency.py +++ b/python/paddle/fluid/tests/test_concurrency.py @@ -173,16 +173,10 @@ class TestRoutineOp(unittest.TestCase): with while_op.block(): result2 = fill_constant( shape=[1], dtype=core.VarDesc.VarType.INT32, value=0) - x_to_send_tmp = fill_constant( - shape=[1], dtype=core.VarDesc.VarType.INT32, value=0) - - # TODO(abhinav): Need to perform copy when doing a channel send. - # Once this is complete, we can remove these lines - assign(input=x, output=x_to_send_tmp) with fluid.Select() as select: - with select.case(fluid.channel_send, channel, - x_to_send_tmp): + with select.case( + fluid.channel_send, channel, x, is_copy=True): assign(input=x, output=x_tmp) assign(input=y, output=x) assign(elementwise_add(x=x_tmp, y=y), output=y) @@ -230,21 +224,12 @@ class TestRoutineOp(unittest.TestCase): core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.FP64) - pong_result = self._create_tensor('pong_return_value', - core.VarDesc.VarType.LOD_TENSOR, - core.VarDesc.VarType.FP64) - def ping(ch, message): - message_to_send_tmp = fill_constant( - shape=[1], dtype=core.VarDesc.VarType.FP64, value=0) - - assign(input=message, output=message_to_send_tmp) - fluid.channel_send(ch, message_to_send_tmp) + fluid.channel_send(ch, message, is_copy=True) def pong(ch1, ch2): fluid.channel_recv(ch1, ping_result) - assign(input=ping_result, output=pong_result) - fluid.channel_send(ch2, pong_result) + fluid.channel_send(ch2, ping_result, is_copy=True) pings = fluid.make_channel( dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index b5fd59cf3a1bea50b799c3ace8f3b9cea088b9d5..2179826d81f715d6d280aea28a76f919330dd644 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -231,6 +231,13 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(layers.softmax(hid)) print(str(program)) + def test_lrn(self): + program = Program() + with program_guard(program): + data = layers.data(name='data', shape=[6, 2, 2], dtype='float32') + self.assertIsNotNone(layers.lrn(data)) + print(str(program)) + def test_get_places(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_lrn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_op.py index 2268eafdbd08cd0d6a175d19cedd79b7b984289b..8fa480b9bce84d2936f23cce9e41e8e54014b074 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lrn_op.py @@ -97,5 +97,24 @@ class TestLRNMKLDNNOp(TestLRNOp): self.check_output(atol=0.002) +class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp): + def get_attrs(self): + attrs = TestLRNMKLDNNOp.get_attrs(self) + attrs['is_test'] = True + return attrs + + def test_check_grad_normal(self): + def check_raise_is_test(): + try: + self.check_grad(['X'], 'Out', max_relative_error=0.01) + except Exception as e: + t = \ + "is_test attribute should be set to False in training phase." + if t in str(e): + raise AttributeError + + self.assertRaises(AttributeError, check_raise_is_test) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_recv_op.py b/python/paddle/fluid/tests/unittests/test_recv_op.py index 985d892c568472614c5f3e6691f54807ddccc4bd..854238c6279528d8f3adf173140a47e233134f43 100644 --- a/python/paddle/fluid/tests/unittests/test_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_recv_op.py @@ -38,14 +38,15 @@ class TestRecvOp(unittest.TestCase): def init_serv(self, place): main = fluid.Program() with fluid.program_guard(main): - x = layers.data( - shape=[32, 32], - dtype='float32', - name="X", - append_batch_size=False) - fluid.initializer.Constant(value=1.0)(x, main.global_block()) - serv = layers.ListenAndServ("127.0.0.1:6174", optimizer_mode=False) + serv = layers.ListenAndServ( + "127.0.0.1:6174", ["X"], optimizer_mode=False) with serv.do(): + x = layers.data( + shape=[32, 32], + dtype='float32', + name="X", + append_batch_size=False) + fluid.initializer.Constant(value=1.0)(x, main.global_block()) o = layers.scale(x=x, scale=10.0) main.global_block().create_var( name=o.name, psersistable=False, dtype=o.dtype, shape=o.shape) diff --git a/python/paddle/fluid/tests/unittests/test_split_ids_op.py b/python/paddle/fluid/tests/unittests/test_split_ids_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f0a06a56b42952800411d548bb3fc1732e031e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_split_ids_op.py @@ -0,0 +1,35 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestSplitIdsOp(OpTest): + def setUp(self): + self.op_type = "split_ids" + ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') + out0 = np.array([[0], [3], [6]]).astype('int64') + out1 = np.array([[]]).astype('int64') + out2 = np.array([[2], [2], [5], [5]]).astype('int64') + self.inputs = {'Ids': ids} + self.outputs = {'Out': [('out0', out0), ('out1', out1), ('out2', out2)]} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/trainer_config_helpers/activations.py b/python/paddle/trainer_config_helpers/activations.py index 00efc01c0592107314f5b23c951706d039d49a88..3683968262266a2d654d2480b828173bc761152b 100644 --- a/python/paddle/trainer_config_helpers/activations.py +++ b/python/paddle/trainer_config_helpers/activations.py @@ -77,7 +77,7 @@ class SoftmaxActivation(BaseActivation): .. math:: - P(y=j|x) = \\frac{e^{x_j}} {\\sum^K_{k=1} e^{x_j} } + P(y=j|x) = \\frac{e^{x_j}} {\\sum^K_{k=1} e^{x_k} } """ def __init__(self):