diff --git a/imperative/CMakeLists.txt b/imperative/CMakeLists.txt index 3bbdeffd0befd82baa263a72ec689eb90cfd8f53..cdff2af50d3558160fd6190dc27ae19fab94502d 100644 --- a/imperative/CMakeLists.txt +++ b/imperative/CMakeLists.txt @@ -45,30 +45,6 @@ add_custom_command( add_custom_target(gen_opr_py DEPENDS ${GEN_OPS_FILE}) -##################### generate opdef c header and python binding ############## - -set(OP_DEF_HEADER_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/src/include) -file(MAKE_DIRECTORY ${OP_DEF_HEADER_OUT_DIR}/megbrain/imperative/opdef) -set(OP_DEF_HEADER ${OP_DEF_HEADER_OUT_DIR}/megbrain/imperative/opdef/all.h) -set(OP_DEF_PYTHON_BINDING_OUT_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/src) -file(MAKE_DIRECTORY ${OP_DEF_PYTHON_BINDING_OUT_DIR}) -set(OP_DEF_PYTHON_BINDING ${OP_DEF_PYTHON_BINDING_OUT_DIR}/opdef.inl) -set(OP_PARAM_DEF ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py) -set(GEN_OP_DEF_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_op_defs.py) - -add_custom_command( - OUTPUT ${OP_DEF_HEADER} ${OP_DEF_PYTHON_BINDING} - COMMAND ${PYTHON_EXECUTABLE} ${GEN_OP_DEF_SCRIPT} ${OP_PARAM_DEF} ${OP_DEF_HEADER} - COMMAND ${PYTHON_EXECUTABLE} ${GEN_OP_DEF_SCRIPT} -t py ${OP_PARAM_DEF} ${OP_DEF_PYTHON_BINDING} - DEPENDS ${GEN_OP_DEF_SCRIPT} ${OP_PARAM_DEF} - VERBATIM -) - -add_custom_target(gen_op_def_internal DEPENDS ${OP_DEF_HEADER} ${OP_DEF_PYTHON_BINDING}) -add_library(gen_op_def INTERFACE) -target_include_directories(gen_op_def INTERFACE ${OP_DEF_HEADER_OUT_DIR} ${OP_DEF_PYTHON_BINDING_OUT_DIR}) -add_dependencies(gen_op_def gen_op_def_internal) - ##################### end of opdef generation ######################### set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) @@ -77,9 +53,9 @@ add_custom_target(_version_ld SOURCES ${VERSION_SCRIPT}) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/pybind11 ${PROJECT_BINARY_DIR}/third_party/pybind11) pybind11_add_module(${MODULE_NAME} NO_EXTRAS ${SRCS}) if (APPLE OR MSVC OR WIN32) - target_link_libraries(${MODULE_NAME} PRIVATE gen_op_def megbrain megdnn) + target_link_libraries(${MODULE_NAME} PRIVATE megbrain megdnn) else() - target_link_libraries(${MODULE_NAME} PRIVATE gen_op_def megbrain megdnn -Wl,--version-script=${VERSION_SCRIPT}) + target_link_libraries(${MODULE_NAME} PRIVATE megbrain megdnn -Wl,--version-script=${VERSION_SCRIPT}) endif() if (MGE_WITH_DISTRIBUTED) message("Imperative configured to link megray") diff --git a/imperative/python/megengine/functional/distributed.py b/imperative/python/megengine/functional/distributed.py index 92e93f84601a833e864fb34c6d3ddf24c7dc5434..ebb81cf7a2ab8d094a85343d8fa81d27aa6349f2 100644 --- a/imperative/python/megengine/functional/distributed.py +++ b/imperative/python/megengine/functional/distributed.py @@ -8,7 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Optional, Tuple -from ..core._imperative_rt.ops import CollectiveCommDefModeEnum +from ..core._imperative_rt.ops import CollectiveCommMode from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn from ..core.autodiff.grad import ( Tracer, @@ -135,7 +135,7 @@ def reduce_sum( :param group: communication group :param device: execute placement """ - mode = CollectiveCommDefModeEnum.REDUCE_SUM + mode = CollectiveCommMode.REDUCE_SUM return collective_comm(inp, mode, group, device) @@ -148,7 +148,7 @@ def broadcast( :param group: communication group :param device: execute placement """ - mode = CollectiveCommDefModeEnum.BROADCAST + mode = CollectiveCommMode.BROADCAST return collective_comm(inp, mode, group, device) @@ -161,7 +161,7 @@ def all_gather( :param group: communication group :param device: execute placement """ - mode = CollectiveCommDefModeEnum.ALL_GATHER + mode = CollectiveCommMode.ALL_GATHER return collective_comm(inp, mode, group, device) @@ -174,7 +174,7 @@ def reduce_scatter_sum( :param group: communication group :param device: execute placement """ - mode = CollectiveCommDefModeEnum.REDUCE_SCATTER_SUM + mode = CollectiveCommMode.REDUCE_SCATTER_SUM return collective_comm(inp, mode, group, device) @@ -187,7 +187,7 @@ def all_reduce_sum( :param group: communication group :param device: execute placement """ - mode = CollectiveCommDefModeEnum.ALL_REDUCE_SUM + mode = CollectiveCommMode.ALL_REDUCE_SUM return collective_comm(inp, mode, group, device) @@ -200,7 +200,7 @@ def all_reduce_max( :param group: communication group :param device: execute placement """ - mode = CollectiveCommDefModeEnum.ALL_REDUCE_MAX + mode = CollectiveCommMode.ALL_REDUCE_MAX return collective_comm(inp, mode, group, device) @@ -213,7 +213,7 @@ def all_reduce_min( :param group: communication group :param device: execute placement """ - mode = CollectiveCommDefModeEnum.ALL_REDUCE_MIN + mode = CollectiveCommMode.ALL_REDUCE_MIN return collective_comm(inp, mode, group, device) @@ -226,7 +226,7 @@ def gather( :param group: communication group :param device: execute placement """ - mode = CollectiveCommDefModeEnum.GATHER + mode = CollectiveCommMode.GATHER return collective_comm(inp, mode, group, device) @@ -239,7 +239,7 @@ def scatter( :param group: communication group :param device: execute placement """ - mode = CollectiveCommDefModeEnum.SCATTER + mode = CollectiveCommMode.SCATTER return collective_comm(inp, mode, group, device) @@ -252,7 +252,7 @@ def all_to_all( :param group: communication group :param device: execute placement """ - mode = CollectiveCommDefModeEnum.ALL_TO_ALL + mode = CollectiveCommMode.ALL_TO_ALL return collective_comm(inp, mode, group, device) diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index b7f1e01e799355006e6948b05bb1503a400adec7..abedeeea0ac49f113d0b622a9a82b218ef090b73 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -23,7 +23,7 @@ namespace py = pybind11; void init_ops(py::module m) { - #include "opdef.inl" + using namespace mgb::imperative; py::class_, OpDef>(m, "OprAttr") .def(py::init<>()) @@ -43,6 +43,21 @@ void init_ops(py::module m) { py::class_, OpDef>(m, "GetVarShape") .def(py::init()); +#define V(m) .value(#m, CollectiveComm::Mode::m) + py::enum_(m, "CollectiveCommMode") + V(REDUCE_SUM) + V(BROADCAST) + V(ALL_GATHER) + V(REDUCE_SCATTER_SUM) + V(ALL_REDUCE_SUM) + V(ALL_REDUCE_MAX) + V(ALL_REDUCE_MIN) + V(ALL_REDUCE_PROD) + V(GATHER) + V(SCATTER) + V(ALL_TO_ALL); +#undef V + py::class_, OpDef>(m, "CollectiveComm") .def(py::init<>()) .def_readwrite("key", &CollectiveComm::key) diff --git a/imperative/python/tools/gen_op_defs.py b/imperative/python/tools/gen_op_defs.py deleted file mode 100755 index e892a0f5d34c66a2e7853d3a0190d479d2389307..0000000000000000000000000000000000000000 --- a/imperative/python/tools/gen_op_defs.py +++ /dev/null @@ -1,504 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import argparse -import collections -import textwrap -import os -import hashlib -import struct - -class member_defs: - """contain classes to define members of an opr param""" - - Dtype = collections.namedtuple('Dtype', ['cname', 'pycvt', 'pyfmt', - 'cppjson', 'cname_attr']) - Dtype.__new__.__defaults__ = ('', ) - uint32 = Dtype('uint32_t', 'int', 'I', 'NumberInt') - uint64 = Dtype('uint64_t', 'int', 'Q', 'NumberInt', - 'alignas(sizeof(uint64_t)) ') - int32 = Dtype('int32_t', 'int', 'i', 'NumberInt') - float32 = Dtype('float', 'float', 'f', 'Number') - float64 = Dtype('double', 'float', 'd', 'Number') - dtype = Dtype('DTypeEnum', '_as_dtype_num', 'I', 'Number') - bool = Dtype('bool', 'bool', '?', 'Bool') - - class Base: - pass - - - class Doc: - """wrap an identifier to associate document - - note: if the doc starts with a linebreak, it would not be reforamtted. - """ - __slots__ = ['id', 'doc'] - - def __init__(self, id_, doc): - assert isinstance(id_, str) and isinstance(doc, str), (id_, doc) - self.id = id_ - self.doc = doc - - @property - def no_reformat(self): - """whether reformat is disallowed for this doc string""" - return self.doc.startswith('\n') - - @property - def raw_lines(self): - """the doc lines when ``no_format`` is true""" - ret = self.doc.split('\n') - assert not ret[0] - return ret[1:] - - @classmethod - def make(cls, v): - """make doc object from str or doc""" - if isinstance(v, cls): - return v - assert isinstance(v, str) - return cls(v, '') - - def __str__(self): - return self.id - - def __eq__(self, rhs): - if isinstance(rhs, str): - return self.id == rhs - return (isinstance(rhs, Doc) and - (self.id, self.doc) == (rhs.id, rhs.doc)) - - - class Enum(Base): - """define an enum; the result would contain both an enum class def and its - corresponding data field - - :param default: index of default member value - - :attr name_field: name of the data field of this enum in the param - struct - :attr member_alias: list of (member, alias) pairs - """ - __slots__ = ['name', 'name_field', 'members', 'default', - 'member_alias'] - - all_enums = {} - """(param_name, name) => enum""" - - def __init__(self, param_name, name, name_field, members, default, - member_alias): - name = member_defs.Doc.make(name) - assert name.id[0].isupper() - members = tuple(map(member_defs.Doc.make, members)) - if isinstance(default, str): - if default not in name_field: - raise ValueError( - "Default value '{}' does not exist.".format(default)) - default = name_field.index(default) - assert isinstance(default, int) - self.name = name - self.name_field = self.get_name_field(name.id, name_field) - self.members = members - self.default = default - - self.all_enums[(param_name, name.id)] = self - - assert isinstance(member_alias, list) - self.member_alias = member_alias - - @classmethod - def get_name_field(cls, name, name_field): - if name_field is None: - name_field = name[0].lower() + name[1:] - assert isinstance(name_field, str) - return name_field - - class Field(Base): - """define a normal data field""" - __slots__ = ['name', 'dtype', 'default'] - - def __init__(self, name, dtype, default): - assert isinstance(dtype, member_defs.Dtype) - self.name = member_defs.Doc.make(name) - self.dtype = dtype - self.default = default - - class Const(Base): - """define a const data field""" - __slots__ = ['name', 'dtype', 'default'] - - def __init__(self, name, dtype, default): - assert isinstance(dtype, member_defs.Dtype) - self.name = member_defs.Doc.make(name) - self.dtype = dtype - self.default = default - - class EnumAlias(Base): - """alias of enum type from another param""" - __slots__ = ['name', 'name_field', 'src_class', 'src_name', 'default'] - - def __init__(self, name, name_field, src_class, src_name, default): - self.name = name - self.name_field = member_defs.Enum.get_name_field(name, name_field) - self.src_class = src_class - if src_name is None: - src_name = name - self.src_name = src_name - self.default = default - - @property - def src_enum(self): - """source Enum class""" - return member_defs.Enum.all_enums[(self.src_class, self.src_name)] - - def get_default(self): - """get default index; fallback to src index if default is not - set""" - if self.default is None: - return self.src_enum.default - return self.default - - -class ParamDef: - """""" - __all_tags = set() - all_param_defs = [] - - __slots__ = ['name', 'members', 'tag', 'is_legacy'] - - def __init__(self, name, doc='', *, version=0, is_legacy=False): - self.members = [] - self.all_param_defs.append(self) - h = hashlib.sha256(name.encode('utf-8')) - if version: - h.update(struct.pack(' 0: - self._indent() - - -class PyWriter(IndentWriterBase): - - _static_members = None - _non_static_members = None - _enums = None - _enum_map = None - - def __call__(self, fout, defs): - super().__call__(fout) - self._enum_map = {} - self._write('// %s', self._get_header()) - self._write('#include "megbrain/imperative/opdef/all.h"') - self._write('') - self._write('using namespace mgb::imperative;') - self._write('') - self._process(defs) - - def _on_param_begin(self, p): - self._enums = [] - self._non_static_members = [] - self._static_members = [] - - def _reg_enum_single(self, cur_def, e): - alias = None - if isinstance(e, member_defs.Enum): - src = e - else: - assert isinstance(e, member_defs.EnumAlias) - src = e.src_enum - alias = e - - src_py_name = self._enum_map.get(src, None) - if src_py_name is not None: - py_name = '{}{}Enum'.format(cur_def, src.name if alias is None else alias.name) - self._write('m.attr("{}") = m.attr("{}");\n'.format(py_name, src_py_name)) - return - - if alias is None: - enum_name = str(src.name) - else: - enum_name = str(alias.name) - c_name = 'opdef::{}::{}'.format(cur_def, enum_name) - py_name = '{}{}Enum'.format(cur_def, enum_name) - self._write('py::enum_<{}>(m, "{}")'.format(c_name, py_name), indent=1) - for i in src.members: - self._write('.value("{0}", {1}::{0})'.format(i, c_name)) - self._write(';\n', indent=-1) - self._enum_map[src] = py_name - - def _on_param_end(self, p): - cur_def = '{}Def'.format(p.name) - for e in self._enums: - self._reg_enum_single(cur_def, e) - self._write('py::class_(m, "{0}")'.format(cur_def), indent=1) - # TODO: use ctor with given default value - self._write('.def(py::init<>())') - for i in self._static_members: - assert isinstance(i, member_defs.Const) - self._write('.def_property_readonly_static("{0}", []() {{ return opdef::{1}::{0}; }})'.format(i.name, cur_def)) - for i in self._non_static_members: - fname = None - if isinstance(i, member_defs.Field): - fname = i.name - else: - assert isinstance(i, (member_defs.Enum, member_defs.EnumAlias)) - fname = i.name_field - self._write('.def_readwrite("{0}", &opdef::{1}::{0})'.format(fname, cur_def)) - self._write(';\n', indent=-1) - - - def _on_member_enum(self, e,): - self._enums.append(e) - self._non_static_members.append(e) - - def _on_member_enum_alias(self, e): - self._enums.append(e) - self._non_static_members.append(e) - - def _on_member_field(self, f): - self._non_static_members.append(f) - - def _on_const_field(self, f): - self._static_members.append(f) - - -class CPPWriter(IndentWriterBase): - _param_namespace = 'opdef' - - _ctor_args = None - """list of (text in func param, var name); func param name must be var name - appended by an underscore""" - _non_static_members = None - - def __call__(self, fout, defs): - super().__call__(fout) - self._write('// %s', self._get_header()) - self._write('#pragma once') - self._write('#include "megdnn.h"') - # which defined in megbrain/tools/param_defs/mgb_opr_param_defs.py - self._write('#include "megbrain/opr/param_defs.h"') - self._write('#include ') - self._write('namespace mgb {') - self._write('namespace imperative {') - self._write('namespace %s {', self._param_namespace) - self._write('namespace {') - self._write('#include "megdnn/dtype.h"') - self._write('using DTypeEnum = megdnn::DTypeEnum;') - self._write('} // anonymous namespace') - self._process(defs) - self._write('} // namespace %s', self._param_namespace) - self._write('} // namespace imperative') - self._write('} // namespace mgb') - self._write('// vim: syntax=cpp.doxygen') - - def _on_param_begin(self, p): - self._write('struct %sDef {', p.name, indent=1) - self._ctor_args = [] - self._non_static_members = [] - - def _add_ctor_args(self, typename, default, varname): - self._ctor_args.append(( - '{} {}_={}'.format(typename, varname, default), - varname)) - - def _on_param_end(self, p): - ''' - MegDNN param structures are not packed and we need to initialize the structure - paddings to zero or it would break MegBrain hash system. We do memset(0) in default - ctor and use a trick, wrapping non-static members in a anonymous union which would - copy the object representation in its default copy/move ctor, for copy/move ctor. - > The implicitly-defined copy/move constructor for a non-union class X performs - > a memberwise copy/move of its bases and members. [class.copy.ctor 14] - > The implicitly-defined copy/move constructor for a union X copies the object - > representation (6.9) of X. [class.copy.ctor 15] - ''' - if self._non_static_members: - self._write('union { struct {') - for i in self._non_static_members: - if isinstance(i, member_defs.Field): - self._write('%s%s %s;', i.dtype.cname_attr, i.dtype.cname, i.name) - else: - assert isinstance(i, (member_defs.Enum, member_defs.EnumAlias)) - self._write('%s %s;', i.name, i.name_field) - self._write('}; };') - param_list = [] - if self._ctor_args: - pdefs, varnames = zip(*self._ctor_args) - self._write('%sDef(%s) {', p.name, ', '.join(pdefs), indent=1) - self._write('memset(this, 0, sizeof(*this));') - for var in varnames: - self._write('this->%s = %s_;', var, var) - param_list.append(str(var)) - self._write('}', indent=-1) - self._write('megdnn::param::%s param() {', self._cur_class, indent=1) - self._write('return {%s};', ','.join(param_list)) - self._write('}', indent=-1) - self._write('};\n', indent=-1) - - - def __on_member_enum(self, e, default_value): - self._write('using %s = megdnn::param::%s::%s;', e.name, self._cur_class, e.name) - self._non_static_members.append(e) - self._add_ctor_args(e.name, default_value, e.name_field) - - def _on_member_enum(self, e,): - self.__on_member_enum(e, '{}::{}'.format(e.name, e.members[e.default])) - - def _on_member_enum_alias(self, e): - self.__on_member_enum(e, '{}::{}'.format(e.name, e.src_enum.members[e.get_default()])) - - def _on_member_field(self, f): - self._non_static_members.append(f) - self._add_ctor_args(f.dtype.cname, f.default, f.name) - - def _on_const_field(self, f): - if 'int' in f.dtype.cname: - self._write('static constexpr %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default) - else: - self._write('static const %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default) - -def main(): - parser = argparse.ArgumentParser( - 'generate opr param defs from description file') - parser.add_argument('-t', '--type', choices=['c++', 'py'], default='c++', - help='output type') - parser.add_argument('input') - parser.add_argument('output') - args = parser.parse_args() - - with open(args.input) as fin: - inputs = fin.read() - exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) - input_hash = hashlib.sha256() - input_hash.update(inputs.encode(encoding='UTF-8')) - input_hash = input_hash.hexdigest() - - if args.type == 'py': - writer = PyWriter() - else: - writer = CPPWriter() - - with open(args.output, 'w') as fout: - writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) - -if __name__ == '__main__': - main() diff --git a/imperative/src/include/megbrain/imperative.h b/imperative/src/include/megbrain/imperative.h index 4d2ec50e670b2a69f0b7e5145c1bc497a9d89722..ca62217d446813731486b17bfeea9ddf89f7a667 100644 --- a/imperative/src/include/megbrain/imperative.h +++ b/imperative/src/include/megbrain/imperative.h @@ -13,6 +13,5 @@ #include "megbrain/imperative/physical_tensor.h" #include "megbrain/imperative/op_def.h" -#include "megbrain/imperative/opdef/all.h" // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/include/megbrain/imperative/ops/collective_comm.h b/imperative/src/include/megbrain/imperative/ops/collective_comm.h index 9ec67af89fd1486ceb7bdaecf28acfaa9cef9270..0bc4ab492e2634f8f4a639abf4c8403956d0c160 100644 --- a/imperative/src/include/megbrain/imperative/ops/collective_comm.h +++ b/imperative/src/include/megbrain/imperative/ops/collective_comm.h @@ -21,11 +21,13 @@ class CollectiveComm : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; public: + using Mode = megdnn::param::CollectiveComm::Mode; + CollectiveComm() = default; CollectiveComm(const std::string& key_, size_t nr_devices_, uint32_t rank_, bool is_root_, bool local_grad_, const std::string& addr_, uint32_t port_, - const megdnn::param::CollectiveComm::Mode& mode_, + const Mode& mode_, const DType& dtype_, const std::string& backend_, const std::string& comp_node_) : key(key_), @@ -46,7 +48,7 @@ public: bool local_grad; std::string addr; uint32_t port; - megdnn::param::CollectiveComm::Mode mode; + Mode mode; DType dtype; std::string backend; std::string comp_node; diff --git a/imperative/test/CMakeLists.txt b/imperative/test/CMakeLists.txt index 7e50124a30927a35092010333b9b5a287e5ae55e..5599b6ca46b9b95e6e55cc191a2dd613371fc088 100644 --- a/imperative/test/CMakeLists.txt +++ b/imperative/test/CMakeLists.txt @@ -18,7 +18,7 @@ target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHO target_compile_definitions(imperative_test PRIVATE MODULE_NAME=C) target_compile_options(imperative_test PRIVATE -Wno-unused-parameter) -set(LINK_LIBS megbrain megdnn gtest pybind11::embed gen_op_def) +set(LINK_LIBS megbrain megdnn gtest pybind11::embed) if(MGE_WITH_CUDA) list(APPEND LINK_LIBS cudart) endif()