diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake index 52a22c1fbf4779fa3c0ca687cab664bd3ca0410a..e3b9d94215a858c5c9a34e1b7e97540f1876801d 100644 --- a/cmake/cblas.cmake +++ b/cmake/cblas.cmake @@ -78,7 +78,7 @@ if(NOT CMAKE_CROSSCOMPILING) /usr/lib/reference/ ) else() - # Diable the finding of reference cblas under host's system path + # Disable the finding of reference cblas under host's system path set(REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/include) set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib) endif() diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index c2ca1bbc78f3ebc6066df6b666720af0d1fbbf59..513e720fd099bcd898a6c73afd1a3a16f6f53aab 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name, if (tensor.memory_size() == 0) { return; } - if (tensor.type().hash_code() != typeid(float).hash_code() && - tensor.type().hash_code() != typeid(double).hash_code()) { + if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT + tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT return; } PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), @@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, // Return true if the block has feed operators and holder of matching info. static bool has_feed_operators( const BlockDesc& block, - std::map& feed_targets, + const std::map& feed_targets, const std::string& feed_holder_name) { size_t feed_count = 0; for (auto* op : block.AllOps()) { if (op->Type() == kFeedOpType) { feed_count++; + // The input variable's name of feed_op should be feed_holder_name. PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, "Input to feed op should be '%s'", feed_holder_name); std::string feed_target_name = op->Output("Out")[0]; @@ -166,13 +167,15 @@ static bool has_feed_operators( feed_count, feed_targets.size(), "The number of feed operators should match 'feed_targets'"); - // When feed operator are present, so should be feed_holder - auto var = block.FindVar(feed_holder_name); - PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", - feed_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, - "'%s' variable should be 'FEED_MINIBATCH' type", - feed_holder_name); + if (!feed_holder_name.empty()) { + // When feed operator are present, so should be feed_holder. + auto var = block.FindVar(feed_holder_name); + PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", + feed_holder_name); + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, + "'%s' variable should be 'FEED_MINIBATCH' type", + feed_holder_name); + } } return feed_count > 0; @@ -185,12 +188,14 @@ static bool has_feed_operators( // and fetch_holder_name. Raise exception when any mismatch is found. // Return true if the block has fetch operators and holder of matching info. static bool has_fetch_operators( - const BlockDesc& block, std::map& fetch_targets, + const BlockDesc& block, + const std::map& fetch_targets, const std::string& fetch_holder_name) { size_t fetch_count = 0; for (auto* op : block.AllOps()) { if (op->Type() == kFetchOpType) { fetch_count++; + // The output variable's name of fetch_op should be fetch_holder_name. PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, "Output of fetch op should be '%s'", fetch_holder_name); std::string fetch_target_name = op->Input("X")[0]; @@ -206,13 +211,15 @@ static bool has_fetch_operators( fetch_count, fetch_targets.size(), "The number of fetch operators should match 'fetch_targets'"); - // When fetch operator are present, so should be fetch_holder - auto var = block.FindVar(fetch_holder_name); - PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", - fetch_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, - "'%s' variable should be 'FETCH_LIST' type", - fetch_holder_name); + if (!fetch_holder_name.empty()) { + // When fetch operator are present, so should be fetch_holder. + auto var = block.FindVar(fetch_holder_name); + PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", + fetch_holder_name); + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, + "'%s' variable should be 'FETCH_LIST' type", + fetch_holder_name); + } } return fetch_count > 0; @@ -259,16 +266,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } - // map the data of feed_targets to feed_holder - for (auto* op : global_block->AllOps()) { - if (op->Type() == kFeedOpType) { - std::string feed_target_name = op->Output("Out")[0]; - int idx = boost::get(op->GetAttr("col")); - SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, - idx); - } - } - if (!has_fetch_ops) { // create fetch_holder variable auto* fetch_holder = global_block->Var(fetch_holder_name); @@ -292,17 +289,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } - Run(*copy_program, scope, 0, create_vars, create_vars); - - // obtain the data of fetch_targets from fetch_holder - for (auto* op : global_block->AllOps()) { - if (op->Type() == kFetchOpType) { - std::string fetch_target_name = op->Input("X")[0]; - int idx = boost::get(op->GetAttr("col")); - *fetch_targets[fetch_target_name] = - GetFetchVariable(*scope, fetch_holder_name, idx); - } - } + auto ctx = Prepare(*copy_program, 0); + RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, create_vars, + feed_holder_name, fetch_holder_name); } std::unique_ptr Executor::Prepare( @@ -370,5 +359,42 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } } +void Executor::RunPreparedContext( + ExecutorPrepareContext* ctx, Scope* scope, + std::map& feed_targets, + std::map& fetch_targets, bool create_vars, + const std::string& feed_holder_name, const std::string& fetch_holder_name) { + auto& global_block = ctx->prog_.Block(ctx->block_id_); + + PADDLE_ENFORCE( + has_feed_operators(global_block, feed_targets, feed_holder_name), + "Program in ExecutorPrepareContext should has feed_ops."); + PADDLE_ENFORCE( + has_fetch_operators(global_block, fetch_targets, fetch_holder_name), + "Program in the prepared context should has fetch_ops."); + + // map the data of feed_targets to feed_holder + for (auto* op : global_block.AllOps()) { + if (op->Type() == kFeedOpType) { + std::string feed_target_name = op->Output("Out")[0]; + int idx = boost::get(op->GetAttr("col")); + SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, + idx); + } + } + + RunPreparedContext(ctx, scope, create_vars, create_vars); + + // obtain the data of fetch_targets from fetch_holder + for (auto* op : global_block.AllOps()) { + if (op->Type() == kFetchOpType) { + std::string fetch_target_name = op->Input("X")[0]; + int idx = boost::get(op->GetAttr("col")); + *fetch_targets[fetch_target_name] = + GetFetchVariable(*scope, fetch_holder_name, idx); + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 75b29b2f4065ad75b62a134b890b8f9f6730fdc7..43defdacf2a1c2f59cf3af2461ae6cfc4c61f5be 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#include +#include +#include #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -70,6 +73,13 @@ class Executor { bool create_local_scope = true, bool create_vars = true); + void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, + std::map& feed_targets, + std::map& fetch_targets, + bool create_vars = true, + const std::string& feed_holder_name = "feed", + const std::string& fetch_holder_name = "fetch"); + private: const platform::Place place_; }; diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 9854d618d2b29ed123833f55198179638c95d6db..f26f212d4d5793b88fd1e6d782cdf983bf341879 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -14,8 +14,12 @@ #include "paddle/fluid/framework/threadpool.h" +#include "gflags/gflags.h" #include "paddle/fluid/platform/enforce.h" +DEFINE_int32(io_threadpool_size, 100, + "number of threads used for doing IO, default 100"); + namespace paddle { namespace framework { @@ -91,5 +95,20 @@ void ThreadPool::TaskLoop() { } } +std::unique_ptr ThreadPoolIO::io_threadpool_(nullptr); +std::once_flag ThreadPoolIO::io_init_flag_; + +ThreadPool* ThreadPoolIO::GetInstanceIO() { + std::call_once(io_init_flag_, &ThreadPoolIO::InitIO); + return io_threadpool_.get(); +} + +void ThreadPoolIO::InitIO() { + if (io_threadpool_.get() == nullptr) { + // TODO(typhoonzero1986): make this configurable + io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size)); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index f9dce7105e32ff0ba03d03f8faaac3a4ed1a3595..94111ee335b1a5df327b3e46d62069b4735c54f6 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -14,12 +14,12 @@ limitations under the License. */ #pragma once -#include +#include // NOLINT #include -#include -#include +#include // NOLINT +#include // NOLINT #include -#include +#include // NOLINT #include #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" @@ -28,6 +28,22 @@ limitations under the License. */ namespace paddle { namespace framework { +struct ExceptionHandler { + mutable std::future> future_; + explicit ExceptionHandler( + std::future>&& f) + : future_(std::move(f)) {} + void operator()() const { + auto ex = this->future_.get(); + if (ex != nullptr) { + LOG(FATAL) << "The exception is thrown inside the thread pool. You " + "should use RunAndGetException to handle the exception.\n" + "The default exception handler is LOG(FATAL)." + << ex->what(); + } + } +}; + // ThreadPool maintains a queue of tasks, and runs them using a fixed // number of threads. class ThreadPool { @@ -87,22 +103,6 @@ class ThreadPool { void Wait(); private: - struct ExceptionHandler { - mutable std::future> future_; - explicit ExceptionHandler( - std::future>&& f) - : future_(std::move(f)) {} - void operator()() const { - auto ex = this->future_.get(); - if (ex != nullptr) { - LOG(FATAL) << "The exception is thrown inside the thread pool. You " - "should use RunAndGetException to handle the exception.\n" - "The default exception handler is LOG(FATAL)." - << ex->what(); - } - } - }; - DISABLE_COPY_AND_ASSIGN(ThreadPool); // If the task queue is empty and avaialbe is equal to the number of @@ -135,6 +135,17 @@ class ThreadPool { std::condition_variable completed_; }; +class ThreadPoolIO : ThreadPool { + public: + static ThreadPool* GetInstanceIO(); + static void InitIO(); + + private: + // NOTE: threadpool in base will be inhereted here. + static std::unique_ptr io_threadpool_; + static std::once_flag io_init_flag_; +}; + // Run a function asynchronously. // NOTE: The function must return void. If the function need to return a value, // you can use lambda to capture a value pointer. @@ -143,5 +154,10 @@ std::future Async(Callback callback) { return ThreadPool::GetInstance()->Run(callback); } +template +std::future AsyncIO(Callback callback) { + return ThreadPoolIO::GetInstanceIO()->Run(callback); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index a29d457b6fa9d0e8297252c8ff1117013d2055f8..3b58019db6e55fa8198d2f77731095c6cf356266 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -23,7 +23,7 @@ limitations under the License. */ namespace paddle { namespace inference { -// Temporarilly add this function for exposing framework::InitDevices() when +// Temporarily add this function for exposing framework::InitDevices() when // linking the inference shared library. void Init(bool init_p2p) { framework::InitDevices(init_p2p); } diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index ca2077d07411d2cd6095e0dc2a874af0890145c5..1e6555bb02033a28dedd2a1d1962981dfcc97cc2 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -46,8 +46,8 @@ TEST(inference, image_classification) { // Run inference on CPU LOG(INFO) << "--- CPU Runs: ---"; - TestInference(dirname, cpu_feeds, - cpu_fetchs1, FLAGS_repeat); + TestInference( + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat); LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA @@ -57,8 +57,8 @@ TEST(inference, image_classification) { // Run inference on CUDA GPU LOG(INFO) << "--- GPU Runs: ---"; - TestInference(dirname, cpu_feeds, - cpu_fetchs2, FLAGS_repeat); + TestInference( + dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat); LOG(INFO) << output2.dims(); CheckError(output1, output2); diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 064e400f0c750872ab2142c5fc8e28dd3da85b1a..c3a8d0889c6a6dd9591837ccc523da56f8d13661 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -89,7 +89,7 @@ void CheckError(const paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; } -template +template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, const std::vector& cpu_fetchs, @@ -175,8 +175,15 @@ void TestInference(const std::string& dirname, } // Ignore the profiling results of the first run - executor.Run(*inference_program, scope, feed_targets, fetch_targets, - CreateVars); + std::unique_ptr ctx; + if (PrepareContext) { + ctx = executor.Prepare(*inference_program, 0); + executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, + CreateVars); + } else { + executor.Run(*inference_program, scope, feed_targets, fetch_targets, + CreateVars); + } // Enable the profiler paddle::platform::EnableProfiler(state); @@ -187,8 +194,15 @@ void TestInference(const std::string& dirname, "run_inference", paddle::platform::DeviceContextPool::Instance().Get(place)); - executor.Run(*inference_program, scope, feed_targets, fetch_targets, - CreateVars); + if (PrepareContext) { + // Note: if you change the inference_program, you need to call + // executor.Prepare() again to get a new ExecutorPrepareContext. + executor.RunPreparedContext(ctx.get(), scope, feed_targets, + fetch_targets, CreateVars); + } else { + executor.Run(*inference_program, scope, feed_targets, fetch_targets, + CreateVars); + } } // Disable the profiler and print the timing information diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 45f88ec8697d9f3de2612f28889fefc36f7ddbf9..661dfa69fe1580ff3890f12defcd124225be0c06 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -35,7 +35,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { + framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, + this] { auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer req; @@ -89,7 +90,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { + framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, + this] { // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); @@ -132,8 +134,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - time_out, ch, this] { + framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, + time_out, ch, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; @@ -196,7 +198,7 @@ bool RPCClient::Wait() { std::vector> waits(req_count_); for (int i = 0; i < req_count_; i++) { - waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); }); + waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); } for (int i = 0; i < req_count_; i++) { diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 0b582a08bc0bfbcfdc8f338a6add8edaa5e80818..119e146e078e476b2768a8495ea63e468f952fd2 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -217,10 +217,10 @@ void AsyncGRPCServer::RunSyncUpdate() { std::function prefetch_register = std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); + // TODO(wuyi): Run these "HandleRequest" in thread pool t_send_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_send_.get(), "cq_send", send_register))); - t_get_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 33a50b5cebc1f65ccf9a00280f0eeadd00982555..0b7c1d6af714558d35918dac62d92d9e0f86c970 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -33,28 +33,14 @@ static constexpr size_t kChannelSize = 0; // kCacheSize - 2 class DoubleBufferReader : public framework::DecoratedReader { public: - struct Item { - Item() : ctx_(nullptr) {} - Item(Item&& b) { - payloads_ = std::move(b.payloads_); - ctx_ = std::move(b.ctx_); - } - Item& operator=(Item&& b) { - payloads_ = std::move(b.payloads_); - ctx_ = std::move(b.ctx_); - return *this; - } - - std::vector payloads_; - platform::DeviceContext* ctx_; - }; - explicit DoubleBufferReader( ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) : DecoratedReader(reader), place_(target_place) { + cpu_tensor_cache_.resize(kCacheSize); + gpu_tensor_cache_.resize(kCacheSize); #ifdef PADDLE_WITH_CUDA - for (size_t i = 0; i < kCacheSize; ++i) { - if (platform::is_gpu_place(place_)) { + if (platform::is_gpu_place(place_)) { + for (size_t i = 0; i < kCacheSize; ++i) { ctxs_.emplace_back(new platform::CUDADeviceContext( boost::get(place_))); } @@ -72,7 +58,7 @@ class DoubleBufferReader : public framework::DecoratedReader { bool HasNext() const; void StartPrefetcher() { - channel_ = framework::MakeChannel(kChannelSize); + channel_ = framework::MakeChannel(kChannelSize); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); } @@ -88,8 +74,10 @@ class DoubleBufferReader : public framework::DecoratedReader { void PrefetchThreadFunc(); std::thread prefetcher_; - framework::Channel* channel_; + framework::Channel* channel_; platform::Place place_; + std::vector> cpu_tensor_cache_; + std::vector> gpu_tensor_cache_; std::vector> ctxs_; }; @@ -153,11 +141,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { void DoubleBufferReader::ReadNext(std::vector* out) { out->clear(); if (HasNext()) { - Item batch; - channel_->Receive(&batch); - *out = batch.payloads_; - if (batch.ctx_) { - batch.ctx_->Wait(); + size_t cached_tensor_id; + channel_->Receive(&cached_tensor_id); + if (platform::is_gpu_place(place_)) { + *out = gpu_tensor_cache_[cached_tensor_id]; + ctxs_[cached_tensor_id]->Wait(); + } else { + // CPU place + *out = cpu_tensor_cache_[cached_tensor_id]; } } } @@ -176,42 +167,33 @@ bool DoubleBufferReader::HasNext() const { void DoubleBufferReader::PrefetchThreadFunc() { VLOG(5) << "A new prefetch thread starts."; - std::vector> cpu_tensor_cache(kCacheSize); - std::vector> gpu_tensor_cache(kCacheSize); size_t cached_tensor_id = 0; - while (true) { - Item batch; - auto& cpu_batch = cpu_tensor_cache[cached_tensor_id]; + auto& cpu_batch = cpu_tensor_cache_[cached_tensor_id]; reader_->ReadNext(&cpu_batch); if (cpu_batch.empty()) { // The underlying reader have no next data. break; } if (platform::is_gpu_place(place_)) { - auto& gpu_batch = gpu_tensor_cache[cached_tensor_id]; + auto& gpu_batch = gpu_tensor_cache_[cached_tensor_id]; auto* gpu_ctx = ctxs_[cached_tensor_id].get(); gpu_batch.resize(cpu_batch.size()); for (size_t i = 0; i < cpu_batch.size(); ++i) { framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]); gpu_batch[i].set_lod(cpu_batch[i].lod()); } - batch.payloads_ = gpu_batch; - batch.ctx_ = gpu_ctx; - } else { - // CPUPlace - batch.payloads_ = cpu_batch; } - ++cached_tensor_id; - cached_tensor_id %= kCacheSize; - try { - channel_->Send(&batch); + size_t tmp = cached_tensor_id; + channel_->Send(&tmp); } catch (paddle::platform::EnforceNotMet e) { VLOG(5) << "WARNING: The double buffer channel has been closed. The " "prefetch thread will terminate."; break; } + ++cached_tensor_id; + cached_tensor_id %= kCacheSize; } channel_->Close(); VLOG(5) << "Prefetch thread terminates."; diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index 9abc78421a7554f51f56665e4d82d34e67c7c159..8320c257c9ab15efec29eabe99eca5b6f74c9e31 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -147,6 +147,7 @@ class ReshapeKernel : public framework::OpKernel { if (!inplace) { out->mutable_data(ctx.GetPlace()); framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); + ctx.device_context().Wait(); // TensorCopy will resize to in_dims. out->Resize(out_dims); } else { @@ -169,6 +170,7 @@ class ReshapeGradKernel : public framework::OpKernel { auto in_dims = d_x->dims(); if (!inplace) { framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + ctx.device_context().Wait(); d_x->Resize(in_dims); } else { d_x->ShareDataWith(*d_out); diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 87699362b2b5a14750a01345098ec5e6cc9be115..acaefaacdaa593c090d81084fdc1b3665314833f 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -24,7 +24,19 @@ template class CPUUniformRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* tensor = ctx.Output("Out"); + framework::Tensor* tensor = nullptr; + auto out_var = ctx.OutputVar("Out"); + if (out_var->IsType()) { + tensor = out_var->GetMutable(); + } else if (out_var->IsType()) { + auto shape = ctx.Attr>("shape"); + tensor = out_var->GetMutable()->mutable_value(); + tensor->Resize(framework::make_ddim(shape)); + } else { + PADDLE_THROW( + "uniform_random_op's output only" + "supports SelectedRows and Tensor"); + } T* data = tensor->mutable_data(ctx.GetPlace()); unsigned int seed = static_cast(ctx.Attr("seed")); std::minstd_rand engine; diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index 1232cd1eb332441b12e59a34b2c2f75669925fd0..e1c7323a30233f4ec4f60e46aa6088ee6d8601b7 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -43,7 +43,19 @@ template class GPUUniformRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); + framework::Tensor* tensor = nullptr; + auto out_var = context.OutputVar("Out"); + if (out_var->IsType()) { + tensor = out_var->GetMutable(); + } else if (out_var->IsType()) { + auto shape = context.Attr>("shape"); + tensor = out_var->GetMutable()->mutable_value(); + tensor->Resize(framework::make_ddim(shape)); + } else { + PADDLE_THROW( + "uniform_random_op's output only" + "supports SelectedRows and Tensor"); + } T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.Attr("seed")); if (seed == 0) { diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index e4997b4069f60ff4382b4254bc026ae8ae29b345..5ec6890c1b0dabd2804a92071b63c9610299e67c 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -157,7 +157,6 @@ def train(nn_type, for ip in pserver_ips.split(","): eplist.append(':'.join([ip, port])) pserver_endpoints = ",".join(eplist) # ip:port,ip:port... - pserver_endpoints = os.getenv("PSERVERS") trainers = int(os.getenv("TRAINERS")) current_endpoint = os.getenv("POD_IP") + ":" + port trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID")) diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index 75ff85a55fc4fd504ecd032e17f7e189c17192fb..346a949b6e7c96b5535f5e65ddbada11e110a0a7 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -15,6 +15,16 @@ import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator + + +def output_hist(out): + hist, _ = np.histogram(out, range=(-5, 10)) + hist = hist.astype("float32") + hist /= float(out.size) + prob = 0.1 * np.ones((10)) + return hist, prob class TestUniformRandomOp(OpTest): @@ -33,11 +43,37 @@ class TestUniformRandomOp(OpTest): self.check_output_customized(self.verify_output) def verify_output(self, outs): - tensor = outs[0] - hist, _ = np.histogram(outs[0], range=(-5, 10)) - hist = hist.astype("float32") - hist /= float(outs[0].size) - prob = 0.1 * np.ones((10)) + hist, prob = output_hist(np.array(outs[0])) + self.assertTrue( + np.allclose( + hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) + + +class TestUniformRandomOpSelectedRows(unittest.TestCase): + def get_places(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + return places + + def test_check_output(self): + for place in self.get_places(): + self.check_with_place(place) + + def check_with_place(self, place): + scope = core.Scope() + out = scope.var("X").get_selected_rows() + + op = Operator( + "uniform_random", + Out="X", + shape=[4, 784], + min=-5.0, + max=10.0, + seed=10) + op.run(scope, place) + self.assertEqual(out.get_tensor().shape(), [4, 784]) + hist, prob = output_hist(np.array(out.get_tensor())) self.assertTrue( np.allclose( hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))