Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
394828b7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
394828b7
编写于
2月 14, 2018
作者:
K
Kavya Srinet
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into refine_pod
上级
4f23dbd5
118d950e
变更
57
隐藏空白更改
内联
并排
Showing
57 changed file
with
619 addition
and
653 deletion
+619
-653
cmake/external/boost.cmake
cmake/external/boost.cmake
+7
-0
doc/design/parallel_do.md
doc/design/parallel_do.md
+162
-0
paddle/fluid/framework/data_device_transform.cc
paddle/fluid/framework/data_device_transform.cc
+1
-1
paddle/fluid/framework/data_device_transform_test.cu
paddle/fluid/framework/data_device_transform_test.cu
+2
-2
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+4
-2
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+6
-6
paddle/fluid/framework/lod_tensor.h
paddle/fluid/framework/lod_tensor.h
+2
-2
paddle/fluid/framework/mixed_vector.h
paddle/fluid/framework/mixed_vector.h
+7
-5
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+1
-1
paddle/fluid/framework/selected_rows.cc
paddle/fluid/framework/selected_rows.cc
+2
-2
paddle/fluid/framework/tensor_util.cc
paddle/fluid/framework/tensor_util.cc
+191
-7
paddle/fluid/framework/tensor_util.cu
paddle/fluid/framework/tensor_util.cu
+1
-119
paddle/fluid/framework/tensor_util.cu
paddle/fluid/framework/tensor_util.cu
+1
-119
paddle/fluid/framework/tensor_util.h
paddle/fluid/framework/tensor_util.h
+29
-232
paddle/fluid/framework/tensor_util_test.cc
paddle/fluid/framework/tensor_util_test.cc
+34
-31
paddle/fluid/framework/tensor_util_test.cu
paddle/fluid/framework/tensor_util_test.cu
+4
-4
paddle/fluid/framework/threadpool.h
paddle/fluid/framework/threadpool.h
+1
-1
paddle/fluid/operators/array_operator.h
paddle/fluid/operators/array_operator.h
+1
-1
paddle/fluid/operators/array_to_lod_tensor_op.cc
paddle/fluid/operators/array_to_lod_tensor_op.cc
+2
-2
paddle/fluid/operators/assign_op.cc
paddle/fluid/operators/assign_op.cc
+2
-2
paddle/fluid/operators/assign_value_op.h
paddle/fluid/operators/assign_value_op.h
+1
-1
paddle/fluid/operators/beam_search_decode_op.h
paddle/fluid/operators/beam_search_decode_op.h
+2
-2
paddle/fluid/operators/detection_output_op.h
paddle/fluid/operators/detection_output_op.h
+17
-17
paddle/fluid/operators/expand_op.h
paddle/fluid/operators/expand_op.h
+2
-1
paddle/fluid/operators/feed_op.cc
paddle/fluid/operators/feed_op.cc
+1
-1
paddle/fluid/operators/fetch_op.cc
paddle/fluid/operators/fetch_op.cc
+1
-1
paddle/fluid/operators/fill_op.cc
paddle/fluid/operators/fill_op.cc
+1
-1
paddle/fluid/operators/layer_norm_op.h
paddle/fluid/operators/layer_norm_op.h
+2
-2
paddle/fluid/operators/load_combine_op.cc
paddle/fluid/operators/load_combine_op.cc
+1
-1
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+1
-1
paddle/fluid/operators/lod_reset_op.h
paddle/fluid/operators/lod_reset_op.h
+2
-2
paddle/fluid/operators/lod_tensor_to_array_op.cc
paddle/fluid/operators/lod_tensor_to_array_op.cc
+3
-3
paddle/fluid/operators/math/context_project.h
paddle/fluid/operators/math/context_project.h
+4
-2
paddle/fluid/operators/math/im2col_test.cc
paddle/fluid/operators/math/im2col_test.cc
+7
-7
paddle/fluid/operators/math/math_function_test.cu
paddle/fluid/operators/math/math_function_test.cu
+18
-18
paddle/fluid/operators/math/selected_rows_functor_test.cu
paddle/fluid/operators/math/selected_rows_functor_test.cu
+4
-4
paddle/fluid/operators/math/sequence_padding.cu
paddle/fluid/operators/math/sequence_padding.cu
+2
-2
paddle/fluid/operators/math/sequence_padding_test.cc
paddle/fluid/operators/math/sequence_padding_test.cc
+2
-2
paddle/fluid/operators/math/vol2col_test.cc
paddle/fluid/operators/math/vol2col_test.cc
+4
-4
paddle/fluid/operators/merge_lod_tensor_op.cc
paddle/fluid/operators/merge_lod_tensor_op.cc
+4
-3
paddle/fluid/operators/mine_hard_examples_op.cc
paddle/fluid/operators/mine_hard_examples_op.cc
+2
-1
paddle/fluid/operators/multiplex_op.cu
paddle/fluid/operators/multiplex_op.cu
+2
-2
paddle/fluid/operators/nccl_op_test.cu.cc
paddle/fluid/operators/nccl_op_test.cu.cc
+1
-1
paddle/fluid/operators/parallel_do_op.cc
paddle/fluid/operators/parallel_do_op.cc
+3
-3
paddle/fluid/operators/print_op.cc
paddle/fluid/operators/print_op.cc
+1
-1
paddle/fluid/operators/recurrent_op.cc
paddle/fluid/operators/recurrent_op.cc
+4
-4
paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc
paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc
+1
-1
paddle/fluid/operators/reshape_op.h
paddle/fluid/operators/reshape_op.h
+2
-2
paddle/fluid/operators/sequence_reshape_op.h
paddle/fluid/operators/sequence_reshape_op.h
+2
-2
paddle/fluid/operators/sequence_slice_op.h
paddle/fluid/operators/sequence_slice_op.h
+8
-8
paddle/fluid/operators/shrink_rnn_memory_op.cc
paddle/fluid/operators/shrink_rnn_memory_op.cc
+1
-1
paddle/fluid/operators/split_lod_tensor_op.cc
paddle/fluid/operators/split_lod_tensor_op.cc
+5
-4
paddle/fluid/operators/sum_op.h
paddle/fluid/operators/sum_op.h
+2
-2
paddle/fluid/operators/tensor_array_read_write_op.cc
paddle/fluid/operators/tensor_array_read_write_op.cc
+2
-2
paddle/fluid/operators/warpctc_op.h
paddle/fluid/operators/warpctc_op.h
+3
-2
paddle/fluid/pybind/tensor_py.h
paddle/fluid/pybind/tensor_py.h
+3
-3
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+38
-0
未找到文件。
cmake/external/boost.cmake
浏览文件 @
394828b7
...
...
@@ -15,6 +15,13 @@
include
(
ExternalProject
)
set
(
BOOST_PROJECT
"extern_boost"
)
# To release PaddlePaddle as a pip package, we have to follow the
# manylinux1 standard, which features as old Linux kernels and
# compilers as possible and recommends CentOS 5. Indeed, the earliest
# CentOS version that works with NVIDIA CUDA is CentOS 6. And a new
# version of boost, say, 1.66.0, doesn't build on CentOS 6. We
# checked that the devtools package of CentOS 6 installs boost 1.41.0.
# So we use 1.41.0 here.
set
(
BOOST_VER
"1.41.0"
)
set
(
BOOST_TAR
"boost_1_41_0"
)
set
(
BOOST_URL
"http://paddlepaddledeps.s3-website-us-west-1.amazonaws.com/
${
BOOST_TAR
}
.tar.gz"
)
...
...
doc/design/parallel_do.md
0 → 100644
浏览文件 @
394828b7
# Design Doc: Parallel_Do in PaddlePaddle
In PaddlePaddle, we use parallel_do primitive to represent multithread data parallel processing.
## Design overview
The definition of a parallel_do op looks like the following
```
c++
AddInput
(
kInputs
,
"Inputs needed to be split onto different devices"
).
AsDuplicable
();
AddInput
(
kParameters
,
"Parameters are duplicated over different devices"
)
.
AsDuplicable
();
AddInput
(
kPlaces
,
"Devices used for parallel processing"
);
AddOutput
(
kOutputs
,
"Outputs needed to be merged from different devices"
).
AsDuplicable
();
AddOutput
(
kParallelScopes
,
"Scopes for all local variables in forward pass. One scope for each device"
);
AddAttr
<
framework
::
BlockDesc
*>
(
kParallelBlock
,
"List of operaters to be executed in parallel"
);
```
A vanilla implementation of parallel_do can be shown as the following (
`|`
means single thread and
`||||`
means multiple threads)
```
In the forward pass
| Split input onto different devices
| Copy parameter to onto different devices
|||| Compute forward pass in parallel
| Merge output from different devices
In the backward pass
| Split output@grad onto different devices
|||| Compute backward pass in parallel
| accumulate param@grad from different devices to the first device
| Merge input@grad from different devices
| Copy param@grad to the place of parallel_do_op
```
This implementation allows to write mixed device program like this
```
python
# get embedding feature on CPU
feature
=
some_cpu_only_op
(
data
)
gpu_places
=
get_place
(
use_gpu
=
True
)
# parallel processing on multiple GPUs
pd
=
ParallelDo
(
gpu_places
)
with
pd
.
do
():
read_input
(
feature
)
prediction
=
my_net
(
feature
)
write_output
(
prediction
)
prediction
=
pd
()
loss
=
cross_entropy
(
prediction
,
label
)
```
And the programDesc are like the following
```
# start_program will be run by executor(CPUPlace), all w1, w2 will be allocated on CPU
start_program
{
vars: w1, w2
ops: init(w1), init(w2)
}
main_program
{
block0 {
vars: data, places, w1, w2
ops: data, get_place, parallel_do(block1),
parallel_do_grad(block2),
sgd(w2, w2_grad),
sgd(w1, w1_grad)
}
block1 {
parent_block: 0
vars: data, h1, h2, loss
ops: fc, fc, softmax
}
block2 {
parent_block: 1
vars: data_grad, h1_grad, h2_grad, loss_gard, w1_grad, w2_grad
ops: softmax_grad,
fc_grad
fc_grad
}
}
```
## Proformance Imporvement
There are serial places we can make this parallel_do faster.
### forward: split input onto different devices
If the input of the parallel_do is independent from any prior opeartors, we can avoid this step by
prefetching the input onto different devices in a seperate background thread. And the python code
looks like this.
```
python
pd
=
ParallelDo
(
gpu_places
)
with
pd
.
do
():
feature
=
get_data_from_prefetch_queue
(
gpu_places
)
prediction
=
my_net
(
feature
)
write_output
(
activation
)
```
### forward: Copy parameter to onto different devices
We can avoid this step by making each device have a copy of the parameter. This requires:
1.
`fluid.default_start_up_program()`
to be run on all devices
1.
In the backward, allreduce param@grad at different devices, this requires
1.
`backward.py`
add
`allreduce`
operators at parallel_do_grad
1.
`allreduce`
operators need to be called in async mode to achieve maximum throughput
1.
apply gradients related op(i.e. cliping, normalization, decay, sgd) on different devices in parallel
By doing so, we also avoided "backward: accumulate param@grad from different devices to the first device".
And the ProgramDesc looks like the following
```
# w1, w2 will be allocated on all GPUs
start_program
{
block0 {
parallel_do(block1)
}
block1 {
parent_block: 0
vars: w1, w2
ops: init(w1), init(w2)
}
}
main_program
{
block0 {
vars: data, places, w1, w2
ops: data, get_place, parallel_do(block1),
parallel_do_grad(block2), # append_backward
parallel_do(block3) # append_optimization
}
block1 {
parent_block: 0
vars: data, h1, h2, loss
ops: fc, fc, softmax
}
block2 {
parent_block: 1
vars: data_grad, h1_grad, h2_grad, loss_gard, w1_grad, w2_grad
ops: softmax_grad,
fc_grad, allreduce(places, scopes, w1_grad),
fc_grad, allreduce(places, scopes, w2_grad)
}
block3 {
parent_block: 0
vars: lr
ops: sgd(w2, w2_grad),
sgd(w1, w1_grad)
}
}
```
paddle/fluid/framework/data_device_transform.cc
浏览文件 @
394828b7
...
...
@@ -37,7 +37,7 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
<<
" dst_place: "
<<
dst_place
;
auto
*
dev_ctx
=
GetDeviceContext
(
in
.
place
(),
dst_place
);
dev_ctx
->
Wait
();
Copy
(
in
,
dst_place
,
*
dev_ctx
,
out
);
Tensor
Copy
(
in
,
dst_place
,
*
dev_ctx
,
out
);
dev_ctx
->
Wait
();
}
...
...
paddle/fluid/framework/data_device_transform_test.cu
浏览文件 @
394828b7
...
...
@@ -157,8 +157,8 @@ TEST(Operator, CPUtoGPU) {
auto
dev_ctx
=
pool
.
Get
(
cuda_place
);
paddle
::
framework
::
Tensor
output_tensor
;
Copy
(
output2
->
Get
<
LoDTensor
>
(),
paddle
::
platform
::
CPUPlace
(),
*
dev_ctx
,
&
output_tensor
);
Tensor
Copy
(
output2
->
Get
<
LoDTensor
>
(),
paddle
::
platform
::
CPUPlace
(),
*
dev_ctx
,
&
output_tensor
);
dev_ctx
->
Wait
();
float
*
output2_ptr
=
output_tensor
.
data
<
float
>
();
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
394828b7
...
...
@@ -73,8 +73,10 @@ static void CheckTensorNANOrInf(const std::string& name,
tensor
.
type
().
hash_code
()
!=
typeid
(
double
).
hash_code
())
{
return
;
}
PADDLE_ENFORCE
(
!
framework
::
HasInf
(
tensor
),
"Tensor %s has Inf"
,
name
);
PADDLE_ENFORCE
(
!
framework
::
HasNAN
(
tensor
),
"Tensor %s has NAN"
,
name
);
PADDLE_ENFORCE
(
!
framework
::
TensorContainsInf
(
tensor
),
"Tensor %s contains Inf"
,
name
);
PADDLE_ENFORCE
(
!
framework
::
TensorContainsNAN
(
tensor
),
"Tensor %s contains NAN"
,
name
);
}
void
Executor
::
Run
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
,
...
...
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
394828b7
...
...
@@ -46,7 +46,7 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
if
(
!
platform
::
is_cpu_place
(
t
.
place
()))
{
LoDTensor
tt
;
framework
::
Copy
(
t
,
platform
::
CPUPlace
(),
&
tt
);
framework
::
Tensor
Copy
(
t
,
platform
::
CPUPlace
(),
&
tt
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
t
.
place
());
dev_ctx
.
Wait
();
...
...
@@ -255,7 +255,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
}
}
// the 3st field, Tensor
Serialize
ToStream
(
os
,
static_cast
<
Tensor
>
(
tensor
),
dev_ctx
);
Tensor
ToStream
(
os
,
static_cast
<
Tensor
>
(
tensor
),
dev_ctx
);
}
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
,
...
...
@@ -282,7 +282,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
}
}
// the 3st filed, Tensor
Deserialize
FromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
),
dev_ctx
);
Tensor
FromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
),
dev_ctx
);
}
std
::
vector
<
LoDTensor
>
LoDTensor
::
SplitLoDTensor
(
...
...
@@ -308,14 +308,14 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
if
(
lod
().
empty
())
{
auto
src
=
Slice
(
begin
,
end
);
auto
&
dst_place
=
places
[
i
];
framework
::
Copy
(
src
,
dst_place
,
&
dst
);
framework
::
Tensor
Copy
(
src
,
dst_place
,
&
dst
);
}
else
{
auto
lod_and_offset
=
GetSubLoDAndAbsoluteOffset
(
lod
(),
begin
,
end
,
0
);
auto
&
offset
=
lod_and_offset
.
second
;
auto
src
=
Slice
(
offset
.
first
,
offset
.
second
);
auto
&
dst_place
=
places
[
i
];
framework
::
Copy
(
src
,
dst_place
,
&
dst
);
framework
::
Tensor
Copy
(
src
,
dst_place
,
&
dst
);
LoD
my_lod
;
for
(
auto
&
l
:
lod_and_offset
.
first
)
{
...
...
@@ -369,7 +369,7 @@ void LoDTensor::MergeLoDTensor(
for
(
auto
*
src
:
lod_tensors
)
{
int
end
=
begin
+
src
->
dims
()[
0
];
auto
dst
=
Slice
(
begin
,
end
);
framework
::
Copy
(
*
src
,
dst_place
,
&
dst
);
framework
::
Tensor
Copy
(
*
src
,
dst_place
,
&
dst
);
begin
=
end
;
}
}
...
...
paddle/fluid/framework/lod_tensor.h
浏览文件 @
394828b7
...
...
@@ -175,8 +175,8 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
for
(
size_t
ins
=
0
;
ins
<
num_instances
;
ins
++
)
{
for
(
size_t
elem
=
lod_level
[
ins
];
elem
<
lod_level
[
ins
+
1
];
elem
++
)
{
auto
slice
=
tensor
.
Slice
(
elem
,
elem
+
1
);
Copy
(
source
.
Slice
(
ins
,
ins
+
1
),
platform
::
CPUPlace
(),
platform
::
CPUDeviceContext
(),
&
slice
);
Tensor
Copy
(
source
.
Slice
(
ins
,
ins
+
1
),
platform
::
CPUPlace
(),
platform
::
CPUDeviceContext
(),
&
slice
);
}
}
return
tensor
;
...
...
paddle/fluid/framework/mixed_vector.h
浏览文件 @
394828b7
...
...
@@ -291,7 +291,7 @@ class Vector {
void
CopyToCPU
()
const
{
// COPY GPU Data To CPU
Copy
(
cuda_vec_
,
platform
::
CPUPlace
(),
&
cpu_vec_
);
Tensor
Copy
(
cuda_vec_
,
platform
::
CPUPlace
(),
&
cpu_vec_
);
WaitPlace
(
cuda_vec_
.
place
());
}
...
...
@@ -305,13 +305,14 @@ class Vector {
void
ImmutableCUDA
(
platform
::
Place
place
)
const
{
if
(
IsDirty
())
{
if
(
IsInCPU
())
{
Copy
(
cpu_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
cuda_vec_
);
TensorCopy
(
cpu_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
cuda_vec_
);
WaitPlace
(
place
);
UnsetFlag
(
kDirty
);
SetFlag
(
kDataInCUDA
);
}
else
if
(
IsInCUDA
()
&&
!
(
place
==
cuda_vec_
.
place
()))
{
framework
::
Tensor
tmp
;
Copy
(
cuda_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
tmp
);
Tensor
Copy
(
cuda_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
tmp
);
WaitPlace
(
cuda_vec_
.
place
());
cuda_vec_
.
ShareDataWith
(
tmp
);
// Still dirty
...
...
@@ -322,13 +323,14 @@ class Vector {
}
else
{
if
(
!
IsInCUDA
())
{
// Even data is not dirty. However, data is not in CUDA. Copy data.
Copy
(
cpu_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
cuda_vec_
);
TensorCopy
(
cpu_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
cuda_vec_
);
WaitPlace
(
place
);
SetFlag
(
kDataInCUDA
);
}
else
if
(
!
(
place
==
cuda_vec_
.
place
()))
{
framework
::
Tensor
tmp
;
WaitPlace
(
cuda_vec_
.
place
());
Copy
(
cuda_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
tmp
);
Tensor
Copy
(
cuda_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
tmp
);
WaitPlace
(
cuda_vec_
.
place
());
WaitPlace
(
place
);
cuda_vec_
.
ShareDataWith
(
tmp
);
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
394828b7
...
...
@@ -105,7 +105,7 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
}
}
Tensor
dst
=
out_tensor
.
Slice
(
dst_offset
,
dst_offset
+
ins_shape
[
0
]);
Copy
(
buffer_
[
i
][
j
],
platform
::
CPUPlace
(),
&
dst
);
Tensor
Copy
(
buffer_
[
i
][
j
],
platform
::
CPUPlace
(),
&
dst
);
dst_offset
+=
ins_shape
[
0
];
}
out_tensor
.
set_lod
(
batch_lod
);
...
...
paddle/fluid/framework/selected_rows.cc
浏览文件 @
394828b7
...
...
@@ -34,7 +34,7 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
height
),
sizeof
(
height
));
}
// the 4st field, Tensor data
Serialize
ToStream
(
os
,
selected_rows
.
value
(),
dev_ctx
);
Tensor
ToStream
(
os
,
selected_rows
.
value
(),
dev_ctx
);
}
void
DeserializeFromStream
(
std
::
istream
&
is
,
SelectedRows
*
selected_rows
,
...
...
@@ -62,7 +62,7 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
selected_rows
->
set_height
(
height
);
}
// the 4st field, tensor which contains the data
Deserialize
FromStream
(
is
,
selected_rows
->
mutable_value
(),
dev_ctx
);
Tensor
FromStream
(
is
,
selected_rows
->
mutable_value
(),
dev_ctx
);
}
}
// namespace framework
...
...
paddle/fluid/framework/tensor_util.cc
浏览文件 @
394828b7
...
...
@@ -16,6 +16,76 @@
namespace
paddle
{
namespace
framework
{
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
)
{
VLOG
(
3
)
<<
"TensorCopy "
<<
src
.
dims
()
<<
" from "
<<
src
.
place
()
<<
" to "
<<
dst_place
;
src
.
check_memory_size
();
dst
->
Resize
(
src
.
dims
());
dst
->
set_layout
(
src
.
layout
());
auto
src_place
=
src
.
place
();
auto
src_ptr
=
src
.
data
<
void
>
();
auto
dst_ptr
=
dst
->
mutable_data
(
dst_place
,
src
.
type
());
auto
size
=
src
.
numel
()
*
SizeOfType
(
src
.
type
());
if
(
platform
::
is_cpu_place
(
src_place
)
&&
platform
::
is_cpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
);
}
#ifdef PADDLE_WITH_CUDA
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
// NOLINT
platform
::
is_cpu_place
(
dst_place
))
{
auto
src_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
src_place
);
auto
dst_cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
);
auto
ctx_place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
PADDLE_ENFORCE_EQ
(
src_gpu_place
,
ctx_gpu_place
);
memory
::
Copy
(
dst_cpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
}
else
if
(
platform
::
is_cpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
auto
src_cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
);
auto
dst_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dst_place
);
auto
ctx_place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
PADDLE_ENFORCE_EQ
(
dst_gpu_place
,
ctx_gpu_place
);
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_cpu_place
,
src_ptr
,
size
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
}
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
auto
src_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
src_place
);
auto
dst_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dst_place
);
auto
ctx_place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
PADDLE_ENFORCE_EQ
(
src_gpu_place
,
ctx_gpu_place
);
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
}
#endif
}
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
Tensor
*
dst
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
const
platform
::
DeviceContext
*
dev_ctx
;
if
(
platform
::
is_gpu_place
(
src
.
place
()))
{
dev_ctx
=
pool
.
Get
(
src
.
place
());
}
else
{
dev_ctx
=
pool
.
Get
(
dst_place
);
}
TensorCopy
(
src
,
dst_place
,
*
dev_ctx
,
dst
);
}
template
<
typename
Predicate
,
typename
DevCtx
>
struct
AnyDTypeVisitor
{
Predicate
predicate_
;
...
...
@@ -69,7 +139,7 @@ struct AnyVisitor : public boost::static_visitor<bool> {
tmp
.
mutable_data
<
bool
>
(
cpu
);
auto
gpuctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
gpu
);
gpuctx
->
Wait
();
Copy
(
out
,
cpu
,
*
gpuctx
,
&
tmp
);
Tensor
Copy
(
out
,
cpu
,
*
gpuctx
,
&
tmp
);
gpuctx
->
Wait
();
return
GetResult
(
tmp
,
cpu
);
}
...
...
@@ -87,7 +157,7 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
return
platform
::
VisitPlace
(
place
,
visitor
);
}
struct
Ha
sNANPredicate
{
struct
Contain
sNANPredicate
{
template
<
typename
T
>
auto
operator
()(
const
T
&
eigen_vec
)
const
->
decltype
(
std
::
declval
<
T
>
().
isnan
())
{
...
...
@@ -96,12 +166,12 @@ struct HasNANPredicate {
}
};
bool
Ha
sNAN
(
const
framework
::
Tensor
&
tensor
)
{
Ha
sNANPredicate
predicate
;
bool
TensorContain
sNAN
(
const
framework
::
Tensor
&
tensor
)
{
Contain
sNANPredicate
predicate
;
return
Any
(
tensor
,
predicate
);
}
struct
Ha
sInfPredicate
{
struct
Contain
sInfPredicate
{
template
<
typename
T
>
auto
operator
()(
const
T
&
eigen_vec
)
const
->
decltype
(
std
::
declval
<
T
>
().
isinf
())
{
...
...
@@ -110,10 +180,124 @@ struct HasInfPredicate {
}
};
bool
Ha
sInf
(
const
framework
::
Tensor
&
tensor
)
{
Ha
sInfPredicate
predicate
;
bool
TensorContain
sInf
(
const
framework
::
Tensor
&
tensor
)
{
Contain
sInfPredicate
predicate
;
return
Any
(
tensor
,
predicate
);
}
void
TensorToStream
(
std
::
ostream
&
os
,
const
Tensor
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
// TODO(typhoonzero): serialize to ostream
{
// the 1st field, uint32_t version
constexpr
uint32_t
version
=
0
;
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
version
),
sizeof
(
version
));
}
{
// the 2nd field, tensor description
// int32_t size
// void* protobuf message
proto
::
VarType
::
TensorDesc
desc
;
desc
.
set_data_type
(
framework
::
ToDataType
(
tensor
.
type
()));
auto
dims
=
framework
::
vectorize
(
tensor
.
dims
());
auto
*
pb_dims
=
desc
.
mutable_dims
();
pb_dims
->
Resize
(
static_cast
<
int
>
(
dims
.
size
()),
0
);
std
::
copy
(
dims
.
begin
(),
dims
.
end
(),
pb_dims
->
begin
());
int32_t
size
=
desc
.
ByteSize
();
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
auto
out
=
desc
.
SerializeAsString
();
os
.
write
(
out
.
data
(),
size
);
}
{
// the 3rd field, tensor data
uint64_t
size
=
tensor
.
memory_size
();
auto
*
data_ptr
=
tensor
.
data
<
void
>
();
PADDLE_ENFORCE
(
size
<
std
::
numeric_limits
<
std
::
streamsize
>::
max
(),
"Index overflow when writing tensor"
);
if
(
platform
::
is_gpu_place
(
tensor
.
place
()))
{
#ifdef PADDLE_WITH_CUDA
constexpr
size_t
kBufSize
=
1024
*
1024
*
64
;
// 64MB
std
::
unique_ptr
<
char
[]
>
buf
(
new
char
[
kBufSize
]);
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
);
platform
::
CPUPlace
cpu
;
uintptr_t
data
=
reinterpret_cast
<
uintptr_t
>
(
data_ptr
);
while
(
size
!=
0
)
{
size_t
size_to_write
=
std
::
min
(
kBufSize
,
static_cast
<
size_t
>
(
size
));
memory
::
Copy
(
cpu
,
buf
.
get
(),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
data
),
size_to_write
,
gpu_dev_ctx
.
stream
());
gpu_dev_ctx
.
Wait
();
os
.
write
(
buf
.
get
(),
size_to_write
);
data
+=
size_to_write
;
size
-=
size_to_write
;
}
#else
PADDLE_THROW
(
"Unexpected branch"
);
#endif
}
else
{
os
.
write
(
static_cast
<
const
char
*>
(
data_ptr
),
static_cast
<
std
::
streamsize
>
(
size
));
}
}
}
struct
DeserializedDataFunctor
{
DeserializedDataFunctor
(
void
**
buf
,
Tensor
*
tensor
,
const
platform
::
Place
&
place
)
:
buf_
(
buf
),
tensor_
(
tensor
),
place_
(
place
)
{}
template
<
typename
T
>
void
operator
()()
{
*
buf_
=
tensor_
->
mutable_data
<
T
>
(
place_
);
}
void
**
buf_
;
Tensor
*
tensor_
;
platform
::
Place
place_
;
};
void
TensorFromStream
(
std
::
istream
&
is
,
Tensor
*
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
uint32_t
version
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
PADDLE_ENFORCE_EQ
(
version
,
0U
,
"Only version 0 is supported"
);
proto
::
VarType
::
TensorDesc
desc
;
{
// int32_t size
// proto buffer
int32_t
size
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
size
),
sizeof
(
size
));
std
::
unique_ptr
<
char
[]
>
buf
(
new
char
[
size
]);
is
.
read
(
reinterpret_cast
<
char
*>
(
buf
.
get
()),
size
);
PADDLE_ENFORCE
(
desc
.
ParseFromArray
(
buf
.
get
(),
size
),
"Cannot parse tensor desc"
);
}
{
// read tensor
std
::
vector
<
int64_t
>
dims
;
dims
.
reserve
(
static_cast
<
size_t
>
(
desc
.
dims
().
size
()));
std
::
copy
(
desc
.
dims
().
begin
(),
desc
.
dims
().
end
(),
std
::
back_inserter
(
dims
));
tensor
->
Resize
(
framework
::
make_ddim
(
dims
));
void
*
buf
;
auto
ctx
=
platform
::
CPUDeviceContext
();
if
(
platform
::
is_gpu_place
(
dev_ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
Tensor
cpu_tensor
;
cpu_tensor
.
Resize
(
framework
::
make_ddim
(
dims
));
framework
::
VisitDataType
(
desc
.
data_type
(),
DeserializedDataFunctor
(
&
buf
,
&
cpu_tensor
,
ctx
.
GetPlace
()));
is
.
read
(
static_cast
<
char
*>
(
buf
),
cpu_tensor
.
memory_size
());
auto
dst_place
=
dev_ctx
.
GetPlace
();
framework
::
TensorCopy
(
cpu_tensor
,
dst_place
,
dev_ctx
,
tensor
);
#else
PADDLE_THROW
(
"Unexpected branch"
);
#endif
}
else
{
framework
::
VisitDataType
(
desc
.
data_type
(),
DeserializedDataFunctor
(
&
buf
,
tensor
,
ctx
.
GetPlace
()));
is
.
read
(
static_cast
<
char
*>
(
buf
),
tensor
->
memory_size
());
}
}
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/tensor_util.cu
已删除
100644 → 0
浏览文件 @
4f23dbd5
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
namespace
paddle
{
namespace
framework
{
template
<
typename
Predicate
,
typename
DevCtx
>
struct
AnyDTypeVisitor
{
Predicate
predicate_
;
const
Tensor
&
tensor_
;
const
DevCtx
&
ctx_
;
Tensor
*
out_
;
AnyDTypeVisitor
(
Predicate
predicate
,
const
Tensor
&
tensor
,
const
DevCtx
&
ctx
,
Tensor
*
out
)
:
predicate_
(
predicate
),
tensor_
(
tensor
),
ctx_
(
ctx
),
out_
(
out
)
{}
template
<
typename
T
>
void
operator
()()
const
{
auto
t
=
EigenVector
<
T
>::
Flatten
(
tensor_
);
auto
o
=
EigenScalar
<
bool
>::
From
(
*
out_
);
// return any of predicate_(t) is true.
o
.
device
(
*
ctx_
.
eigen_device
())
=
predicate_
(
t
).
any
();
}
};
template
<
typename
Predicate
,
typename
DevCtx
>
inline
void
AnyImpl
(
Predicate
predicate
,
const
framework
::
Tensor
&
tensor
,
const
DevCtx
&
ctx
,
framework
::
Tensor
*
out
)
{
VisitDataType
(
ToDataType
(
tensor
.
type
()),
AnyDTypeVisitor
<
Predicate
,
DevCtx
>
(
predicate
,
tensor
,
ctx
,
out
));
}
template
<
typename
Predicate
>
struct
AnyVisitor
:
public
boost
::
static_visitor
<
bool
>
{
const
framework
::
Tensor
&
tensor_
;
Predicate
predicate_
;
AnyVisitor
(
const
framework
::
Tensor
&
tensor
,
Predicate
predicate
)
:
tensor_
(
tensor
),
predicate_
(
std
::
move
(
predicate
))
{}
template
<
typename
Place
>
bool
operator
()(
const
Place
&
place
)
const
{
framework
::
Tensor
out
;
out
.
Resize
({
1
});
out
.
mutable_data
<
bool
>
(
place
);
auto
*
ctx
=
platform
::
DeviceContextPool
::
Instance
().
GetByPlace
(
place
);
AnyImpl
(
predicate_
,
tensor_
,
*
ctx
,
&
out
);
return
this
->
GetResult
(
out
,
place
);
}
bool
GetResult
(
const
framework
::
Tensor
&
out
,
const
platform
::
CUDAPlace
&
gpu
)
const
{
platform
::
CPUPlace
cpu
;
framework
::
Tensor
tmp
;
tmp
.
Resize
({
1
});
tmp
.
mutable_data
<
bool
>
(
cpu
);
auto
gpuctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
gpu
);
gpuctx
->
Wait
();
Copy
(
out
,
cpu
,
*
gpuctx
,
&
tmp
);
gpuctx
->
Wait
();
return
GetResult
(
tmp
,
cpu
);
}
bool
GetResult
(
const
framework
::
Tensor
&
out
,
const
platform
::
CPUPlace
&
cpu
)
const
{
return
*
out
.
data
<
bool
>
();
}
};
template
<
typename
Predicate
>
inline
bool
Any
(
const
framework
::
Tensor
&
tensor
,
Predicate
predicate
)
{
AnyVisitor
<
Predicate
>
visitor
(
tensor
,
predicate
);
auto
place
=
tensor
.
place
();
return
platform
::
VisitPlace
(
place
,
visitor
);
}
struct
HasNANPredicate
{
template
<
typename
T
>
auto
operator
()(
const
T
&
eigen_vec
)
const
->
decltype
(
std
::
declval
<
T
>
().
isnan
())
{
// Cast eigen_vector to vector of bool. true if is inf.
return
eigen_vec
.
isnan
();
}
};
bool
HasNAN
(
const
framework
::
Tensor
&
tensor
)
{
HasNANPredicate
predicate
;
return
Any
(
tensor
,
predicate
);
}
struct
HasInfPredicate
{
template
<
typename
T
>
auto
operator
()(
const
T
&
eigen_vec
)
const
->
decltype
(
std
::
declval
<
T
>
().
isinf
())
{
// Cast eigen_vector to vector of bool. true if is inf.
return
eigen_vec
.
isinf
();
}
};
bool
HasInf
(
const
framework
::
Tensor
&
tensor
)
{
HasInfPredicate
predicate
;
return
Any
(
tensor
,
predicate
);
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/tensor_util.cu
0 → 120000
浏览文件 @
394828b7
tensor_util
.
cc
\ No newline at end of file
paddle/fluid/framework/tensor_util.h
浏览文件 @
394828b7
...
...
@@ -22,106 +22,38 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
/**
* @brief Copy the content of external tensor to a new place.
*
* @param[in] src The external tensor.
* @param[in] dst_place The dst place.
* @param[in] ctx The device context contains device resources.
*
* @note Copy supports CPU <-> GPU, GPU <-> GPU.
*/
inline
void
Copy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
)
{
VLOG
(
3
)
<<
"Copy "
<<
src
.
dims
()
<<
" from "
<<
src
.
place
()
<<
" to "
<<
dst_place
;
src
.
check_memory_size
();
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
);
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
Tensor
*
dst
);
dst
->
Resize
(
src
.
dims
());
dst
->
set_layout
(
src
.
layout
());
auto
src_place
=
src
.
place
();
auto
src_ptr
=
src
.
data
<
void
>
();
template
<
typename
T
>
void
TensorFromVector
(
const
std
::
vector
<
T
>&
src
,
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
);
template
<
typename
T
>
void
TensorFromVector
(
const
std
::
vector
<
T
>&
src
,
Tensor
*
dst
);
auto
dst_ptr
=
dst
->
mutable_data
(
dst_place
,
src
.
type
());
template
<
typename
T
>
void
TensorToVector
(
const
Tensor
&
src
,
const
platform
::
DeviceContext
&
ctx
,
std
::
vector
<
T
>*
dst
);
template
<
typename
T
>
void
TesnorToVector
(
const
Tensor
&
src
,
std
::
vector
<
T
>*
dst
);
auto
size
=
src
.
numel
()
*
SizeOfType
(
src
.
type
());
bool
TensorContainsNAN
(
const
framework
::
Tensor
&
tensor
);
bool
TensorContainsInf
(
const
framework
::
Tensor
&
tensor
);
if
(
platform
::
is_cpu_place
(
src_place
)
&&
platform
::
is_cpu_place
(
dst_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
);
}
#ifdef PADDLE_WITH_CUDA
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
// NOLINT
platform
::
is_cpu_place
(
dst_place
))
{
auto
src_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
src_place
);
auto
dst_cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
dst_place
);
auto
ctx_place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
PADDLE_ENFORCE_EQ
(
src_gpu_place
,
ctx_gpu_place
);
memory
::
Copy
(
dst_cpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
}
else
if
(
platform
::
is_cpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
auto
src_cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
);
auto
dst_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dst_place
);
auto
ctx_place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
PADDLE_ENFORCE_EQ
(
dst_gpu_place
,
ctx_gpu_place
);
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_cpu_place
,
src_ptr
,
size
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
}
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
auto
src_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
src_place
);
auto
dst_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dst_place
);
auto
ctx_place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
PADDLE_ENFORCE_EQ
(
src_gpu_place
,
ctx_gpu_place
);
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
}
#endif
}
void
TensorToStream
(
std
::
ostream
&
os
,
const
Tensor
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
void
TensorFromStream
(
std
::
istream
&
is
,
Tensor
*
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
/**
* @brief Wrapper on
* Copy(const Tensor& src, const platform::Place& dst_place,
* const platform::DeviceContext& ctx, Tensor* dst);
*
* @param[in] src The external tensor.
* @param[in] dst_place The dst place.
*
* @note Copy supports CPU <-> GPU, GPU <-> GPU.
*/
inline
void
Copy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
Tensor
*
dst
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
const
platform
::
DeviceContext
*
dev_ctx
;
if
(
platform
::
is_gpu_place
(
src
.
place
()))
{
dev_ctx
=
pool
.
Get
(
src
.
place
());
}
else
{
dev_ctx
=
pool
.
Get
(
dst_place
);
}
Copy
(
src
,
dst_place
,
*
dev_ctx
,
dst
);
}
//
// The implementation of template functions.
//
/**
* @brief Copy the content of an external vector to a tensor.
*
* @param[in] src The external tensor.
* @param[in] ctx The device context contains device resources.
*
* * @note CopyFromVector will resize dst to an 1D tensor with the same
* size as src.
*/
template
<
typename
T
>
inline
void
Copy
FromVector
(
const
std
::
vector
<
T
>&
src
,
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
)
{
void
Tensor
FromVector
(
const
std
::
vector
<
T
>&
src
,
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
)
{
auto
dst_place
=
ctx
.
GetPlace
();
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
());
platform
::
CPUPlace
src_place
;
...
...
@@ -143,11 +75,8 @@ inline void CopyFromVector(const std::vector<T>& src,
#endif
}
/**
* @brief CopyFromVector CPU vector -> CPU Tensor
*/
template
<
typename
T
>
inline
void
Copy
FromVector
(
const
std
::
vector
<
T
>&
src
,
Tensor
*
dst
)
{
void
Tensor
FromVector
(
const
std
::
vector
<
T
>&
src
,
Tensor
*
dst
)
{
platform
::
CPUPlace
dst_place
=
platform
::
CPUPlace
();
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
());
platform
::
CPUPlace
src_place
;
...
...
@@ -158,18 +87,9 @@ inline void CopyFromVector(const std::vector<T>& src, Tensor* dst) {
memory
::
Copy
(
dst_place
,
dst_ptr
,
src_place
,
src_ptr
,
size
);
}
/**
* @brief Copy the content of a tensor to a vector
*
* @param[in] src The external tensor.
* @param[in] ctx The device context contains device resources.
*
* * @note CopyFromVector assumes that the tensor has been resized
* before invoking.
*/
template
<
typename
T
>
inline
void
Copy
ToVector
(
const
Tensor
&
src
,
const
platform
::
DeviceContext
&
ctx
,
std
::
vector
<
T
>*
dst
)
{
void
Tensor
ToVector
(
const
Tensor
&
src
,
const
platform
::
DeviceContext
&
ctx
,
std
::
vector
<
T
>*
dst
)
{
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
size
=
src
.
numel
()
*
sizeof
(
T
);
...
...
@@ -191,11 +111,8 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx,
#endif
}
/**
* @brief CopyToVector CPUTensor <-> CPU Vector
*/
template
<
typename
T
>
inline
void
Copy
ToVector
(
const
Tensor
&
src
,
std
::
vector
<
T
>*
dst
)
{
void
Tensor
ToVector
(
const
Tensor
&
src
,
std
::
vector
<
T
>*
dst
)
{
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
size
=
src
.
numel
()
*
sizeof
(
T
);
...
...
@@ -209,125 +126,5 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) {
src_ptr
,
size
);
}
// Returns true if a tensor contains NAN, i.e., Not A Number.
bool
HasNAN
(
const
framework
::
Tensor
&
tensor
);
// Returns true if a tensor contains Inf, i.e., Infinity.
bool
HasInf
(
const
framework
::
Tensor
&
tensor
);
inline
void
SerializeToStream
(
std
::
ostream
&
os
,
const
Tensor
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
// TODO(typhoonzero): serialize to ostream
{
// the 1st field, uint32_t version
constexpr
uint32_t
version
=
0
;
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
version
),
sizeof
(
version
));
}
{
// the 2nd field, tensor description
// int32_t size
// void* protobuf message
proto
::
VarType
::
TensorDesc
desc
;
desc
.
set_data_type
(
framework
::
ToDataType
(
tensor
.
type
()));
auto
dims
=
framework
::
vectorize
(
tensor
.
dims
());
auto
*
pb_dims
=
desc
.
mutable_dims
();
pb_dims
->
Resize
(
static_cast
<
int
>
(
dims
.
size
()),
0
);
std
::
copy
(
dims
.
begin
(),
dims
.
end
(),
pb_dims
->
begin
());
int32_t
size
=
desc
.
ByteSize
();
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
auto
out
=
desc
.
SerializeAsString
();
os
.
write
(
out
.
data
(),
size
);
}
{
// the 3rd field, tensor data
uint64_t
size
=
tensor
.
memory_size
();
auto
*
data_ptr
=
tensor
.
data
<
void
>
();
PADDLE_ENFORCE
(
size
<
std
::
numeric_limits
<
std
::
streamsize
>::
max
(),
"Index overflow when writing tensor"
);
if
(
platform
::
is_gpu_place
(
tensor
.
place
()))
{
#ifdef PADDLE_WITH_CUDA
constexpr
size_t
kBufSize
=
1024
*
1024
*
64
;
// 64MB
std
::
unique_ptr
<
char
[]
>
buf
(
new
char
[
kBufSize
]);
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
);
platform
::
CPUPlace
cpu
;
uintptr_t
data
=
reinterpret_cast
<
uintptr_t
>
(
data_ptr
);
while
(
size
!=
0
)
{
size_t
size_to_write
=
std
::
min
(
kBufSize
,
static_cast
<
size_t
>
(
size
));
memory
::
Copy
(
cpu
,
buf
.
get
(),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
reinterpret_cast
<
const
void
*>
(
data
),
size_to_write
,
gpu_dev_ctx
.
stream
());
gpu_dev_ctx
.
Wait
();
os
.
write
(
buf
.
get
(),
size_to_write
);
data
+=
size_to_write
;
size
-=
size_to_write
;
}
#else
PADDLE_THROW
(
"Unexpected branch"
);
#endif
}
else
{
os
.
write
(
static_cast
<
const
char
*>
(
data_ptr
),
static_cast
<
std
::
streamsize
>
(
size
));
}
}
}
struct
DeserializedDataFunctor
{
DeserializedDataFunctor
(
void
**
buf
,
Tensor
*
tensor
,
const
platform
::
Place
&
place
)
:
buf_
(
buf
),
tensor_
(
tensor
),
place_
(
place
)
{}
template
<
typename
T
>
void
operator
()()
{
*
buf_
=
tensor_
->
mutable_data
<
T
>
(
place_
);
}
void
**
buf_
;
Tensor
*
tensor_
;
platform
::
Place
place_
;
};
inline
void
DeserializeFromStream
(
std
::
istream
&
is
,
Tensor
*
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
uint32_t
version
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
version
),
sizeof
(
version
));
PADDLE_ENFORCE_EQ
(
version
,
0U
,
"Only version 0 is supported"
);
proto
::
VarType
::
TensorDesc
desc
;
{
// int32_t size
// proto buffer
int32_t
size
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
size
),
sizeof
(
size
));
std
::
unique_ptr
<
char
[]
>
buf
(
new
char
[
size
]);
is
.
read
(
reinterpret_cast
<
char
*>
(
buf
.
get
()),
size
);
PADDLE_ENFORCE
(
desc
.
ParseFromArray
(
buf
.
get
(),
size
),
"Cannot parse tensor desc"
);
}
{
// read tensor
std
::
vector
<
int64_t
>
dims
;
dims
.
reserve
(
static_cast
<
size_t
>
(
desc
.
dims
().
size
()));
std
::
copy
(
desc
.
dims
().
begin
(),
desc
.
dims
().
end
(),
std
::
back_inserter
(
dims
));
tensor
->
Resize
(
framework
::
make_ddim
(
dims
));
void
*
buf
;
auto
ctx
=
platform
::
CPUDeviceContext
();
if
(
platform
::
is_gpu_place
(
dev_ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
Tensor
cpu_tensor
;
cpu_tensor
.
Resize
(
framework
::
make_ddim
(
dims
));
framework
::
VisitDataType
(
desc
.
data_type
(),
DeserializedDataFunctor
(
&
buf
,
&
cpu_tensor
,
ctx
.
GetPlace
()));
is
.
read
(
static_cast
<
char
*>
(
buf
),
cpu_tensor
.
memory_size
());
auto
dst_place
=
dev_ctx
.
GetPlace
();
framework
::
Copy
(
cpu_tensor
,
dst_place
,
dev_ctx
,
tensor
);
#else
PADDLE_THROW
(
"Unexpected branch"
);
#endif
}
else
{
framework
::
VisitDataType
(
desc
.
data_type
(),
DeserializedDataFunctor
(
&
buf
,
tensor
,
ctx
.
GetPlace
()));
is
.
read
(
static_cast
<
char
*>
(
buf
),
tensor
->
memory_size
());
}
}
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/tensor_util_test.cc
浏览文件 @
394828b7
...
...
@@ -20,7 +20,7 @@
namespace
paddle
{
namespace
framework
{
TEST
(
Copy
,
Tensor
)
{
TEST
(
Tensor
Copy
,
Tensor
)
{
Tensor
src_tensor
;
Tensor
dst_tensor
;
platform
::
CPUDeviceContext
cpu_ctx
((
platform
::
CPUPlace
()));
...
...
@@ -33,7 +33,7 @@ TEST(Copy, Tensor) {
src_tensor
.
set_layout
(
DataLayout
::
kAnyLayout
);
auto
cpu_place
=
new
platform
::
CPUPlace
();
Copy
(
src_tensor
,
*
cpu_place
,
&
dst_tensor
);
Tensor
Copy
(
src_tensor
,
*
cpu_place
,
&
dst_tensor
);
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
dst_ptr
);
...
...
@@ -44,7 +44,7 @@ TEST(Copy, Tensor) {
EXPECT_TRUE
(
dst_tensor
.
layout
()
==
src_tensor
.
layout
());
Tensor
slice_tensor
=
src_tensor
.
Slice
(
1
,
2
);
Copy
(
slice_tensor
,
*
cpu_place
,
&
dst_tensor
);
Tensor
Copy
(
slice_tensor
,
*
cpu_place
,
&
dst_tensor
);
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
dst_ptr
,
slice_ptr
);
...
...
@@ -68,11 +68,11 @@ TEST(Copy, Tensor) {
// CPU Tensor to GPU Tensor
auto
gpu_place
=
new
platform
::
CUDAPlace
(
0
);
platform
::
CUDADeviceContext
gpu_ctx
(
*
gpu_place
);
Copy
(
src_tensor
,
*
gpu_place
,
gpu_ctx
,
&
gpu_tensor
);
Tensor
Copy
(
src_tensor
,
*
gpu_place
,
gpu_ctx
,
&
gpu_tensor
);
// GPU Tensor to CPU Tensor
auto
cpu_place
=
new
platform
::
CPUPlace
();
Copy
(
gpu_tensor
,
*
cpu_place
,
gpu_ctx
,
&
dst_tensor
);
Tensor
Copy
(
gpu_tensor
,
*
cpu_place
,
gpu_ctx
,
&
dst_tensor
);
// Sync before Compare Tensors
gpu_ctx
.
Wait
();
...
...
@@ -85,10 +85,10 @@ TEST(Copy, Tensor) {
Tensor
slice_tensor
=
src_tensor
.
Slice
(
1
,
2
);
// CPU Slice Tensor to GPU Tensor
Copy
(
slice_tensor
,
*
gpu_place
,
gpu_ctx
,
&
gpu_tensor
);
Tensor
Copy
(
slice_tensor
,
*
gpu_place
,
gpu_ctx
,
&
gpu_tensor
);
// GPU Tensor to CPU Tensor
Copy
(
gpu_tensor
,
*
cpu_place
,
gpu_ctx
,
&
dst_tensor
);
Tensor
Copy
(
gpu_tensor
,
*
cpu_place
,
gpu_ctx
,
&
dst_tensor
);
// Sync before Compare Slice Tensors
gpu_ctx
.
Wait
();
...
...
@@ -104,7 +104,7 @@ TEST(Copy, Tensor) {
#endif
}
TEST
(
Copy
FromVector
,
Tensor
)
{
TEST
(
Tensor
FromVector
,
Tensor
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
{
...
...
@@ -114,7 +114,7 @@ TEST(CopyFromVector, Tensor) {
// Copy to CPU Tensor
cpu_tensor
.
Resize
(
make_ddim
({
3
,
3
}));
auto
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
Copy
FromVector
<
int
>
(
src_vec
,
&
cpu_tensor
);
Tensor
FromVector
<
int
>
(
src_vec
,
&
cpu_tensor
);
// Compare Tensors
const
int
*
cpu_ptr
=
cpu_tensor
.
data
<
int
>
();
...
...
@@ -126,7 +126,7 @@ TEST(CopyFromVector, Tensor) {
src_vec
.
erase
(
src_vec
.
begin
(),
src_vec
.
begin
()
+
5
);
cpu_tensor
.
Resize
(
make_ddim
({
2
,
2
}));
Copy
FromVector
<
int
>
(
src_vec
,
&
cpu_tensor
);
Tensor
FromVector
<
int
>
(
src_vec
,
&
cpu_tensor
);
cpu_ptr
=
cpu_tensor
.
data
<
int
>
();
src_ptr
=
src_vec
.
data
();
ASSERT_NE
(
src_ptr
,
cpu_ptr
);
...
...
@@ -148,15 +148,15 @@ TEST(CopyFromVector, Tensor) {
cpu_tensor
.
Resize
(
make_ddim
({
3
,
3
}));
auto
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
CPUDeviceContext
cpu_ctx
(
*
cpu_place
);
Copy
FromVector
<
int
>
(
src_vec
,
cpu_ctx
,
&
cpu_tensor
);
Tensor
FromVector
<
int
>
(
src_vec
,
cpu_ctx
,
&
cpu_tensor
);
// Copy to GPUTensor
gpu_tensor
.
Resize
(
make_ddim
({
3
,
3
}));
auto
gpu_place
=
new
paddle
::
platform
::
CUDAPlace
();
CUDADeviceContext
gpu_ctx
(
*
gpu_place
);
Copy
FromVector
<
int
>
(
src_vec
,
gpu_ctx
,
&
gpu_tensor
);
Tensor
FromVector
<
int
>
(
src_vec
,
gpu_ctx
,
&
gpu_tensor
);
// Copy from GPU to CPU tensor for comparison
Copy
(
gpu_tensor
,
*
cpu_place
,
gpu_ctx
,
&
dst_tensor
);
Tensor
Copy
(
gpu_tensor
,
*
cpu_place
,
gpu_ctx
,
&
dst_tensor
);
// Sync before Compare Tensors
gpu_ctx
.
Wait
();
...
...
@@ -173,10 +173,10 @@ TEST(CopyFromVector, Tensor) {
src_vec
.
erase
(
src_vec
.
begin
(),
src_vec
.
begin
()
+
5
);
cpu_tensor
.
Resize
(
make_ddim
({
2
,
2
}));
Copy
FromVector
<
int
>
(
src_vec
,
cpu_ctx
,
&
cpu_tensor
);
Tensor
FromVector
<
int
>
(
src_vec
,
cpu_ctx
,
&
cpu_tensor
);
gpu_tensor
.
Resize
(
make_ddim
({
2
,
2
}));
Copy
FromVector
<
int
>
(
src_vec
,
gpu_ctx
,
&
gpu_tensor
);
Copy
(
gpu_tensor
,
*
cpu_place
,
gpu_ctx
,
&
dst_tensor
);
Tensor
FromVector
<
int
>
(
src_vec
,
gpu_ctx
,
&
gpu_tensor
);
Tensor
Copy
(
gpu_tensor
,
*
cpu_place
,
gpu_ctx
,
&
dst_tensor
);
// Sync before Compare Tensors
gpu_ctx
.
Wait
();
...
...
@@ -196,7 +196,7 @@ TEST(CopyFromVector, Tensor) {
#endif
}
TEST
(
Copy
ToVector
,
Tensor
)
{
TEST
(
Tensor
ToVector
,
Tensor
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
{
...
...
@@ -208,7 +208,7 @@ TEST(CopyToVector, Tensor) {
CPUPlace
place
;
std
::
vector
<
int
>
dst
;
Copy
ToVector
<
int
>
(
src
,
&
dst
);
Tensor
ToVector
<
int
>
(
src
,
&
dst
);
for
(
int
i
=
0
;
i
<
3
*
3
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
],
dst
[
i
]);
...
...
@@ -220,10 +220,10 @@ TEST(CopyToVector, Tensor) {
Tensor
gpu_tensor
;
CUDAPlace
place
;
CUDADeviceContext
gpu_ctx
(
place
);
Copy
FromVector
<
int
>
(
src_vec
,
gpu_ctx
,
&
gpu_tensor
);
Tensor
FromVector
<
int
>
(
src_vec
,
gpu_ctx
,
&
gpu_tensor
);
std
::
vector
<
int
>
dst
;
Copy
ToVector
<
int
>
(
gpu_tensor
,
gpu_ctx
,
&
dst
);
Tensor
ToVector
<
int
>
(
gpu_tensor
,
gpu_ctx
,
&
dst
);
for
(
int
i
=
0
;
i
<
3
*
3
;
++
i
)
{
EXPECT_EQ
(
src_vec
[
i
],
dst
[
i
]);
...
...
@@ -232,7 +232,7 @@ TEST(CopyToVector, Tensor) {
#endif
}
TEST
(
Ha
sNAN
,
CPU
)
{
TEST
(
TensorContain
sNAN
,
CPU
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
Tensor
src
;
...
...
@@ -240,11 +240,12 @@ TEST(HasNAN, CPU) {
buf
[
0
]
=
0.0
;
buf
[
1
]
=
NAN
;
buf
[
2
]
=
0.0
;
ASSERT_TRUE
(
HasNAN
(
src
));
ASSERT_TRUE
(
TensorContainsNAN
(
src
));
buf
[
1
]
=
0.0
;
ASSERT_FALSE
(
TensorContainsNAN
(
src
));
}
TEST
(
Ha
sInf
,
CPU
)
{
TEST
(
TensorContain
sInf
,
CPU
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
Tensor
src
;
...
...
@@ -252,10 +253,12 @@ TEST(HasInf, CPU) {
buf
[
0
]
=
1.0
;
buf
[
1
]
=
INFINITY
;
buf
[
2
]
=
0.0
;
ASSERT_TRUE
(
HasInf
(
src
));
ASSERT_TRUE
(
TensorContainsInf
(
src
));
buf
[
1
]
=
1.0
;
ASSERT_FALSE
(
TensorContainsInf
(
src
));
}
TEST
(
Tensor
,
SerializeAndDeserialize
)
{
TEST
(
Tensor
,
FromAndToStream
)
{
framework
::
Tensor
src_tensor
;
int
array
[
6
]
=
{
1
,
2
,
3
,
4
,
5
,
6
};
src_tensor
.
Resize
({
2
,
3
});
...
...
@@ -268,10 +271,10 @@ TEST(Tensor, SerializeAndDeserialize) {
auto
place
=
new
platform
::
CPUPlace
();
platform
::
CPUDeviceContext
cpu_ctx
(
*
place
);
std
::
ostringstream
oss
;
Serialize
ToStream
(
oss
,
src_tensor
,
cpu_ctx
);
Tensor
ToStream
(
oss
,
src_tensor
,
cpu_ctx
);
std
::
istringstream
iss
(
oss
.
str
());
Deserialize
FromStream
(
iss
,
&
dst_tensor
,
cpu_ctx
);
Tensor
FromStream
(
iss
,
&
dst_tensor
,
cpu_ctx
);
int
*
dst_ptr
=
dst_tensor
.
mutable_data
<
int
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
5
;
++
i
)
{
ASSERT_EQ
(
dst_ptr
[
i
],
array
[
i
]);
...
...
@@ -288,13 +291,13 @@ TEST(Tensor, SerializeAndDeserialize) {
auto
gpu_place
=
new
platform
::
CUDAPlace
();
platform
::
CUDADeviceContext
gpu_ctx
(
*
gpu_place
);
Copy
(
src_tensor
,
*
gpu_place
,
gpu_ctx
,
&
gpu_tensor
);
Tensor
Copy
(
src_tensor
,
*
gpu_place
,
gpu_ctx
,
&
gpu_tensor
);
std
::
ostringstream
oss
;
Serialize
ToStream
(
oss
,
gpu_tensor
,
gpu_ctx
);
Tensor
ToStream
(
oss
,
gpu_tensor
,
gpu_ctx
);
std
::
istringstream
iss
(
oss
.
str
());
Deserialize
FromStream
(
iss
,
&
dst_tensor
,
gpu_ctx
);
Tensor
FromStream
(
iss
,
&
dst_tensor
,
gpu_ctx
);
int
*
dst_ptr
=
dst_tensor
.
mutable_data
<
int
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
...
...
paddle/fluid/framework/tensor_util_test.cu
浏览文件 @
394828b7
...
...
@@ -31,7 +31,7 @@ static __global__ void FillInf(float* buf) {
buf
[
2
]
=
0.5
;
}
TEST
(
Ha
sNAN
,
GPU
)
{
TEST
(
TensorContain
sNAN
,
GPU
)
{
Tensor
tensor
;
platform
::
CUDAPlace
gpu
(
0
);
auto
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
@@ -39,10 +39,10 @@ TEST(HasNAN, GPU) {
float
*
buf
=
tensor
.
mutable_data
<
float
>
({
3
},
gpu
);
FillNAN
<<<
1
,
1
,
0
,
cuda_ctx
->
stream
()
>>>
(
buf
);
cuda_ctx
->
Wait
();
ASSERT_TRUE
(
Ha
sNAN
(
tensor
));
ASSERT_TRUE
(
TensorContain
sNAN
(
tensor
));
}
TEST
(
Ha
sInf
,
GPU
)
{
TEST
(
TensorContain
sInf
,
GPU
)
{
Tensor
tensor
;
platform
::
CUDAPlace
gpu
(
0
);
auto
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
@@ -50,7 +50,7 @@ TEST(HasInf, GPU) {
float
*
buf
=
tensor
.
mutable_data
<
float
>
({
3
},
gpu
);
FillInf
<<<
1
,
1
,
0
,
cuda_ctx
->
stream
()
>>>
(
buf
);
cuda_ctx
->
Wait
();
ASSERT_TRUE
(
Ha
sInf
(
tensor
));
ASSERT_TRUE
(
TensorContain
sInf
(
tensor
));
}
}
// namespace framework
...
...
paddle/fluid/framework/threadpool.h
浏览文件 @
394828b7
...
...
@@ -64,7 +64,6 @@ class ThreadPool {
Task
task
([
fn
]()
->
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
{
try
{
fn
();
return
nullptr
;
}
catch
(
platform
::
EnforceNotMet
ex
)
{
return
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
(
new
platform
::
EnforceNotMet
(
ex
));
...
...
@@ -73,6 +72,7 @@ class ThreadPool {
<<
"Unexpected exception is catched in thread pool. All "
"throwable exception in Fluid should be an EnforceNotMet."
;
}
return
nullptr
;
});
std
::
future
<
std
::
unique_ptr
<
platform
::
EnforceNotMet
>>
f
=
task
.
get_future
();
tasks_
.
push
(
std
::
move
(
task
));
...
...
paddle/fluid/operators/array_operator.h
浏览文件 @
394828b7
...
...
@@ -42,7 +42,7 @@ class ArrayOp : public framework::OperatorBase {
if
(
platform
::
is_gpu_place
(
i_tensor
.
place
()))
{
// FIXME: Avoid copy from GPU to CPU
framework
::
Tensor
t
;
framework
::
Copy
(
i_tensor
,
platform
::
CPUPlace
(),
dev_ctx
,
&
t
);
framework
::
Tensor
Copy
(
i_tensor
,
platform
::
CPUPlace
(),
dev_ctx
,
&
t
);
dev_ctx
.
Wait
();
offset
=
static_cast
<
size_t
>
(
*
t
.
data
<
int64_t
>
());
}
else
{
...
...
paddle/fluid/operators/array_to_lod_tensor_op.cc
浏览文件 @
394828b7
...
...
@@ -112,8 +112,8 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
Copy
(
x
[
x_idx
].
Slice
(
start_offset
,
end_offset
),
place
,
dev_ctx
,
&
slice
);
framework
::
Tensor
Copy
(
x
[
x_idx
].
Slice
(
start_offset
,
end_offset
),
place
,
dev_ctx
,
&
slice
);
out_offset
+=
len
;
}
}
...
...
paddle/fluid/operators/assign_op.cc
浏览文件 @
394828b7
...
...
@@ -45,7 +45,7 @@ class AssignFunctor {
out_rows
.
set_height
(
rows
.
height
());
auto
&
t
=
rows
.
value
();
auto
*
m
=
out_rows
.
mutable_value
();
framework
::
Copy
(
t
,
t
.
place
(),
dev_ctx_
,
m
);
framework
::
Tensor
Copy
(
t
,
t
.
place
(),
dev_ctx_
,
m
);
}
template
<
typename
T
>
...
...
@@ -57,7 +57,7 @@ class AssignFunctor {
void
copy_tensor
(
const
framework
::
LoDTensor
&
lod_tensor
,
framework
::
LoDTensor
*
out
)
const
{
auto
&
out_tensor
=
*
out
;
Copy
(
lod_tensor
,
lod_tensor
.
place
(),
dev_ctx_
,
&
out_tensor
);
Tensor
Copy
(
lod_tensor
,
lod_tensor
.
place
(),
dev_ctx_
,
&
out_tensor
);
out_tensor
.
set_lod
(
lod_tensor
.
lod
());
}
...
...
paddle/fluid/operators/assign_value_op.h
浏览文件 @
394828b7
...
...
@@ -41,7 +41,7 @@ class AssignValueKernel : public framework::OpKernel<T> {
break
;
}
auto
values
=
ctx
.
Attr
<
std
::
vector
<
T
>>
(
value_name
);
framework
::
Copy
FromVector
(
values
,
ctx
.
device_context
(),
out
);
framework
::
Tensor
FromVector
(
values
,
ctx
.
device_context
(),
out
);
out
->
Resize
(
framework
::
make_ddim
(
shape
));
}
};
...
...
paddle/fluid/operators/beam_search_decode_op.h
浏览文件 @
394828b7
...
...
@@ -232,12 +232,12 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
id_tensor
->
set_lod
(
lod
);
id_tensor
->
Resize
({
static_cast
<
int64_t
>
(
id_data
.
size
())});
id_tensor
->
mutable_data
<
int64_t
>
(
paddle
::
platform
::
CPUPlace
());
framework
::
Copy
FromVector
<
int64_t
>
(
id_data
,
cpu_ctx
,
id_tensor
);
framework
::
Tensor
FromVector
<
int64_t
>
(
id_data
,
cpu_ctx
,
id_tensor
);
score_tensor
->
set_lod
(
lod
);
score_tensor
->
Resize
({
static_cast
<
int64_t
>
(
score_data
.
size
())});
score_tensor
->
mutable_data
<
T
>
(
paddle
::
platform
::
CPUPlace
());
framework
::
Copy
FromVector
<
T
>
(
score_data
,
cpu_ctx
,
score_tensor
);
framework
::
Tensor
FromVector
<
T
>
(
score_data
,
cpu_ctx
,
score_tensor
);
}
template
<
typename
T
>
...
...
paddle/fluid/operators/detection_output_op.h
浏览文件 @
394828b7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -98,16 +98,16 @@ class DetectionOutputKernel : public framework::OpKernel<T> {
T
*
conf_data
=
conf_tensor
.
data
<
T
>
();
if
(
platform
::
is_gpu_place
(
context
.
GetPlace
()))
{
loc_cpu
.
mutable_data
<
T
>
(
loc_tensor
.
dims
(),
platform
::
CPUPlace
());
framework
::
Copy
(
loc_tensor
,
platform
::
CPUPlace
(),
context
.
device_context
(),
&
loc_cpu
);
framework
::
Tensor
Copy
(
loc_tensor
,
platform
::
CPUPlace
(),
context
.
device_context
(),
&
loc_cpu
);
loc_data
=
loc_cpu
.
data
<
T
>
();
conf_cpu
.
mutable_data
<
T
>
(
conf_tensor
.
dims
(),
platform
::
CPUPlace
());
framework
::
Copy
(
conf_tensor
,
platform
::
CPUPlace
(),
context
.
device_context
(),
&
conf_cpu
);
framework
::
Tensor
Copy
(
conf_tensor
,
platform
::
CPUPlace
(),
context
.
device_context
(),
&
conf_cpu
);
conf_data
=
conf_cpu
.
data
<
T
>
();
priorbox_cpu
.
mutable_data
<
T
>
(
in_priorbox
->
dims
(),
platform
::
CPUPlace
());
framework
::
Copy
(
*
in_priorbox
,
platform
::
CPUPlace
(),
context
.
device_context
(),
&
priorbox_cpu
);
framework
::
Tensor
Copy
(
*
in_priorbox
,
platform
::
CPUPlace
(),
context
.
device_context
(),
&
priorbox_cpu
);
priorbox_data
=
priorbox_cpu
.
data
<
T
>
();
}
// get decode bboxes
...
...
@@ -158,8 +158,8 @@ class DetectionOutputKernel : public framework::OpKernel<T> {
batch_size
,
all_indices
,
all_decoded_bboxes
,
out_data
);
if
(
platform
::
is_gpu_place
(
context
.
GetPlace
()))
{
framework
::
Copy
(
out_cpu
,
platform
::
CUDAPlace
(),
context
.
device_context
(),
out
);
framework
::
TensorCopy
(
out_cpu
,
platform
::
CUDAPlace
(),
context
.
device_context
(),
out
);
}
}
};
...
...
paddle/fluid/operators/expand_op.h
浏览文件 @
394828b7
...
...
@@ -126,7 +126,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
auto
*
in0
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
out0
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
out0
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
Copy
(
*
in0
,
context
.
GetPlace
(),
context
.
device_context
(),
out0
);
framework
::
TensorCopy
(
*
in0
,
context
.
GetPlace
(),
context
.
device_context
(),
out0
);
}
else
{
switch
(
dims
)
{
REP_EXPAND_GRAD_TEMPLATE
(
72
)
...
...
paddle/fluid/operators/feed_op.cc
浏览文件 @
394828b7
...
...
@@ -57,7 +57,7 @@ class FeedOp : public framework::OperatorBase {
if
(
platform
::
is_same_place
(
feed_item
.
place
(),
place
))
{
out_item
->
ShareDataWith
(
feed_item
);
}
else
{
framework
::
Copy
(
feed_item
,
place
,
dev_ctx
,
out_item
);
framework
::
Tensor
Copy
(
feed_item
,
place
,
dev_ctx
,
out_item
);
}
out_item
->
set_lod
(
feed_item
.
lod
());
}
...
...
paddle/fluid/operators/fetch_op.cc
浏览文件 @
394828b7
...
...
@@ -56,7 +56,7 @@ class FetchOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
src_item
.
place
());
Copy
(
src_item
,
platform
::
CPUPlace
(),
dev_ctx
,
&
dst_item
);
Tensor
Copy
(
src_item
,
platform
::
CPUPlace
(),
dev_ctx
,
&
dst_item
);
dev_ctx
.
Wait
();
dst_item
.
set_lod
(
src_item
.
lod
());
...
...
paddle/fluid/operators/fill_op.cc
浏览文件 @
394828b7
...
...
@@ -75,7 +75,7 @@ class FillOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
Copy
(
tensor
,
place
,
dev_ctx
,
&
out
);
framework
::
Tensor
Copy
(
tensor
,
place
,
dev_ctx
,
&
out
);
}
}
};
...
...
paddle/fluid/operators/layer_norm_op.h
浏览文件 @
394828b7
...
...
@@ -196,7 +196,7 @@ class LayerNormGradKernel : public framework::OpKernel<T> {
// dy_dx
ElementwiseComputeEx
<
MulFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
&
d_y
,
scale
,
/*axis*/
1
,
MulFunctor
<
T
>
(),
&
temp
);
framework
::
Copy
(
temp
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
framework
::
Tensor
Copy
(
temp
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
// dy_dmean_dx
row_mean
(
dev_ctx
,
temp
,
&
temp_vec
);
...
...
@@ -208,7 +208,7 @@ class LayerNormGradKernel : public framework::OpKernel<T> {
ctx
,
&
temp
,
&
temp_norm
,
/*axis*/
0
,
MulFunctor
<
T
>
(),
&
temp
);
}
else
{
// dy_dx
framework
::
Copy
(
d_y
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
framework
::
Tensor
Copy
(
d_y
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
// dy_dmean_dx
row_mean
(
dev_ctx
,
d_y
,
&
temp_vec
);
...
...
paddle/fluid/operators/load_combine_op.cc
浏览文件 @
394828b7
...
...
@@ -69,7 +69,7 @@ class LoadCombineOp : public framework::OperatorBase {
out_var
->
Clear
();
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
set_lod
(
cpu_tensor
.
lod
());
Copy
(
cpu_tensor
,
place
,
dev_ctx
,
tensor
);
Tensor
Copy
(
cpu_tensor
,
place
,
dev_ctx
,
tensor
);
}
}
}
...
...
paddle/fluid/operators/load_op.cc
浏览文件 @
394828b7
...
...
@@ -55,7 +55,7 @@ class LoadOp : public framework::OperatorBase {
out_var
->
Clear
();
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
set_lod
(
cpu_tensor
.
lod
());
Copy
(
cpu_tensor
,
place
,
dev_ctx
,
tensor
);
Tensor
Copy
(
cpu_tensor
,
place
,
dev_ctx
,
tensor
);
}
}
};
...
...
paddle/fluid/operators/lod_reset_op.h
浏览文件 @
394828b7
...
...
@@ -33,8 +33,8 @@ class LoDResetKernel : public framework::OpKernel<T> {
auto
*
lod
=
lod_t
->
data
<
int
>
();
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
framework
::
Tensor
lod_cpu
;
framework
::
Copy
(
*
lod_t
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
lod_cpu
);
framework
::
TensorCopy
(
*
lod_t
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
lod_cpu
);
lod
=
lod_cpu
.
data
<
int
>
();
}
level0
=
std
::
vector
<
int
>
(
lod
,
lod
+
lod_t
->
numel
());
...
...
paddle/fluid/operators/lod_tensor_to_array_op.cc
浏览文件 @
394828b7
...
...
@@ -94,9 +94,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
Copy
(
x
.
Slice
(
static_cast
<
int
>
(
each_range
.
begin
),
static_cast
<
int
>
(
each_range
.
end
)),
x
.
place
(),
dev_ctx
,
&
slice
);
framework
::
Tensor
Copy
(
x
.
Slice
(
static_cast
<
int
>
(
each_range
.
begin
),
static_cast
<
int
>
(
each_range
.
end
)),
x
.
place
(),
dev_ctx
,
&
slice
);
offset
+=
len
;
}
}
...
...
paddle/fluid/operators/math/context_project.h
浏览文件 @
394828b7
...
...
@@ -149,7 +149,8 @@ class ContextProjectFunctor {
Tensor
out_t_sub
=
out_t
.
Slice
(
k
*
context_length
,
k
*
context_length
+
padding_size
);
Tensor
w_sub
=
padding_data
.
Slice
(
k
,
k
+
padding_size
);
framework
::
Copy
(
w_sub
,
context
.
GetPlace
(),
context
,
&
out_t_sub
);
framework
::
TensorCopy
(
w_sub
,
context
.
GetPlace
(),
context
,
&
out_t_sub
);
}
}
if
(
down_pad
>
0
)
{
// add down pad
...
...
@@ -179,7 +180,8 @@ class ContextProjectFunctor {
(
down_pad_begin_row
+
t
)
*
context_length
);
Tensor
w_sub
=
padding_data
.
Slice
(
up_pad
+
padding_idx
,
up_pad
+
padding_idx
+
padding_size
);
framework
::
Copy
(
w_sub
,
context
.
GetPlace
(),
context
,
&
out_t_sub
);
framework
::
TensorCopy
(
w_sub
,
context
.
GetPlace
(),
context
,
&
out_t_sub
);
}
}
out_t
.
Resize
({
sequence_height
,
context_length
*
sequence_width
});
...
...
paddle/fluid/operators/math/im2col_test.cc
浏览文件 @
394828b7
...
...
@@ -62,7 +62,7 @@ void testIm2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
input
=
input_tmp
;
}
else
{
Copy
(
input_tmp
,
*
place
,
*
context
,
&
input
);
Tensor
Copy
(
input_tmp
,
*
place
,
*
context
,
&
input
);
}
output_cfo
.
mutable_data
<
float
>
(
{
1
,
filter_size
,
filter_size
,
output_height
,
output_width
},
*
place
);
...
...
@@ -87,7 +87,7 @@ void testIm2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
out_cfo_ptr
=
output_cfo
.
data
<
float
>
();
}
else
{
Copy
(
output_cfo
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
output_tmp
);
Tensor
Copy
(
output_cfo
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
output_tmp
);
out_cfo_ptr
=
output_tmp
.
data
<
float
>
();
}
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
...
...
@@ -98,7 +98,7 @@ void testIm2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
out_ocf_ptr
=
output_ocf
.
data
<
float
>
();
}
else
{
Copy
(
output_ocf
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
output_tmp
);
Tensor
Copy
(
output_ocf
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
output_tmp
);
out_ocf_ptr
=
output_tmp
.
data
<
float
>
();
}
...
...
@@ -119,7 +119,7 @@ void testIm2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
input
=
input_tmp
;
}
else
{
Copy
(
input_tmp
,
*
place
,
*
context
,
&
input
);
Tensor
Copy
(
input_tmp
,
*
place
,
*
context
,
&
input
);
}
col2im
(
*
context
,
output_cfo
,
dilation
,
stride
,
padding
,
&
input
);
...
...
@@ -128,7 +128,7 @@ void testIm2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
in_ptr
=
input
.
data
<
float
>
();
}
else
{
Copy
(
input
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
input_tmp
);
Tensor
Copy
(
input
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
input_tmp
);
in_ptr
=
input_tmp
.
data
<
float
>
();
}
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
...
...
@@ -140,7 +140,7 @@ void testIm2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
input
=
input_tmp
;
}
else
{
Copy
(
input_tmp
,
*
place
,
*
context
,
&
input
);
Tensor
Copy
(
input_tmp
,
*
place
,
*
context
,
&
input
);
}
col2im_ocf
(
*
context
,
output_ocf
,
dilation
,
stride
,
padding
,
&
input
);
...
...
@@ -148,7 +148,7 @@ void testIm2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
in_ptr
=
input
.
data
<
float
>
();
}
else
{
Copy
(
input
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
input_tmp
);
Tensor
Copy
(
input
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
input_tmp
);
in_ptr
=
input_tmp
.
data
<
float
>
();
}
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
...
...
paddle/fluid/operators/math/math_function_test.cu
浏览文件 @
394828b7
...
...
@@ -29,15 +29,15 @@ TEST(math_function, notrans_mul_trans) {
auto
*
gpu_place
=
new
paddle
::
platform
::
CUDAPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
paddle
::
framework
::
Copy
(
input1
,
*
gpu_place
,
context
,
&
input1_gpu
);
paddle
::
framework
::
Copy
(
input1
,
*
gpu_place
,
context
,
&
input2_gpu
);
paddle
::
framework
::
Tensor
Copy
(
input1
,
*
gpu_place
,
context
,
&
input1_gpu
);
paddle
::
framework
::
Tensor
Copy
(
input1
,
*
gpu_place
,
context
,
&
input2_gpu
);
out_gpu
.
mutable_data
<
float
>
({
2
,
2
},
*
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
(
context
,
input1_gpu
,
false
,
input2_gpu
,
true
,
1
,
&
out_gpu
,
0
);
paddle
::
framework
::
Copy
(
out_gpu
,
*
cpu_place
,
context
,
&
out
);
paddle
::
framework
::
Tensor
Copy
(
out_gpu
,
*
cpu_place
,
context
,
&
out
);
float
*
out_ptr
=
out
.
data
<
float
>
();
context
.
Wait
();
...
...
@@ -63,15 +63,15 @@ TEST(math_function, trans_mul_notrans) {
auto
*
gpu_place
=
new
paddle
::
platform
::
CUDAPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
paddle
::
framework
::
Copy
(
input1
,
*
gpu_place
,
context
,
&
input1_gpu
);
paddle
::
framework
::
Copy
(
input1
,
*
gpu_place
,
context
,
&
input2_gpu
);
paddle
::
framework
::
Tensor
Copy
(
input1
,
*
gpu_place
,
context
,
&
input1_gpu
);
paddle
::
framework
::
Tensor
Copy
(
input1
,
*
gpu_place
,
context
,
&
input2_gpu
);
out_gpu
.
mutable_data
<
float
>
({
3
,
3
},
*
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
(
context
,
input1_gpu
,
true
,
input2_gpu
,
false
,
1
,
&
out_gpu
,
0
);
paddle
::
framework
::
Copy
(
out_gpu
,
*
cpu_place
,
context
,
&
out
);
paddle
::
framework
::
Tensor
Copy
(
out_gpu
,
*
cpu_place
,
context
,
&
out
);
float
*
out_ptr
=
out
.
data
<
float
>
();
context
.
Wait
();
...
...
@@ -112,9 +112,9 @@ TEST(math_function, gemm_notrans_cublas) {
auto
*
gpu_place
=
new
paddle
::
platform
::
CUDAPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
paddle
::
framework
::
Copy
(
input1
,
*
gpu_place
,
context
,
&
input1_gpu
);
paddle
::
framework
::
Copy
(
input2
,
*
gpu_place
,
context
,
&
input2_gpu
);
paddle
::
framework
::
Copy
(
input3
,
*
gpu_place
,
context
,
&
input3_gpu
);
paddle
::
framework
::
Tensor
Copy
(
input1
,
*
gpu_place
,
context
,
&
input1_gpu
);
paddle
::
framework
::
Tensor
Copy
(
input2
,
*
gpu_place
,
context
,
&
input2_gpu
);
paddle
::
framework
::
Tensor
Copy
(
input3
,
*
gpu_place
,
context
,
&
input3_gpu
);
float
*
a
=
input1_gpu
.
data
<
float
>
();
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
*
gpu_place
);
...
...
@@ -122,7 +122,7 @@ TEST(math_function, gemm_notrans_cublas) {
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
(
context
,
false
,
false
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
1
,
4
,
1
,
c
+
1
,
4
);
paddle
::
framework
::
Copy
(
input3_gpu
,
*
cpu_place
,
context
,
&
input3
);
paddle
::
framework
::
Tensor
Copy
(
input3_gpu
,
*
cpu_place
,
context
,
&
input3
);
// numpy code:
// a = np.arange(6).reshape(2, 3)
...
...
@@ -167,9 +167,9 @@ TEST(math_function, gemm_trans_cublas) {
auto
*
gpu_place
=
new
paddle
::
platform
::
CUDAPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
paddle
::
framework
::
Copy
(
input1
,
*
gpu_place
,
context
,
&
input1_gpu
);
paddle
::
framework
::
Copy
(
input2
,
*
gpu_place
,
context
,
&
input2_gpu
);
paddle
::
framework
::
Copy
(
input3
,
*
gpu_place
,
context
,
&
input3_gpu
);
paddle
::
framework
::
Tensor
Copy
(
input1
,
*
gpu_place
,
context
,
&
input1_gpu
);
paddle
::
framework
::
Tensor
Copy
(
input2
,
*
gpu_place
,
context
,
&
input2_gpu
);
paddle
::
framework
::
Tensor
Copy
(
input3
,
*
gpu_place
,
context
,
&
input3_gpu
);
float
*
a
=
input1_gpu
.
data
<
float
>
();
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
*
gpu_place
);
...
...
@@ -177,7 +177,7 @@ TEST(math_function, gemm_trans_cublas) {
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
(
context
,
false
,
true
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
3
,
3
,
1
,
c
+
1
,
4
);
paddle
::
framework
::
Copy
(
input3_gpu
,
*
cpu_place
,
context
,
&
input3
);
paddle
::
framework
::
Tensor
Copy
(
input3_gpu
,
*
cpu_place
,
context
,
&
input3
);
context
.
Wait
();
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
...
...
@@ -218,15 +218,15 @@ void GemvTest(int m, int n, bool trans) {
}
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
paddle
::
framework
::
Copy
(
mat_a
,
*
gpu_place
,
context
,
&
g_mat_a
);
paddle
::
framework
::
Copy
(
vec_b
,
*
gpu_place
,
context
,
&
g_vec_b
);
paddle
::
framework
::
Tensor
Copy
(
mat_a
,
*
gpu_place
,
context
,
&
g_mat_a
);
paddle
::
framework
::
Tensor
Copy
(
vec_b
,
*
gpu_place
,
context
,
&
g_vec_b
);
paddle
::
operators
::
math
::
gemv
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
(
context
,
trans
,
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
1.
,
g_data_a
,
g_data_b
,
0.
,
g_data_c
);
paddle
::
framework
::
Copy
(
g_vec_c
,
paddle
::
platform
::
CPUPlace
(),
context
,
&
vec_c
);
paddle
::
framework
::
Tensor
Copy
(
g_vec_c
,
paddle
::
platform
::
CPUPlace
(),
context
,
&
vec_c
);
if
(
!
trans
)
{
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
...
...
paddle/fluid/operators/math/selected_rows_functor_test.cu
浏览文件 @
394828b7
...
...
@@ -67,7 +67,7 @@ TEST(selected_rows_functor, gpu_add) {
EXPECT_EQ
(
out_rows
[
6
],
9
);
Tensor
out_cpu
;
Copy
(
*
out_value
,
cpu_place
,
ctx
,
&
out_cpu
);
Tensor
Copy
(
*
out_value
,
cpu_place
,
ctx
,
&
out_cpu
);
ctx
.
Wait
();
auto
*
out_cpu_data
=
out_cpu
.
data
<
float
>
();
...
...
@@ -94,7 +94,7 @@ TEST(selected_rows_functor, gpu_add) {
add_tensor_functor
(
ctx
,
*
output
,
*
tensor1
,
tensor2
.
get
());
Tensor
tensor2_cpu
;
Copy
(
*
tensor2
,
cpu_place
,
ctx
,
&
tensor2_cpu
);
Tensor
Copy
(
*
tensor2
,
cpu_place
,
ctx
,
&
tensor2_cpu
);
ctx
.
Wait
();
auto
*
tensor2_cpu_data
=
tensor2_cpu
.
data
<
float
>
();
...
...
@@ -167,7 +167,7 @@ TEST(selected_rows_functor, gpu_add_to) {
EXPECT_EQ
(
out_rows
[
6
],
9
);
Tensor
out_cpu
;
Copy
(
*
out_value
,
cpu_place
,
ctx
,
&
out_cpu
);
Tensor
Copy
(
*
out_value
,
cpu_place
,
ctx
,
&
out_cpu
);
ctx
.
Wait
();
auto
*
out_cpu_data
=
out_cpu
.
data
<
float
>
();
...
...
@@ -191,7 +191,7 @@ TEST(selected_rows_functor, gpu_add_to) {
add_to_tensor_functor
(
ctx
,
*
output
,
tensor1
.
get
());
Tensor
tensor1_cpu
;
Copy
(
*
tensor1
,
cpu_place
,
ctx
,
&
tensor1_cpu
);
Tensor
Copy
(
*
tensor1
,
cpu_place
,
ctx
,
&
tensor1_cpu
);
ctx
.
Wait
();
auto
*
tensor1_cpu_data
=
tensor1_cpu
.
data
<
float
>
();
...
...
paddle/fluid/operators/math/sequence_padding.cu
浏览文件 @
394828b7
...
...
@@ -97,7 +97,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"width of sequence in LoDTensor seq."
);
if
(
!
norm_by_times
&&
num_sequences
==
1UL
)
{
Copy
(
seq
,
context
.
GetPlace
(),
context
,
&
padding
);
Tensor
Copy
(
seq
,
context
.
GetPlace
(),
context
,
&
padding
);
padding
.
Resize
(
padding_dims
);
return
;
}
...
...
@@ -172,7 +172,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"width of sequence in LoDTensor seq."
);
if
(
!
norm_by_times
&&
num_sequences
==
1UL
)
{
Copy
(
padding
,
context
.
GetPlace
(),
context
,
&
seq
);
Tensor
Copy
(
padding
,
context
.
GetPlace
(),
context
,
&
seq
);
seq
.
Resize
(
seq_dims
);
return
;
}
...
...
paddle/fluid/operators/math/sequence_padding_test.cc
浏览文件 @
394828b7
...
...
@@ -40,7 +40,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
seq
=
cpu_seq
;
}
else
{
Copy
(
cpu_seq
,
*
place
,
*
context
,
&
seq
);
Tensor
Copy
(
cpu_seq
,
*
place
,
*
context
,
&
seq
);
seq
.
set_lod
(
lod
);
}
...
...
@@ -63,7 +63,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
cpu_seq_back
=
seq_back
;
}
else
{
Copy
(
seq_back
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
cpu_seq_back
);
Tensor
Copy
(
seq_back
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
cpu_seq_back
);
cpu_seq_back
.
set_lod
(
lod
);
}
...
...
paddle/fluid/operators/math/vol2col_test.cc
浏览文件 @
394828b7
...
...
@@ -71,7 +71,7 @@ void testVol2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
input
=
input_tmp
;
}
else
{
Copy
(
input_tmp
,
*
place
,
*
context
,
&
input
);
paddle
::
framework
::
Tensor
Copy
(
input_tmp
,
*
place
,
*
context
,
&
input
);
}
output
.
mutable_data
<
float
>
({
1
,
filter_size
,
filter_size
,
filter_size
,
output_depth
,
output_height
,
output_width
},
...
...
@@ -85,7 +85,7 @@ void testVol2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
out_cfo_ptr
=
output
.
data
<
float
>
();
}
else
{
Copy
(
output
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
output_tmp
);
Tensor
Copy
(
output
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
output_tmp
);
out_cfo_ptr
=
output_tmp
.
data
<
float
>
();
}
...
...
@@ -99,7 +99,7 @@ void testVol2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
input
=
input_tmp
;
}
else
{
Copy
(
input_tmp
,
*
place
,
*
context
,
&
input
);
Tensor
Copy
(
input_tmp
,
*
place
,
*
context
,
&
input
);
}
paddle
::
operators
::
math
::
Col2VolFunctor
<
DeviceContext
,
float
>
col2vol
;
...
...
@@ -109,7 +109,7 @@ void testVol2col() {
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
in_ptr
=
input
.
data
<
float
>
();
}
else
{
Copy
(
input
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
input_tmp
);
Tensor
Copy
(
input
,
paddle
::
platform
::
CPUPlace
(),
*
context
,
&
input_tmp
);
in_ptr
=
input_tmp
.
data
<
float
>
();
}
...
...
paddle/fluid/operators/merge_lod_tensor_op.cc
浏览文件 @
394828b7
...
...
@@ -51,7 +51,8 @@ class MergeLoDTensorOp : public framework::OperatorBase {
cpu_mask
->
ShareDataWith
(
mask
);
}
else
if
(
platform
::
is_gpu_place
(
mask
.
place
()))
{
#ifdef PADDLE_WITH_CUDA
framework
::
Copy
(
mask
,
platform
::
CPUPlace
(),
dev_ctx
,
cpu_mask
.
get
());
framework
::
TensorCopy
(
mask
,
platform
::
CPUPlace
(),
dev_ctx
,
cpu_mask
.
get
());
#else
PADDLE_THROW
(
"Not supported GPU, Please compile WITH_GPU option"
);
#endif
...
...
@@ -106,8 +107,8 @@ class MergeLoDTensorOp : public framework::OperatorBase {
continue
;
}
auto
slice
=
out
->
Slice
(
out_offset
,
out_offset
+
len
);
framework
::
Copy
(
input
->
Slice
(
start_offset
,
end_offset
),
place
,
dev_ctx
,
&
slice
);
framework
::
TensorCopy
(
input
->
Slice
(
start_offset
,
end_offset
),
place
,
dev_ctx
,
&
slice
);
out_offset
+=
len
;
(
*
in_idx
)
+=
1
;
}
...
...
paddle/fluid/operators/mine_hard_examples_op.cc
浏览文件 @
394828b7
...
...
@@ -67,7 +67,8 @@ class MineHardExamplesKernel : public framework::OpKernel<T> {
auto
out_match_indices
=
ctx
.
Output
<
framework
::
Tensor
>
(
"UpdatedMatchIndices"
);
framework
::
Copy
(
*
in_matched_indices
,
ctx
.
GetPlace
(),
out_match_indices
);
framework
::
TensorCopy
(
*
in_matched_indices
,
ctx
.
GetPlace
(),
out_match_indices
);
int
batch_size
=
in_matched_indices
->
dims
()[
0
];
int
prior_num
=
in_matched_indices
->
dims
()[
1
];
...
...
paddle/fluid/operators/multiplex_op.cu
浏览文件 @
394828b7
...
...
@@ -33,7 +33,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
auto
cols
=
ins
[
0
]
->
numel
()
/
rows
;
// copy index to cpu
Tensor
index_t_cpu
;
Copy
(
*
ids
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
index_t_cpu
);
Tensor
Copy
(
*
ids
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
index_t_cpu
);
auto
*
index
=
index_t_cpu
.
data
<
int32_t
>
();
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
platform
::
CUDAPlace
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
...
...
@@ -69,7 +69,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
auto
cols
=
ins
[
0
]
->
numel
()
/
rows
;
// copy index to cpu
Tensor
index_t_cpu
;
Copy
(
*
ids
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
index_t_cpu
);
Tensor
Copy
(
*
ids
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
index_t_cpu
);
auto
*
index
=
index_t_cpu
.
data
<
int32_t
>
();
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
...
...
paddle/fluid/operators/nccl_op_test.cu.cc
浏览文件 @
394828b7
...
...
@@ -98,7 +98,7 @@ class NCCLTester : public ::testing::Test {
send_tensor
->
mutable_data
<
T
>
(
kDims
,
place
);
std
::
vector
<
T
>
send_vector
(
f
::
product
(
kDims
),
gpu_id
);
paddle
::
framework
::
Copy
FromVector
<
T
>
(
send_vector
,
*
ctx
,
send_tensor
);
paddle
::
framework
::
Tensor
FromVector
<
T
>
(
send_vector
,
*
ctx
,
send_tensor
);
ctx
->
Wait
();
VLOG
(
1
)
<<
"Send Tensor filled with elements "
<<
send_tensor
->
numel
();
}
...
...
paddle/fluid/operators/parallel_do_op.cc
浏览文件 @
394828b7
...
...
@@ -78,7 +78,7 @@ inline void CopyOrShare(const framework::Variable &src,
dst
->
GetMutable
<
LoDTensor
>
()
->
ShareDataWith
(
src
.
Get
<
LoDTensor
>
());
dst
->
GetMutable
<
LoDTensor
>
()
->
set_lod
(
src
.
Get
<
LoDTensor
>
().
lod
());
}
else
{
Copy
(
src
.
Get
<
LoDTensor
>
(),
dst_place
,
dst
->
GetMutable
<
LoDTensor
>
());
Tensor
Copy
(
src
.
Get
<
LoDTensor
>
(),
dst_place
,
dst
->
GetMutable
<
LoDTensor
>
());
}
}
else
if
(
src
.
IsType
<
SelectedRows
>
())
{
auto
&
src_sr
=
src
.
Get
<
SelectedRows
>
();
...
...
@@ -88,7 +88,7 @@ inline void CopyOrShare(const framework::Variable &src,
dst_sr
->
mutable_value
()
->
ShareDataWith
(
src_sr
.
value
());
dst_sr
->
set_rows
(
src_sr
.
rows
());
}
else
{
Copy
(
src_sr
.
value
(),
dst_place
,
dst_sr
->
mutable_value
());
Tensor
Copy
(
src_sr
.
value
(),
dst_place
,
dst_sr
->
mutable_value
());
}
}
else
{
PADDLE_THROW
(
"Expect LoDTensor/SelectedRows, get %s"
,
src
.
Type
().
name
());
...
...
@@ -146,7 +146,7 @@ class ParallelDoOp : public framework::OperatorBase {
auto
&
place
=
places
[
i
];
auto
*
sub_scope
=
sub_scopes
[
i
];
auto
*
dst
=
sub_scope
->
Var
(
param
)
->
GetMutable
<
LoDTensor
>
();
framework
::
Copy
(
src
,
place
,
dst
);
framework
::
Tensor
Copy
(
src
,
place
,
dst
);
}
}
WaitOnPlaces
(
places
);
...
...
paddle/fluid/operators/print_op.cc
浏览文件 @
394828b7
...
...
@@ -179,7 +179,7 @@ class TensorPrintOp : public framework::OperatorBase {
}
else
{
// copy data to cpu to print
platform
::
CPUPlace
place
;
framework
::
Copy
(
in_tensor
,
place
,
&
printed_tensor
);
framework
::
Tensor
Copy
(
in_tensor
,
place
,
&
printed_tensor
);
}
Formater
formater
;
...
...
paddle/fluid/operators/recurrent_op.cc
浏览文件 @
394828b7
...
...
@@ -291,7 +291,7 @@ class RecurrentOp : public RecurrentBase {
auto
dst_out
=
dst_tensor
->
Slice
(
seq_offset
,
seq_offset
+
1
);
// Explicit copy output since the local RNN scope can be destroyed
// early.
framework
::
Copy
(
src_tensor
,
place
,
dev_ctx
,
&
dst_out
);
framework
::
Tensor
Copy
(
src_tensor
,
place
,
dev_ctx
,
&
dst_out
);
});
scopes
.
Next
();
...
...
@@ -378,7 +378,7 @@ class RecurrentGradOp : public RecurrentBase {
auto
*
cur_grad_var
=
cur_scope
.
Var
(
cur_grad
);
auto
cur_grad_tensor
=
cur_grad_var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
Copy
(
ex_tensor
,
place
,
dev_ctx
,
cur_grad_tensor
);
framework
::
Tensor
Copy
(
ex_tensor
,
place
,
dev_ctx
,
cur_grad_tensor
);
}
}
...
...
@@ -452,7 +452,7 @@ class RecurrentGradOp : public RecurrentBase {
}
auto
dst
=
outside
->
Slice
(
seq_offset
,
seq_offset
+
1
);
framework
::
Copy
(
inside
,
place
,
dev_ctx
,
&
dst
);
framework
::
Tensor
Copy
(
inside
,
place
,
dev_ctx
,
&
dst
);
});
VLOG
(
5
)
<<
"Link outside gradient finished "
;
...
...
@@ -465,7 +465,7 @@ class RecurrentGradOp : public RecurrentBase {
framework
::
LoDTensor
*
outside
)
{
outside
->
Resize
(
inside
.
dims
());
outside
->
mutable_data
(
place
,
inside
.
type
());
framework
::
Copy
(
inside
,
place
,
dev_ctx
,
outside
);
framework
::
Tensor
Copy
(
inside
,
place
,
dev_ctx
,
outside
);
});
VLOG
(
5
)
<<
"Link initialize state gradient finished "
;
}
...
...
paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc
浏览文件 @
394828b7
...
...
@@ -170,7 +170,7 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
Copy
(
x_sliced
,
out_sliced
.
place
(),
dev_ctx
,
&
out_sliced
);
framework
::
Tensor
Copy
(
x_sliced
,
out_sliced
.
place
(),
dev_ctx
,
&
out_sliced
);
out_offset
+=
len
;
return
out_offset
;
}
...
...
paddle/fluid/operators/reshape_op.h
浏览文件 @
394828b7
...
...
@@ -28,7 +28,7 @@ class ReshapeKernel : public framework::OpKernel<T> {
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out_dims
=
out
->
dims
();
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
Copy
(
*
in
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
out
);
framework
::
Tensor
Copy
(
*
in
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
out
);
out
->
Resize
(
out_dims
);
}
};
...
...
@@ -42,7 +42,7 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
in_dims
=
d_x
->
dims
();
framework
::
Copy
(
*
d_out
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
framework
::
Tensor
Copy
(
*
d_out
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
d_x
->
Resize
(
in_dims
);
}
};
...
...
paddle/fluid/operators/sequence_reshape_op.h
浏览文件 @
394828b7
...
...
@@ -61,7 +61,7 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
}
}
framework
::
Copy
(
*
in
,
context
.
GetPlace
(),
out
);
framework
::
Tensor
Copy
(
*
in
,
context
.
GetPlace
(),
out
);
out
->
Resize
({
static_cast
<
int64_t
>
(
out
->
lod
()[
0
].
back
()),
out_width
});
}
};
...
...
@@ -77,7 +77,7 @@ class SequenceReshapeGradKernel : public framework::OpKernel<T> {
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
xg_tensor_ptr
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
Copy
(
*
outg_tensor_ptr
,
context
.
GetPlace
(),
xg_tensor_ptr
);
framework
::
Tensor
Copy
(
*
outg_tensor_ptr
,
context
.
GetPlace
(),
xg_tensor_ptr
);
xg_tensor_ptr
->
Resize
(
x_tensor_ptr
->
dims
());
}
};
...
...
paddle/fluid/operators/sequence_slice_op.h
浏览文件 @
394828b7
...
...
@@ -66,13 +66,13 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
offset_cpu
.
mutable_data
<
T
>
(
offset
->
dims
(),
platform
::
CPUPlace
());
framework
::
Copy
(
*
offset
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
offset_cpu
);
framework
::
Tensor
Copy
(
*
offset
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
offset_cpu
);
offset_data
=
offset_cpu
.
data
<
int64_t
>
();
length_cpu
.
mutable_data
<
T
>
(
length
->
dims
(),
platform
::
CPUPlace
());
framework
::
Copy
(
*
length
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
length_cpu
);
framework
::
Tensor
Copy
(
*
length
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
length_cpu
);
length_data
=
length_cpu
.
data
<
int64_t
>
();
}
...
...
@@ -127,13 +127,13 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
offset_cpu
.
mutable_data
<
T
>
(
offset
->
dims
(),
platform
::
CPUPlace
());
framework
::
Copy
(
*
offset
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
offset_cpu
);
framework
::
Tensor
Copy
(
*
offset
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
offset_cpu
);
offset_data
=
offset_cpu
.
data
<
int64_t
>
();
length_cpu
.
mutable_data
<
T
>
(
length
->
dims
(),
platform
::
CPUPlace
());
framework
::
Copy
(
*
length
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
length_cpu
);
framework
::
Tensor
Copy
(
*
length
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
length_cpu
);
length_data
=
length_cpu
.
data
<
int64_t
>
();
}
...
...
paddle/fluid/operators/shrink_rnn_memory_op.cc
浏览文件 @
394828b7
...
...
@@ -133,7 +133,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
auto
&
dout_tensor
=
dout_var
->
Get
<
framework
::
LoDTensor
>
();
auto
height
=
dout_tensor
.
dims
()[
0
];
auto
slice
=
dx_tensor
.
Slice
(
0
,
static_cast
<
int
>
(
height
));
framework
::
Copy
(
dout_tensor
,
dout_tensor
.
place
(),
dev_ctx
,
&
slice
);
framework
::
Tensor
Copy
(
dout_tensor
,
dout_tensor
.
place
(),
dev_ctx
,
&
slice
);
if
(
dx_tensor
.
dims
()[
0
]
>
height
)
{
auto
rest_tensor
=
dx_tensor
.
Slice
(
static_cast
<
int
>
(
height
),
static_cast
<
int
>
(
dx_tensor
.
dims
()[
0
]));
...
...
paddle/fluid/operators/split_lod_tensor_op.cc
浏览文件 @
394828b7
...
...
@@ -55,7 +55,8 @@ class SplitLoDTensorOp : public framework::OperatorBase {
cpu_mask
->
ShareDataWith
(
mask
);
}
else
if
(
platform
::
is_gpu_place
(
mask
.
place
()))
{
#ifdef PADDLE_WITH_CUDA
framework
::
Copy
(
mask
,
platform
::
CPUPlace
(),
dev_ctx
,
cpu_mask
.
get
());
framework
::
TensorCopy
(
mask
,
platform
::
CPUPlace
(),
dev_ctx
,
cpu_mask
.
get
());
#else
PADDLE_THROW
(
"Not supported GPU, Please compile WITH_GPU option"
);
#endif
...
...
@@ -113,9 +114,9 @@ class SplitLoDTensorOp : public framework::OperatorBase {
// out[offset: offset+len] = x[each_range.begin: each_range.end]
auto
slice
=
out
->
Slice
(
static_cast
<
int
>
(
offset
),
static_cast
<
int
>
(
offset
+
len
));
framework
::
Copy
(
x
.
Slice
(
static_cast
<
int
>
(
each_range
.
begin
),
static_cast
<
int
>
(
each_range
.
end
)),
x
.
place
(),
dev_ctx
,
&
slice
);
framework
::
Tensor
Copy
(
x
.
Slice
(
static_cast
<
int
>
(
each_range
.
begin
),
static_cast
<
int
>
(
each_range
.
end
)),
x
.
place
(),
dev_ctx
,
&
slice
);
offset
+=
len
;
}
}
...
...
paddle/fluid/operators/sum_op.h
浏览文件 @
394828b7
...
...
@@ -137,8 +137,8 @@ class SumKernel : public framework::OpKernel<T> {
out_array
.
resize
(
i
+
1
);
}
if
(
out_array
[
i
].
numel
()
==
0
)
{
framework
::
Copy
(
in_array
[
i
],
in_array
[
i
].
place
(),
context
.
device_context
(),
&
out_array
[
i
]);
framework
::
Tensor
Copy
(
in_array
[
i
],
in_array
[
i
].
place
(),
context
.
device_context
(),
&
out_array
[
i
]);
out_array
[
i
].
set_lod
(
in_array
[
i
].
lod
());
}
else
{
PADDLE_ENFORCE
(
out_array
[
i
].
lod
()
==
in_array
[
i
].
lod
());
...
...
paddle/fluid/operators/tensor_array_read_write_op.cc
浏览文件 @
394828b7
...
...
@@ -45,7 +45,7 @@ class WriteToArrayOp : public ArrayOp {
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
Copy
(
x_tensor
,
place
,
dev_ctx
,
out_tensor
);
Tensor
Copy
(
x_tensor
,
place
,
dev_ctx
,
out_tensor
);
out_tensor
->
set_lod
(
x_tensor
.
lod
());
}
else
{
VLOG
(
10
)
<<
"WARNING: The input tensor 'x_tensor' holds no memory, so "
...
...
@@ -138,7 +138,7 @@ class ReadFromArrayOp : public ArrayOp {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
Copy
(
x_array
[
offset
],
place
,
dev_ctx
,
out_tensor
);
framework
::
Tensor
Copy
(
x_array
[
offset
],
place
,
dev_ctx
,
out_tensor
);
out_tensor
->
set_lod
(
x_array
[
offset
].
lod
());
}
else
{
VLOG
(
10
)
<<
"offset "
<<
offset
<<
" >= "
<<
x_array
.
size
();
...
...
paddle/fluid/operators/warpctc_op.h
浏览文件 @
394828b7
...
...
@@ -185,7 +185,8 @@ class WarpCTCKernel : public framework::OpKernel<T> {
// warpctc accesses labels in CPU memory
Tensor
warpctc_label
;
Copy
(
*
label
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
warpctc_label
);
TensorCopy
(
*
label
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
warpctc_label
);
const
int
*
warpctc_label_data
=
warpctc_label
.
data
<
int
>
();
// warpctc stores loss in CPU memory
Tensor
warpctc_loss
;
...
...
@@ -200,7 +201,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
sequence_width
,
num_sequences
,
blank
,
warpctc_loss_data
);
// Copy the loss back
Copy
(
warpctc_loss
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
loss
);
Tensor
Copy
(
warpctc_loss
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
loss
);
}
};
...
...
paddle/fluid/pybind/tensor_py.h
浏览文件 @
394828b7
...
...
@@ -101,7 +101,7 @@ T TensorGetElement(framework::Tensor &self, size_t offset) {
return
self
.
data
<
T
>
()[
offset
];
}
else
{
std
::
shared_ptr
<
framework
::
Tensor
>
dst
(
new
framework
::
Tensor
);
framework
::
Copy
(
self
,
platform
::
CPUPlace
(),
dst
.
get
());
framework
::
Tensor
Copy
(
self
,
platform
::
CPUPlace
(),
dst
.
get
());
return
dst
->
data
<
T
>
()[
offset
];
}
}
...
...
@@ -111,9 +111,9 @@ template <typename T>
void
TensorSetElement
(
framework
::
Tensor
&
self
,
size_t
offset
,
T
elem
)
{
if
(
platform
::
is_gpu_place
(
self
.
place
()))
{
std
::
shared_ptr
<
framework
::
Tensor
>
dst
(
new
framework
::
Tensor
);
framework
::
Copy
(
self
,
platform
::
CPUPlace
(),
dst
.
get
());
framework
::
Tensor
Copy
(
self
,
platform
::
CPUPlace
(),
dst
.
get
());
dst
->
data
<
T
>
()[
offset
]
=
elem
;
framework
::
Copy
(
*
dst
.
get
(),
self
.
place
(),
&
self
);
framework
::
Tensor
Copy
(
*
dst
.
get
(),
self
.
place
(),
&
self
);
}
else
if
(
platform
::
is_cpu_place
(
self
.
place
()))
{
self
.
data
<
T
>
()[
offset
]
=
elem
;
...
...
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
394828b7
...
...
@@ -68,6 +68,7 @@ __all__ = [
'layer_norm'
,
'softmax_with_cross_entropy'
,
'smooth_l1'
,
'one_hot'
,
]
...
...
@@ -3212,3 +3213,40 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
'Out'
:
loss
},
attrs
=
{
'sigma'
:
sigma
})
return
loss
def
one_hot
(
input
,
depth
):
"""
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
operator.
Args:
input(Tensor/LodTensor): A Tensor/LodTensor of indices, last dimension must be 1.
depth(scalar): an interger defining the depth of the one hot dimension.
Returns:
The one-hot tensor or LodTensor, same as input.
Examples:
X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4, 1]
X.data = [[1], [1], [3], [0]]
set depth = 4
Out is a LoDTensor:
Out.lod = [[0, 1, 4]]
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]]
"""
helper
=
LayerHelper
(
"one_hot"
,
**
locals
())
one_hot_out
=
helper
.
create_tmp_variable
(
dtype
=
'float32'
)
helper
.
append_op
(
type
=
"one_hot"
,
inputs
=
{
'X'
:
input
},
attrs
=
{
'depth'
:
depth
},
outputs
=
{
'Out'
:
one_hot_out
})
return
one_hot_out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录