Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6011f510
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6011f510
编写于
1月 05, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
style(all): fix clang-format for MGB_DEFINE inside another macro
GitOrigin-RevId: 8c2b6a2aed2645db9611c9875724f482d31556ea
上级
111fa975
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
159 addition
and
157 deletion
+159
-157
imperative/src/impl/ops/elemwise.cpp
imperative/src/impl/ops/elemwise.cpp
+57
-56
src/core/include/megbrain/graph/operator_node.h
src/core/include/megbrain/graph/operator_node.h
+6
-6
src/core/include/megbrain/utils/metahelper.h
src/core/include/megbrain/utils/metahelper.h
+7
-7
src/gopt/impl/global_layout_transform/profiler_impl.cpp
src/gopt/impl/global_layout_transform/profiler_impl.cpp
+1
-1
src/opr/include/megbrain/opr/dnn/pooling.h
src/opr/include/megbrain/opr/dnn/pooling.h
+23
-23
src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h
src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h
+21
-21
src/opr/include/megbrain/opr/rand.h
src/opr/include/megbrain/opr/rand.h
+40
-42
tools/format.py
tools/format.py
+4
-1
未找到文件。
imperative/src/impl/ops/elemwise.cpp
浏览文件 @
6011f510
...
...
@@ -158,70 +158,71 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
MGB_DEFINE_OPR_CLASS
(
ForceInplaceElemwise
,
cg
::
SingleCNOperatorNodeBaseT
<
opr
::
mixin
::
MegDNNOprHolder
>
)
//
{
cg
::
SingleCNOperatorNodeBaseT
<
opr
::
mixin
::
MegDNNOprHolder
>
)
//
{
public:
struct
Param
{
using
Mode
=
megdnn
::
Elemwise
::
Param
::
Mode
;
Mode
mode
;
size_t
inplace_index
;
};
using
Mode
=
Param
::
Mode
;
ForceInplaceElemwise
(
const
VarNodeArray
&
inputs
,
Param
param
,
OperatorNodeConfig
config
=
{})
:
Super
(
inputs
[
0
]
->
owner_graph
(),
config
,
"device_add_update"
,
inputs
),
m_param
{
param
}
{
for
(
auto
*
input
:
inputs
)
{
add_input
({
input
});
struct
Param
{
using
Mode
=
megdnn
::
Elemwise
::
Param
::
Mode
;
Mode
mode
;
size_t
inplace_index
;
};
using
Mode
=
Param
::
Mode
;
ForceInplaceElemwise
(
const
VarNodeArray
&
inputs
,
Param
param
,
OperatorNodeConfig
config
=
{})
:
Super
(
inputs
[
0
]
->
owner_graph
(),
config
,
"device_add_update"
,
inputs
),
m_param
{
param
}
{
for
(
auto
*
input
:
inputs
)
{
add_input
({
input
});
}
add_output
(
None
)
->
set_fwd_in2out_writable_force
(
input
(
param
.
inplace_index
))
.
add_flag
(
VarNode
::
Flag
::
NO_MEM_RECLAIM
);
}
add_output
(
None
)
->
set_fwd_in2out_writable_force
(
input
(
param
.
inplace_index
))
.
add_flag
(
VarNode
::
Flag
::
NO_MEM_RECLAIM
);
}
static
SymbolVar
make
(
const
VarNodeArray
&
inputs
,
Param
param
)
{
return
SymbolVar
{
inputs
[
0
]}.
insert_single_output_opr
<
ForceInplaceElemwise
>
(
inputs
,
param
);
}
static
cg
::
OperatorNodeBase
*
shallow_copy
(
const
serialization
::
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
);
static
SymbolVar
make
(
const
VarNodeArray
&
inputs
,
Param
param
)
{
return
SymbolVar
{
inputs
[
0
]}.
insert_single_output_opr
<
ForceInplaceElemwise
>
(
inputs
,
param
);
}
static
cg
::
OperatorNodeBase
*
shallow_copy
(
const
serialization
::
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
);
protected:
NodeProp
*
do_make_node_prop
()
const
override
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_flag
(
NodeProp
::
Flag
::
FORCE_UPDATE_INPUT_VAR
);
return
ret
;
}
void
create_megdnn_opr
()
override
{
auto
opr
=
DnnOprCaller
<
megdnn
::
Elemwise
>::
create_operator
(
comp_node
());
opr
->
param
().
mode
=
m_param
.
mode
;
set_megdnn_opr
(
std
::
move
(
opr
));
}
void
scn_do_execute
()
override
{
auto
to_dnnnd
=
[
&
](
auto
*
var
)
{
return
var
->
dev_tensor
().
as_megdnn
();
};
megdnn
::
TensorNDArray
inputs_dnnnd
;
for
(
auto
*
input
:
input
())
{
inputs_dnnnd
.
push_back
(
to_dnnnd
(
input
));
NodeProp
*
do_make_node_prop
()
const
override
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_flag
(
NodeProp
::
Flag
::
FORCE_UPDATE_INPUT_VAR
);
return
ret
;
}
mgb_assert
(
input
(
m_param
.
inplace_index
)
->
contain_flag
(
VarNode
::
Flag
::
NO_SYS_MEM_ALLOC
),
"ForceInplaceElemwise cannot be applied in internal tensor"
);
auto
*
out_dest
=
output
(
0
);
auto
*
opr
=
static_cast
<
megdnn
::
Elemwise
*>
(
megdnn_opr
());
opr
->
exec
(
std
::
move
(
inputs_dnnnd
),
to_dnnnd
(
out_dest
));
}
void
init_output_static_infer_desc
()
override
{
using
namespace
cg
::
static_infer
;
void
create_megdnn_opr
()
override
{
auto
opr
=
DnnOprCaller
<
megdnn
::
Elemwise
>::
create_operator
(
comp_node
());
opr
->
param
().
mode
=
m_param
.
mode
;
set_megdnn_opr
(
std
::
move
(
opr
));
}
void
scn_do_execute
()
override
{
auto
to_dnnnd
=
[
&
](
auto
*
var
)
{
return
var
->
dev_tensor
().
as_megdnn
();
};
megdnn
::
TensorNDArray
inputs_dnnnd
;
for
(
auto
*
input
:
input
())
{
inputs_dnnnd
.
push_back
(
to_dnnnd
(
input
));
}
mgb_assert
(
input
(
m_param
.
inplace_index
)
->
contain_flag
(
VarNode
::
Flag
::
NO_SYS_MEM_ALLOC
),
"ForceInplaceElemwise cannot be applied in internal tensor"
);
auto
*
out_dest
=
output
(
0
);
auto
*
opr
=
static_cast
<
megdnn
::
Elemwise
*>
(
megdnn_opr
());
opr
->
exec
(
std
::
move
(
inputs_dnnnd
),
to_dnnnd
(
out_dest
));
}
void
init_output_static_infer_desc
()
override
{
using
namespace
cg
::
static_infer
;
owner_graph
()
->
static_infer_manager
().
register_shape_infer
(
output
(
0
),
ShapeInferDesc
::
make_identity
(
input
(
m_param
.
inplace_index
)));
}
owner_graph
()
->
static_infer_manager
().
register_shape_infer
(
output
(
0
),
ShapeInferDesc
::
make_identity
(
input
(
m_param
.
inplace_index
)));
}
private:
Param
m_param
;
void
record_execute_deps
(
ExecDependencyArray
&
deps
)
override
{
record_megdnn_opr
(
deps
);
}
Param
m_param
;
void
record_execute_deps
(
ExecDependencyArray
&
deps
)
override
{
record_megdnn_opr
(
deps
);
}
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ForceInplaceElemwise
);
...
...
src/core/include/megbrain/graph/operator_node.h
浏览文件 @
6011f510
...
...
@@ -1013,13 +1013,13 @@ using OprNodeArray = SmallVector<OperatorNodeBase*>;
*
* Note that opening brace is included
*/
#define MGB_DEFINE_OPR_CLASS(_name, _base, ...) \
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \
MGB_DYN_TYPE_OBJ_FINAL_DECL;
#define MGB_DEFINE_OPR_CLASS(_name, _base, ...)
\
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__)
\
MGB_DYN_TYPE_OBJ_FINAL_DECL;
#define MGB_DEFINE_OPR_CLASS_WITH_EXPORT(_name, _base, ...) \
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;
#define MGB_DEFINE_OPR_CLASS_WITH_EXPORT(_name, _base, ...)
\
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__)
\
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;
}
// namespace cg
}
// namespace mgb
...
...
src/core/include/megbrain/utils/metahelper.h
浏览文件 @
6011f510
...
...
@@ -495,18 +495,18 @@ private:
}
// namespace mgb
#define
_
MGB_DEFINE_CLS_WITH_SUPER_IMPL(_tpl, _name, _base, ...) \
class _name : public _base, ##__VA_ARGS__ {
\
public:
\
using Super = _tpl _base;
\
\
#define MGB_DEFINE_CLS_WITH_SUPER_IMPL(_tpl, _name, _base, ...) \
class _name : public _base, ##__VA_ARGS__ { \
public: \
using Super = _tpl _base; \
\
private:
/*!
* \brief define a class which has Super defined to base
*/
#define MGB_DEFINE_CLS_WITH_SUPER(_name, _base, ...) \
_
MGB_DEFINE_CLS_WITH_SUPER_IMPL(, _name, _base, ##__VA_ARGS__)
MGB_DEFINE_CLS_WITH_SUPER_IMPL(, _name, _base, ##__VA_ARGS__)
/*!
* \brief define a class which has Super defined to base
...
...
@@ -514,5 +514,5 @@ private:
* Used when this class is a template and base class has template
*/
#define MGB_DEFINE_CLS_WITH_SUPER_TPL(_name, _base, ...) \
_
MGB_DEFINE_CLS_WITH_SUPER_IMPL(typename, _name, _base, ##__VA_ARGS__)
MGB_DEFINE_CLS_WITH_SUPER_IMPL(typename, _name, _base, ##__VA_ARGS__)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/gopt/impl/global_layout_transform/profiler_impl.cpp
浏览文件 @
6011f510
...
...
@@ -99,7 +99,7 @@ float GraphPartitionProfiler::duration_in_usec() const {
* \brief An operator that indicates its input var node is contiguous
*/
// clang-format off
MGB_DEFINE_OPR_CLASS
(
MarkInputContiguous
,
SingleCNOperatorNodeBase
)
//{
MGB_DEFINE_OPR_CLASS
(
MarkInputContiguous
,
SingleCNOperatorNodeBase
)
//
{
void
scn_do_execute
()
override
{};
void
init_output_static_infer_desc
()
override
;
void
add_input_layout_constraint
()
override
{
...
...
src/opr/include/megbrain/opr/dnn/pooling.h
浏览文件 @
6011f510
...
...
@@ -20,38 +20,38 @@ namespace opr {
MGB_DEFINE_OPR_CLASS
(
PoolingForward
,
intl
::
MegDNNOprWrapperFwd
<
megdnn
::
PoolingForward
>
,
public
mixin
::
AlgoChooserHelper
)
//
{
public
mixin
::
AlgoChooserHelper
)
//
{
public:
MGE_WIN_DECLSPEC_FUC
PoolingForward
(
VarNode
*
src
,
const
Param
&
param
,
const
ExecutionPolicy
&
policy
,
const
OperatorNodeConfig
&
config
);
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
SymbolVar
src
,
const
Param
&
param
,
const
ExecutionPolicy
&
policy
=
{},
const
OperatorNodeConfig
&
config
=
{});
void
init_output_static_infer_desc
()
override
;
size_t
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
output_shapes
)
const
override
;
MGE_WIN_DECLSPEC_FUC
PoolingForward
(
VarNode
*
src
,
const
Param
&
param
,
const
ExecutionPolicy
&
policy
,
const
OperatorNodeConfig
&
config
);
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
SymbolVar
src
,
const
Param
&
param
,
const
ExecutionPolicy
&
policy
=
{},
const
OperatorNodeConfig
&
config
=
{});
void
init_output_static_infer_desc
()
override
;
size_t
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
output_shapes
)
const
override
;
};
using
Pooling
=
PoolingForward
;
MGB_DEFINE_OPR_CLASS
(
PoolingBackward
,
intl
::
MegDNNOprWrapperBwd
<
megdnn
::
PoolingBackward
>
,
public
mixin
::
AlgoChooserHelper
)
//
{
public
mixin
::
AlgoChooserHelper
)
//
{
public:
MGE_WIN_DECLSPEC_FUC
PoolingBackward
(
VarNode
*
src
,
VarNode
*
dst
,
VarNode
*
diff
,
const
Param
&
param
,
const
ExecutionPolicy
&
policy
,
const
OperatorNodeConfig
&
config
);
MGE_WIN_DECLSPEC_FUC
PoolingBackward
(
VarNode
*
src
,
VarNode
*
dst
,
VarNode
*
diff
,
const
Param
&
param
,
const
ExecutionPolicy
&
policy
,
const
OperatorNodeConfig
&
config
);
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
SymbolVar
src
,
SymbolVar
dst
,
SymbolVar
diff
,
const
Param
&
param
,
const
ExecutionPolicy
&
policy
=
{},
const
OperatorNodeConfig
&
config
=
{});
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
SymbolVar
src
,
SymbolVar
dst
,
SymbolVar
diff
,
const
Param
&
param
,
const
ExecutionPolicy
&
policy
=
{},
const
OperatorNodeConfig
&
config
=
{});
MGE_WIN_DECLSPEC_FUC
size_t
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
output_shapes
)
const
override
final
;
MGE_WIN_DECLSPEC_FUC
size_t
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
output_shapes
)
const
override
final
;
};
}
// namespace opr
...
...
src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h
浏览文件 @
6011f510
...
...
@@ -86,7 +86,7 @@ MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint_contig(OperatorNodeBase& o
//! called in constructor to add output vars
MGE_WIN_DECLSPEC_FUC
void
add_output_vars
(
OperatorNodeBase
&
opr
,
size_t
nr_output
,
bool
add_workspace
);
}
}
// namespace megdnn_utils
/*!
* \brief mixin for infer workspace size based on input and output shapes
...
...
@@ -344,34 +344,34 @@ private:
}
// namespace mgb
//! define a megdnn opr wrapper class with 1 input for forward
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(_name) \
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \
public: \
_name(VarNode* p0, const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(_name)
\
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>)
\
public: \
_name(VarNode* p0, const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
}
//! define a megdnn opr wrapper class with 2 inputs for forward
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD2(_name) \
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \
public:
\
_name(VarNode* p0, VarNode* p1, const Param& param,
\
const OperatorNodeConfig& config);
\
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
\
SymbolVar p0, SymbolVar p1, const Param& param = {},
\
const OperatorNodeConfig& config = {});
\
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD2(_name)
\
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>)
\
public:
\
_name(VarNode* p0, VarNode* p1, const Param& param,
\
const OperatorNodeConfig& config);
\
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
\
SymbolVar p0, SymbolVar p1, const Param& param = {},
\
const OperatorNodeConfig& config = {});
\
}
//! define a megdnn opr wrapper class with 3 inputs for grad
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_BWD3(_name, _extra...) \
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperBwd<megdnn::_name>) \
_extra public : _name(VarNode* p0, VarNode* p1, VarNode* p2, const Param& param,
\
const OperatorNodeConfig& config);
\
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
\
SymbolVar p0, SymbolVar p1, SymbolVar p2, const Param& param = {},
\
const OperatorNodeConfig& config = {});
\
_extra public : _name(VarNode* p0, VarNode* p1, VarNode* p2,
\
const Param& param, const OperatorNodeConfig& config);
\
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
\
SymbolVar p0, SymbolVar p1, SymbolVar p2, const Param& param = {},
\
const OperatorNodeConfig& config = {});
\
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/include/megbrain/opr/rand.h
浏览文件 @
6011f510
...
...
@@ -40,25 +40,25 @@ protected:
};
/* ================= RNG with shape ================= */
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG)
\
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>)
\
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override;
\
\
public:
\
RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config);
\
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
\
SymbolVar shape, const Param& param = {},
\
const OperatorNodeConfig& config = {});
\
static SymbolVar make(
\
ComputingGraph& graph, const TensorShape& shape,
\
const OperatorNodeConfig& config, const Param& param = {}) {
\
return make(
\
var_from_tensor_shape(graph, config, "rng", shape), param, config);
\
}
\
void init_output_static_infer_desc() override;
\
void scn_do_execute() override;
\
}
\
;
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override;
\
\
public:
\
RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config);
\
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
\
SymbolVar shape, const Param& param = {},
\
const OperatorNodeConfig& config = {});
\
static SymbolVar make(
\
ComputingGraph& graph, const TensorShape& shape,
\
const OperatorNodeConfig& config, const Param& param = {}) {
\
return make(
\
var_from_tensor_shape(graph, config, "rng", shape), param,
\
config);
\
}
\
void init_output_static_infer_desc() override;
\
void scn_do_execute() override;
\
}
;
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS
(
UniformRNG
)
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS
(
GaussianRNG
)
...
...
@@ -66,20 +66,19 @@ _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG)
#undef _DEFINE_RNG_OPR_WITH_SHAPE_CLASS
/* ================= RNG with input ================= */
#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \
void add_input_layout_constraint() override; \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
\
public: \
RNG(_INPUTS(VarNode*), const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static _OUTPUTS make( \
_INPUTS(SymbolVar), const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
void init_output_static_infer_desc() override; \
void scn_do_execute() override; \
} \
;
#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \
void add_input_layout_constraint() override; \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
\
public: \
RNG(_INPUTS(VarNode*), const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static _OUTPUTS make( \
_INPUTS(SymbolVar), const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
void init_output_static_infer_desc() override; \
void scn_do_execute() override; \
};
/* ================= 1 input ================= */
#define _INPUTS(preifx) preifx i0
...
...
@@ -100,7 +99,7 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG)
#undef _INPUTS
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS
}
// intl
}
//
namespace
intl
using
UniformRNG
=
intl
::
UniformRNG
;
using
GaussianRNG
=
intl
::
GaussianRNG
;
...
...
@@ -111,16 +110,15 @@ using BetaRNG = intl::BetaRNG;
using
ShuffleRNG
=
intl
::
ShuffleRNGForward
;
MGB_DEFINE_OPR_CLASS_WITH_EXPORT
(
ShuffleRNGBackward
,
intl
::
MegDNNOprWrapperBwd
<
megdnn
::
ShuffleRNGBackward
>
)
//{
ShuffleRNGBackward
,
intl
::
MegDNNOprWrapperBwd
<
megdnn
::
ShuffleRNGBackward
>
)
// {
public:
ShuffleRNGBackward
(
VarNode
*
out_diff
,
VarNode
*
indices
,
VarNode
*
result_shape
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
);
ShuffleRNGBackward
(
VarNode
*
out_diff
,
VarNode
*
indices
,
VarNode
*
result_shape
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
);
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
SymbolVar
out_diff
,
SymbolVar
indices
,
SymbolVar
result_shape
,
const
Param
&
param
=
{},
const
OperatorNodeConfig
&
config
=
{});
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
SymbolVar
out_diff
,
SymbolVar
indices
,
SymbolVar
result_shape
,
const
Param
&
param
=
{},
const
OperatorNodeConfig
&
config
=
{});
};
}
// namespace opr
...
...
tools/format.py
浏览文件 @
6011f510
...
...
@@ -19,7 +19,8 @@ failed_files = Manager().list()
def
process_file
(
file
,
clang_format
,
write
):
source
=
open
(
file
,
"r"
).
read
()
source
=
re
.
sub
(
r
"MGB_DEFINE(?P<r>(.|\n)*?)// +{"
,
"class MGB_DEFINE\g<r>{"
,
source
)
source
=
re
.
sub
(
r
"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{"
,
r
"class MGB_DEFINE\g<r>{"
,
source
)
source
,
count
=
re
.
subn
(
r
"(?<!#define )MGB_DEFINE(.*) +\\"
,
r
"class MGB_DEFINE\1{\\"
,
source
)
result
=
subprocess
.
check_output
(
[
...
...
@@ -33,6 +34,8 @@ def process_file(file, clang_format, write):
)
result
=
result
.
decode
(
"utf-8"
)
if
count
:
result
=
re
.
sub
(
r
"class MGB_DEFINE(.*){( *)\\"
,
r
"MGB_DEFINE\1\2 \\"
,
result
)
result
=
re
.
sub
(
r
"class MGB_DEFINE((.|\n)*?){"
,
r
"MGB_DEFINE\1// {"
,
result
)
if
write
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录