提交 5030c087 编写于 作者: L liuqi

Fix bug of space to batch.

上级 20c2d127
......@@ -11,43 +11,25 @@ __kernel void space_to_batch(__read_only image2d_t space_data,
__private const int batch_height,
__private const int batch_width) {
const int chan_idx = get_global_id(0);
const int batch_w_idx = mul24(get_global_id(1), 4);
const int batch_w_idx = get_global_id(1);
const int batch_hb_idx = get_global_id(2);
const int batch_b_idx = batch_hb_idx / batch_height;
const int batch_h_idx = batch_hb_idx % batch_height;
const int block_size = mul24(block_height, block_width);
const int block_size = block_height * block_width;
const int space_b_idx = batch_b_idx / block_size;
const int remaining_batch_idx = batch_b_idx % block_size;
const int space_h_idx = (remaining_batch_idx / block_width) +
mul24(batch_h_idx, block_height) - padding_height;
int space_w_idx = (remaining_batch_idx % block_width) +
mul24(batch_w_idx, block_width) - padding_width;
batch_h_idx * block_height - padding_height;
const int space_w_idx = (remaining_batch_idx % block_width) +
batch_w_idx * block_width - padding_width;
int2 space_coord = (int2)(mul24(chan_idx, space_width) + space_w_idx,
mul24(space_b_idx, space_height) + space_h_idx);
int2 space_coord = (int2)(chan_idx * space_width + space_w_idx,
space_b_idx * space_height + space_h_idx);
DATA_TYPE4 value = READ_IMAGET(space_data, SAMPLER, space_coord);
int2 batch_coord = (int2)(mul24(chan_idx, batch_width) + batch_w_idx, batch_hb_idx);
WRITE_IMAGET(batch_data, batch_coord, value);
space_coord.x += block_width;
value = READ_IMAGET(space_data, SAMPLER, space_coord);
batch_coord.x += 1;
WRITE_IMAGET(batch_data, batch_coord, value);
space_coord.x += block_width;
value = READ_IMAGET(space_data, SAMPLER, space_coord);
batch_coord.x += 1;
WRITE_IMAGET(batch_data, batch_coord, value);
space_coord.x += block_width;
value = READ_IMAGET(space_data, SAMPLER, space_coord);
batch_coord.x += 1;
int2 batch_coord = (int2)(chan_idx * batch_width + batch_w_idx, batch_hb_idx);
WRITE_IMAGET(batch_data, batch_coord, value);
}
......@@ -62,42 +44,24 @@ __kernel void batch_to_space(__read_only image2d_t batch_data,
__private const int batch_height,
__private const int batch_width) {
const int chan_idx = get_global_id(0);
const int batch_w_idx = mul24(get_global_id(1), 4);
const int batch_w_idx = get_global_id(1);
const int batch_hb_idx = get_global_id(2);
const int batch_b_idx = batch_hb_idx / batch_height;
const int batch_h_idx = batch_hb_idx % batch_height;
const int block_size = mul24(block_height, block_width);
const int block_size = block_height * block_width;
const int space_b_idx = batch_b_idx / block_size;
const int remaining_batch_idx = batch_b_idx % block_size;
const int space_h_idx = (remaining_batch_idx / block_width) +
mul24(batch_h_idx, block_height) - padding_height;
batch_h_idx * block_height - padding_height;
const int space_w_idx = (remaining_batch_idx % block_width) +
mul24(batch_w_idx, block_width) - padding_width;
batch_w_idx * block_width - padding_width;
int2 batch_coord = (int2)(mul24(chan_idx, batch_width) + batch_w_idx, batch_hb_idx);
int2 batch_coord = (int2)(chan_idx * batch_width + batch_w_idx, batch_hb_idx);
DATA_TYPE4 value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
int2 space_coord = (int2)(mul24(chan_idx, space_width) + space_w_idx,
mul24(space_b_idx, space_height) + space_h_idx);
WRITE_IMAGET(space_data, space_coord, value);
batch_coord.x += 1;
value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
space_coord.x += block_width;
WRITE_IMAGET(space_data, space_coord, value);
batch_coord.x += 1;
value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
space_coord.x += block_width;
WRITE_IMAGET(space_data, space_coord, value);
batch_coord.x += 1;
value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
space_coord.x += block_width;
int2 space_coord = (int2)(chan_idx * space_width + space_w_idx,
space_b_idx * space_height + space_h_idx);
WRITE_IMAGET(space_data, space_coord, value);
}
......@@ -54,7 +54,6 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(Tensor *space_tensor
s2b_kernel.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(2)));
const uint32_t chan_blk = RoundUpDiv4<uint32_t>(batch_tensor->dim(3));
// const uint32_t width_blk = RoundUpDiv4<uint32_t>(batch_tensor->dim(2));
const uint32_t gws[3] = {chan_blk,
static_cast<uint32_t>(batch_tensor->dim(2)),
static_cast<uint32_t>(batch_tensor->dim(0) * batch_tensor->dim(1))};
......
......@@ -23,12 +23,12 @@ class BatchToSpaceNDOp : public Operator<D, T> {
true) {}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
const Tensor *batch_tensor = this->Input(INPUT);
Tensor *space_tensor= this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0);
BatchToSpaceHelper(input_tensor, output, output_shape);
functor_(output, output_shape, const_cast<Tensor *>(input_tensor), future);
BatchToSpaceHelper(batch_tensor, space_tensor, output_shape);
functor_(space_tensor, output_shape, const_cast<Tensor *>(batch_tensor), future);
return true;
}
......
......@@ -24,12 +24,12 @@ class SpaceToBatchNDOp : public Operator<D, T> {
false) {}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
const Tensor *space_tensor= this->Input(INPUT);
Tensor *batch_tensor= this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0);
SpaceToBatchHelper(input_tensor, output, output_shape);
functor_(const_cast<Tensor *>(input_tensor), output_shape, output, future);
SpaceToBatchHelper(space_tensor, batch_tensor, output_shape);
functor_(const_cast<Tensor *>(space_tensor), output_shape, batch_tensor, future);
return true;
}
......
......@@ -131,6 +131,20 @@ TEST(SpaceToBatchTest, SmallDataWithTwoPadding) {
);
}
TEST(SpaceToBatchTest, SmallDataWithLargeImage) {
TestBidirectionalTransform<float>({1, 2, 10, 1},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
{2, 2},
{0, 0, 0, 0},
{4, 1, 5, 1},
{1, 3, 5, 7, 9,
2, 4, 6, 8, 10,
11, 13, 15, 17, 19,
12, 14, 16, 18, 20}
);
}
TEST(SpaceToBatchTest, MultiChannelData) {
TestBidirectionalTransform<float>({1, 2, 2, 3},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册