Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5fb1ec2a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5fb1ec2a
编写于
6月 12, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 12, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2043 code review
Merge pull request !2043 from liubuyu/master
上级
c26cb9b1
432f3925
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
209 addition
and
202 deletion
+209
-202
mindspore/ccsrc/common/trans.cc
mindspore/ccsrc/common/trans.cc
+209
-202
未找到文件。
mindspore/ccsrc/common/trans.cc
浏览文件 @
5fb1ec2a
...
...
@@ -14,11 +14,9 @@
* limitations under the License.
*/
#include "common/trans.h"
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include "./securec.h"
#include "common/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel.h"
...
...
@@ -29,34 +27,7 @@
namespace
mindspore
{
namespace
trans
{
namespace
{
std
::
vector
<
size_t
>
PaddingShapeTo4dByDefault
(
const
std
::
vector
<
size_t
>
&
shape
)
{
std
::
vector
<
size_t
>
shape_4d
(
4
,
1
);
switch
(
shape
.
size
())
{
case
0
:
return
shape_4d
;
case
1
:
shape_4d
[
1
]
=
shape
[
0
];
break
;
case
2
:
shape_4d
[
1
]
=
shape
[
0
];
shape_4d
[
2
]
=
shape
[
1
];
break
;
case
3
:
shape_4d
[
1
]
=
shape
[
0
];
shape_4d
[
2
]
=
shape
[
1
];
shape_4d
[
3
]
=
shape
[
2
];
break
;
case
4
:
std
::
copy
(
shape
.
begin
(),
shape
.
end
(),
shape_4d
.
begin
());
break
;
default:
MS_LOG
(
EXCEPTION
)
<<
"Unexpect shape size = "
<<
shape
.
size
();
}
return
shape_4d
;
}
}
// namespace
const
size_t
kNchwDims
=
4
;
enum
kAxis
:
int
{
kN
=
0
,
kC
,
kH
,
kW
,
kNchwDims
,
kNdhwc
};
const
std
::
map
<
TypeId
,
size_t
>
type_map
=
{{
kNumberTypeBool
,
1
},
{
kNumberTypeInt
,
4
},
{
kNumberTypeInt8
,
1
},
{
kNumberTypeInt16
,
2
},
{
kNumberTypeInt32
,
4
},
{
kNumberTypeInt64
,
8
},
{
kNumberTypeUInt
,
4
},
{
kNumberTypeUInt8
,
1
},
{
kNumberTypeUInt16
,
2
},
...
...
@@ -84,7 +55,10 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx,
template
<
typename
T
>
T
DivCeil
(
T
n1
,
T
n2
)
{
return
(
n2
!=
0
)
?
(
n1
-
1
)
/
n2
+
1
:
0
;
if
(
n2
!=
0
)
{
return
(
n1
-
1
)
/
n2
+
1
;
}
return
0
;
}
enum
DataTypeTransMode
{
...
...
@@ -226,8 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) {
}
size_t
ShapeSize
(
const
std
::
vector
<
size_t
>
&
shape
)
{
size_t
product
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
size_t
>
());
return
product
;
return
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
IntToSize
(
1
),
std
::
multiplies
<
size_t
>
());
}
size_t
TypeIdSize
(
const
TypeId
data_type
)
{
...
...
@@ -239,57 +212,9 @@ size_t TypeIdSize(const TypeId data_type) {
return
unsupported_type_error
;
}
bool
IsNeedPadding
(
const
std
::
string
&
format
,
const
size_t
shape_size
)
{
if
(
shape_size
==
0
)
{
return
false
;
}
if
(
format
==
kOpFormat_DEFAULT
||
format
==
kOpFormat_FRAC_NZ
)
{
return
false
;
}
else
if
(
shape_size
<
4
)
{
return
true
;
}
return
false
;
}
std
::
vector
<
int
>
GetRuntimePaddingShape
(
const
AnfNodePtr
&
node
,
size_t
index
)
{
std
::
vector
<
int
>
shape
;
std
::
vector
<
size_t
>
host_shape
;
if
(
node
->
isa
<
ValueNode
>
())
{
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
auto
node_value
=
value_node
->
value
();
auto
tensor
=
node_value
->
cast
<
tensor
::
TensorPtr
>
();
if
(
tensor
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
" the node[ "
<<
node
->
DebugString
()
<<
"]'s cannot convert "
;
}
auto
shape_temp
=
tensor
->
shape
();
(
void
)
std
::
transform
(
shape_temp
.
begin
(),
shape_temp
.
end
(),
std
::
back_inserter
(
host_shape
),
IntToSize
);
if
(
host_shape
.
empty
())
{
host_shape
.
push_back
(
1
);
}
}
else
{
host_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
index
);
}
if
(
trans
::
IsNeedPadding
(
AnfAlgo
::
GetOutputFormat
(
node
,
0
),
host_shape
.
size
()))
{
host_shape
=
trans
::
PaddingShapeTo4d
(
host_shape
,
AnfAlgo
::
GetOutputReshapeType
(
node
,
0
));
}
std
::
transform
(
host_shape
.
begin
(),
host_shape
.
end
(),
std
::
back_inserter
(
shape
),
SizeToInt
);
return
shape
;
}
std
::
vector
<
size_t
>
PaddingShapeTo4d
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
kernel
::
Axis
>
&
padding_axis
)
{
if
(
padding_axis
.
empty
()
||
shape
.
size
()
!=
padding_axis
.
size
())
{
return
PaddingShapeTo4dByDefault
(
shape
);
}
std
::
vector
<
size_t
>
shape_4d
(
4
,
1
);
for
(
size_t
index
=
0
;
index
<
padding_axis
.
size
();
index
++
)
{
shape_4d
[
padding_axis
[
index
]]
=
shape
[
index
];
}
return
shape_4d
;
}
namespace
{
bool
CheckDims
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
shape
.
size
()
!=
4
)
{
if
(
shape
.
size
()
!=
kNchwDims
)
{
MS_LOG
(
ERROR
)
<<
"Host shape dims shoud be 4"
;
return
false
;
}
...
...
@@ -308,10 +233,10 @@ std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
MS_LOG
(
EXCEPTION
)
<<
"Ccheck dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
device_shape
.
push_back
(
shape
[
0
]);
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
shape
[
1
]);
device_shape
.
push_back
(
shape
[
kN
]);
device_shape
.
push_back
(
shape
[
kH
]);
device_shape
.
push_back
(
shape
[
kW
]);
device_shape
.
push_back
(
shape
[
kC
]);
return
device_shape
;
}
...
...
@@ -320,10 +245,10 @@ std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
shape
[
1
]);
device_shape
.
push_back
(
shape
[
0
]);
device_shape
.
push_back
(
shape
[
kH
]);
device_shape
.
push_back
(
shape
[
kW
]);
device_shape
.
push_back
(
shape
[
kC
]);
device_shape
.
push_back
(
shape
[
kN
]);
return
device_shape
;
}
...
...
@@ -332,9 +257,9 @@ std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
size_t
cout16
=
((
shape
[
0
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
size_t
cin16
=
((
shape
[
1
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
device_shape
.
push_back
(
shape
[
2
]
*
shape
[
3
]
*
cin16
/
kCubeSize
);
const
size_t
cout16
=
((
shape
[
kN
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
const
size_t
cin16
=
((
shape
[
kC
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
device_shape
.
push_back
(
shape
[
kH
]
*
shape
[
kW
]
*
cin16
/
kCubeSize
);
device_shape
.
push_back
(
cout16
/
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
...
...
@@ -346,12 +271,12 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
size_t
C1
=
(
shape
[
1
]
+
kCubeSize
-
1
)
/
kCubeSize
;
size_t
C0
=
kCubeSize
;
device_shape
.
push_back
(
shape
[
0
]);
const
size_t
C1
=
(
shape
[
kC
]
+
kCubeSize
-
1
)
/
kCubeSize
;
const
size_t
C0
=
kCubeSize
;
device_shape
.
push_back
(
shape
[
kN
]);
device_shape
.
push_back
(
C1
);
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
shape
[
kH
]);
device_shape
.
push_back
(
shape
[
kW
]);
device_shape
.
push_back
(
C0
);
return
device_shape
;
}
...
...
@@ -361,10 +286,10 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
device_shape
.
push_back
((
shape
[
1
]
-
1
)
/
kCubeSize
+
1
);
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
shape
[
0
]);
device_shape
.
push_back
((
shape
[
kC
]
-
1
)
/
kCubeSize
+
1
);
device_shape
.
push_back
(
shape
[
kH
]);
device_shape
.
push_back
(
shape
[
kW
]);
device_shape
.
push_back
(
shape
[
kN
]);
device_shape
.
push_back
(
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
return
device_shape
;
...
...
@@ -375,9 +300,9 @@ std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
size_t
c0
=
4
;
auto
first_dim
=
DivCeil
(
c0
*
shape
.
at
(
2
)
*
shape
.
at
(
3
)
,
kCubeSize
);
auto
no
=
DivCeil
(
shape
.
at
(
0
),
kCubeSize
);
const
size_t
c0
=
4
;
auto
first_dim
=
DivCeil
(
c0
*
shape
[
kH
]
*
shape
[
kW
]
,
kCubeSize
);
auto
no
=
DivCeil
(
shape
.
at
(
kN
),
kCubeSize
);
device_shape
.
push_back
(
first_dim
);
device_shape
.
push_back
(
no
);
device_shape
.
push_back
(
kCubeSize
);
...
...
@@ -390,24 +315,101 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
MS_LOG
(
EXCEPTION
)
<<
"Check dims failed."
;
}
std
::
vector
<
size_t
>
device_shape
;
size_t
C1
=
1
;
size_t
C0
=
4
;
device_shape
.
push_back
(
shape
[
0
]);
const
size_t
C1
=
1
;
const
size_t
C0
=
4
;
device_shape
.
push_back
(
shape
[
kN
]);
device_shape
.
push_back
(
C1
);
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
shape
[
kH
]);
device_shape
.
push_back
(
shape
[
kW
]);
device_shape
.
push_back
(
C0
);
return
device_shape
;
}
std
::
vector
<
size_t
>
NdhwcDeviceShape
(
const
std
::
vector
<
size_t
>
&
shape
)
{
if
(
shape
.
size
()
<
5
)
{
if
(
shape
.
size
()
<
kNdhwc
)
{
MS_LOG
(
EXCEPTION
)
<<
"Shape dims must be 5 when format is ndhwc."
;
}
return
shape
;
}
std
::
vector
<
size_t
>
PaddingShapeTo4dByDefault
(
const
std
::
vector
<
size_t
>
&
shape
)
{
std
::
vector
<
size_t
>
shape_4d
(
kNchwDims
,
1
);
switch
(
shape
.
size
())
{
case
0
:
return
shape_4d
;
case
1
:
shape_4d
[
kC
]
=
shape
[
kN
];
break
;
case
2
:
shape_4d
[
kC
]
=
shape
[
kN
];
shape_4d
[
kH
]
=
shape
[
kC
];
break
;
case
3
:
shape_4d
[
kC
]
=
shape
[
kN
];
shape_4d
[
kH
]
=
shape
[
kC
];
shape_4d
[
kW
]
=
shape
[
kH
];
break
;
case
4
:
std
::
copy
(
shape
.
begin
(),
shape
.
end
(),
shape_4d
.
begin
());
break
;
default:
MS_LOG
(
EXCEPTION
)
<<
"Unexpect shape size = "
<<
shape
.
size
();
}
return
shape_4d
;
}
}
// namespace
bool
IsNeedPadding
(
const
std
::
string
&
format
,
const
size_t
shape_size
)
{
if
(
shape_size
==
0
)
{
return
false
;
}
if
(
format
==
kOpFormat_DEFAULT
||
format
==
kOpFormat_FRAC_NZ
)
{
return
false
;
}
else
if
(
shape_size
<
kNchwDims
)
{
return
true
;
}
return
false
;
}
std
::
vector
<
int
>
GetRuntimePaddingShape
(
const
AnfNodePtr
&
node
,
size_t
index
)
{
MS_EXCEPTION_IF_NULL
(
node
);
std
::
vector
<
int
>
shape
;
std
::
vector
<
size_t
>
host_shape
;
if
(
node
->
isa
<
ValueNode
>
())
{
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
auto
node_value
=
value_node
->
value
();
MS_EXCEPTION_IF_NULL
(
node_value
);
auto
tensor
=
node_value
->
cast
<
tensor
::
TensorPtr
>
();
if
(
tensor
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
" The node[ "
<<
node
->
DebugString
()
<<
"]'s cannot convert "
;
}
auto
shape_temp
=
tensor
->
shape
();
(
void
)
std
::
transform
(
shape_temp
.
begin
(),
shape_temp
.
end
(),
std
::
back_inserter
(
host_shape
),
IntToSize
);
if
(
host_shape
.
empty
())
{
host_shape
.
push_back
(
1
);
}
}
else
{
host_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
index
);
}
if
(
trans
::
IsNeedPadding
(
AnfAlgo
::
GetOutputFormat
(
node
,
0
),
host_shape
.
size
()))
{
host_shape
=
trans
::
PaddingShapeTo4d
(
host_shape
,
AnfAlgo
::
GetOutputReshapeType
(
node
,
0
));
}
std
::
transform
(
host_shape
.
begin
(),
host_shape
.
end
(),
std
::
back_inserter
(
shape
),
SizeToInt
);
return
shape
;
}
std
::
vector
<
size_t
>
PaddingShapeTo4d
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
kernel
::
Axis
>
&
padding_axis
)
{
if
(
padding_axis
.
empty
()
||
shape
.
size
()
!=
padding_axis
.
size
())
{
return
PaddingShapeTo4dByDefault
(
shape
);
}
std
::
vector
<
size_t
>
shape_4d
(
kNchwDims
,
1
);
for
(
size_t
index
=
0
;
index
<
padding_axis
.
size
();
index
++
)
{
shape_4d
[
padding_axis
[
index
]]
=
shape
[
index
];
}
return
shape_4d
;
}
std
::
vector
<
size_t
>
TransShapeToDevice
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
)
{
using
DeviceShapeTransfer
=
std
::
function
<
std
::
vector
<
size_t
>
(
const
std
::
vector
<
size_t
>
&
)
>
;
const
std
::
map
<
std
::
string
,
DeviceShapeTransfer
>
device_shape_map
{{
kOpFormat_NCHW
,
NchwDeviceShape
},
...
...
@@ -439,7 +441,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
device_shape
.
push_back
(
kCubeSize
);
return
device_shape
;
}
if
(
shape
.
size
()
!=
4
)
{
if
(
shape
.
size
()
!=
kNchwDims
)
{
MS_LOG
(
WARNING
)
<<
"Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"
;
temp_shape
=
PaddingShapeTo4dByDefault
(
shape
);
}
...
...
@@ -455,6 +457,8 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
MS_LOG
(
ERROR
)
<<
"Invalid host shape, host shape dims:"
<<
args
.
host_shape
.
size
()
<<
", expect dims:"
<<
kNchwDims
;
return
false
;
}
MS_EXCEPTION_IF_NULL
(
size
);
MS_EXCEPTION_IF_NULL
(
total_size
);
*
size
=
TypeIdSize
(
args
.
src_data_type
);
if
(
*
size
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
...
...
@@ -540,10 +544,10 @@ bool NchwTo4D(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Check args failed."
;
return
false
;
}
size_t
n
=
args
.
host_shape
[
0
];
size_t
c
=
args
.
host_shape
[
1
];
size_t
h
=
args
.
host_shape
[
2
];
size_t
w
=
args
.
host_shape
[
3
];
auto
n
=
args
.
host_shape
[
kN
];
auto
c
=
args
.
host_shape
[
kC
];
auto
h
=
args
.
host_shape
[
kH
];
auto
w
=
args
.
host_shape
[
kW
];
for
(
size_t
ni
=
0
;
ni
<
n
;
ni
++
)
{
for
(
size_t
ci
=
0
;
ci
<
c
;
ci
++
)
{
for
(
size_t
hi
=
0
;
hi
<
h
;
hi
++
)
{
...
...
@@ -572,10 +576,10 @@ bool ToNchw(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Check args failed."
;
return
false
;
}
size_t
n
=
args
.
host_shape
[
0
];
size_t
c
=
args
.
host_shape
[
1
];
size_t
h
=
args
.
host_shape
[
2
];
size_t
w
=
args
.
host_shape
[
3
];
auto
n
=
args
.
host_shape
[
kN
];
auto
c
=
args
.
host_shape
[
kC
];
auto
h
=
args
.
host_shape
[
kH
];
auto
w
=
args
.
host_shape
[
kW
];
for
(
size_t
ni
=
0
;
ni
<
n
;
ni
++
)
{
for
(
size_t
ci
=
0
;
ci
<
c
;
ci
++
)
{
for
(
size_t
hi
=
0
;
hi
<
h
;
hi
++
)
{
...
...
@@ -602,32 +606,32 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Invalid host shape, host shape dims:"
<<
args
.
host_shape
.
size
()
<<
", expect dims:"
<<
kNchwDims
;
return
false
;
}
size_t
size
=
TypeIdSize
(
args
.
src_data_type
);
auto
size
=
TypeIdSize
(
args
.
src_data_type
);
if
(
size
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
return
false
;
}
auto
n
=
args
.
host_shape
[
0
];
auto
c
=
args
.
host_shape
[
1
];
auto
h
=
args
.
host_shape
[
2
];
auto
w
=
args
.
host_shape
[
3
];
auto
n
=
args
.
host_shape
[
kN
];
auto
c
=
args
.
host_shape
[
kC
];
auto
h
=
args
.
host_shape
[
kH
];
auto
w
=
args
.
host_shape
[
kW
];
size_t
c0
=
CubeSizeByType
(
args
.
src_data_type
);
auto
c0
=
CubeSizeByType
(
args
.
src_data_type
);
if
(
c0
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
return
false
;
}
size_t
c1
=
DivCeil
(
c
,
c0
);
size_t
hw
=
h
*
w
;
size_t
chw
=
c
*
hw
;
size_t
hwc0
=
hw
*
c0
;
size_t
nchw
=
n
*
chw
;
size_t
hf_cnt
=
DivCeil
(
n
,
kCubeSize
);
size_t
vf_cnt
=
c1
*
hw
;
size_t
fractal_ele_cnt
=
c0
*
kCubeSize
;
size_t
total_ele_cnt
=
hf_cnt
*
vf_cnt
*
fractal_ele_cnt
;
size_t
dst_size
=
total_ele_cnt
*
size
;
auto
c1
=
DivCeil
(
c
,
c0
);
auto
hw
=
h
*
w
;
auto
chw
=
c
*
hw
;
auto
hwc0
=
hw
*
c0
;
auto
nchw
=
n
*
chw
;
auto
hf_cnt
=
DivCeil
(
n
,
kCubeSize
);
auto
vf_cnt
=
c1
*
hw
;
auto
fractal_ele_cnt
=
c0
*
kCubeSize
;
auto
total_ele_cnt
=
hf_cnt
*
vf_cnt
*
fractal_ele_cnt
;
auto
dst_size
=
total_ele_cnt
*
size
;
if
(
dst_size
!=
args
.
device_size
)
{
MS_LOG
(
ERROR
)
<<
"Illegal total data size."
<<
"dst size is :"
<<
dst_size
<<
"device size is :"
<<
args
.
device_size
;
...
...
@@ -647,7 +651,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
auto
src_ni
=
hfi
*
kCubeSize
+
col
;
auto
src_idx
=
src_row_offset
+
chw
*
col
;
auto
dst_idx
=
gfi
*
fractal_ele_cnt
+
col
*
c0
+
row
;
auto
pad_zero
=
(
src_ni
>=
n
||
src_idx
>=
nchw
||
src_ci
>=
c
)
?
true
:
false
;
auto
pad_zero
=
src_ni
>=
n
||
src_idx
>=
nchw
||
src_ci
>=
c
;
SetData
(
size
,
pad_zero
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
...
...
@@ -663,12 +667,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Invalid host shape, host shape dims:"
<<
args
.
host_shape
.
size
()
<<
", expect dims:"
<<
kNchwDims
;
return
false
;
}
size_t
size
=
TypeIdSize
(
args
.
src_data_type
);
auto
size
=
TypeIdSize
(
args
.
src_data_type
);
if
(
size
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
return
false
;
}
size_t
total_size
=
ShapeSize
(
args
.
device_shape
)
*
size
;
auto
total_size
=
ShapeSize
(
args
.
device_shape
)
*
size
;
if
(
total_size
!=
args
.
device_size
)
{
MS_LOG
(
ERROR
)
<<
"Illegal total data size, total_size:"
<<
total_size
<<
", device_size:"
<<
args
.
device_size
;
return
false
;
...
...
@@ -677,18 +681,16 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
auto
n0
=
args
.
device_shape
.
at
(
1
);
auto
ni
=
args
.
device_shape
.
at
(
2
);
auto
c0
=
args
.
device_shape
.
at
(
3
);
auto
n
=
args
.
host_shape
[
0
];
auto
c
=
args
.
host_shape
[
1
];
auto
h
=
args
.
host_shape
[
2
];
auto
w
=
args
.
host_shape
[
3
];
size_t
nc
=
ni
*
n0
;
size_t
ncc0
=
nc
*
c0
;
size_t
wncc0
=
w
*
ncc0
;
size_t
hwncc0
=
h
*
wncc0
;
size_t
hw
=
h
*
w
;
size_t
chw
=
c
*
hw
;
auto
n
=
args
.
host_shape
[
kN
];
auto
c
=
args
.
host_shape
[
kC
];
auto
h
=
args
.
host_shape
[
kH
];
auto
w
=
args
.
host_shape
[
kW
];
auto
nc
=
ni
*
n0
;
auto
ncc0
=
nc
*
c0
;
auto
wncc0
=
w
*
ncc0
;
auto
hwncc0
=
h
*
wncc0
;
auto
hw
=
h
*
w
;
auto
chw
=
c
*
hw
;
for
(
size_t
n_idx
=
0
;
n_idx
<
n
;
n_idx
++
)
{
size_t
n_head_addr
=
n_idx
*
chw
;
...
...
@@ -720,20 +722,18 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Check args failed."
;
return
false
;
}
size_t
cube
=
kCubeSize
;
size_t
n
=
args
.
host_shape
[
0
];
size_t
c
=
args
.
host_shape
[
1
];
size_t
h
=
args
.
host_shape
[
2
];
size_t
w
=
args
.
host_shape
[
3
];
size_t
c0
=
4
;
size_t
c1
=
DivCeil
(
c
,
c0
);
size_t
hwc0
=
h
*
w
*
c0
;
size_t
hwc
=
h
*
w
*
c
;
size_t
nhwc
=
n
*
h
*
w
*
c
;
size_t
n_cnt
=
DivCeil
(
n
,
cube
);
size_t
v_cnt
=
DivCeil
(
h
*
w
*
c0
*
c1
,
cube
);
auto
cube
=
kCubeSize
;
auto
n
=
args
.
host_shape
[
kN
];
auto
c
=
args
.
host_shape
[
kC
];
auto
h
=
args
.
host_shape
[
kH
];
auto
w
=
args
.
host_shape
[
kW
];
const
size_t
c0
=
4
;
auto
c1
=
DivCeil
(
c
,
c0
);
auto
hwc0
=
h
*
w
*
c0
;
auto
hwc
=
h
*
w
*
c
;
auto
nhwc
=
n
*
h
*
w
*
c
;
auto
n_cnt
=
DivCeil
(
n
,
cube
);
auto
v_cnt
=
DivCeil
(
h
*
w
*
c0
*
c1
,
cube
);
size_t
dst_idx
=
0
;
for
(
size_t
vi
=
0
;
vi
<
v_cnt
;
vi
++
)
{
...
...
@@ -929,7 +929,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Invalid host shape, host shape dims:"
<<
args
.
host_shape
.
size
()
<<
", expect dims:"
<<
kNchwDims
;
return
false
;
}
size_t
size
=
TypeIdSize
(
args
.
src_data_type
);
auto
size
=
TypeIdSize
(
args
.
src_data_type
);
if
(
size
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
return
false
;
...
...
@@ -940,20 +940,23 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
return
false
;
}
auto
n
=
args
.
host_shape
[
0
];
auto
c
=
args
.
host_shape
[
1
];
auto
h
=
args
.
host_shape
[
2
];
auto
w
=
args
.
host_shape
[
3
];
size_t
c0
=
CubeSizeByType
(
args
.
src_data_type
);
auto
n
=
args
.
host_shape
[
kN
];
auto
c
=
args
.
host_shape
[
kC
];
auto
h
=
args
.
host_shape
[
kH
];
auto
w
=
args
.
host_shape
[
kW
];
auto
c0
=
CubeSizeByType
(
args
.
src_data_type
);
if
(
c0
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
return
false
;
}
size_t
c1
=
DivCeil
(
c
,
c0
);
size_t
hw
=
h
*
w
;
size_t
chw
=
c
*
hw
;
size_t
c1hwc0
=
c1
*
hw
*
c0
;
size_t
wc0
=
w
*
c0
;
if
(
args
.
device_format
==
kOpFormat_NC1HWC0_C04
)
{
c0
=
4
;
}
auto
c1
=
DivCeil
(
c
,
c0
);
auto
hw
=
h
*
w
;
auto
chw
=
c
*
hw
;
auto
c1hwc0
=
c1
*
hw
*
c0
;
auto
wc0
=
w
*
c0
;
for
(
size_t
n_idx
=
0
;
n_idx
<
n
;
n_idx
++
)
{
size_t
n_head_addr
=
n_idx
*
c1hwc0
;
...
...
@@ -967,7 +970,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
size_t
dst_idx
=
c0_idx
+
w_head_addr
;
size_t
c_idx
=
c0_idx
+
c1_idx
*
c0
;
size_t
src_idx
=
n_idx
*
chw
+
c_idx
*
hw
+
h_idx
*
w
+
w_idx
;
auto
pad_zero
=
(
c_idx
<
c
)
?
false
:
true
;
auto
pad_zero
=
c_idx
>=
c
;
SetData
(
size
,
pad_zero
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
...
...
@@ -984,29 +987,29 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Invalid host shape, host shape dims:"
<<
args
.
host_shape
.
size
()
<<
", expect dims:"
<<
kNchwDims
;
return
false
;
}
size_t
size
=
TypeIdSize
(
args
.
src_data_type
);
auto
size
=
TypeIdSize
(
args
.
src_data_type
);
if
(
size
<
1
)
{
MS_LOG
(
ERROR
)
<<
"Illegal dtype."
;
return
false
;
}
size_t
total_size
=
ShapeSize
(
args
.
device_shape
)
*
size
;
auto
total_size
=
ShapeSize
(
args
.
device_shape
)
*
size
;
if
(
total_size
!=
args
.
device_size
)
{
MS_LOG
(
ERROR
)
<<
"Illegal total data size, total_size:"
<<
total_size
<<
", device_size:"
<<
args
.
device_size
;
return
false
;
}
auto
n
=
args
.
host_shape
[
0
];
auto
c
=
args
.
host_shape
[
1
];
auto
h
=
args
.
host_shape
[
2
];
auto
w
=
args
.
host_shape
[
3
];
auto
n
=
args
.
host_shape
[
kN
];
auto
c
=
args
.
host_shape
[
kC
];
auto
h
=
args
.
host_shape
[
kH
];
auto
w
=
args
.
host_shape
[
kW
];
auto
c1
=
args
.
device_shape
[
1
];
auto
c0
=
args
.
device_shape
[
4
];
size_t
hw
=
h
*
w
;
size_t
chw
=
c
*
hw
;
size_t
wc0
=
w
*
c0
;
size_t
hwc0
=
h
*
wc0
;
size_t
c1hwc0
=
c1
*
hwc0
;
auto
hw
=
h
*
w
;
auto
chw
=
c
*
hw
;
auto
wc0
=
w
*
c0
;
auto
hwc0
=
h
*
wc0
;
auto
c1hwc0
=
c1
*
hwc0
;
for
(
size_t
n_idx
=
0
;
n_idx
<
n
;
n_idx
++
)
{
size_t
n_head_addr
=
n_idx
*
chw
;
...
...
@@ -1037,13 +1040,15 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Check args failed."
;
return
false
;
}
auto
n
=
args
.
host_shape
[
0
];
auto
c
=
args
.
host_shape
[
1
];
auto
h
=
args
.
host_shape
[
2
];
auto
w
=
args
.
host_shape
[
3
];
auto
n
=
args
.
host_shape
[
kN
];
auto
c
=
args
.
host_shape
[
kC
];
auto
h
=
args
.
host_shape
[
kH
];
auto
w
=
args
.
host_shape
[
kW
];
const
int
co_idx
=
4
;
const
int
c0_idx
=
5
;
auto
c1
=
args
.
device_shape
[
0
];
auto
co
=
args
.
device_shape
[
4
];
auto
c0
=
args
.
device_shape
[
5
];
auto
co
=
args
.
device_shape
[
co_idx
];
auto
c0
=
args
.
device_shape
[
c0_idx
];
for
(
size_t
c1_i
=
0
;
c1_i
<
c1
;
c1_i
++
)
{
for
(
size_t
h_i
=
0
;
h_i
<
h
;
h_i
++
)
{
...
...
@@ -1055,7 +1060,7 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
co_i
*
c0
+
c0_i
;
size_t
c_i
=
c0_i
+
c1_i
*
c0
;
size_t
src_idx
=
n_i
*
c
*
h
*
w
+
c_i
*
h
*
w
+
h_i
*
w
+
w_i
;
auto
pad_zero
=
(
c_i
<
c
&&
c0_i
==
co_i
)
?
false
:
true
;
auto
pad_zero
=
!
(
c_i
<
c
&&
c0_i
==
co_i
)
;
SetData
(
size
,
pad_zero
,
src_idx
,
dst_idx
,
args
,
result
);
}
}
...
...
@@ -1076,12 +1081,14 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
MS_LOG
(
ERROR
)
<<
"Check args failed."
;
return
false
;
}
auto
n
=
args
.
host_shape
[
0
];
auto
c
=
args
.
host_shape
[
1
];
auto
h
=
args
.
host_shape
[
2
];
auto
w
=
args
.
host_shape
[
3
];
auto
co
=
args
.
device_shape
[
4
];
auto
c0
=
args
.
device_shape
[
5
];
auto
n
=
args
.
host_shape
[
kN
];
auto
c
=
args
.
host_shape
[
kC
];
auto
h
=
args
.
host_shape
[
kH
];
auto
w
=
args
.
host_shape
[
kW
];
const
int
co_idx
=
4
;
const
int
c0_idx
=
5
;
auto
co
=
args
.
device_shape
[
co_idx
];
auto
c0
=
args
.
device_shape
[
c0_idx
];
for
(
size_t
n_i
=
0
;
n_i
<
n
;
n_i
++
)
{
for
(
size_t
c_i
=
0
;
c_i
<
c
;
c_i
++
)
{
for
(
size_t
h_i
=
0
;
h_i
<
h
;
h_i
++
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录