Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cbe128bb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cbe128bb
编写于
10月 27, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into optimize-sum-seq-pooling-op
上级
de539d72
5ed3e6f3
变更
49
隐藏空白更改
内联
并排
Showing
49 changed file
with
1070 addition
and
328 deletion
+1070
-328
.gitignore
.gitignore
+1
-0
Dockerfile
Dockerfile
+3
-3
paddle/fluid/API.spec
paddle/fluid/API.spec
+3
-2
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+3
-3
paddle/fluid/framework/mixed_vector.h
paddle/fluid/framework/mixed_vector.h
+27
-0
paddle/fluid/framework/naive_executor.cc
paddle/fluid/framework/naive_executor.cc
+0
-17
paddle/fluid/framework/naive_executor.h
paddle/fluid/framework/naive_executor.h
+0
-2
paddle/fluid/framework/op_proto_maker.cc
paddle/fluid/framework/op_proto_maker.cc
+2
-0
paddle/fluid/framework/op_proto_maker.h
paddle/fluid/framework/op_proto_maker.h
+3
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+0
-6
paddle/fluid/framework/threadpool.cc
paddle/fluid/framework/threadpool.cc
+9
-22
paddle/fluid/framework/threadpool.h
paddle/fluid/framework/threadpool.h
+0
-24
paddle/fluid/framework/threadpool_test.cc
paddle/fluid/framework/threadpool_test.cc
+10
-6
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+3
-0
paddle/fluid/inference/analysis/analyzer.h
paddle/fluid/inference/analysis/analyzer.h
+0
-1
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+1
-7
paddle/fluid/operators/detection/generate_proposals_op.cc
paddle/fluid/operators/detection/generate_proposals_op.cc
+1
-1
paddle/fluid/operators/dropout_op.cc
paddle/fluid/operators/dropout_op.cc
+28
-2
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+22
-7
paddle/fluid/operators/dropout_op.h
paddle/fluid/operators/dropout_op.h
+16
-3
paddle/fluid/operators/fusion_gru_op.cc
paddle/fluid/operators/fusion_gru_op.cc
+50
-98
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+2
-1
paddle/fluid/operators/math/algorithm.h
paddle/fluid/operators/math/algorithm.h
+46
-0
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+9
-0
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+180
-55
paddle/fluid/operators/math/sequence_pooling.cc
paddle/fluid/operators/math/sequence_pooling.cc
+32
-7
paddle/fluid/operators/math/sequence_pooling_test.cc
paddle/fluid/operators/math/sequence_pooling_test.cc
+126
-0
paddle/fluid/operators/sequence_reverse_op.cc
paddle/fluid/operators/sequence_reverse_op.cc
+29
-0
paddle/fluid/operators/sequence_reverse_op.cu
paddle/fluid/operators/sequence_reverse_op.cu
+25
-0
paddle/fluid/operators/sequence_reverse_op.h
paddle/fluid/operators/sequence_reverse_op.h
+157
-0
paddle/fluid/operators/softmax_cudnn_op.cu.cc
paddle/fluid/operators/softmax_cudnn_op.cu.cc
+3
-1
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+8
-5
paddle/fluid/operators/transpose_op.cu.cc
paddle/fluid/operators/transpose_op.cu.cc
+8
-5
python/paddle/dataset/wmt16.py
python/paddle/dataset/wmt16.py
+1
-1
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+5
-4
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+12
-3
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+68
-11
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+9
-4
python/paddle/fluid/regularizer.py
python/paddle/fluid/regularizer.py
+2
-1
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-3
python/paddle/fluid/tests/unittests/dist_transformer.py
python/paddle/fluid/tests/unittests/dist_transformer.py
+8
-15
python/paddle/fluid/tests/unittests/test_dist_mnist.py
python/paddle/fluid/tests/unittests/test_dist_mnist.py
+2
-1
python/paddle/fluid/tests/unittests/test_dist_se_resnext.py
python/paddle/fluid/tests/unittests/test_dist_se_resnext.py
+2
-1
python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py
python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py
+4
-2
python/paddle/fluid/tests/unittests/test_dist_transformer.py
python/paddle/fluid/tests/unittests/test_dist_transformer.py
+4
-2
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+63
-0
python/paddle/fluid/tests/unittests/test_fusion_gru_op.py
python/paddle/fluid/tests/unittests/test_fusion_gru_op.py
+6
-0
python/paddle/fluid/tests/unittests/test_sequence_reverse.py
python/paddle/fluid/tests/unittests/test_sequence_reverse.py
+69
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+5
-2
未找到文件。
.gitignore
浏览文件 @
cbe128bb
...
...
@@ -28,3 +28,4 @@ third_party/
build_*
# clion workspace.
cmake-build-*
model_test
Dockerfile
浏览文件 @
cbe128bb
...
...
@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \
pip3
install
-U
docopt PyYAML
sphinx
==
1.5.6
&&
\
pip3
install
sphinx-rtd-theme
==
0.1.9 recommonmark
&&
\
easy_install
-U
pip
&&
\
pip
install
-U
wheel
&&
\
pip
install
-U
pip setuptools
wheel
&&
\
pip
install
-U
docopt PyYAML
sphinx
==
1.5.6
&&
\
pip
install
sphinx-rtd-theme
==
0.1.9 recommonmark
RUN
pip3
install
pre-commit
'ipython==5.3.0'
&&
\
RUN
pip3
install
'pre-commit==1.10.4'
'ipython==5.3.0'
&&
\
pip3
install
'ipykernel==4.6.0'
'jupyter==1.0.0'
&&
\
pip3
install
opencv-python
&&
\
pip
install
pre-commit
'ipython==5.3.0'
&&
\
pip
install
'pre-commit==1.10.4'
'ipython==5.3.0'
&&
\
pip
install
'ipykernel==4.6.0'
'jupyter==1.0.0'
&&
\
pip
install
opencv-python
...
...
paddle/fluid/API.spec
浏览文件 @
cbe128bb
...
...
@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'
], varargs=None, keywords=None, defaults=(False, None, None
))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'
, 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer'
))
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))
...
...
@@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label',
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None,
Tru
e, None))
paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None,
Fals
e, None))
paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None))
...
...
@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
cbe128bb
...
...
@@ -252,9 +252,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std
::
vector
<
ir
::
Node
*>
sorted_ret
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
i
<
last_backward
)
{
if
(
boost
::
get
<
int
>
(
ret
[
i
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
if
(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
ret
[
i
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kOptimize
)
))
{
optimize_ops
.
push_back
(
ret
[
i
]);
}
else
{
sorted_ret
.
push_back
(
ret
[
i
]);
...
...
paddle/fluid/framework/mixed_vector.h
浏览文件 @
cbe128bb
...
...
@@ -542,6 +542,33 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
this
->
reserve
(
this
->
size
()
+
size_t
(
end
-
begin
));
this
->
insert
(
this
->
end
(),
begin
,
end
);
}
const
T
*
CUDAData
(
platform
::
Place
place
)
const
{
PADDLE_THROW
(
"Vector::CUDAData() method is not supported in CPU-only version"
);
}
T
*
CUDAMutableData
(
platform
::
Place
place
)
{
PADDLE_THROW
(
"Vector::CUDAMutableData() method is not supported in CPU-only "
"version"
);
}
const
T
*
Data
(
platform
::
Place
place
)
const
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
place
),
"Vector::Data() method is not supported when not in CPUPlace"
);
return
this
->
data
();
}
T
*
MutableData
(
platform
::
Place
place
)
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
place
),
"Vector::MutableData() method is not supported when not in CPUPlace"
);
return
this
->
data
();
}
const
void
*
Handle
()
const
{
return
static_cast
<
const
void
*>
(
this
);
}
};
template
<
typename
T
>
...
...
paddle/fluid/framework/naive_executor.cc
浏览文件 @
cbe128bb
...
...
@@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() {
ops_
.
swap
(
ops
);
}
void
NaiveExecutor
::
EnableMKLDNN
(
const
ProgramDesc
&
program
)
{
#ifdef PADDLE_WITH_MKLDNN
VLOG
(
3
)
<<
"use_mkldnn=True"
;
for
(
size_t
block_id
=
0
;
block_id
<
program
.
Size
();
++
block_id
)
{
auto
*
block
=
const_cast
<
ProgramDesc
&>
(
program
).
MutableBlock
(
block_id
);
for
(
auto
*
op
:
block
->
AllOps
())
{
if
(
op
->
HasAttr
(
"use_mkldnn"
))
{
op
->
SetAttr
(
"use_mkldnn"
,
true
);
}
}
}
#else
LOG
(
WARNING
)
<<
"'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"
;
#endif
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/naive_executor.h
浏览文件 @
cbe128bb
...
...
@@ -48,8 +48,6 @@ class NaiveExecutor {
void
CleanFeedFetchOps
();
void
EnableMKLDNN
(
const
ProgramDesc
&
program
);
protected:
void
CreateVariables
(
const
ProgramDesc
&
desc
,
Scope
*
scope
,
int
block_id
);
...
...
paddle/fluid/framework/op_proto_maker.cc
浏览文件 @
cbe128bb
...
...
@@ -71,6 +71,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kBackward
),
static_cast
<
int
>
(
OpRole
::
kOptimize
)
|
static_cast
<
int
>
(
OpRole
::
kLRSched
),
static_cast
<
int
>
(
OpRole
::
kNotSpecified
)})
.
SetDefault
(
static_cast
<
int
>
(
OpRole
::
kNotSpecified
));
AddAttr
<
std
::
vector
<
std
::
string
>>
(
OpRoleVarAttrName
(),
...
...
paddle/fluid/framework/op_proto_maker.h
浏览文件 @
cbe128bb
...
...
@@ -20,6 +20,9 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
//////////////////////////
// Don't add more roles to make this too complicated!
//////////////////////////
enum
class
OpRole
{
kForward
=
0x0000
,
kBackward
=
0x0001
,
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
cbe128bb
...
...
@@ -156,12 +156,6 @@ ParallelExecutor::ParallelExecutor(
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
);
#endif
// If the loss_var_name is given, the number of graph should be only one.
if
(
loss_var_name
.
size
())
{
PADDLE_ENFORCE_EQ
(
ir
::
GraphNum
(
*
graph
),
1
,
"The number of graph should be only one"
);
}
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
...
...
paddle/fluid/framework/threadpool.cc
浏览文件 @
cbe128bb
...
...
@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0,
namespace
paddle
{
namespace
framework
{
std
::
unique_ptr
<
ThreadPool
>
ThreadPool
::
threadpool_
(
nullptr
);
std
::
once_flag
ThreadPool
::
init_flag_
;
...
...
@@ -47,8 +46,7 @@ void ThreadPool::Init() {
}
}
ThreadPool
::
ThreadPool
(
int
num_threads
)
:
total_threads_
(
num_threads
),
idle_threads_
(
num_threads
),
running_
(
true
)
{
ThreadPool
::
ThreadPool
(
int
num_threads
)
:
running_
(
true
)
{
threads_
.
resize
(
num_threads
);
for
(
auto
&
thread
:
threads_
)
{
// TODO(Yancey1989): binding the thread on the specify CPU number
...
...
@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads)
ThreadPool
::~
ThreadPool
()
{
{
// notify all threads to stop running
std
::
lock_guard
<
std
::
mutex
>
l
(
mutex_
);
running_
=
false
;
scheduled_
.
notify_all
();
}
...
...
@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() {
}
}
void
ThreadPool
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
completed_
.
wait
(
lock
,
[
=
]
{
return
Done
()
==
true
;
});
}
void
ThreadPool
::
TaskLoop
()
{
while
(
running_
)
{
while
(
true
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
scheduled_
.
wait
(
lock
,
[
=
]
{
return
!
tasks_
.
empty
()
||
!
running_
;
});
if
(
!
running_
)
{
break
;
scheduled_
.
wait
(
lock
,
[
this
]
{
return
!
this
->
tasks_
.
empty
()
||
!
this
->
running_
;
});
if
(
!
running_
||
tasks_
.
empty
())
{
return
;
}
// pop a task from the task queue
auto
task
=
std
::
move
(
tasks_
.
front
());
tasks_
.
pop
();
--
idle_threads_
;
lock
.
unlock
();
// run the task
task
();
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
++
idle_threads_
;
if
(
Done
())
{
completed_
.
notify_all
();
}
}
}
}
...
...
paddle/fluid/framework/threadpool.h
浏览文件 @
cbe128bb
...
...
@@ -57,15 +57,6 @@ class ThreadPool {
~
ThreadPool
();
// Returns the number of threads created by the constructor.
size_t
Threads
()
const
{
return
total_threads_
;
}
// Returns the number of currently idle threads.
size_t
IdleThreads
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
idle_threads_
;
}
// Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call
// std::future::wait().
...
...
@@ -94,25 +85,13 @@ class ThreadPool {
});
std
::
future
<
std
::
unique_ptr
<
platform
::
EnforceNotMet
>>
f
=
task
.
get_future
();
tasks_
.
push
(
std
::
move
(
task
));
lock
.
unlock
();
scheduled_
.
notify_one
();
return
f
;
}
// Wait until all the tasks are completed.
void
Wait
();
private:
DISABLE_COPY_AND_ASSIGN
(
ThreadPool
);
// If the task queue is empty and avaialbe is equal to the number of
// threads, means that all tasks are completed. Note: this function
// is not thread-safe. Returns true if all tasks are completed.
// Note: don't delete the data member total_threads_ and use
// threads_.size() instead; because you'd need to lock the mutex
// before accessing threads_.
bool
Done
()
{
return
tasks_
.
empty
()
&&
idle_threads_
==
total_threads_
;
}
// The constructor starts threads to run TaskLoop, which retrieves
// and runs tasks from the queue.
void
TaskLoop
();
...
...
@@ -125,14 +104,11 @@ class ThreadPool {
static
std
::
once_flag
init_flag_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
threads_
;
const
size_t
total_threads_
;
size_t
idle_threads_
;
std
::
queue
<
Task
>
tasks_
;
std
::
mutex
mutex_
;
bool
running_
;
std
::
condition_variable
scheduled_
;
std
::
condition_variable
completed_
;
};
class
ThreadPoolIO
:
ThreadPool
{
...
...
paddle/fluid/framework/threadpool_test.cc
浏览文件 @
cbe128bb
...
...
@@ -19,10 +19,11 @@ limitations under the License. */
namespace
framework
=
paddle
::
framework
;
void
do_sum
(
framework
::
ThreadPool
*
pool
,
std
::
atomic
<
int
>*
sum
,
int
cnt
)
{
std
::
vector
<
std
::
future
<
void
>>
fs
;
void
do_sum
(
std
::
vector
<
std
::
future
<
void
>>*
fs
,
std
::
mutex
*
mu
,
std
::
atomic
<
int
>*
sum
,
int
cnt
)
{
for
(
int
i
=
0
;
i
<
cnt
;
++
i
)
{
fs
.
push_back
(
framework
::
Async
([
sum
]()
{
sum
->
fetch_add
(
1
);
}));
std
::
lock_guard
<
std
::
mutex
>
l
(
*
mu
);
fs
->
push_back
(
framework
::
Async
([
sum
]()
{
sum
->
fetch_add
(
1
);
}));
}
}
...
...
@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) {
}
TEST
(
ThreadPool
,
ConcurrentRun
)
{
framework
::
ThreadPool
*
pool
=
framework
::
ThreadPool
::
GetInstance
();
std
::
atomic
<
int
>
sum
(
0
);
std
::
vector
<
std
::
thread
>
threads
;
std
::
vector
<
std
::
future
<
void
>>
fs
;
std
::
mutex
fs_mu
;
int
n
=
50
;
// sum = (n * (n + 1)) / 2
for
(
int
i
=
1
;
i
<=
n
;
++
i
)
{
std
::
thread
t
(
do_sum
,
pool
,
&
sum
,
i
);
std
::
thread
t
(
do_sum
,
&
fs
,
&
fs_mu
,
&
sum
,
i
);
threads
.
push_back
(
std
::
move
(
t
));
}
for
(
auto
&
t
:
threads
)
{
t
.
join
();
}
pool
->
Wait
();
for
(
auto
&
t
:
fs
)
{
t
.
wait
();
}
EXPECT_EQ
(
sum
,
((
n
+
1
)
*
n
)
/
2
);
}
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
cbe128bb
...
...
@@ -107,6 +107,9 @@ void Analyzer::Run(Argument* argument) {
passes
.
push_back
(
"mkldnn_placement_pass"
);
}
#endif
// infer_clean_graph_pass should be the first default pass
// after mkldnn_placement_pass.
passes
.
push_back
(
"infer_clean_graph_pass"
);
for
(
auto
&
pass
:
ir_passes_
)
{
if
(
!
disabled_ir_passes_
.
count
(
pass
))
{
passes
.
push_back
(
pass
);
...
...
paddle/fluid/inference/analysis/analyzer.h
浏览文件 @
cbe128bb
...
...
@@ -67,7 +67,6 @@ class Analyzer : public OrderedRegistry<PassManager> {
// larger fusion.
const
std
::
vector
<
std
::
string
>
all_ir_passes_
{{
// Manual update the passes here.
"infer_clean_graph_pass"
,
//
"attention_lstm_fuse_pass"
,
//
"seqconv_eltadd_relu_fuse_pass"
,
//
"embedding_fc_lstm_fuse_pass"
,
//
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
cbe128bb
...
...
@@ -124,7 +124,7 @@ class ZeroCopyTensor {
std
::
vector
<
std
::
vector
<
size_t
>>
lod
()
const
;
protected:
ZeroCopyTensor
(
void
*
scope
)
:
scope_
{
scope
}
{}
explicit
ZeroCopyTensor
(
void
*
scope
)
:
scope_
{
scope
}
{}
void
SetName
(
const
std
::
string
&
name
)
{
name_
=
name
;
}
void
*
FindTensor
()
const
;
...
...
@@ -259,12 +259,6 @@ struct AnalysisConfig : public NativeConfig {
kExclude
// Specify the disabled passes in `ir_passes`.
};
void
SetIncludeMode
()
{
ir_mode
=
IrPassMode
::
kInclude
;
// this pass has to be run at the beginning of all fuse passes
ir_passes
=
{
"infer_clean_graph_pass"
};
}
// Determine whether to perform graph optimization.
bool
enable_ir_optim
=
true
;
// Manually determine the IR passes to run.
...
...
paddle/fluid/operators/detection/generate_proposals_op.cc
浏览文件 @
cbe128bb
...
...
@@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
selected_indices
.
push_back
(
idx
);
++
selected_num
;
}
sorted_indices
.
erase
(
sorted_indices
.
end
());
sorted_indices
.
erase
(
sorted_indices
.
end
()
-
1
);
if
(
flag
&&
eta
<
1
&&
adaptive_threshold
>
0.5
)
{
adaptive_threshold
*=
eta
;
}
...
...
paddle/fluid/operators/dropout_op.cc
浏览文件 @
cbe128bb
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <string>
namespace
paddle
{
namespace
operators
{
...
...
@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"will be dropped."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"seed"
,
"Dropout random seed."
).
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"dropout_implementation"
,
"[
\"
downgrade_in_infer
\"
|
\"
upscale_in_train
\"
]"
"There are two kinds of ways to implement dropout"
"(the mask below is a tensor have the same shape with input"
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
"1. downgrade_in_infer(default), downgrade the outcome at inference "
"time"
" train: out = input * mask"
" inference: out = input * dropout_prob"
"2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference"
" train: out = input * mask / ( 1.0 - dropout_prob )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient"
)
.
SetDefault
(
"downgrade_in_infer"
)
.
AddCustomChecker
([](
const
std
::
string
&
type
)
{
PADDLE_ENFORCE
(
type
==
"downgrade_in_infer"
||
type
==
"upscale_in_train"
,
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"
);
});
AddComment
(
R"DOC(
Dropout Operator.
...
...
@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
dropout_grad
,
ops
::
DropoutOpGrad
);
REGISTER_OP_CPU_KERNEL
(
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.cu
浏览文件 @
cbe128bb
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <string>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -26,7 +27,8 @@ namespace operators {
template
<
typename
T
>
__global__
void
RandomGenerator
(
const
size_t
n
,
const
int
seed
,
const
float
dropout_prob
,
const
T
*
src
,
T
*
mask_data
,
T
*
dst
)
{
T
*
mask_data
,
T
*
dst
,
bool
is_upscale_in_train
)
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
...
...
@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed,
if
(
dist
(
rng
)
<
dropout_prob
)
{
mask
=
static_cast
<
T
>
(
0
);
}
else
{
mask
=
static_cast
<
T
>
(
1
);
if
(
is_upscale_in_train
)
{
mask
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
}
else
{
mask
=
static_cast
<
T
>
(
1
);
}
}
dest
=
s
*
mask
;
mask_data
[
idx
]
=
mask
;
...
...
@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
auto
&
place
=
*
context
.
template
device_context
<
Place
>().
eigen_device
();
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
...
...
@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int
grid
=
(
x
->
numel
()
+
threads
-
1
)
/
threads
;
RandomGenerator
<
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
);
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
(
dropout_implementation
==
"upscale_in_train"
));
}
else
{
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
Y
.
device
(
place
)
=
X
;
}
else
{
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
...
...
@@ -99,6 +112,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
dropout
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.h
浏览文件 @
cbe128bb
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
@@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
size_t
size
=
framework
::
product
(
mask
->
dims
());
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
dist
(
engine
)
<
dropout_prob
)
{
mask_data
[
i
]
=
0
;
y_data
[
i
]
=
0
;
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
if
(
dropout_implementation
==
"upscale_in_train"
)
{
mask_data
[
i
]
=
1.0
f
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
y_data
[
i
]
=
x_data
[
i
]
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
}
}
}
}
else
{
...
...
@@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
Y
.
device
(
place
)
=
X
*
(
1.0
f
-
dropout_prob
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
Y
.
device
(
place
)
=
X
;
}
else
{
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
...
...
paddle/fluid/operators/fusion_gru_op.cc
浏览文件 @
cbe128bb
...
...
@@ -16,10 +16,9 @@ limitations under the License. */
#include <cstring> // for memcpy
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> {
}
}
#define INIT_VEC_FUNC \
std::function<void(const int, const T *, T *)> act_gate, act_state; \
std::function<void(const int, const T*, const T*, const T*, T*)> cross; \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto& act_state_str = ctx.Attr<std::string>("activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \
act_gate = act_functor(act_gate_str); \
act_state = act_functor(act_state_str); \
cross = math::vec_cross<T, platform::jit::avx>; \
} else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_state = act_functor(act_state_str); \
cross = math::vec_cross<T, platform::jit::isa_any>; \
}
#define INIT_BASE_INPUT_OUTPUT \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse");
#define INIT_BASE_SIZES \
auto x_dims = x->dims();
/* T x M*/
\
auto wh_dims = wh->dims();
/* D x 3D*/
\
const int total_T = x_dims[0]; \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D3 = wh_dims[1]; \
const int D2 = D * 2;
#define INIT_BASE_DEFINES \
auto* x = ctx.Input<LoDTensor>("X"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto x_lod = x->lod(); \
auto x_dims = x->dims();
/* T x M*/
\
auto wh_dims = wh->dims();
/* D x 3D*/
\
const int total_T = x_dims[0]; \
const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const auto& ker = math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::GRUKernel<T>, \
const std::string&, const std::string&>( \
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("activation"), D); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place)
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
auto
x_lod
=
x
->
lod
();
INIT_BASE_DEFINES
;
INIT_OTHER_DEFINES
;
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
const
T
*
wh_state_data
=
wh_data
+
D
*
D2
;
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
hidden_out_data
=
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
hidden_out_data
=
hidden_out
->
mutable_data
<
T
>
(
place
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
D3
,
M
,
x_data
,
wx_data
,
xx_data
,
...
...
@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if
(
h0_data
)
{
prev_hidden_data
=
h0_data
+
bid
*
D
;
}
else
{
// W: {W_update, W_reset; W_state}
// update gate
act_gate
(
D
,
xx_data
,
xx_data
);
// state gate
act_state
(
D
,
xx_data
+
D2
,
xx_data
+
D2
);
// out = a*b
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D2
,
hidden_out_data
);
// save prev
ker
->
ComputeH1
(
xx_data
,
hidden_out_data
);
prev_hidden_data
=
hidden_out_data
;
tstart
=
1
;
move_step
();
...
...
@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D2
,
D
,
static_cast
<
T
>
(
1
),
prev_hidden_data
,
D
,
wh_data
,
D2
,
static_cast
<
T
>
(
1
),
xx_data
,
D3
);
act_gate
(
D2
,
xx_data
,
xx_data
);
// rt = rt*ht_1 inplace result
blas
.
VMUL
(
D
,
prev_hidden_data
,
xx_data
+
D
,
hidden_out_data
);
ker
->
ComputeHtPart1
(
xx_data
,
prev_hidden_data
,
hidden_out_data
);
// gemm rt * Ws
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D
,
D
,
static_cast
<
T
>
(
1
),
hidden_out_data
,
D
,
wh_state_data
,
D
,
static_cast
<
T
>
(
1
),
xx_data
+
D2
,
D3
);
act_state
(
D
,
xx_data
+
D2
,
xx_data
+
D2
);
// out = zt*ht~ + (1-zt)*ht_1
cross
(
D
,
xx_data
,
xx_data
+
D2
,
prev_hidden_data
,
hidden_out_data
);
ker
->
ComputeHtPart2
(
xx_data
,
prev_hidden_data
,
hidden_out_data
);
// save prev
prev_hidden_data
=
hidden_out_data
;
move_step
();
...
...
@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
INIT_BASE_DEFINES
;
if
(
x_lod
[
0
].
size
()
==
2
)
{
xx
->
Resize
({
total_T
,
D3
});
SeqCompute
(
ctx
);
return
;
}
INIT_VEC_FUNC
INIT_OTHER_DEFINES
;
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
batched_input
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedInput"
);
auto
*
batched_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedOut"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
batched_out_data
=
batched_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
T
*
batched_out_data
=
batched_out
->
mutable_data
<
T
>
(
place
);
hidden_out
->
mutable_data
<
T
>
(
place
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
...
...
@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T
*
prev_hidden_data
=
nullptr
;
if
(
h0
)
{
// reorder h0
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()
);
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
const
T
*
h0_data
=
h0
->
data
<
T
>
();
prev_hidden_data
=
reordered_h0_data
;
size_t
sz
=
sizeof
(
T
)
*
D
;
...
...
@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T
*
cur_out_data
=
batched_out_data
;
// W: {W_update, W_reset; W_state}
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
// update gate
act_gate
(
D
,
cur_in_data
,
cur_in_data
);
// state gate
act_state
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D2
);
// out = a*b
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D2
,
cur_out_data
);
ker
->
ComputeH1
(
cur_in_data
,
cur_out_data
);
// add offset
cur_in_data
+=
D3
;
cur_out_data
+=
D
;
...
...
@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T
*
cur_out_data
=
batched_out_data
;
T
*
cur_prev_hidden_data
=
prev_hidden_data
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
act_gate
(
D2
,
cur_batched_data
,
cur_batched_data
);
// rt = rt*ht_1 inplace result
blas
.
VMUL
(
D
,
cur_prev_hidden_data
,
cur_batched_data
+
D
,
cur_out_data
);
ker
->
ComputeHtPart1
(
cur_batched_data
,
cur_prev_hidden_data
,
cur_out_data
);
cur_batched_data
+=
D3
;
cur_prev_hidden_data
+=
D
;
cur_out_data
+=
D
;
...
...
@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data
=
prev_hidden_data
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
// ht~ = act_state(...)
act_state
(
D
,
cur_batched_data
+
D2
,
cur_batched_data
+
D2
);
// out = zt*ht~ + (1-zt)*ht_1
cross
(
D
,
cur_batched_data
,
cur_batched_data
+
D2
,
cur_prev_hidden_data
,
cur_out_data
);
ker
->
ComputeHtPart2
(
cur_batched_data
,
cur_prev_hidden_data
,
cur_out_data
);
cur_batched_data
+=
D3
;
cur_prev_hidden_data
+=
D
;
cur_out_data
+=
D
;
...
...
@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
batched_out
->
set_lod
(
batched_lod
);
to_seq
(
dev_ctx
,
*
batched_out
,
hidden_out
);
}
#undef INIT_VEC_FUNC
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_DEFINES
};
}
// namespace operators
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
cbe128bb
...
...
@@ -68,6 +68,7 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec
cc_test
(
im2col_test SRCS im2col_test.cc DEPS im2col
)
cc_test
(
vol2col_test SRCS vol2col_test.cc DEPS vol2col
)
cc_test
(
sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding
)
cc_test
(
sequence_pooling_test SRCS sequence_pooling_test.cc DEPS sequence_pooling
)
if
(
WITH_GPU
)
nv_test
(
math_function_gpu_test SRCS math_function_test.cu DEPS math_function
)
nv_test
(
selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function
)
...
...
@@ -75,6 +76,6 @@ endif()
cc_test
(
concat_test SRCS concat_test.cc DEPS concat_and_split
)
cc_test
(
cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info
)
cc_library
(
jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_
lstm
.cc
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_
rnn
.cc
DEPS cpu_info cblas
)
cc_test
(
jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel
)
paddle/fluid/operators/math/algorithm.h
浏览文件 @
cbe128bb
...
...
@@ -39,6 +39,52 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
return
-
1
;
}
template
<
typename
T
>
HOSTDEVICE
inline
size_t
LowerBound
(
const
T
*
x
,
size_t
num
,
const
T
&
val
)
{
#ifdef __CUDA_ARCH__
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/lower_bound
auto
*
first
=
x
;
int64_t
count
=
static_cast
<
int64_t
>
(
num
);
while
(
count
>
0
)
{
int64_t
step
=
(
count
>>
1
);
auto
*
it
=
first
+
step
;
if
(
*
it
<
val
)
{
first
=
++
it
;
count
-=
(
step
+
1
);
}
else
{
count
=
step
;
}
}
return
static_cast
<
size_t
>
(
first
-
x
);
#else
return
static_cast
<
size_t
>
(
std
::
lower_bound
(
x
,
x
+
num
,
val
)
-
x
);
#endif
}
template
<
typename
T
>
HOSTDEVICE
inline
size_t
UpperBound
(
const
T
*
x
,
size_t
num
,
const
T
&
val
)
{
#ifdef __CUDA_ARCH__
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/upper_bound
auto
*
first
=
x
;
int64_t
count
=
static_cast
<
int64_t
>
(
num
);
while
(
count
>
0
)
{
auto
step
=
(
count
>>
1
);
auto
*
it
=
first
+
step
;
if
(
val
<
*
it
)
{
count
=
step
;
}
else
{
first
=
++
it
;
count
-=
(
step
+
1
);
}
}
return
static_cast
<
size_t
>
(
first
-
x
);
#else
return
static_cast
<
size_t
>
(
std
::
upper_bound
(
x
,
x
+
num
,
val
)
-
x
);
#endif
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
cbe128bb
...
...
@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel {
const
T
*
wp_data
=
nullptr
)
const
=
0
;
};
template
<
typename
T
>
class
GRUKernel
:
public
Kernel
{
public:
// compute h1 without h0
virtual
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
=
0
;
virtual
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
=
0
;
virtual
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
=
0
;
};
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/math/jit_kernel_
lstm
.cc
→
paddle/fluid/operators/math/jit_kernel_
rnn
.cc
浏览文件 @
cbe128bb
...
...
@@ -136,6 +136,23 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return
nullptr
;
}
#ifdef __AVX__
template
<
jit
::
cpu_isa_t
isa
>
static
std
::
unique_ptr
<
AVXAct
>
GetAVXAct
(
const
std
::
string
&
type
)
{
if
(
type
==
"sigmoid"
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kSigmoid
,
isa
>
());
}
else
if
(
type
==
"relu"
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kRelu
,
isa
>
());
}
else
if
(
type
==
"tanh"
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kTanh
,
isa
>
());
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kIdentity
,
isa
>
());
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
#endif
/* LSTM JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
LSTMKernelImpl
:
public
LSTMKernel
<
T
>
{
...
...
@@ -192,61 +209,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \
: LSTMKernel<float>() { \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
if (type == "sigmoid") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
} else if (type == "relu") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
} else if (type == "tanh") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
} else if (type == "identity" || type == "") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
} \
PADDLE_THROW("Not support type: %s", type); \
}; \
avx_act_gate_ = GetAVXAct(act_gate); \
avx_act_cand_ = GetAVXAct(act_cand); \
avx_act_cell_ = GetAVXAct(act_cell); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */
\
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */
\
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/
\
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */
\
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
#define INTRI8_FLOAT(isa) \
template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \
: LSTMKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_cand_ = GetAVXAct<isa>(act_cand); \
avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */
\
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */
\
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/
\
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */
\
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
}
// TODO(TJ): optimize keq16
...
...
@@ -354,6 +359,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
#undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL
/* GRU JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
GRUKernelImpl
:
public
GRUKernel
<
T
>
{
public:
explicit
GRUKernelImpl
(
const
std
::
string
&
act_gate
,
const
std
::
string
&
act_state
,
int
d
)
:
GRUKernel
<
T
>
()
{
d_
=
d
;
d2_
=
d
*
2
;
act_gate_d2_
=
GetActKernel
<
T
>
(
act_gate
,
d2_
);
act_gate_d_
=
GetActKernel
<
T
>
(
act_gate
,
d
);
act_state_d_
=
GetActKernel
<
T
>
(
act_state
,
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
}
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
override
{
act_gate_d_
->
Compute
(
gates
,
gates
);
act_state_d_
->
Compute
(
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
);
}
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
// W: {W_update, W_reset; W_state}
act_gate_d2_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
);
}
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
T
*
y
=
gates
+
d2_
;
act_state_d_
->
Compute
(
y
,
y
);
// out = zt*ht~ + (1-zt)*ht_1
for
(
int
i
=
0
;
i
<
d_
;
++
i
)
{
ht
[
i
]
=
gates
[
i
]
*
y
[
i
]
+
(
static_cast
<
T
>
(
1
)
-
gates
[
i
])
*
ht_1
[
i
];
}
}
private:
int
d_
,
d2_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_d2_
,
act_gate_d_
,
act_state_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
#ifdef __AVX__
std
::
unique_ptr
<
const
AVXAct
>
avx_act_gate_
,
avx_act_state_
;
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \
const std::string& act_gate, const std::string& act_state, int d) \
: GRUKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_state_ = GetAVXAct<isa>(act_state); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
const { \
__m256 u, s; \
/* W: {W_update, W_reset; W_state} */
\
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
_mm256_storeu_ps(ht, s); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */
\
__m256 r, ht0; \
r = _mm256_loadu_ps(gates + 8); \
ht0 = _mm256_loadu_ps(ht_1); \
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \
_mm256_storeu_ps(ht, r); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */
\
__m256 u, s, ht0; \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
ht0 = _mm256_loadu_ps(ht_1); \
u = avx_act_gate_->Compute(u); \
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
u = _mm256_mul_ps(u, ht0); \
u = _mm256_add_ps(s, u); \
_mm256_storeu_ps(ht, u); \
}
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
#endif
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
GRUKernel<ker_dtype>, const std::string&, const std::string&, int>( \
const std::string& act_gate, const std::string& act_state, int d)
#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_state
#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
REGISTER_JITKERNEL_ARGS
(
gru
,
GRUKernel
,
JITKERNEL_DECLARE_GRU
,
JITKERNEL_KEY_GRU
,
JITKERNEL_NEW_GRU_IMPL
);
#undef INTRI8_FLOAT
#undef JITKERNEL_NEW_GRU_IMPL
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/math/sequence_pooling.cc
浏览文件 @
cbe128bb
...
...
@@ -157,6 +157,31 @@ class FirstSeqPoolFunctor {
}
};
template
<
typename
T
>
class
SumSeqPoolGradFunctor
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
out_grad
,
framework
::
LoDTensor
*
in_grad
)
{
auto
lod
=
in_grad
->
lod
()[
0
];
int64_t
out_w
=
out_grad
.
numel
()
/
out_grad
.
dims
()[
0
];
int64_t
in_w
=
in_grad
->
numel
()
/
in_grad
->
dims
()[
0
];
PADDLE_ENFORCE
(
in_w
==
out_w
);
const
T
*
out_g_data
=
out_grad
.
data
<
T
>
();
T
*
in_g_data
=
in_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
int64_t
h
=
static_cast
<
int64_t
>
(
lod
[
i
+
1
]
-
lod
[
i
]);
int64_t
in_offset
=
lod
[
i
]
*
in_w
;
const
T
*
out_pos
=
out_g_data
+
i
*
out_w
;
T
*
in_pos
=
in_g_data
+
in_offset
;
for
(
int
r
=
0
;
r
!=
h
;
++
r
)
{
blas
.
VCOPY
(
in_w
,
out_pos
,
in_pos
+
r
*
in_w
);
}
}
}
};
template
<
typename
T
>
class
SequencePoolFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
...
...
@@ -231,9 +256,15 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
functor
;
functor
(
context
,
in_grad
,
0
);
}
if
(
pooltype
==
"SUM"
)
{
math
::
SumSeqPoolGradFunctor
<
T
>
sum_pool_grad
;
sum_pool_grad
(
context
,
out_grad
,
in_grad
);
return
;
}
auto
lod
=
in_grad
->
lod
()[
0
];
auto
&
place
=
*
context
.
eigen_device
();
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
auto
in_g_t
=
in_grad
->
Slice
(
static_cast
<
int
>
(
lod
[
i
]),
static_cast
<
int
>
(
lod
[
i
+
1
]));
...
...
@@ -247,12 +278,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
if
(
pooltype
==
"AVERAGE"
)
{
in_g_e
.
device
(
place
)
=
(
out_g_e
/
static_cast
<
T
>
(
h
)).
broadcast
(
bcast
);
}
else
if
(
pooltype
==
"SUM"
)
{
const
T
*
out_g_data
=
out_g_t
.
data
<
T
>
();
T
*
in_g_data
=
in_g_t
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
r
=
0
;
r
!=
h
;
++
r
)
{
blas
.
VCOPY
(
w
,
out_g_data
,
in_g_data
+
r
*
w
);
}
}
else
if
(
pooltype
==
"SQRT"
)
{
in_g_e
.
device
(
place
)
=
(
out_g_e
/
std
::
sqrt
(
static_cast
<
T
>
(
h
))).
broadcast
(
bcast
);
...
...
paddle/fluid/operators/math/sequence_pooling_test.cc
0 → 100644
浏览文件 @
cbe128bb
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sequence_pooling.h"
#include <gtest/gtest.h>
#include <vector>
template
<
typename
DeviceContext
,
typename
Place
,
typename
T
>
void
TestSequencePoolingSum
(
const
paddle
::
framework
::
LoD
&
lod
)
{
paddle
::
framework
::
LoDTensor
cpu_out_grad
;
paddle
::
framework
::
LoDTensor
cpu_in_grad
;
paddle
::
framework
::
LoDTensor
out_grad
;
paddle
::
framework
::
LoDTensor
in_grad
;
const
size_t
second_dim
=
128u
;
// construct out_grad's tensor in cpu
const
size_t
out_first_dim
=
lod
[
0
].
size
()
-
1
;
auto
out_dims
=
paddle
::
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
out_first_dim
),
static_cast
<
int64_t
>
(
second_dim
)});
cpu_out_grad
.
mutable_data
<
T
>
(
out_dims
,
paddle
::
platform
::
CPUPlace
());
for
(
int64_t
i
=
0
;
i
<
cpu_out_grad
.
numel
();
++
i
)
{
cpu_out_grad
.
data
<
T
>
()[
i
]
=
static_cast
<
T
>
(
i
);
}
// copy to dst out_grad
auto
*
place
=
new
Place
();
DeviceContext
*
context
=
new
DeviceContext
(
*
place
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
out_grad
=
cpu_out_grad
;
}
else
{
TensorCopySync
(
cpu_out_grad
,
*
place
,
&
out_grad
);
}
// construct in_grad
in_grad
.
set_lod
(
lod
);
auto
in_dims
=
paddle
::
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
lod
[
0
].
back
()),
static_cast
<
int64_t
>
(
second_dim
)});
in_grad
.
mutable_data
<
T
>
(
in_dims
,
context
->
GetPlace
());
// check tensor contruction result
PADDLE_ENFORCE_EQ
(
in_grad
.
dims
().
size
(),
out_grad
.
dims
().
size
());
for
(
int64_t
i
=
1
;
i
<
out_grad
.
dims
().
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
in_grad
.
dims
()[
i
],
out_grad
.
dims
()[
i
]);
}
// call functor
paddle
::
operators
::
math
::
SequencePoolGradFunctor
<
DeviceContext
,
T
>
()(
*
context
,
"SUM"
,
out_grad
,
&
in_grad
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
cpu_in_grad
=
in_grad
;
}
else
{
TensorCopySync
(
in_grad
,
paddle
::
platform
::
CPUPlace
(),
&
cpu_in_grad
);
cpu_in_grad
.
set_lod
(
in_grad
.
lod
());
}
EXPECT_EQ
(
in_grad
.
numel
(),
lod
[
0
].
back
()
*
second_dim
);
EXPECT_EQ
(
in_grad
.
lod
(),
lod
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
for
(
int64_t
i
=
0
;
i
<
in_grad
.
lod
()[
0
].
size
()
-
1
;
++
i
)
{
int64_t
begin
=
in_grad
.
lod
()[
0
][
i
];
int64_t
end
=
in_grad
.
lod
()[
0
][
i
+
1
];
paddle
::
framework
::
Tensor
tmp
=
in_grad
.
Slice
(
begin
,
end
);
for
(
int64_t
j
=
0
;
j
!=
tmp
.
numel
()
/
second_dim
;
++
j
)
{
for
(
int64_t
m
=
0
;
m
!=
second_dim
;
++
m
)
{
EXPECT_EQ
(
tmp
.
data
<
T
>
()[
m
+
j
*
second_dim
],
out_grad
.
data
<
T
>
()[
m
+
i
*
second_dim
]);
}
}
}
}
else
{
for
(
int64_t
i
=
0
;
i
<
cpu_in_grad
.
lod
()[
0
].
size
()
-
1
;
++
i
)
{
int64_t
begin
=
cpu_in_grad
.
lod
()[
0
][
i
];
int64_t
end
=
cpu_in_grad
.
lod
()[
0
][
i
+
1
];
paddle
::
framework
::
Tensor
tmp
=
cpu_in_grad
.
Slice
(
begin
,
end
);
for
(
int64_t
j
=
0
;
j
!=
tmp
.
numel
()
/
second_dim
;
++
j
)
{
for
(
int64_t
m
=
0
;
m
!=
second_dim
;
++
m
)
{
EXPECT_EQ
(
tmp
.
data
<
T
>
()[
m
+
j
*
second_dim
],
cpu_out_grad
.
data
<
T
>
()[
m
+
i
*
second_dim
]);
}
}
}
}
delete
place
;
delete
context
;
}
TEST
(
SequencePoolingGrad
,
CPU_SUM
)
{
paddle
::
framework
::
LoD
lod1
;
lod1
.
push_back
(
std
::
vector
<
size_t
>
{
0
,
10
});
TestSequencePoolingSum
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
CPUPlace
,
float
>
(
lod1
);
paddle
::
framework
::
LoD
lod2
;
lod2
.
push_back
(
std
::
vector
<
size_t
>
{
0
,
2
,
7
,
10
});
TestSequencePoolingSum
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
CPUPlace
,
float
>
(
lod2
);
}
#ifdef PADDLE_WITH_CUDA
TEST
(
SequencePoolingGrad
,
CUDA_SUM
)
{
paddle
::
framework
::
LoD
lod1
;
lod1
.
push_back
(
std
::
vector
<
size_t
>
{
0
,
10
});
TestSequencePoolingSum
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
CUDAPlace
,
float
>
(
lod1
);
paddle
::
framework
::
LoD
lod2
;
lod2
.
push_back
(
std
::
vector
<
size_t
>
{
0
,
2
,
7
,
10
});
TestSequencePoolingSum
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
CUDAPlace
,
float
>
(
lod2
);
}
#endif
paddle/fluid/operators/sequence_reverse_op.cc
0 → 100644
浏览文件 @
cbe128bb
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_reverse_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
sequence_reverse
,
ops
::
SequenceReverseOp
,
ops
::
SequenceReverseOpMaker
,
ops
::
SequenceReverseGradOpDescMaker
);
REGISTER_OP_CPU_KERNEL
(
sequence_reverse
,
ops
::
SequenceReverseOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
uint8_t
>
,
ops
::
SequenceReverseOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
SequenceReverseOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
SequenceReverseOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SequenceReverseOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/sequence_reverse_op.cu
0 → 100644
浏览文件 @
cbe128bb
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_reverse_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
sequence_reverse
,
ops
::
SequenceReverseOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
,
ops
::
SequenceReverseOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SequenceReverseOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
SequenceReverseOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SequenceReverseOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/sequence_reverse_op.h
0 → 100644
浏览文件 @
cbe128bb
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
class
SequenceReverseOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must exist"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) must exist"
);
auto
x_dim
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_GE
(
x_dim
.
size
(),
2
,
"Rank of Input(X) must be not less than 2."
);
ctx
->
SetOutputDim
(
"Y"
,
x_dim
);
ctx
->
ShareLoD
(
"X"
,
"Y"
);
}
};
class
SequenceReverseOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"The input LoDTensor of sequence_reverse op."
);
AddOutput
(
"Y"
,
"The output LoDTensor of sequence_reverse op."
);
AddComment
(
R"DOC(
SequenceReverse Operator.
Reverse each sequence in input X along dim 0.
Assuming X is a LoDTensor with dims [5, 4] and lod [[0, 2, 5]], where:
X.data() = [
[1, 2, 3, 4],
[5, 6, 7, 8], # the 0-th sequence with length 2
[9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20] # the 1-st sequence with length 3
]
The output Y would be a LoDTensor sharing the same dims and lod with input X,
and:
Y.data() = [
[5, 6, 7, 8],
[1, 2, 3, 4], # the reversed 0-th sequence with length 2
[17, 18, 19, 20],
[13, 14, 15, 16],
[9, 10, 11, 12] # the reversed 1-st sequence with length 3
]
This Operator is useful to build a reverse dynamic RNN network.
This Operator only supports one-level lod currently.
)DOC"
);
}
};
template
<
typename
T
>
struct
SequenceReverseFunctor
{
SequenceReverseFunctor
(
const
T
*
x
,
T
*
y
,
const
size_t
*
lod
,
size_t
lod_count
,
size_t
row_numel
)
:
x_
(
x
),
y_
(
y
),
lod_
(
lod
),
lod_count_
(
lod_count
),
row_numel_
(
row_numel
)
{}
HOSTDEVICE
void
operator
()(
size_t
idx_x
)
const
{
auto
row_idx_x
=
idx_x
/
row_numel_
;
auto
lod_idx
=
math
::
UpperBound
(
lod_
,
lod_count_
,
row_idx_x
);
auto
row_idx_y
=
lod_
[
lod_idx
-
1
]
+
(
lod_
[
lod_idx
]
-
1
-
row_idx_x
);
auto
idx_y
=
row_idx_y
*
row_numel_
+
idx_x
%
row_numel_
;
y_
[
idx_y
]
=
x_
[
idx_x
];
}
const
T
*
x_
;
T
*
y_
;
const
size_t
*
lod_
;
size_t
lod_count_
;
size_t
row_numel_
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
SequenceReverseOpKernel
:
public
framework
::
OpKernel
<
T
>
{
using
LoDTensor
=
framework
::
LoDTensor
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
x
=
*
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Output
<
LoDTensor
>
(
"Y"
);
PADDLE_ENFORCE_EQ
(
x
.
lod
().
size
(),
1
,
"SequenceReverse Op only support one level lod."
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
const
size_t
*
lod
;
size_t
lod_count
=
x
.
lod
()[
0
].
size
();
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
lod
=
x
.
lod
()[
0
].
CUDAData
(
ctx
.
GetPlace
());
}
else
{
#endif
lod
=
x
.
lod
()[
0
].
data
();
#ifdef PADDLE_WITH_CUDA
}
#endif
size_t
limit
=
static_cast
<
size_t
>
(
x
.
numel
());
size_t
row_numel
=
static_cast
<
size_t
>
(
limit
/
x
.
dims
()[
0
]);
auto
*
x_data
=
x
.
data
<
T
>
();
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_NE
(
x_data
,
y_data
,
"SequenceReverse Op does not support in-place operation"
);
SequenceReverseFunctor
<
T
>
functor
(
x_data
,
y_data
,
lod
,
lod_count
,
row_numel
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
limit
);
for_range
(
functor
);
}
};
class
SequenceReverseGradOpDescMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
op
(
new
framework
::
OpDesc
());
op
->
SetType
(
"sequence_reverse"
);
op
->
SetInput
(
"X"
,
OutputGrad
(
"Y"
));
op
->
SetOutput
(
"Y"
,
InputGrad
(
"X"
));
op
->
SetAttrMap
(
Attrs
());
return
op
;
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/softmax_cudnn_op.cu.cc
浏览文件 @
cbe128bb
...
...
@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_KERNEL
(
softmax
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxCUDNNKernel
<
float
>
,
ops
::
SoftmaxCUDNNKernel
<
double
>
,
ops
::
SoftmaxCUDNNKernel
<
plat
::
float16
>
);
REGISTER_OP_KERNEL
(
softmax_grad
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxGradCUDNNKernel
<
float
>
);
ops
::
SoftmaxGradCUDNNKernel
<
float
>
,
ops
::
SoftmaxGradCUDNNKernel
<
double
>
);
paddle/fluid/operators/transpose_op.cc
浏览文件 @
cbe128bb
...
...
@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
REGISTER_OPERATOR
(
transpose_grad
,
ops
::
TransposeOpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OPERATOR
(
transpose2
,
ops
::
Transpose2Op
,
ops
::
Transpose2OpMaker
,
ops
::
Transpose2GradMaker
);
REGISTER_OPERATOR
(
transpose2_grad
,
ops
::
Transpose2OpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/transpose_op.cu.cc
浏览文件 @
cbe128bb
...
...
@@ -16,15 +16,18 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
python/paddle/dataset/wmt16.py
浏览文件 @
cbe128bb
...
...
@@ -78,7 +78,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
six
.
iteritems
(
word_dict
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)):
if
idx
+
3
==
dict_size
:
break
fout
.
write
(
"%s
\n
"
%
(
word
[
0
]
))
fout
.
write
(
"%s
\n
"
%
(
cpt
.
to_bytes
(
word
[
0
])
))
def
__load_dict
(
tar_file
,
dict_size
,
lang
,
reverse
=
False
):
...
...
python/paddle/fluid/clip.py
浏览文件 @
cbe128bb
...
...
@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
)
square
=
grad
*
grad
local_norm_var
=
layers
.
cast
(
layers
.
reduce_sum
(
input
=
square
),
'float64'
)
local_norm_var
=
layers
.
reduce_sum
(
input
=
square
)
context
[
self
.
group_name
].
append
(
local_norm_var
)
self
.
context
=
context
...
...
@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if
group_scale_name
not
in
self
.
context
:
group_norm_var
=
layers
.
sums
(
input
=
self
.
context
[
self
.
group_name
])
group_norm_var
=
layers
.
sqrt
(
x
=
group_norm_var
)
group_norm_var
=
layers
.
cast
(
group_norm_var
,
'float32'
)
clip_var
=
self
.
context
[
self
.
group_name
+
"_clip"
]
group_scale_var
=
layers
.
elementwise_div
(
x
=
clip_var
,
...
...
@@ -333,7 +332,8 @@ def append_gradient_clip_ops(param_grads):
for
p
,
g
in
param_grads
:
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
([
p
,
g
]):
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'append_clip'
):
clip_attr
=
getattr
(
p
,
'gradient_clip_attr'
,
NullGradientClipAttr
())
if
clip_attr
is
None
:
clip_attr
=
NullGradientClipAttr
()
...
...
@@ -348,7 +348,8 @@ def append_gradient_clip_ops(param_grads):
for
p
,
g
in
param_grads
:
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
([
p
,
g
]):
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'append_graident_clip'
):
res
.
append
(
clip_attr
.
_create_operators
(
param
=
p
,
grad
=
g
))
return
res
...
...
python/paddle/fluid/framework.py
浏览文件 @
cbe128bb
...
...
@@ -1496,6 +1496,9 @@ class Program(object):
>>> with program._optimized_guard([p,g]):
>>> p = p - 0.001 * g
"""
tmp_role
=
self
.
_current_role
tmp_var
=
self
.
_op_role_var
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
self
.
_current_role
=
OpRole
.
Optimize
self
.
_op_role_var
=
[
...
...
@@ -1503,11 +1506,11 @@ class Program(object):
for
var
in
param_and_grads
]
yield
self
.
_op_role_var
=
[]
self
.
_current_role
=
OpRole
.
Forward
self
.
_op_role_var
=
tmp_var
self
.
_current_role
=
tmp_role
@
contextlib
.
contextmanager
def
_lr_schedule_guard
(
self
):
def
_lr_schedule_guard
(
self
,
is_with_opt
=
False
):
"""
A with guard to set :code:`LRSched` :code:`OpRole` and
:code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
...
...
@@ -1515,6 +1518,10 @@ class Program(object):
Notes: This is a very low level API. Users should not use it directly.
Args:
is_with_opt: Only set to true if these ops a in the middle
of a bunch of optimize ops so that it can be treated
correctly. For example, sgd->lr_op->sgd->lr_op->sgd.
Examples:
...
...
@@ -1528,6 +1535,8 @@ class Program(object):
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
self
.
_current_role
=
OpRole
.
LRSched
if
is_with_opt
:
self
.
_current_role
=
int
(
OpRole
.
LRSched
)
|
int
(
OpRole
.
Optimize
)
# TODO(typhoonzero): how to set target learning rate var
self
.
_op_role_var
=
[]
yield
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
cbe128bb
...
...
@@ -154,6 +154,7 @@ __all__ = [
'mul'
,
'sigmoid_cross_entropy_with_logits'
,
'maxout'
,
'sequence_reverse'
,
'affine_channel'
,
]
...
...
@@ -980,7 +981,12 @@ def cos_sim(X, Y):
return
out
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
name
=
None
):
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
name
=
None
,
dropout_implementation
=
"downgrade_in_infer"
):
"""
Computes dropout.
...
...
@@ -1000,6 +1006,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
units will be dropped. DO NOT use a fixed seed in training.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1. downgrade_in_infer(default), downgrade the outcome at inference
train: out = input * mask
inference: out = input * dropout_prob
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
2. upscale_in_train, upscale the outcome at training time
train: out = input * mask / ( 1.0 - dropout_prob )
inference: out = input
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
dropout op can be removed from the program.
the program will be efficient
Returns:
Variable: A tensor variable is the shape with `x`.
...
...
@@ -1029,7 +1050,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
'dropout_prob'
:
dropout_prob
,
'is_test'
:
is_test
,
'fix_seed'
:
seed
is
not
None
,
'seed'
:
seed
if
seed
is
not
None
else
0
'seed'
:
seed
if
seed
is
not
None
else
0
,
'dropout_implementation'
:
dropout_implementation
,
})
return
out
...
...
@@ -4844,7 +4866,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
return
counter
def
reshape
(
x
,
shape
,
actual_shape
=
None
,
act
=
None
,
inplace
=
Tru
e
,
name
=
None
):
def
reshape
(
x
,
shape
,
actual_shape
=
None
,
act
=
None
,
inplace
=
Fals
e
,
name
=
None
):
"""
Gives a new shape to the input Tensor without changing its data.
...
...
@@ -4892,15 +4914,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
:attr:`shape` specifying shape. That is to
say :attr:`actual_shape` has a higher priority
than :attr:`shape`.
act (str): The non-linear activation to be applied to output variable.
inplace(bool): If this flag is set true, the output
shares data with input without copying, otherwise
a new output tensor is created
whose data is copied from input x.
act (str): The non-linear activation to be applied to the reshaped tensor
variable.
inplace(bool): Must use :attr:`False` if :attr:`x` is used in multiple
operators. If this flag is set :attr:`True`, reuse input
:attr:`x` to reshape, which will change the shape of
tensor variable :attr:`x` and might cause errors when
:attr:`x` is used in multiple operators. If :attr:`False`,
preserve the shape :attr:`x` and create a new output tensor
variable whose data is copied from input x but reshaped.
name (str): The name of this layer. It is optional.
Returns:
Variable: The output tensor.
Variable: The reshaped tensor variable if :attr:`act` is None. It is a
\
new tensor variable if :attr:`inplace` is :attr:`False`,
\
otherwise it is :attr:`x`. If :attr:`act` is not None, return
\
the activated tensor variable.
Raises:
TypeError: if actual_shape is neither Variable nor None.
...
...
@@ -4911,7 +4940,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
data = fluid.layers.data(
name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.reshape(
x=data, shape=[-1, 0, 3, 2],
act='tanh',
inplace=True)
x=data, shape=[-1, 0, 3, 2], inplace=True)
"""
if
not
(
isinstance
(
shape
,
list
)
or
isinstance
(
shape
,
tuple
)):
...
...
@@ -4938,7 +4967,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"except one unknown dimension."
)
helper
=
LayerHelper
(
"reshape2"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
out
=
x
if
inplace
else
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
x_shape
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
"reshape2"
,
...
...
@@ -7455,6 +7485,33 @@ def maxout(x, groups, name=None):
return
out
@
templatedoc
()
def
sequence_reverse
(
x
,
name
=
None
):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
name(basestring|None): Name of the output.
Returns:
out(${y_type}): ${y_comment}
"""
helper
=
LayerHelper
(
"sequence_reverse"
,
**
locals
())
if
name
is
None
:
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
else
:
out
=
helper
.
create_variable
(
name
=
name
,
dtype
=
x
.
dtype
,
persistable
=
False
)
helper
.
append_op
(
type
=
"sequence_reverse"
,
inputs
=
{
"X"
:
x
},
outputs
=
{
"Y"
:
out
},
attrs
=
dict
())
return
out
def
affine_channel
(
x
,
scale
=
None
,
bias
=
None
,
data_layout
=
'NCHW'
,
name
=
None
):
"""
Applies a separate affine transformation to each channel of the input.
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
cbe128bb
...
...
@@ -111,7 +111,9 @@ class Optimizer(object):
if
param_lr
==
1.0
:
return
self
.
_global_learning_rate
()
else
:
with
default_main_program
().
_lr_schedule_guard
():
with
default_main_program
().
_lr_schedule_guard
(
is_with_opt
=
True
),
framework
.
name_scope
(
'scale_with_param_lr'
):
return
self
.
_global_learning_rate
()
*
param_lr
def
_create_accumulators
(
self
,
block
,
parameters
):
...
...
@@ -602,7 +604,8 @@ class AdamOptimizer(Optimizer):
for
param
,
grad
in
param_and_grads
:
if
grad
is
None
:
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
"optimizer"
):
beta1_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta1_pow_acc_str
,
param
)
beta2_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta2_pow_acc_str
,
...
...
@@ -740,7 +743,8 @@ class AdamaxOptimizer(Optimizer):
for
param
,
grad
in
parameters_and_grads
:
if
grad
is
None
:
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
'adamx'
):
beta1_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta1_pow_acc_str
,
param
)
main_block
.
append_op
(
...
...
@@ -1279,7 +1283,8 @@ class ModelAverage(Optimizer):
for
param
,
grad
in
self
.
params_grads
:
if
grad
is
None
:
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
'move_average'
):
self
.
_append_average_accumulate_op
(
param
)
self
.
apply_program
=
Program
()
...
...
python/paddle/fluid/regularizer.py
浏览文件 @
cbe128bb
...
...
@@ -47,7 +47,8 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
if
grad
is
None
:
params_and_grads
.
append
((
param
,
grad
))
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
framework
.
name_scope
(
'regularization'
):
regularization_term
=
None
if
param
.
regularizer
is
not
None
:
# Add variable for regularization term in grad block
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
cbe128bb
...
...
@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE)
set_tests_properties
(
test_dist_word2vec PROPERTIES TIMEOUT 200
)
py_test_modules
(
test_dist_se_resnext MODULES test_dist_se_resnext
)
set_tests_properties
(
test_dist_se_resnext PROPERTIES TIMEOUT 1000
)
#
TODO: fix this test
#py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
#
FIXME(typhoonzero): add this back
#py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
endif
(
NOT APPLE
)
py_test_modules
(
test_dist_transpiler MODULES test_dist_transpiler
)
endif
()
...
...
python/paddle/fluid/tests/unittests/dist_transformer.py
浏览文件 @
cbe128bb
...
...
@@ -35,7 +35,7 @@ import paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
from
paddle.fluid
import
core
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
,
RUN_STEP
import
paddle.compat
as
cpt
from
paddle.compat
import
long_type
...
...
@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for
pass_id
in
six
.
moves
.
xrange
(
TrainTaskConfig
.
pass_num
):
pass_start_time
=
time
.
time
()
for
batch_id
,
data
in
enumerate
(
train_data
()):
if
batch_id
>=
5
:
if
batch_id
>=
RUN_STEP
:
break
feed_list
=
[]
total_num_token
=
0
#if TrainTaskConfig.local:
# lr_rate = lr_scheduler.update_learning_rate()
#for place_id, data_buffer in enumerate(
# split_data(
# data, num_part=dev_count)):
if
TrainTaskConfig
.
local
:
lr_rate
=
lr_scheduler
.
update_learning_rate
()
...
...
@@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
init
=
True
# Validate and save the model for inference.
if
batch_id
==
0
or
batch_id
==
4
:
if
TrainTaskConfig
.
val_file_pattern
is
not
None
:
val_avg_cost
,
val_ppl
=
test
()
print
(
"[%f]"
%
val_avg_cost
)
else
:
assert
(
False
)
if
TrainTaskConfig
.
val_file_pattern
is
not
None
:
val_avg_cost
,
val_ppl
=
test
()
print
(
"[%f]"
%
val_avg_cost
)
else
:
assert
(
False
)
#import transformer_reader as reader
...
...
@@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase):
def
run_trainer
(
self
,
args
):
TrainTaskConfig
.
use_gpu
=
args
.
use_cuda
sum_cost
,
avg_cost
,
predict
,
token_num
,
local_lr_scheduler
=
get_model
(
sum_cost
,
avg_cost
,
predict
,
token_num
,
local_lr_scheduler
,
test_program
=
get_model
(
args
.
is_dist
,
not
args
.
sync_mode
)
if
args
.
is_dist
:
...
...
python/paddle/fluid/tests/unittests/test_dist_mnist.py
浏览文件 @
cbe128bb
...
...
@@ -40,7 +40,8 @@ class TestDistMnistAsync(TestDistBase):
self
.
_sync_mode
=
False
self
.
_use_reduce
=
False
def
test_dist_train
(
self
):
# FIXME(typhoonzero): fix async mode test later
def
no_test_dist_train
(
self
):
self
.
check_with_place
(
"dist_mnist.py"
,
delta
=
200
)
...
...
python/paddle/fluid/tests/unittests/test_dist_se_resnext.py
浏览文件 @
cbe128bb
...
...
@@ -40,7 +40,8 @@ class TestDistSeResneXt2x2Async(TestDistBase):
self
.
_sync_mode
=
False
self
.
_use_reader_alloc
=
False
def
test_dist_train
(
self
):
#FIXME(typhoonzero): fix async mode later
def
no_test_dist_train
(
self
):
self
.
check_with_place
(
"dist_se_resnext.py"
,
delta
=
100
)
...
...
python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py
浏览文件 @
cbe128bb
...
...
@@ -42,7 +42,8 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase):
self
.
_sync_mode
=
False
self
.
_enforce_place
=
"CPU"
def
test_simnet_bow
(
self
):
#FIXME(typhoonzero): fix async tests later
def
no_test_simnet_bow
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'0'
,
...
...
@@ -78,7 +79,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
self
.
_sync_mode
=
False
self
.
_enforce_place
=
"CPU"
def
test_simnet_bow
(
self
):
#FIXME(typhoonzero): fix async tests later
def
no_test_simnet_bow
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'1'
,
...
...
python/paddle/fluid/tests/unittests/test_dist_transformer.py
浏览文件 @
cbe128bb
...
...
@@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase):
def
test_dist_train
(
self
):
download_files
()
self
.
check_with_place
(
"dist_transformer.py"
,
delta
=
1e-5
)
self
.
check_with_place
(
"dist_transformer.py"
,
delta
=
1e-5
,
check_error_log
=
False
)
class
TestDistTransformer2x2Async
(
TestDistBase
):
...
...
@@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase):
def
test_dist_train
(
self
):
download_files
()
self
.
check_with_place
(
"dist_transformer.py"
,
delta
=
1.0
)
self
.
check_with_place
(
"dist_transformer.py"
,
delta
=
1.0
,
check_error_log
=
False
)
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
cbe128bb
...
...
@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest):
self
.
check_output
()
class
TestDropoutOp6
(
TestDropoutOp
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
1.0
,
'fix_seed'
:
True
,
'is_test'
:
False
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
),
'Mask'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
)
}
class
TestDropoutOp7
(
TestDropoutOp
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
2
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'fix_seed'
:
True
,
'is_test'
:
False
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'Mask'
:
np
.
ones
((
32
,
64
,
2
)).
astype
(
'float32'
)
}
class
TestDropoutOp8
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.35
,
'fix_seed'
:
True
,
'is_test'
:
True
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestDropoutOp9
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
3
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.75
,
'is_test'
:
True
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFP16DropoutOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
...
...
python/paddle/fluid/tests/unittests/test_fusion_gru_op.py
浏览文件 @
cbe128bb
...
...
@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp):
self
.
D
=
8
class
TestFusionGRUOpMD3
(
TestFusionGRUOp
):
def
set_confs
(
self
):
self
.
M
=
17
self
.
D
=
15
class
TestFusionGRUOpBS1
(
TestFusionGRUOp
):
def
set_confs
(
self
):
self
.
lod
=
[[
3
]]
...
...
python/paddle/fluid/tests/unittests/test_sequence_reverse.py
0 → 100644
浏览文件 @
cbe128bb
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
import
numpy
as
np
class
TestSequenceReverseBase
(
OpTest
):
def
initParameters
(
self
):
pass
def
setUp
(
self
):
self
.
size
=
(
10
,
3
,
4
)
self
.
lod
=
[
2
,
3
,
5
]
self
.
dtype
=
'float32'
self
.
initParameters
()
self
.
op_type
=
'sequence_reverse'
self
.
x
=
np
.
random
.
random
(
self
.
size
).
astype
(
self
.
dtype
)
self
.
y
=
self
.
get_output
()
self
.
inputs
=
{
'X'
:
(
self
.
x
,
[
self
.
lod
,
]),
}
self
.
outputs
=
{
'Y'
:
(
self
.
y
,
[
self
.
lod
,
]),
}
def
get_output
(
self
):
tmp_x
=
np
.
reshape
(
self
.
x
,
newshape
=
[
self
.
x
.
shape
[
0
],
-
1
])
tmp_y
=
np
.
ndarray
(
tmp_x
.
shape
).
astype
(
self
.
dtype
)
prev_idx
=
0
for
cur_len
in
self
.
lod
:
idx_range
=
range
(
prev_idx
,
prev_idx
+
cur_len
)
tmp_y
[
idx_range
,
:]
=
np
.
flip
(
tmp_x
[
idx_range
,
:],
0
)
prev_idx
+=
cur_len
return
np
.
reshape
(
tmp_y
,
newshape
=
self
.
x
.
shape
).
astype
(
self
.
dtype
)
def
test_output
(
self
):
self
.
check_output
(
0
)
def
test_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
class
TestSequenceReserve1
(
TestSequenceReverseBase
):
def
initParameters
(
self
):
self
.
size
=
(
12
,
10
)
self
.
lod
=
[
4
,
5
,
3
]
class
TestSequenceReverse2
(
TestSequenceReverseBase
):
def
initParameters
(
self
):
self
.
size
=
(
12
,
10
)
self
.
lod
=
[
12
]
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
cbe128bb
...
...
@@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
OPT_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
DIST_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Dist
LR_SCHED_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
LRSched
...
...
@@ -1717,8 +1718,10 @@ to transpile() call.")
lr_ops
=
[]
block
=
self
.
origin_program
.
global_block
()
for
op
in
block
.
ops
:
if
int
(
op
.
attr
(
RPC_OP_ROLE_ATTR_NAME
))
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
):
role_id
=
int
(
op
.
attr
(
RPC_OP_ROLE_ATTR_NAME
))
if
role_id
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
)
or
\
role_id
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
)
|
\
int
(
OPT_OP_ROLE_ATTR_VALUE
):
lr_ops
.
append
(
op
)
log
(
"append lr op: "
,
op
.
type
)
return
lr_ops
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录