Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
00e78bf6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
00e78bf6
编写于
5月 09, 2020
作者:
W
wilfChen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
gpu support MinimumGrad & MaximumGrad kernel
上级
bab6e0f5
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
729 addition
and
1 deletion
+729
-1
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cu
+116
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cuh
+38
-0
mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.cc
+38
-0
mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h
+149
-0
tests/st/ops/gpu/test_maximum_op.py
tests/st/ops/gpu/test_maximum_op.py
+168
-1
tests/st/ops/gpu/test_minimum_op.py
tests/st/ops/gpu/test_minimum_op.py
+220
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cu
0 → 100644
浏览文件 @
00e78bf6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/cuda_impl/broadcast_grad_impl.cuh"
#include "device/gpu/cuda_common.h"
template
<
typename
T
>
struct
MinimumGradFunc
{
__device__
__forceinline__
void
operator
()(
const
T
&
x1
,
const
T
&
x2
,
const
T
&
dy
,
T
*
dx1
,
T
*
dx2
)
{
if
(
x1
<
x2
)
{
atomicAdd
(
dx1
,
dy
);
}
else
{
atomicAdd
(
dx2
,
dy
);
}
}
};
template
<
typename
T
>
struct
MaximumGradFunc
{
__device__
__forceinline__
void
operator
()(
const
T
&
x1
,
const
T
&
x2
,
const
T
&
dy
,
T
*
dx1
,
T
*
dx2
)
{
if
(
x1
>
x2
)
{
atomicAdd
(
dx1
,
dy
);
}
else
{
atomicAdd
(
dx2
,
dy
);
}
}
};
__device__
__forceinline__
int
Index
(
const
int
&
index
,
const
int
&
dim
)
{
return
dim
==
1
?
0
:
index
;
}
template
<
typename
T
,
typename
Func
>
__device__
__forceinline__
void
BroadcastGradOperator
(
const
int
&
l0
,
const
int
&
l1
,
const
int
&
l2
,
const
int
&
l3
,
const
int
&
r0
,
const
int
&
r1
,
const
int
&
r2
,
const
int
&
r3
,
const
int
&
d0
,
const
int
&
d1
,
const
int
&
d2
,
const
int
&
d3
,
const
T
*
x1
,
const
T
*
x2
,
const
T
*
dy
,
T
*
dx1
,
T
*
dx2
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
d0
*
d1
*
d2
*
d3
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
i
=
pos
/
(
d1
*
d2
*
d3
)
%
d0
;
int
j
=
pos
/
(
d2
*
d3
)
%
d1
;
int
k
=
pos
/
d3
%
d2
;
int
l
=
pos
%
d3
;
int
l_index
=
Index
(
i
,
l0
)
*
l1
*
l2
*
l3
+
Index
(
j
,
l1
)
*
l2
*
l3
+
Index
(
k
,
l2
)
*
l3
+
Index
(
l
,
l3
);
int
r_index
=
Index
(
i
,
r0
)
*
r1
*
r2
*
r3
+
Index
(
j
,
r1
)
*
r2
*
r3
+
Index
(
k
,
r2
)
*
r3
+
Index
(
l
,
r3
);
Func
()(
x1
[
l_index
],
x2
[
r_index
],
dy
[
pos
],
dx1
+
l_index
,
dx2
+
r_index
);
}
}
template
<
typename
T
>
__global__
void
BroadcastGradKernel
(
const
int
l0
,
const
int
l1
,
const
int
l2
,
const
int
l3
,
const
int
r0
,
const
int
r1
,
const
int
r2
,
const
int
r3
,
const
int
d0
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
enum
BroadcastGradOpType
op
,
const
T
*
x1
,
const
T
*
x2
,
const
T
*
dy
,
T
*
dx1
,
T
*
dx2
)
{
switch
(
op
)
{
case
BROADCAST_GRAD_TYPE_MINIMUM
:
return
BroadcastGradOperator
<
T
,
MinimumGradFunc
<
T
>>
(
l0
,
l1
,
l2
,
l3
,
r0
,
r1
,
r2
,
r3
,
d0
,
d1
,
d2
,
d3
,
x1
,
x2
,
dy
,
dx1
,
dx2
);
case
BROADCAST_GRAD_TYPE_MAXIMUM
:
return
BroadcastGradOperator
<
T
,
MaximumGradFunc
<
T
>>
(
l0
,
l1
,
l2
,
l3
,
r0
,
r1
,
r2
,
r3
,
d0
,
d1
,
d2
,
d3
,
x1
,
x2
,
dy
,
dx1
,
dx2
);
}
}
template
<
typename
T
>
void
BroadcastGrad
(
const
int
&
l0
,
const
int
&
l1
,
const
int
&
l2
,
const
int
&
l3
,
const
int
&
r0
,
const
int
&
r1
,
const
int
&
r2
,
const
int
&
r3
,
const
int
&
d0
,
const
int
&
d1
,
const
int
&
d2
,
const
int
&
d3
,
enum
BroadcastGradOpType
op
,
const
T
*
x1
,
const
T
*
x2
,
const
T
*
dy
,
T
*
dx1
,
T
*
dx2
,
cudaStream_t
stream
)
{
int
size
=
d0
*
d1
*
d2
*
d3
;
BroadcastGradKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
stream
>>>
(
l0
,
l1
,
l2
,
l3
,
r0
,
r1
,
r2
,
r3
,
d0
,
d1
,
d2
,
d3
,
op
,
x1
,
x2
,
dy
,
dx1
,
dx2
);
}
template
<
typename
T
,
typename
Func
>
__device__
__forceinline__
void
NoBroadcastOperator
(
const
int
&
nums
,
const
T
*
x1
,
const
T
*
x2
,
const
T
*
dy
,
T
*
dx1
,
T
*
dx2
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
nums
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
Func
()(
x1
[
pos
],
x2
[
pos
],
dy
[
pos
],
dx1
+
pos
,
dx2
+
pos
);
}
}
template
<
typename
T
>
__global__
void
NoBroadcastGradKernel
(
const
int
nums
,
enum
BroadcastGradOpType
op
,
const
T
*
x1
,
const
T
*
x2
,
const
T
*
dy
,
T
*
dx1
,
T
*
dx2
)
{
switch
(
op
)
{
case
BROADCAST_GRAD_TYPE_MINIMUM
:
return
NoBroadcastOperator
<
T
,
MinimumGradFunc
<
T
>>
(
nums
,
x1
,
x2
,
dy
,
dx1
,
dx2
);
case
BROADCAST_GRAD_TYPE_MAXIMUM
:
return
NoBroadcastOperator
<
T
,
MaximumGradFunc
<
T
>>
(
nums
,
x1
,
x2
,
dy
,
dx1
,
dx2
);
}
}
template
<
typename
T
>
void
NoBroadcastGrad
(
const
int
&
nums
,
enum
BroadcastGradOpType
op
,
const
T
*
x1
,
const
T
*
x2
,
const
T
*
dy
,
T
*
dx1
,
T
*
dx2
,
cudaStream_t
stream
)
{
NoBroadcastGradKernel
<<<
GET_BLOCKS
(
nums
),
GET_THREADS
,
0
,
stream
>>>
(
nums
,
op
,
x1
,
x2
,
dy
,
dx1
,
dx2
);
}
template
void
NoBroadcastGrad
(
const
int
&
nums
,
enum
BroadcastGradOpType
op
,
const
float
*
x1
,
const
float
*
x2
,
const
float
*
dy
,
float
*
dx1
,
float
*
dx2
,
cudaStream_t
stream
);
template
void
BroadcastGrad
(
const
int
&
l0
,
const
int
&
l1
,
const
int
&
l2
,
const
int
&
l3
,
const
int
&
r0
,
const
int
&
r1
,
const
int
&
r2
,
const
int
&
r3
,
const
int
&
d0
,
const
int
&
d1
,
const
int
&
d2
,
const
int
&
d3
,
enum
BroadcastGradOpType
op
,
const
float
*
x1
,
const
float
*
x2
,
const
float
*
dy
,
float
*
dx1
,
float
*
dx2
,
cudaStream_t
stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cuh
0 → 100644
浏览文件 @
00e78bf6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_
#include "device/gpu/cuda_common.h"
enum
BroadcastGradOpType
{
BROADCAST_GRAD_TYPE_MAXIMUM
=
0
,
BROADCAST_GRAD_TYPE_MINIMUM
=
1
,
BROADCAST_GRAD_TYPE_INVALID
=
0xffffffff
,
};
template
<
typename
T
>
void
BroadcastGrad
(
const
int
&
l0
,
const
int
&
l1
,
const
int
&
l2
,
const
int
&
l3
,
const
int
&
r0
,
const
int
&
r1
,
const
int
&
r2
,
const
int
&
r3
,
const
int
&
d0
,
const
int
&
d1
,
const
int
&
d2
,
const
int
&
d3
,
enum
BroadcastGradOpType
op
,
const
T
*
x1
,
const
T
*
x2
,
const
T
*
dy
,
T
*
dx1
,
T
*
dx2
,
cudaStream_t
stream
);
template
<
typename
T
>
void
NoBroadcastGrad
(
const
int
&
nums
,
enum
BroadcastGradOpType
op
,
const
T
*
x1
,
const
T
*
x2
,
const
T
*
dy
,
T
*
dx1
,
T
*
dx2
,
cudaStream_t
stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_
mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.cc
0 → 100644
浏览文件 @
00e78bf6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/math/broadcast_grad_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
MinimumGrad
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGradGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
MaximumGrad
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
BroadcastOpGradGpuKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h
0 → 100644
浏览文件 @
00e78bf6
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include <map>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/broadcast_grad_impl.cuh"
#include "kernel/gpu/kernel_constants.h"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
BroadcastOpGradGpuKernel
:
public
GpuKernel
{
public:
BroadcastOpGradGpuKernel
()
:
op_type_
(
BROADCAST_GRAD_TYPE_INVALID
),
need_broadcast_
(
false
),
input1_num_
(
1
),
input2_num_
(
1
),
output_num_
(
1
)
{}
~
BroadcastOpGradGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
T
*
x1
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
x2
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
dy
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
dx1
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
T
*
dx2
=
GetDeviceAddress
<
T
>
(
outputs
,
1
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMemsetAsync
(
dx1
,
0
,
outputs
[
0
]
->
size
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cudaMemSet Failed"
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMemsetAsync
(
dx2
,
0
,
outputs
[
1
]
->
size
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cudaMemSet Failed"
);
if
(
need_broadcast_
)
{
BroadcastGrad
(
x1_shape_
[
0
],
x1_shape_
[
1
],
x1_shape_
[
2
],
x1_shape_
[
3
],
x2_shape_
[
0
],
x2_shape_
[
1
],
x2_shape_
[
2
],
x2_shape_
[
3
],
dy_shape_
[
0
],
dy_shape_
[
1
],
dy_shape_
[
2
],
dy_shape_
[
3
],
op_type_
,
x1
,
x2
,
dy
,
dx1
,
dx2
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
NoBroadcastGrad
(
output_num_
,
op_type_
,
x1
,
x2
,
dy
,
dx1
,
dx2
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
GetOpType
(
kernel_node
);
auto
shape1
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
auto
shape2
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
1
);
auto
shape3
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
2
);
need_broadcast_
=
IsBroadcast
(
shape1
,
shape2
);
if
(
need_broadcast_
&&
shape1
.
size
()
>
4
)
{
MS_LOG
(
EXCEPTION
)
<<
"Broadcast operation not support dim greater than 4"
;
}
for
(
size_t
i
=
0
;
i
<
shape3
.
size
();
i
++
)
{
dy_shape_
[
i
]
=
shape3
[
i
];
output_num_
*=
shape3
[
i
];
}
int
offset
=
shape3
.
size
()
-
shape1
.
size
();
for
(
size_t
i
=
0
;
i
<
shape1
.
size
();
i
++
)
{
x1_shape_
[
i
+
offset
]
=
shape1
[
i
];
input1_num_
*=
shape1
[
i
];
}
offset
=
shape3
.
size
()
-
shape2
.
size
();
for
(
size_t
i
=
0
;
i
<
shape2
.
size
();
i
++
)
{
x2_shape_
[
i
+
offset
]
=
shape2
[
i
];
input2_num_
*=
shape2
[
i
];
}
InitSizeLists
();
return
true
;
}
protected:
void
InitResource
()
override
{
return
;
}
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input1_num_
*
sizeof
(
T
));
input_size_list_
.
push_back
(
input2_num_
*
sizeof
(
T
));
input_size_list_
.
push_back
(
output_num_
*
sizeof
(
T
));
output_size_list_
.
push_back
(
input1_num_
*
sizeof
(
T
));
output_size_list_
.
push_back
(
input2_num_
*
sizeof
(
T
));
}
private:
void
GetOpType
(
const
CNodePtr
&
kernel_node
)
{
std
::
string
kernel_name
=
AnfAlgo
::
GetCNodeName
(
kernel_node
);
static
std
::
map
<
std
::
string
,
BroadcastGradOpType
>
kBroadcastTypeMap
=
{
{
"MaximumGrad"
,
BROADCAST_GRAD_TYPE_MAXIMUM
},
{
"MinimumGrad"
,
BROADCAST_GRAD_TYPE_MINIMUM
},
};
auto
iter
=
kBroadcastTypeMap
.
find
(
kernel_name
);
if
(
iter
==
kBroadcastTypeMap
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"operation "
<<
kernel_name
<<
" is not supported."
;
}
else
{
op_type_
=
iter
->
second
;
}
}
bool
IsBroadcast
(
const
std
::
vector
<
size_t
>
&
lhs
,
const
std
::
vector
<
size_t
>
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
true
;
}
for
(
size_t
i
=
0
;
i
<
lhs
.
size
();
i
++
)
{
if
(
lhs
[
i
]
!=
rhs
[
i
])
{
return
true
;
}
}
return
false
;
}
BroadcastGradOpType
op_type_
;
bool
need_broadcast_
;
int
input1_num_
;
int
input2_num_
;
int
output_num_
;
int
x1_shape_
[
4
]
=
{
1
,
1
,
1
,
1
};
int
x2_shape_
[
4
]
=
{
1
,
1
,
1
,
1
};
int
dy_shape_
[
4
]
=
{
1
,
1
,
1
,
1
};
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_
tests/st/ops/gpu/test_maximum_op.py
浏览文件 @
00e78bf6
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
pytest
import
pytest
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore.nn
import
Cell
from
mindspore.nn
import
Cell
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
import
mindspore.context
as
context
import
mindspore.context
as
context
...
@@ -29,11 +30,20 @@ class Net(Cell):
...
@@ -29,11 +30,20 @@ class Net(Cell):
def
construct
(
self
,
x
,
y
):
def
construct
(
self
,
x
,
y
):
return
self
.
max
(
x
,
y
)
return
self
.
max
(
x
,
y
)
class
Grad
(
Cell
):
def
__init__
(
self
,
network
):
super
(
Grad
,
self
).
__init__
()
self
.
grad
=
C
.
GradOperation
(
name
=
"get_all"
,
get_all
=
True
,
sens_param
=
True
)
self
.
network
=
network
def
construct
(
self
,
x1
,
x2
,
sens
):
gout
=
self
.
grad
(
self
.
network
)(
x1
,
x2
,
sens
)
return
gout
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
def
test_max
():
def
test_max
imum
():
x
=
Tensor
(
np
.
array
([[
1
,
2
,
3
]]).
astype
(
np
.
float32
))
x
=
Tensor
(
np
.
array
([[
1
,
2
,
3
]]).
astype
(
np
.
float32
))
y
=
Tensor
(
np
.
array
([[
2
]]).
astype
(
np
.
float32
))
y
=
Tensor
(
np
.
array
([[
2
]]).
astype
(
np
.
float32
))
expect
=
[[
2
,
2
,
3
]]
expect
=
[[
2
,
2
,
3
]]
...
@@ -53,3 +63,160 @@ def test_max():
...
@@ -53,3 +63,160 @@ def test_max():
assert
np
.
all
(
diff
<
error
)
assert
np
.
all
(
diff
<
error
)
assert
np
.
all
(
-
diff
<
error
)
assert
np
.
all
(
-
diff
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_broadcast
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
,
device_target
=
'GPU'
)
x1_np
=
np
.
array
([[[[
0.659578
],
[
0.49113268
],
[
0.75909054
],
[
0.71681815
],
[
0.30421826
]]],
[[[
0.30322495
],
[
0.02858258
],
[
0.06398096
],
[
0.09519596
],
[
0.12498625
]]],
[[[
0.7347768
],
[
0.166469
],
[
0.328553
],
[
0.54908437
],
[
0.23673844
]]]]).
astype
(
np
.
float32
)
x2_np
=
np
.
array
([[[[
0.9154968
,
0.29014662
,
0.6492294
,
0.39918253
,
0.1648203
,
0.00861965
]],
[[
0.996885
,
0.24152198
,
0.3601213
,
0.51664376
,
0.7933056
,
0.84706444
]],
[[
0.75606346
,
0.974512
,
0.3939527
,
0.69697475
,
0.83400667
,
0.6348955
]],
[[
0.68492866
,
0.24609096
,
0.4924665
,
0.22500521
,
0.38474053
,
0.5586104
]]]]).
astype
(
np
.
float32
)
dy_np
=
np
.
array
([[[[
0.42891738
,
0.03434946
,
0.06192983
,
0.21216309
,
0.37450036
,
0.6619524
],
[
0.8583447
,
0.5765161
,
0.1468952
,
0.9975385
,
0.6908136
,
0.4903796
],
[
0.68952006
,
0.39336833
,
0.9049695
,
0.66886294
,
0.2338471
,
0.913618
],
[
0.0428149
,
0.6243054
,
0.8519898
,
0.12088962
,
0.9735885
,
0.45661286
],
[
0.41563734
,
0.41607043
,
0.4754915
,
0.32207987
,
0.33823156
,
0.47422352
]],
[[
0.64478457
,
0.22430937
,
0.7682554
,
0.46082005
,
0.8938723
,
0.20490853
],
[
0.44393885
,
0.08278944
,
0.4734108
,
0.5543551
,
0.39428464
,
0.44424313
],
[
0.12612297
,
0.76566416
,
0.71133816
,
0.81280327
,
0.20583127
,
0.54058075
],
[
0.41341263
,
0.48118508
,
0.00401995
,
0.37259838
,
0.05435474
,
0.5240658
],
[
0.4081956
,
0.48718935
,
0.9132831
,
0.67969185
,
0.0119757
,
0.8328054
]],
[[
0.91695577
,
0.95370644
,
0.263782
,
0.7477626
,
0.6448147
,
0.8080634
],
[
0.15576603
,
0.9104615
,
0.3778708
,
0.6912833
,
0.2092224
,
0.67462957
],
[
0.7087075
,
0.7888326
,
0.4672294
,
0.98221505
,
0.25210258
,
0.98920417
],
[
0.7466197
,
0.22702982
,
0.01991269
,
0.6846591
,
0.7515228
,
0.5890395
],
[
0.04531088
,
0.21740614
,
0.8406235
,
0.36480767
,
0.37733936
,
0.02914464
]],
[[
0.33069974
,
0.5497569
,
0.9896345
,
0.4167176
,
0.78057563
,
0.04659131
],
[
0.7747768
,
0.21427679
,
0.29893255
,
0.7706969
,
0.9755185
,
0.42388415
],
[
0.3910244
,
0.39381978
,
0.37065396
,
0.15558061
,
0.05012341
,
0.15870963
],
[
0.17791101
,
0.47219893
,
0.13899496
,
0.32323205
,
0.3628809
,
0.02580585
],
[
0.30274773
,
0.62890774
,
0.11024303
,
0.6980051
,
0.35346958
,
0.062852
]]],
[[[
0.6925081
,
0.74668753
,
0.80145043
,
0.06598313
,
0.665123
,
0.15073007
],
[
0.11784806
,
0.6385372
,
0.5228278
,
0.5349848
,
0.84671104
,
0.8096436
],
[
0.09516156
,
0.63298017
,
0.52382874
,
0.36734378
,
0.66497755
,
0.6019127
],
[
0.46438488
,
0.0194377
,
0.9388292
,
0.7286089
,
0.29178405
,
0.11872514
],
[
0.22101837
,
0.6164887
,
0.6139798
,
0.11711904
,
0.6227745
,
0.09701069
]],
[[
0.80480653
,
0.90034056
,
0.8633447
,
0.97415197
,
0.08309154
,
0.8446033
],
[
0.9473769
,
0.791024
,
0.26339203
,
0.01155075
,
0.2673186
,
0.7116369
],
[
0.9687511
,
0.24281934
,
0.37777108
,
0.09802654
,
0.2421312
,
0.87095344
],
[
0.6311381
,
0.23368953
,
0.0998995
,
0.4364419
,
0.9187446
,
0.5043872
],
[
0.35226053
,
0.09357589
,
0.41317305
,
0.85930043
,
0.16249318
,
0.5478765
]],
[[
0.14338651
,
0.24859418
,
0.4246941
,
0.73034066
,
0.47172204
,
0.8717199
],
[
0.05415315
,
0.78556925
,
0.99214983
,
0.7415298
,
0.673708
,
0.87817156
],
[
0.616975
,
0.42843062
,
0.05179814
,
0.1566958
,
0.04536059
,
0.70166487
],
[
0.15493333
,
0.776598
,
0.4361967
,
0.40253627
,
0.89210516
,
0.8144414
],
[
0.04816005
,
0.29696834
,
0.4586605
,
0.3419852
,
0.5595613
,
0.74093205
]],
[[
0.1388035
,
0.9168704
,
0.64287645
,
0.83864623
,
0.48026922
,
0.78323376
],
[
0.12724937
,
0.83034366
,
0.42557436
,
0.50578654
,
0.25630295
,
0.15349793
],
[
0.27256685
,
0.04547984
,
0.5385756
,
0.39270344
,
0.7661698
,
0.23722854
],
[
0.24620503
,
0.25431684
,
0.71564585
,
0.01161419
,
0.846467
,
0.7043044
],
[
0.63272387
,
0.11857849
,
0.3772076
,
0.16758402
,
0.46743023
,
0.05919575
]]],
[[[
0.18827082
,
0.8912264
,
0.6841404
,
0.74436826
,
0.9582085
,
0.1083683
],
[
0.60695344
,
0.09742349
,
0.25074378
,
0.87940735
,
0.21116392
,
0.39418384
],
[
0.744686
,
0.35679692
,
0.01308284
,
0.45166633
,
0.68166
,
0.8634658
],
[
0.7331758
,
0.21113694
,
0.3935488
,
0.87934476
,
0.70728546
,
0.09309767
],
[
0.12128611
,
0.93696386
,
0.81177396
,
0.85402405
,
0.5827289
,
0.9776509
]],
[[
0.54069614
,
0.66651285
,
0.10646132
,
0.17342485
,
0.88795924
,
0.03551182
],
[
0.25531697
,
0.87946486
,
0.74267226
,
0.89230734
,
0.95171434
,
0.94697934
],
[
0.3708397
,
0.507355
,
0.97099817
,
0.4918163
,
0.17212386
,
0.5008048
],
[
0.62530744
,
0.25210327
,
0.73966664
,
0.71555346
,
0.82484317
,
0.6094874
],
[
0.4589691
,
0.1386695
,
0.27448782
,
0.20373994
,
0.27805242
,
0.23292768
]],
[[
0.7414099
,
0.2270226
,
0.90431255
,
0.47035843
,
0.9581062
,
0.5359226
],
[
0.79603523
,
0.45549425
,
0.80858237
,
0.7705133
,
0.017761
,
0.98001194
],
[
0.06013146
,
0.99240226
,
0.33515573
,
0.04110833
,
0.41470334
,
0.7130743
],
[
0.5687417
,
0.5788611
,
0.00722461
,
0.6603336
,
0.3420471
,
0.75181854
],
[
0.4699261
,
0.51390815
,
0.343182
,
0.81498754
,
0.8942413
,
0.46532857
]],
[[
0.4589523
,
0.5534698
,
0.2825786
,
0.8205943
,
0.78258514
,
0.43154418
],
[
0.27020997
,
0.01667354
,
0.60871965
,
0.90670526
,
0.3208025
,
0.96995634
],
[
0.85337156
,
0.9711295
,
0.1381724
,
0.53670496
,
0.7347996
,
0.73380876
],
[
0.6137464
,
0.54751194
,
0.9037335
,
0.23134394
,
0.61411524
,
0.26583543
],
[
0.70770144
,
0.01813207
,
0.24718016
,
0.70329237
,
0.7062925
,
0.14399007
]]]]).
astype
(
np
.
float32
)
expect_dx1
=
np
.
array
([[[[
6.6534014
],
[
5.649811
],
[
10.071739
],
[
6.6798244
],
[
3.0426278
]]],
[[[
4.2183976
],
[
0.8096436
],
[
0.6019127
],
[
0.11872514
],
[
0.09701069
]]],
[[[
9.573029
],
[
0.60534775
],
[
3.917112
],
[
5.9021177
],
[
2.263672
]]]]).
astype
(
np
.
float32
)
expect_dx2
=
np
.
array
([[[[
6.4205275
,
2.941831
,
5.492452
,
4.3212175
,
2.4262471
,
0.
]],
[[
7.991917
,
2.3792431
,
4.9190216
,
5.2013817
,
6.348791
,
8.351772
]],
[[
5.518505
,
8.401285
,
4.691043
,
6.463884
,
7.504318
,
7.620938
]],
[[
5.2708025
,
1.2835244
,
4.1031275
,
1.9843934
,
4.9320035
,
4.537787
]]]]).
astype
(
np
.
float32
)
net
=
Grad
(
Net
())
output_ms
=
net
(
Tensor
(
x1_np
),
Tensor
(
x2_np
),
Tensor
(
dy_np
))
assert
np
.
allclose
(
output_ms
[
0
].
asnumpy
(),
expect_dx1
)
assert
np
.
allclose
(
output_ms
[
1
].
asnumpy
(),
expect_dx2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_broadcast_diff_dims
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
,
device_target
=
'GPU'
)
x1_np
=
np
.
array
([[[
0.275478
,
0.48933202
,
0.71846116
],
[
0.9803821
,
0.57205725
,
0.28511533
]],
[[
0.61111903
,
0.9671023
,
0.70624334
],
[
0.53730786
,
0.90413177
,
0.94349676
]]]).
astype
(
np
.
float32
)
x2_np
=
np
.
array
([[
0.01045662
,
0.82126397
,
0.6365063
],
[
0.9900942
,
0.6584232
,
0.98537433
]]).
astype
(
np
.
float32
)
dy_np
=
np
.
array
([[[
0.3897645
,
0.61152864
,
0.33675498
],
[
0.5303635
,
0.84893036
,
0.4959739
]],
[[
0.5391046
,
0.8443047
,
0.4174708
],
[
0.57513475
,
0.9225578
,
0.46760973
]]]).
astype
(
np
.
float32
)
expect_dx1
=
np
.
array
([[[
0.3897645
,
0.
,
0.33675498
],
[
0.
,
0.
,
0.
]],
[[
0.5391046
,
0.8443047
,
0.4174708
],
[
0.
,
0.9225578
,
0.
]]]).
astype
(
np
.
float32
)
expect_dx2
=
np
.
array
([[
0.
,
0.61152864
,
0.
],
[
1.1054983
,
0.84893036
,
0.96358365
]]).
astype
(
np
.
float32
)
net
=
Grad
(
Net
())
output_ms
=
net
(
Tensor
(
x1_np
),
Tensor
(
x2_np
),
Tensor
(
dy_np
))
assert
np
.
allclose
(
output_ms
[
0
].
asnumpy
(),
expect_dx1
)
assert
np
.
allclose
(
output_ms
[
1
].
asnumpy
(),
expect_dx2
)
tests/st/ops/gpu/test_minimum_op.py
0 → 100644
浏览文件 @
00e78bf6
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
pytest
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore.nn
import
Cell
from
mindspore.common.tensor
import
Tensor
import
mindspore.common.dtype
as
mstype
import
mindspore.context
as
context
import
numpy
as
np
class
MinimumNet
(
Cell
):
def
__init__
(
self
):
super
(
MinimumNet
,
self
).
__init__
()
self
.
min
=
P
.
Minimum
()
def
construct
(
self
,
x1
,
x2
):
x
=
self
.
min
(
x1
,
x2
)
return
x
class
Grad
(
Cell
):
def
__init__
(
self
,
network
):
super
(
Grad
,
self
).
__init__
()
self
.
grad
=
C
.
GradOperation
(
name
=
"get_all"
,
get_all
=
True
,
sens_param
=
True
)
self
.
network
=
network
def
construct
(
self
,
x1
,
x2
,
sens
):
gout
=
self
.
grad
(
self
.
network
)(
x1
,
x2
,
sens
)
return
gout
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_nobroadcast
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
,
device_target
=
'GPU'
)
x1_np
=
np
.
random
.
rand
(
3
,
4
).
astype
(
np
.
float32
)
x2_np
=
np
.
random
.
rand
(
3
,
4
).
astype
(
np
.
float32
)
dy_np
=
np
.
random
.
rand
(
3
,
4
).
astype
(
np
.
float32
)
net
=
Grad
(
MinimumNet
())
output_ms
=
net
(
Tensor
(
x1_np
),
Tensor
(
x2_np
),
Tensor
(
dy_np
))
output0_np
=
np
.
where
(
x1_np
<
x2_np
,
dy_np
,
0
)
output1_np
=
np
.
where
(
x1_np
<
x2_np
,
0
,
dy_np
)
assert
np
.
allclose
(
output_ms
[
0
].
asnumpy
(),
output0_np
)
assert
np
.
allclose
(
output_ms
[
1
].
asnumpy
(),
output1_np
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_broadcast
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
,
device_target
=
'GPU'
)
x1_np
=
np
.
array
([[[[
0.659578
],
[
0.49113268
],
[
0.75909054
],
[
0.71681815
],
[
0.30421826
]]],
[[[
0.30322495
],
[
0.02858258
],
[
0.06398096
],
[
0.09519596
],
[
0.12498625
]]],
[[[
0.7347768
],
[
0.166469
],
[
0.328553
],
[
0.54908437
],
[
0.23673844
]]]]).
astype
(
np
.
float32
)
x2_np
=
np
.
array
([[[[
0.9154968
,
0.29014662
,
0.6492294
,
0.39918253
,
0.1648203
,
0.00861965
]],
[[
0.996885
,
0.24152198
,
0.3601213
,
0.51664376
,
0.7933056
,
0.84706444
]],
[[
0.75606346
,
0.974512
,
0.3939527
,
0.69697475
,
0.83400667
,
0.6348955
]],
[[
0.68492866
,
0.24609096
,
0.4924665
,
0.22500521
,
0.38474053
,
0.5586104
]]]]).
astype
(
np
.
float32
)
dy_np
=
np
.
array
([[[[
0.42891738
,
0.03434946
,
0.06192983
,
0.21216309
,
0.37450036
,
0.6619524
],
[
0.8583447
,
0.5765161
,
0.1468952
,
0.9975385
,
0.6908136
,
0.4903796
],
[
0.68952006
,
0.39336833
,
0.9049695
,
0.66886294
,
0.2338471
,
0.913618
],
[
0.0428149
,
0.6243054
,
0.8519898
,
0.12088962
,
0.9735885
,
0.45661286
],
[
0.41563734
,
0.41607043
,
0.4754915
,
0.32207987
,
0.33823156
,
0.47422352
]],
[[
0.64478457
,
0.22430937
,
0.7682554
,
0.46082005
,
0.8938723
,
0.20490853
],
[
0.44393885
,
0.08278944
,
0.4734108
,
0.5543551
,
0.39428464
,
0.44424313
],
[
0.12612297
,
0.76566416
,
0.71133816
,
0.81280327
,
0.20583127
,
0.54058075
],
[
0.41341263
,
0.48118508
,
0.00401995
,
0.37259838
,
0.05435474
,
0.5240658
],
[
0.4081956
,
0.48718935
,
0.9132831
,
0.67969185
,
0.0119757
,
0.8328054
]],
[[
0.91695577
,
0.95370644
,
0.263782
,
0.7477626
,
0.6448147
,
0.8080634
],
[
0.15576603
,
0.9104615
,
0.3778708
,
0.6912833
,
0.2092224
,
0.67462957
],
[
0.7087075
,
0.7888326
,
0.4672294
,
0.98221505
,
0.25210258
,
0.98920417
],
[
0.7466197
,
0.22702982
,
0.01991269
,
0.6846591
,
0.7515228
,
0.5890395
],
[
0.04531088
,
0.21740614
,
0.8406235
,
0.36480767
,
0.37733936
,
0.02914464
]],
[[
0.33069974
,
0.5497569
,
0.9896345
,
0.4167176
,
0.78057563
,
0.04659131
],
[
0.7747768
,
0.21427679
,
0.29893255
,
0.7706969
,
0.9755185
,
0.42388415
],
[
0.3910244
,
0.39381978
,
0.37065396
,
0.15558061
,
0.05012341
,
0.15870963
],
[
0.17791101
,
0.47219893
,
0.13899496
,
0.32323205
,
0.3628809
,
0.02580585
],
[
0.30274773
,
0.62890774
,
0.11024303
,
0.6980051
,
0.35346958
,
0.062852
]]],
[[[
0.6925081
,
0.74668753
,
0.80145043
,
0.06598313
,
0.665123
,
0.15073007
],
[
0.11784806
,
0.6385372
,
0.5228278
,
0.5349848
,
0.84671104
,
0.8096436
],
[
0.09516156
,
0.63298017
,
0.52382874
,
0.36734378
,
0.66497755
,
0.6019127
],
[
0.46438488
,
0.0194377
,
0.9388292
,
0.7286089
,
0.29178405
,
0.11872514
],
[
0.22101837
,
0.6164887
,
0.6139798
,
0.11711904
,
0.6227745
,
0.09701069
]],
[[
0.80480653
,
0.90034056
,
0.8633447
,
0.97415197
,
0.08309154
,
0.8446033
],
[
0.9473769
,
0.791024
,
0.26339203
,
0.01155075
,
0.2673186
,
0.7116369
],
[
0.9687511
,
0.24281934
,
0.37777108
,
0.09802654
,
0.2421312
,
0.87095344
],
[
0.6311381
,
0.23368953
,
0.0998995
,
0.4364419
,
0.9187446
,
0.5043872
],
[
0.35226053
,
0.09357589
,
0.41317305
,
0.85930043
,
0.16249318
,
0.5478765
]],
[[
0.14338651
,
0.24859418
,
0.4246941
,
0.73034066
,
0.47172204
,
0.8717199
],
[
0.05415315
,
0.78556925
,
0.99214983
,
0.7415298
,
0.673708
,
0.87817156
],
[
0.616975
,
0.42843062
,
0.05179814
,
0.1566958
,
0.04536059
,
0.70166487
],
[
0.15493333
,
0.776598
,
0.4361967
,
0.40253627
,
0.89210516
,
0.8144414
],
[
0.04816005
,
0.29696834
,
0.4586605
,
0.3419852
,
0.5595613
,
0.74093205
]],
[[
0.1388035
,
0.9168704
,
0.64287645
,
0.83864623
,
0.48026922
,
0.78323376
],
[
0.12724937
,
0.83034366
,
0.42557436
,
0.50578654
,
0.25630295
,
0.15349793
],
[
0.27256685
,
0.04547984
,
0.5385756
,
0.39270344
,
0.7661698
,
0.23722854
],
[
0.24620503
,
0.25431684
,
0.71564585
,
0.01161419
,
0.846467
,
0.7043044
],
[
0.63272387
,
0.11857849
,
0.3772076
,
0.16758402
,
0.46743023
,
0.05919575
]]],
[[[
0.18827082
,
0.8912264
,
0.6841404
,
0.74436826
,
0.9582085
,
0.1083683
],
[
0.60695344
,
0.09742349
,
0.25074378
,
0.87940735
,
0.21116392
,
0.39418384
],
[
0.744686
,
0.35679692
,
0.01308284
,
0.45166633
,
0.68166
,
0.8634658
],
[
0.7331758
,
0.21113694
,
0.3935488
,
0.87934476
,
0.70728546
,
0.09309767
],
[
0.12128611
,
0.93696386
,
0.81177396
,
0.85402405
,
0.5827289
,
0.9776509
]],
[[
0.54069614
,
0.66651285
,
0.10646132
,
0.17342485
,
0.88795924
,
0.03551182
],
[
0.25531697
,
0.87946486
,
0.74267226
,
0.89230734
,
0.95171434
,
0.94697934
],
[
0.3708397
,
0.507355
,
0.97099817
,
0.4918163
,
0.17212386
,
0.5008048
],
[
0.62530744
,
0.25210327
,
0.73966664
,
0.71555346
,
0.82484317
,
0.6094874
],
[
0.4589691
,
0.1386695
,
0.27448782
,
0.20373994
,
0.27805242
,
0.23292768
]],
[[
0.7414099
,
0.2270226
,
0.90431255
,
0.47035843
,
0.9581062
,
0.5359226
],
[
0.79603523
,
0.45549425
,
0.80858237
,
0.7705133
,
0.017761
,
0.98001194
],
[
0.06013146
,
0.99240226
,
0.33515573
,
0.04110833
,
0.41470334
,
0.7130743
],
[
0.5687417
,
0.5788611
,
0.00722461
,
0.6603336
,
0.3420471
,
0.75181854
],
[
0.4699261
,
0.51390815
,
0.343182
,
0.81498754
,
0.8942413
,
0.46532857
]],
[[
0.4589523
,
0.5534698
,
0.2825786
,
0.8205943
,
0.78258514
,
0.43154418
],
[
0.27020997
,
0.01667354
,
0.60871965
,
0.90670526
,
0.3208025
,
0.96995634
],
[
0.85337156
,
0.9711295
,
0.1381724
,
0.53670496
,
0.7347996
,
0.73380876
],
[
0.6137464
,
0.54751194
,
0.9037335
,
0.23134394
,
0.61411524
,
0.26583543
],
[
0.70770144
,
0.01813207
,
0.24718016
,
0.70329237
,
0.7062925
,
0.14399007
]]]]).
astype
(
np
.
float32
)
expect_dx1
=
np
.
array
([[[[
5.7664223
],
[
6.981018
],
[
2.6029902
],
[
2.7598202
],
[
6.763105
]]],
[[[
10.06558
],
[
12.077246
],
[
9.338394
],
[
11.52271
],
[
8.889048
]]],
[[[
3.5789769
],
[
13.424448
],
[
8.732746
],
[
6.9677467
],
[
9.635765
]]]]).
astype
(
np
.
float32
)
expect_dx2
=
np
.
array
([[[[
0.
,
4.250458
,
2.5030296
,
3.623167
,
6.4171505
,
7.2115746
]],
[[
0.
,
4.367449
,
2.803152
,
2.5352
,
0.
,
0.
]],
[[
0.7087075
,
0.
,
2.040332
,
2.1372325
,
0.
,
2.9222295
]],
[[
1.0278877
,
5.247942
,
2.6855955
,
5.494814
,
3.5657988
,
0.66265094
]]]]).
astype
(
np
.
float32
)
net
=
Grad
(
MinimumNet
())
output_ms
=
net
(
Tensor
(
x1_np
),
Tensor
(
x2_np
),
Tensor
(
dy_np
))
assert
np
.
allclose
(
output_ms
[
0
].
asnumpy
(),
expect_dx1
)
assert
np
.
allclose
(
output_ms
[
1
].
asnumpy
(),
expect_dx2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_broadcast_diff_dims
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
,
device_target
=
'GPU'
)
x1_np
=
np
.
array
([[[
0.275478
,
0.48933202
,
0.71846116
],
[
0.9803821
,
0.57205725
,
0.28511533
]],
[[
0.61111903
,
0.9671023
,
0.70624334
],
[
0.53730786
,
0.90413177
,
0.94349676
]]]).
astype
(
np
.
float32
)
x2_np
=
np
.
array
([[
0.01045662
,
0.82126397
,
0.6365063
],
[
0.9900942
,
0.6584232
,
0.98537433
]]).
astype
(
np
.
float32
)
dy_np
=
np
.
array
([[[
0.3897645
,
0.61152864
,
0.33675498
],
[
0.5303635
,
0.84893036
,
0.4959739
]],
[[
0.5391046
,
0.8443047
,
0.4174708
],
[
0.57513475
,
0.9225578
,
0.46760973
]]]).
astype
(
np
.
float32
)
expect_dx1
=
np
.
array
([[[
0.
,
0.61152864
,
0.
],
[
0.5303635
,
0.84893036
,
0.4959739
]],
[[
0.
,
0.
,
0.
],
[
0.57513475
,
0.
,
0.46760973
]]]).
astype
(
np
.
float32
)
expect_dx2
=
np
.
array
([[
0.92886907
,
0.8443047
,
0.7542258
],
[
0.
,
0.9225578
,
0.
]]).
astype
(
np
.
float32
)
net
=
Grad
(
MinimumNet
())
output_ms
=
net
(
Tensor
(
x1_np
),
Tensor
(
x2_np
),
Tensor
(
dy_np
))
assert
np
.
allclose
(
output_ms
[
0
].
asnumpy
(),
expect_dx1
)
assert
np
.
allclose
(
output_ms
[
1
].
asnumpy
(),
expect_dx2
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录