Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6e361542
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e361542
编写于
10月 26, 2018
作者:
J
JiabinYang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into add_reorg_op
上级
7bcba47e
a3efba17
变更
29
隐藏空白更改
内联
并排
Showing
29 changed file
with
287 addition
and
137 deletion
+287
-137
.gitignore
.gitignore
+1
-0
Dockerfile
Dockerfile
+3
-3
paddle/fluid/API.spec
paddle/fluid/API.spec
+2
-2
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+3
-3
paddle/fluid/framework/mixed_vector.h
paddle/fluid/framework/mixed_vector.h
+27
-0
paddle/fluid/framework/naive_executor.cc
paddle/fluid/framework/naive_executor.cc
+0
-17
paddle/fluid/framework/naive_executor.h
paddle/fluid/framework/naive_executor.h
+0
-2
paddle/fluid/framework/op_proto_maker.cc
paddle/fluid/framework/op_proto_maker.cc
+2
-0
paddle/fluid/framework/op_proto_maker.h
paddle/fluid/framework/op_proto_maker.h
+3
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+0
-6
paddle/fluid/framework/threadpool.cc
paddle/fluid/framework/threadpool.cc
+9
-22
paddle/fluid/framework/threadpool.h
paddle/fluid/framework/threadpool.h
+0
-24
paddle/fluid/framework/threadpool_test.cc
paddle/fluid/framework/threadpool_test.cc
+10
-6
paddle/fluid/operators/detection/generate_proposals_op.cc
paddle/fluid/operators/detection/generate_proposals_op.cc
+1
-1
paddle/fluid/operators/dropout_op.cc
paddle/fluid/operators/dropout_op.cc
+28
-2
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+22
-7
paddle/fluid/operators/dropout_op.h
paddle/fluid/operators/dropout_op.h
+16
-3
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+2
-0
paddle/fluid/operators/softmax_cudnn_op.cu.cc
paddle/fluid/operators/softmax_cudnn_op.cu.cc
+3
-1
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+8
-5
paddle/fluid/operators/transpose_op.cu.cc
paddle/fluid/operators/transpose_op.cu.cc
+8
-5
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+5
-4
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+12
-3
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+40
-11
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+9
-4
python/paddle/fluid/regularizer.py
python/paddle/fluid/regularizer.py
+2
-1
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-3
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+63
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+5
-2
未找到文件。
.gitignore
浏览文件 @
6e361542
...
...
@@ -28,3 +28,4 @@ third_party/
build_*
# clion workspace.
cmake-build-*
model_test
Dockerfile
浏览文件 @
6e361542
...
...
@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \
pip3
install
-U
docopt PyYAML
sphinx
==
1.5.6
&&
\
pip3
install
sphinx-rtd-theme
==
0.1.9 recommonmark
&&
\
easy_install
-U
pip
&&
\
pip
install
-U
wheel
&&
\
pip
install
-U
pip setuptools
wheel
&&
\
pip
install
-U
docopt PyYAML
sphinx
==
1.5.6
&&
\
pip
install
sphinx-rtd-theme
==
0.1.9 recommonmark
RUN
pip3
install
pre-commit
'ipython==5.3.0'
&&
\
RUN
pip3
install
'pre-commit==1.10.4'
'ipython==5.3.0'
&&
\
pip3
install
'ipykernel==4.6.0'
'jupyter==1.0.0'
&&
\
pip3
install
opencv-python
&&
\
pip
install
pre-commit
'ipython==5.3.0'
&&
\
pip
install
'pre-commit==1.10.4'
'ipython==5.3.0'
&&
\
pip
install
'ipykernel==4.6.0'
'jupyter==1.0.0'
&&
\
pip
install
opencv-python
...
...
paddle/fluid/API.spec
浏览文件 @
6e361542
...
...
@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'
], varargs=None, keywords=None, defaults=(False, None, None
))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'
, 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer'
))
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))
...
...
@@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label',
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None,
Tru
e, None))
paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None,
Fals
e, None))
paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None))
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
6e361542
...
...
@@ -252,9 +252,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std
::
vector
<
ir
::
Node
*>
sorted_ret
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
i
<
last_backward
)
{
if
(
boost
::
get
<
int
>
(
ret
[
i
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
if
(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
ret
[
i
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kOptimize
)
))
{
optimize_ops
.
push_back
(
ret
[
i
]);
}
else
{
sorted_ret
.
push_back
(
ret
[
i
]);
...
...
paddle/fluid/framework/mixed_vector.h
浏览文件 @
6e361542
...
...
@@ -542,6 +542,33 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
this
->
reserve
(
this
->
size
()
+
size_t
(
end
-
begin
));
this
->
insert
(
this
->
end
(),
begin
,
end
);
}
const
T
*
CUDAData
(
platform
::
Place
place
)
const
{
PADDLE_THROW
(
"Vector::CUDAData() method is not supported in CPU-only version"
);
}
T
*
CUDAMutableData
(
platform
::
Place
place
)
{
PADDLE_THROW
(
"Vector::CUDAMutableData() method is not supported in CPU-only "
"version"
);
}
const
T
*
Data
(
platform
::
Place
place
)
const
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
place
),
"Vector::Data() method is not supported when not in CPUPlace"
);
return
this
->
data
();
}
T
*
MutableData
(
platform
::
Place
place
)
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
place
),
"Vector::MutableData() method is not supported when not in CPUPlace"
);
return
this
->
data
();
}
const
void
*
Handle
()
const
{
return
static_cast
<
const
void
*>
(
this
);
}
};
template
<
typename
T
>
...
...
paddle/fluid/framework/naive_executor.cc
浏览文件 @
6e361542
...
...
@@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() {
ops_
.
swap
(
ops
);
}
void
NaiveExecutor
::
EnableMKLDNN
(
const
ProgramDesc
&
program
)
{
#ifdef PADDLE_WITH_MKLDNN
VLOG
(
3
)
<<
"use_mkldnn=True"
;
for
(
size_t
block_id
=
0
;
block_id
<
program
.
Size
();
++
block_id
)
{
auto
*
block
=
const_cast
<
ProgramDesc
&>
(
program
).
MutableBlock
(
block_id
);
for
(
auto
*
op
:
block
->
AllOps
())
{
if
(
op
->
HasAttr
(
"use_mkldnn"
))
{
op
->
SetAttr
(
"use_mkldnn"
,
true
);
}
}
}
#else
LOG
(
WARNING
)
<<
"'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"
;
#endif
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/naive_executor.h
浏览文件 @
6e361542
...
...
@@ -48,8 +48,6 @@ class NaiveExecutor {
void
CleanFeedFetchOps
();
void
EnableMKLDNN
(
const
ProgramDesc
&
program
);
protected:
void
CreateVariables
(
const
ProgramDesc
&
desc
,
Scope
*
scope
,
int
block_id
);
...
...
paddle/fluid/framework/op_proto_maker.cc
浏览文件 @
6e361542
...
...
@@ -71,6 +71,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kForward
),
static_cast
<
int
>
(
OpRole
::
kLoss
)
|
static_cast
<
int
>
(
OpRole
::
kBackward
),
static_cast
<
int
>
(
OpRole
::
kOptimize
)
|
static_cast
<
int
>
(
OpRole
::
kLRSched
),
static_cast
<
int
>
(
OpRole
::
kNotSpecified
)})
.
SetDefault
(
static_cast
<
int
>
(
OpRole
::
kNotSpecified
));
AddAttr
<
std
::
vector
<
std
::
string
>>
(
OpRoleVarAttrName
(),
...
...
paddle/fluid/framework/op_proto_maker.h
浏览文件 @
6e361542
...
...
@@ -20,6 +20,9 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
//////////////////////////
// Don't add more roles to make this too complicated!
//////////////////////////
enum
class
OpRole
{
kForward
=
0x0000
,
kBackward
=
0x0001
,
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
6e361542
...
...
@@ -156,12 +156,6 @@ ParallelExecutor::ParallelExecutor(
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
);
#endif
// If the loss_var_name is given, the number of graph should be only one.
if
(
loss_var_name
.
size
())
{
PADDLE_ENFORCE_EQ
(
ir
::
GraphNum
(
*
graph
),
1
,
"The number of graph should be only one"
);
}
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
...
...
paddle/fluid/framework/threadpool.cc
浏览文件 @
6e361542
...
...
@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0,
namespace
paddle
{
namespace
framework
{
std
::
unique_ptr
<
ThreadPool
>
ThreadPool
::
threadpool_
(
nullptr
);
std
::
once_flag
ThreadPool
::
init_flag_
;
...
...
@@ -47,8 +46,7 @@ void ThreadPool::Init() {
}
}
ThreadPool
::
ThreadPool
(
int
num_threads
)
:
total_threads_
(
num_threads
),
idle_threads_
(
num_threads
),
running_
(
true
)
{
ThreadPool
::
ThreadPool
(
int
num_threads
)
:
running_
(
true
)
{
threads_
.
resize
(
num_threads
);
for
(
auto
&
thread
:
threads_
)
{
// TODO(Yancey1989): binding the thread on the specify CPU number
...
...
@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads)
ThreadPool
::~
ThreadPool
()
{
{
// notify all threads to stop running
std
::
lock_guard
<
std
::
mutex
>
l
(
mutex_
);
running_
=
false
;
scheduled_
.
notify_all
();
}
...
...
@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() {
}
}
void
ThreadPool
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
completed_
.
wait
(
lock
,
[
=
]
{
return
Done
()
==
true
;
});
}
void
ThreadPool
::
TaskLoop
()
{
while
(
running_
)
{
while
(
true
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
scheduled_
.
wait
(
lock
,
[
=
]
{
return
!
tasks_
.
empty
()
||
!
running_
;
});
if
(
!
running_
)
{
break
;
scheduled_
.
wait
(
lock
,
[
this
]
{
return
!
this
->
tasks_
.
empty
()
||
!
this
->
running_
;
});
if
(
!
running_
||
tasks_
.
empty
())
{
return
;
}
// pop a task from the task queue
auto
task
=
std
::
move
(
tasks_
.
front
());
tasks_
.
pop
();
--
idle_threads_
;
lock
.
unlock
();
// run the task
task
();
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
++
idle_threads_
;
if
(
Done
())
{
completed_
.
notify_all
();
}
}
}
}
...
...
paddle/fluid/framework/threadpool.h
浏览文件 @
6e361542
...
...
@@ -57,15 +57,6 @@ class ThreadPool {
~
ThreadPool
();
// Returns the number of threads created by the constructor.
size_t
Threads
()
const
{
return
total_threads_
;
}
// Returns the number of currently idle threads.
size_t
IdleThreads
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
idle_threads_
;
}
// Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call
// std::future::wait().
...
...
@@ -94,25 +85,13 @@ class ThreadPool {
});
std
::
future
<
std
::
unique_ptr
<
platform
::
EnforceNotMet
>>
f
=
task
.
get_future
();
tasks_
.
push
(
std
::
move
(
task
));
lock
.
unlock
();
scheduled_
.
notify_one
();
return
f
;
}
// Wait until all the tasks are completed.
void
Wait
();
private:
DISABLE_COPY_AND_ASSIGN
(
ThreadPool
);
// If the task queue is empty and avaialbe is equal to the number of
// threads, means that all tasks are completed. Note: this function
// is not thread-safe. Returns true if all tasks are completed.
// Note: don't delete the data member total_threads_ and use
// threads_.size() instead; because you'd need to lock the mutex
// before accessing threads_.
bool
Done
()
{
return
tasks_
.
empty
()
&&
idle_threads_
==
total_threads_
;
}
// The constructor starts threads to run TaskLoop, which retrieves
// and runs tasks from the queue.
void
TaskLoop
();
...
...
@@ -125,14 +104,11 @@ class ThreadPool {
static
std
::
once_flag
init_flag_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
threads_
;
const
size_t
total_threads_
;
size_t
idle_threads_
;
std
::
queue
<
Task
>
tasks_
;
std
::
mutex
mutex_
;
bool
running_
;
std
::
condition_variable
scheduled_
;
std
::
condition_variable
completed_
;
};
class
ThreadPoolIO
:
ThreadPool
{
...
...
paddle/fluid/framework/threadpool_test.cc
浏览文件 @
6e361542
...
...
@@ -19,10 +19,11 @@ limitations under the License. */
namespace
framework
=
paddle
::
framework
;
void
do_sum
(
framework
::
ThreadPool
*
pool
,
std
::
atomic
<
int
>*
sum
,
int
cnt
)
{
std
::
vector
<
std
::
future
<
void
>>
fs
;
void
do_sum
(
std
::
vector
<
std
::
future
<
void
>>*
fs
,
std
::
mutex
*
mu
,
std
::
atomic
<
int
>*
sum
,
int
cnt
)
{
for
(
int
i
=
0
;
i
<
cnt
;
++
i
)
{
fs
.
push_back
(
framework
::
Async
([
sum
]()
{
sum
->
fetch_add
(
1
);
}));
std
::
lock_guard
<
std
::
mutex
>
l
(
*
mu
);
fs
->
push_back
(
framework
::
Async
([
sum
]()
{
sum
->
fetch_add
(
1
);
}));
}
}
...
...
@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) {
}
TEST
(
ThreadPool
,
ConcurrentRun
)
{
framework
::
ThreadPool
*
pool
=
framework
::
ThreadPool
::
GetInstance
();
std
::
atomic
<
int
>
sum
(
0
);
std
::
vector
<
std
::
thread
>
threads
;
std
::
vector
<
std
::
future
<
void
>>
fs
;
std
::
mutex
fs_mu
;
int
n
=
50
;
// sum = (n * (n + 1)) / 2
for
(
int
i
=
1
;
i
<=
n
;
++
i
)
{
std
::
thread
t
(
do_sum
,
pool
,
&
sum
,
i
);
std
::
thread
t
(
do_sum
,
&
fs
,
&
fs_mu
,
&
sum
,
i
);
threads
.
push_back
(
std
::
move
(
t
));
}
for
(
auto
&
t
:
threads
)
{
t
.
join
();
}
pool
->
Wait
();
for
(
auto
&
t
:
fs
)
{
t
.
wait
();
}
EXPECT_EQ
(
sum
,
((
n
+
1
)
*
n
)
/
2
);
}
paddle/fluid/operators/detection/generate_proposals_op.cc
浏览文件 @
6e361542
...
...
@@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
selected_indices
.
push_back
(
idx
);
++
selected_num
;
}
sorted_indices
.
erase
(
sorted_indices
.
end
());
sorted_indices
.
erase
(
sorted_indices
.
end
()
-
1
);
if
(
flag
&&
eta
<
1
&&
adaptive_threshold
>
0.5
)
{
adaptive_threshold
*=
eta
;
}
...
...
paddle/fluid/operators/dropout_op.cc
浏览文件 @
6e361542
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <string>
namespace
paddle
{
namespace
operators
{
...
...
@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"will be dropped."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"seed"
,
"Dropout random seed."
).
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"dropout_implementation"
,
"[
\"
downgrade_in_infer
\"
|
\"
upscale_in_train
\"
]"
"There are two kinds of ways to implement dropout"
"(the mask below is a tensor have the same shape with input"
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
"1. downgrade_in_infer(default), downgrade the outcome at inference "
"time"
" train: out = input * mask"
" inference: out = input * dropout_prob"
"2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference"
" train: out = input * mask / ( 1.0 - dropout_prob )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient"
)
.
SetDefault
(
"downgrade_in_infer"
)
.
AddCustomChecker
([](
const
std
::
string
&
type
)
{
PADDLE_ENFORCE
(
type
==
"downgrade_in_infer"
||
type
==
"upscale_in_train"
,
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"
);
});
AddComment
(
R"DOC(
Dropout Operator.
...
...
@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
dropout_grad
,
ops
::
DropoutOpGrad
);
REGISTER_OP_CPU_KERNEL
(
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.cu
浏览文件 @
6e361542
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <string>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -26,7 +27,8 @@ namespace operators {
template
<
typename
T
>
__global__
void
RandomGenerator
(
const
size_t
n
,
const
int
seed
,
const
float
dropout_prob
,
const
T
*
src
,
T
*
mask_data
,
T
*
dst
)
{
T
*
mask_data
,
T
*
dst
,
bool
is_upscale_in_train
)
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
...
...
@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed,
if
(
dist
(
rng
)
<
dropout_prob
)
{
mask
=
static_cast
<
T
>
(
0
);
}
else
{
mask
=
static_cast
<
T
>
(
1
);
if
(
is_upscale_in_train
)
{
mask
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
}
else
{
mask
=
static_cast
<
T
>
(
1
);
}
}
dest
=
s
*
mask
;
mask_data
[
idx
]
=
mask
;
...
...
@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
auto
&
place
=
*
context
.
template
device_context
<
Place
>().
eigen_device
();
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
...
...
@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int
grid
=
(
x
->
numel
()
+
threads
-
1
)
/
threads
;
RandomGenerator
<
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
);
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
(
dropout_implementation
==
"upscale_in_train"
));
}
else
{
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
Y
.
device
(
place
)
=
X
;
}
else
{
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
...
...
@@ -99,6 +112,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
dropout
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.h
浏览文件 @
6e361542
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
@@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
size_t
size
=
framework
::
product
(
mask
->
dims
());
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
dist
(
engine
)
<
dropout_prob
)
{
mask_data
[
i
]
=
0
;
y_data
[
i
]
=
0
;
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
if
(
dropout_implementation
==
"upscale_in_train"
)
{
mask_data
[
i
]
=
1.0
f
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
y_data
[
i
]
=
x_data
[
i
]
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
}
}
}
}
else
{
...
...
@@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
Y
.
device
(
place
)
=
X
*
(
1.0
f
-
dropout_prob
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
Y
.
device
(
place
)
=
X
;
}
else
{
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
6e361542
...
...
@@ -136,6 +136,7 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return
nullptr
;
}
#ifdef __AVX__
template
<
jit
::
cpu_isa_t
isa
>
static
std
::
unique_ptr
<
AVXAct
>
GetAVXAct
(
const
std
::
string
&
type
)
{
if
(
type
==
"sigmoid"
)
{
...
...
@@ -150,6 +151,7 @@ static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
#endif
/* LSTM JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
paddle/fluid/operators/softmax_cudnn_op.cu.cc
浏览文件 @
6e361542
...
...
@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_KERNEL
(
softmax
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxCUDNNKernel
<
float
>
,
ops
::
SoftmaxCUDNNKernel
<
double
>
,
ops
::
SoftmaxCUDNNKernel
<
plat
::
float16
>
);
REGISTER_OP_KERNEL
(
softmax_grad
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxGradCUDNNKernel
<
float
>
);
ops
::
SoftmaxGradCUDNNKernel
<
float
>
,
ops
::
SoftmaxGradCUDNNKernel
<
double
>
);
paddle/fluid/operators/transpose_op.cc
浏览文件 @
6e361542
...
...
@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
REGISTER_OPERATOR
(
transpose_grad
,
ops
::
TransposeOpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OPERATOR
(
transpose2
,
ops
::
Transpose2Op
,
ops
::
Transpose2OpMaker
,
ops
::
Transpose2GradMaker
);
REGISTER_OPERATOR
(
transpose2_grad
,
ops
::
Transpose2OpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/transpose_op.cu.cc
浏览文件 @
6e361542
...
...
@@ -16,15 +16,18 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
python/paddle/fluid/clip.py
浏览文件 @
6e361542
...
...
@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
)
square
=
grad
*
grad
local_norm_var
=
layers
.
cast
(
layers
.
reduce_sum
(
input
=
square
),
'float64'
)
local_norm_var
=
layers
.
reduce_sum
(
input
=
square
)
context
[
self
.
group_name
].
append
(
local_norm_var
)
self
.
context
=
context
...
...
@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if
group_scale_name
not
in
self
.
context
:
group_norm_var
=
layers
.
sums
(
input
=
self
.
context
[
self
.
group_name
])
group_norm_var
=
layers
.
sqrt
(
x
=
group_norm_var
)
group_norm_var
=
layers
.
cast
(
group_norm_var
,
'float32'
)
clip_var
=
self
.
context
[
self
.
group_name
+
"_clip"
]
group_scale_var
=
layers
.
elementwise_div
(
x
=
clip_var
,
...
...
@@ -333,7 +332,8 @@ def append_gradient_clip_ops(param_grads):
for
p
,
g
in
param_grads
:
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
([
p
,
g
]):
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'append_clip'
):
clip_attr
=
getattr
(
p
,
'gradient_clip_attr'
,
NullGradientClipAttr
())
if
clip_attr
is
None
:
clip_attr
=
NullGradientClipAttr
()
...
...
@@ -348,7 +348,8 @@ def append_gradient_clip_ops(param_grads):
for
p
,
g
in
param_grads
:
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
([
p
,
g
]):
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'append_graident_clip'
):
res
.
append
(
clip_attr
.
_create_operators
(
param
=
p
,
grad
=
g
))
return
res
...
...
python/paddle/fluid/framework.py
浏览文件 @
6e361542
...
...
@@ -1496,6 +1496,9 @@ class Program(object):
>>> with program._optimized_guard([p,g]):
>>> p = p - 0.001 * g
"""
tmp_role
=
self
.
_current_role
tmp_var
=
self
.
_op_role_var
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
self
.
_current_role
=
OpRole
.
Optimize
self
.
_op_role_var
=
[
...
...
@@ -1503,11 +1506,11 @@ class Program(object):
for
var
in
param_and_grads
]
yield
self
.
_op_role_var
=
[]
self
.
_current_role
=
OpRole
.
Forward
self
.
_op_role_var
=
tmp_var
self
.
_current_role
=
tmp_role
@
contextlib
.
contextmanager
def
_lr_schedule_guard
(
self
):
def
_lr_schedule_guard
(
self
,
is_with_opt
=
False
):
"""
A with guard to set :code:`LRSched` :code:`OpRole` and
:code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
...
...
@@ -1515,6 +1518,10 @@ class Program(object):
Notes: This is a very low level API. Users should not use it directly.
Args:
is_with_opt: Only set to true if these ops a in the middle
of a bunch of optimize ops so that it can be treated
correctly. For example, sgd->lr_op->sgd->lr_op->sgd.
Examples:
...
...
@@ -1528,6 +1535,8 @@ class Program(object):
OpRole
=
core
.
op_proto_and_checker_maker
.
OpRole
self
.
_current_role
=
OpRole
.
LRSched
if
is_with_opt
:
self
.
_current_role
=
int
(
OpRole
.
LRSched
)
|
int
(
OpRole
.
Optimize
)
# TODO(typhoonzero): how to set target learning rate var
self
.
_op_role_var
=
[]
yield
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
6e361542
...
...
@@ -981,7 +981,12 @@ def cos_sim(X, Y):
return
out
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
name
=
None
):
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
name
=
None
,
dropout_implementation
=
"downgrade_in_infer"
):
"""
Computes dropout.
...
...
@@ -1001,6 +1006,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
units will be dropped. DO NOT use a fixed seed in training.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1. downgrade_in_infer(default), downgrade the outcome at inference
train: out = input * mask
inference: out = input * dropout_prob
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
2. upscale_in_train, upscale the outcome at training time
train: out = input * mask / ( 1.0 - dropout_prob )
inference: out = input
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
dropout op can be removed from the program.
the program will be efficient
Returns:
Variable: A tensor variable is the shape with `x`.
...
...
@@ -1030,7 +1050,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
'dropout_prob'
:
dropout_prob
,
'is_test'
:
is_test
,
'fix_seed'
:
seed
is
not
None
,
'seed'
:
seed
if
seed
is
not
None
else
0
'seed'
:
seed
if
seed
is
not
None
else
0
,
'dropout_implementation'
:
dropout_implementation
,
})
return
out
...
...
@@ -4845,7 +4866,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
return
counter
def
reshape
(
x
,
shape
,
actual_shape
=
None
,
act
=
None
,
inplace
=
Tru
e
,
name
=
None
):
def
reshape
(
x
,
shape
,
actual_shape
=
None
,
act
=
None
,
inplace
=
Fals
e
,
name
=
None
):
"""
Gives a new shape to the input Tensor without changing its data.
...
...
@@ -4893,15 +4914,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
:attr:`shape` specifying shape. That is to
say :attr:`actual_shape` has a higher priority
than :attr:`shape`.
act (str): The non-linear activation to be applied to output variable.
inplace(bool): If this flag is set true, the output
shares data with input without copying, otherwise
a new output tensor is created
whose data is copied from input x.
act (str): The non-linear activation to be applied to the reshaped tensor
variable.
inplace(bool): Must use :attr:`False` if :attr:`x` is used in multiple
operators. If this flag is set :attr:`True`, reuse input
:attr:`x` to reshape, which will change the shape of
tensor variable :attr:`x` and might cause errors when
:attr:`x` is used in multiple operators. If :attr:`False`,
preserve the shape :attr:`x` and create a new output tensor
variable whose data is copied from input x but reshaped.
name (str): The name of this layer. It is optional.
Returns:
Variable: The output tensor.
Variable: The reshaped tensor variable if :attr:`act` is None. It is a
\
new tensor variable if :attr:`inplace` is :attr:`False`,
\
otherwise it is :attr:`x`. If :attr:`act` is not None, return
\
the activated tensor variable.
Raises:
TypeError: if actual_shape is neither Variable nor None.
...
...
@@ -4912,7 +4940,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
data = fluid.layers.data(
name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.reshape(
x=data, shape=[-1, 0, 3, 2],
act='tanh',
inplace=True)
x=data, shape=[-1, 0, 3, 2], inplace=True)
"""
if
not
(
isinstance
(
shape
,
list
)
or
isinstance
(
shape
,
tuple
)):
...
...
@@ -4939,7 +4967,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"except one unknown dimension."
)
helper
=
LayerHelper
(
"reshape2"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
out
=
x
if
inplace
else
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
x_shape
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
"reshape2"
,
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
6e361542
...
...
@@ -111,7 +111,9 @@ class Optimizer(object):
if
param_lr
==
1.0
:
return
self
.
_global_learning_rate
()
else
:
with
default_main_program
().
_lr_schedule_guard
():
with
default_main_program
().
_lr_schedule_guard
(
is_with_opt
=
True
),
framework
.
name_scope
(
'scale_with_param_lr'
):
return
self
.
_global_learning_rate
()
*
param_lr
def
_create_accumulators
(
self
,
block
,
parameters
):
...
...
@@ -602,7 +604,8 @@ class AdamOptimizer(Optimizer):
for
param
,
grad
in
param_and_grads
:
if
grad
is
None
:
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
"optimizer"
):
beta1_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta1_pow_acc_str
,
param
)
beta2_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta2_pow_acc_str
,
...
...
@@ -740,7 +743,8 @@ class AdamaxOptimizer(Optimizer):
for
param
,
grad
in
parameters_and_grads
:
if
grad
is
None
:
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
'adamx'
):
beta1_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta1_pow_acc_str
,
param
)
main_block
.
append_op
(
...
...
@@ -1279,7 +1283,8 @@ class ModelAverage(Optimizer):
for
param
,
grad
in
self
.
params_grads
:
if
grad
is
None
:
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
'move_average'
):
self
.
_append_average_accumulate_op
(
param
)
self
.
apply_program
=
Program
()
...
...
python/paddle/fluid/regularizer.py
浏览文件 @
6e361542
...
...
@@ -47,7 +47,8 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
if
grad
is
None
:
params_and_grads
.
append
((
param
,
grad
))
continue
with
param
.
block
.
program
.
_optimized_guard
([
param
,
grad
]):
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
framework
.
name_scope
(
'regularization'
):
regularization_term
=
None
if
param
.
regularizer
is
not
None
:
# Add variable for regularization term in grad block
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
6e361542
...
...
@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE)
set_tests_properties
(
test_dist_word2vec PROPERTIES TIMEOUT 200
)
py_test_modules
(
test_dist_se_resnext MODULES test_dist_se_resnext
)
set_tests_properties
(
test_dist_se_resnext PROPERTIES TIMEOUT 1000
)
py_test_modules
(
test_dist_transformer MODULES test_dist_transformer
)
set_tests_properties
(
test_dist_transformer PROPERTIES TIMEOUT 1000
)
# FIXME(typhoonzero): add this back
#
py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#
set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
endif
(
NOT APPLE
)
py_test_modules
(
test_dist_transpiler MODULES test_dist_transpiler
)
endif
()
...
...
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
6e361542
...
...
@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest):
self
.
check_output
()
class
TestDropoutOp6
(
TestDropoutOp
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
1.0
,
'fix_seed'
:
True
,
'is_test'
:
False
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
),
'Mask'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
)
}
class
TestDropoutOp7
(
TestDropoutOp
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
2
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'fix_seed'
:
True
,
'is_test'
:
False
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'Mask'
:
np
.
ones
((
32
,
64
,
2
)).
astype
(
'float32'
)
}
class
TestDropoutOp8
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.35
,
'fix_seed'
:
True
,
'is_test'
:
True
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestDropoutOp9
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
3
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.75
,
'is_test'
:
True
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFP16DropoutOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
6e361542
...
...
@@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME
=
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
RPC_OP_ROLE_ATTR_NAME
=
op_role_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
(
)
OPT_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
RPC_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
RPC
DIST_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Dist
LR_SCHED_OP_ROLE_ATTR_VALUE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
LRSched
...
...
@@ -1717,8 +1718,10 @@ to transpile() call.")
lr_ops
=
[]
block
=
self
.
origin_program
.
global_block
()
for
op
in
block
.
ops
:
if
int
(
op
.
attr
(
RPC_OP_ROLE_ATTR_NAME
))
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
):
role_id
=
int
(
op
.
attr
(
RPC_OP_ROLE_ATTR_NAME
))
if
role_id
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
)
or
\
role_id
==
int
(
LR_SCHED_OP_ROLE_ATTR_VALUE
)
|
\
int
(
OPT_OP_ROLE_ATTR_VALUE
):
lr_ops
.
append
(
op
)
log
(
"append lr op: "
,
op
.
type
)
return
lr_ops
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录