basic_operators.cpp 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
#include "megbrain/imperative/basic_operators.h"

#include "megbrain/imperative/basic_values.h"

namespace mgb {
namespace imperative {

std::string ApplyOp::to_string() const {
    return m_op.to_string();
}

std::string GetAttr::to_string() const {
    std::string buffer;
    const char* attr_name = ([&] {
        switch (m_attr) {
            case None:
                return "None";
            case DType:
                return "DType";
            case Device:
                return "Device";
            case Shape:
                return "Shape";
            case Value:
                return "Value";
            case Data:
                return "Data";
            default:
                buffer = std::to_string(m_attr);
                return buffer.c_str();
        }
    })();
    return ssprintf("GetAttr{attr=%s}", attr_name);
}

36 37 38 39 40 41 42
CreateTensor::CreateTensor(
        Kind kind, CompNode device, DType dtype, ValueShape shape, Format format)
        : m_kind(kind),
          m_device(device),
          m_dtype(dtype),
          m_shape(shape),
          m_format(format) {}
43 44 45 46 47

CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout)
        : m_kind(kind),
          m_device(device),
          m_dtype(layout.dtype),
48 49
          m_shape(ValueShape::from(layout)),
          m_format(Format::Type::DEFAULT) {
50 51 52 53
    mgb_assert(
            layout.is_contiguous() || layout.is_empty(), "layout should be contiguous");
}

54
auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    Args result;
    for (auto&& input : inputs) {
        if (auto host_storage = input.as_ref<HostStorage>()) {
            mgb_assert(!result.host, "duplicated host value");
            result.host.emplace();
            result.host->reset(*host_storage, {shape().as_tensor_shape(), dtype()});
            mgb_assert(result.host->layout().ndim, "invalid shape");
        } else if (auto device_storage = input.as_ref<DeviceStorage>()) {
            mgb_assert(!result.device, "duplicated device value");
            result.device.emplace(device(), shape().as_tensor_shape(), dtype());
            result.device->reset(*device_storage, {shape().as_tensor_shape(), dtype()});
            mgb_assert(result.device->layout().ndim, "invalid shape");
        } else {
            mgb_throw(
                    MegBrainError,
                    "unknown input type, expects HostStorage or DeviceStorage, got "
                    "%s",
72
                    input.to_string().c_str());
73 74 75 76 77 78 79 80 81 82
        }
    }
    mgb_assert(
            result.host || result.device, "require at least one of host/device value");
    result.kind = kind();
    return result;
}

std::string CreateTensor::to_string() const {
    return ssprintf(
83 84 85
            "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s, format=%s}",
            (int)m_kind, m_device.to_string().c_str(), m_dtype.name(),
            m_shape.to_string().c_str(), m_format.to_string().c_str());
86 87 88 89 90 91
}

std::string DTRCommand::to_string() const {
    return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind);
}

92 93 94 95
std::string CreateNode::to_string() const {
    return "CreateNode";
}

96 97 98 99 100 101 102 103 104 105 106 107
std::string GetName::to_string() const {
    return "GetName{}";
}

std::string RenameValue::to_string() const {
    return ssprintf("RenameValue{name=%s}", imperative::quoted(m_name).c_str());
}

std::string IsScalar::to_string() const {
    return "IsScalar";
}

108 109 110 111 112 113 114 115
std::string GetFormat::to_string() const {
    return "GetFormat{}";
}

std::string SetFormat::to_string() const {
    return ssprintf("SetFormat{format=%s}", m_format.to_string().c_str());
}

116 117 118
std::string GetVarVal::to_string() const {
    return "GetVarVal";
}
119 120
}  // namespace imperative
}  // namespace mgb