提交 44c662b4 编写于 作者: D dzhwinter

Merge remote-tracking branch 'origin/develop' into fix/cudnn

...@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope ...@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method) framework_proto glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor) cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -36,5 +36,6 @@ cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_ha ...@@ -36,5 +36,6 @@ cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_ha
device_context broadcast_op_handle) device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context gather_op_handle) device_context gather_op_handle)
cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor)
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory #cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle ) # device_context reduce_op_handle )
...@@ -22,6 +22,7 @@ struct ExecutionStrategy { ...@@ -22,6 +22,7 @@ struct ExecutionStrategy {
size_t num_threads_{0}; size_t num_threads_{0};
bool use_event_{true}; bool use_event_{true};
bool allow_op_delay_{false}; bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{100};
}; };
} // namespace details } // namespace details
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
namespace paddle {
namespace framework {
namespace details {
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
: strategy_(std::move(strategy)),
underlying_executor_(std::move(underlying_executor)),
local_scopes_(std::move(local_scopes)),
var_infos_(std::move(var_infos)),
places_(std::move(places)) {}
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) {
// Create local scopes.
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
}
auto fetch_data = underlying_executor_->Run(fetch_tensors);
drop_scope_counter_ += 1;
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0;
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
}
}
return fetch_data;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace details {
struct VariableInfo {
std::string name_;
proto::VarType::Type type_;
bool persistable_;
};
class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
public:
ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:
size_t drop_scope_counter_{0};
ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
std::vector<Scope*> local_scopes_;
std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -17,10 +17,6 @@ ...@@ -17,10 +17,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
SSAGraphExecutor::SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
: graph_(std::move(graph)) {}
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
} // namespace details } // namespace details
......
...@@ -28,15 +28,11 @@ class SSAGraphExecutor { ...@@ -28,15 +28,11 @@ class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor); DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
public: public:
// Steal graph inside SSAGraphExecutor() {}
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph);
virtual ~SSAGraphExecutor(); virtual ~SSAGraphExecutor();
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
protected:
std::unique_ptr<SSAGraph> graph_;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -21,7 +21,7 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -21,7 +21,7 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph) std::unique_ptr<SSAGraph> &&graph)
: SSAGraphExecutor(std::move(graph)), : graph_(std::move(graph)),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr), : nullptr),
local_scopes_(local_scopes), local_scopes_(local_scopes),
...@@ -189,7 +189,9 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -189,7 +189,9 @@ void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) { BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
try { try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
op->Run(strategy_.use_event_); op->Run(strategy_.use_event_);
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--; running_ops_--;
......
...@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op); details::OpHandleBase *op);
private: private:
std::unique_ptr<SSAGraph> graph_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#endif #endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -42,8 +43,6 @@ class ParallelExecutorPrivate { ...@@ -42,8 +43,6 @@ class ParallelExecutorPrivate {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif #endif
std::vector<std::tuple<std::string, proto::VarType::Type, bool>> var_types_;
bool own_local_scope; bool own_local_scope;
}; };
...@@ -92,9 +91,18 @@ ParallelExecutor::ParallelExecutor( ...@@ -92,9 +91,18 @@ ParallelExecutor::ParallelExecutor(
local_scopes.empty()) { // Is CUDA local_scopes.empty()) { // Is CUDA
BCastParamsToGPUs(bcast_vars); BCastParamsToGPUs(bcast_vars);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
// Step 2. Create vars in each scope;
std::vector<details::VariableInfo> var_infos;
for (auto *var : main_program.Block(0).AllVars()) {
var_infos.emplace_back();
var_infos.back().name_ = var->Name();
var_infos.back().type_ = var->GetType();
var_infos.back().persistable_ = var->Persistable();
}
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(
...@@ -105,16 +113,15 @@ ParallelExecutor::ParallelExecutor( ...@@ -105,16 +113,15 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, params, member_->local_scopes_,
build_strategy); build_strategy);
#endif #endif
auto graph = builder.Build(main_program); auto graph = builder.Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
// Step 3. Create vars in each scope; member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
for (auto *var : main_program.Block(0).AllVars()) { exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->var_types_.emplace_back(var->Name(), var->GetType(), member_->places_, std::move(member_->executor_)));
var->Persistable());
}
} }
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
...@@ -169,42 +176,9 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -169,42 +176,9 @@ void ParallelExecutor::BCastParamsToGPUs(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) { const std::string &fetched_var_name) {
platform::RecordBlock b(0); platform::RecordBlock b(0);
// Create local scopes.
for (auto it = member_->local_scopes_.rbegin();
it != member_->local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &name_type_pair : member_->var_types_) {
if (scope->FindVar(std::get<0>(name_type_pair)) != nullptr) {
continue;
}
if (std::get<2>(name_type_pair)) { // Persistable
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
} else {
InitializeVariable(local_scope.Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
}
}
}
auto fetch_data = member_->executor_->Run(fetch_tensors); auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() = *member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data; fetch_data;
// Wait All computational streams
for (auto p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : member_->local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
}
} }
void ParallelExecutor::FeedTensorsIntoLocalScopes( void ParallelExecutor::FeedTensorsIntoLocalScopes(
......
...@@ -175,7 +175,6 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -175,7 +175,6 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; } Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const { void CUDADeviceContext::Wait() const {
std::lock_guard<std::recursive_mutex> guard(mutex_);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError()); PADDLE_ENFORCE(cudaGetLastError());
} }
......
...@@ -100,7 +100,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -100,7 +100,6 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback> template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) { void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::recursive_mutex> guard(mutex_);
callback(); callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
...@@ -110,8 +109,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -110,8 +109,6 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
mutable std::recursive_mutex mutex_;
cudaStream_t stream_; cudaStream_t stream_;
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
......
...@@ -45,7 +45,7 @@ extern void *cublas_dso_handle; ...@@ -45,7 +45,7 @@ extern void *cublas_dso_handle;
std::call_once(cublas_dso_flag, []() { \ std::call_once(cublas_dso_flag, []() { \
cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \ cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \
}); \ }); \
void *p_##__name = dlsym(cublas_dso_handle, #__name); \ static void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<FUNC_TYPE>(p_##__name)(args...); \ return reinterpret_cast<FUNC_TYPE>(p_##__name)(args...); \
} \ } \
}; \ }; \
......
...@@ -39,7 +39,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -39,7 +39,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \ cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
}); \ }); \
EnforceCUDNNLoaded(#__name); \ EnforceCUDNNLoaded(#__name); \
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ static void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \ return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
} \ } \
}; \ }; \
......
...@@ -45,7 +45,7 @@ extern void *cupti_dso_handle; ...@@ -45,7 +45,7 @@ extern void *cupti_dso_handle;
std::call_once(cupti_dso_flag, []() { \ std::call_once(cupti_dso_flag, []() { \
cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \ cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \
}); \ }); \
void *p_##__name = dlsym(cupti_dso_handle, #__name); \ static void *p_##__name = dlsym(cupti_dso_handle, #__name); \
return reinterpret_cast<cuptiFunc>(p_##__name)(args...); \ return reinterpret_cast<cuptiFunc>(p_##__name)(args...); \
} \ } \
}; \ }; \
......
...@@ -34,7 +34,7 @@ extern void *curand_dso_handle; ...@@ -34,7 +34,7 @@ extern void *curand_dso_handle;
std::call_once(curand_dso_flag, []() { \ std::call_once(curand_dso_flag, []() { \
curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \ curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \
}); \ }); \
void *p_##__name = dlsym(curand_dso_handle, #__name); \ static void *p_##__name = dlsym(curand_dso_handle, #__name); \
return reinterpret_cast<curandFunc>(p_##__name)(args...); \ return reinterpret_cast<curandFunc>(p_##__name)(args...); \
} \ } \
}; \ }; \
......
...@@ -37,7 +37,7 @@ extern void* nccl_dso_handle; ...@@ -37,7 +37,7 @@ extern void* nccl_dso_handle;
std::call_once(nccl_dso_flag, []() { \ std::call_once(nccl_dso_flag, []() { \
nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \ nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \
}); \ }); \
void* p_##__name = dlsym(nccl_dso_handle, #__name); \ static void* p_##__name = dlsym(nccl_dso_handle, #__name); \
return reinterpret_cast<nccl_func>(p_##__name)(args...); \ return reinterpret_cast<nccl_func>(p_##__name)(args...); \
} \ } \
}; \ }; \
......
...@@ -40,7 +40,7 @@ extern void* tensorrt_dso_handle; ...@@ -40,7 +40,7 @@ extern void* tensorrt_dso_handle;
paddle::platform::dynload::GetTensorRtDsoHandle(); \ paddle::platform::dynload::GetTensorRtDsoHandle(); \
PADDLE_ENFORCE(tensorrt_dso_handle, "load tensorrt so failed"); \ PADDLE_ENFORCE(tensorrt_dso_handle, "load tensorrt so failed"); \
}); \ }); \
void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \ static void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \
PADDLE_ENFORCE(p_##__name, "load %s failed", #__name); \ PADDLE_ENFORCE(p_##__name, "load %s failed", #__name); \
return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \ return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \
} \ } \
......
...@@ -40,7 +40,7 @@ extern void* warpctc_dso_handle; ...@@ -40,7 +40,7 @@ extern void* warpctc_dso_handle;
std::call_once(warpctc_dso_flag, []() { \ std::call_once(warpctc_dso_flag, []() { \
warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \ warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
}); \ }); \
void* p_##_name = dlsym(warpctc_dso_handle, #__name); \ static void* p_##_name = dlsym(warpctc_dso_handle, #__name); \
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \ return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
} \ } \
}; \ }; \
......
...@@ -519,6 +519,14 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -519,6 +519,14 @@ All parameter, weight, gradient are variables in Paddle.
[](const ExecutionStrategy &self) { return self.allow_op_delay_; }, [](const ExecutionStrategy &self) { return self.allow_op_delay_; },
[](ExecutionStrategy &self, bool allow_op_delay) { [](ExecutionStrategy &self, bool allow_op_delay) {
self.allow_op_delay_ = allow_op_delay; self.allow_op_delay_ = allow_op_delay;
})
.def_property(
"num_iteration_per_drop_scope",
[](const ExecutionStrategy &self) {
return self.num_iteration_per_drop_scope_;
},
[](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) {
self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope;
}); });
py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy"); py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy");
......
...@@ -81,6 +81,8 @@ __all__ = [ ...@@ -81,6 +81,8 @@ __all__ = [
'label_smooth', 'label_smooth',
'roi_pool', 'roi_pool',
'dice_loss', 'dice_loss',
'image_resize',
'image_resize_short',
'resize_bilinear', 'resize_bilinear',
'gather', 'gather',
'random_crop', 'random_crop',
...@@ -3929,22 +3931,25 @@ def dice_loss(input, label, epsilon=0.00001): ...@@ -3929,22 +3931,25 @@ def dice_loss(input, label, epsilon=0.00001):
return reduce_mean(dice_score) return reduce_mean(dice_score)
def resize_bilinear(input, out_shape=None, scale=None, name=None): def image_resize(input,
out_shape=None,
scale=None,
name=None,
resample='BILINEAR'):
""" """
The mathematical meaning of resize bilinear layer is Resize a batch of images.
Bilinear interpolation.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this layer) on a rectilinear 2D grid.
For details, please refer to Wikipedia: The input must be a tensor of the shape (num_batches, channels, in_h, in_w),
https://en.wikipedia.org/wiki/Bilinear_interpolation and the resizing only applies on the last two dimensions(hight and width).
Supporting resample methods:
'BILINEAR' : Bilinear interpolation
Args: Args:
input (Variable): The input tensor of resize bilinear layer, input (Variable): The input tensor of image resize layer,
This is a 4-D tensor of the shape This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w). (num_batches, channels, in_h, in_w).
out_shape(list|tuple|Variable|None): Output shape of resize bilinear out_shape(list|tuple|Variable|None): Output shape of image resize
layer, the shape is (out_h, out_w). layer, the shape is (out_h, out_w).
Default: None Default: None
scale(float|None): The multiplier for the input height or width. scale(float|None): The multiplier for the input height or width.
...@@ -3953,6 +3958,8 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): ...@@ -3953,6 +3958,8 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
Default: None Default: None
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
resample(str): The resample method. It can only be 'BILINEAR' currently.
Default: 'BILINEAR'
Returns: Returns:
out (Variable): The output is a 4-D tensor of the shape out (Variable): The output is a 4-D tensor of the shape
...@@ -3961,8 +3968,12 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): ...@@ -3961,8 +3968,12 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
out = fluid.layers.resize_bilinear(input, out_shape=[12, 12]) out = fluid.layers.image_resize(input, out_shape=[12, 12])
""" """
resample_methods = {'BILINEAR': 'bilinear_interp'}
if resample not in resample_methods:
raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR' currently.")
if out_shape is None and scale is None: if out_shape is None and scale is None:
raise ValueError("One of out_shape and scale must not be None") raise ValueError("One of out_shape and scale must not be None")
helper = LayerHelper('bilinear_interp', **locals()) helper = LayerHelper('bilinear_interp', **locals())
...@@ -3990,7 +4001,7 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): ...@@ -3990,7 +4001,7 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
out = helper.create_tmp_variable(dtype) out = helper.create_tmp_variable(dtype)
helper.append_op( helper.append_op(
type="bilinear_interp", type=resample_methods[resample],
inputs=inputs, inputs=inputs,
outputs={"Out": out}, outputs={"Out": out},
attrs={"out_h": out_h, attrs={"out_h": out_h,
...@@ -3998,6 +4009,55 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): ...@@ -3998,6 +4009,55 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
return out return out
def resize_bilinear(input, out_shape=None, scale=None, name=None):
"""
This is an alias of layer 'image_resize' with bilinear interpolation.
The mathematical meaning of resize bilinear layer is
Bilinear interpolation.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this layer) on a rectilinear 2D grid.
For details, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
"""
return image_resize(input, out_shape, scale, name, 'BILINEAR')
def image_resize_short(input, out_short_len, resample='BILINEAR'):
"""
Resize a batch of images. The short edge of input images will be
resized to the given 'out_short_len'. The long edge of input images
will be resized proportionately to make images' length-width ratio
constant.
Args:
input (Variable): The input tensor of image resize layer,
This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w).
out_short_len(int): The length of output images' short edge.
Returns:
out (Variable): The output is a 4-D tensor of the shape
(num_batches, channls, out_h, out_w).
"""
in_shape = input.shape
if len(in_shape) != 4:
raise ValueError(
"The rank of input must be 4 (num_batches, channels, in_h, in_w).")
hw = in_shape[2:4]
short_idx = hw.index(min(hw))
long_idx = 1 - short_idx
out_shape = list(hw)
out_shape[short_idx] = out_short_len
out_shape[long_idx] = int(
float(out_shape[long_idx]) * (float(out_short_len) / float(hw[
short_idx])) + 0.5)
return image_resize(input=input, out_shape=out_shape, resample=resample)
def gather(input, index): def gather(input, index):
""" """
Output is obtained by gathering entries of the outer-most dimension Output is obtained by gathering entries of the outer-most dimension
...@@ -4005,7 +4065,7 @@ def gather(input, index): ...@@ -4005,7 +4065,7 @@ def gather(input, index):
.. math:: .. math::
Out = X[Index] Out = X[Index]
.. code-block:: text .. code-block:: text
...@@ -4013,8 +4073,8 @@ def gather(input, index): ...@@ -4013,8 +4073,8 @@ def gather(input, index):
Given: Given:
X = [[1, 2], X = [[1, 2],
[3, 4], [3, 4],
[5, 6]] [5, 6]]
Index = [1, 2] Index = [1, 2]
......
...@@ -43,12 +43,10 @@ list(REMOVE_ITEM TEST_OPS test_warpctc_op) ...@@ -43,12 +43,10 @@ list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list(REMOVE_ITEM TEST_OPS test_dist_train) list(REMOVE_ITEM TEST_OPS test_dist_train)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf) list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed) list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed)
# TODO(wuyi): this test hungs on CI, will add it back later
list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op)
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL) py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL) py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
# FIXME(Yancey1989): this test would cost much more time on CUDAPlace
# since load cudnn libraries, so we use a longer timeout to make this
# unit test stability.
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 30)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册