Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1d6c76f3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1d6c76f3
编写于
8月 18, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
board tensor for pynative infer
上级
c1c30a44
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
77 addition
and
48 deletion
+77
-48
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+12
-5
mindspore/core/abstract/abstract_value.cc
mindspore/core/abstract/abstract_value.cc
+32
-24
mindspore/core/abstract/abstract_value.h
mindspore/core/abstract/abstract_value.h
+33
-19
未找到文件。
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
1d6c76f3
...
...
@@ -285,12 +285,12 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe
void
PynativeInfer
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
py_args
,
OpExecInfo
*
const
op_exec_info
,
const
abstract
::
AbstractBasePtrList
&
args_spec_list
)
{
MS_LOG
(
DEBUG
)
<<
"prim "
<<
prim
->
name
()
<<
"
input infer
"
<<
mindspore
::
ToString
(
args_spec_list
);
MS_LOG
(
DEBUG
)
<<
"prim "
<<
prim
->
name
()
<<
"
input infer
"
<<
mindspore
::
ToString
(
args_spec_list
);
prim
->
BeginRecordAddAttr
();
AbstractBasePtr
infer_res
=
EvalOnePrim
(
prim
,
args_spec_list
)
->
abstract
();
prim
->
EndRecordAddAttr
();
op_exec_info
->
abstract
=
infer_res
;
MS_LOG
(
DEBUG
)
<<
"prim "
<<
prim
->
name
()
<<
"infer result "
<<
op_exec_info
->
abstract
->
ToString
();
MS_LOG
(
DEBUG
)
<<
"prim "
<<
prim
->
name
()
<<
"
infer result "
<<
op_exec_info
->
abstract
->
ToString
();
}
OpExecInfoPtr
GenerateOpExecInfo
(
const
py
::
args
&
args
)
{
...
...
@@ -632,7 +632,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
auto
obj
=
op_exec_info
->
op_inputs
[
i
];
bool
op_mask
=
py
::
hasattr
(
obj
,
"__parameter__"
);
(
*
op_masks
).
push_back
(
op_mask
);
MS_LOG
(
DEBUG
)
<<
"gen args i "
<<
i
<<
op_exec_info
->
op_name
<<
" op mask"
<<
op_mask
<<
"grad_flag_"
<<
grad_flag_
;
MS_LOG
(
DEBUG
)
<<
"gen "
<<
op_exec_info
->
op_name
<<
" arg "
<<
i
<<
": op mask "
<<
op_mask
<<
" grad_flag_ "
<<
grad_flag_
;
AnfNodePtr
node
=
nullptr
;
abstract
::
AbstractBasePtr
abs
=
nullptr
;
...
...
@@ -646,11 +647,17 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
if
(
node
!=
nullptr
&&
node
->
abstract
()
!=
nullptr
)
{
abs
=
node
->
abstract
();
}
MS_LOG
(
DEBUG
)
<<
prim
->
ToString
()
<<
" abs is nullptr "
<<
(
abs
==
nullptr
)
<<
" is_const_value "
<<
prim
->
is_const_value
();
if
(
abs
==
nullptr
||
prim
->
is_const_value
())
{
MS_LOG
(
DEBUG
)
<<
"MakeCnode get node no in map"
<<
id
;
ValuePtr
input_value
=
PyAttrValue
(
obj
);
bool
broaden
=
!
prim
->
is_const_value
()
&&
input_value
->
isa
<
tensor
::
Tensor
>
();
abs
=
abstract
::
FromValueInside
(
input_value
,
broaden
);
abs
=
input_value
->
ToAbstract
();
if
(
!
prim
->
is_const_value
())
{
auto
config
=
abstract
::
AbstractBase
::
kBroadenTensorOnly
;
abs
=
abs
->
Broaden
(
config
);
MS_LOG
(
DEBUG
)
<<
"broaden for "
<<
prim
->
ToString
()
<<
" "
<<
config
;
}
node_abs_map_
[
id
]
=
abs
;
}
(
*
args_spec_list
).
push_back
(
abs
);
...
...
mindspore/core/abstract/abstract_value.cc
浏览文件 @
1d6c76f3
...
...
@@ -66,9 +66,12 @@ ValuePtr AbstractBase::BuildValue() const {
return
value_
;
}
AbstractBasePtr
AbstractBase
::
Broaden
()
const
{
AbstractBasePtr
AbstractBase
::
Broaden
(
uint8_t
config
)
const
{
AbstractBasePtr
clone
=
Clone
();
clone
->
set_value
(
kAnyValue
);
auto
not_broaden
=
config
&
(
kBroadenTensorOnly
|
kBroadenParameterOnly
);
if
(
not_broaden
==
0
)
{
clone
->
set_value
(
kAnyValue
);
}
return
clone
;
}
...
...
@@ -85,7 +88,7 @@ std::string AbstractBase::ToString() const {
return
buffer
.
str
();
}
AbstractBasePtr
AbstractScalar
::
Broaden
(
)
const
{
return
AbstractBase
::
Broaden
(
);
}
AbstractBasePtr
AbstractScalar
::
Broaden
(
uint8_t
config
)
const
{
return
AbstractBase
::
Broaden
(
config
);
}
AbstractBasePtr
AbstractScalar
::
Join
(
const
AbstractBasePtr
&
other
)
{
MS_EXCEPTION_IF_NULL
(
other
);
...
...
@@ -224,11 +227,11 @@ AbstractBasePtrList AbstractSequeue::ElementsClone() const {
return
ele_list
;
}
AbstractBasePtrList
AbstractSequeue
::
ElementsBroaden
()
const
{
AbstractBasePtrList
AbstractSequeue
::
ElementsBroaden
(
uint8_t
config
)
const
{
AbstractBasePtrList
ele_list
;
for
(
const
auto
&
ele
:
elements_
)
{
MS_EXCEPTION_IF_NULL
(
ele
);
AbstractBasePtr
broadend
=
ele
->
Broaden
();
AbstractBasePtr
broadend
=
ele
->
Broaden
(
config
);
ele_list
.
push_back
(
broadend
);
}
return
ele_list
;
...
...
@@ -376,13 +379,13 @@ AbstractBasePtr AbstractSlice::Clone() const {
return
std
::
make_shared
<
AbstractSlice
>
(
start
,
stop
,
step
);
}
AbstractBasePtr
AbstractSlice
::
Broaden
()
const
{
AbstractBasePtr
AbstractSlice
::
Broaden
(
uint8_t
config
)
const
{
MS_EXCEPTION_IF_NULL
(
start_
);
MS_EXCEPTION_IF_NULL
(
stop_
);
MS_EXCEPTION_IF_NULL
(
step_
);
AbstractBasePtr
start
=
start_
->
Broaden
();
AbstractBasePtr
stop
=
stop_
->
Broaden
();
AbstractBasePtr
step
=
step_
->
Broaden
();
AbstractBasePtr
start
=
start_
->
Broaden
(
config
);
AbstractBasePtr
stop
=
stop_
->
Broaden
(
config
);
AbstractBasePtr
step
=
step_
->
Broaden
(
config
);
return
std
::
make_shared
<
AbstractSlice
>
(
start
,
stop
,
step
);
}
...
...
@@ -506,12 +509,15 @@ AbstractBasePtr AbstractTensor::Clone() const {
return
clone
;
}
AbstractBasePtr
AbstractTensor
::
Broaden
()
const
{
AbstractBasePtr
AbstractTensor
::
Broaden
(
uint8_t
config
)
const
{
MS_EXCEPTION_IF_NULL
(
element_
);
auto
broaden
=
std
::
make_shared
<
AbstractTensor
>
(
element_
->
Broaden
());
auto
shp
=
shape
();
broaden
->
set_shape
(
shp
->
Clone
());
broaden
->
set_value
(
kAnyValue
);
auto
not_broaden
=
config
&
kBroadenParameterOnly
;
if
(
not_broaden
==
0
)
{
broaden
->
set_value
(
kAnyValue
);
}
return
broaden
;
}
...
...
@@ -585,12 +591,12 @@ AbstractBasePtr AbstractDictionary::Clone() const {
return
std
::
make_shared
<
AbstractDictionary
>
(
kv
);
}
AbstractBasePtr
AbstractDictionary
::
Broaden
()
const
{
AbstractBasePtr
AbstractDictionary
::
Broaden
(
uint8_t
config
)
const
{
std
::
vector
<
AbstractAttribute
>
kv
;
(
void
)
std
::
transform
(
key_values_
.
begin
(),
key_values_
.
end
(),
std
::
back_inserter
(
kv
),
[](
const
AbstractAttribute
&
item
)
{
[
config
](
const
AbstractAttribute
&
item
)
{
MS_EXCEPTION_IF_NULL
(
item
.
second
);
return
std
::
make_pair
(
item
.
first
,
item
.
second
->
Broaden
());
return
std
::
make_pair
(
item
.
first
,
item
.
second
->
Broaden
(
config
));
});
return
std
::
make_shared
<
AbstractDictionary
>
(
kv
);
}
...
...
@@ -711,11 +717,11 @@ AbstractBasePtr AbstractClass::Clone() const {
return
std
::
make_shared
<
AbstractClass
>
(
tag_
,
attributes_clone
,
methods_
);
}
AbstractBasePtr
AbstractClass
::
Broaden
()
const
{
AbstractBasePtr
AbstractClass
::
Broaden
(
uint8_t
config
)
const
{
std
::
vector
<
AbstractAttribute
>
attributes_clone
;
for
(
auto
attr
:
attributes_
)
{
MS_EXCEPTION_IF_NULL
(
attr
.
second
);
AbstractBasePtr
clone
=
attr
.
second
->
Broaden
();
AbstractBasePtr
clone
=
attr
.
second
->
Broaden
(
config
);
AbstractAttribute
elem
(
attr
.
first
,
clone
);
attributes_clone
.
push_back
(
elem
);
}
...
...
@@ -843,9 +849,8 @@ TypePtr AbstractRef::BuildType() const {
}
bool
AbstractRef
::
operator
==
(
const
AbstractRef
&
other
)
const
{
return
(
*
ref_
==
*
other
.
ref_
)
&&
(
need_cast_
==
other
.
need_cast_
)
&&
return
(
*
ref_
==
*
other
.
ref_
)
&&
(
need_cast_
==
other
.
need_cast_
)
&&
(
*
ref_key_
==
*
other
.
ref_key_
)
&&
(
!
need_cast_
||
(
*
target_type_
==
*
other
.
target_type_
));
// not compare the key for reuse the graph (*ref_key_ == *other.ref_key_);
}
bool
AbstractRef
::
operator
==
(
const
AbstractBase
&
other
)
const
{
...
...
@@ -921,9 +926,12 @@ std::string AbstractNone::ToString() const {
ValuePtr
AbstractNone
::
RealBuildValue
()
const
{
return
kNone
;
}
AbstractBasePtr
AbstractRefKey
::
Broaden
()
const
{
AbstractBasePtr
AbstractRefKey
::
Broaden
(
uint8_t
config
)
const
{
auto
refkey
=
std
::
make_shared
<
AbstractRefKey
>
();
refkey
->
set_value
(
kAnyValue
);
auto
not_broaden
=
config
&
(
kBroadenTensorOnly
|
kBroadenParameterOnly
);
if
(
not_broaden
==
0
)
{
refkey
->
set_value
(
kAnyValue
);
}
return
refkey
;
}
...
...
@@ -1016,9 +1024,9 @@ AbstractBasePtr AbstractKeywordArg::Clone() const {
return
std
::
make_shared
<
AbstractKeywordArg
>
(
arg_name_
,
arg_value_
->
Clone
());
}
AbstractBasePtr
AbstractKeywordArg
::
Broaden
()
const
{
AbstractBasePtr
AbstractKeywordArg
::
Broaden
(
uint8_t
config
)
const
{
MS_EXCEPTION_IF_NULL
(
arg_value_
);
return
std
::
make_shared
<
AbstractKeywordArg
>
(
arg_name_
,
arg_value_
->
Broaden
());
return
std
::
make_shared
<
AbstractKeywordArg
>
(
arg_name_
,
arg_value_
->
Broaden
(
config
));
}
std
::
size_t
AbstractKeywordArg
::
hash
()
const
{
...
...
@@ -1123,7 +1131,7 @@ AbstractBasePtr AbstractRowTensor::Clone() const {
return
clone
;
}
AbstractBasePtr
AbstractRowTensor
::
Broaden
()
const
{
AbstractBasePtr
AbstractRowTensor
::
Broaden
(
uint8_t
config
)
const
{
MS_EXCEPTION_IF_NULL
(
element
());
auto
broaden
=
std
::
make_shared
<
AbstractRowTensor
>
(
element
()
->
Broaden
());
auto
shp
=
shape
();
...
...
@@ -1182,7 +1190,7 @@ AbstractBasePtr AbstractSparseTensor::Clone() const {
return
clone
;
}
AbstractBasePtr
AbstractSparseTensor
::
Broaden
()
const
{
AbstractBasePtr
AbstractSparseTensor
::
Broaden
(
uint8_t
config
)
const
{
MS_EXCEPTION_IF_NULL
(
element
());
auto
broaden
=
std
::
make_shared
<
AbstractSparseTensor
>
(
element
()
->
Broaden
());
auto
shp
=
shape
();
...
...
mindspore/core/abstract/abstract_value.h
浏览文件 @
1d6c76f3
...
...
@@ -69,7 +69,14 @@ class AbstractBase : public Base {
virtual
TypePtr
BuildType
()
const
=
0
;
virtual
BaseShapePtr
BuildShape
()
const
{
return
kNoShape
;
}
virtual
AbstractBasePtr
Clone
()
const
=
0
;
virtual
AbstractBasePtr
Broaden
()
const
;
// mask for Broaden config
inline
static
const
uint8_t
kBroadenTensorOnly
=
1
;
inline
static
const
uint8_t
kBroadenParameterOnly
=
2
;
// Each bit for on config.
// 00000001 -> 1: only boarden tensor
// 00000010 -> 2: only boarden parameter
virtual
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
;
virtual
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
)
{
return
shared_from_base
<
AbstractBase
>
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
shared_ptr
<
AbstractBase
>
&
a
)
{
...
...
@@ -108,7 +115,7 @@ class AbstractScalar : public AbstractBase {
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractScalar
>
(
GetValueTrack
(),
GetTypeTrack
()
->
Clone
());
}
AbstractBasePtr
Broaden
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
;
};
using
AbstractScalarPtr
=
std
::
shared_ptr
<
AbstractScalar
>
;
...
...
@@ -128,7 +135,7 @@ class AbstractType : public AbstractBase {
TypePtr
BuildType
()
const
override
{
return
std
::
make_shared
<
TypeType
>
();
}
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
()
const
override
{
return
Clone
();
}
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
{
return
Clone
();
}
};
using
AbstractTypePtr
=
std
::
shared_ptr
<
AbstractType
>
;
...
...
@@ -143,7 +150,7 @@ class AbstractError : public AbstractBase {
MS_DECLARE_PARENT
(
AbstractError
,
AbstractBase
)
TypePtr
BuildType
()
const
override
{
return
std
::
make_shared
<
Problem
>
();
}
AbstractBasePtr
Broaden
()
const
override
{
return
Clone
();
}
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
{
return
Clone
();
}
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractError
>
(
GetValueTrack
()
->
cast
<
StringImmPtr
>
(),
node_
);
...
...
@@ -180,7 +187,7 @@ class AbstractFunction : public AbstractBase {
TypePtr
BuildType
()
const
override
{
return
std
::
make_shared
<
Function
>
();
}
AbstractBasePtr
Clone
()
const
override
{
return
Copy
();
}
// For Function, no need to broaden.
AbstractBasePtr
Broaden
()
const
override
{
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
{
return
const_cast
<
AbstractFunction
*>
(
this
)
->
shared_from_base
<
AbstractFunction
>
();
}
virtual
AbstractFunctionPtr
Copy
()
const
=
0
;
...
...
@@ -209,7 +216,7 @@ class AbstractKeywordArg : public AbstractBase {
TypePtr
BuildType
()
const
override
;
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
std
::
size_t
hash
()
const
override
;
bool
operator
==
(
const
AbstractKeywordArg
&
other
)
const
;
...
...
@@ -275,7 +282,7 @@ class AbstractTensor : public AbstractUndetermined {
TypePtr
BuildType
()
const
override
;
BaseShapePtr
BuildShape
()
const
override
;
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
AbstractBasePtr
BroadenWithShape
()
const
;
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
final
;
int
format
()
const
{
return
this
->
format_
;
}
...
...
@@ -312,7 +319,7 @@ class AbstractSequeue : public AbstractBase {
TypePtrList
ElementsType
()
const
;
BaseShapePtrList
ElementsShape
()
const
;
AbstractBasePtrList
ElementsClone
()
const
;
AbstractBasePtrList
ElementsBroaden
()
const
;
AbstractBasePtrList
ElementsBroaden
(
uint8_t
config
=
0
)
const
;
template
<
typename
T
>
ValuePtr
ElementsBuildValue
()
const
;
...
...
@@ -345,7 +352,9 @@ class AbstractTuple : public AbstractSequeue {
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractTuple
>
(
ElementsClone
());
}
AbstractBasePtr
Broaden
()
const
override
{
return
std
::
make_shared
<
AbstractTuple
>
(
ElementsBroaden
());
}
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
{
return
std
::
make_shared
<
AbstractTuple
>
(
ElementsBroaden
(
config
));
}
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
{
return
ElementsJoin
<
AbstractTuple
>
(
other
);
}
...
...
@@ -372,7 +381,9 @@ class AbstractList : public AbstractSequeue {
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractList
>
(
ElementsClone
());
}
AbstractBasePtr
Broaden
()
const
override
{
return
std
::
make_shared
<
AbstractList
>
(
ElementsBroaden
());
}
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
{
return
std
::
make_shared
<
AbstractList
>
(
ElementsBroaden
(
config
));
}
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
{
return
ElementsJoin
<
AbstractList
>
(
other
);
}
...
...
@@ -403,7 +414,7 @@ class AbstractClass : public AbstractBase {
AbstractBasePtr
GetAttribute
(
const
std
::
string
&
name
);
ValuePtr
GetMethod
(
const
std
::
string
&
name
);
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
std
::
string
ToString
()
const
override
;
Named
tag
()
const
{
return
tag_
;
}
std
::
size_t
hash
()
const
override
;
...
...
@@ -428,7 +439,7 @@ class AbstractDictionary : public AbstractBase {
bool
operator
==
(
const
AbstractDictionary
&
other
)
const
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
std
::
string
ToString
()
const
override
;
std
::
size_t
hash
()
const
override
;
std
::
size_t
size
()
const
{
return
key_values_
.
size
();
}
...
...
@@ -452,7 +463,7 @@ class AbstractSlice : public AbstractBase {
bool
operator
==
(
const
AbstractSlice
&
other
)
const
;
bool
operator
==
(
const
AbstractBase
&
other
)
const
override
;
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
std
::
string
ToString
()
const
override
;
std
::
size_t
hash
()
const
override
;
AbstractBasePtr
start
()
const
{
return
start_
;
}
...
...
@@ -478,7 +489,9 @@ class AbstractJTagged : public AbstractBase {
TypePtr
BuildType
()
const
override
;
AbstractBasePtr
Clone
()
const
override
{
return
std
::
make_shared
<
AbstractJTagged
>
(
element_
->
Clone
());
}
AbstractBasePtr
Broaden
()
const
override
{
return
std
::
make_shared
<
AbstractJTagged
>
(
element_
->
Broaden
());
}
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
{
return
std
::
make_shared
<
AbstractJTagged
>
(
element_
->
Broaden
(
config
));
}
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
;
bool
operator
==
(
const
AbstractJTagged
&
other
)
const
;
...
...
@@ -558,7 +571,7 @@ class AbstractRefKey : public AbstractBase {
}
RefKeyPtr
ref_key_value
()
const
{
return
ref_key_value_
;
}
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
;
AbstractBasePtr
Broaden
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
std
::
string
ToString
()
const
override
;
private:
...
...
@@ -588,8 +601,9 @@ class AbstractRef : public AbstractBase {
inline
RefKeyPtr
ref_key_value
()
const
{
return
ref_key_value_
;
}
inline
TypePtr
target_type
()
const
{
return
target_type_
;
}
inline
bool
need_cast
()
const
{
return
need_cast_
;
}
AbstractBasePtr
Broaden
()
const
override
{
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Broaden
(),
ref_
->
Broaden
(),
need_cast_
,
target_type_
);
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
{
// always broaden for ref
return
std
::
make_shared
<
AbstractRef
>
(
ref_key_
->
Broaden
(
config
),
ref_
->
Broaden
(),
need_cast_
,
target_type_
);
}
AbstractBasePtr
Join
(
const
AbstractBasePtr
&
other
)
override
;
std
::
size_t
hash
()
const
override
{
...
...
@@ -636,7 +650,7 @@ class AbstractRowTensor : public AbstractUndetermined {
void
set_dense_shape
(
const
AbstractTuplePtr
&
dense_shape
)
{
dense_shape_
=
dense_shape
;
}
TypePtr
BuildType
()
const
override
;
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
AbstractBasePtr
BroadenWithShape
()
const
;
std
::
string
ToString
()
const
override
;
...
...
@@ -665,7 +679,7 @@ class AbstractSparseTensor : public AbstractUndetermined {
void
set_dense_shape
(
const
AbstractTuplePtr
&
dense_shape
)
{
dense_shape_
=
dense_shape
;
}
TypePtr
BuildType
()
const
override
;
AbstractBasePtr
Clone
()
const
override
;
AbstractBasePtr
Broaden
()
const
override
;
AbstractBasePtr
Broaden
(
uint8_t
config
=
0
)
const
override
;
AbstractBasePtr
BroadenWithShape
()
const
;
std
::
string
ToString
()
const
override
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录