basic_values.cpp 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#include "megbrain/imperative/basic_values.h"

namespace mgb {
namespace imperative {

std::string ShapeValue::to_string() const {
    return ssprintf("ValueShape%s", ValueShape::to_string().c_str());
}

std::string CompNodeValue::to_string() const {
    return CompNode::to_string();
}

std::string BoolValue::to_string() const {
15
    return (*this) ? "true" : "false";
16 17
}

18 19 20 21
std::string IntegerValue::to_string() const {
    return std::to_string((int)*this);
}

22 23 24 25 26 27 28 29 30 31 32
std::string HostStorage::to_string() const {
    return ssprintf("HostStorage{device=%s}", comp_node().to_string().c_str());
}

std::string DeviceStorage::to_string() const {
    return ssprintf("DeviceStorage{device=%s}", comp_node().to_string().c_str());
}

std::string HostValue::to_string() const {
    return ssprintf(
            "HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
33
            dtype().name(), shape().to_string().c_str());
34 35
}

36
HostTensorND HostTensor::as_nd(bool allow_scalar) const {
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
    HostTensorND nd;
    TensorShape tensor_shape;
    if (m_shape.is_scalar()) {
        mgb_assert(allow_scalar);
        tensor_shape = TensorShape{1};
    } else {
        tensor_shape = m_shape.as_tensor_shape();
    }
    nd.reset(m_storage, {tensor_shape, dtype()});
    return nd;
}

std::string DeviceValue::to_string() const {
    return ssprintf(
            "DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
52
            dtype().name(), shape().to_string().c_str());
53 54
}

55
DeviceTensorND DeviceTensor::as_nd(bool allow_scalar) const {
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
    DeviceTensorND nd;
    TensorShape tensor_shape;
    if (m_shape.is_scalar()) {
        mgb_assert(allow_scalar);
        tensor_shape = TensorShape{1};
    } else {
        tensor_shape = m_shape.as_tensor_shape();
    }
    nd.reset(m_storage, {tensor_shape, dtype()});
    return nd;
}

std::string FunctionValue::to_string() const {
    return ssprintf("FunctionValue{type=%s}", target_type().name());
}

std::string DTypeValue::to_string() const {
    return DType::name();
}

std::string StringValue::to_string() const {
    return imperative::quoted((std::string&)*this);
}

std::string ErrorValue::to_string() const {
    return ssprintf("ErrorValue{message=%s}", message().c_str());
}

}  // namespace imperative
}  // namespace mgb