Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d080d3e6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d080d3e6
编写于
8月 10, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into timeline-support-pure-cpu
上级
e008600b
cf799a6a
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
369 addition
and
47 deletion
+369
-47
paddle/fluid/inference/analysis/data_flow_graph.cc
paddle/fluid/inference/analysis/data_flow_graph.cc
+28
-0
paddle/fluid/inference/analysis/data_flow_graph.h
paddle/fluid/inference/analysis/data_flow_graph.h
+1
-0
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
...fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
+1
-0
paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc
paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc
+3
-3
paddle/fluid/inference/api/api.cc
paddle/fluid/inference/api/api.cc
+22
-4
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+2
-1
paddle/fluid/operators/elementwise_op_function.h
paddle/fluid/operators/elementwise_op_function.h
+6
-6
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+209
-9
python/paddle/dataset/conll05.py
python/paddle/dataset/conll05.py
+4
-4
python/paddle/dataset/wmt14.py
python/paddle/dataset/wmt14.py
+1
-1
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+78
-6
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+9
-8
python/paddle/v2/dataset/conll05.py
python/paddle/v2/dataset/conll05.py
+4
-4
python/paddle/v2/dataset/wmt14.py
python/paddle/v2/dataset/wmt14.py
+1
-1
未找到文件。
paddle/fluid/inference/analysis/data_flow_graph.cc
浏览文件 @
d080d3e6
...
...
@@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
std
::
vector
<
Node
*>
(
outputs
.
begin
(),
outputs
.
end
()));
}
void
FilterRedundantOutputOfSubGraph
(
DataFlowGraph
*
graph
)
{
std
::
vector
<
Node
*>
op_nodes
;
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
graph
).
nodes_in_TS
())
{
if
(
node
.
type
()
==
Node
::
Type
::
kValue
||
node
.
deleted
())
{
continue
;
}
op_nodes
.
push_back
(
&
node
);
}
size_t
op_num
=
op_nodes
.
size
();
for
(
size_t
i
=
0
;
i
<
op_num
;
i
++
)
{
if
(
op_nodes
[
i
]
->
type
()
==
Node
::
Type
::
kFunction
)
continue
;
std
::
unordered_set
<
std
::
string
>
follow_up_input_names
;
for
(
size_t
j
=
i
+
1
;
j
<
op_num
;
j
++
)
{
for
(
auto
*
in
:
op_nodes
[
j
]
->
inlinks
)
{
follow_up_input_names
.
insert
(
in
->
name
());
}
}
std
::
vector
<
Node
*>
filtered_subgraph_outlinks
;
for
(
auto
*
out
:
op_nodes
[
i
]
->
outlinks
)
{
if
(
follow_up_input_names
.
count
(
out
->
name
()))
{
filtered_subgraph_outlinks
.
push_back
(
out
);
}
}
PADDLE_ENFORCE_GE
(
filtered_subgraph_outlinks
.
size
(),
1UL
);
op_nodes
[
i
]
->
outlinks
=
filtered_subgraph_outlinks
;
}
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/analysis/data_flow_graph.h
浏览文件 @
d080d3e6
...
...
@@ -178,6 +178,7 @@ struct GraphTraits<DataFlowGraph> {
std
::
pair
<
std
::
vector
<
Node
*>
,
std
::
vector
<
Node
*>>
ExtractInputAndOutputOfSubGraph
(
std
::
vector
<
Node
*>
&
graph
);
// NOLINT
void
FilterRedundantOutputOfSubGraph
(
DataFlowGraph
*
graph
);
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
浏览文件 @
d080d3e6
...
...
@@ -52,6 +52,7 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
bool
DataFlowGraphToFluidPass
::
Finalize
()
{
return
true
;
}
void
DataFlowGraphToFluidPass
::
Run
(
DataFlowGraph
*
graph
)
{
FilterRedundantOutputOfSubGraph
(
graph
);
LOG
(
INFO
)
<<
"graph.inputs "
<<
graph
->
inputs
.
size
();
for
(
auto
&
node
:
GraphTraits
<
DataFlowGraph
>
(
graph
).
nodes_in_TS
())
{
if
(
node
.
deleted
())
continue
;
...
...
paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc
浏览文件 @
d080d3e6
...
...
@@ -46,9 +46,9 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
for
(
size_t
i
=
0
;
i
<
graph
->
nodes
.
size
();
i
++
)
{
const
Node
&
node
=
graph
->
nodes
.
Get
(
i
);
if
(
!
config_
.
display_deleted_node
&&
node
.
deleted
())
continue
;
for
(
auto
&
in
:
node
.
in
links
)
{
if
(
!
config_
.
display_deleted_node
&&
in
->
deleted
())
continue
;
dot
.
AddEdge
(
in
->
repr
(),
node
.
repr
(),
{});
for
(
auto
&
out
:
node
.
out
links
)
{
if
(
!
config_
.
display_deleted_node
&&
out
->
deleted
())
continue
;
dot
.
AddEdge
(
node
.
repr
(),
out
->
repr
(),
{});
}
}
return
dot
.
Build
();
...
...
paddle/fluid/inference/api/api.cc
浏览文件 @
d080d3e6
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace
paddle
{
...
...
@@ -40,19 +41,36 @@ PaddleBuf::PaddleBuf(PaddleBuf&& other)
PaddleBuf
::
PaddleBuf
(
const
PaddleBuf
&
other
)
{
*
this
=
other
;
}
PaddleBuf
&
PaddleBuf
::
operator
=
(
const
PaddleBuf
&
other
)
{
if
(
!
other
.
memory_owned_
)
{
data_
=
other
.
data_
;
length_
=
other
.
length_
;
memory_owned_
=
other
.
memory_owned_
;
}
else
{
Resize
(
other
.
length
());
memcpy
(
data_
,
other
.
data
(),
other
.
length
());
length_
=
other
.
length
();
memory_owned_
=
true
;
}
return
*
this
;
}
PaddleBuf
&
PaddleBuf
::
operator
=
(
PaddleBuf
&&
other
)
{
// only the buffer with external memory can be copied
assert
(
!
other
.
memory_owned_
);
data_
=
other
.
data_
;
length_
=
other
.
length_
;
memory_owned_
=
other
.
memory_owned_
;
other
.
data_
=
nullptr
;
other
.
length_
=
0
;
other
.
memory_owned_
=
false
;
return
*
this
;
}
void
PaddleBuf
::
Resize
(
size_t
length
)
{
// Only the owned memory can be reset, the external memory can't be changed.
if
(
length_
==
length
)
return
;
assert
(
memory_owned_
);
Free
();
if
(
memory_owned_
)
{
Free
();
}
data_
=
new
char
[
length
];
length_
=
length
;
memory_owned_
=
true
;
...
...
@@ -68,7 +86,7 @@ void PaddleBuf::Reset(void* data, size_t length) {
void
PaddleBuf
::
Free
()
{
if
(
memory_owned_
&&
data_
)
{
assert
(
length_
>
0
);
delete
static_cast
<
char
*>
(
data_
);
delete
[]
static_cast
<
char
*>
(
data_
);
data_
=
nullptr
;
length_
=
0
;
}
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
d080d3e6
...
...
@@ -40,11 +40,12 @@ class PaddleBuf {
// Copy only available when memory is managed externally.
explicit
PaddleBuf
(
const
PaddleBuf
&
);
PaddleBuf
&
operator
=
(
const
PaddleBuf
&
);
PaddleBuf
&
operator
=
(
PaddleBuf
&&
);
// Do not own the memory.
PaddleBuf
(
void
*
data
,
size_t
length
)
:
data_
(
data
),
length_
(
length
),
memory_owned_
{
false
}
{}
// Own memory.
explicit
PaddleBuf
(
size_t
length
)
PaddleBuf
(
size_t
length
)
:
data_
(
new
char
[
length
]),
length_
(
length
),
memory_owned_
(
true
)
{}
// Resize to `length` bytes.
void
Resize
(
size_t
length
);
...
...
paddle/fluid/operators/elementwise_op_function.h
浏览文件 @
d080d3e6
...
...
@@ -534,8 +534,8 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
const
framework
::
Tensor
&
dout
,
int
axis
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
,
DX_OP
dx_op
,
DY_OP
dy_op
)
{
const
framework
::
DDim
x_dim
=
x
.
dims
();
const
framework
::
DDim
y_dim
=
y
.
dims
();
const
framework
::
DDim
&
x_dim
=
x
.
dims
();
const
framework
::
DDim
&
y_dim
=
y
.
dims
();
if
(
x
.
dims
()
==
y
.
dims
())
{
ElemwiseGradComputeNoBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
x_dim
,
y_dim
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
...
...
@@ -558,19 +558,19 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
,
DX_OP
dx_op
,
DY_OP
dy_op
)
{
if
(
dy
==
nullptr
)
{
const
framework
::
DDim
dx_dims
=
dout
.
dims
();
const
framework
::
DDim
&
dx_dims
=
dout
.
dims
();
auto
dy_dims
=
dx_dims
;
ElemwiseGradComputeNoBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
dx_dims
,
dy_dims
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
else
{
if
(
dout
.
dims
()
==
dy
->
dims
())
{
const
framework
::
DDim
dx_dims
=
dout
.
dims
();
const
framework
::
DDim
dy_dims
=
dy
->
dims
();
const
framework
::
DDim
&
dx_dims
=
dout
.
dims
();
const
framework
::
DDim
&
dy_dims
=
dy
->
dims
();
ElemwiseGradComputeNoBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
dx_dims
,
dy_dims
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
else
{
// Y is a scalar
auto
dx_dims
=
dout
.
dims
();
const
framework
::
DDim
dy_dims
=
dy
->
dims
();
const
framework
::
DDim
&
dy_dims
=
dy
->
dims
();
ElemwiseGradComputeWithBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
dx_dims
,
dy_dims
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
d080d3e6
/* Copyright (c) 201
6
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -14,6 +14,8 @@ limitations under the License. */
#define EIGEN_USE_GPU
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
namespace
paddle
{
...
...
@@ -53,8 +55,196 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
logit_grad
[
ids
]
=
loss_grad
[
row_ids
]
*
(
logit_grad
[
ids
]
-
labels
[
ids
]);
}
}
}
// namespace
static
__device__
__forceinline__
float
real_exp
(
float
x
)
{
return
expf
(
x
);
}
static
__device__
__forceinline__
double
real_exp
(
double
x
)
{
return
exp
(
x
);
}
static
__device__
__forceinline__
float
real_log
(
float
x
)
{
return
math
::
TolerableValue
<
float
>
()(
logf
(
x
));
}
static
__device__
__forceinline__
double
real_log
(
double
x
)
{
return
math
::
TolerableValue
<
double
>
()(
log
(
x
));
}
/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
* and loss **/
/*
Supposing the x is `logits` and y is `labels`, the equations are as
followings:
cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
= \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
= \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
= \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
= \sum_{j}(-y_i_j * tmp_i_j)
softmax_i_j = e^{tmp_i_j}
where:
max_i = \max_{j}{x_i_j}
logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/
// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template
<
typename
T
,
int
BlockDim
>
using
BlockReduce
=
cub
::
BlockReduce
<
T
,
BlockDim
/*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/
>
;
template
<
typename
T
,
int
BlockDim
>
using
BlockReduceTempStorage
=
typename
BlockReduce
<
T
,
BlockDim
>::
TempStorage
;
// Make sure that BlockDim <= feature_size
// This kernel is used to calculate the max element of each row
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForMax
(
const
T
*
logits_data
,
T
*
max_data
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
auto
end_idx
=
feature_size
*
(
blockIdx
.
x
+
1
);
T
cur_max
=
logits_data
[
beg_idx
];
beg_idx
+=
BlockDim
;
while
(
beg_idx
<
end_idx
)
{
if
(
cur_max
<
logits_data
[
beg_idx
])
{
cur_max
=
logits_data
[
beg_idx
];
}
beg_idx
+=
BlockDim
;
}
cur_max
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
cur_max
,
cub
::
Max
());
if
(
threadIdx
.
x
==
0
)
{
max_data
[
blockIdx
.
x
]
=
cur_max
<
-
64
?
-
64
:
cur_max
;
}
}
// Make sure that BlockDim <= feature_size
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForDiffMaxSum
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
auto
end_idx
=
feature_size
*
(
blockIdx
.
x
+
1
);
auto
block_max
=
max_data
[
blockIdx
.
x
];
softmax
[
beg_idx
]
=
logits_data
[
beg_idx
]
-
block_max
;
T
diff_max_sum
=
real_exp
(
softmax
[
beg_idx
]);
beg_idx
+=
BlockDim
;
while
(
beg_idx
<
end_idx
)
{
softmax
[
beg_idx
]
=
logits_data
[
beg_idx
]
-
block_max
;
diff_max_sum
+=
real_exp
(
softmax
[
beg_idx
]);
beg_idx
+=
BlockDim
;
}
diff_max_sum
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
diff_max_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
real_log
(
diff_max_sum
);
}
// Make sure that BlockDim <= feature_size
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
loss_data
,
T
*
softmax
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
auto
end_idx
=
feature_size
*
(
blockIdx
.
x
+
1
);
// log_diff_max_sum shares memory with loss
auto
block_log_diff_max_sum
=
loss_data
[
blockIdx
.
x
];
auto
tmp
=
softmax
[
beg_idx
]
-
block_log_diff_max_sum
;
softmax
[
beg_idx
]
=
real_exp
(
tmp
);
auto
loss
=
-
labels_data
[
beg_idx
]
*
tmp
;
beg_idx
+=
BlockDim
;
while
(
beg_idx
<
end_idx
)
{
tmp
=
softmax
[
beg_idx
]
-
block_log_diff_max_sum
;
softmax
[
beg_idx
]
=
real_exp
(
tmp
);
loss
-=
(
labels_data
[
beg_idx
]
*
tmp
);
beg_idx
+=
BlockDim
;
}
loss
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
loss
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
loss_data
[
blockIdx
.
x
]
=
loss
;
}
template
<
typename
T
>
__global__
void
SetSoftmaxToOneWhenFeatureSizeIsOne
(
T
*
out
,
int
batch_size
)
{
auto
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
batch_size
)
out
[
idx
]
=
static_cast
<
T
>
(
1
);
}
template
<
typename
T
>
static
void
SoftmaxWithCrossEntropyFusedKernel
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
softmax_data
,
T
*
loss_data
,
int
batch_size
,
int
feature_size
,
cudaStream_t
stream
)
{
constexpr
int
kMaxBlockDim
=
512
;
int
block_dim
=
feature_size
>=
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
feature_size
)));
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \
RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, feature_size); \
RowReductionForDiffMaxSum<T, \
BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, softmax_data, feature_size); \
RowReductionForSoftmaxAndCrossEntropy< \
T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, labels_data, loss_data, softmax_data, feature_size); \
break
switch
(
block_dim
)
{
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
512
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
256
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
128
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
64
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
32
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
16
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
8
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
4
);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
2
);
case
1
:
SetSoftmaxToOneWhenFeatureSizeIsOne
<<<
(
batch_size
+
kMaxBlockDim
-
1
)
/
kMaxBlockDim
,
kMaxBlockDim
,
0
,
stream
>>>
(
softmax_data
,
batch_size
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
,
stream
);
break
;
default:
PADDLE_THROW
(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"
);
break
;
}
#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template
<
typename
T
>
class
SoftmaxWithCrossEntropyCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -66,14 +256,24 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
Tensor
*
softmax
=
context
.
Output
<
Tensor
>
(
"Softmax"
);
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"Loss"
);
softmax
->
mutable_data
<
T
>
(
context
.
GetPlace
());
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
context
.
Attr
<
bool
>
(
"soft_label"
));
auto
*
softmax_data
=
softmax
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
loss_data
=
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
soft_label
=
context
.
Attr
<
bool
>
(
"soft_label"
);
if
(
soft_label
)
{
int
batch_size
=
logits
->
dims
()[
0
];
int
feature_size
=
logits
->
dims
()[
1
];
auto
*
logits_data
=
logits
->
data
<
T
>
();
auto
*
labels_data
=
labels
->
data
<
T
>
();
SoftmaxWithCrossEntropyFusedKernel
(
logits_data
,
labels_data
,
softmax_data
,
loss_data
,
batch_size
,
feature_size
,
context
.
cuda_device_context
().
stream
());
}
else
{
math
::
SoftmaxCUDNNFunctor
<
T
>
()(
context
.
cuda_device_context
(),
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
false
);
}
}
};
...
...
python/paddle/dataset/conll05.py
浏览文件 @
d080d3e6
...
...
@@ -29,13 +29,13 @@ __all__ = ['test, get_dict', 'get_embedding', 'convert']
DATA_URL
=
'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz'
DATA_MD5
=
'387719152ae52d60422c016e92a742fc'
WORDDICT_URL
=
'http://paddle
paddle.bj.bcebos.com/demo/srl_dict_and_embedding/
wordDict.txt'
WORDDICT_URL
=
'http://paddle
models.bj.bcebos.com/conll05st%2F
wordDict.txt'
WORDDICT_MD5
=
'ea7fb7d4c75cc6254716f0177a506baa'
VERBDICT_URL
=
'http://paddle
paddle.bj.bcebos.com/demo/srl_dict_and_embedding/
verbDict.txt'
VERBDICT_URL
=
'http://paddle
models.bj.bcebos.com/conll05st%2F
verbDict.txt'
VERBDICT_MD5
=
'0d2977293bbb6cbefab5b0f97db1e77c'
TRGDICT_URL
=
'http://paddle
paddle.bj.bcebos.com/demo/srl_dict_and_embedding/
targetDict.txt'
TRGDICT_URL
=
'http://paddle
models.bj.bcebos.com/conll05st%2F
targetDict.txt'
TRGDICT_MD5
=
'd8c7f03ceb5fc2e5a0fa7503a4353751'
EMB_URL
=
'http://paddle
paddle.bj.bcebos.com/demo/srl_dict_and_embedding/
emb'
EMB_URL
=
'http://paddle
models.bj.bcebos.com/conll05st%2F
emb'
EMB_MD5
=
'bf436eb0faa1f6f9103017f8be57cdb7'
UNK_IDX
=
0
...
...
python/paddle/dataset/wmt14.py
浏览文件 @
d080d3e6
...
...
@@ -40,7 +40,7 @@ URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/'
'wmt_shrinked_data/wmt14.tgz'
)
MD5_TRAIN
=
'0791583d57d5beb693b9414c5b36798c'
# BLEU of this trained model is 26.92
URL_MODEL
=
'http://paddle
paddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.
gz'
URL_MODEL
=
'http://paddle
models.bj.bcebos.com/wmt%2Fwmt14.t
gz'
MD5_MODEL
=
'0cb4a5366189b6acba876491c8724fa3'
START
=
"<s>"
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
d080d3e6
...
...
@@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase):
self
.
origin_prog
=
main
.
clone
()
return
main
def
get_trainer
(
self
,
config
=
None
):
t
=
self
.
_transpiler_instance
(
config
)
def
get_trainer
(
self
,
config
=
None
,
sync_mode
=
True
):
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
)
return
t
.
get_trainer_program
()
def
get_pserver
(
self
,
ep
,
config
=
None
):
t
=
self
.
_transpiler_instance
(
config
)
def
get_pserver
(
self
,
ep
,
config
=
None
,
sync_mode
=
True
):
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
)
pserver
=
t
.
get_pserver_program
(
ep
)
startup
=
t
.
get_startup_program
(
ep
,
pserver
)
return
pserver
,
startup
def
_transpiler_instance
(
self
,
config
=
None
):
def
_transpiler_instance
(
self
,
config
=
None
,
sync_mode
=
True
):
if
not
self
.
transpiler
:
main
=
self
.
get_main_program
()
self
.
transpiler
=
fluid
.
DistributeTranspiler
(
config
=
config
)
...
...
@@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase):
self
.
trainer_id
,
program
=
main
,
pservers
=
self
.
pserver_eps
,
trainers
=
self
.
trainers
)
trainers
=
self
.
trainers
,
sync_mode
=
sync_mode
)
return
self
.
transpiler
...
...
@@ -464,5 +465,76 @@ class TestDistLookupTable(TestDistLookupTableBase):
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
blocks
[
0
].
ops
],
ops
)
class
TestAsyncLocalLookupTable
(
TestDistLookupTableBase
):
def
net_conf
(
self
):
self
.
network_with_table
(
is_sparse
=
True
,
is_distributed
=
False
)
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
pserver1
,
startup1
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
,
False
)
self
.
assertEqual
(
len
(
pserver1
.
blocks
),
3
)
# 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam
self
.
assertEqual
([
op
.
type
for
op
in
pserver1
.
blocks
[
1
].
ops
],
[
"adam"
,
"scale"
,
"scale"
])
# 2 optimize for table adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self
.
assertEqual
([
op
.
type
for
op
in
pserver1
.
blocks
[
2
].
ops
],
[
"adam"
,
"scale"
,
"scale"
])
trainer
=
self
.
get_trainer
(
config
)
self
.
assertEqual
(
len
(
trainer
.
blocks
),
1
)
ops
=
[
'lookup_table'
,
'sequence_pool'
,
'lookup_table'
,
'sequence_pool'
,
'concat'
,
'mul'
,
'elementwise_add'
,
'cross_entropy'
,
'mean'
,
'fill_constant'
,
'mean_grad'
,
'cross_entropy_grad'
,
'elementwise_add_grad'
,
'send'
,
'mul_grad'
,
'send'
,
'concat_grad'
,
'sequence_pool_grad'
,
'lookup_table_grad'
,
'sequence_pool_grad'
,
'lookup_table_grad'
,
'sum'
,
'split_selected_rows'
,
'send'
,
'recv'
,
'recv'
,
'recv'
,
'concat'
]
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
blocks
[
0
].
ops
],
ops
)
class
TestAsyncDistLookupTable
(
TestDistLookupTableBase
):
def
net_conf
(
self
):
self
.
network_with_table
(
is_sparse
=
True
,
is_distributed
=
True
)
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
pserver1
,
startup1
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
,
False
)
self
.
assertEqual
(
len
(
pserver1
.
blocks
),
6
)
# 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam
self
.
assertEqual
([
op
.
type
for
op
in
pserver1
.
blocks
[
1
].
ops
],
[
"adam"
,
"scale"
,
"scale"
])
# 2 optimize for table sgd
self
.
assertEqual
([
op
.
type
for
op
in
pserver1
.
blocks
[
2
].
ops
],
[
"sgd"
])
# 3 prefetch -> lookup_sparse_table for data0
self
.
assertEqual
([
op
.
type
for
op
in
pserver1
.
blocks
[
3
].
ops
],
[
"lookup_sparse_table"
])
# 4 prefetch -> lookup_sparse_table for data1
self
.
assertEqual
([
op
.
type
for
op
in
pserver1
.
blocks
[
4
].
ops
],
[
"lookup_sparse_table"
])
# 5 save table
self
.
assertEqual
([
op
.
type
for
op
in
pserver1
.
blocks
[
5
].
ops
],
[
"save"
])
trainer
=
self
.
get_trainer
(
config
)
self
.
assertEqual
(
len
(
trainer
.
blocks
),
1
)
ops
=
[
'split_ids'
,
'prefetch'
,
'merge_ids'
,
'sequence_pool'
,
'split_ids'
,
'prefetch'
,
'merge_ids'
,
'sequence_pool'
,
'concat'
,
'mul'
,
'elementwise_add'
,
'cross_entropy'
,
'mean'
,
'fill_constant'
,
'mean_grad'
,
'cross_entropy_grad'
,
'elementwise_add_grad'
,
'send'
,
'mul_grad'
,
'send'
,
'concat_grad'
,
'sequence_pool_grad'
,
'lookup_table_grad'
,
'sequence_pool_grad'
,
'lookup_table_grad'
,
'sum'
,
'split_ids'
,
'send'
,
'recv'
,
'recv'
]
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
blocks
[
0
].
ops
],
ops
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
d080d3e6
...
...
@@ -293,14 +293,15 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
program
.
global_block
().
append_op
(
type
=
"fetch_barrier"
,
inputs
=
{},
outputs
=
{},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
if
self
.
sync_mode
:
program
.
global_block
().
append_op
(
type
=
"fetch_barrier"
,
inputs
=
{},
outputs
=
{},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
for
varname
,
splited_var
in
self
.
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
...
...
python/paddle/v2/dataset/conll05.py
浏览文件 @
d080d3e6
...
...
@@ -29,13 +29,13 @@ __all__ = ['test, get_dict', 'get_embedding', 'convert']
DATA_URL
=
'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz'
DATA_MD5
=
'387719152ae52d60422c016e92a742fc'
WORDDICT_URL
=
'http://paddle
paddle.bj.bcebos.com/demo/srl_dict_and_embedding/
wordDict.txt'
WORDDICT_URL
=
'http://paddle
models.bj.bcebos.com/conll05st%2F
wordDict.txt'
WORDDICT_MD5
=
'ea7fb7d4c75cc6254716f0177a506baa'
VERBDICT_URL
=
'http://paddle
paddle.bj.bcebos.com/demo/srl_dict_and_embedding/
verbDict.txt'
VERBDICT_URL
=
'http://paddle
models.bj.bcebos.com/conll05st%2F
verbDict.txt'
VERBDICT_MD5
=
'0d2977293bbb6cbefab5b0f97db1e77c'
TRGDICT_URL
=
'http://paddle
paddle.bj.bcebos.com/demo/srl_dict_and_embedding/
targetDict.txt'
TRGDICT_URL
=
'http://paddle
models.bj.bcebos.com/conll05st%2F
targetDict.txt'
TRGDICT_MD5
=
'd8c7f03ceb5fc2e5a0fa7503a4353751'
EMB_URL
=
'http://paddle
paddle.bj.bcebos.com/demo/srl_dict_and_embedding/
emb'
EMB_URL
=
'http://paddle
models.bj.bcebos.com/conll05st%2F
emb'
EMB_MD5
=
'bf436eb0faa1f6f9103017f8be57cdb7'
UNK_IDX
=
0
...
...
python/paddle/v2/dataset/wmt14.py
浏览文件 @
d080d3e6
...
...
@@ -41,7 +41,7 @@ URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/'
'wmt_shrinked_data/wmt14.tgz'
)
MD5_TRAIN
=
'0791583d57d5beb693b9414c5b36798c'
# BLEU of this trained model is 26.92
URL_MODEL
=
'http://paddle
paddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.
gz'
URL_MODEL
=
'http://paddle
models.bj.bcebos.com/wmt%2Fwmt14.t
gz'
MD5_MODEL
=
'0cb4a5366189b6acba876491c8724fa3'
START
=
"<s>"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录