Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
44c662b4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
44c662b4
编写于
6月 06, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'origin/develop' into fix/cudnn
上级
2b9ef7e2
ea408d55
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
246 addition
and
84 deletion
+246
-84
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-0
paddle/fluid/framework/details/execution_strategy.h
paddle/fluid/framework/details/execution_strategy.h
+1
-0
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
...id/framework/details/scope_buffered_ssa_graph_executor.cc
+76
-0
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+53
-0
paddle/fluid/framework/details/ssa_graph_executor.cc
paddle/fluid/framework/details/ssa_graph_executor.cc
+0
-4
paddle/fluid/framework/details/ssa_graph_executor.h
paddle/fluid/framework/details/ssa_graph_executor.h
+1
-5
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+4
-2
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+1
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+16
-42
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+0
-1
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+0
-3
paddle/fluid/platform/dynload/cublas.h
paddle/fluid/platform/dynload/cublas.h
+1
-1
paddle/fluid/platform/dynload/cudnn.h
paddle/fluid/platform/dynload/cudnn.h
+1
-1
paddle/fluid/platform/dynload/cupti.h
paddle/fluid/platform/dynload/cupti.h
+1
-1
paddle/fluid/platform/dynload/curand.h
paddle/fluid/platform/dynload/curand.h
+1
-1
paddle/fluid/platform/dynload/nccl.h
paddle/fluid/platform/dynload/nccl.h
+1
-1
paddle/fluid/platform/dynload/tensorrt.h
paddle/fluid/platform/dynload/tensorrt.h
+1
-1
paddle/fluid/platform/dynload/warpctc.h
paddle/fluid/platform/dynload/warpctc.h
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+8
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+75
-15
python/paddle/fluid/tests/no_test_concurrency.py
python/paddle/fluid/tests/no_test_concurrency.py
+0
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-4
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
44c662b4
...
...
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor
scope_buffered_ssa_graph_executor
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
44c662b4
...
...
@@ -36,5 +36,6 @@ cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_ha
device_context broadcast_op_handle
)
cc_test
(
gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context gather_op_handle
)
cc_library
(
scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor
)
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle )
paddle/fluid/framework/details/execution_strategy.h
浏览文件 @
44c662b4
...
...
@@ -22,6 +22,7 @@ struct ExecutionStrategy {
size_t
num_threads_
{
0
};
bool
use_event_
{
true
};
bool
allow_op_delay_
{
false
};
size_t
num_iteration_per_drop_scope_
{
100
};
};
}
// namespace details
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
0 → 100644
浏览文件 @
44c662b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
ScopeBufferedSSAGraphExecutor
::
ScopeBufferedSSAGraphExecutor
(
ExecutionStrategy
strategy
,
std
::
vector
<
Scope
*>
local_scopes
,
std
::
vector
<
VariableInfo
>
var_infos
,
std
::
vector
<
platform
::
Place
>
places
,
std
::
unique_ptr
<
SSAGraphExecutor
>
&&
underlying_executor
)
:
strategy_
(
std
::
move
(
strategy
)),
underlying_executor_
(
std
::
move
(
underlying_executor
)),
local_scopes_
(
std
::
move
(
local_scopes
)),
var_infos_
(
std
::
move
(
var_infos
)),
places_
(
std
::
move
(
places
))
{}
FeedFetchList
ScopeBufferedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
if
(
drop_scope_counter_
==
0
)
{
// Create local scopes.
for
(
auto
it
=
local_scopes_
.
rbegin
();
it
!=
local_scopes_
.
rend
();
++
it
)
{
auto
&
scope
=
*
it
;
Scope
&
local_scope
=
scope
->
NewScope
();
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
()
=
&
local_scope
;
for
(
auto
&
info
:
var_infos_
)
{
if
(
scope
->
FindVar
(
info
.
name_
)
!=
nullptr
)
{
continue
;
}
if
(
info
.
persistable_
)
{
// Persistable
InitializeVariable
(
scope
->
Var
(
info
.
name_
),
info
.
type_
);
}
else
{
InitializeVariable
(
local_scope
.
Var
(
info
.
name_
),
info
.
type_
);
}
}
}
}
auto
fetch_data
=
underlying_executor_
->
Run
(
fetch_tensors
);
drop_scope_counter_
+=
1
;
if
(
!
fetch_tensors
.
empty
()
||
drop_scope_counter_
==
strategy_
.
num_iteration_per_drop_scope_
)
{
drop_scope_counter_
=
0
;
// Wait All computational streams
for
(
auto
p
:
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
}
for
(
auto
&
scope
:
local_scopes_
)
{
auto
&
local_scope
=
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
();
scope
->
DeleteScope
(
local_scope
);
}
}
return
fetch_data
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
0 → 100644
浏览文件 @
44c662b4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
struct
VariableInfo
{
std
::
string
name_
;
proto
::
VarType
::
Type
type_
;
bool
persistable_
;
};
class
ScopeBufferedSSAGraphExecutor
:
public
SSAGraphExecutor
{
public:
ScopeBufferedSSAGraphExecutor
(
ExecutionStrategy
strategy
,
std
::
vector
<
Scope
*>
local_scopes
,
std
::
vector
<
VariableInfo
>
var_infos
,
std
::
vector
<
platform
::
Place
>
places
,
std
::
unique_ptr
<
SSAGraphExecutor
>&&
underlying_executor
);
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
private:
size_t
drop_scope_counter_
{
0
};
ExecutionStrategy
strategy_
;
std
::
unique_ptr
<
SSAGraphExecutor
>
underlying_executor_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
VariableInfo
>
var_infos_
;
std
::
vector
<
platform
::
Place
>
places_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_executor.cc
浏览文件 @
44c662b4
...
...
@@ -17,10 +17,6 @@
namespace
paddle
{
namespace
framework
{
namespace
details
{
SSAGraphExecutor
::
SSAGraphExecutor
(
std
::
unique_ptr
<
SSAGraph
>
&&
graph
)
:
graph_
(
std
::
move
(
graph
))
{}
SSAGraphExecutor
::~
SSAGraphExecutor
()
{}
}
// namespace details
...
...
paddle/fluid/framework/details/ssa_graph_executor.h
浏览文件 @
44c662b4
...
...
@@ -28,15 +28,11 @@ class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN
(
SSAGraphExecutor
);
public:
// Steal graph inside
explicit
SSAGraphExecutor
(
std
::
unique_ptr
<
SSAGraph
>
&&
graph
);
SSAGraphExecutor
()
{}
virtual
~
SSAGraphExecutor
();
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
=
0
;
protected:
std
::
unique_ptr
<
SSAGraph
>
graph_
;
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
44c662b4
...
...
@@ -21,7 +21,7 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
SSAGraph
>
&&
graph
)
:
SSAGraphExecutor
(
std
::
move
(
graph
)),
:
graph_
(
std
::
move
(
graph
)),
pool_
(
strategy
.
num_threads_
>=
2
?
new
::
ThreadPool
(
strategy
.
num_threads_
)
:
nullptr
),
local_scopes_
(
local_scopes
),
...
...
@@ -189,7 +189,9 @@ void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
auto
op_run
=
[
ready_var_q
,
op
,
this
]
{
try
{
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
if
(
VLOG_IS_ON
(
10
))
{
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
}
op
->
Run
(
strategy_
.
use_event_
);
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" Done "
;
running_ops_
--
;
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
44c662b4
...
...
@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details
::
OpHandleBase
*
op
);
private:
std
::
unique_ptr
<
SSAGraph
>
graph_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
44c662b4
...
...
@@ -23,6 +23,7 @@ limitations under the License. */
#endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -42,8 +43,6 @@ class ParallelExecutorPrivate {
#ifdef PADDLE_WITH_CUDA
std
::
unique_ptr
<
platform
::
NCCLContextMap
>
nccl_ctxs_
;
#endif
std
::
vector
<
std
::
tuple
<
std
::
string
,
proto
::
VarType
::
Type
,
bool
>>
var_types_
;
bool
own_local_scope
;
};
...
...
@@ -92,9 +91,18 @@ ParallelExecutor::ParallelExecutor(
local_scopes
.
empty
())
{
// Is CUDA
BCastParamsToGPUs
(
bcast_vars
);
}
// Startup Program has been run. All local scopes has correct parameters.
// Startup Program has been run. All local scopes has correct parameters.
// Step 2. Create vars in each scope;
std
::
vector
<
details
::
VariableInfo
>
var_infos
;
for
(
auto
*
var
:
main_program
.
Block
(
0
).
AllVars
())
{
var_infos
.
emplace_back
();
var_infos
.
back
().
name_
=
var
->
Name
();
var_infos
.
back
().
type_
=
var
->
GetType
();
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
}
// Step
2
. Convert main_program to SSA form and dependency graph. Also, insert
// Step
3
. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
details
::
MultiDevSSAGraphBuilder
builder
(
...
...
@@ -105,16 +113,15 @@ ParallelExecutor::ParallelExecutor(
params
,
member_
->
local_scopes_
,
build_strategy
);
#endif
auto
graph
=
builder
.
Build
(
main_program
);
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
// Step 3. Create vars in each scope;
for
(
auto
*
var
:
main_program
.
Block
(
0
).
AllVars
())
{
member_
->
var_types_
.
emplace_back
(
var
->
Name
(),
var
->
GetType
(),
var
->
Persistable
());
}
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
std
::
move
(
var_infos
),
member_
->
places_
,
std
::
move
(
member_
->
executor_
)));
}
void
ParallelExecutor
::
BCastParamsToGPUs
(
...
...
@@ -169,42 +176,9 @@ void ParallelExecutor::BCastParamsToGPUs(
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
platform
::
RecordBlock
b
(
0
);
// Create local scopes.
for
(
auto
it
=
member_
->
local_scopes_
.
rbegin
();
it
!=
member_
->
local_scopes_
.
rend
();
++
it
)
{
auto
&
scope
=
*
it
;
Scope
&
local_scope
=
scope
->
NewScope
();
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
()
=
&
local_scope
;
for
(
auto
&
name_type_pair
:
member_
->
var_types_
)
{
if
(
scope
->
FindVar
(
std
::
get
<
0
>
(
name_type_pair
))
!=
nullptr
)
{
continue
;
}
if
(
std
::
get
<
2
>
(
name_type_pair
))
{
// Persistable
InitializeVariable
(
scope
->
Var
(
std
::
get
<
0
>
(
name_type_pair
)),
std
::
get
<
1
>
(
name_type_pair
));
}
else
{
InitializeVariable
(
local_scope
.
Var
(
std
::
get
<
0
>
(
name_type_pair
)),
std
::
get
<
1
>
(
name_type_pair
));
}
}
}
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
fetch_data
;
// Wait All computational streams
for
(
auto
p
:
member_
->
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
}
for
(
auto
&
scope
:
member_
->
local_scopes_
)
{
auto
&
local_scope
=
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
();
scope
->
DeleteScope
(
local_scope
);
}
}
void
ParallelExecutor
::
FeedTensorsIntoLocalScopes
(
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
44c662b4
...
...
@@ -175,7 +175,6 @@ CUDADeviceContext::~CUDADeviceContext() {
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
void
CUDADeviceContext
::
Wait
()
const
{
std
::
lock_guard
<
std
::
recursive_mutex
>
guard
(
mutex_
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaGetLastError
());
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
44c662b4
...
...
@@ -100,7 +100,6 @@ class CUDADeviceContext : public DeviceContext {
template
<
typename
Callback
>
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
guard
(
mutex_
);
callback
();
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
}
...
...
@@ -110,8 +109,6 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
mutable
std
::
recursive_mutex
mutex_
;
cudaStream_t
stream_
;
cudnnHandle_t
cudnn_handle_
;
cublasHandle_t
cublas_handle_
;
...
...
paddle/fluid/platform/dynload/cublas.h
浏览文件 @
44c662b4
...
...
@@ -45,7 +45,7 @@ extern void *cublas_dso_handle;
std::call_once(cublas_dso_flag, []() { \
cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \
}); \
void *p_##__name = dlsym(cublas_dso_handle, #__name);
\
static void *p_##__name = dlsym(cublas_dso_handle, #__name);
\
return reinterpret_cast<FUNC_TYPE>(p_##__name)(args...); \
} \
}; \
...
...
paddle/fluid/platform/dynload/cudnn.h
浏览文件 @
44c662b4
...
...
@@ -39,7 +39,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
}); \
EnforceCUDNNLoaded(#__name); \
void* p_##__name = dlsym(cudnn_dso_handle, #__name);
\
static void* p_##__name = dlsym(cudnn_dso_handle, #__name);
\
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
} \
}; \
...
...
paddle/fluid/platform/dynload/cupti.h
浏览文件 @
44c662b4
...
...
@@ -45,7 +45,7 @@ extern void *cupti_dso_handle;
std::call_once(cupti_dso_flag, []() { \
cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \
}); \
void *p_##__name = dlsym(cupti_dso_handle, #__name);
\
static void *p_##__name = dlsym(cupti_dso_handle, #__name);
\
return reinterpret_cast<cuptiFunc>(p_##__name)(args...); \
} \
}; \
...
...
paddle/fluid/platform/dynload/curand.h
浏览文件 @
44c662b4
...
...
@@ -34,7 +34,7 @@ extern void *curand_dso_handle;
std::call_once(curand_dso_flag, []() { \
curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \
}); \
void *p_##__name = dlsym(curand_dso_handle, #__name);
\
static void *p_##__name = dlsym(curand_dso_handle, #__name);
\
return reinterpret_cast<curandFunc>(p_##__name)(args...); \
} \
}; \
...
...
paddle/fluid/platform/dynload/nccl.h
浏览文件 @
44c662b4
...
...
@@ -37,7 +37,7 @@ extern void* nccl_dso_handle;
std::call_once(nccl_dso_flag, []() { \
nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \
}); \
void* p_##__name = dlsym(nccl_dso_handle, #__name);
\
static void* p_##__name = dlsym(nccl_dso_handle, #__name);
\
return reinterpret_cast<nccl_func>(p_##__name)(args...); \
} \
}; \
...
...
paddle/fluid/platform/dynload/tensorrt.h
浏览文件 @
44c662b4
...
...
@@ -40,7 +40,7 @@ extern void* tensorrt_dso_handle;
paddle::platform::dynload::GetTensorRtDsoHandle(); \
PADDLE_ENFORCE(tensorrt_dso_handle, "load tensorrt so failed"); \
}); \
void* p_##__name = dlsym(tensorrt_dso_handle, #__name);
\
static void* p_##__name = dlsym(tensorrt_dso_handle, #__name);
\
PADDLE_ENFORCE(p_##__name, "load %s failed", #__name); \
return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \
} \
...
...
paddle/fluid/platform/dynload/warpctc.h
浏览文件 @
44c662b4
...
...
@@ -40,7 +40,7 @@ extern void* warpctc_dso_handle;
std::call_once(warpctc_dso_flag, []() { \
warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
}); \
void* p_##_name = dlsym(warpctc_dso_handle, #__name);
\
static void* p_##_name = dlsym(warpctc_dso_handle, #__name);
\
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
} \
}; \
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
44c662b4
...
...
@@ -519,6 +519,14 @@ All parameter, weight, gradient are variables in Paddle.
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
allow_op_delay_
;
},
[](
ExecutionStrategy
&
self
,
bool
allow_op_delay
)
{
self
.
allow_op_delay_
=
allow_op_delay
;
})
.
def_property
(
"num_iteration_per_drop_scope"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
num_iteration_per_drop_scope_
;
},
[](
ExecutionStrategy
&
self
,
size_t
num_iteration_per_drop_scope
)
{
self
.
num_iteration_per_drop_scope_
=
num_iteration_per_drop_scope
;
});
py
::
class_
<
BuildStrategy
>
build_strategy
(
pe
,
"BuildStrategy"
);
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
44c662b4
...
...
@@ -81,6 +81,8 @@ __all__ = [
'label_smooth'
,
'roi_pool'
,
'dice_loss'
,
'image_resize'
,
'image_resize_short'
,
'resize_bilinear'
,
'gather'
,
'random_crop'
,
...
...
@@ -3929,22 +3931,25 @@ def dice_loss(input, label, epsilon=0.00001):
return
reduce_mean
(
dice_score
)
def
resize_bilinear
(
input
,
out_shape
=
None
,
scale
=
None
,
name
=
None
):
def
image_resize
(
input
,
out_shape
=
None
,
scale
=
None
,
name
=
None
,
resample
=
'BILINEAR'
):
"""
The mathematical meaning of resize bilinear layer is
Bilinear interpolation.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this layer) on a rectilinear 2D grid.
Resize a batch of images.
For details, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
The input must be a tensor of the shape (num_batches, channels, in_h, in_w),
and the resizing only applies on the last two dimensions(hight and width).
Supporting resample methods:
'BILINEAR' : Bilinear interpolation
Args:
input (Variable): The input tensor of
resize bilinear
layer,
input (Variable): The input tensor of
image resize
layer,
This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w).
out_shape(list|tuple|Variable|None): Output shape of
resize bilinear
out_shape(list|tuple|Variable|None): Output shape of
image resize
layer, the shape is (out_h, out_w).
Default: None
scale(float|None): The multiplier for the input height or width.
...
...
@@ -3953,6 +3958,8 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
Default: None
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
resample(str): The resample method. It can only be 'BILINEAR' currently.
Default: 'BILINEAR'
Returns:
out (Variable): The output is a 4-D tensor of the shape
...
...
@@ -3961,8 +3968,12 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
Examples:
.. code-block:: python
out = fluid.layers.
resize_bilinear
(input, out_shape=[12, 12])
out = fluid.layers.
image_resize
(input, out_shape=[12, 12])
"""
resample_methods
=
{
'BILINEAR'
:
'bilinear_interp'
}
if
resample
not
in
resample_methods
:
raise
ValueError
(
"The 'resample' of image_resize can only be 'BILINEAR' currently."
)
if
out_shape
is
None
and
scale
is
None
:
raise
ValueError
(
"One of out_shape and scale must not be None"
)
helper
=
LayerHelper
(
'bilinear_interp'
,
**
locals
())
...
...
@@ -3990,7 +4001,7 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
out
=
helper
.
create_tmp_variable
(
dtype
)
helper
.
append_op
(
type
=
"bilinear_interp"
,
type
=
resample_methods
[
resample
]
,
inputs
=
inputs
,
outputs
=
{
"Out"
:
out
},
attrs
=
{
"out_h"
:
out_h
,
...
...
@@ -3998,6 +4009,55 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
return
out
def
resize_bilinear
(
input
,
out_shape
=
None
,
scale
=
None
,
name
=
None
):
"""
This is an alias of layer 'image_resize' with bilinear interpolation.
The mathematical meaning of resize bilinear layer is
Bilinear interpolation.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this layer) on a rectilinear 2D grid.
For details, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
"""
return
image_resize
(
input
,
out_shape
,
scale
,
name
,
'BILINEAR'
)
def
image_resize_short
(
input
,
out_short_len
,
resample
=
'BILINEAR'
):
"""
Resize a batch of images. The short edge of input images will be
resized to the given 'out_short_len'. The long edge of input images
will be resized proportionately to make images' length-width ratio
constant.
Args:
input (Variable): The input tensor of image resize layer,
This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w).
out_short_len(int): The length of output images' short edge.
Returns:
out (Variable): The output is a 4-D tensor of the shape
(num_batches, channls, out_h, out_w).
"""
in_shape
=
input
.
shape
if
len
(
in_shape
)
!=
4
:
raise
ValueError
(
"The rank of input must be 4 (num_batches, channels, in_h, in_w)."
)
hw
=
in_shape
[
2
:
4
]
short_idx
=
hw
.
index
(
min
(
hw
))
long_idx
=
1
-
short_idx
out_shape
=
list
(
hw
)
out_shape
[
short_idx
]
=
out_short_len
out_shape
[
long_idx
]
=
int
(
float
(
out_shape
[
long_idx
])
*
(
float
(
out_short_len
)
/
float
(
hw
[
short_idx
]))
+
0.5
)
return
image_resize
(
input
=
input
,
out_shape
=
out_shape
,
resample
=
resample
)
def
gather
(
input
,
index
):
"""
Output is obtained by gathering entries of the outer-most dimension
...
...
@@ -4005,7 +4065,7 @@ def gather(input, index):
.. math::
Out = X[Index]
Out = X[Index]
.. code-block:: text
...
...
@@ -4013,8 +4073,8 @@ def gather(input, index):
Given:
X = [[1, 2],
[3, 4],
X = [[1, 2],
[3, 4],
[5, 6]]
Index = [1, 2]
...
...
python/paddle/fluid/tests/test_concurrency.py
→
python/paddle/fluid/tests/
no_
test_concurrency.py
浏览文件 @
44c662b4
文件已移动
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
44c662b4
...
...
@@ -43,12 +43,10 @@ list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list
(
REMOVE_ITEM TEST_OPS test_dist_train
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_executor_crf
)
list
(
REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed
)
# TODO(wuyi): this test hungs on CI, will add it back later
list
(
REMOVE_ITEM TEST_OPS test_listen_and_serv_op
)
foreach
(
TEST_OP
${
TEST_OPS
}
)
py_test_modules
(
${
TEST_OP
}
MODULES
${
TEST_OP
}
)
endforeach
(
TEST_OP
)
py_test_modules
(
test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=
${
WARPCTC_LIB_DIR
}
SERIAL
)
py_test_modules
(
test_dist_train MODULES test_dist_train SERIAL
)
# FIXME(Yancey1989): this test would cost much more time on CUDAPlace
# since load cudnn libraries, so we use a longer timeout to make this
# unit test stability.
set_tests_properties
(
test_listen_and_serv_op PROPERTIES TIMEOUT 30
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录