From babde3527b5875ccd7131bc758abdb3c1e3e5e62 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Mon, 23 Sep 2019 14:20:22 +0800 Subject: [PATCH] add cast from uint8 to float, test=develop (#2080) * add cast from uint8 to float, test=develop --- lite/kernels/arm/cast_compute.cc | 6 ++++ lite/tests/kernels/cast_compute_test.cc | 39 ++++++++++++++++--------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc index 5192eee0b1..87afbae153 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -53,6 +53,12 @@ void CastCompute::Run() { // float>); // todo: the input type actually is float. memcpy(out_data, x_data_begin, sizeof(float) * param.X->numel()); + } else if (param.in_dtype == 20 && param.out_dtype == 5) { // uint8->float32 + const unsigned char* x_data_begin = param.X->data(); + const unsigned char* x_data_end = x_data_begin + param.X->numel(); + float* out_data = param.Out->mutable_data(); + std::transform( + x_data_begin, x_data_end, out_data, TransOp); } else { LOG(FATAL) << "other has not been implemented"; } diff --git a/lite/tests/kernels/cast_compute_test.cc b/lite/tests/kernels/cast_compute_test.cc index db69d866c9..e738b67a71 100644 --- a/lite/tests/kernels/cast_compute_test.cc +++ b/lite/tests/kernels/cast_compute_test.cc @@ -27,7 +27,7 @@ class CastComputeTester : public arena::TestCase { std::string output_ = "out"; int in_dtype_; int out_dtype_; - DDim x_dims_{{2, 2, 2, 2}}; + DDim x_dims_{{2, 2}}; public: CastComputeTester(const Place& place, @@ -41,27 +41,32 @@ class CastComputeTester : public arena::TestCase { CHECK(out); out->Resize(x_dims_); - if (out_dtype_ == 5 && in_dtype_ == 21) { + if (out_dtype_ == 5 && in_dtype_ == 20) { + auto* x = scope->FindTensor(input_); + auto* x_data = x->data(); + auto* output_data = out->mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + *output_data = static_cast(*x_data); + output_data++; + x_data++; + } + } else if (out_dtype_ == 5 && in_dtype_ == 21) { auto* output_data = out->mutable_data(); auto* x = scope->FindTensor(input_); auto* x_data = x->data(); - auto* output_data_tmp = output_data; - auto* x_data_tmp = x_data; for (int i = 0; i < x_dims_.production(); i++) { - *output_data_tmp = static_cast(*x_data_tmp); - output_data_tmp++; - x_data_tmp++; + *output_data = static_cast(*x_data); + output_data++; + x_data++; } } else if (out_dtype_ == 5 && in_dtype_ == 2) { auto* output_data = out->mutable_data(); auto* x = scope->FindTensor(input_); auto* x_data = x->data(); - auto* output_data_tmp = output_data; - auto* x_data_tmp = x_data; for (int i = 0; i < x_dims_.production(); i++) { - *output_data_tmp = static_cast(*x_data_tmp); - output_data_tmp++; - x_data_tmp++; + *output_data = static_cast(*x_data); + output_data++; + x_data++; } } } @@ -75,7 +80,13 @@ class CastComputeTester : public arena::TestCase { } void PrepareData() override { - if (in_dtype_ == 21) { + if (in_dtype_ == 20) { + std::vector x_data(x_dims_.production()); + for (int i = 0; i < x_dims_.production(); i++) { + x_data[i] = static_cast(i % 128); + } + SetCommonTensor(input_, x_dims_, x_data.data()); + } else if (in_dtype_ == 21) { std::vector x_data(x_dims_.production()); for (int i = 0; i < x_dims_.production(); i++) { float sign = i % 3 == 0 ? -1.0f : 1.0f; @@ -101,7 +112,7 @@ TEST(Cast, precision) { Place place(TARGET(kARM)); std::unique_ptr tester( - new CastComputeTester(place, "def", 21, 5)); + new CastComputeTester(place, "def", 20, 5)); arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); -- GitLab