Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a6edeb39
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a6edeb39
编写于
5月 03, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into feature/clean_blas
上级
caa4027d
73650a83
变更
83
隐藏空白更改
内联
并排
Showing
83 changed file
with
1900 addition
and
672 deletion
+1900
-672
Dockerfile
Dockerfile
+1
-1
benchmark/cluster/vgg16/run_vgg_dist.sh
benchmark/cluster/vgg16/run_vgg_dist.sh
+21
-0
benchmark/cluster/vgg16/vgg16_fluid.py
benchmark/cluster/vgg16/vgg16_fluid.py
+8
-9
paddle/capi/Matrix.cpp
paddle/capi/Matrix.cpp
+1
-1
paddle/cuda/include/hl_base.h
paddle/cuda/include/hl_base.h
+18
-6
paddle/cuda/src/hl_cuda_lstm.cu
paddle/cuda/src/hl_cuda_lstm.cu
+9
-5
paddle/cuda/src/hl_top_k.cu
paddle/cuda/src/hl_top_k.cu
+4
-1
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+5
-5
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+2
-2
paddle/fluid/framework/lod_tensor_test.cc
paddle/fluid/framework/lod_tensor_test.cc
+2
-2
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+37
-4
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+2
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+3
-3
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+1
-1
paddle/fluid/framework/selected_rows.cc
paddle/fluid/framework/selected_rows.cc
+5
-5
paddle/fluid/framework/selected_rows.h
paddle/fluid/framework/selected_rows.h
+5
-3
paddle/fluid/framework/selected_rows_test.cc
paddle/fluid/framework/selected_rows_test.cc
+4
-4
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+1
-1
paddle/fluid/inference/tensorrt/helper.h
paddle/fluid/inference/tensorrt/helper.h
+5
-5
paddle/fluid/inference/tensorrt/test_tensorrt.cc
paddle/fluid/inference/tensorrt/test_tensorrt.cc
+13
-13
paddle/fluid/operators/batch_norm_mkldnn_op.cc
paddle/fluid/operators/batch_norm_mkldnn_op.cc
+325
-0
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+42
-8
paddle/fluid/operators/batch_norm_op.cu.cc
paddle/fluid/operators/batch_norm_op.cu.cc
+3
-1
paddle/fluid/operators/cross_entropy_op.cc
paddle/fluid/operators/cross_entropy_op.cc
+6
-4
paddle/fluid/operators/cross_entropy_op.cu
paddle/fluid/operators/cross_entropy_op.cu
+6
-93
paddle/fluid/operators/cross_entropy_op.h
paddle/fluid/operators/cross_entropy_op.h
+77
-40
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+22
-6
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+6
-1
paddle/fluid/operators/detail/serde_test.cc
paddle/fluid/operators/detail/serde_test.cc
+1
-1
paddle/fluid/operators/elementwise_op_function.h
paddle/fluid/operators/elementwise_op_function.h
+3
-39
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+69
-38
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+13
-5
paddle/fluid/operators/lookup_sparse_table_op.cc
paddle/fluid/operators/lookup_sparse_table_op.cc
+165
-0
paddle/fluid/operators/math/cross_entropy.cu
paddle/fluid/operators/math/cross_entropy.cu
+10
-55
paddle/fluid/operators/math/pooling.cc
paddle/fluid/operators/math/pooling.cc
+55
-52
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+50
-34
paddle/fluid/operators/math/pooling.h
paddle/fluid/operators/math/pooling.h
+49
-34
paddle/fluid/operators/math/sequence_padding.cc
paddle/fluid/operators/math/sequence_padding.cc
+8
-8
paddle/fluid/operators/math/sequence_padding.cu
paddle/fluid/operators/math/sequence_padding.cu
+13
-12
paddle/fluid/operators/math/sequence_padding.h
paddle/fluid/operators/math/sequence_padding.h
+3
-2
paddle/fluid/operators/math/sequence_padding_test.cc
paddle/fluid/operators/math/sequence_padding_test.cc
+2
-2
paddle/fluid/operators/momentum_op.cc
paddle/fluid/operators/momentum_op.cc
+8
-0
paddle/fluid/operators/mul_op.cc
paddle/fluid/operators/mul_op.cc
+4
-2
paddle/fluid/operators/mul_op.cu.cc
paddle/fluid/operators/mul_op.cu.cc
+3
-1
paddle/fluid/operators/row_conv_op.cu
paddle/fluid/operators/row_conv_op.cu
+1
-1
paddle/fluid/operators/save_load_op_test.cc
paddle/fluid/operators/save_load_op_test.cc
+33
-0
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+20
-1
paddle/fluid/operators/scale_op.cc
paddle/fluid/operators/scale_op.cc
+4
-6
paddle/fluid/operators/send_recv_op_test.cc
paddle/fluid/operators/send_recv_op_test.cc
+30
-13
paddle/fluid/operators/sgd_op.cc
paddle/fluid/operators/sgd_op.cc
+20
-1
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+4
-2
paddle/fluid/operators/softmax_op.cu.cc
paddle/fluid/operators/softmax_op.cu.cc
+4
-2
paddle/fluid/operators/top_k_op.cc
paddle/fluid/operators/top_k_op.cc
+2
-1
paddle/fluid/operators/top_k_op.cu
paddle/fluid/operators/top_k_op.cu
+2
-1
paddle/fluid/operators/uniform_random_op.cc
paddle/fluid/operators/uniform_random_op.cc
+22
-2
paddle/fluid/operators/warpctc_op.h
paddle/fluid/operators/warpctc_op.h
+2
-2
paddle/fluid/platform/cuda_device_function.h
paddle/fluid/platform/cuda_device_function.h
+74
-0
paddle/fluid/platform/cuda_primitives.h
paddle/fluid/platform/cuda_primitives.h
+0
-17
paddle/fluid/platform/profiler.h
paddle/fluid/platform/profiler.h
+0
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+5
-5
paddle/fluid/pybind/tensor_py.h
paddle/fluid/pybind/tensor_py.h
+3
-3
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+10
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-2
python/paddle/fluid/distribute_transpiler.py
python/paddle/fluid/distribute_transpiler.py
+41
-17
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+6
-4
python/paddle/fluid/layers/math_op_patch.py
python/paddle/fluid/layers/math_op_patch.py
+2
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+7
-3
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+2
-1
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+7
-2
python/paddle/fluid/tests/book/test_fit_a_line.py
python/paddle/fluid/tests/book/test_fit_a_line.py
+1
-6
python/paddle/fluid/tests/book/test_image_classification.py
python/paddle/fluid/tests/book/test_image_classification.py
+1
-6
python/paddle/fluid/tests/book/test_label_semantic_roles.py
python/paddle/fluid/tests/book/test_label_semantic_roles.py
+1
-6
python/paddle/fluid/tests/book/test_machine_translation.py
python/paddle/fluid/tests/book/test_machine_translation.py
+1
-6
python/paddle/fluid/tests/book/test_recognize_digits.py
python/paddle/fluid/tests/book/test_recognize_digits.py
+1
-6
python/paddle/fluid/tests/book/test_recommender_system.py
python/paddle/fluid/tests/book/test_recommender_system.py
+1
-6
python/paddle/fluid/tests/book/test_understand_sentiment.py
python/paddle/fluid/tests/book/test_understand_sentiment.py
+1
-6
python/paddle/fluid/tests/book/test_word2vec.py
python/paddle/fluid/tests/book/test_word2vec.py
+1
-6
python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py
...dle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py
+147
-0
python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py
...paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py
+56
-0
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
+38
-11
python/paddle/fluid/tests/unittests/test_dist_train.py
python/paddle/fluid/tests/unittests/test_dist_train.py
+1
-1
python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py
...ddle/fluid/tests/unittests/test_lookup_sparse_table_op.py
+86
-0
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+172
-16
未找到文件。
Dockerfile
浏览文件 @
a6edeb39
...
...
@@ -32,7 +32,7 @@ RUN apt-get update && \
automake locales clang-format swig doxygen cmake
\
liblapack-dev liblapacke-dev
\
clang-3.8 llvm-3.8 libclang-3.8-dev
\
net-tools libtool
&&
\
net-tools libtool
ccache
&&
\
apt-get clean
-y
# Install Go and glide
...
...
benchmark/cluster/vgg16/run_vgg_dist.sh
0 → 100644
浏览文件 @
a6edeb39
#!/bin/bash
# Update to point to the source file.
VGG_SRC
=
"vgg16_fluid.py"
export
TRAINING_ROLE
=
PSERVER
export
TRAINERS
=
2
export
POD_IP
=
127.0.0.1
export
PADDLE_INIT_PORT
=
6174
MKL_NUM_THREADS
=
1 python
-u
${
VGG_SRC
}
--local
0
--ps_host
=
127.0.0.1:6174
--trainer_hosts
=
127.0.0.1:6174 &
# Need to wait for the ps to start first.
sleep
10
echo
"done start ps"
export
TRAINING_ROLE
=
TRAINER
export
TRAINERS
=
2
export
POD_IP
=
127.0.0.1
export
PADDLE_INIT_PORT
=
6174
CUDA_VISIBLE_DEVICES
=
4
MKL_NUM_THREADS
=
1 python
-u
${
VGG_SRC
}
--local
0
--ps_host
=
127.0.0.1:6174
--trainer_hosts
=
127.0.0.1:6174
--device
=
GPU
--task_index
=
0 &
CUDA_VISIBLE_DEVICES
=
5
MKL_NUM_THREADS
=
1 python
-u
${
VGG_SRC
}
--local
0
--ps_host
=
127.0.0.1:6174
--trainer_hosts
=
127.0.0.1:6174
--device
=
GPU
--task_index
=
1 &
benchmark/cluster/vgg16/vgg16_fluid.py
浏览文件 @
a6edeb39
...
...
@@ -200,18 +200,19 @@ def main():
num_samples
+=
len
(
data
)
train_pass_acc
.
add
(
value
=
acc
,
weight
=
b_size
)
print
(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s"
%
(
pass_id
,
iters
,
loss
,
acc
,
len
(
data
)
/
(
time
.
time
()
-
ts
))
"Task:%d Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s "
%
(
args
.
task_index
,
pass_id
,
iters
,
loss
,
acc
,
len
(
data
)
/
(
time
.
time
()
-
ts
))
)
# The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed
=
time
.
time
()
-
start_time
pass_train_acc
=
train_pass_acc
.
eval
()
pass_test_acc
=
test
(
exe
)
print
(
"Pass = %d, Training performance = %f imgs/s, Train accuracy = %f, Test accuracy = %f
\n
"
%
(
pass_id
,
num_samples
/
pass_elapsed
,
pass_train_acc
,
pass_test_acc
))
print
(
"Task:%d Pass = %d, Training performance = %f imgs/s, "
"Train accuracy = %f, Test accuracy = %f
\n
"
%
(
args
.
task_index
,
pass_id
,
num_samples
/
pass_elapsed
,
pass_t
rain_acc
,
pass_t
est_acc
))
if
args
.
local
:
# Parameter initialization
...
...
@@ -239,8 +240,6 @@ def main():
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
trainer_id
=
args
.
task_index
,
pservers
=
args
.
ps_hosts
,
trainers
=
trainers
)
...
...
paddle/capi/Matrix.cpp
浏览文件 @
a6edeb39
...
...
@@ -108,7 +108,7 @@ paddle_error paddle_matrix_get_row(paddle_matrix mat,
paddle_error
paddle_matrix_get_shape
(
paddle_matrix
mat
,
uint64_t
*
height
,
uint64_t
*
width
)
{
if
(
mat
==
nullptr
)
return
kPD_NULLPTR
;
if
(
mat
==
nullptr
||
cast
(
mat
)
->
mat
==
nullptr
)
return
kPD_NULLPTR
;
if
(
height
!=
nullptr
)
{
*
height
=
cast
(
mat
)
->
mat
->
getHeight
();
}
...
...
paddle/cuda/include/hl_base.h
浏览文件 @
a6edeb39
...
...
@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_BASE_H_
#define HL_BASE_H_
#pragma once
#include <cstddef>
...
...
@@ -207,8 +206,8 @@ typedef struct {
#ifdef __NVCC__
#include
"cuda_runtime.h"
#include "hl_cuda.h"
#include
<cuda_runtime.h>
#include "
paddle/cuda/include/
hl_cuda.h"
#include "paddle/utils/Logging.h"
extern
__thread
bool
g_sync_flag
;
...
...
@@ -228,6 +227,19 @@ extern __thread cudaStream_t default_stream;
<< "CUDA error: " << hl_get_device_error_string((size_t)err); \
}
#endif
/* __NVCC__ */
// __shfl has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_sync
(
unsigned
,
T
val
,
int
src_line
,
int
width
)
{
return
__shfl
(
val
,
src_line
,
width
);
}
#endif
/* HL_BASE_H_ */
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
#endif // __NVCC__
paddle/cuda/src/hl_cuda_lstm.cu
浏览文件 @
a6edeb39
...
...
@@ -341,12 +341,15 @@ void hl_lstm_parallel_forward(real *gateValue,
}
__device__
__forceinline__
void
transpose_32x32
(
real
a
[],
const
int
idx
)
{
int
addr
=
idx
%
32
;
const
int
warp_size
=
32
;
int
addr
=
idx
%
warp_size
;
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
addr
<
warp_size
);
#pragma unroll
for
(
int
k
=
1
;
k
<
32
;
k
++
)
{
// rSrc[k] = __shfl_sync(rSrc[k], (threadIdx.x + k) % 32, 32);
addr
=
__shfl_sync
(
addr
,
(
idx
+
1
)
%
32
,
32
);
a
[
k
]
=
__shfl_sync
(
a
[
k
],
addr
,
32
);
addr
=
__shfl_sync
(
mask
,
addr
,
(
idx
+
1
)
%
32
,
32
);
a
[
k
]
=
__shfl_sync
(
mask
,
a
[
k
],
addr
,
32
);
}
#pragma unroll
...
...
@@ -360,10 +363,11 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
}
addr
=
(
32
-
idx
)
%
32
;
CREATE_SHFL_MASK
(
mask
,
idx
%
32
<
warp_size
);
#pragma unroll
for
(
int
k
=
0
;
k
<
32
;
k
++
)
{
a
[
k
]
=
__shfl_sync
(
a
[
k
],
addr
,
32
);
addr
=
__shfl_sync
(
addr
,
(
idx
+
31
)
%
32
,
32
);
a
[
k
]
=
__shfl_sync
(
mask
,
a
[
k
],
addr
,
32
);
addr
=
__shfl_sync
(
mask
,
addr
,
(
idx
+
31
)
%
32
,
32
);
}
}
...
...
paddle/cuda/src/hl_top_k.cu
浏览文件 @
a6edeb39
...
...
@@ -244,13 +244,16 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
if
(
--
beamSize
==
0
)
break
;
__syncthreads
();
unsigned
mask
=
0u
;
// CREATE_SHFL_MASK(mask, tid < len);
if
(
tid
==
maxId
[
0
])
{
if
(
beam
<
maxLength
)
{
shTopK
[
tid
]
=
topK
[
beam
];
}
}
if
(
maxId
[
0
]
/
32
==
warp
)
{
if
(
__shfl_sync
(
beam
,
(
maxId
[
0
])
%
32
,
32
)
==
maxLength
)
break
;
if
(
__shfl_sync
(
mask
,
beam
,
(
maxId
[
0
])
%
32
,
32
)
==
maxLength
)
break
;
}
}
}
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
a6edeb39
...
...
@@ -34,7 +34,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
skip_scale_loss
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
use_default_grad_scale
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
...
...
@@ -45,7 +45,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
skip_scale_loss
)
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
use_default_grad_scale
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
)
{
...
...
@@ -53,7 +53,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
for
(
auto
&
p
:
params
)
{
grad_names_
.
insert
(
GradVarName
(
p
));
}
skip_scale_loss_
=
skip_scale_loss
;
use_default_grad_scale_
=
use_default_grad_scale
;
}
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
SSAGraph
*
result
,
...
...
@@ -126,8 +126,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
else
if
(
IsDistTrainOp
(
*
op
,
send_op
))
{
CreateComputationalOps
(
&
result
,
*
op
,
1
);
}
else
if
(
IsScaleLossOp
(
*
op
))
{
// user can customize loss@grad if
skip_scale_loss
_
if
(
!
skip_scale_loss
_
)
{
// user can customize loss@grad if
not use_default_grad_scale
_
if
(
use_default_grad_scale
_
)
{
CreateScaleLossGradOp
(
&
result
);
}
is_forwarding
=
false
;
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
a6edeb39
...
...
@@ -41,7 +41,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
skip_scale_loss
);
bool
use_default_grad_scale
);
#endif
std
::
unique_ptr
<
SSAGraph
>
Build
(
const
ProgramDesc
&
program
)
const
override
;
...
...
@@ -59,7 +59,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
bool
skip_scale_loss
_
;
bool
use_default_grad_scale
_
;
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
...
...
paddle/fluid/framework/lod_tensor_test.cc
浏览文件 @
a6edeb39
...
...
@@ -255,11 +255,11 @@ TEST(LoDTensor, RecordIO) {
std
::
unique_ptr
<
std
::
istream
>
stream_ptr
(
stream
);
recordio
::
Scanner
scanner
(
std
::
move
(
stream_ptr
));
auto
tensors
=
ReadFromRecordIO
(
&
scanner
,
ctx
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
ASSERT_EQ
(
tensors
.
size
(),
static_cast
<
size_t
>
(
2
)
);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
1
]);
tensors
=
ReadFromRecordIO
(
&
scanner
,
ctx
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
ASSERT_EQ
(
tensors
.
size
(),
static_cast
<
size_t
>
(
2
)
);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
1
]);
}
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
a6edeb39
...
...
@@ -93,6 +93,14 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
RunImpl
(
scope
,
place
);
}
bool
OperatorBase
::
HasInputs
(
const
std
::
string
&
name
)
const
{
if
(
inputs_
.
find
(
name
)
!=
inputs_
.
end
())
{
return
true
;
}
else
{
return
false
;
}
}
std
::
string
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
&
ins
=
Inputs
(
name
);
PADDLE_ENFORCE_LE
(
ins
.
size
(),
1UL
,
...
...
@@ -109,6 +117,14 @@ const std::vector<std::string>& OperatorBase::Inputs(
return
it
->
second
;
}
bool
OperatorBase
::
HasOutputs
(
const
std
::
string
&
name
)
const
{
if
(
outputs_
.
find
(
name
)
!=
outputs_
.
end
())
{
return
true
;
}
else
{
return
false
;
}
}
std
::
string
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
auto
&
outs
=
Outputs
(
name
);
PADDLE_ENFORCE_LE
(
outs
.
size
(),
1UL
,
...
...
@@ -220,13 +236,18 @@ void OperatorBase::CheckAllInputOutputSet() const {
if
(
op_info
==
nullptr
||
op_info
->
proto_
==
nullptr
)
return
;
for
(
auto
&
in
:
op_info
->
Proto
().
inputs
())
{
PADDLE_ENFORCE
(
inputs_
.
find
(
in
.
name
())
!=
inputs_
.
end
(),
"Type %s's input %s is not set"
,
Type
(),
in
.
name
());
if
(
!
in
.
dispensable
())
{
PADDLE_ENFORCE
(
inputs_
.
find
(
in
.
name
())
!=
inputs_
.
end
(),
"Operator %s's input, %s, is not set"
,
Type
(),
in
.
name
());
}
}
for
(
auto
&
out
:
op_info
->
Proto
().
outputs
())
{
PADDLE_ENFORCE
(
outputs_
.
find
(
out
.
name
())
!=
outputs_
.
end
(),
"Type %s's output %s is not set"
,
Type
(),
out
.
name
());
if
(
!
out
.
dispensable
())
{
PADDLE_ENFORCE
(
outputs_
.
find
(
out
.
name
())
!=
outputs_
.
end
(),
"Operator %s's output, %s, is not set"
,
Type
(),
out
.
name
());
}
}
}
...
...
@@ -332,6 +353,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
:
op_
(
op
),
scope_
(
scope
)
{}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
if
(
!
op_
.
HasInputs
(
name
))
{
return
false
;
}
auto
&
ins
=
Inputs
(
name
);
size_t
length
=
ins
.
size
();
if
(
length
==
0
)
{
...
...
@@ -345,6 +369,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
if
(
!
op_
.
HasOutputs
(
name
))
{
return
false
;
}
auto
&
outs
=
Outputs
(
name
);
size_t
length
=
outs
.
size
();
if
(
length
==
0
)
{
...
...
@@ -358,6 +385,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
if
(
!
op_
.
HasInputs
(
name
))
{
return
false
;
}
auto
inputs
=
op_
.
Inputs
(
name
);
if
(
inputs
.
empty
())
{
return
false
;
...
...
@@ -371,6 +401,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
if
(
!
op_
.
HasOutputs
(
name
))
{
return
false
;
}
auto
outputs
=
op_
.
Outputs
(
name
);
if
(
outputs
.
empty
())
{
return
false
;
...
...
paddle/fluid/framework/operator.h
浏览文件 @
a6edeb39
...
...
@@ -105,6 +105,7 @@ class OperatorBase {
const
VariableNameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
VariableNameMap
&
Outputs
()
const
{
return
outputs_
;
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
;
//! Get a input with argument's name described in `op_proto`
std
::
string
Input
(
const
std
::
string
&
name
)
const
;
//! Get a input which has multiple variables.
...
...
@@ -112,6 +113,7 @@ class OperatorBase {
//! Get all inputs variable names
std
::
vector
<
std
::
string
>
InputVars
()
const
;
bool
HasOutputs
(
const
std
::
string
&
name
)
const
;
//! Get a output with argument's name described in `op_proto`
std
::
string
Output
(
const
std
::
string
&
name
)
const
;
//! Get an output which has multiple variables.
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
a6edeb39
...
...
@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor(
const
std
::
unordered_set
<
std
::
string
>
&
bcast_vars
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
allow_op_delay
,
bool
customize_scale_loss
)
bool
use_default_grad_scale
)
:
member_
(
new
ParallelExecutorPrivate
(
places
))
{
member_
->
global_scope_
=
scope
;
...
...
@@ -93,11 +93,11 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA
details
::
MultiDevSSAGraphBuilder
builder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
customize_scale_loss
,
member_
->
nccl_ctxs_
.
get
());
use_default_grad_scale
,
member_
->
nccl_ctxs_
.
get
());
#else
details
::
MultiDevSSAGraphBuilder
builder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
customize_scale_loss
);
use_default_grad_scale
);
#endif
auto
graph
=
builder
.
Build
(
main_program
);
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
a6edeb39
...
...
@@ -40,7 +40,7 @@ class ParallelExecutor {
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
const
std
::
vector
<
Scope
*>&
local_scopes
,
bool
allow_op_delay
,
bool
customize_scale_loss
);
bool
allow_op_delay
,
bool
use_default_grad_scale
);
~
ParallelExecutor
();
...
...
paddle/fluid/framework/selected_rows.cc
浏览文件 @
a6edeb39
...
...
@@ -120,11 +120,11 @@ bool SelectedRows::HasKey(int64_t key) const {
:
true
;
}
std
::
vector
<
int64_t
>
SelectedRows
::
Get
(
std
::
vector
<
int64_t
>
keys
,
framework
::
Tensor
*
value
)
const
{
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>>
SelectedRows
::
Get
(
std
::
vector
<
int64_t
>
keys
,
framework
::
Tensor
*
value
)
const
{
PADDLE_ENFORCE
(
value
->
IsInitialized
(),
"The value tensor should be initialized."
);
std
::
vector
<
int64_t
>
non_keys
;
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>>
non_keys_pair
;
int64_t
value_width
=
value_
->
numel
()
/
value_
->
dims
()[
0
];
PADDLE_ENFORCE_EQ
(
value_width
,
value
->
numel
()
/
value
->
dims
()[
0
],
"output tensor should have the same shape with table "
...
...
@@ -133,7 +133,7 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
int64_t
index
=
Index
(
keys
[
i
]);
if
(
index
==
-
1
)
{
non_keys
.
push_back
(
keys
[
i
]
);
non_keys
_pair
.
push_back
(
std
::
make_pair
(
keys
[
i
],
static_cast
<
int64_t
>
(
i
))
);
}
else
{
framework
::
VisitDataType
(
framework
::
ToDataType
(
value_
->
type
()),
...
...
@@ -141,7 +141,7 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
index
*
value_width
,
value_width
));
}
}
return
non_keys
;
return
non_keys
_pair
;
}
bool
SelectedRows
::
Set
(
int64_t
key
,
const
framework
::
Tensor
&
value
)
{
...
...
paddle/fluid/framework/selected_rows.h
浏览文件 @
a6edeb39
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
...
...
@@ -78,10 +79,11 @@ class SelectedRows {
/*
* @brief Get value by the key list, if the
*
* @return a list of keys which does not exists in table
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
std
::
vector
<
int64_t
>
Get
(
std
::
vector
<
int64_t
>
keys
,
framework
::
Tensor
*
tensor
)
const
;
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>
>
Get
(
std
::
vector
<
int64_t
>
keys
,
framework
::
Tensor
*
value
)
const
;
/*
* @brief Set a key-value pair into the table.
...
...
paddle/fluid/framework/selected_rows_test.cc
浏览文件 @
a6edeb39
...
...
@@ -59,7 +59,7 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
ASSERT_EQ
(
selected_rows_
->
GetCompleteDims
(),
dst_tensor
.
GetCompleteDims
());
}
TEST_F
(
SelectedRowsTester
,
Table
)
{
TEST_F
(
SelectedRowsTester
,
Sparse
Table
)
{
platform
::
CPUPlace
cpu
;
SelectedRows
table
;
// initialize a sparse table
...
...
@@ -87,11 +87,11 @@ TEST_F(SelectedRowsTester, Table) {
framework
::
Tensor
get_value
;
get_value
.
mutable_data
<
float
>
(
framework
::
make_ddim
({
2
,
100
}),
cpu
);
std
::
vector
<
int64_t
>
keys
({
non_key
,
key
});
auto
non_keys
=
table
.
Get
(
keys
,
&
get_value
);
auto
non_key
_pair
s
=
table
.
Get
(
keys
,
&
get_value
);
ASSERT_EQ
(
get_value
.
data
<
float
>
()[
100
],
static_cast
<
float
>
(
10
));
ASSERT_EQ
(
non_keys
.
size
(),
static_cast
<
size_t
>
(
1
));
ASSERT_EQ
(
non_key
s
[
0
]
,
non_key
);
ASSERT_EQ
(
non_key
_pair
s
.
size
(),
static_cast
<
size_t
>
(
1
));
ASSERT_EQ
(
non_key
_pairs
[
0
].
first
,
non_key
);
}
}
// namespace framework
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
a6edeb39
...
...
@@ -65,7 +65,7 @@ class TensorRTEngine : public EngineBase {
// Initialize the inference network, so that TensorRT layers can add to this
// network.
void
InitNetwork
()
{
infer_builder_
.
reset
(
createInferBuilder
(
logger_
));
infer_builder_
.
reset
(
createInferBuilder
(
&
logger_
));
infer_network_
.
reset
(
infer_builder_
->
createNetwork
());
}
// After finishing adding ops, freeze this network and creates the executation
...
...
paddle/fluid/inference/tensorrt/helper.h
浏览文件 @
a6edeb39
...
...
@@ -46,13 +46,13 @@ const int kDataTypeSize[] = {
// The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library.
static
nvinfer1
::
IBuilder
*
createInferBuilder
(
nvinfer1
::
ILogger
&
logger
)
{
static
nvinfer1
::
IBuilder
*
createInferBuilder
(
nvinfer1
::
ILogger
*
logger
)
{
return
static_cast
<
nvinfer1
::
IBuilder
*>
(
dy
::
createInferBuilder_INTERNAL
(
&
logger
,
NV_TENSORRT_VERSION
));
dy
::
createInferBuilder_INTERNAL
(
logger
,
NV_TENSORRT_VERSION
));
}
static
nvinfer1
::
IRuntime
*
createInferRuntime
(
nvinfer1
::
ILogger
&
logger
)
{
static
nvinfer1
::
IRuntime
*
createInferRuntime
(
nvinfer1
::
ILogger
*
logger
)
{
return
static_cast
<
nvinfer1
::
IRuntime
*>
(
dy
::
createInferRuntime_INTERNAL
(
&
logger
,
NV_TENSORRT_VERSION
));
dy
::
createInferRuntime_INTERNAL
(
logger
,
NV_TENSORRT_VERSION
));
}
// A logger for create TensorRT infer builder.
...
...
@@ -80,7 +80,7 @@ class NaiveLogger : public nvinfer1::ILogger {
return
*
x
;
}
virtual
~
NaiveLogger
()
override
{}
~
NaiveLogger
()
override
{}
};
}
// namespace tensorrt
...
...
paddle/fluid/inference/tensorrt/test_tensorrt.cc
浏览文件 @
a6edeb39
...
...
@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "NvInfer.h"
#include "cuda.h"
#include "cuda_runtime_api.h"
#include "paddle/fluid/platform/dynload/tensorrt.h"
namespace
dy
=
paddle
::
platform
::
dynload
;
...
...
@@ -43,7 +43,7 @@ class Logger : public nvinfer1::ILogger {
class
ScopedWeights
{
public:
ScopedWeights
(
float
value
)
:
value_
(
value
)
{
explicit
ScopedWeights
(
float
value
)
:
value_
(
value
)
{
w
.
type
=
nvinfer1
::
DataType
::
kFLOAT
;
w
.
values
=
&
value_
;
w
.
count
=
1
;
...
...
@@ -58,13 +58,13 @@ class ScopedWeights {
// The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library.
nvinfer1
::
IBuilder
*
createInferBuilder
(
nvinfer1
::
ILogger
&
logger
)
{
nvinfer1
::
IBuilder
*
createInferBuilder
(
nvinfer1
::
ILogger
*
logger
)
{
return
static_cast
<
nvinfer1
::
IBuilder
*>
(
dy
::
createInferBuilder_INTERNAL
(
&
logger
,
NV_TENSORRT_VERSION
));
dy
::
createInferBuilder_INTERNAL
(
logger
,
NV_TENSORRT_VERSION
));
}
nvinfer1
::
IRuntime
*
createInferRuntime
(
nvinfer1
::
ILogger
&
logger
)
{
nvinfer1
::
IRuntime
*
createInferRuntime
(
nvinfer1
::
ILogger
*
logger
)
{
return
static_cast
<
nvinfer1
::
IRuntime
*>
(
dy
::
createInferRuntime_INTERNAL
(
&
logger
,
NV_TENSORRT_VERSION
));
dy
::
createInferRuntime_INTERNAL
(
logger
,
NV_TENSORRT_VERSION
));
}
const
char
*
kInputTensor
=
"input"
;
...
...
@@ -74,7 +74,7 @@ const char* kOutputTensor = "output";
nvinfer1
::
IHostMemory
*
CreateNetwork
()
{
Logger
logger
;
// Create the engine.
nvinfer1
::
IBuilder
*
builder
=
createInferBuilder
(
logger
);
nvinfer1
::
IBuilder
*
builder
=
createInferBuilder
(
&
logger
);
ScopedWeights
weights
(
2.
);
ScopedWeights
bias
(
3.
);
...
...
@@ -103,9 +103,9 @@ nvinfer1::IHostMemory* CreateNetwork() {
return
model
;
}
void
Execute
(
nvinfer1
::
IExecutionContext
&
context
,
const
float
*
input
,
void
Execute
(
nvinfer1
::
IExecutionContext
*
context
,
const
float
*
input
,
float
*
output
)
{
const
nvinfer1
::
ICudaEngine
&
engine
=
context
.
getEngine
();
const
nvinfer1
::
ICudaEngine
&
engine
=
context
->
getEngine
();
// Two binds, input and output
ASSERT_EQ
(
engine
.
getNbBindings
(),
2
);
const
int
input_index
=
engine
.
getBindingIndex
(
kInputTensor
);
...
...
@@ -119,7 +119,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input,
// Copy the input to the GPU, execute the network, and copy the output back.
ASSERT_EQ
(
0
,
cudaMemcpyAsync
(
buffers
[
input_index
],
input
,
sizeof
(
float
),
cudaMemcpyHostToDevice
,
stream
));
context
.
enqueue
(
1
,
buffers
,
stream
,
nullptr
);
context
->
enqueue
(
1
,
buffers
,
stream
,
nullptr
);
ASSERT_EQ
(
0
,
cudaMemcpyAsync
(
output
,
buffers
[
output_index
],
sizeof
(
float
),
cudaMemcpyDeviceToHost
,
stream
));
cudaStreamSynchronize
(
stream
);
...
...
@@ -136,7 +136,7 @@ TEST(TensorrtTest, BasicFunction) {
// Use the model to create an engine and an execution context.
Logger
logger
;
nvinfer1
::
IRuntime
*
runtime
=
createInferRuntime
(
logger
);
nvinfer1
::
IRuntime
*
runtime
=
createInferRuntime
(
&
logger
);
nvinfer1
::
ICudaEngine
*
engine
=
runtime
->
deserializeCudaEngine
(
model
->
data
(),
model
->
size
(),
nullptr
);
model
->
destroy
();
...
...
@@ -145,7 +145,7 @@ TEST(TensorrtTest, BasicFunction) {
// Execute the network.
float
input
=
1234
;
float
output
;
Execute
(
*
context
,
&
input
,
&
output
);
Execute
(
context
,
&
input
,
&
output
);
EXPECT_EQ
(
output
,
input
*
2
+
3
);
// Destroy the engine.
...
...
paddle/fluid/operators/batch_norm_mkldnn_op.cc
0 → 100644
浏览文件 @
a6edeb39
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
paddle
::
platform
::
MKLDNNMemDesc
;
using
mkldnn
::
memory
;
template
<
typename
T
>
using
EigenArrayMap
=
Eigen
::
Map
<
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
;
template
<
typename
T
>
using
ConstEigenArrayMap
=
Eigen
::
Map
<
const
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
;
template
<
typename
T
>
using
EigenVectorArrayMap
=
Eigen
::
Map
<
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
1
>>
;
template
<
typename
T
>
using
ConstEigenVectorArrayMap
=
Eigen
::
Map
<
const
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
1
>>
;
namespace
{
template
<
typename
T
>
struct
bn_type_traits
{
using
op_type
=
T
;
using
op_desc
=
typename
op_type
::
desc
;
using
op_prim
=
typename
op_type
::
primitive_desc
;
};
template
<
typename
T
,
typename
Container
>
void
copy_to_weights
(
T
scale_begin
,
T
scale_end
,
T
shift_begin
,
T
shift_end
,
Container
*
c
)
{
auto
it
=
std
::
begin
(
*
c
);
std
::
copy
(
scale_begin
,
scale_end
,
std
::
inserter
(
*
c
,
it
));
std
::
copy
(
shift_begin
,
shift_end
,
std
::
inserter
(
*
c
,
std
::
next
(
it
,
std
::
distance
(
scale_begin
,
scale_end
))));
}
template
<
typename
Op
,
typename
...
Args
>
void
run_batch_norm_op
(
Args
&&
...
args
)
{
Op
batch_norm_op
{
args
...};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
pipeline
.
push_back
(
batch_norm_op
);
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
template
<
typename
T
>
inline
void
*
cast_const_to_void
(
const
T
*
t
)
{
return
static_cast
<
void
*>
(
const_cast
<
T
*>
(
t
));
}
}
// namespace
template
<
typename
T
>
class
BatchNormMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_layout_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_layout"
);
auto
data_layout
=
framework
::
StringToDataLayout
(
data_layout_str
);
PADDLE_ENFORCE
(
data_layout
==
framework
::
DataLayout
::
kNCHW
,
"MKLDNN batch normalization handles only NCHW data layout"
);
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
const
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
variance
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
mkldnn_engine
=
dev_ctx
.
GetEngine
();
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
mean_out
=
ctx
.
Output
<
Tensor
>
(
"MeanOut"
);
auto
*
variance_out
=
ctx
.
Output
<
Tensor
>
(
"VarianceOut"
);
auto
*
batch_mean
=
ctx
.
Output
<
Tensor
>
(
"SavedMean"
);
auto
*
batch_variance
=
ctx
.
Output
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mean_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
variance_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
!
is_test
)
{
batch_mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
batch_variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
auto
propagation
=
is_test
==
true
?
mkldnn
::
prop_kind
::
forward_scoring
:
mkldnn
::
prop_kind
::
forward_training
;
auto
dims
=
paddle
::
framework
::
vectorize2int
(
x
->
dims
());
auto
src_md
=
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
auto
dst_md
=
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
auto
src_pd
=
mkldnn
::
memory
::
primitive_desc
{
src_md
,
mkldnn_engine
};
auto
dst_pd
=
mkldnn
::
memory
::
primitive_desc
{
dst_md
,
mkldnn_engine
};
auto
src
=
mkldnn
::
memory
{
src_pd
,
cast_const_to_void
(
x
->
data
<
T
>
())};
auto
dst
=
mkldnn
::
memory
{
dst_pd
,
y
->
data
<
T
>
()};
unsigned
flags
=
mkldnn
::
use_scale_shift
;
if
(
is_test
)
flags
|=
mkldnn
::
use_global_stats
;
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
auto
batch_norm_fwd_desc
=
bn_fwd_types
::
op_desc
{
propagation
,
src_md
,
epsilon
,
flags
};
auto
batch_norm_fwd_pd
=
bn_fwd_types
::
op_prim
{
batch_norm_fwd_desc
,
mkldnn_engine
};
const
unsigned
int
ic
=
dims
[
1
];
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
std
::
vector
<
T
>
scaleshift_data
;
scaleshift_data
.
reserve
(
scaleshift_size
);
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
auto
scaleshift_memory
=
mkldnn
::
memory
{
batch_norm_fwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
()};
if
(
is_test
)
{
auto
mean_memory
=
mkldnn
::
memory
{
batch_norm_fwd_pd
.
mean_primitive_desc
(),
cast_const_to_void
(
mean
->
data
<
T
>
())};
auto
variance_memory
=
mkldnn
::
memory
{
batch_norm_fwd_pd
.
variance_primitive_desc
(),
cast_const_to_void
(
variance
->
data
<
T
>
())};
run_batch_norm_op
<
typename
bn_fwd_types
::
op_type
>
(
batch_norm_fwd_pd
,
src
,
(
const
mkldnn
::
primitive
::
at
&
)
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
variance_memory
,
scaleshift_memory
,
dst
);
}
else
{
auto
mean_memory
=
mkldnn
::
memory
{
batch_norm_fwd_pd
.
mean_primitive_desc
(),
cast_const_to_void
(
batch_mean
->
data
<
T
>
())};
auto
variance_memory
=
mkldnn
::
memory
{
batch_norm_fwd_pd
.
variance_primitive_desc
(),
cast_const_to_void
(
batch_variance
->
data
<
T
>
())};
run_batch_norm_op
<
bn_fwd_types
::
op_type
>
(
batch_norm_fwd_pd
,
src
,
scaleshift_memory
,
dst
,
mean_memory
,
variance_memory
);
}
if
(
!
is_test
)
{
const
unsigned
int
in
=
dims
[
0
];
const
unsigned
int
sample_size
=
x
->
numel
()
/
in
/
ic
;
// saved_xx is use just in this batch of data
EigenVectorArrayMap
<
T
>
saved_mean_e
(
batch_mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ic
);
EigenVectorArrayMap
<
T
>
saved_variance_e
(
batch_variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ic
);
saved_mean_e
.
setZero
();
saved_variance_e
.
setZero
();
const
unsigned
int
x_arr_size
=
in
*
ic
;
ConstEigenArrayMap
<
T
>
x_arr
(
x
->
data
<
T
>
(),
sample_size
,
x_arr_size
);
for
(
unsigned
int
nc
=
0
;
nc
<
x_arr_size
;
++
nc
)
{
saved_mean_e
(
nc
%
ic
)
+=
x_arr
.
col
(
nc
).
sum
();
}
saved_mean_e
/=
in
*
sample_size
;
for
(
unsigned
int
nc
=
0
;
nc
<
x_arr_size
;
++
nc
)
{
saved_variance_e
(
nc
%
ic
)
+=
(
x_arr
.
col
(
nc
)
-
saved_mean_e
(
nc
%
ic
)).
matrix
().
squaredNorm
();
}
saved_variance_e
/=
in
*
sample_size
;
ConstEigenVectorArrayMap
<
T
>
mean_arr
{
mean
->
data
<
T
>
(),
ic
};
ConstEigenVectorArrayMap
<
T
>
variance_arr
{
variance
->
data
<
T
>
(),
ic
};
EigenVectorArrayMap
<
T
>
running_mean_arr
(
mean_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ic
);
EigenVectorArrayMap
<
T
>
running_var_arr
(
variance_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ic
);
auto
one_minus_momentum
=
1.
-
momentum
;
running_mean_arr
=
mean_arr
*
momentum
+
saved_mean_e
*
one_minus_momentum
;
running_var_arr
=
variance_arr
*
momentum
+
saved_variance_e
*
one_minus_momentum
;
}
}
};
template
<
typename
T
>
class
BatchNormMKLDNNGradOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_layout_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_layout"
);
auto
data_layout
=
framework
::
StringToDataLayout
(
data_layout_str
);
PADDLE_ENFORCE
(
data_layout
==
framework
::
DataLayout
::
kNCHW
,
"MKLDNN batch normalization handles only NCHW data layout"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
const
auto
*
batch_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
batch_variance
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
diff_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
diff_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
diff_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale"
));
auto
*
diff_shift
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
diff_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
diff_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
diff_shift
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dims
=
paddle
::
framework
::
vectorize2int
(
x
->
dims
());
unsigned
flags
=
mkldnn
::
use_scale_shift
|
!
mkldnn
::
use_global_stats
;
auto
src_md
=
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
auto
dst_md
=
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
auto
diff_src_md
=
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
auto
diff_dst_md
=
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
using
bn_bwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_backward
>
;
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
auto
batch_norm_fwd_desc
=
bn_fwd_types
::
op_desc
{
mkldnn
::
prop_kind
::
forward_training
,
src_md
,
epsilon
,
flags
};
auto
batch_norm_fwd_pd
=
bn_fwd_types
::
op_prim
{
batch_norm_fwd_desc
,
mkldnn_engine
};
auto
batch_norm_bwd_desc
=
bn_bwd_types
::
op_desc
{
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
dst_md
,
epsilon
,
flags
};
auto
batch_norm_bwd_pd
=
bn_bwd_types
::
op_prim
{
batch_norm_bwd_desc
,
mkldnn_engine
,
batch_norm_fwd_pd
};
auto
src
=
mkldnn
::
memory
{{
src_md
,
mkldnn_engine
},
cast_const_to_void
(
x
->
data
<
T
>
())};
auto
mean
=
mkldnn
::
memory
{
batch_norm_bwd_pd
.
mean_primitive_desc
(),
cast_const_to_void
(
batch_mean
->
data
<
T
>
())};
auto
variance
=
mkldnn
::
memory
{
batch_norm_bwd_pd
.
variance_primitive_desc
(),
cast_const_to_void
(
batch_variance
->
data
<
T
>
())};
auto
diff_dst
=
mkldnn
::
memory
{{
diff_dst_md
,
mkldnn_engine
},
cast_const_to_void
(
diff_y
->
data
<
T
>
())};
const
unsigned
int
ic
=
dims
[
1
];
const
size_t
scaleshift_size
=
2
*
ic
;
std
::
vector
<
T
>
scaleshift_data
;
scaleshift_data
.
reserve
(
scaleshift_size
);
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
auto
scaleshift_memory
=
mkldnn
::
memory
{
batch_norm_bwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
()};
std
::
vector
<
T
>
diff_scaleshift_data
;
diff_scaleshift_data
.
reserve
(
scaleshift_size
);
copy_to_weights
(
diff_scale
->
data
<
T
>
(),
diff_scale
->
data
<
T
>
()
+
ic
,
diff_shift
->
data
<
T
>
(),
diff_shift
->
data
<
T
>
()
+
ic
,
&
diff_scaleshift_data
);
auto
diff_scaleshift_memory
=
mkldnn
::
memory
{
batch_norm_bwd_pd
.
diff_weights_primitive_desc
(),
diff_scaleshift_data
.
data
()};
auto
diff_src
=
mkldnn
::
memory
{{
diff_src_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
diff_x
->
data
<
T
>
())};
run_batch_norm_op
<
bn_bwd_types
::
op_type
>
(
batch_norm_bwd_pd
,
src
,
mean
,
variance
,
diff_dst
,
scaleshift_memory
,
diff_src
,
diff_scaleshift_memory
);
auto
it
=
std
::
begin
(
diff_scaleshift_data
);
std
::
copy
(
it
,
std
::
next
(
it
,
ic
),
diff_scale
->
data
<
T
>
());
std
::
copy
(
std
::
next
(
it
,
ic
),
std
::
end
(
diff_scaleshift_data
),
diff_shift
->
data
<
T
>
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_KERNEL
(
batch_norm
,
MKLDNN
,
paddle
::
platform
::
CPUPlace
,
ops
::
BatchNormMKLDNNOpKernel
<
float
>
);
REGISTER_OP_KERNEL
(
batch_norm_grad
,
MKLDNN
,
paddle
::
platform
::
CPUPlace
,
ops
::
BatchNormMKLDNNGradOpKernel
<
float
>
);
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
a6edeb39
...
...
@@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -87,9 +90,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
// For float or float16 input tensor, the type of the scale, bias, mean,
// and var tensors should both be float.
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto
bn_param_type
=
framework
::
proto
::
VarType
::
FP32
;
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP64
)
{
bn_param_type
=
framework
::
proto
::
VarType
::
FP64
;
}
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Scale"
)
->
type
()),
"Scale input should be of float type"
);
...
...
@@ -102,7 +109,18 @@ class BatchNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Variance"
)
->
type
()),
"Variance input should be of float type"
);
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kMKLDNN
;
}
#endif
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
library_
);
}
};
...
...
@@ -147,6 +165,9 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
"Variance of the current mini batch, "
"will apply to output when training"
)
.
AsIntermediate
();
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Only used in mkldnn kernel"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Batch Normalization.
...
...
@@ -345,8 +366,19 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
if
(
t
==
nullptr
)
{
PADDLE_THROW
(
"can't find Y@GRAD"
);
}
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
t
->
type
()),
ctx
.
GetPlace
());
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kMKLDNN
;
}
#endif
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
layout
,
library_
);
}
};
...
...
@@ -470,6 +502,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
op
->
SetInput
(
framework
::
GradVarName
(
"Y"
),
OutputGrad
(
"Y"
));
op
->
SetInput
(
"Scale"
,
Input
(
"Scale"
));
op
->
SetInput
(
"Bias"
,
Input
(
"Bias"
));
op
->
SetInput
(
"SavedMean"
,
Output
(
"SavedMean"
));
op
->
SetInput
(
"SavedVariance"
,
Output
(
"SavedVariance"
));
...
...
@@ -492,8 +525,9 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
REGISTER_OPERATOR
(
batch_norm_grad
,
ops
::
BatchNormGradOp
);
REGISTER_OP_CPU_KERNEL
(
batch_norm
,
ops
::
BatchNormKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
batch_norm
,
ops
::
BatchNormKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
BatchNormKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
batch_norm_grad
,
ops
::
BatchNormGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
BatchNormGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
BatchNormGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/batch_norm_op.cu.cc
浏览文件 @
a6edeb39
...
...
@@ -287,6 +287,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
batch_norm
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
batch_norm_grad
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
batch_norm_grad
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/cross_entropy_op.cc
浏览文件 @
a6edeb39
...
...
@@ -164,11 +164,13 @@ or not. But the output only shares the LoD information with input X.
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CPUCtx
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
cross_entropy
,
ops
::
CrossEntropyOp
,
ops
::
CrossEntropyOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
cross_entropy_grad
,
ops
::
CrossEntropyGradientOp
);
REGISTER_OP_CPU_KERNEL
(
cross_entropy
,
ops
::
CrossEntropyOpKernel
<
float
>
,
ops
::
CrossEntropyOpKernel
<
double
>
);
REGISTER_OP_CPU_KERNEL
(
cross_entropy
,
ops
::
CrossEntropyOpKernel
<
CPUCtx
,
float
>
,
ops
::
CrossEntropyOpKernel
<
CPUCtx
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
cross_entropy_grad
,
ops
::
CrossEntropyGradientOpKernel
<
float
>
,
ops
::
CrossEntropyGradientOpKernel
<
double
>
);
ops
::
CrossEntropyGradientOpKernel
<
CPUCtx
,
float
>
,
ops
::
CrossEntropyGradientOpKernel
<
CPUCtx
,
double
>
);
paddle/fluid/operators/cross_entropy_op.cu
浏览文件 @
a6edeb39
...
...
@@ -14,98 +14,11 @@ limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h"
namespace
paddle
{
namespace
operators
{
namespace
{
template
<
typename
T
>
__global__
void
CrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
const
int64_t
*
label
,
const
int
N
,
const
int
D
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
idx
=
i
*
D
+
label
[
i
];
dX
[
idx
]
=
-
dY
[
i
]
/
X
[
idx
];
}
}
template
<
typename
T
>
__global__
void
SoftCrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
int
ids
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
ids
<
N
*
D
)
{
int
row_ids
=
ids
/
D
;
dX
[
ids
]
=
-
label
[
ids
]
*
dY
[
row_ids
]
/
X
[
ids
];
}
}
}
// namespace
template
<
typename
T
>
class
CrossEntropyOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on GPU device."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
y
,
x
,
label
,
ctx
.
Attr
<
bool
>
(
"soft_label"
));
}
};
template
<
typename
T
>
class
CrossEntropyGradientOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on GPU device."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
dy_data
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
))
->
data
<
T
>
();
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
x_data
=
x
->
data
<
T
>
();
int64_t
batch_size
=
x
->
dims
()[
0
];
int64_t
class_num
=
x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
batch_size
*
class_num
+
block
-
1
)
/
block
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
stream
=
dev_ctx
.
stream
();
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
*
label_data
=
label
->
data
<
T
>
();
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
else
{
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
functor
;
functor
(
dev_ctx
,
dx
,
0
);
auto
*
label_data
=
label
->
data
<
int64_t
>
();
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
cross_entropy
,
ops
::
CrossEntropyOpCUDAKernel
<
float
>
,
ops
::
CrossEntropyOpCUDAKernel
<
double
>
);
using
CUDACtx
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
cross_entropy
,
ops
::
CrossEntropyOpKernel
<
CUDACtx
,
float
>
,
ops
::
CrossEntropyOpKernel
<
CUDACtx
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
cross_entropy_grad
,
ops
::
CrossEntropyGradientOp
CUDAKernel
<
float
>
,
ops
::
CrossEntropyGradientOp
CUDAKernel
<
double
>
);
ops
::
CrossEntropyGradientOp
Kernel
<
CUDACtx
,
float
>
,
ops
::
CrossEntropyGradientOp
Kernel
<
CUDACtx
,
double
>
);
paddle/fluid/operators/cross_entropy_op.h
浏览文件 @
a6edeb39
...
...
@@ -17,69 +17,106 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
CrossEntropyOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on CPU."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
CrossEntropyFunctor
<
platform
::
CPU
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
platform
::
CPU
DeviceContext
>(),
y
,
x
,
labels
,
math
::
CrossEntropyFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
y
,
x
,
labels
,
ctx
.
Attr
<
bool
>
(
"soft_label"
));
}
};
template
<
typename
T
>
class
XeSoftlabelGradFunctor
{
public:
XeSoftlabelGradFunctor
(
T
*
dx
,
const
T
*
dy
,
// NOLINT
const
T
*
x
,
// NOLINT
const
T
*
label
,
// NOLINT
size_t
num_classes
)
:
dx_
(
dx
),
dy_
(
dy
),
x_
(
x
),
label_
(
label
),
num_classes_
(
num_classes
)
{}
HOSTDEVICE
void
operator
()(
size_t
i
)
{
auto
row_ids
=
i
/
num_classes_
;
dx_
[
i
]
=
-
label_
[
i
]
*
dy_
[
row_ids
]
/
x_
[
i
];
}
private:
T
*
dx_
;
const
T
*
dy_
;
const
T
*
x_
;
const
T
*
label_
;
size_t
num_classes_
;
};
template
<
typename
T
>
class
XeGradFunctor
{
public:
XeGradFunctor
(
T
*
dx
,
const
T
*
dy
,
// NOLINT
const
T
*
x
,
// NOLINT
const
int64_t
*
label
,
// NOLINT
size_t
num_classes
)
:
dx_
(
dx
),
dy_
(
dy
),
x_
(
x
),
label_
(
label
),
num_classes_
(
num_classes
)
{}
HOSTDEVICE
void
operator
()(
size_t
sample_id
)
{
auto
x_is_true_offset
=
sample_id
*
num_classes_
+
label_
[
sample_id
];
for
(
size_t
x_offset
=
sample_id
*
num_classes_
;
x_offset
<
(
sample_id
+
1
)
*
num_classes_
;
++
x_offset
)
{
dx_
[
x_offset
]
=
x_offset
!=
x_is_true_offset
?
static_cast
<
T
>
(
0
)
:
-
dy_
[
sample_id
]
/
x_
[
x_offset
];
}
}
private:
T
*
dx_
;
const
T
*
dy_
;
const
T
*
x_
;
const
int64_t
*
label_
;
size_t
num_classes_
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
CrossEntropyGradientOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on CPU."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
class_num
=
x
->
dims
()[
1
];
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
x_mat
=
EigenMatrix
<
T
>::
From
(
*
x
);
auto
dy_mat
=
EigenMatrix
<
T
>::
From
(
*
dy
);
auto
lbl_mat
=
EigenMatrix
<
T
>::
From
(
*
label
);
auto
dx_mat
=
EigenMatrix
<
T
>::
From
(
*
dx
);
dx_mat
.
device
(
*
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>()
.
eigen_device
())
=
-
(
lbl_mat
*
dy_mat
.
broadcast
(
Eigen
::
DSizes
<
int64_t
,
2
>
(
1
,
class_num
))
/
x_mat
);
XeSoftlabelGradFunctor
<
T
>
functor
(
dx_data
,
dy
->
data
<
T
>
(),
x
->
data
<
T
>
(),
label
->
data
<
T
>
(),
static_cast
<
size_t
>
(
class_num
));
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
.
template
device_context
<
DeviceContext
>(),
static_cast
<
size_t
>
(
dx
->
numel
()));
for_range
(
functor
);
}
else
{
int64_t
batch_size
=
x
->
dims
()[
0
];
const
T
*
dy_data
=
dy
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
int64_t
*
label_data
=
label
->
data
<
int64_t
>
();
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
functor
;
functor
(
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
dx
,
0
);
for
(
int64_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
PADDLE_ASSERT
(
label_data
[
i
]
>=
0
||
label_data
[
i
]
<
class_num
);
int64_t
index
=
i
*
class_num
+
label_data
[
i
];
dx_data
[
index
]
=
math
::
TolerableValue
<
T
>
()(
-
dy_data
[
i
]
/
x_data
[
index
]);
}
XeGradFunctor
<
T
>
functor
(
dx_data
,
dy
->
data
<
T
>
(),
x
->
data
<
T
>
(),
label
->
data
<
int64_t
>
(),
static_cast
<
size_t
>
(
class_num
));
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
.
template
device_context
<
DeviceContext
>(),
static_cast
<
size_t
>
(
dy
->
numel
()));
for_range
(
functor
);
}
}
};
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
a6edeb39
...
...
@@ -82,7 +82,9 @@ class RequestSend final : public RequestBase {
virtual
std
::
string
GetReqName
()
{
return
request_
->
Varname
();
}
virtual
void
Process
()
{
queue_
->
Push
(
std
::
make_pair
(
request_
->
Varname
(),
request_
));
std
::
string
var_name
=
GetReqName
();
VLOG
(
3
)
<<
"RequestSend "
<<
var_name
;
queue_
->
Push
(
std
::
make_pair
(
var_name
,
request_
));
sendrecv
::
VoidMessage
reply
;
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
...
...
@@ -106,7 +108,7 @@ class RequestGet final : public RequestBase {
responder_
(
&
ctx_
),
scope_
(
scope
),
queue_
(
queue
)
{
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
auto
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
}
...
...
@@ -118,6 +120,7 @@ class RequestGet final : public RequestBase {
virtual
void
Process
()
{
// proc request.
std
::
string
var_name
=
request_
.
varname
();
VLOG
(
3
)
<<
"RequestGet "
<<
var_name
;
auto
*
var
=
scope_
->
FindVar
(
var_name
);
::
grpc
::
ByteBuffer
reply
;
...
...
@@ -176,7 +179,7 @@ class RequestPrefetch final : public RequestBase {
::
grpc
::
ByteBuffer
reply
;
std
::
string
var_name
=
request_
->
OutVarname
();
VLOG
(
3
)
<<
"
prefetch var
"
<<
var_name
;
VLOG
(
3
)
<<
"
RequestPrefetch
"
<<
var_name
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
var_name
);
framework
::
Scope
*
local_scope
=
&
scope_
->
NewScope
();
auto
*
var
=
local_scope
->
FindVar
(
var_name
);
...
...
@@ -208,6 +211,11 @@ void AsyncGRPCServer::WaitClientGet(int count) {
}
}
void
AsyncGRPCServer
::
WaitServerReady
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
condition_ready_
.
wait
(
lock
,
[
=
]
{
return
this
->
ready_
==
1
;
});
}
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
::
grpc
::
ServerBuilder
builder
;
builder
.
AddListeningPort
(
address_
,
::
grpc
::
InsecureServerCredentials
(),
...
...
@@ -241,6 +249,12 @@ void AsyncGRPCServer::RunSyncUpdate() {
t_prefetch_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_prefetch_
.
get
(),
"cq_prefetch"
,
prefetch_register
)));
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
ready_
=
1
;
}
condition_ready_
.
notify_all
();
// wait server
server_
->
Wait
();
t_send_
->
join
();
...
...
@@ -307,18 +321,20 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
bool
ok
=
false
;
while
(
true
)
{
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" w
hile in
"
;
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" w
ait Next
"
;
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
LOG
(
INFO
)
<<
cq_name
<<
" CompletionQueue shutdown!"
;
break
;
}
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
"
while after
Next"
;
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
"
get
Next"
;
PADDLE_ENFORCE
(
tag
);
if
(
sync_mode_
)
{
// FIXME(typhoonzero): de-couple the barriers with recv_op
if
(
!
is_shut_down_
&&
cq_name
==
"cq_get"
)
WaitCond
(
1
);
if
(
!
is_shut_down_
&&
cq_name
==
"cq_send"
)
WaitCond
(
0
);
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" after WaitCond"
;
}
RequestBase
*
base
=
reinterpret_cast
<
RequestBase
*>
(
tag
);
...
...
@@ -336,9 +352,9 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
switch
(
base
->
Status
())
{
case
PROCESS
:
{
VLOG
(
4
)
<<
cq_name
<<
" PROCESS status:"
<<
base
->
Status
();
TryToRegisterNewOne
();
base
->
Process
();
VLOG
(
4
)
<<
cq_name
<<
" PROCESS status:"
<<
base
->
Status
();
break
;
}
case
FINISH
:
{
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
a6edeb39
...
...
@@ -45,8 +45,9 @@ class RequestBase;
class
AsyncGRPCServer
final
{
public:
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
bool
sync_mode
)
:
address_
(
address
),
sync_mode_
(
sync_mode
)
{}
:
address_
(
address
),
sync_mode_
(
sync_mode
)
,
ready_
(
0
)
{}
void
WaitServerReady
();
void
RunSyncUpdate
();
// functions to sync server barrier status.
...
...
@@ -118,6 +119,10 @@ class AsyncGRPCServer final {
framework
::
ProgramDesc
*
program_
;
framework
::
Executor
*
executor_
;
int
selected_port_
;
std
::
mutex
mutex_ready_
;
std
::
condition_variable
condition_ready_
;
int
ready_
;
};
};
// namespace detail
...
...
paddle/fluid/operators/detail/serde_test.cc
浏览文件 @
a6edeb39
...
...
@@ -108,7 +108,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
EXPECT_FLOAT_EQ
(
tensor_data2
[
i
],
32.7
);
}
for
(
size_t
i
=
0
;
i
<
rows2
->
size
();
++
i
)
{
EXPECT_EQ
(
rows_data2
[
i
],
i
);
EXPECT_EQ
(
rows_data2
[
i
],
static_cast
<
int64_t
>
(
i
)
);
}
EXPECT_EQ
(
slr2
->
height
(),
1000
);
}
...
...
paddle/fluid/operators/elementwise_op_function.h
浏览文件 @
a6edeb39
...
...
@@ -22,6 +22,7 @@ limitations under the License. */
#ifdef __NVCC__
#include <cuda.h>
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
constexpr
int
ELEMWISE_MAX_BLOCK_DIM
=
1024
;
#endif
...
...
@@ -336,43 +337,6 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
}
#ifdef __NVCC__
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const
int
warpSize
=
32
;
__shared__
T
shm
[
warpSize
];
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
platform
::
__shfl_down_sync
(
mask
,
val
,
offset
);
if
(
tid
<
warpSize
)
shm
[
tid
]
=
0
;
__syncthreads
();
if
(
tid
%
warpSize
==
0
)
{
shm
[
tid
/
warpSize
]
=
val
;
}
__syncthreads
();
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
if
(
tid
<
warpSize
)
{
val
=
shm
[
tid
];
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
platform
::
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
...
...
@@ -395,7 +359,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
if
(
dy
)
{
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
reduceSum
(
val
,
tid
,
h
);
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
}
...
...
@@ -472,7 +436,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
if
(
dy
)
{
int
h
=
pre
*
post
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
reduceSum
(
val
,
tid
,
h
);
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
}
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
a6edeb39
...
...
@@ -45,20 +45,6 @@ static void split(const std::string &str, char sep,
}
}
static
void
AsyncExecuteBlock
(
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
,
framework
::
Scope
*
scope
)
{
std
::
future
<
void
>
future
=
framework
::
Async
([
&
executor
,
&
prepared
,
&
scope
]()
{
try
{
executor
->
RunPreparedContext
(
prepared
,
scope
,
false
,
false
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
});
// TODO(qiao) maybe we can remove this
future
.
wait
();
}
static
void
ParallelExecuteBlocks
(
const
std
::
vector
<
size_t
>
&
parallel_blkids
,
framework
::
Executor
*
executor
,
const
std
::
vector
<
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
...
...
@@ -80,12 +66,7 @@ static void ParallelExecuteBlocks(
for
(
size_t
i
=
0
;
i
<
fs
.
size
();
++
i
)
fs
[
i
].
wait
();
}
static
void
SavePort
(
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service
)
{
std
::
ofstream
port_file
;
port_file
.
open
(
"/tmp/paddle.selected_port"
);
port_file
<<
rpc_service
->
GetSelectedPort
();
port_file
.
close
();
}
std
::
atomic_int
ListenAndServOp
::
selected_port_
{
0
};
ListenAndServOp
::
ListenAndServOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
...
...
@@ -93,15 +74,27 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
int
ListenAndServOp
::
GetSelectedPort
()
const
{
return
rpc_service_
->
GetSelectedPort
();
}
void
ListenAndServOp
::
Stop
()
{
rpc_service_
->
Push
(
LISTEN_TERMINATE_MESSAGE
);
server_thread_
->
join
();
}
void
ListenAndServOp
::
SavePort
(
const
std
::
string
&
file_path
)
const
{
// NOTE: default write file to /tmp/paddle.selected_port
selected_port_
=
rpc_service_
->
GetSelectedPort
();
std
::
ofstream
port_file
;
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
.
load
();
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
}
void
ListenAndServOp
::
WaitServerReady
()
{
while
(
selected_port_
.
load
()
==
0
)
{
}
}
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
...
...
@@ -201,14 +194,40 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
}
// while(true)
}
static
void
AsyncUpdateThread
(
const
std
::
string
&
var_name
,
const
bool
&
exit_flag
,
const
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
&
queue
,
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
)
{
VLOG
(
3
)
<<
"update thread for "
<<
var_name
<<
" started"
;
while
(
!
exit_flag
)
{
const
detail
::
ReceivedMessage
v
=
queue
->
Pop
();
auto
recv_var_name
=
v
.
first
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
auto
fs
=
framework
::
Async
([
var_name
,
&
executor
,
&
v
,
prepared
]
{
try
{
executor
->
RunPreparedContext
(
prepared
,
v
.
second
->
GetMutableLocalScope
(),
false
,
false
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
});
fs
.
wait
();
}
}
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
{
framework
::
ProgramDesc
*
program
)
const
{
VLOG
(
3
)
<<
"RunAsyncLoop in"
;
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_block_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
detail
::
ReceivedQueue
>>
grad_to_queue
;
auto
grad_to_block_id_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
...
...
@@ -220,6 +239,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
PADDLE_ENFORCE_EQ
(
grad_to_block_id
.
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_block_id
[
pieces
[
0
]]
=
block_id
;
grad_to_queue
[
pieces
[
0
]]
=
std
::
make_shared
<
detail
::
ReceivedQueue
>
();
id_to_grad
[
block_id
]
=
pieces
[
0
];
}
size_t
num_blocks
=
program
->
Size
();
...
...
@@ -238,8 +258,21 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
}
VLOG
(
3
)
<<
"RunAsyncLoop into while"
;
bool
exit_flag
=
false
;
VLOG
(
3
)
<<
"start async optimize threads"
;
std
::
vector
<
std
::
future
<
void
>>
fs
;
for
(
auto
iter
=
grad_to_queue
.
begin
();
iter
!=
grad_to_queue
.
end
();
iter
++
)
{
std
::
string
grad_name
=
iter
->
first
;
VLOG
(
3
)
<<
"create async update thread for "
<<
grad_name
;
fs
.
push_back
(
framework
::
AsyncIO
([
grad_name
,
&
exit_flag
,
&
executor
,
&
grad_to_queue
,
&
grad_to_prepared_ctx
]()
{
AsyncUpdateThread
(
grad_name
,
exit_flag
,
grad_to_queue
[
grad_name
],
executor
,
grad_to_prepared_ctx
[
grad_name
].
get
());
}));
}
VLOG
(
3
)
<<
"RunAsyncLoop into while"
;
while
(
!
exit_flag
)
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
...
...
@@ -249,13 +282,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
break
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
AsyncExecuteBlock
(
executor
,
grad_to_prepared_ctx
[
recv_var_name
].
get
(),
v
.
second
->
GetMutableLocalScope
());
grad_to_queue
[
recv_var_name
]
->
Push
(
v
);
}
if
(
exit_flag
)
{
...
...
@@ -298,13 +325,17 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
VLOG
(
3
)
<<
"wait server thread to become ready..."
;
sleep
(
5
);
rpc_service_
->
WaitServerReady
();
// Write to a file of server selected port for python use.
SavePort
(
rpc_service_
);
std
::
string
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.selected_port"
,
static_cast
<
int
>
(
::
getpid
()));
SavePort
(
file_path
);
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
}
else
{
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
RunAsyncLoop
(
&
executor
,
program
);
}
}
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
a6edeb39
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <stdint.h>
#include <atomic>
#include <ostream>
#include <string>
...
...
@@ -39,26 +40,33 @@ class ListenAndServOp : public framework::OperatorBase {
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
int
GetSelectedPort
()
const
;
void
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
;
void
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
;
framework
::
ProgramDesc
*
program
)
const
;
void
SavePort
(
const
std
::
string
&
file_path
=
"/tmp/paddle.selected_port"
)
const
;
void
WaitServerReady
();
int
GetSelectedPort
()
{
return
selected_port_
;
}
void
Stop
()
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
;
static
void
ResetPort
()
{
selected_port_
=
0
;
}
protected:
mutable
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
// FIXME(wuyi): it's static so that the operator can be cloned.
static
std
::
atomic_int
selected_port_
;
};
}
// namespace operators
...
...
paddle/fluid/operators/lookup_sparse_table_op.cc
0 → 100644
浏览文件 @
a6edeb39
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
operators
{
constexpr
int64_t
kNoPadding
=
-
1
;
class
LookupSparseTableInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of LookupSparseTableOp should not be null."
);
auto
shape_w
=
ctx
->
GetInputDim
(
"W"
);
auto
shape_ids
=
ctx
->
GetInputDim
(
"Ids"
);
shape_w
[
0
]
=
shape_ids
.
size
();
ctx
->
SetOutputDim
(
"Out"
,
shape_w
);
}
};
class
LookupSparseTableOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
out_var
=
scope
.
FindVar
(
Output
(
"Out"
));
auto
w_var
=
scope
.
FindVar
(
Input
(
"W"
));
auto
ids_var
=
scope
.
FindVar
(
Input
(
"Ids"
));
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
Attr
<
int
>
(
"seed"
));
float
min
=
Attr
<
float
>
(
"min"
);
float
max
=
Attr
<
float
>
(
"max"
);
bool
auto_grown_table
=
Attr
<
bool
>
(
"auto_grown_table"
);
PADDLE_ENFORCE
(
out_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The type of Out var should be LodTensor."
);
PADDLE_ENFORCE
(
w_var
->
IsType
<
framework
::
SelectedRows
>
(),
"The type of W var should be SelectedRows."
);
PADDLE_ENFORCE
(
ids_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The type of Ids var should be LoDTensor."
);
auto
&
ids_t
=
ids_var
->
Get
<
framework
::
LoDTensor
>
();
auto
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
w_t
=
w_var
->
GetMutable
<
framework
::
SelectedRows
>
();
std
::
vector
<
int64_t
>
keys
;
keys
.
resize
(
ids_t
.
numel
());
for
(
size_t
i
=
0
;
i
<
ids_t
.
numel
();
++
i
)
{
keys
[
i
]
=
ids_t
.
data
<
int64_t
>
()[
i
];
}
// TODO(Yancey1989): support CUDA Place for the sparse table
platform
::
CPUPlace
cpu
;
auto
out_shape
=
w_t
->
value
().
dims
();
out_shape
[
0
]
=
keys
.
size
();
out_t
->
Resize
(
out_shape
);
out_t
->
mutable_data
(
cpu
,
w_t
->
value
().
type
());
PADDLE_ENFORCE_EQ
(
framework
::
ToDataType
(
w_t
->
value
().
type
()),
framework
::
proto
::
VarType
::
FP32
,
"The sparse table only support FP32"
);
auto
non_keys_pair
=
w_t
->
Get
(
keys
,
out_t
);
if
(
!
auto_grown_table
)
{
PADDLE_ENFORCE_EQ
(
non_keys_pair
.
size
(),
static_cast
<
size_t
>
(
0
),
"there is some keys does exists in the sparse table."
);
}
auto
value_shape
=
w_t
->
value
().
dims
();
value_shape
[
0
]
=
1
;
for
(
const
auto
&
it
:
non_keys_pair
)
{
const
auto
key
=
it
.
first
;
const
auto
index
=
it
.
second
;
framework
::
Tensor
value
;
value
.
Resize
(
value_shape
);
auto
data
=
value
.
mutable_data
<
float
>
(
cpu
);
std
::
minstd_rand
engine
;
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
float
>
dist
(
min
,
max
);
int64_t
size
=
value
.
numel
();
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
}
w_t
->
Set
(
key
,
value
);
memory
::
Copy
(
cpu
,
out_t
->
mutable_data
<
float
>
(
cpu
)
+
index
*
value
.
numel
(),
cpu
,
value
.
data
<
float
>
(),
value
.
numel
()
*
sizeof
(
float
));
}
}
};
class
LookupSparseTableOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
LookupSparseTableOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"W"
,
"(SelectedRows) The input represents embedding table, "
"which is a learnable parameter."
);
AddInput
(
"Ids"
,
"(LoDTensor) Ids's type should be LoDTensor"
"THe ids to be looked up in W."
);
AddOutput
(
"Out"
,
"(LoDTensor) The lookup results, which have the "
"same type as W."
);
AddAttr
<
int64_t
>
(
"padding_idx"
,
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids."
)
.
SetDefault
(
kNoPadding
);
AddAttr
<
float
>
(
"min"
,
"(float, default -1.0) "
"Minimum value of uniform random"
)
.
SetDefault
(
-
1.0
f
);
AddAttr
<
float
>
(
"max"
,
"(float, default 1.0) "
"Maximun value of uniform random"
)
.
SetDefault
(
1.0
f
);
AddAttr
<
int
>
(
"seed"
,
"(int, default 0) "
"Random seed used for generating samples. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time."
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"auto_grown_table"
,
"(bool default false)"
"Whether create new value if for nonexistent key."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
Lookup Sprase Tablel Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
lookup_sparse_table
,
ops
::
LookupSparseTableOp
,
ops
::
LookupSparseTableInferShape
,
ops
::
LookupSparseTableOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
paddle/fluid/operators/math/cross_entropy.cu
浏览文件 @
a6edeb39
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
...
...
@@ -30,66 +31,22 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
}
}
template
<
typename
T
>
__device__
__forceinline__
T
sum_single_warp
(
T
val
)
{
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
16
);
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
8
);
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
4
);
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
2
);
val
+=
platform
::
__shfl_down_sync
(
0
,
val
,
1
);
return
val
;
}
// CUDA do not support dynamic arrary in template
// https://stackoverflow.com/questions/20497209
template
<
typename
T
>
struct
SharedMemory
{
// Ensure that we won't compile any un-specialized types
__device__
T
*
GetPointer
()
{
return
NULL
;
}
};
template
<
>
struct
SharedMemory
<
float
>
{
__device__
float
*
GetPointer
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
double
*
GetPointer
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
template
<
typename
T
>
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
class_num
)
{
int
tid
=
threadIdx
.
x
;
SharedMemory
<
T
>
d_sum_shared
;
T
*
d_sum
=
d_sum_shared
.
GetPointer
();
d_sum
[
tid
]
=
0
;
T
val
=
0
;
int
cur_idx
=
tid
;
int
next_idx
=
blockIdx
.
x
*
class_num
+
tid
;
while
(
cur_idx
<
class_num
)
{
d_sum
[
tid
]
+=
math
::
TolerableValue
<
T
>
()(
std
::
log
(
X
[
next_idx
]))
*
label
[
next_idx
];
next_idx
+=
blockDim
.
x
;
cur_idx
+=
blockDim
.
x
;
int
idx
=
blockIdx
.
x
*
class_num
+
tid
;
int
end
=
blockIdx
.
x
*
class_num
+
class_num
;
for
(;
idx
<
end
;
idx
+=
blockDim
.
x
)
{
val
+=
math
::
TolerableValue
<
T
>
()(
std
::
log
(
X
[
idx
]))
*
label
[
idx
];
}
__syncthreads
();
for
(
unsigned
int
stride
=
blockDim
.
x
>>
1
;
stride
>=
32
;
stride
>>=
1
)
{
if
(
tid
<
stride
)
d_sum
[
tid
]
+=
d_sum
[
tid
+
stride
];
__syncthreads
()
;
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
Y
[
blockIdx
.
x
]
=
-
val
;
}
T
val
=
d_sum
[
tid
];
val
=
sum_single_warp
<
T
>
(
val
);
if
(
tid
==
0
)
Y
[
blockIdx
.
x
]
=
-
val
;
}
}
// namespace
...
...
@@ -113,9 +70,7 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
?
512
:
pow
(
2
,
static_cast
<
int
>
(
std
::
log2
(
class_num
)));
SoftCrossEntropyKernel
<
T
><<<
batch_size
,
block
,
block
*
sizeof
(
T
),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
SoftCrossEntropyKernel
<
T
><<<
batch_size
,
block
,
0
,
ctx
.
stream
()
>>>
(
loss_data
,
prob_data
,
label_data
,
class_num
);
}
else
{
const
int64_t
*
label_data
=
labels
->
data
<
int64_t
>
();
...
...
paddle/fluid/operators/math/pooling.cc
浏览文件 @
a6edeb39
...
...
@@ -11,8 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/pooling.h"
#include <algorithm>
#include <vector>
namespace
paddle
{
namespace
operators
{
...
...
@@ -27,9 +28,10 @@ template <typename PoolProcess, typename T>
class
Pool2dFunctor
<
platform
::
CPUDeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -63,11 +65,11 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
T
ele
=
pool_process
.
initial
();
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_process
.
compute
(
ele
,
input_data
[
h
*
input_width
+
w
]
);
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_process
.
finalize
(
ele
,
(
static_cast
<
T
>
(
pool_size
))
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
ph
*
output_width
+
pw
]
=
ele
;
}
}
...
...
@@ -86,13 +88,12 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
template
<
typename
PoolProcess
,
class
T
>
class
Pool2dGradFunctor
<
platform
::
CPUDeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -131,8 +132,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
input_data
[
h
*
input_width
+
w
],
output_data
[
ph
*
output_width
+
pw
],
output_grad_data
[
ph
*
output_width
+
pw
],
input_grad_data
[
h
*
input_width
+
w
]
,
static_cast
<
T
>
(
scale
)
);
static_cast
<
T
>
(
scale
)
,
input_grad_data
+
h
*
input_width
+
w
);
}
}
}
...
...
@@ -154,12 +155,11 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
template
<
class
T
>
class
MaxPool2dGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -246,9 +246,10 @@ template <typename PoolProcess, class T>
class
Pool3dFunctor
<
platform
::
CPUDeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -293,14 +294,14 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_process
.
compute
(
ele
,
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
]
);
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
]
,
&
ele
);
}
}
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_process
.
finalize
(
ele
,
static_cast
<
T
>
(
pool_size
)
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
output_idx
]
=
ele
;
}
}
...
...
@@ -320,13 +321,12 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
template
<
typename
PoolProcess
,
class
T
>
class
Pool3dGradFunctor
<
platform
::
CPUDeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -379,8 +379,8 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
pool_grad_process
.
compute
(
input_data
[
input_idx
],
output_data
[
output_idx
],
output_grad_data
[
output_idx
],
input_grad_data
[
input_idx
],
static_cast
<
T
>
(
scale
)
);
output_grad_data
[
output_idx
],
static_cast
<
T
>
(
scale
),
input_grad_data
+
input_idx
);
}
}
}
...
...
@@ -404,12 +404,11 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
template
<
class
T
>
class
MaxPool3dGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -510,9 +509,10 @@ template <typename T1, typename T2>
class
MaxPool2dWithIndexFunctor
<
platform
::
CPUDeviceContext
,
T1
,
T2
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -576,8 +576,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_height
=
input_grad
->
dims
()[
2
];
...
...
@@ -628,9 +629,10 @@ template <typename T1, typename T2>
class
MaxPool3dWithIndexFunctor
<
platform
::
CPUDeviceContext
,
T1
,
T2
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -708,8 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_depth
=
input_grad
->
dims
()[
2
];
...
...
paddle/fluid/operators/math/pooling.cu
浏览文件 @
a6edeb39
...
...
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
@@ -47,11 +49,11 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
T
ele
=
pool_process
.
initial
();
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_process
.
compute
(
ele
,
input_data
[
h
*
input_width
+
w
]
);
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_process
.
finalize
(
ele
,
(
static_cast
<
T
>
(
pool_size
))
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
index
]
=
ele
;
}
}
...
...
@@ -96,8 +98,8 @@ __global__ void KernelPool2DGrad(
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
output_sub_idx
=
ph
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
gradient
,
static_cast
<
T
>
(
1.0
/
pool_size
));
output_grad
[
output_sub_idx
],
static_cast
<
T
>
(
1.0
/
pool_size
)
,
&
gradient
);
}
}
input_grad
[
index
]
=
gradient
;
...
...
@@ -158,9 +160,10 @@ template <typename PoolProcess, typename T>
class
Pool2dFunctor
<
platform
::
CUDADeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -201,9 +204,11 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -246,8 +251,10 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
...
...
@@ -340,12 +347,12 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data,
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_process
.
compute
(
ele
,
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
]
);
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
],
&
ele
);
}
}
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_process
.
finalize
(
ele
,
static_cast
<
T
>
(
pool_size
)
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
index
]
=
ele
;
}
}
...
...
@@ -405,8 +412,8 @@ __global__ void KernelPool3DGrad(
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
output_sub_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
gradient
,
static_cast
<
T
>
(
1.0
/
pool_size
));
output_grad
[
output_sub_idx
],
static_cast
<
T
>
(
1.0
/
pool_size
)
,
&
gradient
);
}
}
}
...
...
@@ -474,9 +481,10 @@ template <typename PoolProcess, class T>
class
Pool3dFunctor
<
platform
::
CUDADeviceContext
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -525,9 +533,11 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -578,8 +588,10 @@ class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
...
...
@@ -736,9 +748,10 @@ template <typename T1, typename T2>
class
MaxPool2dWithIndexFunctor
<
platform
::
CUDADeviceContext
,
T1
,
T2
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -779,8 +792,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_channels
=
input_grad
->
dims
()[
1
];
...
...
@@ -937,9 +951,10 @@ template <typename T1, typename T2>
class
MaxPool3dWithIndexFunctor
<
platform
::
CUDADeviceContext
,
T1
,
T2
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -987,8 +1002,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_channels
=
input_grad
->
dims
()[
1
];
...
...
paddle/fluid/operators/math/pooling.h
浏览文件 @
a6edeb39
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
...
...
@@ -23,8 +24,8 @@ namespace operators {
namespace
math
{
#define FLT_MAX \
__FLT_MAX__ //
It might need to be placed in another file, but I'm still
// wondering where to put it.
__FLT_MAX__ //
TODO(zcd) :It might need to be placed in another file, but I'm
//
still
wondering where to put it.
/*
* \brief Extracting simple operations from pooling.
...
...
@@ -40,33 +41,33 @@ template <class T>
class
MaxPool
{
public:
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
-
FLT_MAX
);
}
DEVICE
inline
void
compute
(
T
&
y
,
const
T
&
x
)
{
y
=
y
>
x
?
y
:
x
;
}
DEVICE
inline
void
finalize
(
T
&
y
,
const
T
&
pool_field
)
{}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
=
*
y
>
x
?
*
y
:
x
;
}
DEVICE
inline
void
finalize
(
const
T
&
pool_field
,
T
*
y
)
{}
};
template
<
class
T
>
class
AvgPool
{
public:
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
0
);
}
DEVICE
inline
void
compute
(
T
&
y
,
const
T
&
x
)
{
y
+=
x
;
}
DEVICE
inline
void
finalize
(
T
&
y
,
const
T
&
pool_field
)
{
y
/=
pool_field
;
}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
+=
x
;
}
DEVICE
inline
void
finalize
(
const
T
&
pool_field
,
T
*
y
)
{
*
y
/=
pool_field
;
}
};
template
<
class
T
>
class
MaxPoolGrad
{
public:
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
&
dx
,
T
scale
)
{
dx
+=
dy
*
(
x
==
y
);
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
scale
,
T
*
dx
)
{
*
dx
+=
dy
*
(
x
==
y
);
}
};
template
<
class
T
>
class
AvgPoolGrad
{
public:
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
&
dx
,
T
scale
)
{
dx
+=
(
scale
*
dy
);
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
scale
,
T
*
dx
)
{
*
dx
+=
(
scale
*
dy
);
}
};
...
...
@@ -88,8 +89,9 @@ template <typename DeviceContext, typename PoolProcess, typename T>
class
Pool2dFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
};
...
...
@@ -98,9 +100,11 @@ class Pool2dGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
DeviceContext
,
class
T
>
...
...
@@ -108,8 +112,10 @@ class MaxPool2dGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
...
...
@@ -117,8 +123,9 @@ template <typename DeviceContext, typename PoolProcess, typename T>
class
Pool3dFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
};
...
...
@@ -127,9 +134,11 @@ class Pool3dGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
DeviceContext
,
class
T
>
...
...
@@ -137,8 +146,10 @@ class MaxPool3dGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
...
...
@@ -153,8 +164,9 @@ template <typename DeviceContext, typename T1, typename T2>
class
MaxPool2dWithIndexFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
);
};
...
...
@@ -163,8 +175,9 @@ class MaxPool2dWithIndexGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
...
...
@@ -172,8 +185,9 @@ template <typename DeviceContext, typename T1, typename T2>
class
MaxPool3dWithIndexFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
);
};
...
...
@@ -182,8 +196,9 @@ class MaxPool3dWithIndexGradFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
const
framework
::
Tensor
&
mask
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
...
...
paddle/fluid/operators/math/sequence_padding.cc
浏览文件 @
a6edeb39
...
...
@@ -22,7 +22,7 @@ template <typename T>
class
PaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq
,
framework
::
Tensor
&
padding
,
const
framework
::
LoDTensor
&
seq
,
framework
::
Tensor
*
padding
,
bool
norm_by_times
)
{
auto
lod
=
seq
.
lod
();
PADDLE_ENFORCE_GT
(
lod
.
size
(),
0UL
,
...
...
@@ -37,7 +37,7 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length."
);
auto
padding_dims
=
padding
.
dims
();
auto
padding_dims
=
padding
->
dims
();
PADDLE_ENFORCE_EQ
(
padding_dims
.
size
(),
3UL
,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width]."
);
...
...
@@ -58,7 +58,7 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
"width of sequence in LoDTensor seq."
);
const
T
*
seq_data
=
seq
.
data
<
T
>
();
T
*
padding_data
=
padding
.
data
<
T
>
();
T
*
padding_data
=
padding
->
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
max_sequence_length
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
num_sequences
;
++
j
)
{
int64_t
start_pos
=
abs_offset_lod
[
level
][
j
];
...
...
@@ -84,16 +84,16 @@ template <typename T>
class
UnpaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
framework
::
LoDTensor
&
seq
,
const
framework
::
Tensor
&
padding
,
framework
::
LoDTensor
*
seq
,
const
framework
::
Tensor
&
padding
,
bool
norm_by_times
)
{
auto
lod
=
seq
.
lod
();
auto
lod
=
seq
->
lod
();
PADDLE_ENFORCE_GT
(
lod
.
size
(),
0UL
,
"The LoD of LoDTensor seq should not be null."
);
const
size_t
level
=
0
;
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
auto
seq_dims
=
seq
.
dims
();
auto
seq_dims
=
seq
->
dims
();
PADDLE_ENFORCE_EQ
(
seq_dims
[
0
],
static_cast
<
int64_t
>
(
abs_offset_lod
[
level
].
back
()),
"The first dimension of LoDTensor seq should be "
...
...
@@ -114,13 +114,13 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq."
);
const
int64_t
sequence_width
=
seq
.
numel
()
/
seq_dims
[
0
];
const
int64_t
sequence_width
=
seq
->
numel
()
/
seq_dims
[
0
];
PADDLE_ENFORCE_EQ
(
padding_dims
[
2
],
sequence_width
,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq."
);
const
T
*
padding_data
=
padding
.
data
<
T
>
();
T
*
seq_data
=
seq
.
data
<
T
>
();
T
*
seq_data
=
seq
->
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
num_sequences
;
++
i
)
{
int64_t
start_pos
=
abs_offset_lod
[
level
][
i
];
int64_t
sequence_length
=
abs_offset_lod
[
level
][
i
+
1
]
-
start_pos
;
...
...
paddle/fluid/operators/math/sequence_padding.cu
浏览文件 @
a6edeb39
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/operators/math/sequence_padding.h"
namespace
paddle
{
...
...
@@ -61,7 +62,7 @@ template <typename T>
class
PaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq
,
framework
::
Tensor
&
padding
,
const
framework
::
LoDTensor
&
seq
,
framework
::
Tensor
*
padding
,
bool
norm_by_times
)
{
auto
lod
=
seq
.
lod
();
PADDLE_ENFORCE_GT
(
lod
.
size
(),
0UL
,
...
...
@@ -76,7 +77,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length."
);
auto
padding_dims
=
padding
.
dims
();
auto
padding_dims
=
padding
->
dims
();
PADDLE_ENFORCE_EQ
(
padding_dims
.
size
(),
3UL
,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width]."
);
...
...
@@ -97,8 +98,8 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"width of sequence in LoDTensor seq."
);
if
(
!
norm_by_times
&&
num_sequences
==
1UL
)
{
TensorCopy
(
seq
,
context
.
GetPlace
(),
context
,
&
padding
);
padding
.
Resize
(
padding_dims
);
TensorCopy
(
seq
,
context
.
GetPlace
(),
context
,
padding
);
padding
->
Resize
(
padding_dims
);
return
;
}
...
...
@@ -117,7 +118,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
dim3
grid
(
grid_dim_x
,
grid_dim_y
);
const
T
*
seq_data
=
seq
.
data
<
T
>
();
T
*
padding_data
=
padding
.
data
<
T
>
();
T
*
padding_data
=
padding
->
data
<
T
>
();
if
(
norm_by_times
)
{
SequencePaddingKernel
<
T
,
1
,
1
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
padding_data
,
const_cast
<
T
*>
(
seq_data
),
...
...
@@ -136,16 +137,16 @@ template <typename T>
class
UnpaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
framework
::
LoDTensor
&
seq
,
const
framework
::
Tensor
&
padding
,
framework
::
LoDTensor
*
seq
,
const
framework
::
Tensor
&
padding
,
bool
norm_by_times
)
{
auto
lod
=
seq
.
lod
();
auto
lod
=
seq
->
lod
();
PADDLE_ENFORCE_GT
(
lod
.
size
(),
0UL
,
"The lod of LoDTensor seq should not be null."
);
const
size_t
level
=
0
;
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
auto
seq_dims
=
seq
.
dims
();
auto
seq_dims
=
seq
->
dims
();
PADDLE_ENFORCE_EQ
(
seq_dims
[
0
],
static_cast
<
int64_t
>
(
abs_offset_lod
[
level
].
back
()),
"The first dimension of LoDTensor seq should be "
...
...
@@ -166,14 +167,14 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq."
);
const
int64_t
sequence_width
=
seq
.
numel
()
/
seq_dims
[
0
];
const
int64_t
sequence_width
=
seq
->
numel
()
/
seq_dims
[
0
];
PADDLE_ENFORCE_EQ
(
padding_dims
[
2
],
sequence_width
,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq."
);
if
(
!
norm_by_times
&&
num_sequences
==
1UL
)
{
TensorCopy
(
padding
,
context
.
GetPlace
(),
context
,
&
seq
);
seq
.
Resize
(
seq_dims
);
TensorCopy
(
padding
,
context
.
GetPlace
(),
context
,
seq
);
seq
->
Resize
(
seq_dims
);
return
;
}
...
...
@@ -192,7 +193,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
dim3
grid
(
grid_dim_x
,
grid_dim_y
);
const
T
*
padding_data
=
padding
.
data
<
T
>
();
T
*
seq_data
=
seq
.
data
<
T
>
();
T
*
seq_data
=
seq
->
data
<
T
>
();
if
(
norm_by_times
)
{
SequencePaddingKernel
<
T
,
1
,
0
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
const_cast
<
T
*>
(
padding_data
),
seq_data
,
...
...
paddle/fluid/operators/math/sequence_padding.h
浏览文件 @
a6edeb39
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/device_context.h"
...
...
@@ -64,13 +65,13 @@ template <typename DeviceContext, typename T>
class
PaddingLoDTensorFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq
,
framework
::
Tensor
&
padding
,
bool
norm_by_times
);
framework
::
Tensor
*
padding
,
bool
norm_by_times
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
UnpaddingLoDTensorFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
framework
::
LoDTensor
&
seq
,
void
operator
()(
const
DeviceContext
&
context
,
framework
::
LoDTensor
*
seq
,
const
framework
::
Tensor
&
padding
,
bool
norm_by_times
);
};
...
...
paddle/fluid/operators/math/sequence_padding_test.cc
浏览文件 @
a6edeb39
...
...
@@ -54,12 +54,12 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
static_cast
<
int64_t
>
(
sequence_width
)});
padding
.
mutable_data
<
T
>
(
padding_dims
,
*
place
);
paddle
::
operators
::
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
*
context
,
seq
,
padding
,
false
);
*
context
,
seq
,
&
padding
,
false
);
seq_back
.
set_lod
(
lod
);
seq_back
.
mutable_data
<
T
>
(
seq_dims
,
*
place
);
paddle
::
operators
::
math
::
UnpaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
*
context
,
seq_back
,
padding
,
false
);
*
context
,
&
seq_back
,
padding
,
false
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
cpu_seq_back
=
seq_back
;
...
...
paddle/fluid/operators/momentum_op.cc
浏览文件 @
a6edeb39
...
...
@@ -17,6 +17,8 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
MomentumOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -50,6 +52,12 @@ class MomentumOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"ParamOut"
,
param_dim
);
ctx
->
SetOutputDim
(
"VelocityOut"
,
param_dim
);
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
class
MomentumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/operators/mul_op.cc
浏览文件 @
a6edeb39
...
...
@@ -204,6 +204,8 @@ REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
mul_grad
,
ops
::
MulGradOp
);
REGISTER_OP_CPU_KERNEL
(
mul
,
ops
::
MulKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
mul
,
ops
::
MulKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MulKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
mul_grad
,
ops
::
MulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
mul_grad
,
ops
::
MulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MulGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/mul_op.cu.cc
浏览文件 @
a6edeb39
...
...
@@ -18,6 +18,8 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
mul
,
ops
::
MulKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
MulKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
MulKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
mul_grad
,
ops
::
MulGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
ops
::
MulGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
MulGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/row_conv_op.cu
浏览文件 @
a6edeb39
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/row_conv_op.h"
#include "paddle/fluid/platform/cuda_
primitives
.h"
#include "paddle/fluid/platform/cuda_
device_function
.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/save_load_op_test.cc
浏览文件 @
a6edeb39
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
USE_NO_KERNEL_OP
(
save
);
USE_NO_KERNEL_OP
(
load
);
...
...
@@ -61,3 +62,35 @@ TEST(SaveLoadOp, CPU) {
}
}
}
TEST
(
SaveLoadFP16Op
,
CPU
)
{
paddle
::
framework
::
Scope
scope
;
paddle
::
platform
::
CPUPlace
place
;
auto
var
=
scope
.
Var
(
"test_var"
);
auto
tensor
=
var
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
tensor
->
Resize
({
3
,
10
});
float
*
expect
=
tensor
->
mutable_data
<
float
>
(
place
);
for
(
int64_t
i
=
0
;
i
<
tensor
->
numel
();
++
i
)
{
expect
[
i
]
=
static_cast
<
float
>
(
paddle
::
platform
::
float16
(
i
));
}
paddle
::
framework
::
AttributeMap
attrs
;
attrs
.
insert
({
"file_path"
,
std
::
string
(
"tensor.save"
)});
attrs
.
insert
({
"save_as_fp16"
,
true
});
auto
save_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"save"
,
{{
"X"
,
{
"test_var"
}}},
{},
attrs
);
save_op
->
Run
(
scope
,
place
);
auto
load_var
=
scope
.
Var
(
"out_var"
);
auto
target
=
load_var
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
auto
load_op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
"load"
,
{},
{{
"Out"
,
{
"out_var"
}}},
attrs
);
load_op
->
Run
(
scope
,
place
);
paddle
::
platform
::
float16
*
actual
=
target
->
data
<
paddle
::
platform
::
float16
>
();
for
(
int64_t
i
=
0
;
i
<
tensor
->
numel
();
++
i
)
{
EXPECT_EQ
(
expect
[
i
],
static_cast
<
float
>
(
actual
[
i
]));
}
}
paddle/fluid/operators/save_op.cc
浏览文件 @
a6edeb39
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <numeric>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -68,6 +69,7 @@ class SaveOp : public framework::OperatorBase {
const
platform
::
Place
&
place
)
const
override
{
auto
filename
=
Attr
<
std
::
string
>
(
"file_path"
);
auto
overwrite
=
Attr
<
bool
>
(
"overwrite"
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
if
(
FileExists
(
filename
)
&&
!
overwrite
)
{
PADDLE_THROW
(
"%s is existed, cannot save to it when overwrite=false"
,
...
...
@@ -96,7 +98,18 @@ class SaveOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
SerializeToStream
(
fout
,
tensor
,
dev_ctx
);
auto
in_dtype
=
framework
::
ToDataType
(
tensor
.
type
());
auto
out_dtype
=
save_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
if
(
in_dtype
!=
out_dtype
)
{
auto
in_kernel_type
=
framework
::
OpKernelType
(
in_dtype
,
place
);
auto
out_kernel_type
=
framework
::
OpKernelType
(
out_dtype
,
place
);
framework
::
LoDTensor
out
;
framework
::
TransDataType
(
in_kernel_type
,
out_kernel_type
,
tensor
,
&
out
);
framework
::
SerializeToStream
(
fout
,
out
,
dev_ctx
);
}
else
{
framework
::
SerializeToStream
(
fout
,
tensor
,
dev_ctx
);
}
}
};
...
...
@@ -114,6 +127,12 @@ This operator will serialize and write a tensor variable to file on disk.
"(boolean, default true)"
"Overwrite the output file if exist"
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"save_as_fp16"
,
"(boolean, default false)"
"If true, the tensor will be converted to float16 data "
"type and then saved. Otherwise, the tensor will be "
"directly saved without data type conversion."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"file_path"
,
"(string)"
"The
\"
file_path
\"
where the variable will be saved."
)
...
...
paddle/fluid/operators/scale_op.cc
浏览文件 @
a6edeb39
...
...
@@ -35,7 +35,6 @@ class ScaleOp : public framework::OperatorWithKernel {
}
};
template
<
typename
AttrType
>
class
ScaleOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
ScaleOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
...
...
@@ -47,9 +46,9 @@ Scale operator
$$Out = scale*X$$
)DOC"
);
AddAttr
<
AttrType
>
(
"scale"
,
"(float, default 1.0)"
"The scaling factor of the scale operator."
)
AddAttr
<
float
>
(
"scale"
,
"(float, default 1.0)"
"The scaling factor of the scale operator."
)
.
SetDefault
(
1.0
);
}
};
...
...
@@ -73,8 +72,7 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
<
float
>
,
ops
::
ScaleGradMaker
);
REGISTER_OPERATOR
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
,
ops
::
ScaleGradMaker
);
REGISTER_OP_CPU_KERNEL
(
scale
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
...
...
paddle/fluid/operators/send_recv_op_test.cc
浏览文件 @
a6edeb39
...
...
@@ -113,15 +113,15 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
op
->
SetAttrMap
(
attrs
);
}
void
StartServerNet
(
bool
is_sparse
)
{
void
StartServerNet
(
bool
is_sparse
,
std
::
atomic
<
bool
>
*
initialized
)
{
f
::
Scope
scope
;
p
::
CPUPlace
place
;
VLOG
(
4
)
<<
"before init tensor"
;
if
(
is_sparse
)
{
InitSelectedRowsInScope
(
place
,
&
scope
);
}
else
{
InitTensorsInScope
(
place
,
&
scope
);
}
// sub program run in listen_and_serv_op, for simple test we use sum
f
::
ProgramDesc
program
;
const
auto
&
root_block
=
program
.
Block
(
0
);
...
...
@@ -129,7 +129,6 @@ void StartServerNet(bool is_sparse) {
auto
*
prefetch_block
=
program
.
AppendBlock
(
root_block
);
// X for server side tensors, RX for received tensors, must be of same shape.
AddOp
(
"sum"
,
{{
"X"
,
{
"x0"
,
"x1"
}}},
{{
"Out"
,
{
"Out"
}}},
{},
optimize_block
);
f
::
AttributeMap
attrs
;
attrs
.
insert
({
"endpoint"
,
std
::
string
(
"127.0.0.1:0"
)});
attrs
.
insert
({
"Fanin"
,
1
});
...
...
@@ -139,15 +138,22 @@ void StartServerNet(bool is_sparse) {
attrs
.
insert
({
"PrefetchBlock"
,
prefetch_block
});
attrs
.
insert
({
"grad_to_block_id"
,
std
::
vector
<
std
::
string
>
({
""
})});
attrs
.
insert
({
"sync_mode"
,
true
});
VLOG
(
4
)
<<
"before init op"
;
listen_and_serv_op
=
f
::
OpRegistry
::
CreateOp
(
"listen_and_serv"
,
{{
"X"
,
{
"x1"
}}},
{},
attrs
);
*
initialized
=
true
;
listen_and_serv_op
->
Run
(
scope
,
place
);
LOG
(
INFO
)
<<
"server exit"
;
}
TEST
(
SendRecvOp
,
CPUDense
)
{
std
::
thread
server_thread
(
StartServerNet
,
false
);
sleep
(
5
);
// wait server to start
std
::
atomic
<
bool
>
initialized
{
false
};
std
::
thread
server_thread
(
StartServerNet
,
false
,
&
initialized
);
while
(
!
initialized
)
{
}
static_cast
<
paddle
::
operators
::
ListenAndServOp
*>
(
listen_and_serv_op
.
get
())
->
WaitServerReady
();
// local net
f
::
Scope
scope
;
p
::
CPUPlace
place
;
...
...
@@ -156,9 +162,11 @@ TEST(SendRecvOp, CPUDense) {
scope
.
Var
(
"RPC_CLIENT_VAR"
);
f
::
AttributeMap
attrs
;
selected_port
=
static_cast
<
paddle
::
operators
::
ListenAndServOp
*>
(
listen_and_serv_op
.
get
())
->
GetSelectedPort
();
auto
*
listen_and_serv_op_ptr
=
static_cast
<
paddle
::
operators
::
ListenAndServOp
*>
(
listen_and_serv_op
.
get
());
ASSERT_TRUE
(
listen_and_serv_op_ptr
!=
nullptr
);
selected_port
=
listen_and_serv_op_ptr
->
GetSelectedPort
();
std
::
string
endpoint
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
selected_port
);
attrs
.
insert
({
"endpoints"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
...
...
@@ -181,11 +189,21 @@ TEST(SendRecvOp, CPUDense) {
listen_and_serv_op
->
Stop
();
server_thread
.
join
();
listen_and_serv_op
.
reset
(
nullptr
);
paddle
::
operators
::
ListenAndServOp
::
ResetPort
();
}
TEST
(
SendRecvOp
,
CPUSparse
)
{
std
::
thread
server_thread
(
StartServerNet
,
true
);
sleep
(
3
);
// wait server to start
std
::
atomic
<
bool
>
initialized
;
initialized
=
false
;
std
::
thread
server_thread
(
StartServerNet
,
true
,
&
initialized
);
while
(
!
initialized
)
{
}
auto
*
listen_and_serv_op_ptr
=
static_cast
<
paddle
::
operators
::
ListenAndServOp
*>
(
listen_and_serv_op
.
get
());
ASSERT_TRUE
(
listen_and_serv_op_ptr
!=
nullptr
);
listen_and_serv_op_ptr
->
WaitServerReady
();
// local net
f
::
Scope
scope
;
p
::
CPUPlace
place
;
...
...
@@ -193,9 +211,7 @@ TEST(SendRecvOp, CPUSparse) {
InitSelectedRowsInScope
(
place
,
&
scope
);
scope
.
Var
(
"RPC_CLIENT_VAR"
);
f
::
AttributeMap
attrs
;
selected_port
=
static_cast
<
paddle
::
operators
::
ListenAndServOp
*>
(
listen_and_serv_op
.
get
())
->
GetSelectedPort
();
selected_port
=
listen_and_serv_op_ptr
->
GetSelectedPort
();
std
::
string
endpoint
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
selected_port
);
attrs
.
insert
({
"endpoints"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
attrs
.
insert
({
"epmap"
,
std
::
vector
<
std
::
string
>
({
endpoint
})});
...
...
@@ -226,4 +242,5 @@ TEST(SendRecvOp, CPUSparse) {
listen_and_serv_op
->
Stop
();
server_thread
.
join
();
listen_and_serv_op
.
reset
();
paddle
::
operators
::
ListenAndServOp
::
ResetPort
();
}
paddle/fluid/operators/sgd_op.cc
浏览文件 @
a6edeb39
...
...
@@ -48,6 +48,24 @@ class SGDOp : public framework::OperatorWithKernel {
}
};
class
SGDOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
input_var
=
op_desc
.
Input
(
"Param"
)[
0
];
for
(
auto
&
out_var
:
op_desc
.
Output
(
"ParamOut"
))
{
if
(
block
->
FindRecursiveOrCreateVar
(
input_var
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
}
};
class
SGDOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
SGDOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
...
...
@@ -74,5 +92,6 @@ $$param\_out = param - learning\_rate * grad$$
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
sgd
,
ops
::
SGDOp
,
ops
::
SGDOpMaker
);
REGISTER_OPERATOR
(
sgd
,
ops
::
SGDOp
,
ops
::
SGDOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SGDOpInferVarType
);
REGISTER_OP_CPU_KERNEL
(
sgd
,
ops
::
SGDOpKernel
<
float
>
,
ops
::
SGDOpKernel
<
double
>
);
paddle/fluid/operators/softmax_op.cc
浏览文件 @
a6edeb39
...
...
@@ -164,7 +164,9 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
softmax_grad
,
ops
::
SoftmaxOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
softmax
,
ops
::
SoftmaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SoftmaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
softmax_grad
,
ops
::
SoftmaxGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
SoftmaxGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SoftmaxGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/softmax_op.cu.cc
浏览文件 @
a6edeb39
...
...
@@ -19,6 +19,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
SoftmaxKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
SoftmaxKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
softmax_grad
,
ops
::
SoftmaxGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
softmax_grad
,
ops
::
SoftmaxGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
SoftmaxGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/top_k_op.cc
浏览文件 @
a6edeb39
...
...
@@ -75,4 +75,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
top_k
,
ops
::
TopkOp
,
ops
::
TopkOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
top_k
,
ops
::
TopkKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
TopkKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
TopkKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
paddle/fluid/operators/top_k_op.cu
浏览文件 @
a6edeb39
...
...
@@ -318,4 +318,5 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
top_k
,
paddle
::
operators
::
TopkOpCUDAKernel
<
float
>
);
REGISTER_OP_CUDA_KERNEL
(
top_k
,
paddle
::
operators
::
TopkOpCUDAKernel
<
float
>
,
paddle
::
operators
::
TopkOpCUDAKernel
<
double
>
);
paddle/fluid/operators/uniform_random_op.cc
浏览文件 @
a6edeb39
...
...
@@ -116,11 +116,31 @@ uniform distribution.
.
SetDefault
(
framework
::
proto
::
VarType
::
FP32
);
}
};
class
UniformRandomOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"Out"
).
front
();
if
(
block
->
FindRecursiveOrCreateVar
(
out_var_name
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var_name
)
.
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
block
->
FindRecursiveOrCreateVar
(
out_var_name
)
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_WITHOUT_GRADIENT
(
uniform_random
,
paddle
::
operators
::
UniformRandomOp
,
paddle
::
operators
::
UniformRandomOpMaker
);
REGISTER_OPERATOR
(
uniform_random
,
paddle
::
operators
::
UniformRandomOp
,
paddle
::
operators
::
UniformRandomOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
UniformRandomOpVarTypeInference
);
REGISTER_OP_CPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
CPUUniformRandomKernel
<
float
>
,
paddle
::
operators
::
CPUUniformRandomKernel
<
double
>
);
...
...
paddle/fluid/operators/warpctc_op.h
浏览文件 @
a6edeb39
...
...
@@ -162,7 +162,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
static_cast
<
int64_t
>
(
sequence_width
)});
warpctc_logits
.
mutable_data
<
T
>
(
warpctc_logits_dims
,
ctx
.
GetPlace
());
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
*
logits
,
warpctc_logits
,
ctx
.
template
device_context
<
DeviceContext
>(),
*
logits
,
&
warpctc_logits
,
false
);
const
T
*
warpctc_logits_data
=
warpctc_logits
.
data
<
T
>
();
...
...
@@ -217,7 +217,7 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
logits_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
bool
norm_by_times
=
ctx
.
Attr
<
bool
>
(
"norm_by_times"
);
math
::
UnpaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
*
logits_grad
,
ctx
.
template
device_context
<
DeviceContext
>(),
logits_grad
,
*
warpctc_grad
,
norm_by_times
);
const
T
*
loss_grad_data
=
loss_grad
->
data
<
T
>
();
...
...
paddle/fluid/platform/cuda_device_function.h
0 → 100644
浏览文件 @
a6edeb39
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
namespace
paddle
{
namespace
platform
{
// __shfl_down and __shfl have been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_sync
(
unsigned
,
T
val
,
int
src_line
,
int
width
)
{
return
__shfl
(
val
,
src_line
,
width
);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const
int
warpSize
=
32
;
__shared__
T
shm
[
warpSize
];
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
platform
::
__shfl_down_sync
(
mask
,
val
,
offset
);
if
(
tid
<
warpSize
)
shm
[
tid
]
=
0
;
if
(
tid
%
warpSize
==
0
)
{
shm
[
tid
/
warpSize
]
=
val
;
}
__syncthreads
();
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
if
(
tid
<
warpSize
)
{
val
=
shm
[
tid
];
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
platform
::
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/cuda_primitives.h
浏览文件 @
a6edeb39
...
...
@@ -66,22 +66,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
}
#endif
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
mask
,
T
val
,
int
delta
)
{
return
__shfl_down
(
mask
,
val
,
delta
);
}
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/profiler.h
浏览文件 @
a6edeb39
...
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.pb.h"
namespace
paddle
{
namespace
platform
{
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
a6edeb39
...
...
@@ -502,11 +502,11 @@ All parameter, weight, gradient are variables in Paddle.
const
std
::
unordered_set
<
std
::
string
>
&
bcast_vars
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
allow_op_delay
,
bool
customize_loss_grad
)
{
new
(
&
self
)
ParallelExecutor
(
num_threads
,
use_event
,
places
,
params
,
bcast_vars
,
main_program
,
loss_var_name
,
scope
,
local_scopes
,
allow_op_delay
,
customize_loss_grad
);
bool
allow_op_delay
,
bool
use_default_grad_scale
)
{
new
(
&
self
)
ParallelExecutor
(
num_threads
,
use_event
,
places
,
params
,
bcast_vars
,
main_program
,
loss_var_name
,
scope
,
local_scopes
,
allow_op_delay
,
use_default_grad_scale
);
})
.
def
(
"bcast_params"
,
&
ParallelExecutor
::
BCastParamsToGPUs
)
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
...
...
paddle/fluid/pybind/tensor_py.h
浏览文件 @
a6edeb39
...
...
@@ -107,7 +107,7 @@ T TensorGetElement(const framework::Tensor &self, size_t offset) {
return
self
.
data
<
T
>
()[
offset
];
}
else
{
std
::
shared_ptr
<
framework
::
Tensor
>
dst
(
new
framework
::
Tensor
);
framework
::
TensorCopy
(
self
,
platform
::
CPUPlace
(),
dst
.
get
());
framework
::
TensorCopy
Sync
(
self
,
platform
::
CPUPlace
(),
dst
.
get
());
return
dst
->
data
<
T
>
()[
offset
];
}
}
...
...
@@ -117,9 +117,9 @@ template <typename T>
void
TensorSetElement
(
framework
::
Tensor
*
self
,
size_t
offset
,
T
elem
)
{
if
(
platform
::
is_gpu_place
(
self
->
place
()))
{
std
::
shared_ptr
<
framework
::
Tensor
>
dst
(
new
framework
::
Tensor
);
framework
::
TensorCopy
(
*
self
,
platform
::
CPUPlace
(),
dst
.
get
());
framework
::
TensorCopy
Sync
(
*
self
,
platform
::
CPUPlace
(),
dst
.
get
());
dst
->
data
<
T
>
()[
offset
]
=
elem
;
framework
::
TensorCopy
(
*
dst
.
get
(),
self
->
place
(),
self
);
framework
::
TensorCopy
Sync
(
*
dst
.
get
(),
self
->
place
(),
self
);
}
else
if
(
platform
::
is_cpu_place
(
self
->
place
()))
{
self
->
data
<
T
>
()[
offset
]
=
elem
;
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
a6edeb39
...
...
@@ -40,6 +40,7 @@ function print_usage() {
${
BLUE
}
capi
${
NONE
}
: generate paddle CAPI package
${
BLUE
}
fluid_inference_lib
${
NONE
}
: deploy fluid inference library
${
BLUE
}
check_style
${
NONE
}
: run code style check
${
BLUE
}
cicheck
${
NONE
}
: run CI tasks
"
}
...
...
@@ -453,6 +454,8 @@ function gen_capi_package() {
}
function
gen_fluid_inference_lib
()
{
mkdir
-p
${
PADDLE_ROOT
}
/build
cd
${
PADDLE_ROOT
}
/build
if
[
${
WITH_C_API
:-
OFF
}
==
"OFF"
]
;
then
cat
<<
EOF
========================================
...
...
@@ -503,6 +506,13 @@ function main() {
check_style
)
check_style
;;
cicheck
)
cmake_gen
${
PYTHON_ABI
:-
""
}
build
run_test
gen_capi_package
gen_fluid_inference_lib
;;
*
)
print_usage
exit
0
...
...
python/paddle/fluid/__init__.py
浏览文件 @
a6edeb39
...
...
@@ -21,8 +21,7 @@ import executor
from
executor
import
*
import
trainer
from
trainer
import
Trainer
from
trainer
import
Event
from
trainer
import
*
import
inferencer
from
inferencer
import
Inferencer
...
...
python/paddle/fluid/distribute_transpiler.py
浏览文件 @
a6edeb39
...
...
@@ -137,8 +137,6 @@ def split_dense_variable(var_list,
class
DistributeTranspiler
:
def
transpile
(
self
,
optimize_ops
,
params_grads
,
trainer_id
,
program
=
None
,
pservers
=
"127.0.0.1:6174"
,
...
...
@@ -169,11 +167,6 @@ class DistributeTranspiler:
4. append ops that should run on current server instance.
5. add listen_and_serv op
:param optimize_ops: op list of optimization, should be the
return value of Optimizer.minimize
:type optimize_ops: list
:param params_grads: list of tuple(weight, gradient)
:type params_grads: list
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
...
...
@@ -194,7 +187,6 @@ class DistributeTranspiler:
program
=
default_main_program
()
self
.
origin_program
=
program
self
.
trainer_num
=
trainers
self
.
optimize_ops
=
optimize_ops
self
.
sync_mode
=
sync_mode
# TODO(typhoonzero): currently trainer_id is fetched from cluster system
# like Kubernetes, we should port this to use etcd later when developing
...
...
@@ -202,6 +194,7 @@ class DistributeTranspiler:
self
.
trainer_id
=
trainer_id
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
optimize_ops
,
params_grads
=
self
.
_get_optimize_pass
()
# process lookup_table_op
# 1. check all lookup_table_op is distributed
...
...
@@ -408,11 +401,8 @@ class DistributeTranspiler:
# HACK: optimization global ops only used to scale beta1 and beta2
# replace it with dependency engine.
for
op
in
self
.
optimize_ops
:
if
op
.
type
==
"scale"
:
for
in_name
in
op
.
input_arg_names
:
if
in_name
.
startswith
(
"beta1_pow_acc"
)
or
\
in_name
.
startswith
(
"beta2_pow_acc"
):
global_ops
.
append
(
op
)
if
self
.
_is_adam_connected_op
(
op
):
global_ops
.
append
(
op
)
def
__append_optimize_op__
(
op
,
block
,
grad_to_block_id
):
if
self
.
_is_opt_op
(
op
):
...
...
@@ -661,7 +651,7 @@ class DistributeTranspiler:
shape
=
trainer_out
.
shape
,
dtype
=
trainer_out
.
dtype
)
prefetch_block
.
append_op
(
type
=
LOOKUP_TABLE_TYPE
,
type
=
"lookup_sparse_table"
,
inputs
=
{
'Ids'
:
pserver_ids
,
"W"
:
table_var
},
outputs
=
{
"Out"
:
pserver_out
},
...
...
@@ -685,9 +675,14 @@ class DistributeTranspiler:
# STEP: create table optimize block
# create table param and grad var in pserver program
param_var
=
_clone_var
(
pserver_program
.
global_block
(),
self
.
origin_program
.
global_block
().
vars
[
self
.
table_name
])
origin_param_var
=
self
.
origin_program
.
global_block
().
vars
[
self
.
table_name
]
param_var
=
pserver_program
.
global_block
().
create_var
(
name
=
origin_param_var
.
name
,
shape
=
origin_param_var
.
shape
,
dtype
=
origin_param_var
.
dtype
,
type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
,
persistable
=
True
)
grad_var
=
_clone_var
(
pserver_program
.
global_block
(),
self
.
origin_program
.
global_block
().
vars
[
framework
.
grad_var_name
(
...
...
@@ -1142,3 +1137,32 @@ class DistributeTranspiler:
# we only need to append op for once
break
return
lr_ops
def
_get_optimize_pass
(
self
):
block
=
self
.
origin_program
.
global_block
()
opt_ops
=
[]
params_grads
=
[]
for
op
in
block
.
ops
:
if
self
.
_is_opt_op
(
op
):
opt_ops
.
append
(
op
)
params_grads
.
append
((
self
.
origin_program
.
global_block
().
var
(
op
.
input
(
"Param"
)[
0
]),
self
.
origin_program
.
global_block
().
var
(
op
.
input
(
"Grad"
)[
0
])))
elif
self
.
_is_adam_connected_op
(
op
):
opt_ops
.
append
(
op
)
else
:
pass
return
opt_ops
,
params_grads
def
_is_adam_connected_op
(
self
,
op
):
"""
A hack function to determinate whether the input operator
is connected to optimize operator.
"""
if
op
.
type
==
"scale"
:
for
in_name
in
op
.
input_arg_names
:
if
in_name
.
startswith
(
"beta1_pow_acc"
)
or
\
in_name
.
startswith
(
"beta2_pow_acc"
):
return
True
return
False
python/paddle/fluid/layers/io.py
浏览文件 @
a6edeb39
...
...
@@ -50,8 +50,6 @@ def data(name,
dtype(int|float): The type of data : float32, float_16, int etc
type(VarType): The output type. By default it is LOD_TENSOR.
lod_level(int): The LoD Level. 0 means the input data is not a sequence.
main_program(Program): Name of the main program that calls this
startup_program(Program): Name of the startup program
stop_gradient(bool): A boolean that mentions whether gradient should flow.
Returns:
...
...
@@ -74,13 +72,15 @@ def data(name,
if
append_batch_size
:
shape
=
[
-
1
]
+
shape
# append batch size as -1
return
helper
.
create_global_variable
(
data_var
=
helper
.
create_global_variable
(
name
=
name
,
shape
=
shape
,
dtype
=
dtype
,
type
=
type
,
stop_gradient
=
stop_gradient
,
lod_level
=
lod_level
)
data_var
.
is_data
=
True
return
data_var
class
BlockGuardServ
(
BlockGuard
):
...
...
@@ -168,7 +168,9 @@ class ListenAndServ(object):
'endpoint'
:
self
.
endpoint
,
'Fanin'
:
self
.
fan_in
,
'OptimizeBlock'
:
current_block
,
'PrefetchBlock'
:
empty_block
'PrefetchBlock'
:
empty_block
,
'sync_mode'
:
True
,
# did not support async now in layers
'grad_to_block_id'
:
[
""
]
})
...
...
python/paddle/fluid/layers/math_op_patch.py
浏览文件 @
a6edeb39
...
...
@@ -169,7 +169,9 @@ def monkey_patch_variable():
# a*b == b*a. Do not need to reverse explicitly
(
"__rmul__"
,
"elementwise_mul"
,
False
),
(
"__div__"
,
"elementwise_div"
,
False
),
(
"__truediv__"
,
"elementwise_div"
,
False
),
(
"__rdiv__"
,
"elementwise_div"
,
True
),
(
"__rtruediv__"
,
"elementwise_div"
,
True
),
(
"__pow__"
,
"elementwise_pow"
,
False
),
(
"__rpow__"
,
"elementwise_pow"
,
True
),
# for logical compare
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
a6edeb39
...
...
@@ -1496,6 +1496,7 @@ def batch_norm(input,
bias_attr
=
None
,
data_layout
=
'NCHW'
,
in_place
=
False
,
use_mkldnn
=
False
,
name
=
None
,
moving_mean_name
=
None
,
moving_variance_name
=
None
,
...
...
@@ -1574,9 +1575,12 @@ def batch_norm(input,
"SavedMean"
:
saved_mean
,
"SavedVariance"
:
saved_variance
},
attrs
=
{
"momentum"
:
momentum
,
"epsilon"
:
epsilon
,
"is_test"
:
is_test
})
attrs
=
{
"momentum"
:
momentum
,
"epsilon"
:
epsilon
,
"is_test"
:
is_test
,
"use_mkldnn"
:
use_mkldnn
})
return
helper
.
append_activation
(
batch_norm_out
)
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
a6edeb39
...
...
@@ -28,7 +28,8 @@ from contextlib import contextmanager
__all__
=
[
'SGD'
,
'Momentum'
,
'Adagrad'
,
'Adam'
,
'Adamax'
,
'DecayedAdagrad'
,
'SGDOptimizer'
,
'MomentumOptimizer'
,
'AdagradOptimizer'
,
'AdamOptimizer'
,
'AdamaxOptimizer'
,
'DecayedAdagradOptimizer'
,
'Adadelta'
,
'ModelAverage'
'AdamaxOptimizer'
,
'DecayedAdagradOptimizer'
,
'Adadelta'
,
'ModelAverage'
,
'Optimizer'
]
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
a6edeb39
...
...
@@ -30,7 +30,7 @@ class ParallelExecutor(object):
num_threads
=
None
,
allow_op_delay
=
False
,
share_vars_from
=
None
,
customize_loss_grad
=
Fals
e
):
use_default_grad_scale
=
Tru
e
):
"""
ParallelExecutor can run program in parallel.
...
...
@@ -46,6 +46,11 @@ class ParallelExecutor(object):
improve performance in some cases, defalut False.
share_vars_from(ParallelExecutor, default None): If provied,
it will share variables from the specified ParallelExecutor.
use_default_grad_scale(bool, default True): If set True, a default
scale value equal to `1./device_count` would be multiplied to
gradients of each device and scaled gradients would be
aggregated. Otherwise, a customized scale value should be fed
to the network.
Returns:
A ParallelExecutor object.
...
...
@@ -124,7 +129,7 @@ class ParallelExecutor(object):
scope
,
local_scopes
,
allow_op_delay
,
customize_loss_grad
)
use_default_grad_scale
)
self
.
scope
=
scope
def
run
(
self
,
fetch_list
,
feed
=
None
,
feed_dict
=
None
):
...
...
python/paddle/fluid/tests/book/test_fit_a_line.py
浏览文件 @
a6edeb39
...
...
@@ -80,12 +80,7 @@ def train(use_cuda, save_dirname, is_local):
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_INIT_TRAINER_ID"
))
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
...
...
python/paddle/fluid/tests/book/test_image_classification.py
浏览文件 @
a6edeb39
...
...
@@ -189,12 +189,7 @@ def train(net_type, use_cuda, save_dirname, is_local):
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_INIT_TRAINER_ID"
))
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
...
...
python/paddle/fluid/tests/book/test_label_semantic_roles.py
浏览文件 @
a6edeb39
...
...
@@ -259,12 +259,7 @@ def train(use_cuda, save_dirname=None, is_local=True):
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_INIT_TRAINER_ID"
))
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
...
...
python/paddle/fluid/tests/book/test_machine_translation.py
浏览文件 @
a6edeb39
...
...
@@ -231,12 +231,7 @@ def train_main(use_cuda, is_sparse, is_local=True):
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_INIT_TRAINER_ID"
))
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
...
...
python/paddle/fluid/tests/book/test_recognize_digits.py
浏览文件 @
a6edeb39
...
...
@@ -162,12 +162,7 @@ def train(nn_type,
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_INIT_TRAINER_ID"
))
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
...
...
python/paddle/fluid/tests/book/test_recommender_system.py
浏览文件 @
a6edeb39
...
...
@@ -261,12 +261,7 @@ def train(use_cuda, save_dirname, is_local=True):
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_INIT_TRAINER_ID"
))
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
...
...
python/paddle/fluid/tests/book/test_understand_sentiment.py
浏览文件 @
a6edeb39
...
...
@@ -213,12 +213,7 @@ def train(word_dict,
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_INIT_TRAINER_ID"
))
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
...
...
python/paddle/fluid/tests/book/test_word2vec.py
浏览文件 @
a6edeb39
...
...
@@ -145,12 +145,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True):
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_INIT_TRAINER_ID"
))
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
optimize_ops
,
params_grads
,
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
current_endpoint
,
...
...
python/paddle/fluid/tests/book/word2vec/no_test_word2vec_new_api.py
0 → 100644
浏览文件 @
a6edeb39
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
import
math
import
sys
from
functools
import
partial
PASS_NUM
=
100
EMBED_SIZE
=
32
HIDDEN_SIZE
=
256
N
=
5
BATCH_SIZE
=
32
def
create_random_lodtensor
(
lod
,
place
,
low
,
high
):
# The range of data elements is [low, high]
data
=
np
.
random
.
random_integers
(
low
,
high
,
[
lod
[
-
1
],
1
]).
astype
(
"int64"
)
res
=
fluid
.
LoDTensor
()
res
.
set
(
data
,
place
)
res
.
set_lod
([
lod
])
return
res
word_dict
=
paddle
.
dataset
.
imikolov
.
build_dict
()
dict_size
=
len
(
word_dict
)
def
inference_network
(
is_sparse
):
first_word
=
fluid
.
layers
.
data
(
name
=
'firstw'
,
shape
=
[
1
],
dtype
=
'int64'
)
second_word
=
fluid
.
layers
.
data
(
name
=
'secondw'
,
shape
=
[
1
],
dtype
=
'int64'
)
third_word
=
fluid
.
layers
.
data
(
name
=
'thirdw'
,
shape
=
[
1
],
dtype
=
'int64'
)
forth_word
=
fluid
.
layers
.
data
(
name
=
'forthw'
,
shape
=
[
1
],
dtype
=
'int64'
)
embed_first
=
fluid
.
layers
.
embedding
(
input
=
first_word
,
size
=
[
dict_size
,
EMBED_SIZE
],
dtype
=
'float32'
,
is_sparse
=
is_sparse
,
param_attr
=
'shared_w'
)
embed_second
=
fluid
.
layers
.
embedding
(
input
=
second_word
,
size
=
[
dict_size
,
EMBED_SIZE
],
dtype
=
'float32'
,
is_sparse
=
is_sparse
,
param_attr
=
'shared_w'
)
embed_third
=
fluid
.
layers
.
embedding
(
input
=
third_word
,
size
=
[
dict_size
,
EMBED_SIZE
],
dtype
=
'float32'
,
is_sparse
=
is_sparse
,
param_attr
=
'shared_w'
)
embed_forth
=
fluid
.
layers
.
embedding
(
input
=
forth_word
,
size
=
[
dict_size
,
EMBED_SIZE
],
dtype
=
'float32'
,
is_sparse
=
is_sparse
,
param_attr
=
'shared_w'
)
concat_embed
=
fluid
.
layers
.
concat
(
input
=
[
embed_first
,
embed_second
,
embed_third
,
embed_forth
],
axis
=
1
)
hidden1
=
fluid
.
layers
.
fc
(
input
=
concat_embed
,
size
=
HIDDEN_SIZE
,
act
=
'sigmoid'
)
predict_word
=
fluid
.
layers
.
fc
(
input
=
hidden1
,
size
=
dict_size
,
act
=
'softmax'
)
return
predict_word
def
train_network
(
is_sparse
):
next_word
=
fluid
.
layers
.
data
(
name
=
'nextw'
,
shape
=
[
1
],
dtype
=
'int64'
)
predict_word
=
inference_network
(
is_sparse
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict_word
,
label
=
next_word
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
return
avg_cost
def
train
(
use_cuda
,
is_sparse
,
save_path
):
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
imikolov
.
train
(
word_dict
,
N
),
BATCH_SIZE
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
def
event_handler
(
event
):
print
type
(
event
)
if
isinstance
(
event
,
fluid
.
EndEpochEvent
):
avg_cost
=
trainer
.
test
(
reader
=
paddle
.
dataset
.
imikolov
.
test
(
word_dict
,
N
))
if
avg_cost
<
5.0
:
trainer
.
params
.
save
(
save_path
)
return
if
math
.
isnan
(
avg_cost
):
sys
.
exit
(
"got NaN loss, training failed."
)
trainer
=
fluid
.
Trainer
(
partial
(
train_network
,
is_sparse
),
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
),
place
=
place
)
trainer
.
train
(
reader
=
train_reader
,
num_epochs
=
100
,
event_handler
=
event_handler
)
def
infer
(
use_cuda
,
save_path
):
params
=
fluid
.
Params
(
save_path
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
inferencer
=
fluid
.
Inferencer
(
inference_network
,
params
,
place
=
place
)
lod
=
[
0
,
1
]
first_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
-
1
)
second_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
-
1
)
third_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
-
1
)
fourth_word
=
create_random_lodtensor
(
lod
,
place
,
low
=
0
,
high
=
dict_size
-
1
)
result
=
inferencer
.
infer
({
'firstw'
:
first_word
,
'secondw'
:
second_word
,
'thirdw'
:
third_word
,
'forthw'
:
fourth_word
})
print
(
result
)
def
main
(
use_cuda
,
is_sparse
):
if
use_cuda
and
not
fluid
.
core
.
is_compiled_with_cuda
():
return
save_path
=
"word2vec.inference.model"
train
(
use_cuda
,
is_sparse
,
save_path
)
infer
(
use_cuda
,
save_path
)
if
__name__
==
'__main__'
:
for
use_cuda
in
(
False
,
True
):
for
is_sparse
in
(
False
,
True
):
main
(
use_cuda
=
use_cuda
,
is_sparse
=
is_sparse
)
python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py
0 → 100644
浏览文件 @
a6edeb39
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
import
paddle.fluid
as
fluid
from
op_test
import
OpTest
from
paddle.fluid.framework
import
grad_var_name
from
test_batch_norm_op
import
TestBatchNormOpInference
,
TestBatchNormOpTraining
,
_reference_training
,
_reference_grad
class
TestMKLDNNBatchNormOpTraining
(
TestBatchNormOpTraining
):
def
init_kernel_type
(
self
):
self
.
use_mkldnn
=
True
self
.
data_formats
=
[
"NCHW"
]
def
ref_forward_backward
(
self
,
x
,
y_grad
,
scale
,
bias
,
mean
,
variance
,
epsilon
,
momentum
,
shape
,
data_layout
):
# run forward
y
,
saved_mean
,
saved_variance
=
_reference_training
(
x
,
scale
,
bias
,
epsilon
,
data_layout
)
mean_out
=
saved_mean
*
(
1.
-
momentum
)
+
momentum
*
mean
variance_out
=
saved_variance
*
(
1.
-
momentum
)
+
momentum
*
variance
# run backward
x_grad
,
scale_grad
,
bias_grad
=
_reference_grad
(
x
,
y_grad
,
scale
,
saved_mean
,
saved_variance
,
epsilon
,
data_layout
)
return
y
,
mean_out
,
variance_out
,
saved_mean
,
saved_variance
,
x_grad
,
scale_grad
,
bias_grad
class
TestMKLDNNBatchNormOpInference
(
TestBatchNormOpInference
):
def
init_kernel_type
(
self
):
self
.
use_mkldnn
=
True
def
test_check_output
(
self
):
place
=
core
.
CPUPlace
()
data_format
=
"NCHW"
self
.
check_with_place
(
place
,
data_format
,
self
.
dtype
,
[
2
,
3
,
4
,
5
])
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
浏览文件 @
a6edeb39
...
...
@@ -158,6 +158,8 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
class
TestBatchNormOpInference
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
dtype
=
np
.
float32
self
.
use_mkldnn
=
False
self
.
init_kernel_type
()
def
__assert_close
(
self
,
tensor
,
np_array
,
msg
,
atol
=
1e-4
):
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
tensor
),
np_array
,
atol
=
atol
),
msg
)
...
...
@@ -230,6 +232,7 @@ class TestBatchNormOpInference(unittest.TestCase):
# attrs
is_test
=
True
,
data_layout
=
data_layout
,
use_mkldnn
=
self
.
use_mkldnn
,
epsilon
=
epsilon
)
batch_norm_op
.
run
(
scope
,
place
)
...
...
@@ -254,10 +257,15 @@ class TestBatchNormOpInference(unittest.TestCase):
[
2
,
3
,
4
,
5
])
self
.
check_with_place
(
place
,
data_format
,
self
.
dtype
,
[
2
,
3
])
def
init_kernel_type
(
self
):
pass
class
TestFP16BatchNormOpInference
(
TestBatchNormOpInference
):
def
setUp
(
self
):
self
.
dtype
=
np
.
float16
self
.
use_mkldnn
=
False
self
.
init_kernel_type
()
def
test_check_output
(
self
):
places
=
[]
...
...
@@ -274,9 +282,28 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
class
TestBatchNormOpTraining
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
use_mkldnn
=
False
self
.
data_formats
=
[
"NCHW"
,
"NHWC"
]
self
.
init_kernel_type
()
def
__assert_close
(
self
,
tensor
,
np_array
,
msg
,
atol
=
1e-4
):
np
.
allclose
(
np
.
array
(
tensor
),
np_array
,
atol
=
atol
)
def
ref_forward_backward
(
self
,
x
,
y_grad
,
scale
,
bias
,
mean
,
variance
,
epsilon
,
momentum
,
shape
,
data_layout
):
# run forward
y
,
saved_mean
,
var_ref
=
_reference_training
(
x
,
scale
,
bias
,
epsilon
,
data_layout
)
mean_out
=
saved_mean
*
(
1.
-
momentum
)
+
momentum
*
mean
variance_out
=
var_ref
*
(
1.
-
momentum
)
+
momentum
*
variance
saved_variance
=
1.
/
np
.
sqrt
(
var_ref
+
epsilon
)
# run backward
x_grad
,
scale_grad
,
bias_grad
=
_reference_grad
(
x
,
y_grad
,
scale
,
saved_mean
,
var_ref
,
epsilon
,
data_layout
)
return
y
,
mean_out
,
variance_out
,
saved_mean
,
saved_variance
,
x_grad
,
scale_grad
,
bias_grad
def
test_forward_backward
(
self
):
def
test_with_place
(
place
,
data_layout
,
shape
):
# attr
...
...
@@ -295,16 +322,11 @@ class TestBatchNormOpTraining(unittest.TestCase):
mean
=
np
.
zeros
(
scale_shape
).
astype
(
np
.
float32
)
variance
=
np
.
ones
(
scale_shape
).
astype
(
np
.
float32
)
# run forward
y
,
saved_mean
,
var_ref
=
_reference_training
(
x
,
scale
,
bias
,
epsilon
,
data_layout
)
mean_out
=
saved_mean
*
(
1.
-
momentum
)
+
momentum
*
mean
variance_out
=
var_ref
*
(
1.
-
momentum
)
+
momentum
*
variance
saved_variance
=
1.
/
np
.
sqrt
(
var_ref
+
epsilon
)
# run backward
y_grad
=
np
.
random
.
random_sample
(
shape
).
astype
(
np
.
float32
)
x_grad
,
scale_grad
,
bias_grad
=
_reference_grad
(
x
,
y_grad
,
scale
,
saved_mean
,
var_ref
,
epsilon
,
data_layout
)
y
,
mean_out
,
variance_out
,
saved_mean
,
saved_variance
,
x_grad
,
scale_grad
,
bias_grad
=
self
.
ref_forward_backward
(
x
,
y_grad
,
scale
,
bias
,
mean
,
variance
,
epsilon
,
momentum
,
shape
,
data_layout
)
var_dict
=
locals
()
var_dict
[
'y@GRAD'
]
=
y_grad
...
...
@@ -344,7 +366,8 @@ class TestBatchNormOpTraining(unittest.TestCase):
"momentum"
:
momentum
,
"epsilon"
:
epsilon
,
"is_test"
:
False
,
"data_layout"
:
data_layout
"data_layout"
:
data_layout
,
"use_mkldnn"
:
self
.
use_mkldnn
})
block
.
create_var
(
name
=
'y@GRAD'
,
dtype
=
'float32'
,
shape
=
y
.
shape
)
...
...
@@ -387,13 +410,17 @@ class TestBatchNormOpTraining(unittest.TestCase):
print
"op test forward passed: "
,
str
(
place
),
data_layout
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
()
and
core
.
op_support_gpu
(
"batch_norm"
):
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
for
data_format
in
[
"NCHW"
,
"NHWC"
]
:
for
data_format
in
self
.
data_formats
:
test_with_place
(
place
,
data_format
,
[
2
,
3
,
4
,
5
])
def
init_kernel_type
(
self
):
pass
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_dist_train.py
浏览文件 @
a6edeb39
...
...
@@ -34,7 +34,7 @@ class TestSendOp(unittest.TestCase):
p
.
start
()
time
.
sleep
(
10
)
with
open
(
"/tmp/paddle.
selected_port"
,
"r"
)
as
fn
:
with
open
(
"/tmp/paddle.
%d.selected_port"
%
p
.
pid
,
"r"
)
as
fn
:
selected_port
=
int
(
fn
.
readlines
()[
0
])
self
.
init_client
(
place
,
selected_port
)
...
...
python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py
0 → 100644
浏览文件 @
a6edeb39
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
def
output_hist
(
out
):
hist
,
_
=
np
.
histogram
(
out
,
range
=
(
-
5
,
10
))
hist
=
hist
.
astype
(
"float32"
)
hist
/=
float
(
out
.
size
)
prob
=
0.1
*
np
.
ones
((
10
))
return
hist
,
prob
class
TestLookupSpraseTable
(
OpTest
):
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
# create and initialize Id Variable
ids
=
scope
.
var
(
"Ids"
).
get_tensor
()
ids_array
=
np
.
array
([
0
,
2
,
3
,
5
,
100
]).
astype
(
"int64"
)
ids
.
set
(
ids_array
,
place
)
# create and initialize W Variable
rows
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]
row_numel
=
10000
w_selected_rows
=
scope
.
var
(
'W'
).
get_selected_rows
()
w_selected_rows
.
set_height
(
len
(
rows
))
w_selected_rows
.
set_rows
(
rows
)
w_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
for
i
in
range
(
len
(
rows
)):
w_array
[
i
]
*=
i
w_tensor
=
w_selected_rows
.
get_tensor
()
w_tensor
.
set
(
w_array
,
place
)
# create Out Variable
out_tensor
=
scope
.
var
(
'Out'
).
get_tensor
()
# create and run lookup_table operator
lookup_table
=
Operator
(
"lookup_sparse_table"
,
W
=
'W'
,
Ids
=
'Ids'
,
Out
=
'Out'
,
min
=-
5.0
,
max
=
10.0
,
seed
=
10
)
lookup_table
.
run
(
scope
,
place
)
# get result from Out
result_array
=
np
.
array
(
out_tensor
)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for
idx
,
row
in
enumerate
(
ids_array
[:
-
2
]):
assert
(
row
==
result_array
[
idx
]).
all
()
# check the random value
hist
,
prob
=
output_hist
(
result_array
[
-
1
])
self
.
assertTrue
(
np
.
allclose
(
hist
,
prob
,
rtol
=
0
,
atol
=
0.01
),
"hist: "
+
str
(
hist
))
def
test_w_is_selected_rows
(
self
):
places
=
[
core
.
CPUPlace
()]
# currently only support CPU
for
place
in
places
:
self
.
check_with_place
(
place
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/trainer.py
浏览文件 @
a6edeb39
...
...
@@ -12,44 +12,200 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
core
import
framework
import
executor
import
data_feeder
import
contextlib
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import
optimizer
as
opt_module
__all__
=
[
'Event'
,
'Trainer'
,
'BeginEpochEvent'
,
'EndEpochEvent'
,
'BeginStepEvent'
,
'EndStepEvent'
,
]
class
Event
(
object
):
BEGIN_EPOCH
=
0
END_EPOCH
=
1
BEGIN_STEP
=
2
END_STEP
=
3
class
BeginEpochEvent
(
object
):
def
__init__
(
self
,
epoch_id
):
self
.
epoch
=
epoch_id
class
EndEpochEvent
(
object
):
def
__init__
(
self
,
epoch_id
):
self
.
epoch
=
epoch_id
def
__init__
(
self
):
self
.
step
=
0
self
.
epoch
=
0
self
.
type
=
Event
.
BEGIN_EPOCH
class
BeginStepEvent
(
object
):
def
__init__
(
self
,
epoch_id
,
step_id
):
self
.
epoch
=
epoch_id
self
.
step
=
step_id
class
EndStepEvent
(
object
):
def
__init__
(
self
,
epoch_id
,
step_id
):
self
.
epoch
=
epoch_id
self
.
step
=
step_id
class
Trainer
(
object
):
"""
Args:
network_func(callable): A function which will return loss. The loss must be a scaler.
optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer
params:
place: The device place of this trainer.
"""
def
__init__
(
self
,
network_func
,
optimizer
,
params
=
None
,
place
=
None
):
# 1. we need to generate a framework.Program by calling
# network_func. Reference: fluid.program_guard in
# test_word2vec.py
self
.
scope
=
self
.
_get_scope_from_params
(
params
)
self
.
startup_program
=
framework
.
Program
()
self
.
train_program
=
framework
.
Program
()
with
framework
.
program_guard
(
self
.
train_program
,
self
.
startup_program
):
loss
=
network_func
()
if
not
isinstance
(
optimizer
,
opt_module
.
Optimizer
):
raise
TypeError
(
"The optimizer should be an instance of Optimizer"
)
optimizer
.
minimize
(
loss
)
self
.
place
=
Trainer
.
_check_and_get_place
(
place
)
# 2. move the default_main_program to self.program and run the
# default_startup program on an empty core.Scope()
# Run startup program
if
params
is
None
:
exe
=
executor
.
Executor
(
place
)
exe
.
run
(
self
.
startup_program
,
scope
=
self
.
scope
)
# 3. call self.params.add_vars with the initialized scope, it
# will add the new vars of the initialized scope into
# self.params.
self
.
network_func
=
network_func
self
.
optimizer
=
optimizer
self
.
params
=
params
self
.
place
=
place
# TODO(yuyang): This depends on parameters implementation.
# TODO(helin): support distributed training
def
train
(
self
,
reader
,
num_epochs
,
event_handler
):
pass
def
train
(
self
,
num_epochs
,
event_handler
,
reader
=
None
,
parallel
=
False
,
feed_order
=
None
):
"""
Train the model.
Args:
num_epochs: The number of epoch. An epoch will process all data in reader
event_handler: The event handler. A function with type (ev:Event)->void
reader:
parallel: True if use multi-CPUs or multi-GPUs
feed_order: Feeding order of reader. None will following the defining
order in program
Returns:
"""
if
parallel
:
raise
NotImplementedError
(
"Parallel Executor version of trainer is not implemented"
)
self
.
_train_by_executor
(
num_epochs
,
event_handler
,
reader
,
feed_order
)
def
test
(
self
,
reader
):
pass
def
_get_scope_from_params
(
self
,
params
):
"""
Get Scope from parameter object.
Args:
params(Parameter|None): The parameter object instance. Could be None.
Returns: New scope if params is None. Or params.scope()
NOTE: This method is WIP. Not fully implemented.
"""
if
params
is
None
:
return
core
.
Scope
()
# new scope when params is None
else
:
raise
NotImplementedError
(
"Not implemented right now."
)
@
staticmethod
def
_check_and_get_place
(
place
):
"""
Check the type of place or get the default place
Args:
place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
Raises:
TypeError if the type mismatched.
Returns:
the original place if it is not None.
if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
Otherwise returns CPUPlace by default.
"""
if
place
is
None
:
if
core
.
is_compiled_with_cuda
():
return
core
.
CUDAPlace
(
0
)
else
:
return
core
.
CPUPlace
()
else
:
if
not
isinstance
(
place
,
core
.
CUDAPlace
)
and
not
isinstance
(
place
,
core
.
CPUPlace
):
raise
TypeError
(
"Place should be either CUDAPlace or CPUPlace"
)
return
place
@
contextlib
.
contextmanager
def
_prog_and_scope_guard
(
self
):
with
framework
.
program_guard
(
main_program
=
self
.
train_program
,
startup_program
=
self
.
startup_program
):
with
executor
.
scope_guard
(
self
.
scope
):
yield
def
_train_by_executor
(
self
,
num_epochs
,
event_handler
,
reader
,
feed_order
):
"""
Train by Executor and single device.
Args:
num_epochs:
event_handler:
reader:
feed_order:
Returns:
"""
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
if
feed_order
is
None
:
feed_var_list
=
[
var
for
var
in
self
.
train_program
.
global_block
(
).
vars
.
itervalues
()
if
hasattr
(
var
,
'is_data'
)
and
var
.
is_data
]
else
:
feed_var_list
=
[
self
.
train_program
.
global_block
().
var
(
var_name
)
for
var_name
in
feed_order
]
feeder
=
data_feeder
.
DataFeeder
(
feed_list
=
feed_var_list
,
place
=
self
.
place
)
for
epoch_id
in
range
(
num_epochs
):
event_handler
(
BeginEpochEvent
(
epoch_id
))
for
step_id
,
data
in
enumerate
(
reader
()):
event_handler
(
BeginStepEvent
(
epoch_id
,
step_id
))
exe
.
run
(
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[])
event_handler
(
EndStepEvent
(
epoch_id
,
step_id
))
event_handler
(
EndEpochEvent
(
epoch_id
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录