basic_operators.cpp 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
#include "megbrain/imperative/basic_operators.h"

#include "megbrain/imperative/basic_values.h"

namespace mgb {
namespace imperative {

std::string ApplyOp::to_string() const {
    return m_op.to_string();
}

std::string GetAttr::to_string() const {
    std::string buffer;
    const char* attr_name = ([&] {
        switch (m_attr) {
            case None:
                return "None";
            case DType:
                return "DType";
            case Device:
                return "Device";
            case Shape:
                return "Shape";
            case Value:
                return "Value";
            case Data:
                return "Data";
            default:
                buffer = std::to_string(m_attr);
                return buffer.c_str();
        }
    })();
    return ssprintf("GetAttr{attr=%s}", attr_name);
}

CreateTensor::CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape)
        : m_kind(kind), m_device(device), m_dtype(dtype), m_shape(shape) {}

CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout)
        : m_kind(kind),
          m_device(device),
          m_dtype(layout.dtype),
          m_shape(ValueShape::from(layout)) {
    mgb_assert(
            layout.is_contiguous() || layout.is_empty(), "layout should be contiguous");
}

48
auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    Args result;
    for (auto&& input : inputs) {
        if (auto host_storage = input.as_ref<HostStorage>()) {
            mgb_assert(!result.host, "duplicated host value");
            result.host.emplace();
            result.host->reset(*host_storage, {shape().as_tensor_shape(), dtype()});
            mgb_assert(result.host->layout().ndim, "invalid shape");
        } else if (auto device_storage = input.as_ref<DeviceStorage>()) {
            mgb_assert(!result.device, "duplicated device value");
            result.device.emplace(device(), shape().as_tensor_shape(), dtype());
            result.device->reset(*device_storage, {shape().as_tensor_shape(), dtype()});
            mgb_assert(result.device->layout().ndim, "invalid shape");
        } else {
            mgb_throw(
                    MegBrainError,
                    "unknown input type, expects HostStorage or DeviceStorage, got "
                    "%s",
66
                    input.to_string().c_str());
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
        }
    }
    mgb_assert(
            result.host || result.device, "require at least one of host/device value");
    result.kind = kind();
    return result;
}

std::string CreateTensor::to_string() const {
    return ssprintf(
            "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s}", (int)m_kind,
            m_device.to_string().c_str(), m_dtype.name(), m_shape.to_string().c_str());
}

std::string DTRCommand::to_string() const {
    return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind);
}

85 86 87 88
std::string CreateNode::to_string() const {
    return "CreateNode";
}

89 90 91 92 93 94 95 96 97 98 99 100
std::string GetName::to_string() const {
    return "GetName{}";
}

std::string RenameValue::to_string() const {
    return ssprintf("RenameValue{name=%s}", imperative::quoted(m_name).c_str());
}

std::string IsScalar::to_string() const {
    return "IsScalar";
}

101 102 103 104
std::string GetVarVal::to_string() const {
    return "GetVarVal";
}

105 106
}  // namespace imperative
}  // namespace mgb