Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6e361542
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e361542
编写于
10月 26, 2018
作者:
J
JiabinYang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into add_reorg_op
上级
7bcba47e
a3efba17
变更
29
隐藏空白更改
内联
并排
Showing
29 changed file
with
287 addition
and
137 deletion
+287
-137
.gitignore
.gitignore
+1
-0
Dockerfile
Dockerfile
+3
-3
paddle/fluid/API.spec
paddle/fluid/API.spec
+2
-2
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+3
-3
paddle/fluid/framework/mixed_vector.h
paddle/fluid/framework/mixed_vector.h
+27
-0
paddle/fluid/framework/naive_executor.cc
paddle/fluid/framework/naive_executor.cc
+0
-17
paddle/fluid/framework/naive_executor.h
paddle/fluid/framework/naive_executor.h
+0
-2
paddle/fluid/framework/op_proto_maker.cc
paddle/fluid/framework/op_proto_maker.cc
+2
-0
paddle/fluid/framework/op_proto_maker.h
paddle/fluid/framework/op_proto_maker.h
+3
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+0
-6
paddle/fluid/framework/threadpool.cc
paddle/fluid/framework/threadpool.cc
+9
-22
paddle/fluid/framework/threadpool.h
paddle/fluid/framework/threadpool.h
+0
-24
paddle/fluid/framework/threadpool_test.cc
paddle/fluid/framework/threadpool_test.cc
+10
-6
paddle/fluid/operators/detection/generate_proposals_op.cc
paddle/fluid/operators/detection/generate_proposals_op.cc
+1
-1
paddle/fluid/operators/dropout_op.cc
paddle/fluid/operators/dropout_op.cc
+28
-2
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+22
-7
paddle/fluid/operators/dropout_op.h
paddle/fluid/operators/dropout_op.h
+16
-3
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+2
-0
paddle/fluid/operators/softmax_cudnn_op.cu.cc
paddle/fluid/operators/softmax_cudnn_op.cu.cc
+3
-1
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+8
-5
paddle/fluid/operators/transpose_op.cu.cc
paddle/fluid/operators/transpose_op.cu.cc
+8
-5
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+5
-4
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+12
-3
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+40
-11
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+9
-4
python/paddle/fluid/regularizer.py
python/paddle/fluid/regularizer.py
+2
-1
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-3
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+63
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+5
-2
未找到文件。
.gitignore
浏览文件 @
6e361542
...
...
@@ -28,3 +28,4 @@ third_party/
build_*
# clion workspace.
cmake-build-*
model_test
Dockerfile
浏览文件 @
6e361542
...
...
@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \
pip3
install
-U
docopt PyYAML
sphinx
==
1.5.6
&&
\
pip3
install
sphinx-rtd-theme
==
0.1.9 recommonmark
&&
\
easy_install
-U
pip
&&
\
pip
install
-U
wheel
&&
\
pip
install
-U
pip setuptools
wheel
&&
\
pip
install
-U
docopt PyYAML
sphinx
==
1.5.6
&&
\
pip
install
sphinx-rtd-theme
==
0.1.9 recommonmark
RUN
pip3
install
pre-commit
'ipython==5.3.0'
&&
\
RUN
pip3
install
'pre-commit==1.10.4'
'ipython==5.3.0'
&&
\
pip3
install
'ipykernel==4.6.0'
'jupyter==1.0.0'
&&
\
pip3
install
opencv-python
&&
\
pip
install
pre-commit
'ipython==5.3.0'
&&
\
pip
install
'pre-commit==1.10.4'
'ipython==5.3.0'
&&
\
pip
install
'ipykernel==4.6.0'
'jupyter==1.0.0'
&&
\
pip
install
opencv-python
...
...
paddle/fluid/API.spec
浏览文件 @
6e361542
...
...
@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'
], varargs=None, keywords=None, defaults=(False, None, None
))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'
, 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer'
))
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))
...
...
@@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label',
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None,
Tru
e, None))
paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None,
Fals
e, None))
paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None))
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
6e361542
...
...
@@ -252,9 +252,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std
::
vector
<
ir
::
Node
*>
sorted_ret
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
i
<
last_backward
)
{
if
(
boost
::
get
<
int
>
(
ret
[
i
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
if
(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
ret
[
i
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kOptimize
)
))
{
optimize_ops
.
push_back
(
ret
[
i
]);
}
else
{
sorted_ret
.
push_back
(
ret
[
i
]);
...
...
paddle/fluid/framework/mixed_vector.h
浏览文件 @
6e361542
...
...
@@ -542,6 +542,33 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
this
->
reserve
(
this
->
size
()
+
size_t
(
end
-
begin
));
this
->
insert
(
this
->
end
(),
begin
,
end
);
}
const
T
*
CUDAData
(
platform
::
Place
place
)
const
{
PADDLE_THROW
(
"Vector::CUDAData() method is not supported in CPU-only version"
);
}
T
*
CUDAMutableData
(
platform
::
Place
place
)
{
PADDLE_THROW
(
"Vector::CUDAMutableData() method is not supported in CPU-only "
"version"
);
}
const
T
*
Data
(
platform
::
Place
place
)
const
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
place
),
"Vector::Data() method is not supported when not in CPUPlace"
);
return
this
->
data
();
}
T
*
MutableData
(
platform
::
Place
place
)
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
place
),
"Vector::MutableData() method is not supported when not in CPUPlace"
);
return
this
->
data
();
}
const
void
*
Handle
()
const
{
return
static_cast
<
const
void
*>
(
this
);
}
};
template
<
typename
T
>
...
...
paddle/fluid/framework/naive_executor.cc
浏览文件 @
6e361542
...
...
@@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() {
ops_
.
swap
(
ops
);
}
void
NaiveExecutor
::
EnableMKLDNN
(
const
ProgramDesc
&
program
)
{
#ifdef PADDLE_WITH_MKLDNN
VLOG
(
3
)
<<
"use_mkldnn=True"
;
for
(
size_t
block_id
=
0
;
block_id
<
program
.
Size
();
++
block_id
)
{
auto
*
block
=
const_cast
<
ProgramDesc
&>
(
program
).
MutableBlock
(
block_id
);
for
(
auto
*
op
:
block
->
AllOps
())
{
if
(
op
->
HasAttr
(
"use_mkldnn"
))
{
op
->
SetAttr
(
"use_mkldnn"
,
true
);
}
}
}
#else
LOG
(
WARNING
)
<<
"'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"
;
#endif
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/naive_executor.h
浏览文件 @
6e361542
...
...
@@ -48,8 +48,6 @@ class NaiveExecutor {
void
CleanFeedFetchOps
();
void
EnableMKLDNN
(
const
ProgramDesc
&
program
);
protected:
void
CreateVariables
(
const
ProgramDesc
&
desc
,
Scope
*
scope
,
int
block_id
);
...
...
paddle/fluid/framework/op_proto_maker.cc
浏览文件 @
6e361542
...
...
@@ -71,6 +71,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kBackward
),
static_cast
<
int
>
(
OpRole
::
kOptimize
)
|
static_cast
<
int
>
(
OpRole
::
kLRSched
),
static_cast
<
int
>
(
OpRole
::
kNotSpecified
)})
.
SetDefault
(
static_cast
<
int
>
(
OpRole
::
kNotSpecified
));
AddAttr
<
std
::
vector
<
std
::
string
>>
(
OpRoleVarAttrName
(),
...
...
paddle/fluid/framework/op_proto_maker.h
浏览文件 @
6e361542
...
...
@@ -20,6 +20,9 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
//////////////////////////
// Don't add more roles to make this too complicated!
//////////////////////////
enum
class
OpRole
{
kForward
=
0x0000
,
kBackward
=
0x0001
,
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
6e361542
...
...
@@ -156,12 +156,6 @@ ParallelExecutor::ParallelExecutor(
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
);
#endif
// If the loss_var_name is given, the number of graph should be only one.
if
(
loss_var_name
.
size
())
{
PADDLE_ENFORCE_EQ
(
ir
::
GraphNum
(
*
graph
),
1
,
"The number of graph should be only one"
);
}
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
...
...
paddle/fluid/framework/threadpool.cc
浏览文件 @
6e361542
...
...
@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0,
namespace
paddle
{
namespace
framework
{
std
::
unique_ptr
<
ThreadPool
>
ThreadPool
::
threadpool_
(
nullptr
);
std
::
once_flag
ThreadPool
::
init_flag_
;
...
...
@@ -47,8 +46,7 @@ void ThreadPool::Init() {
}
}
ThreadPool
::
ThreadPool
(
int
num_threads
)
:
total_threads_
(
num_threads
),
idle_threads_
(
num_threads
),
running_
(
true
)
{
ThreadPool
::
ThreadPool
(
int
num_threads
)
:
running_
(
true
)
{
threads_
.
resize
(
num_threads
);
for
(
auto
&
thread
:
threads_
)
{
// TODO(Yancey1989): binding the thread on the specify CPU number
...
...
@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads)
ThreadPool
::~
ThreadPool
()
{
{
// notify all threads to stop running
std
::
lock_guard
<
std
::
mutex
>
l
(
mutex_
);
running_
=
false
;
scheduled_
.
notify_all
();
}
...
...
@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() {
}
}
void
ThreadPool
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
completed_
.
wait
(
lock
,
[
=
]
{
return
Done
()
==
true
;
});
}
void
ThreadPool
::
TaskLoop
()
{
while
(
running_
)
{
while
(
true
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
scheduled_
.
wait
(
lock
,
[
=
]
{
return
!
tasks_
.
empty
()
||
!
running_
;
});
if
(
!
running_
)
{
break
;
scheduled_
.
wait
(
lock
,
[
this
]
{
return
!
this
->
tasks_
.
empty
()
||
!
this
->
running_
;
});
if
(
!
running_
||
tasks_
.
empty
())
{
return
;
}
// pop a task from the task queue
auto
task
=
std
::
move
(
tasks_
.
front
());
tasks_
.
pop
();
--
idle_threads_
;
lock
.
unlock
();
// run the task
task
();
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
++
idle_threads_
;
if
(
Done
())
{
completed_
.
notify_all
();
}
}
}
}
...
...
paddle/fluid/framework/threadpool.h
浏览文件 @
6e361542
...
...
@@ -57,15 +57,6 @@ class ThreadPool {
~
ThreadPool
();
// Returns the number of threads created by the constructor.
size_t
Threads
()
const
{
return
total_threads_
;
}
// Returns the number of currently idle threads.
size_t
IdleThreads
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
idle_threads_
;
}
// Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call
// std::future::wait().
...
...
@@ -94,25 +85,13 @@ class ThreadPool {
});
std
::
future
<
std
::
unique_ptr
<
platform
::
EnforceNotMet
>>
f
=
task
.
get_future
();
tasks_
.
push
(
std
::
move
(
task
));
lock
.
unlock
();
scheduled_
.
notify_one
();
return
f
;
}
// Wait until all the tasks are completed.
void
Wait
();
private:
DISABLE_COPY_AND_ASSIGN
(
ThreadPool
);
// If the task queue is empty and avaialbe is equal to the number of
// threads, means that all tasks are completed. Note: this function
// is not thread-safe. Returns true if all tasks are completed.
// Note: don't delete the data member total_threads_ and use
// threads_.size() instead; because you'd need to lock the mutex
// before accessing threads_.
bool
Done
()
{
return
tasks_
.
empty
()
&&
idle_threads_
==
total_threads_
;
}
// The constructor starts threads to run TaskLoop, which retrieves
// and runs tasks from the queue.
void
TaskLoop
();
...
...
@@ -125,14 +104,11 @@ class ThreadPool {
static
std
::
once_flag
init_flag_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
threads_
;
const
size_t
total_threads_
;
size_t
idle_threads_
;
std
::
queue
<
Task
>
tasks_
;
std
::
mutex
mutex_
;
bool
running_
;
std
::
condition_variable
scheduled_
;
std
::
condition_variable
completed_
;
};
class
ThreadPoolIO
:
ThreadPool
{
...
...
paddle/fluid/framework/threadpool_test.cc
浏览文件 @
6e361542
...
...
@@ -19,10 +19,11 @@ limitations under the License. */
namespace
framework
=
paddle
::
framework
;
void
do_sum
(
framework
::
ThreadPool
*
pool
,
std
::
atomic
<
int
>*
sum
,
int
cnt
)
{
std
::
vector
<
std
::
future
<
void
>>
fs
;
void
do_sum
(
std
::
vector
<
std
::
future
<
void
>>*
fs
,
std
::
mutex
*
mu
,
std
::
atomic
<
int
>*
sum
,
int
cnt
)
{
for
(
int
i
=
0
;
i
<
cnt
;
++
i
)
{
fs
.
push_back
(
framework
::
Async
([
sum
]()
{
sum
->
fetch_add
(
1
);
}));
std
::
lock_guard
<
std
::
mutex
>
l
(
*
mu
);
fs
->
push_back
(
framework
::
Async
([
sum
]()
{
sum
->
fetch_add
(
1
);
}));
}
}
...
...
@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) {
}
TEST
(
ThreadPool
,
ConcurrentRun
)
{
framework
::
ThreadPool
*
pool
=
framework
::
ThreadPool
::
GetInstance
();
std
::
atomic
<
int
>
sum
(
0
);
std
::
vector
<
std
::
thread
>
threads
;
std
::
vector
<
std
::
future
<
void
>>
fs
;
std
::
mutex
fs_mu
;
int
n
=
50
;
// sum = (n * (n + 1)) / 2
for
(
int
i
=
1
;
i
<=
n
;
++
i
)
{
std
::
thread
t
(
do_sum
,
pool
,
&
sum
,
i
);
std
::
thread
t
(
do_sum
,
&
fs
,
&
fs_mu
,
&
sum
,
i
);
threads
.
push_back
(
std
::
move
(
t
));
}
for
(
auto
&
t
:
threads
)
{
t
.
join
();
}
pool
->
Wait
();
for
(
auto
&
t
:
fs
)
{
t
.
wait
();
}
EXPECT_EQ
(
sum
,
((
n
+
1
)
*
n
)
/
2
);
}
paddle/fluid/operators/detection/generate_proposals_op.cc
浏览文件 @
6e361542
...
...
@@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
selected_indices
.
push_back
(
idx
);
++
selected_num
;
}
sorted_indices
.
erase
(
sorted_indices
.
end
());
sorted_indices
.
erase
(
sorted_indices
.
end
()
-
1
);
if
(
flag
&&
eta
<
1
&&
adaptive_threshold
>
0.5
)
{
adaptive_threshold
*=
eta
;
}
...
...
paddle/fluid/operators/dropout_op.cc
浏览文件 @
6e361542
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <string>
namespace
paddle
{
namespace
operators
{
...
...
@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"will be dropped."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"seed"
,
"Dropout random seed."
).
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"dropout_implementation"
,
"[
\"
downgrade_in_infer
\"
|
\"
upscale_in_train
\"
]"
"There are two kinds of ways to implement dropout"
"(the mask below is a tensor have the same shape with input"
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
"1. downgrade_in_infer(default), downgrade the outcome at inference "
"time"
" train: out = input * mask"
" inference: out = input * dropout_prob"
"2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference"
" train: out = input * mask / ( 1.0 - dropout_prob )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient"
)
.
SetDefault
(
"downgrade_in_infer"
)
.
AddCustomChecker
([](
const
std
::
string
&
type
)
{
PADDLE_ENFORCE
(
type
==
"downgrade_in_infer"
||
type
==
"upscale_in_train"
,
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"
);
});
AddComment
(
R"DOC(
Dropout Operator.
...
...
@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
dropout_grad
,
ops
::
DropoutOpGrad
);
REGISTER_OP_CPU_KERNEL
(
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.cu
浏览文件 @
6e361542
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <string>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -26,7 +27,8 @@ namespace operators {
template
<
typename
T
>
__global__
void
RandomGenerator
(
const
size_t
n
,
const
int
seed
,
const
float
dropout_prob
,
const
T
*
src
,
T
*
mask_data
,
T
*
dst
)
{
T
*
mask_data
,
T
*
dst
,
bool
is_upscale_in_train
)
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
...
...
@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed,
if
(
dist
(
rng
)
<
dropout_prob
)
{
mask
=
static_cast
<
T
>
(
0
);
}
else
{
mask
=
static_cast
<
T
>
(
1
);
if
(
is_upscale_in_train
)
{
mask
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
}
else
{
mask
=
static_cast
<
T
>
(
1
);
}
}
dest
=
s
*
mask
;
mask_data
[
idx
]
=
mask
;
...
...
@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
auto
&
place
=
*
context
.
template
device_context
<
Place
>().
eigen_device
();
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
...
...
@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int
grid
=
(
x
->
numel
()
+
threads
-
1
)
/
threads
;
RandomGenerator
<
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
);
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
(
dropout_implementation
==
"upscale_in_train"
));
}
else
{
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
Y
.
device
(
place
)
=
X
;
}
else
{
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
...
...
@@ -99,6 +112,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
dropout
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.h
浏览文件 @
6e361542
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
@@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
size_t
size
=
framework
::
product
(
mask
->
dims
());
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
dist
(
engine
)
<
dropout_prob
)
{
mask_data
[
i
]
=
0
;
y_data
[
i
]
=
0
;
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
if
(
dropout_implementation
==
"upscale_in_train"
)
{
mask_data
[
i
]
=
1.0
f
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
y_data
[
i
]
=
x_data
[
i
]
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
}
}
}
}
else
{
...
...
@@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
Y
.
device
(
place
)
=
X
*
(
1.0
f
-
dropout_prob
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
Y
.
device
(
place
)
=
X
;
}
else
{
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
6e361542
...
...
@@ -136,6 +136,7 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return
nullptr
;
}
#ifdef __AVX__
template
<
jit
::
cpu_isa_t
isa
>
static
std
::
unique_ptr
<
AVXAct
>
GetAVXAct
(
const
std
::
string
&
type
)
{
if
(
type
==
"sigmoid"
)
{
...
...
@@ -150,6 +151,7 @@ static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
#endif
/* LSTM JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
paddle/fluid/operators/softmax_cudnn_op.cu.cc
浏览文件 @
6e361542
...
...
@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_KERNEL
(
softmax
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxCUDNNKernel
<
float
>
,
ops
::
SoftmaxCUDNNKernel
<
double
>
,
ops
::
SoftmaxCUDNNKernel
<
plat
::
float16
>
);
REGISTER_OP_KERNEL
(
softmax_grad
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxGradCUDNNKernel
<
float
>
);
ops
::
SoftmaxGradCUDNNKernel
<
float
>
,
ops
::
SoftmaxGradCUDNNKernel
<
double
>
);
paddle/fluid/operators/transpose_op.cc
浏览文件 @
6e361542
...
...
@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
REGISTER_OPERATOR
(
transpose_grad
,
ops
::
TransposeOpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OPERATOR
(
transpose2
,
ops
::
Transpose2Op
,
ops
::
Transpose2OpMaker
,
ops
::
Transpose2GradMaker
);
REGISTER_OPERATOR
(
transpose2_grad
,
ops
::
Transpose2OpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/transpose_op.cu.cc
浏览文件 @
6e361542
...
...
@@ -16,15 +16,18 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
python/paddle/fluid/clip.py
浏览文件 @
6e361542
...
...
@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
)
square
=
grad
*
grad
local_norm_var
=
layers
.
cast
(
layers
.
reduce_sum
(
input
=
square
),
'float64'
)
local_norm_var
=
layers
.
reduce_sum
(
input
=
square
)
context
[
self
.
group_name
].
append
(
local_norm_var
)
self
.
context
=
context
...
...
@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if
group_scale_name
not
in
self
.
context
:
group_norm_var
=
layers
.
sums
(
input
=
self
.
context
[
self
.
group_name
])
group_norm_var
=
layers
.
sqrt
(
x
=
group_norm_var
)
group_norm_var
=
layers
.
cast
(
group_norm_var
,
'float32'
)
clip_var
=
self
.
context
[
self
.
group_name
+
"_clip"
]
group_scale_var
=
layers
.
elementwise_div
(
x
=
clip_var
,
...
...
@@ -333,7 +332,8 @@ def append_gradient_clip_ops(param_grads):
for
p
,
g
in
param_grads
:
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
([
p
,
g
]):
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'append_clip'
):
clip_attr
=
getattr
(
p
,
'gradient_clip_attr'
,
NullGradientClipAttr
())
if
clip_attr
is
None
:
clip_attr
=
NullGradientClipAttr
()
...
...
@@ -348,7 +348,8 @@ def append_gradient_clip_ops(param_grads):
for
p
,
g
in
param_grads
:
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
([
p
,
g
]):
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'append_graident_clip'
):
res
.
append
(
clip_attr
.
_create_operators
(
param
=
p
,
grad
=
g
))
return
res
...
...
python/paddle/fluid/framework.py
浏览文件 @
6e361542
...
...
@@ -1496,6 +1496,9 @@ class Program(object):
>>> with program._optimized_guard([p,g]):
>>> p = p - 0.001 * g
"""
tmp_role
=
self
.
_current_role
tmp_var
=
self
.
_op_role_var
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
self
.
_current_role
=
OpRole
.
Optimize
self
.
_op_role_var
=
[
...
...
@@ -1503,11 +1506,11 @@ class Program(object):
for
var
in
param_and_grads
]
yield
self
.
_op_role_var
=
[]
self
.
_current_role
=
OpRole
.
Forward
self
.
_op_role_var
=
tmp_var
self
.
_current_role
=
tmp_role
@
contextlib
.
contextmanager
def
_lr_schedule_guard
(
self
):
def
_lr_schedule_guard
(
self
,
is_with_opt
=
False
):
"""
A with guard to set :code:`LRSched` :code:`OpRole` and
:code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
...
...
@@ -1515,6 +1518,10 @@ class Program(object):
Notes: This is a very low level API. Users should not use it directly.
Args:
is_with_opt: Only set to true if these ops a in the middle
of a bunch of optimize ops so that it can be treated
correctly. For example, sgd->lr_op->sgd->lr_op->sgd.
Examples:
...
...
@@ -1528,6 +1535,8 @@ class Program(object):
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
self
.
_current_role
=
OpRole
.
LRSched
if
is_with_opt
:
self
.
_current_role
=
int
(
OpRole
.
LRSched
)
|
int
(
OpRole
.
Optimize
)
# TODO(typhoonzero): how to set target learning rate var
self
.
_op_role_var
=
[]
yield
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
6e361542
...
...
@@ -981,7 +981,12 @@ def cos_sim(X, Y):
return
out
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
name
=
None
):
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
name
=
None
,
dropout_implementation
=
"downgrade_in_infer"
):
"""
Computes dropout.
...
...
@@ -1001,6 +1006,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
units will be dropped. DO NOT use a fixed seed in training.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1. downgrade_in_infer(default), downgrade the outcome at inference
train: out = input * mask
inference: out = input * dropout_prob
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
2. upscale_in_train, upscale the outcome at training time
train: out = input * mask / ( 1.0 - dropout_prob )
inference: out = input
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
dropout op can be removed from the program.
the program will be efficient
Returns:
Variable: A tensor variable is the shape with `x`.
...
...
@@ -1030,7 +1050,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
'dropout_prob'
:
dropout_prob
,
'is_test'
:
is_test
,
'fix_seed'
:
seed
is
not
None
,
'seed'
:
seed
if
seed
is
not
None
else
0
'seed'
:
seed
if
seed
is
not
None
else
0
,
'dropout_implementation'
:
dropout_implementation
,
})
return
out
...
...
@@ -4845,7 +4866,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
return
counter
def
reshape
(
x
,
shape
,
actual_shape
=
None
,
act
=
None
,
inplace
=
Tru
e
,
name
=
None
):
def
reshape
(
x
,
shape
,
actual_shape
=
None
,
act
=
None
,
inplace
=
Fals
e
,
name
=
None
):
"""
Gives a new shape to the input Tensor without changing its data.
...
...
@@ -4893,15 +4914,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
:attr:`shape` specifying shape. That is to
say :attr:`actual_shape` has a higher priority
than :attr:`shape`.
act (str): The non-linear activation to be applied to output variable.
inplace(bool): If this flag is set true, the output
shares data with input without copying, otherwise
a new output tensor is created
whose data is copied from input x.
act (str): The non-linear activation to be applied to the reshaped tensor
variable.
inplace(bool): Must use :attr:`False` if :attr:`x` is used in multiple
operators. If this flag is set :attr:`True`, reuse input
:attr:`x` to reshape, which will change the shape of
tensor variable :attr:`x` and might cause errors when
:attr:`x` is used in multiple operators. If :attr:`False`,
preserve the shape :attr:`x` and create a new output tensor
variable whose data is copied from input x but reshaped.
name (str): The name of this layer. It is optional.
Returns:
Variable: The output tensor.
Variable: The reshaped tensor variable if :attr:`act` is None. It is a
\
new tensor variable if :attr:`inplace` is :attr:`False`,
\
otherwise it is :attr:`x`. If :attr:`act` is not None, return
\
the activated tensor variable.
Raises:
TypeError: if actual_shape is neither Variable nor None.
...
...
@@ -4912,7 +4940,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
data = fluid.layers.data(
name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.reshape(
x=data, shape=[-1, 0, 3, 2],
act='tanh',
inplace=True)
x=data, shape=[-1, 0, 3, 2], inplace=True)
"""
if
not
(
isinstance
(
shape
,
list
)
or
isinstance
(
shape
,
tuple
)):
...
...
@@ -4939,7 +4967,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"except one unknown dimension."
)
helper
=
LayerHelper
(
"reshape2"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
out
=
x
if
inplace
else
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
x_shape
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
"reshape2"
,
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
6e361542
...
...
@@ -111,7 +111,9 @@ class Optimizer(object):
if
param_lr
==
1.0
:
return
self
.
_global_learning_rate
()
else
:
with
default_main_program
().
_lr_schedule_guard
():
with
default_main_program
().
_lr_schedule_guard
(
is_with_opt
=
True
),
framework
.
name_scope
(
'scale_with_param_lr'
):
return
self
.
_global_learning_rate
()
*
param_lr
def
_create_accumulators
(
self
,
block
,
parameters
):
...
...
@@ -602,7 +604,8 @@ class AdamOptimizer(Optimizer):
for
param
,
grad
in
param_and_grads
:
if
grad
is
None
:
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
"optimizer"
):
beta1_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta1_pow_acc_str
,
param
)
beta2_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta2_pow_acc_str
,
...
...
@@ -740,7 +743,8 @@ class AdamaxOptimizer(Optimizer):
for
param
,
grad
in
parameters_and_grads
:
if
grad
is
None
:
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
'adamx'
):
beta1_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta1_pow_acc_str
,
param
)
main_block
.
append_op
(
...
...
@@ -1279,7 +1283,8 @@ class ModelAverage(Optimizer):
for
param
,
grad
in
self
.
params_grads
:
if
grad
is
None
:
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
'move_average'
):
self
.
_append_average_accumulate_op
(
param
)
self
.
apply_program
=
Program
()
...
...
python/paddle/fluid/regularizer.py
浏览文件 @
6e361542
...
...
@@ -47,7 +47,8 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
if
grad
is
None
:
params_and_grads
.
append
((
param
,
grad
))
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
framework
.
name_scope
(
'regularization'
):
regularization_term
=
None
if
param
.
regularizer
is
not
None
:
# Add variable for regularization term in grad block
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
6e361542
...
...
@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE)
set_tests_properties
(
test_dist_word2vec PROPERTIES TIMEOUT 200
)
py_test_modules
(
test_dist_se_resnext MODULES test_dist_se_resnext
)
set_tests_properties
(
test_dist_se_resnext PROPERTIES TIMEOUT 1000
)
py_test_modules
(
test_dist_transformer MODULES test_dist_transformer
)
set_tests_properties
(
test_dist_transformer PROPERTIES TIMEOUT 1000
)
# FIXME(typhoonzero): add this back
#
py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#
set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
endif
(
NOT APPLE
)
py_test_modules
(
test_dist_transpiler MODULES test_dist_transpiler
)
endif
()
...
...
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
6e361542
...
...
@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest):
self
.
check_output
()
class
TestDropoutOp6
(
TestDropoutOp
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
1.0
,
'fix_seed'
:
True
,
'is_test'
:
False
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
),
'Mask'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
)
}
class
TestDropoutOp7
(
TestDropoutOp
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
2
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'fix_seed'
:
True
,
'is_test'
:
False
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'Mask'
:
np
.
ones
((
32
,
64
,
2
)).
astype
(
'float32'
)
}
class
TestDropoutOp8
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.35
,
'fix_seed'
:
True
,
'is_test'
:
True
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestDropoutOp9
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
3
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.75
,
'is_test'
:
True
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFP16DropoutOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
6e361542
...
...
@@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
OPT_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
DIST_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Dist
LR_SCHED_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
LRSched
...
...
@@ -1717,8 +1718,10 @@ to transpile() call.")
lr_ops
=
[]
block
=
self
.
origin_program
.
global_block
()
for
op
in
block
.
ops
:
if
int
(
op
.
attr
(
RPC_OP_ROLE_ATTR_NAME
))
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
):
role_id
=
int
(
op
.
attr
(
RPC_OP_ROLE_ATTR_NAME
))
if
role_id
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
)
or
\
role_id
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
)
|
\
int
(
OPT_OP_ROLE_ATTR_VALUE
):
lr_ops
.
append
(
op
)
log
(
"append lr op: "
,
op
.
type
)
return
lr_ops
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录