Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d6651b9b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d6651b9b
编写于
9月 08, 2017
作者:
X
xzl
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixed bug of the gpu impl
上级
17b4b980
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
52 addition
and
19 deletion
+52
-19
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+1
-3
paddle/operators/transpose_op.cc
paddle/operators/transpose_op.cc
+7
-5
paddle/operators/transpose_op.cu
paddle/operators/transpose_op.cu
+17
-10
paddle/operators/transpose_op.h
paddle/operators/transpose_op.h
+0
-1
python/paddle/v2/framework/tests/test_transpose_op.py
python/paddle/v2/framework/tests/test_transpose_op.py
+27
-0
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
d6651b9b
...
...
@@ -51,8 +51,7 @@ list(REMOVE_ITEM GENERAL_OPS
minus_op
mul_op
recurrent_op
scale_op
transpose_op
)
scale_op
)
op_library
(
net_op SRCS net_op.cc
)
op_library
(
minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op
)
...
...
@@ -60,7 +59,6 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor operator net_op
)
op_library
(
scale_op SRCS scale_op.cc scale_op.cu DEPS net_op
)
op_library
(
transpose_op SRCS transpose_op.cc transpose_op.cu DEPS paddle_memory device_context
)
foreach
(
src
${
GENERAL_OPS
}
)
op_library
(
${
src
}
SRCS
${
src
}
.cc
${
src
}
.cu
)
...
...
paddle/operators/transpose_op.cc
浏览文件 @
d6651b9b
...
...
@@ -31,6 +31,7 @@ class TransposeOp : public framework::OperatorWithKernel {
auto
axis
=
ctx
.
GetAttr
<
std
::
vector
<
int
>>
(
"axis"
);
size_t
in_dim_size
=
in_dim
.
size
();
size_t
axis_size
=
axis
.
size
();
PADDLE_ENFORCE_EQ
(
in_dim_size
,
axis_size
,
"the input tensor dimensions should be equal to the axis size"
);
...
...
@@ -42,7 +43,7 @@ class TransposeOp : public framework::OperatorWithKernel {
"the sorted axis should be [0, 1, ... dims - 1], "
"the dims equals to the input tensor dimensions"
);
}
//
framework
::
DDim
out_dim
(
in_dim
);
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
i
++
)
{
out_dim
[
i
]
=
in_dim
[
axis
[
i
]];
...
...
@@ -60,11 +61,12 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Out"
,
"The output of transpose op"
);
AddAttr
<
std
::
vector
<
int
>>
(
"axis"
,
"a list of integers, and the num of integers should be "
"the same with the input tensor dimensions"
);
"a list of values, and the size of the list should be "
"the same with the input tensor dimensions, the tensor will "
"permute the axes according the the values given"
);
AddComment
(
R"DOC(
T
ranspose the input tensor.
For example,
input tensor shape(N, C, H, W) and ax
is {0, 2, 3, 1},
T
he Tensor will be permuted according to the axis values given.
For example,
given a input tensor of shape(N, C, H, W) and the axis
is {0, 2, 3, 1},
the output tensor shape will be (N, H, W, C)
)DOC"
);
}
...
...
paddle/operators/transpose_op.cu
浏览文件 @
d6651b9b
...
...
@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/operators/transpose_op.h"
...
...
@@ -24,7 +25,7 @@ __global__ void transpose_kernel(int nthreads, const T* in_data, T* out_data,
int
*
offset_buffer
,
int
ndims
)
{
int
*
in_offset
=
offset_buffer
;
int
*
out_offset
=
offset_buffer
+
ndims
;
int
*
axis
=
offset_buffer
+
ndims
;
int
*
axis
=
offset_buffer
+
ndims
*
2
;
int
to_index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -51,31 +52,37 @@ void TransposeCUDA(const framework::ExecutionContext& context,
size_t
ndims
=
in_dim
.
size
();
std
::
vector
<
int
>
in_offset
(
ndims
,
1
);
std
::
vector
<
int
>
out_offset
(
ndims
,
1
);
std
::
vector
<
int64_t
>
buffer_dim_shape
(
1
,
ndims
*
3
);
auto
cpu_place
=
platform
::
CPUPlace
();
auto
gpu_place
=
boost
::
get
<
platform
::
GPUPlace
>
(
context
.
GetPlace
());
// Get a host_buffer to cache the input offset, output offset and the axis.
std
::
vector
<
int64_t
>
buffer_dim_shape
(
1
,
ndims
*
3
);
auto
buffer_dims
=
framework
::
make_ddim
(
buffer_dim_shape
);
framework
::
Tensor
host_buffer
;
platform
::
CPUPlace
cpu_place
;
platform
::
GPUPlace
gpu_place
;
int
*
host_buffer_data
=
host_buffer
.
mutable_data
<
int
>
(
buffer_dims
,
cpu_place
);
auto
offset_buffer
=
memory
::
Alloc
(
context
.
GetPlace
(),
ndims
*
3
*
sizeof
(
int
));
for
(
int
i
=
ndims
-
2
;
i
>=
0
;
i
--
)
{
in_offset
[
i
]
=
in_offset
[
i
+
1
]
*
in_dim
[
i
+
1
];
out_offset
[
i
]
=
out_offset
[
i
+
1
]
*
out_dim
[
i
+
1
];
}
// copy the data to the host_buffer
for
(
int
i
=
0
;
i
<
ndims
;
i
++
)
{
host_buffer_data
[
i
]
=
in_offset
[
i
];
host_buffer_data
[
i
+
ndims
]
=
out_offset
[
i
];
host_buffer_data
[
i
+
ndims
*
2
]
=
axis
[
i
];
}
// Get a device_buffer to cache the input offset, output offset and the axis.
auto
offset_buffer
=
memory
::
Alloc
(
gpu_place
,
ndims
*
3
*
sizeof
(
int
));
auto
*
cuda_device_context
=
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
));
// copy the host_buffer data to the device_buffer
memory
::
Copy
(
gpu_place
,
offset_buffer
,
cpu_place
,
host_buffer_data
,
ndims
*
3
*
sizeof
(
int
));
ndims
*
3
*
sizeof
(
int
),
cuda_device_context
->
stream
());
int
block
=
512
;
int
grid
=
(
data_size
+
block
-
1
)
/
block
;
transpose_kernel
<
T
><<<
grid
,
block
>>>
(
data_size
,
in_data
,
out_data
,
...
...
paddle/operators/transpose_op.h
浏览文件 @
d6651b9b
...
...
@@ -17,7 +17,6 @@
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
...
...
python/paddle/v2/framework/tests/test_transpose_op.py
0 → 100644
浏览文件 @
d6651b9b
import
unittest
import
numpy
as
np
from
gradient_checker
import
GradientChecker
from
op_test_util
import
OpTestMeta
from
paddle.v2.framework.op
import
Operator
class
TestTransposeOp
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
self
.
type
=
"transpose"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
3
,
4
)).
astype
(
"float32"
),
}
self
.
attrs
=
{
'axis'
:
[
1
,
0
]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
transpose
((
1
,
0
))}
class
TransposeGradOpTest
(
GradientChecker
):
def
test_transpose
(
self
):
op
=
Operator
(
"transpose"
,
X
=
"X"
,
Out
=
"Out"
,
axis
=
[
1
,
0
])
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
84
)).
astype
(
"float32"
),
}
self
.
check_grad
(
op
,
inputs
,
set
([
"X"
]),
"Out"
,
max_relative_error
=
0.5
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录