提交 ec584890 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Various bugfixes to the core float16 support. (These started showing up

when I converted ops and added unit tests.)
Change: 119189211
上级 4a3536f8
......@@ -269,6 +269,7 @@ template <>
struct ProtoHelper<bfloat16> {
typedef Helper<float>::RepeatedFieldType FieldType;
static const bfloat16* Begin(const TensorProto& proto) {
// TODO: Isn't this wrong, given that int_val is 32 bits long?
return reinterpret_cast<const bfloat16*>(proto.int_val().data());
}
static size_t NumElements(const TensorProto& proto) {
......@@ -284,17 +285,10 @@ struct ProtoHelper<bfloat16> {
template <>
struct ProtoHelper<Eigen::half> {
typedef Helper<float>::RepeatedFieldType FieldType;
static const Eigen::half* Begin(const TensorProto& proto) {
return reinterpret_cast<const Eigen::half*>(proto.int_val().data());
}
static size_t NumElements(const TensorProto& proto) {
return proto.int_val().size();
}
static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
proto->mutable_int_val()->Reserve(n);
proto->mutable_half_val()->Reserve(n);
for (size_t i = 0; i < n; ++i) {
proto->mutable_int_val()->AddAlreadyReserved(data[i].x);
proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
}
}
};
......@@ -345,6 +339,29 @@ TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
return buf;
}
// fp16 is opaque to the protobuf, so we deserialize these identical to uint16
// but with data stored in half_val instead of int_val (ie., we don't use
// ProtoHelper<uint16>).
template <>
TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
int64 n) {
CHECK_GT(n, 0);
Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
uint16* data = buf->template base<uint16>();
const int64 in_n = in.half_val().size();
auto begin = in.half_val().begin();
if (n <= in_n) {
std::copy_n(begin, n, data);
} else if (in_n > 0) {
std::copy_n(begin, in_n, data);
const uint16 last = *(data + in_n - 1);
std::fill_n(data + in_n, n - in_n, last);
} else {
std::fill_n(data, n, 0);
}
return buf;
}
// Copies T[n] stored in the buffer "in" into the repeated field in
// "out" corresponding to type T.
template <typename T>
......
......@@ -90,6 +90,7 @@ typedef enum {
TF_QUINT16 = 16, // Quantized uint16
TF_UINT16 = 17,
TF_COMPLEX128 = 18, // Double-precision complex
TF_HALF = 19,
} TF_DataType;
// --------------------------------------------------------------------------
......
......@@ -93,6 +93,9 @@ Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
int pyarray_type = PyArray_TYPE(array);
PyArray_Descr* descr = PyArray_DESCR(array);
switch (pyarray_type) {
case NPY_FLOAT16:
*out_tf_datatype = TF_HALF;
break;
case NPY_FLOAT32:
*out_tf_datatype = TF_FLOAT;
break;
......@@ -144,6 +147,9 @@ Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
int* out_pyarray_type) {
switch (tf_datatype) {
case TF_HALF:
*out_pyarray_type = NPY_FLOAT16;
break;
case TF_FLOAT:
*out_pyarray_type = NPY_FLOAT32;
break;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册