提交 babde352 编写于 作者: J juncaipeng 提交者: GitHub

add cast from uint8 to float, test=develop (#2080)

* add cast from uint8 to float, test=develop
上级 bad8b8cc
......@@ -53,6 +53,12 @@ void CastCompute::Run() {
// float>);
// todo: the input type actually is float.
memcpy(out_data, x_data_begin, sizeof(float) * param.X->numel());
} else if (param.in_dtype == 20 && param.out_dtype == 5) { // uint8->float32
const unsigned char* x_data_begin = param.X->data<unsigned char>();
const unsigned char* x_data_end = x_data_begin + param.X->numel();
float* out_data = param.Out->mutable_data<float>();
std::transform(
x_data_begin, x_data_end, out_data, TransOp<unsigned char, float>);
} else {
LOG(FATAL) << "other has not been implemented";
}
......
......@@ -27,7 +27,7 @@ class CastComputeTester : public arena::TestCase {
std::string output_ = "out";
int in_dtype_;
int out_dtype_;
DDim x_dims_{{2, 2, 2, 2}};
DDim x_dims_{{2, 2}};
public:
CastComputeTester(const Place& place,
......@@ -41,27 +41,32 @@ class CastComputeTester : public arena::TestCase {
CHECK(out);
out->Resize(x_dims_);
if (out_dtype_ == 5 && in_dtype_ == 21) {
if (out_dtype_ == 5 && in_dtype_ == 20) {
auto* x = scope->FindTensor(input_);
auto* x_data = x->data<unsigned char>();
auto* output_data = out->mutable_data<float>();
for (int i = 0; i < x_dims_.production(); i++) {
*output_data = static_cast<float>(*x_data);
output_data++;
x_data++;
}
} else if (out_dtype_ == 5 && in_dtype_ == 21) {
auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
auto* x_data = x->data<char>();
auto* output_data_tmp = output_data;
auto* x_data_tmp = x_data;
for (int i = 0; i < x_dims_.production(); i++) {
*output_data_tmp = static_cast<float>(*x_data_tmp);
output_data_tmp++;
x_data_tmp++;
*output_data = static_cast<float>(*x_data);
output_data++;
x_data++;
}
} else if (out_dtype_ == 5 && in_dtype_ == 2) {
auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
auto* x_data = x->data<int32_t>();
auto* output_data_tmp = output_data;
auto* x_data_tmp = x_data;
for (int i = 0; i < x_dims_.production(); i++) {
*output_data_tmp = static_cast<float>(*x_data_tmp);
output_data_tmp++;
x_data_tmp++;
*output_data = static_cast<float>(*x_data);
output_data++;
x_data++;
}
}
}
......@@ -75,7 +80,13 @@ class CastComputeTester : public arena::TestCase {
}
void PrepareData() override {
if (in_dtype_ == 21) {
if (in_dtype_ == 20) {
std::vector<unsigned char> x_data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
x_data[i] = static_cast<unsigned char>(i % 128);
}
SetCommonTensor(input_, x_dims_, x_data.data());
} else if (in_dtype_ == 21) {
std::vector<char> x_data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
......@@ -101,7 +112,7 @@ TEST(Cast, precision) {
Place place(TARGET(kARM));
std::unique_ptr<arena::TestCase> tester(
new CastComputeTester(place, "def", 21, 5));
new CastComputeTester(place, "def", 20, 5));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册