Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5d225f93
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5d225f93
编写于
4月 08, 2020
作者:
L
lianliguang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change the padding strategy & refactor insert transdata
上级
60958d6b
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
425 addition
and
379 deletion
+425
-379
mindspore/ccsrc/common/trans.cc
mindspore/ccsrc/common/trans.cc
+94
-38
mindspore/ccsrc/common/trans.h
mindspore/ccsrc/common/trans.h
+5
-1
mindspore/ccsrc/device/ascend/ascend_device_address.cc
mindspore/ccsrc/device/ascend/ascend_device_address.cc
+2
-2
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
+3
-2
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
+8
-5
mindspore/ccsrc/device/kernel_adjust.cc
mindspore/ccsrc/device/kernel_adjust.cc
+3
-1
mindspore/ccsrc/device/kernel_info.h
mindspore/ccsrc/device/kernel_info.h
+4
-0
mindspore/ccsrc/device/kernel_runtime.cc
mindspore/ccsrc/device/kernel_runtime.cc
+4
-16
mindspore/ccsrc/kernel/kernel_build_info.cc
mindspore/ccsrc/kernel/kernel_build_info.cc
+12
-18
mindspore/ccsrc/kernel/kernel_build_info.h
mindspore/ccsrc/kernel/kernel_build_info.h
+6
-2
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
+63
-126
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
+1
-1
mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc
...re_activate/ascend/format_type/deal_ref_trans_and_cast.cc
+2
-4
mindspore/ccsrc/pre_activate/ascend/ir_fusion/transdata_split.cc
...re/ccsrc/pre_activate/ascend/ir_fusion/transdata_split.cc
+99
-99
mindspore/ccsrc/session/anf_runtime_algorithm.cc
mindspore/ccsrc/session/anf_runtime_algorithm.cc
+67
-37
mindspore/ccsrc/session/anf_runtime_algorithm.h
mindspore/ccsrc/session/anf_runtime_algorithm.h
+5
-0
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+3
-2
mindspore/ccsrc/session/kernel_graph.cc
mindspore/ccsrc/session/kernel_graph.cc
+15
-3
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+8
-5
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+2
-2
tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc
...d/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc
+2
-0
tests/ut/cpp/session/anf_runtime_algorithm_test.cc
tests/ut/cpp/session/anf_runtime_algorithm_test.cc
+14
-12
tests/ut/cpp/session/kernel_graph_test.cc
tests/ut/cpp/session/kernel_graph_test.cc
+3
-3
未找到文件。
mindspore/ccsrc/common/trans.cc
浏览文件 @
5d225f93
...
...
@@ -20,6 +20,8 @@
#include <utility>
#include "./securec.h"
#include "common/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel.h"
#include "device/convert_tensor_utils.h"
#include "utils/convert_utils.h"
#include "utils/log_adapter.h"
...
...
@@ -27,6 +29,33 @@
namespace
mindspore
{
namespace
trans
{
namespace
{
std
::
vector
<
size_t
>
PaddingShapeTo4dByDefault
(
const
std
::
vector
<
size_t
>
&
shape
)
{
std
::
vector
<
size_t
>
shape_4d
(
4
,
1
);
switch
(
shape
.
size
())
{
case
0
:
return
shape_4d
;
case
1
:
shape_4d
[
1
]
=
shape
[
0
];
break
;
case
2
:
shape_4d
[
1
]
=
shape
[
0
];
shape_4d
[
2
]
=
shape
[
1
];
break
;
case
3
:
shape_4d
[
1
]
=
shape
[
0
];
shape_4d
[
2
]
=
shape
[
1
];
shape_4d
[
3
]
=
shape
[
2
];
break
;
case
4
:
std
::
copy
(
shape
.
begin
(),
shape
.
end
(),
shape_4d
.
begin
());
break
;
default:
MS_LOG
(
EXCEPTION
)
<<
"Unexpect shape size = "
<<
shape
.
size
();
}
return
shape_4d
;
}
}
// namespace
const
size_t
kNchwDims
=
4
;
const
std
::
map
<
TypeId
,
size_t
>
type_map
=
{{
kNumberTypeBool
,
1
},
{
kNumberTypeInt
,
4
},
{
kNumberTypeInt8
,
1
},
{
kNumberTypeInt16
,
2
},
{
kNumberTypeInt32
,
4
},
{
kNumberTypeInt64
,
8
},
...
...
@@ -154,38 +183,64 @@ size_t TypeIdSize(const TypeId data_type) {
return
unsupported_type_error
;
}
std
::
vector
<
size_t
>
TransShapeTo4d
(
const
std
::
vector
<
size_t
>
&
shape
)
{
bool
IsNeedPadding
(
const
std
::
string
&
format
,
const
size_t
shape_size
)
{
if
(
shape_size
==
0
)
{
return
false
;
}
if
(
format
==
kOpFormat_DEFAULT
||
format
==
kOpFormat_FRAC_NZ
)
{
return
false
;
}
else
if
(
shape_size
<
4
)
{
return
true
;
}
return
false
;
}
std
::
vector
<
int
>
GetRuntimePaddingShape
(
const
AnfNodePtr
&
node
,
size_t
index
)
{
std
::
vector
<
int
>
shape
;
std
::
vector
<
size_t
>
host_shape
;
if
(
node
->
isa
<
ValueNode
>
())
{
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
auto
node_value
=
value_node
->
value
();
auto
tensor
=
node_value
->
cast
<
tensor
::
TensorPtr
>
();
if
(
tensor
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
" the node[ "
<<
node
->
DebugString
()
<<
"]'s cannot convert "
;
}
shape
=
tensor
->
shape
();
(
void
)
std
::
transform
(
shape
.
begin
(),
shape
.
end
(),
std
::
back_inserter
(
host_shape
),
IntToSize
);
if
(
host_shape
.
empty
())
{
host_shape
.
push_back
(
1
);
}
}
else
{
host_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
index
);
}
if
(
trans
::
IsNeedPadding
(
AnfAlgo
::
GetOutputFormat
(
node
,
0
),
host_shape
.
size
()))
{
host_shape
=
trans
::
PaddingShapeTo4d
(
host_shape
,
AnfAlgo
::
GetOutputReshapeType
(
node
,
0
));
}
std
::
transform
(
host_shape
.
begin
(),
host_shape
.
end
(),
std
::
back_inserter
(
shape
),
SizeToInt
);
return
shape
;
}
std
::
vector
<
size_t
>
PaddingShapeTo4d
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
kernel
::
Axis
>
&
padding_axis
)
{
if
(
padding_axis
.
empty
()
||
shape
.
size
()
!=
padding_axis
.
size
())
{
return
PaddingShapeTo4dByDefault
(
shape
);
}
std
::
vector
<
size_t
>
shape_4d
(
4
,
1
);
switch
(
shape
.
size
())
{
case
0
:
break
;
case
1
:
shape_4d
[
1
]
=
shape
[
0
];
break
;
case
2
:
shape_4d
[
0
]
=
shape
[
0
];
shape_4d
[
1
]
=
shape
[
1
];
break
;
case
3
:
MS_LOG
(
EXCEPTION
)
<<
"Unexpected shape size = 3,it should has a default format"
;
case
4
:
for
(
size_t
i
=
0
;
i
<
4
;
++
i
)
{
shape_4d
[
i
]
=
shape
[
i
];
}
break
;
default:
MS_LOG
(
EXCEPTION
)
<<
"Unexpected shape size = "
<<
shape
.
size
();
for
(
size_t
index
=
0
;
index
<
padding_axis
.
size
();
index
++
)
{
shape_4d
[
padding_axis
[
index
]]
=
shape
[
index
];
}
return
shape_4d
;
}
std
::
vector
<
size_t
>
TransShapeToDevice
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
)
{
if
(
format
==
kOpFormat_ND
||
format
==
kOpFormat_DEFAULT
)
{
return
shape
;
}
auto
temp_shape
=
shape
;
std
::
vector
<
size_t
>
device_shape
;
if
(
format
==
kOpFormat_FRAC_NZ
)
{
if
(
shape
.
size
()
<
2
)
{
MS_EXCEPTION
(
NotSupportError
)
<<
"Format "
<<
format
<<
" is not support shape "
<<
shape
.
size
();
}
if
(
shape
.
size
()
>
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"Format"
<<
format
<<
" is not support shape "
<<
shape
.
size
();
}
else
{
(
void
)
std
::
copy
(
shape
.
begin
(),
shape
.
end
()
-
2
,
std
::
back_inserter
(
device_shape
));
}
auto
h1
=
(
shape
[
shape
.
size
()
-
2
]
-
1
)
/
kCubeSize
+
1
;
...
...
@@ -197,35 +252,36 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
return
device_shape
;
}
if
(
shape
.
size
()
!=
4
)
{
MS_LOG
(
EXCEPTION
)
<<
"shape_4d size should be 4"
;
MS_LOG
(
WARNING
)
<<
"Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"
;
temp_shape
=
PaddingShapeTo4dByDefault
(
shape
);
}
if
(
format
==
kOpFormat_NC1HWC0
)
{
size_t
C1
=
(
shape
[
1
]
+
kCubeSize
-
1
)
/
kCubeSize
;
size_t
C1
=
(
temp_
shape
[
1
]
+
kCubeSize
-
1
)
/
kCubeSize
;
size_t
C0
=
kCubeSize
;
device_shape
.
push_back
(
shape
[
0
]);
device_shape
.
push_back
(
temp_
shape
[
0
]);
device_shape
.
push_back
(
C1
);
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
temp_
shape
[
2
]);
device_shape
.
push_back
(
temp_
shape
[
3
]);
device_shape
.
push_back
(
C0
);
return
device_shape
;
}
else
if
(
format
==
kOpFormat_FRAC_Z
)
{
size_t
cout16
=
((
shape
[
0
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
size_t
cin16
=
((
shape
[
1
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
device_shape
.
push_back
(
shape
[
2
]
*
shape
[
3
]
*
cin16
/
kCubeSize
);
size_t
cout16
=
((
temp_
shape
[
0
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
size_t
cin16
=
((
temp_
shape
[
1
]
+
kCubeSize
-
1
)
/
kCubeSize
)
*
kCubeSize
;
device_shape
.
push_back
(
temp_shape
[
2
]
*
temp_
shape
[
3
]
*
cin16
/
kCubeSize
);
device_shape
.
push_back
(
cout16
/
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
device_shape
.
push_back
(
kCubeSize
);
return
device_shape
;
}
else
if
(
format
==
kOpFormat_NHWC
)
{
device_shape
.
push_back
(
shape
[
0
]);
device_shape
.
push_back
(
shape
[
2
]);
device_shape
.
push_back
(
shape
[
3
]);
device_shape
.
push_back
(
shape
[
1
]);
device_shape
.
push_back
(
temp_
shape
[
0
]);
device_shape
.
push_back
(
temp_
shape
[
2
]);
device_shape
.
push_back
(
temp_
shape
[
3
]);
device_shape
.
push_back
(
temp_
shape
[
1
]);
return
device_shape
;
}
else
if
(
format
==
kOpFormat_NCHW
)
{
return
shape
;
}
else
if
(
format
==
kOpFormat_HWCN
)
{
return
{
shape
[
2
],
shape
[
3
],
shape
[
1
],
shape
[
0
]};
return
{
temp_shape
[
2
],
temp_shape
[
3
],
temp_shape
[
1
],
temp_shape
[
0
]};
}
else
if
(
format
==
kOpFormat_NCHW
)
{
return
temp_shape
;
}
MS_LOG
(
EXCEPTION
)
<<
"Unexpected format["
<<
format
<<
"]"
;
}
...
...
mindspore/ccsrc/common/trans.h
浏览文件 @
5d225f93
...
...
@@ -24,6 +24,7 @@
#include <utility>
#include <vector>
#include "ir/dtype.h"
#include "kernel/kernel.h"
#include "ir/dtype/type.h"
namespace
mindspore
{
...
...
@@ -49,7 +50,10 @@ size_t TypeIdSize(const TypeId data_type);
size_t
ShapeSize
(
const
std
::
vector
<
size_t
>
&
shape
);
size_t
CubeSizeByType
(
const
TypeId
data_type
);
std
::
vector
<
size_t
>
TransShapeTo4d
(
const
std
::
vector
<
size_t
>
&
shape
);
std
::
vector
<
size_t
>
PaddingShapeTo4d
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
kernel
::
Axis
>
&
padding_axis
=
{});
std
::
vector
<
int
>
GetRuntimePaddingShape
(
const
AnfNodePtr
&
node
,
size_t
index
);
bool
IsNeedPadding
(
const
std
::
string
&
format
,
const
size_t
shape_size
);
std
::
vector
<
size_t
>
TransShapeToDevice
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
string
&
format
);
bool
TransDataType
(
const
TypeIdArgs
&
args
,
void
*
result
);
bool
TransFormat
(
const
FormatArgs
&
args
,
void
*
result
);
...
...
mindspore/ccsrc/device/ascend/ascend_device_address.cc
浏览文件 @
5d225f93
...
...
@@ -141,7 +141,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
if
(
format_
==
kOpFormat_FRAC_NZ
)
{
device_shape
=
trans
::
TransShapeToDevice
(
host_shape
,
format_
);
}
else
{
host_shape
=
trans
::
Trans
ShapeTo4d
(
host_shape
);
host_shape
=
trans
::
Padding
ShapeTo4d
(
host_shape
);
device_shape
=
trans
::
TransShapeToDevice
(
host_shape
,
format_
);
}
if
(
type_id_
!=
type
)
{
...
...
@@ -224,7 +224,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
if
(
format_
==
kOpFormat_FRAC_NZ
)
{
device_shape
=
trans
::
TransShapeToDevice
(
host_shape
,
format_
);
}
else
{
host_shape
=
trans
::
Trans
ShapeTo4d
(
host_shape
);
host_shape
=
trans
::
Padding
ShapeTo4d
(
host_shape
);
device_shape
=
trans
::
TransShapeToDevice
(
host_shape
,
format_
);
}
if
(
type_id_
!=
type
)
{
...
...
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
浏览文件 @
5d225f93
...
...
@@ -27,6 +27,7 @@
#include "utils/context/ms_context.h"
#include "device/ascend/profiling/profiling_manager.h"
#include "hccl/hcom.h"
#include "common/trans.h"
#include "runtime/context.h"
#include "device/ascend/ascend_stream_assign.h"
#include "device/ascend/ascend_memory_pool.h"
...
...
@@ -150,7 +151,7 @@ void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path,
auto
output_size
=
AnfAlgo
::
GetOutputTensorNum
(
node
);
for
(
size_t
j
=
0
;
j
<
output_size
;
++
j
)
{
auto
addr
=
AnfAlgo
::
GetOutputAddr
(
node
,
j
);
auto
shape
=
AnfAlgo
::
GetOutputInfer
Shape
(
node
,
j
);
auto
shape
=
trans
::
GetRuntimePadding
Shape
(
node
,
j
);
auto
type
=
AnfAlgo
::
GetOutputInferDataType
(
node
,
j
);
auto
format
=
kOpFormat_DEFAULT
;
string
filepath
=
dump_path
+
'/'
+
kernel_name
+
'_'
+
"output_"
+
std
::
to_string
(
j
);
...
...
@@ -181,7 +182,7 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p
continue
;
}
auto
addr
=
AnfAlgo
::
GetOutputAddr
(
item
,
PRAMATER_OUTPUT_INDEX
);
auto
shape
=
AnfAlgo
::
GetOutputInfer
Shape
(
item
,
PRAMATER_OUTPUT_INDEX
);
auto
shape
=
trans
::
GetRuntimePadding
Shape
(
item
,
PRAMATER_OUTPUT_INDEX
);
auto
type
=
AnfAlgo
::
GetOutputInferDataType
(
item
,
PRAMATER_OUTPUT_INDEX
);
auto
format
=
kOpFormat_DEFAULT
;
string
filepath
=
dump_path
+
'/'
+
parameter_name
+
'_'
+
"output_0"
;
...
...
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
浏览文件 @
5d225f93
...
...
@@ -184,7 +184,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
}
if
(
kernel_build_info
.
GetInputFormat
(
input_index
)
==
AnfAlgo
::
GetPrevNodeOutputFormat
(
kernel_node
,
input_index
))
{
if
(
AnfAlgo
::
IsFeatureMapInput
(
kernel_node
,
input_index
)
&&
k
SpecialFormatSet
.
find
(
kernel_build_info
.
GetInputFormat
(
input_index
))
!=
kSpecial
FormatSet
.
end
())
{
k
NeedTransFormatSet
.
find
(
kernel_build_info
.
GetInputFormat
(
input_index
))
!=
kNeedTrans
FormatSet
.
end
())
{
(
*
cur_kernelinfo_match_counts
)[
MATCH_SPECIAL_FORMAT_COUNT
]
++
;
}
(
*
cur_kernelinfo_match_counts
)[
MATCH_FORMAT_COUNT
]
++
;
...
...
@@ -210,19 +210,22 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
(
*
cur_kernelinfo_match_counts
)[
MATCH_OUTPUT_DTYPE_COUNT
]
++
;
}
}
}
// namespace
}
void
SetTensorDeviceInfo
(
const
kernel
::
KernelBuildInfo
&
selected_kernel_info
,
const
CNodePtr
&
kernel_node
)
{
MS_EXCEPTION_IF_NULL
(
kernel_node
);
for
(
size_t
input_index
=
0
;
input_index
<
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
++
input_index
)
{
auto
input_kernel_node
=
AnfAlgo
::
GetInputNode
(
kernel_node
,
input_index
);
MS_EXCEPTION_IF_NULL
(
input_kernel_node
);
if
(
AnfAlgo
::
IsFeatureMapInput
(
kernel_node
,
input_index
))
{
continue
;
}
auto
input_with_index
=
AnfAlgo
::
VisitKernel
(
input_kernel_node
,
0
);
MS_EXCEPTION_IF_NULL
(
input_with_index
.
first
);
auto
real_input_node
=
input_with_index
.
first
;
if
(
real_input_node
->
isa
<
CNode
>
())
{
continue
;
}
if
(
real_input_node
->
isa
<
Parameter
>
()
&&
!
AnfAlgo
::
IsParameterWeight
(
real_input_node
->
cast
<
ParameterPtr
>
()))
{
continue
;
}
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// we set special device info of a input tensor.
...
...
mindspore/ccsrc/device/kernel_adjust.cc
浏览文件 @
5d225f93
...
...
@@ -25,6 +25,7 @@
#include "session/anf_runtime_algorithm.h"
#include "utils/context/ms_context.h"
#include "common/trans.h"
#include "utils/config_manager.h"
#include "common/utils.h"
#include "kernel/kernel_build_info.h"
...
...
@@ -391,7 +392,8 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &c
auto
device_address
=
AnfAlgo
::
GetMutableOutputAddr
(
pk_node
,
0
);
MS_EXCEPTION_IF_NULL
(
device_address
);
tensor
->
set_device_address
(
device_address
);
if
(
!
device_address
->
SyncHostToDevice
(
tensor
->
shape
(),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
if
(
!
device_address
->
SyncHostToDevice
(
trans
::
GetRuntimePaddingShape
(
pk_node
,
0
),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
tensor
->
data_c
(
false
)))
{
MS_LOG
(
INFO
)
<<
"SyncHostToDevice failed."
;
return
false
;
...
...
mindspore/ccsrc/device/kernel_info.h
浏览文件 @
5d225f93
...
...
@@ -31,6 +31,7 @@ class KernelInfo {
public:
KernelInfo
()
{
kernel_mod_
=
nullptr
;
is_feature_map_
=
false
;
select_kernel_build_info_
=
nullptr
;
output_address_list_
=
{};
workspace_address_list_
=
{};
...
...
@@ -45,6 +46,7 @@ class KernelInfo {
void
set_select_kernel_build_info
(
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
select_kernel_build_info_
=
select_kernel_build_info
;
}
void
SetFeatureMapFlag
(
bool
flag
)
{
is_feature_map_
=
flag
;
}
const
DeviceAddress
*
GetOutputAddr
(
size_t
index
)
const
;
DeviceAddressPtr
GetMutableOutputAddr
(
size_t
index
)
const
;
bool
OutputAddrExist
(
size_t
index
)
const
;
...
...
@@ -63,8 +65,10 @@ class KernelInfo {
void
set_graph_id
(
uint32_t
graph_id
)
{
graph_id_
=
graph_id
;
}
uint32_t
graph_id
()
const
{
return
graph_id_
;
}
bool
operator
==
(
const
KernelInfo
&
other
)
const
;
bool
is_feature_map
()
const
{
return
is_feature_map_
;
}
private:
bool
is_feature_map_
;
kernel
::
KernelBuildInfoPtr
select_kernel_build_info_
;
std
::
vector
<
std
::
shared_ptr
<
DeviceAddress
>>
output_address_list_
;
std
::
vector
<
std
::
shared_ptr
<
DeviceAddress
>>
workspace_address_list_
;
...
...
mindspore/ccsrc/device/kernel_runtime.cc
浏览文件 @
5d225f93
...
...
@@ -105,7 +105,7 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod
std
::
vector
<
size_t
>
shape
=
AnfAlgo
::
GetOutputDeviceShape
(
node
,
output_index
);
auto
format
=
AnfAlgo
::
GetOutputFormat
(
node
,
output_index
);
if
(
shape
.
empty
()
&&
format
!=
kOpFormat_DEFAULT
)
{
shape
=
trans
::
TransShapeTo4d
(
shape
);
shape
=
trans
::
PaddingShapeTo4d
(
shape
,
AnfAlgo
::
GetOutputReshapeType
(
node
,
output_index
)
);
shape
=
trans
::
TransShapeToDevice
(
shape
,
format
);
}
// scalar's output shape is a empty vector
...
...
@@ -401,8 +401,9 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
auto
address
=
CreateDeviceAddress
(
ptr
,
node_size
,
AnfAlgo
::
GetOutputFormat
(
value_node
,
output_idx
),
output_type_id
);
MS_EXCEPTION_IF_NULL
(
address
);
AnfAlgo
::
SetOutputAddr
(
address
,
output_idx
,
value_node
.
get
());
if
(
!
address
->
SyncHostToDevice
(
tensor
->
shape
(),
tensor_size
,
tensor
->
data_type
(),
tensor
->
data_c
(
false
)))
{
MS_EXCEPTION
(
NotExistsError
)
<<
"kValueNode SyncHostToDevice fail!"
<<
value_node
->
DebugString
()
<<
"node format is"
if
(
!
address
->
SyncHostToDevice
(
trans
::
GetRuntimePaddingShape
(
value_node
,
0
),
tensor_size
,
tensor
->
data_type
(),
tensor
->
data_c
(
false
)))
{
MS_EXCEPTION
(
NotExistsError
)
<<
"ValueNode SyncHostToDevice fail!"
<<
value_node
->
DebugString
()
<<
"node format is"
<<
AnfAlgo
::
GetOutputFormat
(
value_node
,
output_idx
)
<<
"node dtype is "
<<
AnfAlgo
::
GetOutputInferDataType
(
value_node
,
output_idx
);
}
...
...
@@ -421,19 +422,6 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL
(
node_value
);
if
(
node_value
->
isa
<
Tensor
>
())
{
AssignValueNodeTensor
(
value_node
,
node_value
,
0
);
}
else
if
(
node_value
->
isa
<
ValueTuple
>
())
{
auto
value_tuple
=
node_value
->
cast
<
ValueTuplePtr
>
();
if
(
value_tuple
==
nullptr
)
{
MS_LOG
(
WARNING
)
<<
"value_tuple is null"
;
continue
;
}
size_t
i
=
0
;
auto
value_list
=
value_tuple
->
value
();
for
(
auto
value_ptr
:
value_list
)
{
if
(
value_ptr
->
isa
<
Tensor
>
())
{
AssignValueNodeTensor
(
value_node
,
value_ptr
,
i
++
);
}
}
}
else
if
(
node_value
->
isa
<
StringImm
>
())
{
auto
value
=
GetValue
<
std
::
string
>
(
node_value
);
size_t
tensor_size
=
value
.
size
();
...
...
mindspore/ccsrc/kernel/kernel_build_info.cc
浏览文件 @
5d225f93
...
...
@@ -59,30 +59,20 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); }
size_t
KernelBuildInfo
::
GetOutputNum
()
const
{
return
outputs_format_
.
size
();
}
bool
KernelBuildInfo
::
GetInputReshapeType
(
size_t
input_index
,
std
::
vector
<
Axis
>
*
reshape_type
)
const
{
MS_EXCEPTION_IF_NULL
(
reshape_type
);
reshape_type
->
clear
();
std
::
vector
<
Axis
>
KernelBuildInfo
::
GetInputReshapeType
(
size_t
input_index
)
const
{
if
(
input_index
>=
input_reshape_type_
.
size
())
{
MS_LOG
(
WARNING
)
<<
"The index ["
<<
input_index
<<
"] is exceed the number of input node size "
<<
input_reshape_type_
.
size
();
return
false
;
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
input_index
<<
"] is exceed the number of input node size "
<<
input_reshape_type_
.
size
();
}
(
void
)
std
::
copy
(
input_reshape_type_
[
input_index
].
begin
(),
input_reshape_type_
[
input_index
].
end
(),
std
::
inserter
(
*
reshape_type
,
(
*
reshape_type
).
begin
()));
return
true
;
return
input_reshape_type_
[
input_index
];
}
bool
KernelBuildInfo
::
GetOutputReshapeType
(
size_t
output_index
,
std
::
vector
<
Axis
>
*
reshape_type
)
const
{
MS_EXCEPTION_IF_NULL
(
reshape_type
);
reshape_type
->
clear
();
std
::
vector
<
Axis
>
KernelBuildInfo
::
GetOutputReshapeType
(
size_t
output_index
)
const
{
if
(
output_index
>=
output_reshape_type_
.
size
())
{
MS_LOG
(
WARNING
)
<<
"The index ["
<<
output_index
<<
"] is exceed the number of output node dixr"
<<
output_reshape_type_
.
size
();
return
false
;
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_index
<<
"] is exceed the number of output node size "
<<
output_reshape_type_
.
size
();
}
(
void
)
std
::
copy
(
output_reshape_type_
[
output_index
].
begin
(),
output_reshape_type_
[
output_index
].
end
(),
std
::
inserter
(
*
reshape_type
,
(
*
reshape_type
).
begin
()));
return
true
;
return
output_reshape_type_
[
output_index
];
}
std
::
string
KernelBuildInfo
::
ToString
()
const
{
...
...
@@ -115,6 +105,10 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
return
!
(
inputs_device_type_
!=
other
.
inputs_device_type_
||
outputs_device_type_
!=
other
.
outputs_device_type_
);
}
bool
KernelBuildInfo
::
IsInputDefaultPadding
()
const
{
return
output_reshape_type_
.
empty
();
}
bool
KernelBuildInfo
::
IsOutputDefaultPadding
()
const
{
return
input_reshape_type_
.
empty
();
}
void
KernelBuildInfo
::
KernelBuildInfoBuilder
::
SetKernelType
(
const
KernelType
&
kernel_type
)
{
MS_EXCEPTION_IF_NULL
(
kernel_build_info_
);
kernel_build_info_
->
kernel_type_
=
kernel_type
;
...
...
mindspore/ccsrc/kernel/kernel_build_info.h
浏览文件 @
5d225f93
...
...
@@ -54,9 +54,13 @@ class KernelBuildInfo {
TypeId
GetOutputDeviceType
(
size_t
output_index
)
const
;
bool
GetInputReshapeType
(
size_t
input_index
,
std
::
vector
<
Axis
>
*
reshape_type
)
const
;
std
::
vector
<
Axis
>
GetInputReshapeType
(
size_t
input_index
)
const
;
bool
GetOutputReshapeType
(
size_t
input_index
,
std
::
vector
<
Axis
>
*
reshape_type
)
const
;
bool
IsInputDefaultPadding
()
const
;
bool
IsOutputDefaultPadding
()
const
;
std
::
vector
<
Axis
>
GetOutputReshapeType
(
size_t
input_index
)
const
;
std
::
vector
<
std
::
string
>
GetAllInputFormats
()
const
;
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
浏览文件 @
5d225f93
...
...
@@ -18,20 +18,21 @@
#include <set>
#include "common/trans.h"
#include "common/utils.h"
#include "utils/utils.h"
#include "device/kernel_info.h"
#include "kernel/oplib/oplib.h"
#include "operator/ops.h"
#include "session/anf_runtime_algorithm.h"
#include "session/kernel_graph.h"
#include "utils/context/ms_context.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
using
KernelBuildInfoBuilder
=
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
;
namespace
{
kernel
::
KernelBuildInfoPtr
CreateKernelBuildInfo
(
const
std
::
string
&
input_format
,
const
std
::
string
&
output_format
,
const
AnfNodePtr
&
node
,
const
kernel
::
KernelBuildInfo
ori_build_info
)
{
kernel
::
KernelBuildInfoPtr
RefreshKernelBuildInfo
(
const
std
::
string
&
input_format
,
const
std
::
string
&
output_format
,
const
AnfNodePtr
&
node
,
const
kernel
::
KernelBuildInfo
ori_build_info
)
{
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
input_format
});
builder
.
SetOutputsFormat
({
output_format
});
...
...
@@ -54,9 +55,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
CNodePtr
trans_node
=
func_graph
->
NewCNode
(
trans_inputs
);
MS_EXCEPTION_IF_NULL
(
trans_node
);
if
(
need_padding
)
{
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
input
,
0
)},
{
trans
::
TransShapeTo4d
(
AnfAlgo
::
GetOutputInferShape
(
input
,
0
))},
trans_node
.
get
());
// if need padding we should set the transdata node's shape to the padding shape
AnfAlgo
::
SetOutputInferTypeAndShape
(
{
AnfAlgo
::
GetOutputInferDataType
(
input
,
0
)},
{
trans
::
PaddingShapeTo4d
(
AnfAlgo
::
GetOutputInferShape
(
input
,
0
),
AnfAlgo
::
GetOutputReshapeType
(
input
,
0
))},
trans_node
.
get
());
}
else
{
AnfAlgo
::
SetOutputInferTypeAndShape
({
AnfAlgo
::
GetOutputInferDataType
(
input
,
0
)},
{
AnfAlgo
::
GetOutputInferShape
(
input
,
0
)},
trans_node
.
get
());
...
...
@@ -92,9 +95,11 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
AnfNodePtr
GetTransInputNodePtr
(
const
FuncGraphPtr
&
func_graph
,
const
CNodePtr
&
node
,
size_t
index
,
const
KernelSelectPtr
&
kernel_select
)
{
MS_EXCEPTION_IF_NULL
(
node
);
bool
padding_flag
=
false
;
auto
input_node
=
AnfAlgo
::
GetInputNode
(
node
,
index
);
if
(
input_node
->
isa
<
ValueNode
>
()
||
input_node
->
isa
<
Parameter
>
())
{
auto
node_with_index
=
AnfAlgo
::
VisitKernel
(
input_node
,
0
);
MS_EXCEPTION_IF_NULL
(
node_with_index
.
first
);
auto
real_input
=
node_with_index
.
first
;
if
(
real_input
->
isa
<
ValueNode
>
()
||
real_input
->
isa
<
Parameter
>
())
{
input_node
=
InsertTransOpForOutput
(
func_graph
,
input_node
,
kernel_select
);
MS_EXCEPTION_IF_NULL
(
input_node
);
AnfAlgo
::
SetNodeInput
(
node
,
input_node
,
index
);
...
...
@@ -106,33 +111,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
node
,
index
);
std
::
string
origin_format
=
kOpFormat_DEFAULT
;
std
::
string
dest_format
=
AnfAlgo
::
GetInputFormat
(
node
,
index
);
if
(
dest_format
==
kOpFormat_C1HWNCoC0
)
{
padding_flag
=
(
origin_shape
.
size
()
!=
kShape4dDims
);
AnfNodePtr
replace_input
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
index
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
true
);
MS_EXCEPTION_IF_NULL
(
replace_input
);
return
replace_input
;
}
if
(
dest_format
==
kOpFormat_NC1HWC0
&&
origin_shape
.
size
()
>
1
)
{
padding_flag
=
(
origin_shape
.
size
()
!=
kShape4dDims
);
AnfNodePtr
replace_input
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
index
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
true
);
MS_EXCEPTION_IF_NULL
(
replace_input
);
MS_LOG
(
DEBUG
)
<<
"Inserted Translate45, index: "
<<
index
;
return
replace_input
;
}
else
if
(
dest_format
==
kOpFormat_FRAC_NZ
)
{
AnfNodePtr
replace_input
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
index
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
true
);
MS_EXCEPTION_IF_NULL
(
replace_input
);
MS_LOG
(
DEBUG
)
<<
"inserted translate "
<<
AnfAlgo
::
GetInputFormat
(
node
,
index
)
<<
" To default, index: "
<<
index
;
return
replace_input
;
}
else
if
(
dest_format
==
kOpFormat_FRAC_Z
&&
!
origin_shape
.
empty
())
{
padding_flag
=
(
origin_shape
.
size
()
!=
kShape4dDims
);
AnfNodePtr
replace_input
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
index
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
true
);
MS_EXCEPTION_IF_NULL
(
replace_input
);
MS_LOG
(
DEBUG
)
<<
"Inserted Translate45, index: "
<<
index
;
return
replace_input
;
if
(
kNeedTransFormatSet
.
find
(
dest_format
)
!=
kNeedTransFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
MS_LOG
(
DEBUG
)
<<
node
->
DebugString
()
<<
"Insert transdata "
<<
AnfAlgo
::
GetInputFormat
(
node
,
index
)
<<
" To DefaultFormat , index: "
<<
index
;
return
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
index
,
origin_format
,
dest_format
,
kTransDataOpName
,
true
);
}
return
input_node
;
}
...
...
@@ -140,7 +123,6 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
AnfNodePtr
InsertTransOpForSingleOutput
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
KernelSelectPtr
&
kernel_select
)
{
MS_EXCEPTION_IF_NULL
(
node
);
bool
padding_flag
=
false
;
std
::
string
output_format
;
std
::
vector
<
size_t
>
origin_shape
;
if
(
!
AnfAlgo
::
IsRealKernel
(
node
))
{
...
...
@@ -156,46 +138,14 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
}
std
::
string
origin_format
=
output_format
;
std
::
string
dest_format
=
kOpFormat_DEFAULT
;
if
(
output_format
==
kOpFormat_C1HWNCoC0
)
{
padding_flag
=
(
origin_shape
.
size
()
!=
kShape4dDims
);
AnfNodePtr
replace_input
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
0
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
false
);
MS_EXCEPTION_IF_NULL
(
replace_input
);
return
replace_input
;
}
if
(
output_format
==
kOpFormat_NC1HWC0
&&
origin_shape
.
size
()
>
1
)
{
padding_flag
=
(
origin_shape
.
size
()
!=
kShape4dDims
);
AnfNodePtr
replace_output
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
0
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
false
);
MS_EXCEPTION_IF_NULL
(
replace_output
);
MS_LOG
(
DEBUG
)
<<
"Inserted Trans54"
;
return
replace_output
;
}
else
if
(
output_format
==
kOpFormat_FRAC_NZ
)
{
AnfNodePtr
replace_output
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
0
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
false
);
MS_EXCEPTION_IF_NULL
(
replace_output
);
MS_LOG
(
DEBUG
)
<<
"Inserted Translate "
<<
output_format
<<
" To default, index: 0"
;
return
replace_output
;
}
else
if
(
output_format
==
kOpFormat_FRAC_Z
&&
!
origin_shape
.
empty
())
{
padding_flag
=
(
origin_shape
.
size
()
!=
kShape4dDims
);
AnfNodePtr
replace_output
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
0
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
false
);
MS_EXCEPTION_IF_NULL
(
replace_output
);
MS_LOG
(
DEBUG
)
<<
"Inserted Trans54"
;
return
replace_output
;
if
(
kNeedTransFormatSet
.
find
(
output_format
)
!=
kNeedTransFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
MS_LOG
(
DEBUG
)
<<
"Inserted Transdata "
<<
output_format
<<
" To default , index :0"
;
return
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select
,
0
,
origin_format
,
dest_format
,
kTransDataOpName
,
false
);
}
return
node
;
}
void
GetTransDataInputFormat
(
const
AnfNodePtr
&
node
,
size_t
idx
,
std
::
string
*
input_format
)
{
MS_EXCEPTION_IF_NULL
(
input_format
);
if
(
AnfAlgo
::
IsRealKernel
(
node
))
{
*
input_format
=
AnfAlgo
::
GetOutputFormat
(
node
,
idx
);
}
else
{
*
input_format
=
AnfAlgo
::
GetPrevNodeOutputFormat
(
node
,
0
);
}
}
AnfNodePtr
InsertTransOpForMultipleOutput
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
KernelSelectPtr
&
kernel_select
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
...
...
@@ -203,46 +153,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
std
::
vector
<
AnfNodePtr
>
make_tuple_inputs
;
make_tuple_inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
for
(
size_t
output_idx
=
0
;
output_idx
<
AnfAlgo
::
GetOutputTensorNum
(
node
);
++
output_idx
)
{
bool
padding_flag
=
false
;
std
::
string
output_format
;
GetTransDataInputFormat
(
node
,
output_idx
,
&
output_format
);
std
::
string
output_format
=
AnfAlgo
::
GetOutputFormat
(
node
,
output_idx
);
if
(
output_format
==
kOpFormat_NC1KHKWHWC0
)
{
MS_LOG
(
EXCEPTION
)
<<
"
got the hw
format"
<<
output_format
<<
" when insert the transdata node "
MS_LOG
(
EXCEPTION
)
<<
"
Got the special
format"
<<
output_format
<<
" when insert the transdata node "
<<
node
->
DebugString
();
}
auto
tuple_getitem
=
CreatTupleGetItemNode
(
func_graph
,
node
,
output_idx
);
std
::
vector
<
size_t
>
origin_shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
output_idx
);
std
::
string
origin_format
=
output_format
;
std
::
string
dest_format
=
kOpFormat_DEFAULT
;
if
(
output_format
==
kOpFormat_C1HWNCoC0
)
{
padding_flag
=
(
origin_shape
.
size
()
!=
kShape4dDims
);
AnfNodePtr
replace_input
=
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
false
);
MS_EXCEPTION_IF_NULL
(
replace_input
);
return
replace_input
;
}
if
(
output_format
==
kOpFormat_NC1HWC0
&&
origin_shape
.
size
()
>
1
)
{
padding_flag
=
(
origin_shape
.
size
()
!=
kShape4dDims
);
// Insert a 5to4 trans op.
AnfNodePtr
replace_output
=
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
false
);
MS_EXCEPTION_IF_NULL
(
replace_output
);
MS_LOG
(
DEBUG
)
<<
"Inserted Translate54"
;
make_tuple_inputs
.
push_back
(
replace_output
);
}
else
if
(
output_format
==
kOpFormat_FRAC_NZ
)
{
AnfNodePtr
replace_output
=
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
false
);
MS_EXCEPTION_IF_NULL
(
replace_output
);
MS_LOG
(
DEBUG
)
<<
"Inserted Translate "
<<
output_format
<<
" To default, index: "
<<
output_idx
;
make_tuple_inputs
.
push_back
(
replace_output
);
}
else
if
(
output_format
==
kOpFormat_FRAC_Z
&&
!
origin_shape
.
empty
())
{
padding_flag
=
(
origin_shape
.
size
()
!=
kShape4dDims
);
AnfNodePtr
replace_output
=
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
padding_flag
,
origin_format
,
dest_format
,
kTransDataOpName
,
false
);
MS_EXCEPTION_IF_NULL
(
replace_output
);
MS_LOG
(
DEBUG
)
<<
"Inserted Translate54"
;
make_tuple_inputs
.
push_back
(
replace_output
);
if
(
kNeedTransFormatSet
.
find
(
output_format
)
!=
kNeedTransFormatSet
.
end
()
&&
origin_shape
.
size
()
>
1
)
{
make_tuple_inputs
.
emplace_back
(
AddTransOpNodeToGraph
(
func_graph
,
tuple_getitem
,
kernel_select
,
0
,
output_format
,
dest_format
,
kTransDataOpName
,
false
));
}
else
{
// No need insert trans op.
make_tuple_inputs
.
push_back
(
tuple_getitem
);
...
...
@@ -253,16 +174,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
}
}
// namespace
AnfNodePtr
AddTransOpNodeToGraph
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
KernelSelectPtr
&
kernel_select
,
size_t
insert_index
,
const
bool
padding_flag
,
const
KernelSelectPtr
&
kernel_select
,
size_t
insert_index
,
const
std
::
string
&
origin_format
,
const
std
::
string
&
dest_format
,
const
std
::
string
&
op_name
,
bool
is_insert_input
)
{
AnfNodePtr
trans_node
=
nullptr
;
AnfNodePtr
input_node
=
n
ullptr
;
AnfNodePtr
input_node
=
n
ode
;
AnfNodePtr
trans_data
=
nullptr
;
MS_EXCEPTION_IF_NULL
(
node
);
if
(
origin_format
.
empty
()
||
dest_format
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"trans op format is error, origin = "
<<
origin_format
<<
", dest "
<<
origin_format
;
}
// if insert transdata for input we need to change the input
if
(
is_insert_input
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"cannot insert a transdata node to a node's input which the node is not a cnode"
;
...
...
@@ -270,29 +192,34 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
input_node
=
AnfAlgo
::
GetInputNode
(
cnode
,
insert_index
);
if
(
padding_flag
)
{
auto
padd_shape
=
trans
::
TransShapeTo4d
(
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
));
auto
reshape_node
=
CreateReshapeNode
(
func_graph
,
input_node
,
kernel_select
,
padd_shape
);
trans_data
=
NewTransOpNode
(
func_graph
,
reshape_node
,
kernel_select
,
padding_flag
,
op_name
);
}
else
{
trans_data
=
NewTransOpNode
(
func_graph
,
input_node
,
kernel_select
,
padding_flag
,
op_name
);
}
}
bool
need_padding
=
(
trans
::
IsNeedPadding
(
dest_format
,
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
).
size
())
&&
op_name
==
kTransDataOpName
);
if
(
!
need_padding
)
{
// don't need padding insert transdata only
trans_data
=
NewTransOpNode
(
func_graph
,
input_node
,
kernel_select
,
need_padding
,
op_name
);
trans_node
=
trans_data
;
}
else
if
(
is_insert_input
)
{
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto
padding_shape
=
trans
::
PaddingShapeTo4d
(
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
),
AnfAlgo
::
GetInputReshapeType
(
node
,
0
));
auto
reshape_node
=
CreateReshapeNode
(
func_graph
,
input_node
,
kernel_select
,
padding_shape
);
trans_data
=
NewTransOpNode
(
func_graph
,
reshape_node
,
kernel_select
,
need_padding
,
op_name
);
trans_node
=
trans_data
;
}
else
{
input_node
=
node
;
trans_data
=
NewTransOpNode
(
func_graph
,
input_node
,
kernel_select
,
padding_flag
,
op_name
);
if
(
padding_flag
)
{
auto
reshape_node
=
CreateReshapeNode
(
func_graph
,
trans_data
,
kernel_select
,
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
));
trans_node
=
reshape_node
;
}
else
{
trans_node
=
trans_data
;
}
// if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape]
trans_data
=
NewTransOpNode
(
func_graph
,
input_node
,
kernel_select
,
need_padding
,
op_name
);
auto
reshape_node
=
CreateReshapeNode
(
func_graph
,
trans_data
,
kernel_select
,
AnfAlgo
::
GetOutputInferShape
(
input_node
,
0
));
trans_node
=
reshape_node
;
}
// refresh the transdata's format to ori format & dst format
MS_EXCEPTION_IF_NULL
(
trans_data
);
MS_EXCEPTION_IF_NULL
(
trans_data
->
kernel_info
());
auto
trans_ori_build_info
=
trans_data
->
kernel_info
()
->
select_kernel_build_info
();
auto
kernel_build_info
=
Create
KernelBuildInfo
(
origin_format
,
dest_format
,
input_node
,
*
trans_ori_build_info
);
auto
kernel_build_info
=
Refresh
KernelBuildInfo
(
origin_format
,
dest_format
,
input_node
,
*
trans_ori_build_info
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info
,
trans_data
.
get
());
return
trans_node
;
}
...
...
@@ -376,7 +303,17 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
for
(
size_t
input_index
=
0
;
input_index
<
AnfAlgo
::
GetInputTensorNum
(
cnode
);
++
input_index
)
{
TypeId
origin_type
;
auto
cur_input
=
AnfAlgo
::
GetInputNode
(
cnode
,
input_index
);
if
(
!
AnfAlgo
::
IsFeatureMapInput
(
cnode
,
input_index
))
{
auto
kernel_with_index
=
AnfAlgo
::
VisitKernel
(
cur_input
,
0
);
auto
is_weight_boundary
=
[](
const
AnfNodePtr
&
node
)
->
bool
{
if
(
node
->
isa
<
ValueNode
>
())
{
return
true
;
}
else
if
(
node
->
isa
<
Parameter
>
()
&&
AnfAlgo
::
IsParameterWeight
(
node
->
cast
<
ParameterPtr
>
()))
{
return
true
;
}
return
false
;
};
auto
real_input_node
=
kernel_with_index
.
first
;
if
(
is_weight_boundary
(
real_input_node
))
{
// weight
origin_type
=
AnfAlgo
::
GetPrevNodeOutputDeviceDataType
(
cnode
,
input_index
);
}
else
{
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
浏览文件 @
5d225f93
...
...
@@ -48,7 +48,7 @@ class KernelQuery {
using
KernelQueryPtr
=
std
::
shared_ptr
<
KernelQuery
>
;
AnfNodePtr
AddTransOpNodeToGraph
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
KernelSelectPtr
&
kernel_select
,
size_t
insert_index
,
bool
padding_flag
,
const
KernelSelectPtr
&
kernel_select
,
size_t
insert_index
,
const
std
::
string
&
origin_format
,
const
std
::
string
&
dest_format
,
const
std
::
string
&
op_name
,
bool
is_insert_input
);
...
...
mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc
浏览文件 @
5d225f93
...
...
@@ -105,10 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
// insert trans
if
(
origin_format
!=
cur_format
)
{
auto
kernel_select
=
std
::
make_shared
<
KernelSelect
>
();
bool
need_padding
=
(
cur_format
==
kOpFormat_NC1HWC0
&&
AnfAlgo
::
GetOutputInferShape
(
final_node
,
0
).
size
()
!=
kShape4dDims
);
final_node
=
AddTransOpNodeToGraph
(
func_graph
,
final_node
,
kernel_select
,
0
,
need_padding
,
cur_format
,
origin_format
,
kTransDataOpName
,
false
);
final_node
=
AddTransOpNodeToGraph
(
func_graph
,
final_node
,
kernel_select
,
0
,
cur_format
,
origin_format
,
kTransDataOpName
,
false
);
final_index
=
0
;
MS_EXCEPTION_IF_NULL
(
final_node
);
MS_LOG
(
INFO
)
<<
"DealRefTransAndCast add trans op, op debug info is "
<<
final_node
->
DebugString
();
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/transdata_split.cc
浏览文件 @
5d225f93
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/transdata_split.h"
#include <set>
#include "pre_activate/ascend/ascend_helper.h"
#include "session/anf_runtime_algorithm.h"
#include "debug/anf_ir_dump.h"
namespace
mindspore
{
namespace
opt
{
const
std
::
set
<
std
::
pair
<
string
,
string
>>
invalid_formats_pair
=
{{
kOpFormat_C1HWNCoC0
,
kOpFormat_NCHW
},
{
kOpFormat_NCHW
,
kOpFormat_C1HWNCoC0
},
{
kOpFormat_C1HWNCoC0
,
kOpFormat_DEFAULT
},
{
kOpFormat_DEFAULT
,
kOpFormat_C1HWNCoC0
}};
bool
TransDataSplit
::
Run
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
bool
changed
=
false
;
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
func_graph
->
get_return
());
for
(
auto
&
node
:
node_list
)
{
if
(
node
!=
nullptr
&&
node
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
node
)
==
kTransDataOpName
)
{
CheckCNodeInputSize
(
node
->
cast
<
CNodePtr
>
(),
kBackendTransDataInputNum
);
if
(
IsFormatInvaild
(
node
))
{
changed
=
DoSplit
(
func_graph
,
node
);
}
}
}
return
changed
;
}
bool
TransDataSplit
::
IsFormatInvaild
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
input_format
=
AnfAlgo
::
GetInputFormat
(
node
,
0
);
auto
output_format
=
AnfAlgo
::
GetOutputFormat
(
node
,
0
);
auto
format_pair
=
std
::
make_pair
(
input_format
,
output_format
);
return
invalid_formats_pair
.
find
(
format_pair
)
!=
invalid_formats_pair
.
end
();
}
// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW)
bool
TransDataSplit
::
DoSplit
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
input_node
=
node
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
input_node
);
auto
input_format
=
AnfAlgo
::
GetInputFormat
(
node
,
0
);
auto
output_format
=
AnfAlgo
::
GetOutputFormat
(
node
,
0
);
AnfNodePtr
new_transdata_node
=
nullptr
;
AnfNodePtr
new_transpose_node
=
nullptr
;
AnfNodePtr
new_replace_node
=
nullptr
;
// if output_format=default transdata need split transdata->transpose else transpose->transdata
if
(
output_format
==
kOpFormat_DEFAULT
||
output_format
==
kOpFormat_NCHW
)
{
// trans input_format to hwcn
new_transdata_node
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select_
,
0
,
false
,
input_format
,
kOpFormat_HWCN
,
kTransDataOpName
,
true
);
// trans hwcn to default_format
new_transpose_node
=
AddTransOpNodeToGraph
(
func_graph
,
new_transdata_node
,
kernel_select_
,
0
,
false
,
kOpFormat_HWCN
,
output_format
,
prim
::
kPrimTranspose
->
name
(),
false
);
AnfAlgo
::
SetNodeAttr
(
kAttrPerm
,
MakeValue
(
std
::
vector
<
int
>
{
3
,
2
,
0
,
1
}),
new_transpose_node
);
new_replace_node
=
new_transpose_node
;
}
else
{
// trans default to hwcn
new_transpose_node
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select_
,
0
,
false
,
input_format
,
kOpFormat_HWCN
,
prim
::
kPrimTranspose
->
name
(),
true
);
AnfAlgo
::
SetNodeAttr
(
kAttrPerm
,
MakeValue
(
std
::
vector
<
int
>
{
2
,
3
,
1
,
0
}),
new_transpose_node
);
// trans hwcn to output_format
new_transdata_node
=
AddTransOpNodeToGraph
(
func_graph
,
new_transpose_node
,
kernel_select_
,
0
,
false
,
kOpFormat_HWCN
,
output_format
,
kTransDataOpName
,
false
);
new_replace_node
=
new_transdata_node
;
}
FuncGraphManagerPtr
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
manager
->
AddFuncGraph
(
func_graph
);
if
(
!
manager
->
Replace
(
node
,
new_replace_node
))
{
MS_LOG
(
EXCEPTION
)
<<
"
manager replace node failed"
;
}
MS_LOG
(
INFO
)
<<
"
transdata node:"
<<
cnode
->
DebugString
()
<<
"split success."
;
return
true
;
}
}
// namespace opt
}
// namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/transdata_split.h"
#include <set>
#include "pre_activate/ascend/ascend_helper.h"
#include "session/anf_runtime_algorithm.h"
#include "debug/anf_ir_dump.h"
namespace
mindspore
{
namespace
opt
{
const
std
::
set
<
std
::
pair
<
string
,
string
>>
invalid_formats_pair
=
{{
kOpFormat_C1HWNCoC0
,
kOpFormat_NCHW
},
{
kOpFormat_NCHW
,
kOpFormat_C1HWNCoC0
},
{
kOpFormat_C1HWNCoC0
,
kOpFormat_DEFAULT
},
{
kOpFormat_DEFAULT
,
kOpFormat_C1HWNCoC0
}};
bool
TransDataSplit
::
Run
(
const
FuncGraphPtr
&
func_graph
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
bool
changed
=
false
;
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
func_graph
->
get_return
());
for
(
auto
&
node
:
node_list
)
{
if
(
node
!=
nullptr
&&
node
->
isa
<
CNode
>
()
&&
AnfAlgo
::
GetCNodeName
(
node
)
==
kTransDataOpName
)
{
CheckCNodeInputSize
(
node
->
cast
<
CNodePtr
>
(),
kBackendTransDataInputNum
);
if
(
IsFormatInvaild
(
node
))
{
changed
=
DoSplit
(
func_graph
,
node
);
}
}
}
return
changed
;
}
bool
TransDataSplit
::
IsFormatInvaild
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
input_format
=
AnfAlgo
::
GetInputFormat
(
node
,
0
);
auto
output_format
=
AnfAlgo
::
GetOutputFormat
(
node
,
0
);
auto
format_pair
=
std
::
make_pair
(
input_format
,
output_format
);
return
invalid_formats_pair
.
find
(
format_pair
)
!=
invalid_formats_pair
.
end
();
}
// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW)
bool
TransDataSplit
::
DoSplit
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
input_node
=
node
->
cast
<
CNodePtr
>
()
->
input
(
1
);
MS_EXCEPTION_IF_NULL
(
input_node
);
auto
input_format
=
AnfAlgo
::
GetInputFormat
(
node
,
0
);
auto
output_format
=
AnfAlgo
::
GetOutputFormat
(
node
,
0
);
AnfNodePtr
new_transdata_node
=
nullptr
;
AnfNodePtr
new_transpose_node
=
nullptr
;
AnfNodePtr
new_replace_node
=
nullptr
;
// if output_format=default transdata need split transdata->transpose else transpose->transdata
if
(
output_format
==
kOpFormat_DEFAULT
||
output_format
==
kOpFormat_NCHW
)
{
// trans input_format to hwcn
new_transdata_node
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select_
,
0
,
input_format
,
kOpFormat_HWCN
,
kTransDataOpName
,
true
);
// trans hwcn to default_format
new_transpose_node
=
AddTransOpNodeToGraph
(
func_graph
,
new_transdata_node
,
kernel_select_
,
0
,
kOpFormat_HWCN
,
output_format
,
prim
::
kPrimTranspose
->
name
(),
false
);
AnfAlgo
::
SetNodeAttr
(
kAttrPerm
,
MakeValue
(
std
::
vector
<
int
>
{
3
,
2
,
0
,
1
}),
new_transpose_node
);
new_replace_node
=
new_transpose_node
;
}
else
{
// trans default to hwcn
new_transpose_node
=
AddTransOpNodeToGraph
(
func_graph
,
node
,
kernel_select_
,
0
,
input_format
,
kOpFormat_HWCN
,
prim
::
kPrimTranspose
->
name
(),
true
);
AnfAlgo
::
SetNodeAttr
(
kAttrPerm
,
MakeValue
(
std
::
vector
<
int
>
{
2
,
3
,
1
,
0
}),
new_transpose_node
);
// trans hwcn to output_format
new_transdata_node
=
AddTransOpNodeToGraph
(
func_graph
,
new_transpose_node
,
kernel_select_
,
0
,
kOpFormat_HWCN
,
output_format
,
kTransDataOpName
,
false
);
new_replace_node
=
new_transdata_node
;
}
FuncGraphManagerPtr
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
manager
->
AddFuncGraph
(
func_graph
);
if
(
!
manager
->
Replace
(
node
,
new_replace_node
))
{
MS_LOG
(
EXCEPTION
)
<<
"
Manager replace node failed"
;
}
MS_LOG
(
INFO
)
<<
"
Transdata node:"
<<
cnode
->
DebugString
()
<<
"split success."
;
return
true
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/session/anf_runtime_algorithm.cc
浏览文件 @
5d225f93
...
...
@@ -289,6 +289,11 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
std
::
string
AnfRuntimeAlgorithm
::
GetOutputFormat
(
const
AnfNodePtr
&
node
,
size_t
output_idx
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
output_idx
>
GetOutputTensorNum
(
node
))
{
MS_LOG
(
EXCEPTION
)
<<
"Output index:"
<<
output_idx
<<
" is out of the node output range :"
<<
GetOutputTensorNum
(
node
)
<<
" #node ["
<<
node
->
DebugString
()
<<
"]"
;
}
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
...
@@ -298,6 +303,11 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
std
::
string
AnfRuntimeAlgorithm
::
GetInputFormat
(
const
AnfNodePtr
&
node
,
size_t
input_idx
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
input_idx
>
GetInputTensorNum
(
node
))
{
MS_LOG
(
EXCEPTION
)
<<
"Input index :"
<<
input_idx
<<
" is out of the number node Input range :"
<<
GetInputTensorNum
(
node
)
<<
"#node ["
<<
node
->
DebugString
()
<<
"]"
;
}
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
...
@@ -362,62 +372,60 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNo
std
::
vector
<
size_t
>
AnfRuntimeAlgorithm
::
GetOutputDeviceShape
(
const
AnfNodePtr
&
node
,
size_t
output_idx
)
{
auto
format
=
GetOutputFormat
(
node
,
output_idx
);
auto
infer_shape
=
GetOutputInferShape
(
node
,
output_idx
);
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if
(
format
==
kOpFormat_DEFAULT
||
format
==
kOpFormat_NC1KHKWHWC0
)
{
return
infer_shape
;
}
// scalar shape
if
(
infer_shape
.
empty
())
{
return
infer_shape
;
}
if
(
format
==
kOpFormat_FRAC_NZ
)
{
return
trans
::
TransShapeToDevice
(
infer_shape
,
format
);
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if
(
trans
::
IsNeedPadding
(
format
,
infer_shape
.
size
()))
{
infer_shape
=
trans
::
PaddingShapeTo4d
(
infer_shape
,
GetOutputReshapeType
(
node
,
output_idx
));
}
// else trans infer shape to 4d and then calculate device shape
return
trans
::
TransShapeToDevice
(
trans
::
TransShapeTo4d
(
infer_shape
),
format
);
return
trans
::
TransShapeToDevice
(
infer_shape
,
format
);
}
std
::
vector
<
size_t
>
AnfRuntimeAlgorithm
::
GetInputDeviceShape
(
const
AnfNodePtr
&
node
,
size_t
input_idx
)
{
auto
format
=
GetInputFormat
(
node
,
input_idx
);
auto
infer_shape
=
GetPrevNodeOutputInferShape
(
node
,
input_idx
);
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if
(
format
==
kOpFormat_DEFAULT
||
format
==
kOpFormat_NC1KHKWHWC0
)
{
return
infer_shape
;
}
if
(
infer_shape
.
empty
())
{
return
infer_shape
;
}
if
(
format
==
kOpFormat_FRAC_NZ
)
{
return
trans
::
TransShapeToDevice
(
infer_shape
,
format
);
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if
(
trans
::
IsNeedPadding
(
format
,
infer_shape
.
size
()))
{
infer_shape
=
trans
::
PaddingShapeTo4d
(
infer_shape
,
GetInputReshapeType
(
node
,
input_idx
));
}
// else trans infer shape to 4d and then calculate device shape
return
trans
::
TransShapeToDevice
(
trans
::
TransShapeTo4d
(
infer_shape
),
format
);
return
trans
::
TransShapeToDevice
(
infer_shape
,
format
);
}
std
::
vector
<
kernel
::
Axis
>
AnfRuntimeAlgorithm
::
GetInputReshapeType
(
const
AnfNodePtr
&
node
,
size_t
input_idx
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
input_idx
>
GetInputTensorNum
(
node
))
{
MS_LOG
(
EXCEPTION
)
<<
"The index:"
<<
input_idx
<<
" is out of range of the node's input size : "
<<
GetInputTensorNum
(
node
)
<<
"#node["
<<
node
->
DebugString
()
<<
"]"
;
}
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
std
::
vector
<
kernel
::
Axis
>
result
;
if
(
!
build_info
->
GetInputReshapeType
(
input_idx
,
&
result
))
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to get the node's[ "
<<
node
->
DebugString
()
<<
"] reshape type !"
;
if
(
build_info
->
IsInputDefaultPadding
())
{
return
{};
}
return
result
;
return
build_info
->
GetInputReshapeType
(
input_idx
)
;
}
std
::
vector
<
kernel
::
Axis
>
AnfRuntimeAlgorithm
::
GetOutputReshapeType
(
const
AnfNodePtr
&
node
,
size_t
output_idx
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
output_idx
>
GetOutputTensorNum
(
node
))
{
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_idx
<<
"] is out of range of the node's output size [ "
<<
GetOutputTensorNum
(
node
)
<<
"#node[ "
<<
node
->
DebugString
()
<<
"]"
;
}
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
std
::
vector
<
kernel
::
Axis
>
result
;
if
(
!
build_info
->
GetOutputReshapeType
(
output_idx
,
&
result
))
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to get the node's[ "
<<
node
->
DebugString
()
<<
"] reshape type !"
;
if
(
build_info
->
IsOutputDefaultPadding
())
{
return
{};
}
return
result
;
return
build_info
->
GetOutputReshapeType
(
output_idx
)
;
}
TypeId
AnfRuntimeAlgorithm
::
GetOutputInferDataType
(
const
AnfNodePtr
&
node
,
size_t
output_idx
)
{
...
...
@@ -463,6 +471,10 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &nod
TypeId
AnfRuntimeAlgorithm
::
GetOutputDeviceDataType
(
const
AnfNodePtr
&
node
,
size_t
output_idx
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
output_idx
>
GetOutputTensorNum
(
node
))
{
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_idx
<<
"] is out of range of the node's output size [ "
<<
GetOutputTensorNum
(
node
)
<<
"#node [ "
<<
node
->
DebugString
()
<<
"]"
;
}
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
...
@@ -472,6 +484,10 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
TypeId
AnfRuntimeAlgorithm
::
GetInputDeviceDataType
(
const
AnfNodePtr
&
node
,
size_t
input_idx
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
input_idx
>
GetInputTensorNum
(
node
))
{
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
input_idx
<<
"] is out of range of the node's input size [ "
<<
GetInputTensorNum
(
node
)
<<
"#node [ "
<<
node
->
DebugString
()
<<
"]"
;
}
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
...
@@ -496,11 +512,15 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
MS_LOG
(
EXCEPTION
)
<<
node
->
DebugString
()
<<
"Invalid nop node"
;
}
}
if
(
output_idx
>
GetOutputTensorNum
(
node
))
{
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_idx
<<
"] is out of range of the node's output size [ "
<<
GetOutputTensorNum
(
node
)
<<
"#node:[ "
<<
node
->
DebugString
()
<<
"]"
;
}
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
addr
=
kernel_info
->
GetOutputAddr
(
output_idx
);
if
(
addr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"
o
utput_idx "
<<
output_idx
<<
" of node "
<<
node
->
DebugString
()
MS_LOG
(
EXCEPTION
)
<<
"
O
utput_idx "
<<
output_idx
<<
" of node "
<<
node
->
DebugString
()
<<
" output addr is not exist"
;
}
return
addr
;
...
...
@@ -517,11 +537,15 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
MS_LOG
(
EXCEPTION
)
<<
node
->
DebugString
()
<<
"Invalid nop node."
;
}
}
if
(
output_idx
>
GetOutputTensorNum
(
node
))
{
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_idx
<<
"] is out of range of the node's output size [ "
<<
GetOutputTensorNum
(
node
)
<<
"#node:[ "
<<
node
->
DebugString
()
<<
"]"
;
}
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
addr
=
kernel_info
->
GetMutableOutputAddr
(
output_idx
);
if
(
addr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"
o
utput_idx"
<<
output_idx
<<
" of node "
<<
node
->
DebugString
()
MS_LOG
(
EXCEPTION
)
<<
"
O
utput_idx"
<<
output_idx
<<
" of node "
<<
node
->
DebugString
()
<<
" output addr is not exist"
;
}
return
addr
;
...
...
@@ -530,6 +554,10 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
// get output device addr of anf_node
bool
AnfRuntimeAlgorithm
::
OutputAddrExist
(
const
AnfNodePtr
&
node
,
size_t
output_idx
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
output_idx
>
GetOutputTensorNum
(
node
))
{
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_idx
<<
"] is out of range of the node's output size [ "
<<
GetOutputTensorNum
(
node
)
<<
"#node:[ "
<<
node
->
DebugString
()
<<
"]"
;
}
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
return
kernel_info
->
OutputAddrExist
(
output_idx
);
...
...
@@ -769,22 +797,24 @@ AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index)
return
node
->
input
(
get_input_index
);
}
bool
AnfRuntimeAlgorithm
::
IsFeatureMapOutput
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
if
(
node
->
isa
<
ValueNode
>
())
{
return
false
;
}
auto
kernel_info
=
node
->
kernel_info
();
MS_EXCEPTION_IF_NULL
(
kernel_info
);
return
kernel_info
->
is_feature_map
();
}
bool
AnfRuntimeAlgorithm
::
IsFeatureMapInput
(
const
AnfNodePtr
&
node
,
size_t
input_index
)
{
if
(
!
node
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot input a parameter or a valuenode to charge it's input if is a feature"
;
MS_LOG
(
EXCEPTION
)
<<
"Cannot input a parameter or a valuenode to charge it's input if is a feature
map
"
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
input_node
=
cnode
->
input
(
input_index
+
1
);
auto
node_with_index
=
VisitKernel
(
input_node
,
0
);
MS_EXCEPTION_IF_NULL
(
node_with_index
.
first
);
if
(
node_with_index
.
first
->
isa
<
ValueNode
>
())
{
return
false
;
}
if
(
node_with_index
.
first
->
isa
<
Parameter
>
())
{
return
!
AnfAlgo
::
IsParameterWeight
(
node_with_index
.
first
->
cast
<
ParameterPtr
>
());
}
return
true
;
return
IsFeatureMapOutput
(
input_node
);
}
size_t
AnfRuntimeAlgorithm
::
GetRealInputIndex
(
const
mindspore
::
AnfNodePtr
&
anf_node
,
const
size_t
cur_index
)
{
...
...
mindspore/ccsrc/session/anf_runtime_algorithm.h
浏览文件 @
5d225f93
...
...
@@ -101,7 +101,9 @@ class AnfRuntimeAlgorithm {
static
std
::
vector
<
size_t
>
GetOutputDeviceShape
(
const
AnfNodePtr
&
node
,
size_t
output_idx
);
// get input shapes which will built and run in device
static
std
::
vector
<
size_t
>
GetInputDeviceShape
(
const
AnfNodePtr
&
node
,
size_t
input_idx
);
// Get Input Padding Axis
static
std
::
vector
<
kernel
::
Axis
>
GetInputReshapeType
(
const
AnfNodePtr
&
node
,
size_t
output_idx
);
// Get Output Padding Axis
static
std
::
vector
<
kernel
::
Axis
>
GetOutputReshapeType
(
const
AnfNodePtr
&
node
,
size_t
output_idx
);
// get output data type inferred by ME of anf node
static
TypeId
GetOutputInferDataType
(
const
AnfNodePtr
&
node
,
size_t
output_idx
);
...
...
@@ -165,6 +167,9 @@ class AnfRuntimeAlgorithm {
// get graph id
static
uint32_t
GetGraphId
(
const
AnfNode
*
node
);
static
AnfNodePtr
GetInputNode
(
const
CNodePtr
&
node
,
size_t
index
);
// charge if the node's output is a feature map output
static
bool
IsFeatureMapOutput
(
const
AnfNodePtr
&
node
);
// charge if the node's input is from a feature map output
static
bool
IsFeatureMapInput
(
const
AnfNodePtr
&
node
,
size_t
input_index
);
// get real input index for some tbe ops which input order is different between me and tbe impl
static
size_t
GetRealInputIndex
(
const
AnfNodePtr
&
anf_node
,
const
size_t
cur_index
);
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
5d225f93
...
...
@@ -18,6 +18,7 @@
#include "operator/ops.h"
#include "ir/meta_tensor.h"
#include "ir/anf.h"
#include "common/trans.h"
#include "device/kernel_runtime.h"
#include "device/ascend/kernel_select_ascend.h"
#include "device/ascend/kernel_build_ascend.h"
...
...
@@ -730,8 +731,8 @@ void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor
size_t
tensor_size
=
front_tensor
->
data
().
nbytes
();
auto
addr
=
AnfAlgo
::
GetOutputAddr
(
backend_parameter
,
0
);
MS_EXCEPTION_IF_NULL
(
addr
);
if
(
!
addr
->
SyncHostToDevice
(
front_tensor
->
shape
(),
tensor_size
,
front_tensor
->
data_type
()
,
front_tensor
->
data_c
(
false
)))
{
if
(
!
addr
->
SyncHostToDevice
(
trans
::
GetRuntimePaddingShape
(
backend_parameter
,
0
),
tensor_size
,
front_tensor
->
data_
type
(),
front_tensor
->
data_
c
(
false
)))
{
MS_LOG
(
EXCEPTION
)
<<
"Tensor SyncHostToDevice fail!"
;
}
MS_LOG
(
INFO
)
<<
"Finish!"
;
...
...
mindspore/ccsrc/session/kernel_graph.cc
浏览文件 @
5d225f93
...
...
@@ -143,6 +143,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
cnode
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractNone
>
());
// create kernel_info from new parameter
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
if
(
inputs
.
size
()
==
1
||
std
::
any_of
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
[
&
](
const
AnfNodePtr
&
node
)
{
return
AnfAlgo
::
IsFeatureMapOutput
(
node
);
}))
{
kernel_info
->
SetFeatureMapFlag
(
true
);
}
cnode
->
set_kernel_info
(
kernel_info
);
AnfAlgo
::
SetGraphId
(
graph_id_
,
cnode
.
get
());
return
cnode
;
...
...
@@ -162,22 +168,26 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
ParameterPtr
KernelGraph
::
NewParameter
(
const
ParameterPtr
&
parameter
)
{
ParameterPtr
new_parameter
=
add_parameter
();
MS_EXCEPTION_IF_NULL
(
new_parameter
);
// create kernel_info form new parameter
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
size_t
output_tensor_num
=
1
;
// if use default parameter = nullptr,it remarks create a new parameter from no parameter
if
(
parameter
==
nullptr
)
{
new_parameter
->
set_abstract
(
std
::
make_shared
<
abstract
::
AbstractNone
>
());
kernel_info
->
SetFeatureMapFlag
(
true
);
}
else
{
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
new_parameter
->
set_abstract
(
parameter
->
abstract
());
new_parameter
->
set_name
(
parameter
->
name
());
if
(
parameter
->
has_default
(
))
{
if
(
AnfAlgo
::
IsParameterWeight
(
parameter
))
{
new_parameter
->
set_default_param
(
parameter
->
default_param
());
kernel_info
->
SetFeatureMapFlag
(
false
);
}
else
{
kernel_info
->
SetFeatureMapFlag
(
true
);
}
// if output is a tuple tensor,now can use for loop to handle tuple tensor
output_tensor_num
=
AnfAlgo
::
GetOutputTensorNum
(
parameter
);
}
// create kernel_info form new parameter
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
new_parameter
->
set_kernel_info
(
kernel_info
);
// create kernel_build_info for new parameter
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
...
...
@@ -217,6 +227,7 @@ std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNo
AddValueNodeToGraph
(
new_value_node
);
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
new_value_node
->
set_kernel_info
(
kernel_info
);
kernel_info
->
SetFeatureMapFlag
(
false
);
// create kernel_build_info for new value node
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// set the format of value_node to DEFAULT_FORMAT
...
...
@@ -240,6 +251,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
new_value_node
->
set_abstract
(
value_node
->
abstract
());
// create kernel_info fo new value node
auto
kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
kernel_info
->
SetFeatureMapFlag
(
false
);
new_value_node
->
set_kernel_info
(
kernel_info
);
// create kernel_build_info for new value node
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
5d225f93
...
...
@@ -20,6 +20,7 @@
#include "pipeline/parse/data_converter.h"
#include "ir/manager.h"
#include "operator/ops.h"
#include "common/trans.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "session/anf_runtime_algorithm.h"
...
...
@@ -124,7 +125,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
enable_pynative_infer
())
{
tensor
->
set_device_address
(
AnfAlgo
::
GetMutableOutputAddr
(
node
,
output_index
));
}
else
if
(
!
address
->
SyncDeviceToHost
(
tensor
->
shape
(),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
}
else
if
(
!
address
->
SyncDeviceToHost
(
trans
::
GetRuntimePaddingShape
(
node
,
output_index
),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
tensor
->
data_c
(
true
)))
{
MS_LOG
(
INFO
)
<<
"output sync device to host error!!!"
;
tensor
->
set_dirty
(
false
);
...
...
@@ -369,7 +371,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
kernel_build_info_builder
->
SetOutputsDeviceType
(
std
::
vector
<
TypeId
>
{
input_tensor
->
device_address
()
->
type_id
()});
}
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
param
.
get
());
//
cons
truct abstract of parameter
//
f
truct abstract of parameter
auto
abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
input_tensor
);
param
->
set_abstract
(
abstract
);
return
param
;
...
...
@@ -548,7 +550,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
if
(
need_sync
)
{
tensor
->
set_device_address
(
device_address
);
MS_EXCEPTION_IF_NULL
(
device_address
);
if
(
!
device_address
->
SyncHostToDevice
(
tensor
->
shape
(),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
if
(
!
device_address
->
SyncHostToDevice
(
trans
::
GetRuntimePaddingShape
(
pk_node
,
0
),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
tensor
->
data_c
(
false
)))
{
MS_LOG
(
EXCEPTION
)
<<
"SyncHostToDevice failed."
;
}
...
...
@@ -620,8 +623,8 @@ void SessionBasic::Summary(KernelGraph *graph) {
(
void
)
std
::
copy
(
shape
.
begin
(),
shape
.
end
(),
std
::
back_inserter
(
temp_shape
));
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
type_id
,
temp_shape
);
MS_EXCEPTION_IF_NULL
(
address
);
if
(
!
address
->
SyncDeviceToHost
(
t
ensor
->
shape
(),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(
),
tensor
->
data_c
(
true
)))
{
if
(
!
address
->
SyncDeviceToHost
(
t
rans
::
GetRuntimePaddingShape
(
node
,
index
),
LongToSize
(
tensor
->
data
().
nbytes
()
),
tensor
->
data_
type
(),
tensor
->
data_
c
(
true
)))
{
MS_LOG
(
ERROR
)
<<
"Failed to sync output from device to host."
;
}
tensor
->
set_dirty
(
false
);
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
5d225f93
...
...
@@ -197,8 +197,8 @@ const std::set<std::string> kOptOperatorSet = {
kApplyRMSPropOpName
,
};
const
std
::
set
<
std
::
string
>
k
Special
FormatSet
=
{
kOpFormat_FRAC_Z
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_NC1HWC0
,
kOpFormat_FRAC_NZ
,
kOpFormat_C1HWNCoC0
};
const
std
::
set
<
std
::
string
>
k
NeedTrans
FormatSet
=
{
kOpFormat_FRAC_Z
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_NC1HWC0
,
kOpFormat_FRAC_NZ
,
kOpFormat_C1HWNCoC0
};
static
inline
void
ChangeFileMode
(
const
std
::
string
&
file_name
,
mode_t
mode
)
{
if
(
access
(
file_name
.
c_str
(),
F_OK
)
!=
0
)
{
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc
浏览文件 @
5d225f93
...
...
@@ -80,6 +80,8 @@ TEST_F(TestHWLayerNormBetaGammaBackpropFusion, layernorm_beta_gamma_backprop_fus
builder1
.
SetOutputsDeviceType
({
kNumberTypeFloat32
});
cast0
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
cast1
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
cast0
->
set_abstract
(
x_abstract
);
cast1
->
set_abstract
(
x_abstract
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder1
.
Build
(),
cast0
.
get
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder1
.
Build
(),
cast1
.
get
());
...
...
tests/ut/cpp/session/anf_runtime_algorithm_test.cc
浏览文件 @
5d225f93
...
...
@@ -211,8 +211,8 @@ TEST_F(AnfRuntimeAlgorithmTest, EraseNodeAttr) {
TEST_F
(
AnfRuntimeAlgorithmTest
,
GetInputTensorNum
)
{
auto
kernel_graph
=
std
::
make_shared
<
KernelGraph
>
();
// test cnode node
auto
parameter_one
=
kernel_graph
->
add_p
arameter
();
auto
parameter_two
=
kernel_graph
->
add_p
arameter
();
auto
parameter_one
=
kernel_graph
->
NewP
arameter
();
auto
parameter_two
=
kernel_graph
->
NewP
arameter
();
std
::
vector
<
AnfNodePtr
>
add_inputs
{
NewValueNode
(
prim
::
kPrimTensorAdd
),
parameter_one
,
parameter_two
};
auto
add
=
kernel_graph
->
NewCNode
(
add_inputs
);
EXPECT_EQ
(
AnfAlgo
::
GetInputTensorNum
(
add
),
2
);
...
...
@@ -247,9 +247,11 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputTensorNum) {
TEST_F
(
AnfRuntimeAlgorithmTest
,
GetOutputFormat
)
{
auto
kernel_graph
=
std
::
make_shared
<
KernelGraph
>
();
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimTensorAdd
))
;
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
prim
::
kPrimTensorAdd
),
kernel_graph
->
NewParameter
(),
kernel_graph
->
NewParameter
()}
;
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
std
::
vector
<
size_t
>
shape
=
{
1
,
2
,
3
,
4
};
AnfAlgo
::
SetOutputInferTypeAndShape
({
kNumberTypeFloat32
,
kNumberTypeFloat32
},
{
shape
,
shape
},
add
.
get
());
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
();
...
...
@@ -266,8 +268,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) {
TEST_F
(
AnfRuntimeAlgorithmTest
,
GetInputFormat
)
{
auto
kernel_graph
=
std
::
make_shared
<
KernelGraph
>
();
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimTensorAdd
))
;
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
prim
::
kPrimTensorAdd
),
kernel_graph
->
NewParameter
(),
kernel_graph
->
NewParameter
()}
;
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
...
...
@@ -345,7 +347,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputInferShape) {
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
// test parameter node as input
auto
parameter_node
=
kernel_graph
->
add_p
arameter
();
auto
parameter_node
=
kernel_graph
->
NewP
arameter
();
MS_EXCEPTION_IF_NULL
(
parameter_node
);
parameter_node
->
set_abstract
(
x_abstract
);
EXPECT_THROW
(
AnfAlgo
::
GetPrevNodeOutputInferShape
(
parameter_node
,
0
),
std
::
runtime_error
);
...
...
@@ -387,13 +389,13 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) {
auto
kernel_graph
=
std
::
make_shared
<
KernelGraph
>
();
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
auto
parameter_one
=
kernel_graph
->
add_p
arameter
();
auto
parameter_one
=
kernel_graph
->
NewP
arameter
();
MS_EXCEPTION_IF_NULL
(
parameter_one
);
parameter_one
->
set_abstract
(
x_abstract
);
auto
parameter_two
=
kernel_graph
->
add_p
arameter
();
auto
parameter_two
=
kernel_graph
->
NewP
arameter
();
MS_EXCEPTION_IF_NULL
(
parameter_two
);
parameter_two
->
set_abstract
(
x_abstract
);
auto
parameter_third
=
kernel_graph
->
add_p
arameter
();
auto
parameter_third
=
kernel_graph
->
NewP
arameter
();
MS_EXCEPTION_IF_NULL
(
parameter_third
);
parameter_third
->
set_abstract
(
x_abstract
);
// test cnode as input
...
...
@@ -466,8 +468,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) {
TEST_F
(
AnfRuntimeAlgorithmTest
,
GetInputDeviceDataTypeTest
)
{
auto
kernel_graph
=
std
::
make_shared
<
KernelGraph
>
();
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
push_back
(
NewValueNode
(
prim
::
kPrimTensorAdd
))
;
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
prim
::
kPrimTensorAdd
),
kernel_graph
->
NewParameter
(),
kernel_graph
->
NewParameter
()}
;
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
...
...
tests/ut/cpp/session/kernel_graph_test.cc
浏览文件 @
5d225f93
...
...
@@ -140,11 +140,11 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) {
std
::
vector
<
int
>
shape
=
{
2
,
32
,
224
,
224
};
auto
abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shape
);
auto
x_parameter
=
kernel_graph
->
add_p
arameter
();
auto
x_parameter
=
kernel_graph
->
NewP
arameter
();
MS_EXCEPTION_IF_NULL
(
x_parameter
);
x_parameter
->
set_name
(
"x_parameter"
);
x_parameter
->
set_abstract
(
abstract
);
auto
y_parameter
=
kernel_graph
->
add_p
arameter
();
auto
y_parameter
=
kernel_graph
->
NewP
arameter
();
MS_EXCEPTION_IF_NULL
(
y_parameter
);
y_parameter
->
set_name
(
"y_parameter"
);
y_parameter
->
set_abstract
(
abstract
);
...
...
@@ -153,7 +153,7 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) {
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_abstract
(
abstract
);
auto
z_parameter
=
kernel_graph
->
add_p
arameter
();
auto
z_parameter
=
kernel_graph
->
NewP
arameter
();
MS_EXCEPTION_IF_NULL
(
z_parameter
);
z_parameter
->
set_name
(
"z_parameter"
);
z_parameter
->
set_abstract
(
abstract
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录