提交 6011f510 编写于 作者: M Megvii Engine Team

style(all): fix clang-format for MGB_DEFINE inside another macro

GitOrigin-RevId: 8c2b6a2aed2645db9611c9875724f482d31556ea
上级 111fa975
......@@ -158,70 +158,71 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
MGB_DEFINE_OPR_CLASS(
ForceInplaceElemwise,
cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{
cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) // {
public:
struct Param {
using Mode = megdnn::Elemwise::Param::Mode;
Mode mode;
size_t inplace_index;
};
using Mode = Param::Mode;
ForceInplaceElemwise(
const VarNodeArray& inputs, Param param, OperatorNodeConfig config = {})
: Super(inputs[0]->owner_graph(), config, "device_add_update", inputs),
m_param{param} {
for (auto* input : inputs) {
add_input({input});
struct Param {
using Mode = megdnn::Elemwise::Param::Mode;
Mode mode;
size_t inplace_index;
};
using Mode = Param::Mode;
ForceInplaceElemwise(
const VarNodeArray& inputs, Param param, OperatorNodeConfig config = {})
: Super(inputs[0]->owner_graph(), config, "device_add_update", inputs),
m_param{param} {
for (auto* input : inputs) {
add_input({input});
}
add_output(None)
->set_fwd_in2out_writable_force(input(param.inplace_index))
.add_flag(VarNode::Flag::NO_MEM_RECLAIM);
}
add_output(None)
->set_fwd_in2out_writable_force(input(param.inplace_index))
.add_flag(VarNode::Flag::NO_MEM_RECLAIM);
}
static SymbolVar make(const VarNodeArray& inputs, Param param) {
return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>(
inputs, param);
}
static cg::OperatorNodeBase* shallow_copy(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config);
static SymbolVar make(const VarNodeArray& inputs, Param param) {
return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>(
inputs, param);
}
static cg::OperatorNodeBase* shallow_copy(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config);
protected:
NodeProp* do_make_node_prop() const override {
auto ret = Super::do_make_node_prop();
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
return ret;
}
void create_megdnn_opr() override {
auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node());
opr->param().mode = m_param.mode;
set_megdnn_opr(std::move(opr));
}
void scn_do_execute() override {
auto to_dnnnd = [&](auto* var) { return var->dev_tensor().as_megdnn(); };
megdnn::TensorNDArray inputs_dnnnd;
for (auto* input : input()) {
inputs_dnnnd.push_back(to_dnnnd(input));
NodeProp* do_make_node_prop() const override {
auto ret = Super::do_make_node_prop();
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
return ret;
}
mgb_assert(
input(m_param.inplace_index)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC),
"ForceInplaceElemwise cannot be applied in internal tensor");
auto* out_dest = output(0);
auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr());
opr->exec(std::move(inputs_dnnnd), to_dnnnd(out_dest));
}
void init_output_static_infer_desc() override {
using namespace cg::static_infer;
void create_megdnn_opr() override {
auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node());
opr->param().mode = m_param.mode;
set_megdnn_opr(std::move(opr));
}
void scn_do_execute() override {
auto to_dnnnd = [&](auto* var) { return var->dev_tensor().as_megdnn(); };
megdnn::TensorNDArray inputs_dnnnd;
for (auto* input : input()) {
inputs_dnnnd.push_back(to_dnnnd(input));
}
mgb_assert(
input(m_param.inplace_index)
->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC),
"ForceInplaceElemwise cannot be applied in internal tensor");
auto* out_dest = output(0);
auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr());
opr->exec(std::move(inputs_dnnnd), to_dnnnd(out_dest));
}
void init_output_static_infer_desc() override {
using namespace cg::static_infer;
owner_graph()->static_infer_manager().register_shape_infer(
output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index)));
}
owner_graph()->static_infer_manager().register_shape_infer(
output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index)));
}
private:
Param m_param;
void record_execute_deps(ExecDependencyArray& deps) override {
record_megdnn_opr(deps);
}
Param m_param;
void record_execute_deps(ExecDependencyArray& deps) override {
record_megdnn_opr(deps);
}
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise);
......
......@@ -1013,13 +1013,13 @@ using OprNodeArray = SmallVector<OperatorNodeBase*>;
*
* Note that opening brace is included
*/
#define MGB_DEFINE_OPR_CLASS(_name, _base, ...) \
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \
MGB_DYN_TYPE_OBJ_FINAL_DECL;
#define MGB_DEFINE_OPR_CLASS(_name, _base, ...) \
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \
MGB_DYN_TYPE_OBJ_FINAL_DECL;
#define MGB_DEFINE_OPR_CLASS_WITH_EXPORT(_name, _base, ...) \
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;
#define MGB_DEFINE_OPR_CLASS_WITH_EXPORT(_name, _base, ...) \
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;
} // namespace cg
} // namespace mgb
......
......@@ -495,18 +495,18 @@ private:
} // namespace mgb
#define _MGB_DEFINE_CLS_WITH_SUPER_IMPL(_tpl, _name, _base, ...) \
class _name : public _base, ##__VA_ARGS__ { \
public: \
using Super = _tpl _base; \
\
#define MGB_DEFINE_CLS_WITH_SUPER_IMPL(_tpl, _name, _base, ...) \
class _name : public _base, ##__VA_ARGS__ { \
public: \
using Super = _tpl _base; \
\
private:
/*!
* \brief define a class which has Super defined to base
*/
#define MGB_DEFINE_CLS_WITH_SUPER(_name, _base, ...) \
_MGB_DEFINE_CLS_WITH_SUPER_IMPL(, _name, _base, ##__VA_ARGS__)
MGB_DEFINE_CLS_WITH_SUPER_IMPL(, _name, _base, ##__VA_ARGS__)
/*!
* \brief define a class which has Super defined to base
......@@ -514,5 +514,5 @@ private:
* Used when this class is a template and base class has template
*/
#define MGB_DEFINE_CLS_WITH_SUPER_TPL(_name, _base, ...) \
_MGB_DEFINE_CLS_WITH_SUPER_IMPL(typename, _name, _base, ##__VA_ARGS__)
MGB_DEFINE_CLS_WITH_SUPER_IMPL(typename, _name, _base, ##__VA_ARGS__)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -99,7 +99,7 @@ float GraphPartitionProfiler::duration_in_usec() const {
* \brief An operator that indicates its input var node is contiguous
*/
// clang-format off
MGB_DEFINE_OPR_CLASS(MarkInputContiguous, SingleCNOperatorNodeBase) //{
MGB_DEFINE_OPR_CLASS(MarkInputContiguous, SingleCNOperatorNodeBase) // {
void scn_do_execute() override {};
void init_output_static_infer_desc() override;
void add_input_layout_constraint() override {
......
......@@ -20,38 +20,38 @@ namespace opr {
MGB_DEFINE_OPR_CLASS(
PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>,
public mixin::AlgoChooserHelper) //{
public mixin::AlgoChooserHelper) // {
public:
MGE_WIN_DECLSPEC_FUC PoolingForward(
VarNode* src, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {});
void init_output_static_infer_desc() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
MGE_WIN_DECLSPEC_FUC PoolingForward(
VarNode* src, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {});
void init_output_static_infer_desc() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
};
using Pooling = PoolingForward;
MGB_DEFINE_OPR_CLASS(
PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>,
public mixin::AlgoChooserHelper) //{
public mixin::AlgoChooserHelper) // {
public:
MGE_WIN_DECLSPEC_FUC PoolingBackward(
VarNode* src, VarNode* dst, VarNode* diff, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC PoolingBackward(
VarNode* src, VarNode* dst, VarNode* diff, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param,
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param,
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override final;
MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override final;
};
} // namespace opr
......
......@@ -86,7 +86,7 @@ MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint_contig(OperatorNodeBase& o
//! called in constructor to add output vars
MGE_WIN_DECLSPEC_FUC void add_output_vars(
OperatorNodeBase& opr, size_t nr_output, bool add_workspace);
}
} // namespace megdnn_utils
/*!
* \brief mixin for infer workspace size based on input and output shapes
......@@ -344,34 +344,34 @@ private:
} // namespace mgb
//! define a megdnn opr wrapper class with 1 input for forward
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(_name) \
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \
public: \
_name(VarNode* p0, const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(_name) \
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \
public: \
_name(VarNode* p0, const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
}
//! define a megdnn opr wrapper class with 2 inputs for forward
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD2(_name) \
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \
public: \
_name(VarNode* p0, VarNode* p1, const Param& param, \
const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, SymbolVar p1, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD2(_name) \
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \
public: \
_name(VarNode* p0, VarNode* p1, const Param& param, \
const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, SymbolVar p1, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
}
//! define a megdnn opr wrapper class with 3 inputs for grad
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_BWD3(_name, _extra...) \
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperBwd<megdnn::_name>) \
_extra public : _name(VarNode* p0, VarNode* p1, VarNode* p2, const Param& param, \
const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, SymbolVar p1, SymbolVar p2, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
_extra public : _name(VarNode* p0, VarNode* p1, VarNode* p2, \
const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar p0, SymbolVar p1, SymbolVar p2, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -40,25 +40,25 @@ protected:
};
/* ================= RNG with shape ================= */
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
\
public: \
RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar shape, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
static SymbolVar make( \
ComputingGraph& graph, const TensorShape& shape, \
const OperatorNodeConfig& config, const Param& param = {}) { \
return make( \
var_from_tensor_shape(graph, config, "rng", shape), param, config); \
} \
void init_output_static_infer_desc() override; \
void scn_do_execute() override; \
} \
;
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
\
public: \
RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \
SymbolVar shape, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
static SymbolVar make( \
ComputingGraph& graph, const TensorShape& shape, \
const OperatorNodeConfig& config, const Param& param = {}) { \
return make( \
var_from_tensor_shape(graph, config, "rng", shape), param, \
config); \
} \
void init_output_static_infer_desc() override; \
void scn_do_execute() override; \
};
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG)
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG)
......@@ -66,20 +66,19 @@ _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG)
#undef _DEFINE_RNG_OPR_WITH_SHAPE_CLASS
/* ================= RNG with input ================= */
#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \
void add_input_layout_constraint() override; \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
\
public: \
RNG(_INPUTS(VarNode*), const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static _OUTPUTS make( \
_INPUTS(SymbolVar), const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
void init_output_static_infer_desc() override; \
void scn_do_execute() override; \
} \
;
#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \
void add_input_layout_constraint() override; \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
\
public: \
RNG(_INPUTS(VarNode*), const Param& param, const OperatorNodeConfig& config); \
MGE_WIN_DECLSPEC_FUC static _OUTPUTS make( \
_INPUTS(SymbolVar), const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
void init_output_static_infer_desc() override; \
void scn_do_execute() override; \
};
/* ================= 1 input ================= */
#define _INPUTS(preifx) preifx i0
......@@ -100,7 +99,7 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG)
#undef _INPUTS
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS
} // intl
} // namespace intl
using UniformRNG = intl::UniformRNG;
using GaussianRNG = intl::GaussianRNG;
......@@ -111,16 +110,15 @@ using BetaRNG = intl::BetaRNG;
using ShuffleRNG = intl::ShuffleRNGForward;
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
ShuffleRNGBackward,
intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) //{
ShuffleRNGBackward, intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) // {
public:
ShuffleRNGBackward(
VarNode* out_diff, VarNode* indices, VarNode* result_shape, const Param& param,
const OperatorNodeConfig& config);
ShuffleRNGBackward(
VarNode* out_diff, VarNode* indices, VarNode* result_shape,
const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar out_diff, SymbolVar indices, SymbolVar result_shape,
const Param& param = {}, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar out_diff, SymbolVar indices, SymbolVar result_shape,
const Param& param = {}, const OperatorNodeConfig& config = {});
};
} // namespace opr
......
......@@ -19,7 +19,8 @@ failed_files = Manager().list()
def process_file(file, clang_format, write):
source = open(file, "r").read()
source = re.sub(r"MGB_DEFINE(?P<r>(.|\n)*?)// +{", "class MGB_DEFINE\g<r>{", source)
source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source)
source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source)
result = subprocess.check_output(
[
......@@ -33,6 +34,8 @@ def process_file(file, clang_format, write):
)
result = result.decode("utf-8")
if count:
result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result)
result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result)
if write:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册