提交 4563cc60 编写于 作者: 李寅

Add batch space cpu implementation

上级 66a49116
......@@ -107,6 +107,8 @@ inline std::ostream &operator<<(std::ostream &os, unsigned char c) {
}
} // namespace numerical_chars
enum DataFormat { NHWC = 0, NCHW = 1, HWOI = 2, OIHW = 3, HWIO = 4 };
class Tensor {
public:
Tensor(Allocator *alloc, DataType type)
......
......@@ -8,6 +8,7 @@ __kernel void space_to_batch(KERNEL_ERROR_PARAMS
__private const int block_width,
__private const int padding_height,
__private const int padding_width,
__private const int batch_size,
__private const int space_height,
__private const int space_width,
__private const int batch_height,
......@@ -27,8 +28,8 @@ __kernel void space_to_batch(KERNEL_ERROR_PARAMS
const int batch_h_idx = batch_hb_idx % batch_height;
const int block_size = mul24(block_height, block_width);
const int space_b_idx = batch_b_idx / block_size;
const int remaining_batch_idx = batch_b_idx % block_size;
const int space_b_idx = batch_b_idx % batch_size;
const int remaining_batch_idx = batch_b_idx / batch_size;
const int space_h_idx = (remaining_batch_idx / block_width) +
mul24(batch_h_idx, block_height) - padding_height;
const int space_w_idx = (remaining_batch_idx % block_width) +
......@@ -57,6 +58,7 @@ __kernel void batch_to_space(KERNEL_ERROR_PARAMS
__private const int block_width,
__private const int padding_height,
__private const int padding_width,
__private const int batch_size,
__private const int space_height,
__private const int space_width,
__private const int batch_height,
......@@ -76,8 +78,8 @@ __kernel void batch_to_space(KERNEL_ERROR_PARAMS
const int batch_h_idx = batch_hb_idx % batch_height;
const int block_size = mul24(block_height, block_width);
const int space_b_idx = batch_b_idx / block_size;
const int remaining_batch_idx = batch_b_idx % block_size;
const int space_b_idx = batch_b_idx % batch_size;
const int remaining_batch_idx = batch_b_idx / batch_size;
const int space_h_idx = (remaining_batch_idx / block_width) +
mul24(batch_h_idx, block_height) - padding_height;
const int space_w_idx = (remaining_batch_idx % block_width) +
......
......@@ -27,9 +27,19 @@ namespace kernels {
template <typename T>
void SpaceToBatchFunctor<DeviceType::GPU, T>::operator()(
Tensor *space_tensor,
const std::vector<index_t> &output_shape,
Tensor *batch_tensor,
StatsFuture *future) {
std::vector<index_t> output_shape(4, 0);
if (b2s_) {
CalculateBatchToSpaceOutputShape(batch_tensor,
DataFormat::NHWC,
output_shape.data());
} else {
CalculateSpaceToBatchOutputShape(space_tensor,
DataFormat::NHWC,
output_shape.data());
}
const char *kernel_name = nullptr;
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
......@@ -97,6 +107,7 @@ void SpaceToBatchFunctor<DeviceType::GPU, T>::operator()(
kernel_.setArg(idx++, block_shape_[1]);
kernel_.setArg(idx++, paddings_[0]);
kernel_.setArg(idx++, paddings_[2]);
kernel_.setArg(idx++, static_cast<int32_t>(space_tensor->dim(0)));
kernel_.setArg(idx++, static_cast<int32_t>(space_tensor->dim(1)));
kernel_.setArg(idx++, static_cast<int32_t>(space_tensor->dim(2)));
kernel_.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(1)));
......
......@@ -33,31 +33,204 @@ struct SpaceToBatchFunctorBase {
SpaceToBatchFunctorBase(const std::vector<int> &paddings,
const std::vector<int> &block_shape,
bool b2s)
: paddings_(paddings.begin(), paddings.end()),
block_shape_(block_shape.begin(), block_shape.end()),
b2s_(b2s) {}
: paddings_(paddings.begin(), paddings.end()),
block_shape_(block_shape.begin(), block_shape.end()),
b2s_(b2s) {
MACE_CHECK(
block_shape.size() == 2 && block_shape[0] > 1 && block_shape[1] > 1,
"Block's shape should be 1D, and greater than 1");
MACE_CHECK(paddings.size() == 4, "Paddings' shape should be 2D");
}
std::vector<int> paddings_;
std::vector<int> block_shape_;
bool b2s_;
protected:
void CalculateSpaceToBatchOutputShape(const Tensor *input_tensor,
const DataFormat data_format,
index_t *output_shape) {
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
index_t batch = input_tensor->dim(0);
index_t channels = 0;
index_t height = 0;
index_t width = 0;
if (data_format == DataFormat::NHWC) {
height = input_tensor->dim(1);
width = input_tensor->dim(2);
channels = input_tensor->dim(3);
} else if (data_format == DataFormat::NCHW) {
height = input_tensor->dim(2);
width = input_tensor->dim(3);
channels = input_tensor->dim(1);
} else {
MACE_NOT_IMPLEMENTED;
}
index_t padded_height = height + paddings_[0] + paddings_[1];
index_t padded_width = width + paddings_[2] + paddings_[3];
MACE_CHECK(padded_height % block_shape_[0] == 0, "padded input height",
padded_height, " is not divisible by block height");
MACE_CHECK(padded_width % block_shape_[1] == 0, "padded input width",
padded_height, " is not divisible by block width");
index_t new_batch = batch * block_shape_[0] * block_shape_[1];
index_t new_height = padded_height / block_shape_[0];
index_t new_width = padded_width / block_shape_[1];
if (data_format == DataFormat::NHWC) {
output_shape[0] = new_batch;
output_shape[1] = new_height;
output_shape[2] = new_width;
output_shape[3] = channels;
} else {
output_shape[0] = new_batch;
output_shape[1] = channels;
output_shape[2] = new_height;
output_shape[3] = new_width;
}
}
void CalculateBatchToSpaceOutputShape(const Tensor *input_tensor,
const DataFormat data_format,
index_t *output_shape) {
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
index_t batch = input_tensor->dim(0);
index_t channels = 0;
index_t height = 0;
index_t width = 0;
if (data_format == DataFormat::NHWC) {
height = input_tensor->dim(1);
width = input_tensor->dim(2);
channels = input_tensor->dim(3);
} else if (data_format == DataFormat::NCHW) {
height = input_tensor->dim(2);
width = input_tensor->dim(3);
channels = input_tensor->dim(1);
} else {
MACE_NOT_IMPLEMENTED;
}
index_t new_batch = batch / block_shape_[0] / block_shape_[1];
index_t new_height = height * block_shape_[0] - paddings_[0] - paddings_[1];
index_t new_width = width * block_shape_[1] - paddings_[2] - paddings_[3];
if (data_format == DataFormat::NHWC) {
output_shape[0] = new_batch;
output_shape[1] = new_height;
output_shape[2] = new_width;
output_shape[3] = channels;
} else {
output_shape[0] = new_batch;
output_shape[1] = channels;
output_shape[2] = new_height;
output_shape[3] = new_width;
}
}
};
template <DeviceType D, typename T>
struct SpaceToBatchFunctor : SpaceToBatchFunctorBase {
template<DeviceType D, typename T>
struct SpaceToBatchFunctor;
template<>
struct SpaceToBatchFunctor<DeviceType::CPU, float> : SpaceToBatchFunctorBase {
SpaceToBatchFunctor(const std::vector<int> &paddings,
const std::vector<int> &block_shape,
bool b2s)
: SpaceToBatchFunctorBase(paddings, block_shape, b2s) {}
: SpaceToBatchFunctorBase(paddings, block_shape, b2s) {}
void operator()(Tensor *space_tensor,
const std::vector<index_t> &output_shape,
Tensor *batch_tensor,
StatsFuture *future) {
MACE_UNUSED(space_tensor);
MACE_UNUSED(output_shape);
MACE_UNUSED(batch_tensor);
MACE_UNUSED(future);
MACE_NOT_IMPLEMENTED;
std::vector<index_t> output_shape(4, 0);
if (b2s_) {
CalculateBatchToSpaceOutputShape(batch_tensor,
DataFormat::NCHW,
output_shape.data());
space_tensor->Resize(output_shape);
} else {
CalculateSpaceToBatchOutputShape(space_tensor,
DataFormat::NCHW,
output_shape.data());
batch_tensor->Resize(output_shape);
}
Tensor::MappingGuard input_guard(space_tensor);
Tensor::MappingGuard output_guard(batch_tensor);
if (b2s_) {
const float *input_data = batch_tensor->data<float>();
float *output_data = space_tensor->mutable_data<float>();
index_t in_height = batch_tensor->dim(2);
index_t in_width = batch_tensor->dim(3);
index_t out_batches = space_tensor->dim(0);
index_t channels = space_tensor->dim(1);
index_t out_height = space_tensor->dim(2);
index_t out_width = space_tensor->dim(3);
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < out_batches; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t h = 0; h < out_height; ++h) {
const index_t in_h = (h + paddings_[0]) / block_shape_[0];
const index_t tile_h = (h + paddings_[0]) % block_shape_[0];
for (index_t w = 0; w < out_width; ++w) {
const index_t in_w = (w + paddings_[2]) / block_shape_[1];
const index_t tile_w = (w + paddings_[2]) % block_shape_[1];
const index_t
in_b = (tile_h * block_shape_[1] + tile_w) * out_batches + b;
output_data[((b * channels + c) * out_height + h) * out_width
+ w] =
input_data[
((in_b * channels + c) * in_height + in_h) * in_width
+ in_w];
}
}
}
}
} else {
const float *input_data = space_tensor->data<float>();
float *output_data = batch_tensor->mutable_data<float>();
index_t in_batches = space_tensor->dim(0);
index_t in_height = space_tensor->dim(2);
index_t in_width = space_tensor->dim(3);
index_t out_batches = batch_tensor->dim(0);
index_t channels = batch_tensor->dim(1);
index_t out_height = batch_tensor->dim(2);
index_t out_width = batch_tensor->dim(3);
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < out_batches; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t in_b = b % in_batches;
const index_t tile_h = b / in_batches / block_shape_[1];
const index_t tile_w = b / in_batches % block_shape_[1];
for (index_t h = 0; h < out_height; ++h) {
const index_t in_h = h * block_shape_[0] + tile_h - paddings_[0];
for (index_t w = 0; w < out_width; ++w) {
const index_t in_w = w * block_shape_[1] + tile_w - paddings_[2];
if (in_h >= 0 && in_w >= 0 && in_h < in_height
&& in_w < in_width) {
output_data[((b * channels + c) * out_height + h) * out_width
+ w] =
input_data[
((in_b * channels + c) * in_height + in_h) * in_width
+ in_w];
} else {
output_data[((b * channels + c) * out_height + h) * out_width
+ w] = 0;
}
}
}
}
}
}
}
};
......@@ -70,7 +243,6 @@ struct SpaceToBatchFunctor<DeviceType::GPU, T> : SpaceToBatchFunctorBase {
: SpaceToBatchFunctorBase(paddings, block_shape, b2s) {}
void operator()(Tensor *space_tensor,
const std::vector<index_t> &output_shape,
Tensor *batch_tensor,
StatsFuture *future);
......
......@@ -18,6 +18,11 @@ namespace mace {
namespace ops {
void Register_BatchToSpaceND(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchToSpaceND")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
BatchToSpaceNDOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchToSpaceND")
.Device(DeviceType::GPU)
......@@ -29,8 +34,6 @@ void Register_BatchToSpaceND(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
BatchToSpaceNDOp<DeviceType::GPU, half>);
#else
MACE_UNUSED(op_registry);
#endif // MACE_ENABLE_OPENCL
}
......
......@@ -36,41 +36,11 @@ class BatchToSpaceNDOp : public Operator<D, T> {
bool Run(StatsFuture *future) override {
const Tensor *batch_tensor = this->Input(INPUT);
Tensor *space_tensor = this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0);
CalculateOutputShape(batch_tensor, output_shape.data());
functor_(space_tensor, output_shape, const_cast<Tensor *>(batch_tensor),
functor_(space_tensor, const_cast<Tensor *>(batch_tensor),
future);
return true;
}
private:
inline void CalculateOutputShape(const Tensor *input_tensor,
index_t *output_shape) {
auto crops = OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0});
auto block_shape =
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
MACE_CHECK(block_shape.size() == 2, "Block's shape should be 1D");
MACE_CHECK(crops.size() == 4, "Crops' shape should be 2D");
const index_t block_dims = block_shape.size();
index_t block_shape_product = 1;
for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
MACE_CHECK(block_shape[block_dim] > 1,
"block_shape's value should be great to 1");
const index_t block_shape_value = block_shape[block_dim];
const index_t cropped_input_size =
input_tensor->dim(block_dim + 1) * block_shape_value -
crops[block_dim * 2] - crops[block_dim * 2 + 1];
MACE_CHECK(cropped_input_size >= 0, "cropped size must be non-negative");
block_shape_product *= block_shape_value;
output_shape[block_dim + 1] = cropped_input_size;
}
output_shape[0] = input_tensor->dim(0) / block_shape_product;
output_shape[3] = input_tensor->dim(3);
}
private:
kernels::SpaceToBatchFunctor<D, T> functor_;
......
......@@ -35,8 +35,6 @@ namespace mace {
namespace ops {
namespace test {
enum DataFormat { NHWC = 0, NCHW = 1, HWOI = 2, OIHW = 3, HWIO = 4 };
class OpDefBuilder {
public:
OpDefBuilder(const char *type, const std::string &name) {
......
......@@ -18,6 +18,11 @@ namespace mace {
namespace ops {
void Register_SpaceToBatchND(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToBatchND")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
SpaceToBatchNDOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToBatchND")
.Device(DeviceType::GPU)
......@@ -30,8 +35,6 @@ void Register_SpaceToBatchND(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
SpaceToBatchNDOp<DeviceType::GPU, half>);
#else
MACE_UNUSED(op_registry);
#endif // MACE_ENABLE_OPENCL
}
......
......@@ -37,43 +37,11 @@ class SpaceToBatchNDOp : public Operator<D, T> {
bool Run(StatsFuture *future) override {
const Tensor *space_tensor = this->Input(INPUT);
Tensor *batch_tensor = this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0);
CalculateOutputShape(space_tensor, output_shape.data());
functor_(const_cast<Tensor *>(space_tensor), output_shape, batch_tensor,
functor_(const_cast<Tensor *>(space_tensor), batch_tensor,
future);
return true;
}
private:
inline void CalculateOutputShape(const Tensor *input_tensor,
index_t *output_shape) {
auto paddings =
OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0});
auto block_shape =
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
MACE_CHECK(block_shape.size() == 2, "Block's shape should be 1D");
MACE_CHECK(paddings.size() == 4, "Paddings' shape should be 2D");
const index_t block_dims = block_shape.size();
index_t block_shape_product = 1;
for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
MACE_CHECK(block_shape[block_dim] > 1,
"block_shape's value should be great to 1");
const index_t block_shape_value = block_shape[block_dim];
const index_t padded_input_size = input_tensor->dim(block_dim + 1) +
paddings[block_dim * 2] +
paddings[block_dim * 2 + 1];
MACE_CHECK(padded_input_size % block_shape_value == 0, "padded input ",
padded_input_size, " is not divisible by block_shape");
block_shape_product *= block_shape_value;
output_shape[block_dim + 1] = padded_input_size / block_shape_value;
}
output_shape[0] = input_tensor->dim(0) * block_shape_product;
output_shape[3] = input_tensor->dim(3);
}
private:
kernels::SpaceToBatchFunctor<D, T> functor_;
......
......@@ -31,20 +31,40 @@ void RunSpaceToBatch(const std::vector<index_t> &input_shape,
OpsTestNet net;
net.AddInputFromArray<D, float>("Input", input_shape, input_data);
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
if (D == GPU) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntsArg("paddings", padding_data)
.AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef());
} else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("paddings", padding_data)
.AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef());
}
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
if (D == GPU) {
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
} else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
}
// Check
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"));
}
......@@ -59,20 +79,40 @@ void RunBatchToSpace(const std::vector<index_t> &input_shape,
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input_data);
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
if (D == GPU) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntsArg("crops", crops_data)
.AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef());
} else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("crops", crops_data)
.AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef());
}
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
if (D == GPU) {
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
} else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
}
// Check
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"));
}
......@@ -108,9 +148,13 @@ void TestBidirectionalTransform(const std::vector<index_t> &space_shape,
RunSpaceToBatch<DeviceType::GPU>(space_shape, space_data, block_data,
padding_data, batch_tensor.get());
RunSpaceToBatch<DeviceType::CPU>(space_shape, space_data, block_data,
padding_data, batch_tensor.get());
RunBatchToSpace<DeviceType::GPU>(batch_shape, batch_data, block_data,
padding_data, space_tensor.get());
RunBatchToSpace<DeviceType::CPU>(batch_shape, batch_data, block_data,
padding_data, space_tensor.get());
}
} // namespace
......@@ -156,7 +200,7 @@ TEST(SpaceToBatchTest, MultiBatchData) {
TestBidirectionalTransform<float>(
{2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{2, 2}, {0, 0, 0, 0}, {8, 1, 2, 1},
{1, 3, 2, 4, 5, 7, 6, 8, 9, 11, 10, 12, 13, 15, 14, 16});
{1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16});
}
TEST(SpaceToBatchTest, MultiBatchAndChannelData) {
......@@ -165,8 +209,8 @@ TEST(SpaceToBatchTest, MultiBatchAndChannelData) {
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
{2, 2}, {0, 0, 0, 0}, {8, 1, 2, 2},
{1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16,
17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32});
{1, 2, 5, 6, 17, 18, 21, 22, 3, 4, 7, 8, 19, 20, 23, 24,
9, 10, 13, 14, 25, 26, 29, 30, 11, 12, 15, 16, 27, 28, 31, 32});
}
} // namespace test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册