提交 5030c087 编写于 作者: L liuqi

Fix bug of space to batch.

上级 20c2d127
...@@ -11,43 +11,25 @@ __kernel void space_to_batch(__read_only image2d_t space_data, ...@@ -11,43 +11,25 @@ __kernel void space_to_batch(__read_only image2d_t space_data,
__private const int batch_height, __private const int batch_height,
__private const int batch_width) { __private const int batch_width) {
const int chan_idx = get_global_id(0); const int chan_idx = get_global_id(0);
const int batch_w_idx = mul24(get_global_id(1), 4); const int batch_w_idx = get_global_id(1);
const int batch_hb_idx = get_global_id(2); const int batch_hb_idx = get_global_id(2);
const int batch_b_idx = batch_hb_idx / batch_height; const int batch_b_idx = batch_hb_idx / batch_height;
const int batch_h_idx = batch_hb_idx % batch_height; const int batch_h_idx = batch_hb_idx % batch_height;
const int block_size = mul24(block_height, block_width); const int block_size = block_height * block_width;
const int space_b_idx = batch_b_idx / block_size; const int space_b_idx = batch_b_idx / block_size;
const int remaining_batch_idx = batch_b_idx % block_size; const int remaining_batch_idx = batch_b_idx % block_size;
const int space_h_idx = (remaining_batch_idx / block_width) + const int space_h_idx = (remaining_batch_idx / block_width) +
mul24(batch_h_idx, block_height) - padding_height; batch_h_idx * block_height - padding_height;
int space_w_idx = (remaining_batch_idx % block_width) + const int space_w_idx = (remaining_batch_idx % block_width) +
mul24(batch_w_idx, block_width) - padding_width; batch_w_idx * block_width - padding_width;
int2 space_coord = (int2)(mul24(chan_idx, space_width) + space_w_idx, int2 space_coord = (int2)(chan_idx * space_width + space_w_idx,
mul24(space_b_idx, space_height) + space_h_idx); space_b_idx * space_height + space_h_idx);
DATA_TYPE4 value = READ_IMAGET(space_data, SAMPLER, space_coord); DATA_TYPE4 value = READ_IMAGET(space_data, SAMPLER, space_coord);
int2 batch_coord = (int2)(mul24(chan_idx, batch_width) + batch_w_idx, batch_hb_idx); int2 batch_coord = (int2)(chan_idx * batch_width + batch_w_idx, batch_hb_idx);
WRITE_IMAGET(batch_data, batch_coord, value);
space_coord.x += block_width;
value = READ_IMAGET(space_data, SAMPLER, space_coord);
batch_coord.x += 1;
WRITE_IMAGET(batch_data, batch_coord, value);
space_coord.x += block_width;
value = READ_IMAGET(space_data, SAMPLER, space_coord);
batch_coord.x += 1;
WRITE_IMAGET(batch_data, batch_coord, value);
space_coord.x += block_width;
value = READ_IMAGET(space_data, SAMPLER, space_coord);
batch_coord.x += 1;
WRITE_IMAGET(batch_data, batch_coord, value); WRITE_IMAGET(batch_data, batch_coord, value);
} }
...@@ -62,42 +44,24 @@ __kernel void batch_to_space(__read_only image2d_t batch_data, ...@@ -62,42 +44,24 @@ __kernel void batch_to_space(__read_only image2d_t batch_data,
__private const int batch_height, __private const int batch_height,
__private const int batch_width) { __private const int batch_width) {
const int chan_idx = get_global_id(0); const int chan_idx = get_global_id(0);
const int batch_w_idx = mul24(get_global_id(1), 4); const int batch_w_idx = get_global_id(1);
const int batch_hb_idx = get_global_id(2); const int batch_hb_idx = get_global_id(2);
const int batch_b_idx = batch_hb_idx / batch_height; const int batch_b_idx = batch_hb_idx / batch_height;
const int batch_h_idx = batch_hb_idx % batch_height; const int batch_h_idx = batch_hb_idx % batch_height;
const int block_size = mul24(block_height, block_width); const int block_size = block_height * block_width;
const int space_b_idx = batch_b_idx / block_size; const int space_b_idx = batch_b_idx / block_size;
const int remaining_batch_idx = batch_b_idx % block_size; const int remaining_batch_idx = batch_b_idx % block_size;
const int space_h_idx = (remaining_batch_idx / block_width) + const int space_h_idx = (remaining_batch_idx / block_width) +
mul24(batch_h_idx, block_height) - padding_height; batch_h_idx * block_height - padding_height;
const int space_w_idx = (remaining_batch_idx % block_width) + const int space_w_idx = (remaining_batch_idx % block_width) +
mul24(batch_w_idx, block_width) - padding_width; batch_w_idx * block_width - padding_width;
int2 batch_coord = (int2)(mul24(chan_idx, batch_width) + batch_w_idx, batch_hb_idx); int2 batch_coord = (int2)(chan_idx * batch_width + batch_w_idx, batch_hb_idx);
DATA_TYPE4 value = READ_IMAGET(batch_data, SAMPLER, batch_coord); DATA_TYPE4 value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
int2 space_coord = (int2)(mul24(chan_idx, space_width) + space_w_idx, int2 space_coord = (int2)(chan_idx * space_width + space_w_idx,
mul24(space_b_idx, space_height) + space_h_idx); space_b_idx * space_height + space_h_idx);
WRITE_IMAGET(space_data, space_coord, value);
batch_coord.x += 1;
value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
space_coord.x += block_width;
WRITE_IMAGET(space_data, space_coord, value);
batch_coord.x += 1;
value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
space_coord.x += block_width;
WRITE_IMAGET(space_data, space_coord, value);
batch_coord.x += 1;
value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
space_coord.x += block_width;
WRITE_IMAGET(space_data, space_coord, value); WRITE_IMAGET(space_data, space_coord, value);
} }
...@@ -54,7 +54,6 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(Tensor *space_tensor ...@@ -54,7 +54,6 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(Tensor *space_tensor
s2b_kernel.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(2))); s2b_kernel.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(2)));
const uint32_t chan_blk = RoundUpDiv4<uint32_t>(batch_tensor->dim(3)); const uint32_t chan_blk = RoundUpDiv4<uint32_t>(batch_tensor->dim(3));
// const uint32_t width_blk = RoundUpDiv4<uint32_t>(batch_tensor->dim(2));
const uint32_t gws[3] = {chan_blk, const uint32_t gws[3] = {chan_blk,
static_cast<uint32_t>(batch_tensor->dim(2)), static_cast<uint32_t>(batch_tensor->dim(2)),
static_cast<uint32_t>(batch_tensor->dim(0) * batch_tensor->dim(1))}; static_cast<uint32_t>(batch_tensor->dim(0) * batch_tensor->dim(1))};
......
...@@ -23,12 +23,12 @@ class BatchToSpaceNDOp : public Operator<D, T> { ...@@ -23,12 +23,12 @@ class BatchToSpaceNDOp : public Operator<D, T> {
true) {} true) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT); const Tensor *batch_tensor = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT); Tensor *space_tensor= this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0); std::vector<index_t> output_shape(4, 0);
BatchToSpaceHelper(input_tensor, output, output_shape); BatchToSpaceHelper(batch_tensor, space_tensor, output_shape);
functor_(output, output_shape, const_cast<Tensor *>(input_tensor), future); functor_(space_tensor, output_shape, const_cast<Tensor *>(batch_tensor), future);
return true; return true;
} }
......
...@@ -24,12 +24,12 @@ class SpaceToBatchNDOp : public Operator<D, T> { ...@@ -24,12 +24,12 @@ class SpaceToBatchNDOp : public Operator<D, T> {
false) {} false) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT); const Tensor *space_tensor= this->Input(INPUT);
Tensor *output = this->Output(OUTPUT); Tensor *batch_tensor= this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0); std::vector<index_t> output_shape(4, 0);
SpaceToBatchHelper(input_tensor, output, output_shape); SpaceToBatchHelper(space_tensor, batch_tensor, output_shape);
functor_(const_cast<Tensor *>(input_tensor), output_shape, output, future); functor_(const_cast<Tensor *>(space_tensor), output_shape, batch_tensor, future);
return true; return true;
} }
......
...@@ -131,6 +131,20 @@ TEST(SpaceToBatchTest, SmallDataWithTwoPadding) { ...@@ -131,6 +131,20 @@ TEST(SpaceToBatchTest, SmallDataWithTwoPadding) {
); );
} }
TEST(SpaceToBatchTest, SmallDataWithLargeImage) {
TestBidirectionalTransform<float>({1, 2, 10, 1},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
{2, 2},
{0, 0, 0, 0},
{4, 1, 5, 1},
{1, 3, 5, 7, 9,
2, 4, 6, 8, 10,
11, 13, 15, 17, 19,
12, 14, 16, 18, 20}
);
}
TEST(SpaceToBatchTest, MultiChannelData) { TEST(SpaceToBatchTest, MultiChannelData) {
TestBidirectionalTransform<float>({1, 2, 2, 3}, TestBidirectionalTransform<float>({1, 2, 2, 3},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册