Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5ede6fd4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5ede6fd4
编写于
9月 18, 2017
作者:
X
xzl
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete cuda impl, complete comments, modify variable naming
上级
6b3ae01e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
121 addition
and
209 deletion
+121
-209
paddle/operators/transpose_op.cc
paddle/operators/transpose_op.cc
+47
-30
paddle/operators/transpose_op.cu
paddle/operators/transpose_op.cu
+5
-112
paddle/operators/transpose_op.h
paddle/operators/transpose_op.h
+31
-52
python/paddle/v2/framework/tests/test_transpose_op.py
python/paddle/v2/framework/tests/test_transpose_op.py
+38
-15
未找到文件。
paddle/operators/transpose_op.cc
浏览文件 @
5ede6fd4
...
...
@@ -13,8 +13,6 @@
limitations under the License. */
#include "paddle/operators/transpose_op.h"
#include <vector>
#include "paddle/framework/ddim.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -27,28 +25,31 @@ class TransposeOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
in_dim
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Input"
),
"Input(Input) should not be null"
);
auto
input_dim
=
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
dims
();
auto
axis
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
size_t
in
_dim_size
=
in
_dim
.
size
();
size_t
in
put_dim_size
=
input
_dim
.
size
();
size_t
axis_size
=
axis
.
size
();
PADDLE_ENFORCE_EQ
(
in_dim_size
,
axis_size
,
"the input tensor dimensions should be equal to the axis size"
);
PADDLE_ENFORCE_EQ
(
input_dim_size
,
axis_size
,
"the input tensor's dimension(%d) "
"should be equal to the axis's size(%d)"
,
input_dim_size
,
axis_size
);
std
::
vector
<
int
>
axis_sorted
(
axis
);
std
::
sort
(
axis_sorted
.
begin
(),
axis_sorted
.
end
());
for
(
size_t
i
=
0
;
i
<
axis_sorted
.
size
();
i
++
)
{
PADDLE_ENFORCE_EQ
(
axis_sorted
[
i
],
(
int
)
i
,
PADDLE_ENFORCE_EQ
(
axis_sorted
[
i
],
static_cast
<
int
>
(
i
)
,
"the sorted axis should be [0, 1, ... dims - 1], "
"
the dims equals to the input tensor dimensions
"
);
"
where the dims is the axis's size
"
);
}
framework
::
DDim
out
_dim
(
in
_dim
);
framework
::
DDim
out
put_dim
(
input
_dim
);
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
i
++
)
{
out
_dim
[
i
]
=
in
_dim
[
axis
[
i
]];
out
put_dim
[
i
]
=
input
_dim
[
axis
[
i
]];
}
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
(
o
ut_dim
);
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Output"
)
->
Resize
(
outp
ut_dim
);
}
};
...
...
@@ -57,16 +58,30 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
TransposeOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of transpose op"
);
AddOutput
(
"Out"
,
"The output of transpose op"
);
AddInput
(
"Input"
,
"(Tensor)The input tensor, tensors with rank at most 7 are supported"
);
AddOutput
(
"Output"
,
"(Tensor)The output tensor"
);
AddAttr
<
std
::
vector
<
int
>>
(
"axis"
,
"a list of values, and the size of the list should be "
"
(vector<int>)
a list of values, and the size of the list should be "
"the same with the input tensor dimensions, the tensor will "
"permute the axes according the the values given"
);
AddComment
(
R"DOC(
The Tensor will be permuted according to the axis values given.
For example, given a input tensor of shape(N, C, H, W) and the axis is {0, 2, 3, 1},
The op is very much like the numpy.transpose function in python
For example:
>> input = numpy.arange(6).reshape((2,3))
>> input
array([[0, 1, 2],
[3, 4, 5]])
>> axis = [1, 0]
>> output = input.transpose(axis)
>> output
array([[0, 3],
[1, 4],
[2, 5]])
So, given a input tensor of shape(N, C, H, W) and the axis is {0, 2, 3, 1},
the output tensor shape will be (N, H, W, C)
)DOC"
);
}
...
...
@@ -78,20 +93,22 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
*
x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
out_grad_dims
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
))
->
dims
();
auto
out_dims
=
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
dims
();
PADDLE_ENFORCE
(
out_grad_dims
==
out_dims
,
"Out@GRAD dims must equal to Input(X) dims"
);
x_grad
->
Resize
(
x_dims
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Input"
),
"Input(Input) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Output"
)),
"Input(Output@GRAD) should not be null"
);
auto
input_dims
=
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
dims
();
auto
*
input_grad
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Input"
));
auto
output_grad_dims
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
))
->
dims
();
auto
output_dims
=
ctx
.
Input
<
Tensor
>
(
"Output"
)
->
dims
();
PADDLE_ENFORCE
(
output_grad_dims
==
output_dims
,
"Output@GRAD dims must equal to Input(Input) dims"
);
input_grad
->
Resize
(
input_dims
);
}
};
...
...
paddle/operators/transpose_op.cu
浏览文件 @
5ede6fd4
...
...
@@ -12,118 +12,11 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/operators/transpose_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
transpose_kernel
(
int
nthreads
,
const
T
*
in_data
,
T
*
out_data
,
int
*
offset_buffer
,
int
ndims
)
{
int
*
in_offset
=
offset_buffer
;
int
*
out_offset
=
offset_buffer
+
ndims
;
int
*
axis
=
offset_buffer
+
ndims
*
2
;
int
to_index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
to_index
<
nthreads
)
{
int
from_index
=
0
;
int
temp
=
to_index
;
for
(
size_t
i
=
0
;
i
<
ndims
;
i
++
)
{
from_index
+=
(
temp
/
out_offset
[
i
])
*
in_offset
[
axis
[
i
]];
temp
=
temp
%
out_offset
[
i
];
}
out_data
[
to_index
]
=
in_data
[
from_index
];
}
}
template
<
typename
T
>
void
TransposeCUDA
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
in
,
framework
::
Tensor
&
out
,
std
::
vector
<
int
>
axis
)
{
auto
*
in_data
=
in
.
template
data
<
T
>();
auto
*
out_data
=
out
.
template
mutable_data
<
T
>(
context
.
GetPlace
());
auto
in_dim
=
in
.
dims
();
auto
out_dim
=
out
.
dims
();
auto
data_size
=
product
(
in_dim
);
size_t
ndims
=
in_dim
.
size
();
std
::
vector
<
int
>
in_offset
(
ndims
,
1
);
std
::
vector
<
int
>
out_offset
(
ndims
,
1
);
auto
cpu_place
=
platform
::
CPUPlace
();
auto
gpu_place
=
boost
::
get
<
platform
::
GPUPlace
>
(
context
.
GetPlace
());
// Get a host_buffer to cache the input offset, output offset and the axis.
std
::
vector
<
int64_t
>
buffer_dim_shape
(
1
,
ndims
*
3
);
auto
buffer_dims
=
framework
::
make_ddim
(
buffer_dim_shape
);
framework
::
Tensor
host_buffer
;
int
*
host_buffer_data
=
host_buffer
.
mutable_data
<
int
>
(
buffer_dims
,
cpu_place
);
for
(
int
i
=
ndims
-
2
;
i
>=
0
;
i
--
)
{
in_offset
[
i
]
=
in_offset
[
i
+
1
]
*
in_dim
[
i
+
1
];
out_offset
[
i
]
=
out_offset
[
i
+
1
]
*
out_dim
[
i
+
1
];
}
// copy the data to the host_buffer
for
(
int
i
=
0
;
i
<
ndims
;
i
++
)
{
host_buffer_data
[
i
]
=
in_offset
[
i
];
host_buffer_data
[
i
+
ndims
]
=
out_offset
[
i
];
host_buffer_data
[
i
+
ndims
*
2
]
=
axis
[
i
];
}
// Get a device_buffer to cache the input offset, output offset and the axis.
auto
offset_buffer
=
memory
::
Alloc
(
gpu_place
,
ndims
*
3
*
sizeof
(
int
));
auto
*
cuda_device_context
=
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
));
// copy the host_buffer data to the device_buffer
memory
::
Copy
(
gpu_place
,
offset_buffer
,
cpu_place
,
host_buffer_data
,
ndims
*
3
*
sizeof
(
int
),
cuda_device_context
->
stream
());
int
block
=
512
;
int
grid
=
(
data_size
+
block
-
1
)
/
block
;
transpose_kernel
<
T
><<<
grid
,
block
>>>
(
data_size
,
in_data
,
out_data
,
static_cast
<
int
*>
(
offset_buffer
),
ndims
);
memory
::
Free
(
gpu_place
,
offset_buffer
);
}
template
<
typename
T
>
class
TransposeCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
context
.
GetPlace
()),
"It must use GPUPlace."
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
TransposeCUDA
<
T
>
(
context
,
*
in
,
*
out
,
axis
);
}
};
template
<
typename
T
>
class
TransposeGradCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
context
.
GetPlace
()),
"It must use GPUPlace."
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
axis_temp
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
std
::
vector
<
int
>
axis
(
axis_temp
);
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
i
++
)
{
axis
[
axis_temp
[
i
]]
=
i
;
}
TransposeCUDA
<
T
>
(
context
,
*
in
,
*
out
,
axis
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
transpose
,
ops
::
TransposeCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
transpose_grad
,
ops
::
TransposeGradCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/transpose_op.h
浏览文件 @
5ede6fd4
...
...
@@ -20,41 +20,10 @@
namespace
paddle
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
void
NaiveCpuTranspose
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
in
,
framework
::
Tensor
&
out
,
std
::
vector
<
int
>
axis
)
{
auto
in_data
=
in
.
data
<
T
>
();
auto
out_data
=
out
.
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
in_dim
=
in
.
dims
();
auto
out_dim
=
out
.
dims
();
size_t
ndims
=
in_dim
.
size
();
std
::
vector
<
int
>
in_offset
(
ndims
,
1
);
std
::
vector
<
int
>
out_offset
(
ndims
,
1
);
for
(
int
i
=
ndims
-
2
;
i
>=
0
;
i
--
)
{
in_offset
[
i
]
=
in_offset
[
i
+
1
]
*
in_dim
[
i
+
1
];
out_offset
[
i
]
=
out_offset
[
i
+
1
]
*
out_dim
[
i
+
1
];
}
size_t
data_size
=
product
(
in_dim
);
for
(
size_t
to_index
=
0
;
to_index
<
data_size
;
to_index
++
)
{
int
from_index
=
0
;
int
temp
=
to_index
;
for
(
size_t
i
=
0
;
i
<
ndims
;
i
++
)
{
from_index
+=
(
temp
/
out_offset
[
i
])
*
in_offset
[
axis
[
i
]];
temp
=
temp
%
out_offset
[
i
];
}
out_data
[
to_index
]
=
in_data
[
from_index
];
}
}
template
<
typename
Place
,
typename
T
,
int
Dims
>
void
Do
Transpose
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
in
,
framework
::
Tensor
&
out
,
std
::
vector
<
int
>
axis
)
{
void
Eigen
Transpose
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
in
,
framework
::
Tensor
&
out
,
std
::
vector
<
int
>
axis
)
{
Eigen
::
array
<
int
,
Dims
>
permute
;
for
(
int
i
=
0
;
i
<
Dims
;
i
++
)
{
permute
[
i
]
=
axis
[
i
];
...
...
@@ -72,28 +41,32 @@ template <typename Place, typename T>
class
TransposeKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X
"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"O
ut"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
in
put
=
context
.
Input
<
framework
::
Tensor
>
(
"Input
"
);
auto
*
out
put
=
context
.
Output
<
framework
::
Tensor
>
(
"Outp
ut"
);
out
put
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
axis
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
int
ndims
=
axis
.
size
();
switch
(
ndims
)
{
case
1
:
break
;
case
2
:
DoTranspose
<
Place
,
T
,
2
>
(
context
,
*
in
,
*
o
ut
,
axis
);
EigenTranspose
<
Place
,
T
,
2
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
break
;
case
3
:
DoTranspose
<
Place
,
T
,
3
>
(
context
,
*
in
,
*
o
ut
,
axis
);
EigenTranspose
<
Place
,
T
,
3
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
break
;
case
4
:
DoTranspose
<
Place
,
T
,
4
>
(
context
,
*
in
,
*
o
ut
,
axis
);
EigenTranspose
<
Place
,
T
,
4
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
break
;
case
5
:
DoTranspose
<
Place
,
T
,
5
>
(
context
,
*
in
,
*
o
ut
,
axis
);
EigenTranspose
<
Place
,
T
,
5
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
break
;
default
:
NaiveCpuTranspose
<
Place
,
T
>
(
context
,
*
in
,
*
o
ut
,
axis
);
case
6
:
EigenTranspose
<
Place
,
T
,
6
>
(
context
,
*
input
,
*
outp
ut
,
axis
);
break
;
default:
PADDLE_THROW
(
"Tensors with rank at most 6 are supported"
);
}
}
};
...
...
@@ -102,9 +75,11 @@ template <typename Place, typename T>
class
TransposeGradKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
output_grad
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
auto
*
input_grad
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
axis_temp
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"axis"
);
std
::
vector
<
int
>
axis
(
axis_temp
);
...
...
@@ -116,21 +91,25 @@ class TransposeGradKernel : public framework::OpKernel {
int
ndims
=
axis
.
size
();
switch
(
ndims
)
{
case
1
:
break
;
case
2
:
DoTranspose
<
Place
,
T
,
2
>
(
context
,
*
in
,
*
out
,
axis
);
EigenTranspose
<
Place
,
T
,
2
>
(
context
,
*
output_grad
,
*
input_grad
,
axis
);
break
;
case
3
:
DoTranspose
<
Place
,
T
,
3
>
(
context
,
*
in
,
*
out
,
axis
);
EigenTranspose
<
Place
,
T
,
3
>
(
context
,
*
output_grad
,
*
input_grad
,
axis
);
break
;
case
4
:
DoTranspose
<
Place
,
T
,
4
>
(
context
,
*
in
,
*
out
,
axis
);
EigenTranspose
<
Place
,
T
,
4
>
(
context
,
*
output_grad
,
*
input_grad
,
axis
);
break
;
case
5
:
DoTranspose
<
Place
,
T
,
5
>
(
context
,
*
in
,
*
out
,
axis
);
EigenTranspose
<
Place
,
T
,
5
>
(
context
,
*
output_grad
,
*
input_grad
,
axis
);
break
;
default
:
NaiveCpuTranspose
<
Place
,
T
>
(
context
,
*
in
,
*
out
,
axis
);
case
6
:
EigenTranspose
<
Place
,
T
,
6
>
(
context
,
*
output_grad
,
*
input_grad
,
axis
);
break
;
default:
PADDLE_THROW
(
"Tensors with rank at most 6 are supported"
);
}
}
};
...
...
python/paddle/v2/framework/tests/test_transpose_op.py
浏览文件 @
5ede6fd4
import
unittest
import
numpy
as
np
from
gradient_checker
import
GradientChecker
from
op_test_util
import
OpTestMeta
from
paddle.v2.framework.op
import
Operator
from
op_test
import
OpTest
class
TestTransposeOp
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
class
TestTransposeOp
(
OpTest
):
def
setUp
(
self
):
self
.
type
=
"transpose"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
3
,
4
)).
astype
(
"float32"
),
}
self
.
attrs
=
{
'axis'
:
[
1
,
0
]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
transpose
((
1
,
0
))}
self
.
initTestCase
()
self
.
op_type
=
"transpose"
self
.
inputs
=
{
'Input'
:
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
'axis'
:
list
(
self
.
axis
)}
self
.
outputs
=
{
'Output'
:
self
.
inputs
[
'Input'
].
transpose
(
self
.
axis
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'Input'
],
'Output'
)
def
initTestCase
(
self
):
self
.
shape
=
(
3
,
4
)
self
.
axis
=
(
1
,
0
)
class
TestCase1
(
TestTransposeOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
3
,
4
,
5
)
self
.
axis
=
(
0
,
2
,
1
)
class
TestCase2
(
TestTransposeOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
2
,
3
,
4
,
5
)
self
.
axis
=
(
0
,
2
,
3
,
1
)
class
TestCase3
(
TestTransposeOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
2
,
3
,
4
,
5
,
6
)
self
.
axis
=
(
4
,
2
,
3
,
1
,
0
)
class
TransposeGradOpTest
(
GradientChecker
):
def
test_transpose
(
self
):
op
=
Operator
(
"transpose"
,
X
=
"X"
,
Out
=
"Out"
,
axis
=
[
1
,
0
])
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
84
)).
astype
(
"float32"
),
}
self
.
check_grad
(
op
,
inputs
,
set
([
"X"
]),
"Out"
,
max_relative_error
=
0.5
)
class
TestCase4
(
TestTransposeOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
2
,
3
,
4
,
5
,
6
,
1
)
self
.
axis
=
(
4
,
2
,
3
,
1
,
0
,
5
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录