Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2c29cf1e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2c29cf1e
编写于
9月 19, 2017
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use Tensor as the temp variables instead of CUDA api
上级
8d9d537b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
48 addition
and
48 deletion
+48
-48
paddle/operators/crop_op.cc
paddle/operators/crop_op.cc
+23
-23
paddle/operators/crop_op.cu
paddle/operators/crop_op.cu
+25
-25
未找到文件。
paddle/operators/crop_op.cc
浏览文件 @
2c29cf1e
...
...
@@ -27,12 +27,12 @@ class CropOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
x_dim
=
ctx
.
Input
<
LoDTensor
>
(
"X"
)
->
dims
();
auto
Y
=
ctx
.
Input
<
LoDTensor
>
(
"Y"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) of CropOp should not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Out"
),
"Output(Out) of CropOp should not be null."
);
auto
x_dim
=
ctx
.
Input
<
LoDTensor
>
(
"X"
)
->
dims
();
auto
Y
=
ctx
.
Input
<
LoDTensor
>
(
"Y"
);
if
(
Y
==
nullptr
)
{
auto
shape
=
Attr
<
std
::
vector
<
int
>>
(
"shape"
);
PADDLE_ENFORCE_EQ
(
...
...
@@ -40,7 +40,7 @@ class CropOp : public framework::OperatorWithKernel {
"Shape size should be equal to dimention size of input tensor."
);
std
::
vector
<
int64_t
>
tensor_shape
(
shape
.
size
());
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
tensor_shape
[
i
]
=
(
int64_t
)
shape
[
i
]
;
tensor_shape
[
i
]
=
static_cast
<
int64_t
>
(
shape
[
i
])
;
}
ctx
.
Output
<
LoDTensor
>
(
"Out"
)
->
Resize
(
framework
::
make_ddim
(
tensor_shape
));
}
else
{
...
...
@@ -65,6 +65,15 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Out"
,
"The output of crop op "
"with the same dimension as X."
);
AddAttr
<
std
::
vector
<
int
>>
(
"offsets"
,
"A list<int> describing offsets to be cropped."
"The size of offsets list should be as same as "
"dimension size of input X."
);
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"A list<int> describing the shape of output."
"The size of shape list should be as same as "
"dimension size of input X."
)
.
SetDefault
(
std
::
vector
<
int
>
());
AddComment
(
R"DOC(
Crop Operator.
Crop input into output, as specified by offsets and shape.
...
...
@@ -81,33 +90,24 @@ The input should be a k-D tensor(k > 0 and k < 7). As an example:
Given:
X = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]]
X = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]]
and
offsets = [0, 1]
offsets = [0, 1]
and
shape = [2, 2]
shape = [2, 2]
then we get
Out = [[1, 2],
[3, 4]]
Out = [[1, 2],
[3, 4]]
)DOC"
);
AddAttr
<
std
::
vector
<
int
>>
(
"offsets"
,
"A list<int> describing offsets to be cropped."
"The size of offsets list should be as same as "
"dimension size of input X."
);
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"A list<int> describing the shape of output."
"The size of shape list should be as same as "
"dimension size of input X."
)
.
SetDefault
(
std
::
vector
<
int
>
());
}
};
...
...
@@ -149,17 +149,17 @@ template <typename T>
class
CropCPUKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
LoD
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
LoD
Tensor
>
(
"Out"
);
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
x_data
=
x
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
x_dims
=
x
->
dims
();
auto
out_dims
=
out
->
dims
();
int64_t
out_count
=
framework
::
product
(
out_dims
);
int64_t
out_count
=
out
->
numel
(
);
std
::
vector
<
int64_t
>
x_shape
=
framework
::
vectorize
(
x_dims
);
std
::
vector
<
int64_t
>
out_shape
=
framework
::
vectorize
(
out_dims
);
auto
offsets
=
context
.
op
().
Attr
<
std
::
vector
<
int
>>
(
"offsets"
);
auto
offsets
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"offsets"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
offsets
.
size
(),
"Offsets size should be equal to dimension size of input tensor."
);
...
...
paddle/operators/crop_op.cu
浏览文件 @
2c29cf1e
...
...
@@ -20,6 +20,7 @@ namespace paddle {
namespace
operators
{
using
framework
::
LoDTensor
;
using
framework
::
Tensor
;
template
<
typename
T
,
int
D
>
__global__
void
CropKernel
(
const
int
N
,
const
int64_t
*
out_shape
,
...
...
@@ -54,35 +55,36 @@ void CropCUDAFunctoin(const framework::ExecutionContext& context) {
T
*
out_data
=
out
->
mutable_data
<
T
>
(
paddle
::
platform
::
GPUPlace
());
auto
x_dims
=
x
->
dims
();
auto
out_dims
=
out
->
dims
();
int64_t
out_count
=
framework
::
product
(
out_dims
);
int64_t
x_shape
[
D
];
int64_t
out_shape
[
D
];
int64_t
out_count
=
out
->
numel
();
Tensor
x_shape
;
Tensor
out_shape
;
int64_t
*
x_shape_data
=
x_shape
.
mutable_data
<
int64_t
>
({
D
},
paddle
::
platform
::
CPUPlace
());
int64_t
*
out_shape_data
=
out_shape
.
mutable_data
<
int64_t
>
({
D
},
paddle
::
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
D
;
++
i
)
{
x_shape
[
i
]
=
x_dims
[
i
];
out_shape
[
i
]
=
out_dims
[
i
];
x_shape
_data
[
i
]
=
x_dims
[
i
];
out_shape
_data
[
i
]
=
out_dims
[
i
];
}
int64_t
*
x_shape_gpu
;
int64_t
*
out_shape_gpu
;
cudaMalloc
((
void
**
)
&
x_shape_gpu
,
sizeof
(
int64_t
)
*
D
);
cudaMemcpy
(
x_shape_gpu
,
x_shape
,
sizeof
(
int64_t
)
*
D
,
cudaMemcpyHostToDevice
);
cudaMalloc
((
void
**
)
&
out_shape_gpu
,
sizeof
(
int64_t
)
*
D
);
cudaMemcpy
(
out_shape_gpu
,
out_shape
,
sizeof
(
int64_t
)
*
D
,
cudaMemcpyHostToDevice
);
Tensor
x_shape_gpu
;
Tensor
out_shape_gpu
;
x_shape_gpu
.
CopyFrom
<
int64_t
>
(
x_shape
,
paddle
::
platform
::
GPUPlace
());
out_shape_gpu
.
CopyFrom
<
int64_t
>
(
out_shape
,
paddle
::
platform
::
GPUPlace
());
auto
offsets
=
context
.
op
().
Attr
<
std
::
vector
<
int
>>
(
"offsets"
);
PADDLE_ENFORCE_EQ
(
D
,
offsets
.
size
(),
"Offsets size should be equal to dimension size of input tensor."
);
int
crop_rules
[
D
*
2
];
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
++
i
)
{
crop_rules
[
i
*
2
]
=
offsets
[
i
];
crop_rules
[
i
*
2
+
1
]
=
x_dims
[
i
]
-
out_dims
[
i
]
-
offsets
[
i
];
Tensor
crop_rules
;
int
*
crop_rules_data
=
crop_rules
.
mutable_data
<
int
>
({
D
*
2
},
paddle
::
platform
::
CPUPlace
());
for
(
size_t
i
=
0
;
i
<
D
;
++
i
)
{
crop_rules_data
[
i
*
2
]
=
offsets
[
i
];
crop_rules_data
[
i
*
2
+
1
]
=
x_dims
[
i
]
-
out_dims
[
i
]
-
offsets
[
i
];
}
int
*
crop_rules_gpu
;
cudaMalloc
((
void
**
)
&
crop_rules_gpu
,
sizeof
(
int
)
*
D
*
2
);
cudaMemcpy
(
crop_rules_gpu
,
crop_rules
,
sizeof
(
int
)
*
D
*
2
,
cudaMemcpyHostToDevice
);
Tensor
crop_rules_gpu
;
crop_rules_gpu
.
CopyFrom
<
int
>
(
crop_rules
,
paddle
::
platform
::
GPUPlace
());
int
n
=
out_dims
[
0
];
int
d
=
out_dims
[
1
];
...
...
@@ -94,11 +96,9 @@ void CropCUDAFunctoin(const framework::ExecutionContext& context) {
CropKernel
<
T
,
D
><<<
grid
,
block
,
0
,
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
device_context
)
->
stream
()
>>>
(
out_count
,
out_shape_gpu
,
x_shape_gpu
,
crop_rules_gpu
,
x_data
,
out_data
);
cudaFree
(
crop_rules_gpu
);
cudaFree
(
x_shape_gpu
);
cudaFree
(
out_shape_gpu
);
->
stream
()
>>>
(
out_count
,
out_shape_gpu
.
data
<
int64_t
>
(),
x_shape_gpu
.
data
<
int64_t
>
(),
crop_rules_gpu
.
data
<
int
>
(),
x_data
,
out_data
);
}
template
<
typename
T
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录