Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1f7bf1ad
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
1f7bf1ad
编写于
7月 29, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(opr): fix the compatilibity of elemwise multitype new mode
GitOrigin-RevId: ee58271276ee4a31e11aa26c53c466f5c07dd019
上级
b3a7d149
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
244 addition
and
54 deletion
+244
-54
src/opr/impl/custom_opnode.sereg.h
src/opr/impl/custom_opnode.sereg.h
+10
-10
src/opr/impl/loop/forward_sereg.cpp
src/opr/impl/loop/forward_sereg.cpp
+3
-3
src/opr/impl/nn_int.sereg.h
src/opr/impl/nn_int.sereg.h
+66
-1
src/serialization/impl/opr_shallow_copy.cpp
src/serialization/impl/opr_shallow_copy.cpp
+23
-11
src/serialization/impl/serializer_oss_v2.cpp
src/serialization/impl/serializer_oss_v2.cpp
+5
-0
src/serialization/include/megbrain/serialization/sereg.h
src/serialization/include/megbrain/serialization/sereg.h
+33
-29
src/serialization/test/serializer_oss.cpp
src/serialization/test/serializer_oss.cpp
+104
-0
未找到文件。
src/opr/impl/custom_opnode.sereg.h
浏览文件 @
1f7bf1ad
...
@@ -46,7 +46,7 @@ mgb::cg::OperatorNodeBase* custom_loader(
...
@@ -46,7 +46,7 @@ mgb::cg::OperatorNodeBase* custom_loader(
static void entry() { \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD( \
MGB_SEREG_OPR_INTL_CALL_ADD( \
cls, ::mgb::serialization::custom_dumper, \
cls, ::mgb::serialization::custom_dumper, \
::mgb::serialization::custom_loader
);
\
::mgb::serialization::custom_loader
, true);
\
} \
} \
}; \
}; \
} \
} \
...
...
src/opr/impl/loop/forward_sereg.cpp
浏览文件 @
1f7bf1ad
...
@@ -131,10 +131,10 @@ cg::OperatorNodeBase* serialization::opr_shallow_copy_loop(
...
@@ -131,10 +131,10 @@ cg::OperatorNodeBase* serialization::opr_shallow_copy_loop(
}
}
void
LoopSerializer
::
reg_all
()
{
void
LoopSerializer
::
reg_all
()
{
MGB_SEREG_OPR_INTL_CALL_ADD
(
opr
::
Loop
,
dump_loop
,
load_loop
);
MGB_SEREG_OPR_INTL_CALL_ADD
(
opr
::
Loop
,
dump_loop
,
load_loop
,
true
);
MGB_SEREG_OPR_INTL_CALL_ADD
(
InputMaker
,
dump_input_maker
,
load_input_maker
);
MGB_SEREG_OPR_INTL_CALL_ADD
(
InputMaker
,
dump_input_maker
,
load_input_maker
,
true
);
MGB_SEREG_OPR_INTL_CALL_ADD
(
MGB_SEREG_OPR_INTL_CALL_ADD
(
CounterProvider
,
dump_counter_provider
,
load_counter_provider
);
CounterProvider
,
dump_counter_provider
,
load_counter_provider
,
true
);
MGB_SEREG_OPR_INTL_CALL_ADD_V2
(
MGB_SEREG_OPR_INTL_CALL_ADD_V2
(
opr
::
Loop
,
dump_loop
,
load_loop
,
nullptr
,
2
,
CURRENT_VERSION
);
opr
::
Loop
,
dump_loop
,
load_loop
,
nullptr
,
2
,
CURRENT_VERSION
);
...
...
src/opr/impl/nn_int.sereg.h
浏览文件 @
1f7bf1ad
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/serialization/sereg.h"
#include "megbrain/serialization/sereg.h"
...
@@ -7,10 +8,74 @@ template <>
...
@@ -7,10 +8,74 @@ template <>
struct
OprMaker
<
opr
::
ElemwiseMultiType
,
0
>
struct
OprMaker
<
opr
::
ElemwiseMultiType
,
0
>
:
public
OprMakerVariadic
<
opr
::
ElemwiseMultiType
>
{};
:
public
OprMakerVariadic
<
opr
::
ElemwiseMultiType
>
{};
template
<
>
struct
OprLoadDumpImplV2
<
opr
::
ElemwiseMultiType
,
0
>
{
using
Opr
=
opr
::
ElemwiseMultiType
;
using
PersisParam
=
opr
::
ElemwiseMultiType
::
Param
;
using
PersisElemwseiParam
=
opr
::
Elemwise
::
Param
;
static
void
dump
(
OprDumpContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr
)
{
ctx
.
write_param
<
PersisParam
>
(
opr
.
cast_final_safe
<
Opr
>
().
param
());
}
static
cg
::
OperatorNodeBase
*
replace_opr
(
cg
::
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
inputs
)
{
auto
mode
=
opr
->
cast_final_safe
<
Opr
>
().
param
().
mode
;
auto
change_to_elemwise_mode
=
[
&
](
PersisParam
::
Mode
multitype_mode
)
{
if
(
multitype_mode
==
PersisParam
::
Mode
::
EQ
)
{
return
PersisElemwseiParam
::
Mode
::
EQ
;
}
else
if
(
multitype_mode
==
PersisParam
::
Mode
::
LT
)
{
return
PersisElemwseiParam
::
Mode
::
LT
;
}
else
if
(
multitype_mode
==
PersisParam
::
Mode
::
LEQ
)
{
return
PersisElemwseiParam
::
Mode
::
LEQ
;
}
mgb_assert
(
0
,
"no supported model."
);
};
if
(
PersisParam
::
Mode
::
EQ
==
mode
||
PersisParam
::
Mode
::
LT
==
mode
||
PersisParam
::
Mode
::
LEQ
==
mode
)
{
auto
elemwise_mode
=
change_to_elemwise_mode
(
mode
);
auto
elemiwse_out
=
opr
::
Elemwise
::
make
(
inputs
,
{
elemwise_mode
});
return
opr
::
TypeCvt
::
make
(
elemiwse_out
,
dtype
::
Bool
()).
node
()
->
owner_opr
();
}
else
if
(
PersisParam
::
Mode
::
NEQ
==
mode
)
{
auto
elemiwse_out
=
opr
::
Elemwise
::
make
(
inputs
,
{
PersisElemwseiParam
::
Mode
::
EQ
});
auto
bool_out
=
opr
::
TypeCvt
::
make
(
elemiwse_out
,
dtype
::
Bool
());
return
opr
::
Elemwise
::
make
({
bool_out
},
{
PersisElemwseiParam
::
Mode
::
NOT
})
.
node
()
->
owner_opr
();
}
else
if
(
PersisParam
::
Mode
::
ISNAN
==
mode
)
{
auto
elemiwse_out
=
opr
::
Elemwise
::
make
(
{
inputs
[
0
],
inputs
[
0
]},
{
PersisElemwseiParam
::
Mode
::
EQ
});
auto
bool_out
=
opr
::
TypeCvt
::
make
(
elemiwse_out
,
dtype
::
Bool
());
return
opr
::
Elemwise
::
make
({
bool_out
},
{
PersisElemwseiParam
::
Mode
::
NOT
})
.
node
()
->
owner_opr
();
}
else
if
(
PersisParam
::
Mode
::
ISINF
==
mode
)
{
auto
input_var
=
SymbolVar
{
inputs
[
0
]};
auto
inf_var
=
input_var
.
make_scalar
(
INFINITY
);
auto
float_out
=
opr
::
TypeCvt
::
make
(
inputs
[
0
],
dtype
::
Float32
());
auto
elemiwse_out
=
opr
::
Elemwise
::
make
(
{
float_out
,
inf_var
},
{
PersisElemwseiParam
::
Mode
::
EQ
});
return
opr
::
TypeCvt
::
make
(
elemiwse_out
,
dtype
::
Bool
()).
node
()
->
owner_opr
();
}
return
opr
;
}
static
cg
::
OperatorNodeBase
*
load
(
OprLoadContext
&
ctx
,
const
cg
::
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
return
OprMaker
<
opr
::
ElemwiseMultiType
,
0
>::
make
(
ctx
.
read_param
<
PersisParam
>
(),
inputs
,
ctx
.
graph
(),
config
);
}
};
}
// namespace serialization
}
// namespace serialization
namespace
opr
{
namespace
opr
{
MGB_SEREG_OPR
(
ElemwiseMultiType
,
0
);
MGB_SEREG_OPR_CONDITION
(
ElemwiseMultiType
,
0
,
false
);
MGB_SEREG_OPR_V2
(
ElemwiseMultiType
,
0
,
(
mgb
::
serialization
::
OprLoadDumpImplV2
<
opr
::
ElemwiseMultiType
,
0
>::
replace_opr
),
VERSION_1
,
VERSION_1
);
MGB_SEREG_OPR
(
AffineInt
,
3
);
MGB_SEREG_OPR
(
AffineInt
,
3
);
}
// namespace opr
}
// namespace opr
}
// namespace mgb
}
// namespace mgb
...
...
src/serialization/impl/opr_shallow_copy.cpp
浏览文件 @
1f7bf1ad
...
@@ -125,16 +125,18 @@ ComputingGraph* serialization::OprShallowCopyContext::owner_graph(
...
@@ -125,16 +125,18 @@ ComputingGraph* serialization::OprShallowCopyContext::owner_graph(
cg
::
OperatorNodeBase
*
serialization
::
copy_opr_shallow
(
cg
::
OperatorNodeBase
*
serialization
::
copy_opr_shallow
(
const
cg
::
OperatorNodeBase
&
opr
,
const
VarNodeArray
&
inputs
,
const
cg
::
OperatorNodeBase
&
opr
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
,
const
OprShallowCopyContext
&
ctx
)
{
const
OperatorNodeConfig
&
config
,
const
OprShallowCopyContext
&
ctx
)
{
auto
registry
=
OprRegistry
::
find_by_type
(
opr
.
dyn_typeinfo
());
OprShallowCopy
shallow_copy
=
nullptr
;
mgb_assert
(
if
(
auto
registry
=
OprRegistry
::
find_by_type
(
opr
.
dyn_typeinfo
()))
{
registry
,
"could not find OprReceiver to copy opr %s{%s}"
,
opr
.
cname
(),
shallow_copy
=
registry
->
shallow_copy
;
opr
.
dyn_typeinfo
()
->
name
);
}
else
{
shallow_copy
=
intl
::
copy_opr_shallow_default_impl
;
}
mgb_assert
(
inputs
.
size
()
==
opr
.
input
().
size
());
mgb_assert
(
inputs
.
size
()
==
opr
.
input
().
size
());
auto
dst_og
=
ctx
.
owner_graph
(
opr
,
inputs
);
auto
dst_og
=
ctx
.
owner_graph
(
opr
,
inputs
);
auto
do_copy
=
[
&
]()
{
auto
do_copy
=
[
&
]()
{
auto
nr_opr_before
=
opr
.
owner_graph
()
->
nr_oprs_in_graph
();
auto
nr_opr_before
=
opr
.
owner_graph
()
->
nr_oprs_in_graph
();
auto
ret
=
registry
->
shallow_copy
(
ctx
,
opr
,
inputs
,
config
);
auto
ret
=
shallow_copy
(
ctx
,
opr
,
inputs
,
config
);
if
(
dst_og
!=
opr
.
owner_graph
()
||
if
(
dst_og
!=
opr
.
owner_graph
()
||
opr
.
owner_graph
()
->
nr_oprs_in_graph
()
!=
nr_opr_before
)
{
opr
.
owner_graph
()
->
nr_oprs_in_graph
()
!=
nr_opr_before
)
{
...
@@ -188,18 +190,28 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
...
@@ -188,18 +190,28 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
const
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr
,
const
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
MGB_MARK_USED_VAR
(
ctx
);
MGB_MARK_USED_VAR
(
ctx
);
OprDumper
opr_dumper
=
nullptr
;
OprLoaderWrapper
opr_loader
=
nullptr
;
auto
registry
=
OprRegistry
::
find_by_type
(
opr
.
dyn_typeinfo
());
if
(
auto
registry
=
OprRegistry
::
find_by_type
(
opr
.
dyn_typeinfo
()))
{
opr_loader
=
registry
->
loader
;
opr_dumper
=
registry
->
dumper
;
}
else
{
auto
registryv2
=
OprRegistryV2
::
versioned_find_by_typeinfo
(
opr
.
dyn_typeinfo
(),
CURRENT_VERSION
);
opr_loader
=
registryv2
->
loader
;
opr_dumper
=
registryv2
->
dumper
;
}
mgb_assert
(
mgb_assert
(
registry
&&
registry
->
dumper
&&
registry
->
loader
,
opr_dumper
&&
opr_
loader
,
"can not shallow_copy operator %s{%s}: "
"can not shallow_copy operator %s{%s}: "
"no dumper/loader registered"
,
"no dumper/loader registered"
,
opr
.
cname
(),
opr
.
dyn_typeinfo
()
->
name
);
opr
.
cname
(),
opr
.
dyn_typeinfo
()
->
name
);
OprDumpContextMemory
dumper
;
OprDumpContextMemory
memory_
dumper
;
registry
->
dumper
(
dumper
,
opr
);
opr_dumper
(
memory_
dumper
,
opr
);
OprLoadContextMemory
loader
{
opr
.
owner_graph
(),
dumper
};
OprLoadContextMemory
loader
{
opr
.
owner_graph
(),
memory_
dumper
};
return
registry
->
loader
(
loader
,
inputs
,
config
).
opr
();
return
opr_
loader
(
loader
,
inputs
,
config
).
opr
();
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/serialization/impl/serializer_oss_v2.cpp
浏览文件 @
1f7bf1ad
...
@@ -358,6 +358,11 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
...
@@ -358,6 +358,11 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
auto
new_output_vars
=
output_vars
;
auto
new_output_vars
=
output_vars
;
if
(
!
config
.
no_change_graph
)
{
if
(
!
config
.
no_change_graph
)
{
new_output_vars
=
converter_all_opr_to_compatiable
(
output_vars
);
new_output_vars
=
converter_all_opr_to_compatiable
(
output_vars
);
mgb_assert
(
output_vars
.
size
()
==
new_output_vars
.
size
());
for
(
size_t
id
=
0
;
id
<
output_vars
.
size
();
id
++
)
{
auto
&
new_var
=
new_output_vars
[
id
];
new_var
.
rename
(
output_vars
[
id
].
node
()
->
name
());
}
}
}
auto
begin_pos
=
m_file
->
tell
();
auto
begin_pos
=
m_file
->
tell
();
...
...
src/serialization/include/megbrain/serialization/sereg.h
浏览文件 @
1f7bf1ad
...
@@ -151,7 +151,7 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
...
@@ -151,7 +151,7 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
//! call OprRegistryV2::versioned_add for new serialization which is compatiable
//! call OprRegistryV2::versioned_add for new serialization which is compatiable
//! with old serialization, convert is nullptr, this registry is just only for
//! with old serialization, convert is nullptr, this registry is just only for
//! varsion 1
//! varsion 1
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load
)
\
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load
, _registerv2)
\
do { \
do { \
::mgb::serialization::OprRegistry::add( \
::mgb::serialization::OprRegistry::add( \
{_cls::typeinfo(), \
{_cls::typeinfo(), \
...
@@ -161,10 +161,12 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
...
@@ -161,10 +161,12 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
_load, \
_load, \
{}, \
{}, \
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \
if (_registerv2) { \
::mgb::serialization::OprRegistryV2::versioned_add( \
::mgb::serialization::OprRegistryV2::versioned_add( \
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \
::mgb::VERSION_1, ::mgb::VERSION_1); \
::mgb::VERSION_1, ::mgb::VERSION_1); \
} \
} while (0)
} while (0)
//! call OprRegistryV2::versioned_add for new serialization, in which convert the
//! call OprRegistryV2::versioned_add for new serialization, in which convert the
...
@@ -181,7 +183,7 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
...
@@ -181,7 +183,7 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
/*!
/*!
* \brief register opr serialization methods
* \brief register opr serialization methods
*/
*/
#define MGB_SEREG_OPR
(_cls, _arity)
\
#define MGB_SEREG_OPR
_CONDITION(_cls, _arity, _registerv2)
\
namespace { \
namespace { \
namespace ser = ::mgb::serialization; \
namespace ser = ::mgb::serialization; \
struct _OprReg##_cls { \
struct _OprReg##_cls { \
...
@@ -192,12 +194,14 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
...
@@ -192,12 +194,14 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \
} \
} \
static void entry() { \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader
);
\
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader
, _registerv2);
\
} \
} \
}; \
}; \
} \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)
#define MGB_SEREG_OPR(_cls, _arity) MGB_SEREG_OPR_CONDITION(_cls, _arity, true)
//! new dump/load function should implement in OprLoadDumpImplV2, _converter is
//! new dump/load function should implement in OprLoadDumpImplV2, _converter is
//! optional , if not implement pass nullptr
//! optional , if not implement pass nullptr
#define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \
#define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \
...
...
src/serialization/test/serializer_oss.cpp
浏览文件 @
1f7bf1ad
#include "megbrain/opr/nn_int.h"
#if MGB_ENABLE_FBS_SERIALIZATION
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/basic_arith_wrapper.h"
...
@@ -1016,4 +1017,107 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) {
...
@@ -1016,4 +1017,107 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) {
load
();
load
();
}
}
TEST
(
TestSerializer2
,
TestElemwiseMultiTypeLoadDump
)
{
auto
fname
=
GET_OUTPUT_FILE
(
GraphDumpFormat
::
FLATBUFFERS_V2
);
TensorShape
shape
{
3
};
auto
cn
=
CompNode
::
load
(
"xpu0"
);
std
::
shared_ptr
<
HostTensorND
>
host0
=
std
::
make_shared
<
HostTensorND
>
(
cn
,
shape
,
dtype
::
Float32
{});
std
::
shared_ptr
<
HostTensorND
>
host1
=
std
::
make_shared
<
HostTensorND
>
(
cn
,
shape
,
dtype
::
Float32
{});
HostTensorND
dst_truth
;
host0
->
ptr
<
float
>
()[
0
]
=
2
;
host0
->
ptr
<
float
>
()[
1
]
=
2
;
host0
->
ptr
<
float
>
()[
2
]
=
-
1
;
host1
->
ptr
<
float
>
()[
0
]
=
1
;
host1
->
ptr
<
float
>
()[
1
]
=
2
;
host1
->
ptr
<
float
>
()[
2
]
=
3
;
auto
dump
=
[
&
](
opr
::
ElemwiseMultiType
::
Param
::
Mode
mode
,
size_t
nr_opr
)
{
auto
graph
=
ComputingGraph
::
make
();
OperatorNodeConfig
config
;
config
.
name
(
"input0"
);
auto
h2d0
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host0
,
config
);
config
.
name
(
"input1"
);
auto
h2d1
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host1
,
config
);
auto
x
=
opr
::
ElemwiseMultiType
::
make
(
{
h2d0
,
h2d1
},
{
mode
},
OperatorNodeConfig
{
dtype
::
Bool
()});
x
.
rename
(
"out"
);
auto
func
=
graph
->
compile
({
make_callback_copy
(
x
,
dst_truth
)});
auto
dumper
=
GraphDumper
::
make
(
OutputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS_V2
);
auto
rst
=
dumper
->
dump
({
x
});
func
->
execute
().
wait
();
ASSERT_EQ
(
rst
.
nr_opr
,
nr_opr
);
};
auto
load
=
[
&
]()
{
auto
loader
=
GraphLoader
::
make
(
InputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS_V2
);
auto
rst
=
loader
->
load
();
ASSERT_EQ
(
rst
.
tensor_map
.
size
(),
2
);
ASSERT_EQ
(
rst
.
output_var_map
.
count
(
"out"
),
1
);
HostTensorND
host_x
;
auto
func
=
rst
.
graph_compile
({
make_callback_copy
(
rst
.
output_var_list
[
0
],
host_x
)});
for
(
auto
&
input
:
rst
.
tensor_map
)
{
if
(
input
.
first
==
"input0"
)
{
input
.
second
->
copy_from
(
*
host0
).
sync
();
}
else
if
(
input
.
first
==
"input1"
)
{
input
.
second
->
copy_from
(
*
host1
).
sync
();
}
}
func
->
execute
().
wait
();
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
EXPECT_EQ
(
host_x
.
ptr
<
bool
>
()[
i
],
dst_truth
.
ptr
<
bool
>
()[
i
]);
}
};
dump
(
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
EQ
,
4
);
load
();
dump
(
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
LT
,
4
);
load
();
dump
(
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
LEQ
,
4
);
load
();
dump
(
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
NEQ
,
5
);
load
();
auto
dump_single_input
=
[
&
](
opr
::
ElemwiseMultiType
::
Param
::
Mode
mode
,
size_t
nr_opr
)
{
auto
graph
=
ComputingGraph
::
make
();
auto
h2d0
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host0
);
auto
x
=
opr
::
ElemwiseMultiType
::
make
(
{
h2d0
},
{
mode
},
OperatorNodeConfig
{
dtype
::
Bool
()});
x
.
rename
(
"out"
);
auto
func
=
graph
->
compile
({
make_callback_copy
(
x
,
dst_truth
)});
auto
dumper
=
GraphDumper
::
make
(
OutputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS_V2
);
auto
rst
=
dumper
->
dump
({
x
});
func
->
execute
().
wait
();
ASSERT_EQ
(
rst
.
nr_opr
,
nr_opr
);
};
auto
load_single_input
=
[
&
]()
{
auto
loader
=
GraphLoader
::
make
(
InputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS_V2
);
auto
rst
=
loader
->
load
();
ASSERT_EQ
(
rst
.
tensor_map
.
size
(),
1
);
ASSERT_EQ
(
rst
.
output_var_map
.
count
(
"out"
),
1
);
HostTensorND
host_x
;
auto
func
=
rst
.
graph_compile
({
make_callback_copy
(
rst
.
output_var_list
[
0
],
host_x
)});
rst
.
tensor_map
.
begin
()
->
second
->
copy_from
(
*
host0
).
sync
();
func
->
execute
().
wait
();
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
EXPECT_EQ
(
host_x
.
ptr
<
bool
>
()[
i
],
dst_truth
.
ptr
<
bool
>
()[
i
]);
}
};
host0
->
ptr
<
float
>
()[
2
]
=
INFINITY
;
dump_single_input
(
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
ISINF
,
4
);
load_single_input
();
host0
->
ptr
<
float
>
()[
2
]
=
NAN
;
dump_single_input
(
opr
::
ElemwiseMultiType
::
Param
::
Mode
::
ISNAN
,
4
);
load_single_input
();
}
#endif
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录