diff --git a/mace/core/serializer.cc b/mace/core/serializer.cc index 5d24e9b7b24b682751c86e7032f2b847eca6df03..6e96f34c32172acbc722d7777926d014bbdfa404 100644 --- a/mace/core/serializer.cc +++ b/mace/core/serializer.cc @@ -23,6 +23,10 @@ unique_ptr Serializer::Deserialize(const ConstTensor &proto, tensor->Resize(dims); switch (proto.data_type()) { + case DT_HALF: + tensor->Copy(reinterpret_cast(proto.data()), + proto.data_size()); + break; case DT_FLOAT: tensor->Copy(reinterpret_cast(proto.data()), proto.data_size()); diff --git a/mace/ops/buffer_to_image_test.cc b/mace/ops/buffer_to_image_test.cc index 760103d0f40886e4c2b34fecb6a94ed2dee37fdc..34c7d16f8c1421e4a295896b8bef8bf88625c4d1 100644 --- a/mace/ops/buffer_to_image_test.cc +++ b/mace/ops/buffer_to_image_test.cc @@ -4,7 +4,6 @@ #include "gtest/gtest.h" #include "mace/ops/ops_test_util.h" -#include "mace/kernels/opencl/helper.h" using namespace mace; @@ -130,3 +129,43 @@ void TestDiffTypeBidirectionTransform(const int type, const std::vector TEST(BufferToImageTest, ArgFloatToHalfSmall) { TestDiffTypeBidirectionTransform(kernels::ARGUMENT, {11}); } + +template +void TestStringHalfBidirectionTransform(const int type, + const std::vector &input_shape, + const unsigned char *input_data) { + OpsTestNet net; + OpDefBuilder("BufferToImage", "BufferToImageTest") + .Input("Input") + .Output("B2IOutput") + .AddIntArg("buffer_type", type) + .AddIntArg("T", DataTypeToEnum::value) + .Finalize(net.NewOperatorDef()); + + const half *h_data = reinterpret_cast(input_data); + + net.AddInputFromArray("Input", input_shape, std::vector(h_data, h_data+2)); + + // Run + net.RunOp(D); + + OpDefBuilder("ImageToBuffer", "ImageToBufferTest") + .Input("B2IOutput") + .Output("I2BOutput") + .AddIntArg("buffer_type", type) + .AddIntArg("T", DataTypeToEnum::value) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + // Check + ExpectTensorNear(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-2); +} + +TEST(BufferToImageTest, ArgStringHalfToHalfSmall) { + const unsigned char input_data[] = {0xCD, 0x3C, 0x33, 0x40,}; + TestStringHalfBidirectionTransform(kernels::ARGUMENT, + {2}, + input_data); +}