提交 822f7a4d 编写于 作者: L liuqi

Support half-type const tensor.

上级 07f8ff18
......@@ -23,6 +23,10 @@ unique_ptr<Tensor> Serializer::Deserialize(const ConstTensor &proto,
tensor->Resize(dims);
switch (proto.data_type()) {
case DT_HALF:
tensor->Copy<half>(reinterpret_cast<const half*>(proto.data()),
proto.data_size());
break;
case DT_FLOAT:
tensor->Copy<float>(reinterpret_cast<const float *>(proto.data()),
proto.data_size());
......
......@@ -4,7 +4,6 @@
#include "gtest/gtest.h"
#include "mace/ops/ops_test_util.h"
#include "mace/kernels/opencl/helper.h"
using namespace mace;
......@@ -130,3 +129,43 @@ void TestDiffTypeBidirectionTransform(const int type, const std::vector<index_t>
TEST(BufferToImageTest, ArgFloatToHalfSmall) {
TestDiffTypeBidirectionTransform<DeviceType::OPENCL, half>(kernels::ARGUMENT, {11});
}
template<DeviceType D, typename T>
void TestStringHalfBidirectionTransform(const int type,
const std::vector<index_t> &input_shape,
const unsigned char *input_data) {
OpsTestNet net;
OpDefBuilder("BufferToImage", "BufferToImageTest")
.Input("Input")
.Output("B2IOutput")
.AddIntArg("buffer_type", type)
.AddIntArg("T", DataTypeToEnum<T>::value)
.Finalize(net.NewOperatorDef());
const half *h_data = reinterpret_cast<const half*>(input_data);
net.AddInputFromArray<D, half>("Input", input_shape, std::vector<half>(h_data, h_data+2));
// Run
net.RunOp(D);
OpDefBuilder("ImageToBuffer", "ImageToBufferTest")
.Input("B2IOutput")
.Output("I2BOutput")
.AddIntArg("buffer_type", type)
.AddIntArg("T", DataTypeToEnum<T>::value)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Check
ExpectTensorNear<half>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-2);
}
TEST(BufferToImageTest, ArgStringHalfToHalfSmall) {
const unsigned char input_data[] = {0xCD, 0x3C, 0x33, 0x40,};
TestStringHalfBidirectionTransform<DeviceType::OPENCL, half>(kernels::ARGUMENT,
{2},
input_data);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册