提交 364b9e5c 编写于 作者: L liuqi

Fix space_to_batch bugs of coordinates beyond boundary.

上级 d1c3fef4
......@@ -173,9 +173,13 @@ int main(int argc, char **argv) {
// load input
ifstream in_file(input_file, ios::in | ios::binary);
in_file.read(reinterpret_cast<char *>(input_data.get()),
input_size * sizeof(float));
in_file.close();
if (in_file.is_open()) {
in_file.read(reinterpret_cast<char *>(input_data.get()),
input_size * sizeof(float));
in_file.close();
} else {
LOG(ERROR) << "Open input file failed";
}
// Init model
VLOG(0) << "Run init";
......
......@@ -25,8 +25,14 @@ __kernel void space_to_batch(__read_only image2d_t space_data,
const int space_w_idx = (remaining_batch_idx % block_width) +
batch_w_idx * block_width - padding_width;
int2 space_coord = (int2)(chan_idx * space_width + space_w_idx,
space_b_idx * space_height + space_h_idx);
const int space_coord_x = select(chan_idx * space_width + space_w_idx,
-1,
space_w_idx < 0 || space_w_idx >= space_width);
const int space_coord_y = select(space_b_idx * space_height + space_h_idx,
-1,
space_h_idx < 0 || space_h_idx >= space_height);
int2 space_coord = (int2)(space_coord_x,
space_coord_y);
DATA_TYPE4 value = READ_IMAGET(space_data, SAMPLER, space_coord);
int2 batch_coord = (int2)(chan_idx * batch_width + batch_w_idx, batch_hb_idx);
......@@ -58,10 +64,13 @@ __kernel void batch_to_space(__read_only image2d_t batch_data,
const int space_w_idx = (remaining_batch_idx % block_width) +
batch_w_idx * block_width - padding_width;
int2 batch_coord = (int2)(chan_idx * batch_width + batch_w_idx, batch_hb_idx);
DATA_TYPE4 value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
if (0 <= space_w_idx && space_w_idx < space_width &&
0 <= space_h_idx && space_h_idx < space_height) {
int2 batch_coord = (int2)(chan_idx * batch_width + batch_w_idx, batch_hb_idx);
DATA_TYPE4 value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
int2 space_coord = (int2)(chan_idx * space_width + space_w_idx,
space_b_idx * space_height + space_h_idx);
WRITE_IMAGET(space_data, space_coord, value);
int2 space_coord = (int2)(chan_idx * space_width + space_w_idx,
space_b_idx * space_height + space_h_idx);
WRITE_IMAGET(space_data, space_coord, value);
}
}
......@@ -322,9 +322,18 @@ struct Expector<EXP_TYPE, RES_TYPE, true> {
Tensor::MappingGuard y_mapper(&y);
auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) {
EXPECT_NEAR(a[i], b[i], abs_err) << "a = " << a << " b = " << b
<< " index = " << i;
for (int n = 0; n < x.dim(0); ++n) {
for (int h = 0; h < x.dim(1); ++h) {
for (int w = 0; w < x.dim(2); ++w) {
for (int c = 0; c < x.dim(3); ++c) {
EXPECT_NEAR(*a, *b, abs_err) << "with index = ["
<< n << ", " << h << ", "
<< w << ", " << c << "]";
a++;
b++;
}
}
}
}
}
......
......@@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "mace/ops/ops_test_util.h"
#include <fstream>
using namespace mace;
......@@ -187,3 +188,61 @@ TEST(SpaceToBatchTest, MultiBatchAndChannelData) {
);
}
//TEST(SpaceTobatchTest, CompareTF) {
//
// const std::string space_file = "/data/local/tmp/test/input";
// const std::string batch_file = "/data/local/tmp/test/output";
// const std::vector<index_t> space_shape = {1, 256, 256, 32};
// const int space_size = std::accumulate(space_shape.begin(), space_shape.end(), 1, std::multiplies<int>());
// const std::vector<index_t> batch_shape = {4, 130, 130, 32};
// const int batch_size = std::accumulate(batch_shape.begin(), batch_shape.end(), 1, std::multiplies<int>());
//
// auto space_tensor = unique_ptr<Tensor>(new Tensor(GetDeviceAllocator(DeviceType::OPENCL),
// DataTypeToEnum<float>::v()));
// space_tensor->Resize(space_shape);
// std::vector<float> space_data(space_size, 0.0);
// std::ifstream in_file(space_file, std::ios::in | std::ios::binary);
// if (in_file.is_open()) {
// in_file.read(reinterpret_cast<char *>(space_data.data()),
// space_size * sizeof(float));
// in_file.close();
// Tensor::MappingGuard space_mapper(space_tensor.get());
// float *space_ptr = space_tensor->mutable_data<float>();
// MACE_CHECK(static_cast<size_t>(space_tensor->size()) == space_data.size())
// << "Space tensor size:" << space_tensor->size()
// << ", space data size:" << space_data.size();
// memcpy(space_ptr, space_data.data(), space_data.size() * sizeof(float));
// } else {
// VLOG(0) << "open space file failed";
// }
//
// auto batch_tensor = unique_ptr<Tensor>(new Tensor(GetDeviceAllocator(DeviceType::OPENCL),
// DataTypeToEnum<float>::v()));
// std::vector<float> batch_data(batch_size, 0.0);
// batch_tensor->Resize(batch_shape);
// {
// std::ifstream in_file(batch_file, std::ios::in | std::ios::binary);
// if (in_file.is_open()) {
// in_file.read(reinterpret_cast<char *>(batch_data.data()),
// batch_size * sizeof(float));
// in_file.close();
// } else {
// VLOG(0) << "open batch file failed";
// }
// Tensor::MappingGuard batch_mapper(batch_tensor.get());
// float *batch_ptr = batch_tensor->mutable_data<float>();
// MACE_CHECK(static_cast<size_t>(batch_tensor->size()) == batch_data.size());
// memcpy(batch_ptr, batch_data.data(), batch_data.size() * sizeof(float));
// }
//
// RunSpaceToBatch<DeviceType::OPENCL>(space_shape, space_data,
// {2, 2},
// {2, 2, 2, 2},
// batch_tensor.get());
//
// RunBatchToSpace<DeviceType::OPENCL>(batch_shape, batch_data,
// {2, 2},
// {2, 2, 2, 2},
// space_tensor.get());
//}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册