Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
681d908e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
681d908e
编写于
1月 14, 2020
作者:
F
FlyingQianMM
提交者:
GitHub
1月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add backward gradient computation for op argsort. cherry-pick #22203. test=release/1.7 (#22233)
上级
c63a63d5
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
287 addition
and
6 deletion
+287
-6
paddle/fluid/operators/argsort_op.cc
paddle/fluid/operators/argsort_op.cc
+50
-5
paddle/fluid/operators/argsort_op.cu
paddle/fluid/operators/argsort_op.cu
+123
-0
paddle/fluid/operators/argsort_op.h
paddle/fluid/operators/argsort_op.h
+97
-0
python/paddle/fluid/tests/unittests/test_argsort_op.py
python/paddle/fluid/tests/unittests/test_argsort_op.py
+17
-1
未找到文件。
paddle/fluid/operators/argsort_op.cc
浏览文件 @
681d908e
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/argsort_op.h"
#include "paddle/fluid/operators/argsort_op.h"
#include <memory>
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -21,7 +22,7 @@ class ArgsortOp : public framework::OperatorWithKernel {
...
@@ -21,7 +22,7 @@ class ArgsortOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ArgsortOp should not be null."
);
"Input(X) of ArgsortOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -49,6 +50,24 @@ class ArgsortOp : public framework::OperatorWithKernel {
...
@@ -49,6 +50,24 @@ class ArgsortOp : public framework::OperatorWithKernel {
}
}
};
};
class
ArgsortGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*-->*/
framework
::
GradVarName
(
"X"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
)),
ctx
.
device_context
());
}
};
class
ArgsortOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
ArgsortOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
...
@@ -83,16 +102,42 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
...
@@ -83,16 +102,42 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
}
}
};
};
template
<
typename
T
>
class
ArgsortGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
std
::
unique_ptr
<
T
>
Apply
()
const
override
{
std
::
unique_ptr
<
T
>
op
(
new
T
());
op
->
SetType
(
"argsort_grad"
);
op
->
SetInput
(
"Indices"
,
this
->
Output
(
"Indices"
));
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetAttrMap
(
this
->
Attrs
());
return
op
;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE
(
ArgsortGradNoNeedBufferVarInference
,
"X"
);
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
argsort
,
ops
::
ArgsortOp
,
ops
::
ArgsortOpMaker
,
argsort
,
ops
::
ArgsortOp
,
ops
::
ArgsortOpMaker
,
ops
::
ArgsortGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
ArgsortGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
argsort_grad
,
ops
::
ArgsortGradOp
,
ops
::
ArgsortGradNoNeedBufferVarInference
);
REGISTER_OP_CPU_KERNEL
(
argsort
,
REGISTER_OP_CPU_KERNEL
(
argsort
,
ops
::
ArgsortKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
ArgsortKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
ArgsortKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
,
ops
::
ArgsortKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
,
ops
::
ArgsortKernel
<
paddle
::
platform
::
CPUPlace
,
int
>
,
ops
::
ArgsortKernel
<
paddle
::
platform
::
CPUPlace
,
int
>
,
ops
::
ArgsortKernel
<
paddle
::
platform
::
CPUPlace
,
int64_t
>
);
ops
::
ArgsortKernel
<
paddle
::
platform
::
CPUPlace
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
argsort_grad
,
ops
::
ArgsortGradientKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
ArgsortGradientKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
,
ops
::
ArgsortGradientKernel
<
paddle
::
platform
::
CPUPlace
,
int
>
,
ops
::
ArgsortGradientKernel
<
paddle
::
platform
::
CPUPlace
,
int64_t
>
);
paddle/fluid/operators/argsort_op.cu
浏览文件 @
681d908e
...
@@ -58,6 +58,19 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
...
@@ -58,6 +58,19 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
}
}
}
template
<
typename
T
,
typename
IndType
>
static
__global__
void
FillGrad
(
const
T
*
dO
,
const
IndType
*
indices
,
T
*
dX
,
IndType
num_rows
,
IndType
num_cols
)
{
int
col_id
=
threadIdx
.
x
;
int
row_id
=
blockIdx
.
x
;
for
(
IndType
j
=
row_id
;
j
<
num_rows
;
j
+=
gridDim
.
x
)
{
for
(
IndType
i
=
col_id
;
i
<
num_cols
;
i
+=
blockDim
.
x
)
{
dX
[
j
*
num_cols
+
indices
[
j
*
num_cols
+
i
]]
=
dO
[
j
*
num_cols
+
i
];
}
}
}
// Sort by flag descending, True: descending. False: Ascending.
// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
// Default is false.
template
<
typename
T
,
typename
IndType
>
template
<
typename
T
,
typename
IndType
>
...
@@ -160,6 +173,35 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
...
@@ -160,6 +173,35 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
temp_storage_bytes
,
cudaGetErrorString
(
err
));
temp_storage_bytes
,
cudaGetErrorString
(
err
));
}
}
template
<
typename
T
,
typename
IndType
>
void
ArgFullAssign
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
*
dO
,
const
Tensor
*
indices
,
Tensor
*
dX
,
const
IndType
num_rows
,
const
IndType
num_cols
)
{
auto
cu_stream
=
ctx
.
stream
();
auto
ComputeBlockSize
=
[](
IndType
col
)
{
if
(
col
>
512
)
return
1024
;
else
if
(
col
>
256
&&
col
<=
512
)
return
512
;
else
if
(
col
>
128
&&
col
<=
256
)
return
256
;
else
if
(
col
>
64
&&
col
<=
128
)
return
128
;
else
return
64
;
};
int
block_size
=
ComputeBlockSize
(
num_cols
);
int
maxGridDimX
=
ctx
.
GetCUDAMaxGridDimSize
().
x
;
// actually, int num_rows < max_grid_size
int
grid_size
=
num_rows
<
maxGridDimX
?
num_rows
:
maxGridDimX
;
FillGrad
<<<
grid_size
,
block_size
,
0
,
cu_stream
>>>
(
dO
->
data
<
T
>
(),
indices
->
data
<
IndType
>
(),
dX
->
data
<
T
>
(),
num_rows
,
num_cols
);
}
template
<
typename
T
>
template
<
typename
T
>
class
ArgsortOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ArgsortOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -234,6 +276,81 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -234,6 +276,81 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
}
}
};
};
template
<
typename
T
>
class
ArgsortGradOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
indices
=
ctx
.
Input
<
Tensor
>
(
"Indices"
);
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dO
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
dX
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dxt
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dX
);
auto
&
place
=
*
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>()
.
eigen_device
();
dxt
.
device
(
place
)
=
dxt
.
constant
(
static_cast
<
T
>
(
0
));
if
(
dO
->
numel
()
==
0
)
return
;
auto
in_dims
=
indices
->
dims
();
axis
=
(
axis
<
0
)
?
(
in_dims
.
size
()
+
axis
)
:
axis
;
int64_t
numel
=
indices
->
numel
();
// Special case for full sort, speedup ~190x.
if
(
axis
==
-
1
||
axis
+
1
==
in_dims
.
size
())
{
const
int64_t
input_height
=
framework
::
product
(
framework
::
slice_ddim
(
in_dims
,
0
,
in_dims
.
size
()
-
1
));
const
int64_t
input_width
=
in_dims
[
in_dims
.
size
()
-
1
];
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
ArgFullAssign
<
T
,
int64_t
>
(
dev_ctx
,
dO
,
indices
,
dX
,
input_height
,
input_width
);
}
else
{
// if not full sort, do transpose first
std
::
vector
<
int
>
trans
;
for
(
int
i
=
0
;
i
<
axis
;
i
++
)
{
trans
.
push_back
(
i
);
}
trans
.
push_back
(
in_dims
.
size
()
-
1
);
for
(
int
i
=
axis
+
1
;
i
<
in_dims
.
size
()
-
1
;
i
++
)
{
trans
.
push_back
(
i
);
}
trans
.
push_back
(
axis
);
framework
::
DDim
trans_dims
(
in_dims
);
for
(
int
i
=
0
;
i
<
trans
.
size
();
i
++
)
{
trans_dims
[
i
]
=
in_dims
[
trans
[
i
]];
}
Tensor
trans_dO
;
trans_dO
.
mutable_data
<
T
>
(
trans_dims
,
ctx
.
GetPlace
());
Tensor
trans_ind
;
trans_ind
.
mutable_data
<
int64_t
>
(
trans_dims
,
ctx
.
GetPlace
());
int
ndims
=
trans
.
size
();
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
// Do transpose
TransCompute
<
platform
::
CUDADeviceContext
,
T
>
(
ndims
,
dev_ctx
,
*
dO
,
&
trans_dO
,
trans
);
TransCompute
<
platform
::
CUDADeviceContext
,
int64_t
>
(
ndims
,
dev_ctx
,
*
indices
,
&
trans_ind
,
trans
);
const
int64_t
input_height
=
framework
::
product
(
framework
::
slice_ddim
(
trans_dims
,
0
,
trans_dims
.
size
()
-
1
));
const
int64_t
input_width
=
trans_dims
[
trans_dims
.
size
()
-
1
];
Tensor
tmp_out
;
tmp_out
.
mutable_data
<
T
>
(
trans_dims
,
ctx
.
GetPlace
());
ArgFullAssign
<
T
,
int64_t
>
(
dev_ctx
,
&
trans_dO
,
&
trans_ind
,
&
tmp_out
,
input_height
,
input_width
);
// transpose back
TransCompute
<
platform
::
CUDADeviceContext
,
T
>
(
ndims
,
dev_ctx
,
tmp_out
,
dX
,
trans
);
return
;
}
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
@@ -243,3 +360,9 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -243,3 +360,9 @@ REGISTER_OP_CUDA_KERNEL(
paddle
::
operators
::
ArgsortOpCUDAKernel
<
int
>
,
paddle
::
operators
::
ArgsortOpCUDAKernel
<
int
>
,
paddle
::
operators
::
ArgsortOpCUDAKernel
<
int64_t
>
,
paddle
::
operators
::
ArgsortOpCUDAKernel
<
int64_t
>
,
paddle
::
operators
::
ArgsortOpCUDAKernel
<
paddle
::
platform
::
float16
>
);
paddle
::
operators
::
ArgsortOpCUDAKernel
<
paddle
::
platform
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
argsort_grad
,
paddle
::
operators
::
ArgsortGradOpCUDAKernel
<
float
>
,
paddle
::
operators
::
ArgsortGradOpCUDAKernel
<
double
>
,
paddle
::
operators
::
ArgsortGradOpCUDAKernel
<
int
>
,
paddle
::
operators
::
ArgsortGradOpCUDAKernel
<
int64_t
>
,
paddle
::
operators
::
ArgsortGradOpCUDAKernel
<
paddle
::
platform
::
float16
>
);
paddle/fluid/operators/argsort_op.h
浏览文件 @
681d908e
...
@@ -68,6 +68,31 @@ static void FullSort(Type input_height, Type input_width, int input_dim,
...
@@ -68,6 +68,31 @@ static void FullSort(Type input_height, Type input_width, int input_dim,
}
}
}
}
}
}
template
<
typename
T
,
typename
Type
>
static
void
FullAssign
(
Type
input_height
,
Type
input_width
,
int
input_dim
,
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
indices
,
T
*
t_out
)
{
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for
(
Type
i
=
0
;
i
<
input_height
;
++
i
)
{
if
(
input_dim
==
1
)
{
auto
e_input
=
EigenVector
<
T
>::
Flatten
(
*
input
);
auto
e_indices
=
EigenVector
<
Type
>::
Flatten
(
*
indices
);
for
(
Type
j
=
0
;
j
<
input_width
;
++
j
)
{
t_out
[
i
*
input_width
+
e_indices
(
j
)]
=
e_input
(
e_indices
(
j
));
}
}
else
{
auto
e_input
=
EigenMatrix
<
T
>::
Reshape
(
*
input
,
input_dim
-
1
);
auto
e_indices
=
EigenMatrix
<
Type
>::
Reshape
(
*
indices
,
input_dim
-
1
);
for
(
Type
j
=
0
;
j
<
input_width
;
++
j
)
{
t_out
[
i
*
input_width
+
e_indices
(
i
,
j
)]
=
e_input
(
i
,
e_indices
(
i
,
j
));
}
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
ArgsortKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ArgsortKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -142,5 +167,77 @@ class ArgsortKernel : public framework::OpKernel<T> {
...
@@ -142,5 +167,77 @@ class ArgsortKernel : public framework::OpKernel<T> {
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
class
ArgsortGradientKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
indices
=
ctx
.
Input
<
Tensor
>
(
"Indices"
);
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dO
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
in_dims
=
indices
->
dims
();
axis
=
(
axis
<
0
)
?
(
in_dims
.
size
()
+
axis
)
:
axis
;
dX
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dxt
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dX
);
auto
&
place
=
*
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>()
.
eigen_device
();
dxt
.
device
(
place
)
=
dxt
.
constant
(
static_cast
<
T
>
(
0
));
if
(
dO
->
numel
()
==
0
)
return
;
// Do full assign
if
(
axis
==
-
1
||
axis
+
1
==
in_dims
.
size
())
{
const
int64_t
input_height
=
framework
::
product
(
framework
::
slice_ddim
(
in_dims
,
0
,
in_dims
.
size
()
-
1
));
const
int64_t
input_width
=
in_dims
[
in_dims
.
size
()
-
1
];
FullAssign
<
T
,
int64_t
>
(
input_height
,
input_width
,
in_dims
.
size
(),
dO
,
indices
,
dX
->
data
<
T
>
());
}
else
{
// If not full assign do transpose
std
::
vector
<
int
>
trans
;
for
(
int
i
=
0
;
i
<
axis
;
i
++
)
{
trans
.
push_back
(
i
);
}
trans
.
push_back
(
in_dims
.
size
()
-
1
);
for
(
int
i
=
axis
+
1
;
i
<
in_dims
.
size
()
-
1
;
i
++
)
{
trans
.
push_back
(
i
);
}
trans
.
push_back
(
axis
);
framework
::
DDim
trans_dims
(
in_dims
);
for
(
size_t
i
=
0
;
i
<
trans
.
size
();
i
++
)
{
trans_dims
[
i
]
=
in_dims
[
trans
[
i
]];
}
Tensor
trans_dO
;
trans_dO
.
mutable_data
<
T
>
(
trans_dims
,
ctx
.
GetPlace
());
Tensor
trans_ind
;
trans_ind
.
mutable_data
<
int64_t
>
(
trans_dims
,
ctx
.
GetPlace
());
int
ndims
=
trans
.
size
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>();
// Do transpose
TransCompute
<
platform
::
CPUDeviceContext
,
T
>
(
ndims
,
dev_ctx
,
*
dO
,
&
trans_dO
,
trans
);
TransCompute
<
platform
::
CPUDeviceContext
,
int64_t
>
(
ndims
,
dev_ctx
,
*
indices
,
&
trans_ind
,
trans
);
const
int64_t
input_height
=
framework
::
product
(
framework
::
slice_ddim
(
trans_dims
,
0
,
trans_dims
.
size
()
-
1
));
const
int64_t
input_width
=
trans_dims
[
trans_dims
.
size
()
-
1
];
Tensor
tmp_out
;
T
*
t_out
=
tmp_out
.
mutable_data
<
T
>
(
trans_dims
,
ctx
.
GetPlace
());
FullAssign
<
T
,
int64_t
>
(
input_height
,
input_width
,
in_dims
.
size
(),
&
trans_dO
,
&
trans_ind
,
t_out
);
// transpose back
TransCompute
<
platform
::
CPUDeviceContext
,
T
>
(
ndims
,
dev_ctx
,
tmp_out
,
dX
,
trans
);
}
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_argsort_op.py
100755 → 100644
浏览文件 @
681d908e
...
@@ -48,7 +48,7 @@ class TestArgsortOp(OpTest):
...
@@ -48,7 +48,7 @@ class TestArgsortOp(OpTest):
self
.
axis
=
-
1
self
.
axis
=
-
1
def
init_datatype
(
self
):
def
init_datatype
(
self
):
self
.
dtype
=
"float
32
"
self
.
dtype
=
"float
64
"
def
init_direction
(
self
):
def
init_direction
(
self
):
self
.
descending
=
False
self
.
descending
=
False
...
@@ -56,6 +56,9 @@ class TestArgsortOp(OpTest):
...
@@ -56,6 +56,9 @@ class TestArgsortOp(OpTest):
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestArgsortOpAxis0
(
TestArgsortOp
):
class
TestArgsortOpAxis0
(
TestArgsortOp
):
def
init_axis
(
self
):
def
init_axis
(
self
):
...
@@ -146,5 +149,18 @@ class TestArgsortOpDescendingAxisNeg2(TestArgsortOpAxisNeg2):
...
@@ -146,5 +149,18 @@ class TestArgsortOpDescendingAxisNeg2(TestArgsortOpAxisNeg2):
self
.
descending
=
True
self
.
descending
=
True
class
TestArgsortOpFP32Axis
(
TestArgsortOp
):
def
init_datatype
(
self
):
self
.
dtype
=
"float32"
class
TestArgsortOpFP32DescendingAxis
(
TestArgsortOp
):
def
init_datatype
(
self
):
self
.
dtype
=
"float32"
def
init_direction
(
self
):
self
.
descending
=
True
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录