basic_values.cpp 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
#include "megbrain/imperative/basic_values.h"

namespace mgb {
namespace imperative {

std::string ShapeValue::to_string() const {
    return ssprintf("ValueShape%s", ValueShape::to_string().c_str());
}

std::string CompNodeValue::to_string() const {
    return CompNode::to_string();
}

std::string BoolValue::to_string() const {
    return (*m_value) ? "true" : "false";
}

std::string HostStorage::to_string() const {
    return ssprintf("HostStorage{device=%s}", comp_node().to_string().c_str());
}

std::string DeviceStorage::to_string() const {
    return ssprintf("DeviceStorage{device=%s}", comp_node().to_string().c_str());
}

std::string HostValue::to_string() const {
    return ssprintf(
            "HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
            m_dtype.name(), m_shape.to_string().c_str());
}

HostTensorND HostValue::as_nd(bool allow_scalar) const {
    HostTensorND nd;
    TensorShape tensor_shape;
    if (m_shape.is_scalar()) {
        mgb_assert(allow_scalar);
        tensor_shape = TensorShape{1};
    } else {
        tensor_shape = m_shape.as_tensor_shape();
    }
    nd.reset(m_storage, {tensor_shape, dtype()});
    return nd;
}

std::string DeviceValue::to_string() const {
    return ssprintf(
            "DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
            m_dtype.name(), m_shape.to_string().c_str());
}

DeviceTensorND DeviceValue::as_nd(bool allow_scalar) const {
    DeviceTensorND nd;
    TensorShape tensor_shape;
    if (m_shape.is_scalar()) {
        mgb_assert(allow_scalar);
        tensor_shape = TensorShape{1};
    } else {
        tensor_shape = m_shape.as_tensor_shape();
    }
    nd.reset(m_storage, {tensor_shape, dtype()});
    return nd;
}

std::string FunctionValue::to_string() const {
    return ssprintf("FunctionValue{type=%s}", target_type().name());
}

std::string DTypeValue::to_string() const {
    return DType::name();
}

std::string StringValue::to_string() const {
    return imperative::quoted((std::string&)*this);
}

std::string ErrorValue::to_string() const {
    return ssprintf("ErrorValue{message=%s}", message().c_str());
}

}  // namespace imperative
}  // namespace mgb