diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index 4f3a9a3f9cac950dfd6bb217f3a89434f9d63860..ade2030f8dbb926b9a1bc8e9901764fdbfb808c3 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -31,7 +32,7 @@ namespace megdnn { #define MEGDNN_FOREACH_DTYPE_NAME(cb) \ cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(IntB1) cb(IntB2) cb(IntB4) \ cb(Byte) DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) \ - cb(UintB4) cb(Bool) cb(Uint16) + cb(UintB4) cb(Bool) cb(Uint16) cb(Complex64) /*! * \brief iterate through each full byte dtype @@ -39,7 +40,7 @@ namespace megdnn { #define MEGDNN_FOREACH_FULL_BYTE_DTYPE(cb) \ cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(Byte) \ DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) cb(Bool) \ - cb(Uint16) + cb(Uint16) cb(Complex64) /*! * \brief iterate through each fractional byte dtype @@ -314,6 +315,7 @@ typedef bool dt_bool; typedef uint16_t dt_uint16; DNN_INC_FLOAT16(typedef half_float::half dt_float16;) DNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) +typedef std::complex dt_complex64; #define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000 #if MEGDNN_CC_HOST @@ -341,6 +343,7 @@ struct DTypeEnum { #endif Bool = 12, Uint16 = 13, + Complex64 = 14, #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, #define D(_name) _name, MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D) @@ -356,12 +359,28 @@ DTypeEnum(uint32_t e) : ev(e) {} #if MEGDNN_CC_HOST //! dtype numeric category fo -enum class DTypeCategory : int { OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL }; +enum class DTypeCategory : int { + OTHER, + FLOAT, + INT, + LOWBIT, + QUANTIZED, + BOOL, + COMPLEX, +}; //! dtype signedness enum class DTypeSignedness : int { OTHER, UNSIGNED, SIGNED }; #else struct DTypeCategory { - enum Ev { OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL }; + enum Ev { + OTHER, + FLOAT, + INT, + LOWBIT, + QUANTIZED, + BOOL, + COMPLEX, + }; int ev; }; struct DTypeSignedness { @@ -447,6 +466,15 @@ public: bool is_low_bit() const { return low_bit() != 0; } + bool is_complex() const { + return +#if MEGDNN_CC_HOST + m_trait->category == DTypeCategory::COMPLEX; +#else + m_trait->category.ev == DTypeCategory::Ev::COMPLEX; +#endif + } + bool is_quantized_lowbit() const { return low_bit() != 0 && #if MEGDNN_CC_HOST @@ -665,6 +693,11 @@ struct DTypeTrait { MEGDNN_DEF_DT_BASIC_FIELDS(Byte, dt_byte, OTHER, OTHER, 0, false); }; +template <> +struct DTypeTrait { + MEGDNN_DEF_DT_BASIC_FIELDS(Complex64, dt_complex64, COMPLEX, SIGNED, 0, false); +}; + #define MEGDNN_DEF_FRACTION_DT(_name, b) \ template <> \ struct DTypeTrait { \ diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index e639df07372cbe79d1fb74e2623448eda78acb1a..3bb6b96b9c919cd7697736f51939fc5c1e63f30b 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -9,8 +9,11 @@ from ..core._imperative_rt.core2 import ( Const, apply, broadcast_cpp, + create_complex, dtype_promotion, expand_dims_cpp, + get_imag, + get_real, split_cpp, squeeze_cpp, ) @@ -20,13 +23,14 @@ from ..core.ops.builtin import Copy, Identity from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn from ..device import get_default_device from ..tensor import Tensor -from .elemwise import ceil +from .elemwise import ceil, cos, sin __all__ = [ "arange", "broadcast_to", "concat", "cond_take", + "copy", "cumsum", "diag", "expand_dims", @@ -35,21 +39,24 @@ __all__ = [ "full", "full_like", "gather", + "imag", "linspace", "meshgrid", "ones", "ones_like", + "polar", "repeat", "reshape", "roll", + "scatter", "split", "squeeze", "stack", - "scatter", + "swapaxes", "tile", - "copy", "transpose", - "swapaxes", + "complex", + "real", "where", "zeros", "zeros_like", @@ -417,6 +424,26 @@ def ones_like(inp: Tensor) -> Tensor: return full_like(inp, 1.0) +def polar(abs: Tensor, angle: Tensor) -> Tensor: + return create_complex(abs * cos(angle), abs * sin(angle)) + + +def complex(real: Tensor, imag: Tensor) -> Tensor: + if not isinstance(real, Tensor): + real = Tensor(real) + if not isinstance(imag, Tensor): + imag = Tensor(imag) + return create_complex(real, imag) + + +def real(complex: Tensor) -> Tensor: + return get_real(complex) + + +def imag(complex: Tensor) -> Tensor: + return get_imag(complex) + + def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: r"""Returns a tensor filled with given value with the same shape as input tensor. diff --git a/imperative/python/src/helper.cpp b/imperative/python/src/helper.cpp index fc4848902402a8491ff5ac2f95059c10ab98e9cd..0f06ea4fe97374eb7aab87836a7f9874bf114bdd 100644 --- a/imperative/python/src/helper.cpp +++ b/imperative/python/src/helper.cpp @@ -160,7 +160,8 @@ int to_mgb_supported_dtype_raw(int dtype) { #define FOREACH_NPY_DTYPE_PAIR(cb) \ cb(Uint8, NPY_UINT8) cb(Int8, NPY_INT8) cb(Uint16, NPY_UINT16) \ cb(Int16, NPY_INT16) cb(Int32, NPY_INT32) cb(Float16, NPY_FLOAT16) \ - cb(Float32, NPY_FLOAT32) cb(Bool, NPY_BOOL) + cb(Float32, NPY_FLOAT32) cb(Bool, NPY_BOOL) \ + cb(Complex64, NPY_COMPLEX64) #define FOREACH_NPY_MGB_DTYPE_PAIR(cb) \ FOREACH_NPY_DTYPE_PAIR(cb) \ diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 887dbd1342c2f48b98bfd3921da33deaf536aa65..63de260e3954dbb7cf25530546553b27ff94fa61 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -2,11 +2,13 @@ #include "megbrain/dtype.h" #include "megbrain/imperative/backtrace.h" #include "megbrain/imperative/cpp_cupti.h" +#include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/profiler.h" #include "megbrain/imperative/transformation.h" +#include "megbrain/imperative/transformations/complex.h" #include "megbrain/imperative/transformations/dim_expansion.h" #include "megbrain/imperative/transformations/dtype_promote.h" #include "megbrain/imperative/transformations/eval.h" @@ -826,6 +828,10 @@ void init_tensor(py::module m) { .register_at( std::make_shared()) .release()); + MGB_MARK_USED_VAR(transformations + .register_at( + std::make_shared()) + .release()); auto format_trans = std::make_shared(); MGB_MARK_USED_VAR( transformations.register_at(format_trans).release()); @@ -1460,6 +1466,31 @@ void init_tensor(py::module m) { [format_trans]() { return format_trans->get_auto_convert(); }); py::register_exception(m, "TraceError"); + + m.def("create_complex", [](py::object real, py::object imag) { + return TensorWrapper::make( + py_tensor_type, + imperative::apply( + CreateComplex(), + TensorWrapper::try_cast(real.ptr())->m_tensor->data(), + TensorWrapper::try_cast(imag.ptr())->m_tensor->data())[0]); + }); + + m.def("get_real", [](py::object complex) { + return TensorWrapper::make( + py_tensor_type, + imperative::apply( + GetReal(), + TensorWrapper::try_cast(complex.ptr())->m_tensor->data())[0]); + }); + + m.def("get_imag", [](py::object complex) { + return TensorWrapper::make( + py_tensor_type, + imperative::apply( + GetImag(), + TensorWrapper::try_cast(complex.ptr())->m_tensor->data())[0]); + }); } #undef MGE_PY_INTERFACE diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index eab8bb0ddc69ea62d785ea14e432762aa81f8c3e..8f27b9d45e63fe50ff29c6cd84314bf1c3f34f1f 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -19,15 +19,17 @@ public: GroupComm, DTypePromote, DimExpansion, + Complex, Format, Grad, Scalar, Symbol, Trace, Eval, + SEGMENT_COUNT, }; - std::array>, 10> segments; + std::array>, SEGMENT_COUNT> segments; private: template diff --git a/imperative/src/include/megbrain/imperative/transformations/complex.h b/imperative/src/include/megbrain/imperative/transformations/complex.h new file mode 100644 index 0000000000000000000000000000000000000000..ea328551a0f041a64496d1b091cc69a1a6f04ac1 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/complex.h @@ -0,0 +1,330 @@ +#pragma once + +#include + +#include "megbrain/common.h" +#include "megbrain/exception.h" +#include "megbrain/imperative/basic_operators.h" +#include "megbrain/imperative/basic_values.h" +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/operator.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/transformation.h" +#include "megbrain/imperative/utils/helper.h" +#include "megbrain/imperative/utils/span.h" +#include "megbrain/imperative/value.h" +#include "megdnn/thin/small_vector.h" + +namespace mgb { +namespace imperative { + +class ComplexTensor final : public ObjectValue { +private: + ValueRef m_real; + ValueRef m_imag; + +public: + ComplexTensor(ValueRef real, ValueRef imag) : m_real(real), m_imag(imag) {} + + std::string to_string() const override { + return ssprintf( + "ComplexTensor{m_real=%s, m_imag=%s}", m_real.to_string().c_str(), + m_imag.to_string().c_str()); + } + + DTypeValue::ref_t dtype() const { + auto dtype = m_real.dtype(); + mgb_assert(dtype == m_imag.dtype()); + return dtype; + } + + const ValueRef& real() const { return m_real; } + + const ValueRef imag() const { return m_imag; } + + /** + * \brief clear all states of this value + * + */ + void clear() override { + m_real = {}; + m_imag = {}; + } +}; + +class CreateComplex final : public OperatorImpl { +public: + std::string to_string() const override { return "CreateComplex"; } + + std::string raw_type() const override { return "CreateComplex"; } +}; + +class GetReal final : public OperatorImpl { +public: + std::string to_string() const override { return "GetReal"; } + + std::string raw_type() const override { return "GetReal"; } +}; + +class GetImag final : public OperatorImpl { +public: + std::string to_string() const override { return "GetImag"; } + + std::string raw_type() const override { return "GetImag"; } +}; + +class ComplexTransformation final : public Transformation { +private: + ObjectType m_complex_type{"Complex"}; + +public: + std::string name() const override { return "ComplexTransformation"; } + + HostTensorND make_complex_tensor(HostTensorND real, HostTensorND imag) { + mgb_assert(real.shape().eq_shape(imag.shape())); + mgb_assert( + real.dtype() == dtype::Float32() && imag.dtype() == dtype::Float32()); + mgb_assert(real.comp_node() == imag.comp_node()); + HostTensorND complex{real.comp_node(), real.shape(), dtype::Complex64()}; + TensorShape f32_shape = complex.shape(); + f32_shape[f32_shape.ndim++] = 2; + TensorLayout f32_layout = {f32_shape, dtype::Float32()}; + f32_layout.init_contiguous_stride(); + HostTensorND f32{complex.comp_node(), f32_layout}; + f32.storage(complex.storage()); + TensorLayout real_layout = f32_layout; + real_layout.ndim--; + TensorLayout imag_layout = real_layout; + // mgb_assert(!real_layout.is_contiguous()); + // mgb_assert(!imag_layout.is_contiguous()); + f32.sub(SubTensorSpec::make_from_layout(real_layout)).copy_from_fixlayout(real); + f32.sub(SubTensorSpec::make_from_offset_elem(imag_layout, 1)) + .copy_from_fixlayout(imag); + return complex; + } + + ValueRefList apply_complex_mask( + const ApplyOp& apply_op, Span inputs, Span mask) { + ValueRefList real_list(inputs.size()); + ValueRefList imag_list(inputs.size()); + bool any_complex = false; + bool all_complex = true; + for (size_t i = 0; i < inputs.size(); ++i) { + if (auto* complex = inputs[i].as(m_complex_type)) { + mgb_assert(mask[i], "unexpected complex"); + any_complex = true; + real_list[i] = complex->real(); + imag_list[i] = complex->imag(); + } else { + real_list[i] = inputs[i]; + if (mask[i]) { + all_complex = false; + } else { + imag_list[i] = inputs[i]; + } + } + } + if (!any_complex) { + // no complex + return imperative::apply(apply_op, real_list); + } else { + // all complex + mgb_assert(all_complex, "only serval inputs are complex"); + auto reals = imperative::apply(apply_op, real_list); + auto imags = imperative::apply(apply_op, imag_list); + mgb_assert(reals.size() == imags.size()); + ValueRefList results(reals.size()); + for (size_t i = 0; i < results.size(); ++i) { + results[i] = m_complex_type.make(reals[i], imags[i]); + } + return results; + } + } + + ValueRefList apply_complex_real(const ApplyOp& apply_op, Span inputs) { + ValueRefList real_list(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (auto* complex = inputs[i].as(m_complex_type)) { + real_list[i] = complex->real(); + } else { + real_list[i] = inputs[i]; + } + } + return imperative::apply(apply_op, real_list); + } + + ValueRefList apply_transformation( + const Operator& op, Span inputs) override { + if (auto* create_complex = op.as()) { + auto [real, imag] = inputs.as_array<2>(); + auto dtype_real = real.dtype(); + auto dtype_imag = imag.dtype(); + mgb_assert( + *dtype_real == *dtype_imag, "dtype mismatch: %s vs %s", + dtype_real->name(), dtype_imag->name()); + return {m_complex_type.make(real, imag)}; + } else if (auto* create_tensor = op.as()) { + if (create_tensor->dtype().is_complex()) { + auto args = create_tensor->parse(inputs); + mgb_assert(!args.device); + auto& host = *args.host; + // reinterpret_cast to f32 + mgb_assert(host.layout().is_physical_contiguous()); + mgb_assert(host.dtype() == dtype::Complex64()); + TensorShape f32_shape = host.shape(); + f32_shape[f32_shape.ndim++] = 2; + TensorLayout f32_layout = {f32_shape, dtype::Float32()}; + HostTensorND f32_host = {host.comp_node(), f32_layout}; + f32_host.storage(host.storage()); + // take real slice and imag slice + auto real_layout = f32_layout; + real_layout[real_layout.ndim - 1] = 1; + auto imag_layout = real_layout; + auto real_host = + f32_host.sub(SubTensorSpec::make_from_layout(real_layout)); + auto imag_host = f32_host.sub( + SubTensorSpec::make_from_offset_elem(imag_layout, 1)); + // create real and imag + auto real = imperative::apply( + CreateTensor( + create_tensor->kind(), create_tensor->device(), + real_layout), + HostStorage::make(real_host.storage()))[0]; + auto imag = imperative::apply( + CreateTensor( + create_tensor->kind(), create_tensor->device(), + imag_layout), + HostStorage::make(imag_host.storage()))[0]; + return {m_complex_type.make(real, imag)}; + } else { + return imperative::apply(op, inputs); + } + } + bool any_complex = false; + for (auto&& input : inputs) { + if (input.is(m_complex_type)) { + any_complex = true; + break; + } + } + if (!any_complex) { + return imperative::apply(op, inputs); + } + if (auto* apply_op = op.as()) { + // TODO: handle apply op + // see https://zhuanlan.zhihu.com/p/627536105 + if (auto* elemwise = apply_op->op().try_cast_final()) { + switch (elemwise->mode) { + case Elemwise::Mode::MUL: { + auto* complex_a = inputs[0].as(m_complex_type); + auto* complex_b = inputs[1].as(m_complex_type); + auto& mul = *apply_op; + if (complex_a && complex_b) { + auto add = Elemwise::make(Elemwise::Mode::ADD); + auto sub = Elemwise::make(Elemwise::Mode::SUB); + auto real = imperative::apply( + *sub, + imperative::apply( + mul, complex_a->real(), + complex_b->real())[0], + imperative::apply( + mul, complex_a->imag(), + complex_b->imag())[0])[0]; + auto imag = imperative::apply( + *add, + imperative::apply( + mul, complex_a->real(), + complex_b->imag())[0], + imperative::apply( + mul, complex_a->imag(), + complex_b->real())[0])[0]; + return {m_complex_type.make(real, imag)}; + } else if (complex_a) { + auto real = imperative::apply( + mul, complex_a->real(), inputs[1])[0]; + auto imag = imperative::apply( + mul, complex_a->imag(), inputs[1])[0]; + return {m_complex_type.make(real, imag)}; + } else if (complex_b) { + auto real = imperative::apply( + mul, complex_b->real(), inputs[0])[0]; + auto imag = imperative::apply( + mul, complex_b->imag(), inputs[0])[0]; + return {m_complex_type.make(real, imag)}; + } else { + mgb_assert(0); + } + } + case Elemwise::Mode::ADD: + case Elemwise::Mode::SUB: { + bool mask[2] = {true, true}; + return apply_complex_mask(*apply_op, inputs, {mask, 2}); + } + case Elemwise::Mode::NEGATE: { + bool mask[1] = {true}; + return apply_complex_mask(*apply_op, inputs, {mask, 1}); + } + default: { + mgb_assert(0, "unsupported elemwise mode"); + } + } + } else if (auto* reshape = apply_op->op().try_cast_final()) { + SmallVector mask(inputs.size(), false); + mask[0] = true; + return apply_complex_mask(*apply_op, inputs, mask); + } else if (auto* subtensor = apply_op->op().try_cast_final()) { + SmallVector mask(inputs.size(), false); + mask[0] = true; + return apply_complex_mask(*apply_op, inputs, mask); + } else if (auto* get_shape = apply_op->op().try_cast_final()) { + return apply_complex_real(*apply_op, inputs); + } else { + mgb_assert(0, "unsupported operator"); + } + } else if (auto* get_attr = op.as()) { + // TODO: handle get attr + auto&& input = inputs[0].as_ref(m_complex_type); + switch (get_attr->attr()) { + case GetAttr::DType: + switch (input->dtype()->enumv()) { + case DTypeEnum::Float32: { + return {DTypeValue::make(dtype::Complex64())}; + } + default: + mgb_assert( + 0, "unsupported dtype %s", input->dtype()->name()); + } + case GetAttr::Device: + case GetAttr::Shape: + return imperative::apply(op, input->real()); + case GetAttr::Value: { + auto complex = make_complex_tensor( + input->real().numpy()->as_nd(), + input->imag().numpy()->as_nd()); + return {HostValue::make(complex)}; + } + default: + mgb_throw( + MegBrainError, "unsupported %s for complex", + get_attr->to_string().c_str()); + } + } else if (auto* as_real = op.as()) { + auto&& input = inputs[0].as_ref(m_complex_type); + return {input->real()}; + } else if (auto* as_real = op.as()) { + auto&& input = inputs[0].as_ref(m_complex_type); + return {input->imag()}; + } + mgb_throw( + MegBrainError, "unsupported op for complex: %s", + op.to_string().c_str()); + } + + ValueRef unwrap(ValueRef value) override { + mgb_assert(!value.is(m_complex_type), "cannot unwrap complex value"); + return value; + } +}; + +} // namespace imperative +} // namespace mgb diff --git a/src/opr/impl/loop/forward.cpp b/src/opr/impl/loop/forward.cpp index a370d0f2658fc3c7c2be831df19517373b9cfa6c..ce43a930e1d8444cdd3a9a442c7783e57c6277b9 100644 --- a/src/opr/impl/loop/forward.cpp +++ b/src/opr/impl/loop/forward.cpp @@ -400,6 +400,8 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { break; case DTypeEnum::Bool: break; + case DTypeEnum::Complex64: + break; #define cb(x) \ case DTypeEnum::x: \ diff --git a/src/opr/impl/loop/impl.cpp b/src/opr/impl/loop/impl.cpp index 446309273e4584e683ddf53ef707a3946e45955e..286697e5d7904b2904df1031d8a2d9e4eb2fc47d 100644 --- a/src/opr/impl/loop/impl.cpp +++ b/src/opr/impl/loop/impl.cpp @@ -235,6 +235,8 @@ public: break; case DTypeEnum::Uint16: break; + case DTypeEnum::Complex64: + break; #define cb(_dt) \ case DTypeEnum::_dt: \ break; diff --git a/src/serialization/impl/dtype.fbs b/src/serialization/impl/dtype.fbs index a387d05c0bd270bdb84c0d6ffeec3cd079f59c25..6da55fed57af95ff8134f927e48e27d325cc495d 100644 --- a/src/serialization/impl/dtype.fbs +++ b/src/serialization/impl/dtype.fbs @@ -24,6 +24,7 @@ enum DTypeEnum : byte { Bool, Uint16, QuantizedS1, + Complex64, } table LinearQuantizationParam {