Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
68b5e5bf
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
68b5e5bf
编写于
9月 21, 2017
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use stridecpy instead of CUDA kernel
上级
ce709b75
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
45 addition
and
179 deletion
+45
-179
paddle/operators/crop_op.cc
paddle/operators/crop_op.cc
+1
-49
paddle/operators/crop_op.cu
paddle/operators/crop_op.cu
+1
-120
paddle/operators/crop_op.h
paddle/operators/crop_op.h
+43
-10
未找到文件。
paddle/operators/crop_op.cc
浏览文件 @
68b5e5bf
...
...
@@ -128,59 +128,11 @@ class CropOpGrad : public framework::OperatorWithKernel {
}
};
int64_t
transIndex
(
std
::
vector
<
int64_t
>
out_shape
,
std
::
vector
<
int64_t
>
x_shape
,
std
::
vector
<
std
::
pair
<
int
,
int
>>
crop_rules
,
size_t
index
)
{
int64_t
dim_size
=
out_shape
.
size
();
std
::
vector
<
int64_t
>
pos
(
dim_size
);
for
(
int64_t
i
=
out_shape
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
pos
[
i
]
=
(
index
%
out_shape
[
i
])
+
crop_rules
[
i
].
first
;
index
=
index
/
out_shape
[
i
];
}
size_t
result
=
pos
[
0
];
for
(
size_t
i
=
1
;
i
<
x_shape
.
size
();
++
i
)
{
result
=
result
*
x_shape
[
i
]
+
pos
[
i
];
}
return
result
;
}
template
<
typename
T
>
class
CropCPUKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
x_data
=
x
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
x_dims
=
x
->
dims
();
auto
out_dims
=
out
->
dims
();
int64_t
out_count
=
out
->
numel
();
std
::
vector
<
int64_t
>
x_shape
=
framework
::
vectorize
(
x_dims
);
std
::
vector
<
int64_t
>
out_shape
=
framework
::
vectorize
(
out_dims
);
auto
offsets
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"offsets"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
offsets
.
size
(),
"Offsets size should be equal to dimension size of input tensor."
);
std
::
vector
<
std
::
pair
<
int
,
int
>>
crop_rules
(
x_dims
.
size
());
for
(
size_t
i
=
0
;
i
<
crop_rules
.
size
();
++
i
)
{
crop_rules
[
i
].
first
=
offsets
[
i
];
crop_rules
[
i
].
second
=
x_dims
[
i
]
-
out_dims
[
i
]
-
offsets
[
i
];
}
for
(
int64_t
i
=
0
;
i
<
out_count
;
++
i
)
{
out_data
[
i
]
=
x_data
[
transIndex
(
out_shape
,
x_shape
,
crop_rules
,
i
)];
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
crop
,
ops
::
CropOp
,
ops
::
CropOpMaker
,
crop_grad
,
ops
::
CropOpGrad
);
REGISTER_OP_CPU_KERNEL
(
crop
,
ops
::
Crop
CPU
Kernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
crop
,
ops
::
CropKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
crop_grad
,
ops
::
CropGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/crop_op.cu
浏览文件 @
68b5e5bf
...
...
@@ -13,128 +13,9 @@
limitations under the License. */
#define EIGEN_USE_GPU
#include <stdio.h>
#include "paddle/operators/crop_op.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
LoDTensor
;
using
framework
::
Tensor
;
template
<
typename
T
,
int
D
>
__global__
void
CropKernel
(
const
int
N
,
const
int64_t
*
out_shape
,
const
int64_t
*
x_shape
,
const
int
*
crop_rules
,
const
T
*
x_data
,
T
*
out_data
)
{
int64_t
pos
[
D
];
int
tmp
;
int64_t
x_index
;
for
(
int
out_index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
out_index
<
N
;
out_index
+=
blockDim
.
x
*
gridDim
.
x
)
{
tmp
=
out_index
;
for
(
int64_t
i
=
D
-
1
;
i
>=
0
;
--
i
)
{
pos
[
i
]
=
(
tmp
%
out_shape
[
i
])
+
crop_rules
[
i
*
2
];
tmp
=
tmp
/
out_shape
[
i
];
}
x_index
=
pos
[
0
];
for
(
size_t
i
=
1
;
i
<
D
;
++
i
)
{
x_index
=
x_index
*
x_shape
[
i
]
+
pos
[
i
];
}
out_data
[
out_index
]
=
x_data
[
x_index
];
}
}
template
<
typename
T
,
int
D
>
void
CropCUDAFunctoin
(
const
framework
::
ExecutionContext
&
context
)
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
context
.
GetPlace
()),
"It must use GPUPlace."
);
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
auto
x_data
=
x
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
paddle
::
platform
::
GPUPlace
());
auto
x_dims
=
x
->
dims
();
auto
out_dims
=
out
->
dims
();
int64_t
out_count
=
out
->
numel
();
Tensor
x_shape
;
Tensor
out_shape
;
int64_t
*
x_shape_data
=
x_shape
.
mutable_data
<
int64_t
>
({
D
},
paddle
::
platform
::
CPUPlace
());
int64_t
*
out_shape_data
=
out_shape
.
mutable_data
<
int64_t
>
({
D
},
paddle
::
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
D
;
++
i
)
{
x_shape_data
[
i
]
=
x_dims
[
i
];
out_shape_data
[
i
]
=
out_dims
[
i
];
}
Tensor
x_shape_gpu
;
Tensor
out_shape_gpu
;
x_shape_gpu
.
CopyFrom
<
int64_t
>
(
x_shape
,
paddle
::
platform
::
GPUPlace
());
out_shape_gpu
.
CopyFrom
<
int64_t
>
(
out_shape
,
paddle
::
platform
::
GPUPlace
());
auto
offsets
=
context
.
op
().
Attr
<
std
::
vector
<
int
>>
(
"offsets"
);
PADDLE_ENFORCE_EQ
(
D
,
offsets
.
size
(),
"Offsets size should be equal to dimension size of input tensor."
);
Tensor
crop_rules
;
int
*
crop_rules_data
=
crop_rules
.
mutable_data
<
int
>
({
D
*
2
},
paddle
::
platform
::
CPUPlace
());
for
(
size_t
i
=
0
;
i
<
D
;
++
i
)
{
crop_rules_data
[
i
*
2
]
=
offsets
[
i
];
crop_rules_data
[
i
*
2
+
1
]
=
x_dims
[
i
]
-
out_dims
[
i
]
-
offsets
[
i
];
}
Tensor
crop_rules_gpu
;
crop_rules_gpu
.
CopyFrom
<
int
>
(
crop_rules
,
paddle
::
platform
::
GPUPlace
());
int
n
=
out_dims
[
0
];
int
d
=
out_dims
[
1
];
int
block
=
512
;
int
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
CropKernel
<
T
,
D
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
.
device_context
())
.
stream
()
>>>
(
out_count
,
out_shape_gpu
.
data
<
int64_t
>
(),
x_shape_gpu
.
data
<
int64_t
>
(),
crop_rules_gpu
.
data
<
int
>
(),
x_data
,
out_data
);
}
template
<
typename
T
>
class
CropOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
size_t
rank
=
context
.
Input
<
LoDTensor
>
(
"X"
)
->
dims
().
size
();
switch
(
rank
)
{
case
1
:
CropCUDAFunctoin
<
T
,
1
>
(
context
);
break
;
case
2
:
CropCUDAFunctoin
<
T
,
2
>
(
context
);
break
;
case
3
:
CropCUDAFunctoin
<
T
,
3
>
(
context
);
break
;
case
4
:
CropCUDAFunctoin
<
T
,
4
>
(
context
);
break
;
case
5
:
CropCUDAFunctoin
<
T
,
5
>
(
context
);
break
;
case
6
:
CropCUDAFunctoin
<
T
,
6
>
(
context
);
break
;
default:
PADDLE_THROW
(
"CropOp only support tensors with no more than 6 dimensions."
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
crop
,
ops
::
Crop
OpCUDA
Kernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
crop
,
ops
::
CropKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
crop_grad
,
ops
::
CropGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/crop_op.h
浏览文件 @
68b5e5bf
...
...
@@ -16,6 +16,7 @@
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"
namespace
paddle
{
namespace
operators
{
// Internal
...
...
@@ -24,26 +25,58 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
framework
::
LoDTensor
;
using
framework
::
Tensor
;
using
framework
::
DDim
;
// TODO(wanghaoshuang): move this function to other place
DDim
stride
(
const
DDim
&
ddim
)
{
std
::
vector
<
int64_t
>
strides
(
ddim
.
size
());
strides
[
ddim
.
size
()
-
1
]
=
1
;
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
+
1
];
}
return
make_ddim
(
strides
);
}
template
<
typename
T
>
class
CropKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
T
*
x_data
=
x
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
x_stride
=
stride
(
x
->
dims
());
auto
out_stride
=
stride
(
out
->
dims
());
auto
offsets
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"offsets"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
offsets
.
size
(),
"Offsets size should be equal to dimension size of input tensor."
);
int64_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
offsets
.
size
();
++
i
)
{
offset
+=
(
x_stride
[
i
]
*
offsets
[
i
]);
}
StridedMemcpy
<
T
>
(
context
.
device_context
(),
x_data
+
offset
,
x_stride
,
out
->
dims
(),
out_stride
,
out_data
);
}
};
template
<
typename
Place
,
typename
T
,
size_t
D
>
void
CropGradFunction
(
const
framework
::
ExecutionContext
&
context
)
{
auto
*
d_out
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
d_x
!=
nullptr
)
{
auto
*
d_out
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
d_x_dims
=
d_x
->
dims
();
auto
d_out_dims
=
d_out
->
dims
();
auto
offsets
=
context
.
op
().
Attr
<
std
::
vector
<
int
>>
(
"offsets"
);
auto
offsets
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"offsets"
);
Eigen
::
array
<
std
::
pair
<
int
,
int
>
,
D
>
paddings
;
for
(
int
i
=
0
;
i
<
d_out_dims
.
size
()
;
++
i
)
{
for
(
int
i
=
0
;
i
<
D
;
++
i
)
{
paddings
[
i
].
first
=
offsets
[
i
];
paddings
[
i
].
second
=
d_x_dims
[
i
]
-
d_out_dims
[
i
]
-
offsets
[
i
];
}
auto
d_x_tensor
=
EigenTensor
<
T
,
D
>::
From
(
*
d_x
);
auto
d_out_tensor
=
EigenTensor
<
T
,
D
>::
From
(
*
d_out
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
d_x_tensor
.
device
(
place
)
=
d_out_tensor
.
pad
(
paddings
,
0
);
d_x_tensor
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
d_out_tensor
.
pad
(
paddings
,
0
);
}
}
...
...
@@ -52,7 +85,7 @@ class CropGradKernel : public framework::OpKernel {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
size_t
rank
=
context
.
Input
<
LoD
Tensor
>
(
framework
::
GradVarName
(
"Out"
))
->
dims
().
size
();
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
))
->
dims
().
size
();
switch
(
rank
)
{
case
1
:
CropGradFunction
<
Place
,
T
,
1
>
(
context
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录