提交 41a905c9 编写于 作者: L Liangliang He

Merge branch 'spacetobatch' into 'master'

Finish Space to batch and Add side-gcn validation.

See merge request !186
...@@ -493,12 +493,11 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type): ...@@ -493,12 +493,11 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type):
ws_->CreateTensor("mace_input_node:0", GetDeviceAllocator(device_type_), DT_FLOAT); ws_->CreateTensor("mace_input_node:0", GetDeviceAllocator(device_type_), DT_FLOAT);
net_ = std::move(CreateNet(*net_def, ws_.get(), device_type)); net_ = std::move(CreateNet(*net_def, ws_.get(), device_type));
} }
MaceEngine::~MaceEngine(){} MaceEngine::~MaceEngine() = default;
bool MaceEngine::Run(const float *input, bool MaceEngine::Run(const float *input,
const std::vector<index_t> &input_shape, const std::vector<index_t> &input_shape,
float *output) { float *output) {
MACE_CHECK(output != nullptr, "output ptr cannot be NULL"); MACE_CHECK(output != nullptr, "output ptr cannot be NULL");
Tensor *input_tensor = Tensor *input_tensor =
ws_->CreateTensor("mace_input_node:0", GetDeviceAllocator(device_type_), DT_FLOAT); ws_->CreateTensor("mace_input_node:0", GetDeviceAllocator(device_type_), DT_FLOAT);
input_tensor->Resize(input_shape); input_tensor->Resize(input_shape);
...@@ -518,6 +517,7 @@ bool MaceEngine::Run(const float *input, ...@@ -518,6 +517,7 @@ bool MaceEngine::Run(const float *input,
auto shape = output_tensor->shape(); auto shape = output_tensor->shape();
int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1, int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>()); std::multiplies<int64_t>());
// TODO: check for overflow exception.
std::memcpy(output, output_tensor->data<float>(), std::memcpy(output, output_tensor->data<float>(),
output_size * sizeof(float)); output_size * sizeof(float));
return true; return true;
......
...@@ -310,6 +310,9 @@ class MaceEngine { ...@@ -310,6 +310,9 @@ class MaceEngine {
bool Run(const float *input, bool Run(const float *input,
const std::vector<int64_t> &input_shape, const std::vector<int64_t> &input_shape,
float *output); float *output);
MaceEngine(const MaceEngine&) = delete;
MaceEngine &operator=(const MaceEngine&) = delete;
private: private:
DeviceType device_type_; DeviceType device_type_;
std::unique_ptr<Workspace> ws_; std::unique_ptr<Workspace> ws_;
......
...@@ -173,9 +173,13 @@ int main(int argc, char **argv) { ...@@ -173,9 +173,13 @@ int main(int argc, char **argv) {
// load input // load input
ifstream in_file(input_file, ios::in | ios::binary); ifstream in_file(input_file, ios::in | ios::binary);
in_file.read(reinterpret_cast<char *>(input_data.get()), if (in_file.is_open()) {
input_size * sizeof(float)); in_file.read(reinterpret_cast<char *>(input_data.get()),
in_file.close(); input_size * sizeof(float));
in_file.close();
} else {
LOG(ERROR) << "Open input file failed";
}
// Init model // Init model
VLOG(0) << "Run init"; VLOG(0) << "Run init";
......
...@@ -64,7 +64,7 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer, ...@@ -64,7 +64,7 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer,
image_shape[1], image_shape[1],
1}; 1};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(b2f_kernel); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(b2f_kernel);
const std::vector<uint32_t> lws = {kwg_size, 1, 1}; const std::vector<uint32_t> lws = {16, 64, 1};
cl::Event event; cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
b2f_kernel, cl::NullRange, b2f_kernel, cl::NullRange,
......
#include <common.h> #include <common.h>
// Supported data type: all __kernel void space_to_batch(__read_only image2d_t space_data,
__kernel void space_to_batch(__global DATA_TYPE *space_data_ptr, __write_only image2d_t batch_data,
__global const int *block_shape_ptr, __private const int block_height,
__global const int *paddings_ptr, __private const int block_width,
__private const int space_batch, __private const int padding_height,
__private const int space_channel, __private const int padding_width,
__private const int space_height, __private const int space_height,
__private const int space_width, __private const int space_width,
__private const int batch_height, __private const int batch_height,
__private const int batch_width, __private const int batch_width) {
__private const int b2s, const int chan_idx = get_global_id(0);
__global DATA_TYPE* batch_data_ptr) { const int batch_w_idx = get_global_id(1);
int batch_idx = get_global_id(0); const int batch_hb_idx = get_global_id(2);
int batch_channel_idx = get_global_id(1);
int batch_pixel_idx = get_global_id(2); const int batch_b_idx = batch_hb_idx / batch_height;
const int batch_h_idx = batch_hb_idx % batch_height;
const int block_height = block_shape_ptr[0];
const int block_width = block_shape_ptr[1]; const int block_size = mul24(block_height, block_width);
const int padding_height_start = paddings_ptr[0]; const int space_b_idx = batch_b_idx / block_size;
const int padding_width_start = paddings_ptr[2]; const int remaining_batch_idx = batch_b_idx % block_size;
const int space_h_idx = (remaining_batch_idx / block_width) +
const int batch_pixel_height_idx = batch_pixel_idx / batch_width; mul24(batch_h_idx, block_height) - padding_height;
const int batch_pixel_width_idx = batch_pixel_idx % batch_width; const int space_w_idx = (remaining_batch_idx % block_width) +
mul24(batch_w_idx, block_width) - padding_width;
const int block_size = block_height * block_width;
const int space_idx = batch_idx / block_size; const int space_coord_x = select(mul24(chan_idx, space_width) + space_w_idx,
const int remaining_batch_idx = batch_idx % block_size; -1,
int space_pixel_height_idx = (remaining_batch_idx / block_width) + space_w_idx < 0 || space_w_idx >= space_width);
batch_pixel_height_idx * block_height; const int space_coord_y = select(mul24(space_b_idx, space_height) + space_h_idx,
int space_pixel_width_idx = (remaining_batch_idx % block_width) + -1,
batch_pixel_width_idx * block_width; space_h_idx < 0 || space_h_idx >= space_height);
int2 space_coord = (int2)(space_coord_x,
const int batch_data_offset = batch_idx * (space_channel * batch_height * batch_width) + space_coord_y);
(batch_channel_idx * batch_height * batch_width) + DATA_TYPE4 value = READ_IMAGET(space_data, SAMPLER, space_coord);
batch_pixel_height_idx * batch_width +
batch_pixel_width_idx; int2 batch_coord = (int2)(mul24(chan_idx, batch_width) + batch_w_idx, batch_hb_idx);
WRITE_IMAGET(batch_data, batch_coord, value);
space_pixel_height_idx -= padding_height_start; }
space_pixel_width_idx -= padding_width_start;
const int space_data_offset = space_idx * (space_channel * space_height * space_width) + __kernel void batch_to_space(__read_only image2d_t batch_data,
(batch_channel_idx * space_height * space_width) + __write_only image2d_t space_data,
space_pixel_height_idx * space_width + __private const int block_height,
space_pixel_width_idx; __private const int block_width,
if (space_pixel_height_idx < 0 || space_pixel_height_idx >= space_height || __private const int padding_height,
space_pixel_width_idx < 0 || space_pixel_width_idx >= space_width) { __private const int padding_width,
if (!b2s) { __private const int space_height,
*(batch_data_ptr + batch_data_offset) = 0; __private const int space_width,
} __private const int batch_height,
} else { __private const int batch_width) {
if (b2s) { const int chan_idx = get_global_id(0);
*(space_data_ptr + space_data_offset) = *(batch_data_ptr + batch_data_offset); const int batch_w_idx = get_global_id(1);
} else { const int batch_hb_idx = get_global_id(2);
*(batch_data_ptr + batch_data_offset) = *(space_data_ptr + space_data_offset);
} const int batch_b_idx = batch_hb_idx / batch_height;
const int batch_h_idx = batch_hb_idx % batch_height;
const int block_size = mul24(block_height, block_width);
const int space_b_idx = batch_b_idx / block_size;
const int remaining_batch_idx = batch_b_idx % block_size;
const int space_h_idx = (remaining_batch_idx / block_width) +
mul24(batch_h_idx, block_height) - padding_height;
const int space_w_idx = (remaining_batch_idx % block_width) +
mul24(batch_w_idx, block_width) - padding_width;
if (0 <= space_w_idx && space_w_idx < space_width &&
0 <= space_h_idx && space_h_idx < space_height) {
int2 batch_coord = (int2)(mul24(chan_idx, batch_width) + batch_w_idx, batch_hb_idx);
DATA_TYPE4 value = READ_IMAGET(batch_data, SAMPLER, batch_coord);
int2 space_coord = (int2)(mul24(chan_idx, space_width) + space_w_idx,
space_b_idx * space_height + space_h_idx);
WRITE_IMAGET(space_data, space_coord, value);
} }
} }
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_
#define MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_
#include "mace/core/common.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/space_to_batch.h"
#include "mace/kernels/opencl/helper.h"
namespace mace {
namespace kernels {
template <>
void SpaceToBatchFunctor<DeviceType::OPENCL, float>::operator()(Tensor *space_tensor,
const Tensor *block_shape_tensor,
const Tensor *paddings_tensor,
Tensor *batch_tensor,
StatsFuture *future) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(space_tensor->dtype()));
auto s2b_kernel = runtime->BuildKernel("space_to_batch", "space_to_batch", built_options);
uint32_t idx = 0;
s2b_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(space_tensor->buffer())));
s2b_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(block_shape_tensor->buffer())));
s2b_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(paddings_tensor->buffer())));
s2b_kernel.setArg(idx++, static_cast<int32_t>(space_tensor->dim(0)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(space_tensor->dim(1)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(space_tensor->dim(2)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(space_tensor->dim(3)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(2)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(3)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(b2s_));
s2b_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(batch_tensor->buffer())));
const uint32_t gws[3] = {static_cast<uint32_t>(batch_tensor->dim(0)),
static_cast<uint32_t>(batch_tensor->dim(1)),
static_cast<uint32_t>(batch_tensor->dim(2) * batch_tensor->dim(3))};
const uint32_t lws[3] = {static_cast<uint32_t>(1),
static_cast<uint32_t>(8),
static_cast<uint32_t>(128)};
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
s2b_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
}
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_
#define MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_
#include "mace/core/common.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/space_to_batch.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template <typename T>
void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(Tensor *space_tensor,
const std::vector<index_t> &output_shape,
Tensor *batch_tensor,
StatsFuture *future) {
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape);
const char *kernel_name = nullptr;
if (b2s_) {
space_tensor->ResizeImage(output_shape, output_image_shape);
kernel_name = "batch_to_space";
} else {
batch_tensor->ResizeImage(output_shape, output_image_shape);
kernel_name = "space_to_batch";
}
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DataTypeToEnum<T>::value));
auto s2b_kernel = runtime->BuildKernel("space_to_batch", kernel_name, built_options);
uint32_t idx = 0;
if (b2s_) {
s2b_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(batch_tensor->buffer())));
s2b_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(space_tensor->buffer())));
} else {
s2b_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(space_tensor->buffer())));
s2b_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(batch_tensor->buffer())));
}
s2b_kernel.setArg(idx++, block_shape_[0]);
s2b_kernel.setArg(idx++, block_shape_[1]);
s2b_kernel.setArg(idx++, paddings_[0]);
s2b_kernel.setArg(idx++, paddings_[2]);
s2b_kernel.setArg(idx++, static_cast<int32_t>(space_tensor->dim(1)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(space_tensor->dim(2)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(1)));
s2b_kernel.setArg(idx++, static_cast<int32_t>(batch_tensor->dim(2)));
const uint32_t chan_blk = RoundUpDiv4<uint32_t>(batch_tensor->dim(3));
const uint32_t gws[3] = {chan_blk,
static_cast<uint32_t>(batch_tensor->dim(2)),
static_cast<uint32_t>(batch_tensor->dim(0) * batch_tensor->dim(1))};
const std::vector<uint32_t> lws = {8, 16, 8};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(s2b_kernel);
auto params_generator = [&]() -> std::vector<std::vector<uint32_t>> {
std::vector<uint32_t> local_ws(3, 0);
local_ws[0] = std::min<uint32_t>(chan_blk, kwg_size);
local_ws[1] = std::min<uint32_t>(32, kwg_size / local_ws[0]);
local_ws[2] = std::min<uint32_t>(32, kwg_size / (local_ws[0] * local_ws[1]));
return {{local_ws[0], local_ws[1], local_ws[2]},
{4, 32, 8},
{4, 64, 4},
{4, 128, 2},
{8, 16, 8},
{8, 32, 4},
{8, 64, 2},
{16, 8, 8},
{16, 16, 4},
{16, 32, 2},
{32, 8, 4},
{32, 16, 2},
{64, 4, 4}};
};
cl::Event event;
auto func = [&](const std::vector<uint32_t> &params) -> cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
s2b_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss;
ss << kernel_name << "_"
<< batch_tensor->dim(0) << "_"
<< batch_tensor->dim(1) << "_"
<< batch_tensor->dim(2) << "_"
<< batch_tensor->dim(3);
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
}
template struct SpaceToBatchFunctor<DeviceType::OPENCL, float>;
template struct SpaceToBatchFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_OPENCL_SPACE_TO_BATCH_H_
...@@ -12,27 +12,46 @@ ...@@ -12,27 +12,46 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> struct SpaceToBatchFunctorBase {
struct SpaceToBatchFunctor { SpaceToBatchFunctorBase(const std::vector<int> &paddings,
SpaceToBatchFunctor(const bool b2s = false): b2s_(b2s){} const std::vector<int> &block_shape,
bool b2s):
paddings_(paddings.begin(), paddings.end()),
block_shape_(block_shape.begin(), block_shape.end()),
b2s_(b2s)
{}
std::vector<int> paddings_;
std::vector<int> block_shape_;
bool b2s_;
};
void operator()(Tensor *input_tensor, template <DeviceType D, typename T>
const Tensor *block_shape_tensor, struct SpaceToBatchFunctor : SpaceToBatchFunctorBase{
const Tensor *paddings_tensor, SpaceToBatchFunctor(const std::vector<int> &paddings,
Tensor *output_tensor, const std::vector<int> &block_shape,
bool b2s): SpaceToBatchFunctorBase(paddings, block_shape, b2s){}
void operator()(Tensor *space_tensor,
const std::vector<index_t> &output_shape,
Tensor *batch_tensor,
StatsFuture *future) { StatsFuture *future) {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
bool b2s_;
}; };
template <> template <typename T>
void SpaceToBatchFunctor<DeviceType::OPENCL, float>::operator()(Tensor *input_tensor, struct SpaceToBatchFunctor<DeviceType::OPENCL, T>: SpaceToBatchFunctorBase{
const Tensor *block_shape_tensor, SpaceToBatchFunctor(const std::vector<int> &paddings,
const Tensor *paddings_tensor, const std::vector<int> &block_shape,
Tensor *output, bool b2s): SpaceToBatchFunctorBase(paddings, block_shape, b2s){}
StatsFuture *future);
void operator()(Tensor *space_tensor,
const std::vector<index_t> &output_shape,
Tensor *batch_tensor,
StatsFuture *future);
};
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -10,5 +10,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchToSpaceND") ...@@ -10,5 +10,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchToSpaceND")
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
BatchToSpaceNDOp<DeviceType::OPENCL, float>); BatchToSpaceNDOp<DeviceType::OPENCL, float>);
REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchToSpaceND")
.TypeConstraint<half>("T")
.Build(),
BatchToSpaceNDOp<DeviceType::OPENCL, half>);
} // namespace mace } // namespace mace
...@@ -12,63 +12,58 @@ ...@@ -12,63 +12,58 @@
namespace mace { namespace mace {
static void BatchToSpaceHelper(const Tensor *input_tensor, template<DeviceType D, typename T>
const Tensor *block_shape_tensor, class BatchToSpaceNDOp : public Operator<D, T> {
const Tensor *cropped_tensor,
Tensor *output) {
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
MACE_CHECK(block_shape_tensor->dim_size() == 1, "Block's shape should be 1D");
MACE_CHECK(cropped_tensor->dim_size() == 2, "Paddings' shape should be 2D");
const index_t block_dims = block_shape_tensor->dim(0);
MACE_CHECK(block_dims == cropped_tensor->dim(0) && 2 == cropped_tensor->dim(1));
// TODO change tensor to attribute if needed based on the benchmark
Tensor::MappingGuard block_shape_tensor_mapper(block_shape_tensor);
Tensor::MappingGuard cropped_tensor_mapper(cropped_tensor);
const int *block_shape_ptr = block_shape_tensor->data<int>();
const int *cropped_ptr = cropped_tensor->data<int>();
std::vector<index_t> output_shape(4, 0);
index_t block_shape_product = 1;
for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
MACE_CHECK(block_shape_ptr[block_dim] > 1, "block_shape's value should be great to 1");
const index_t block_shape_value = block_shape_ptr[block_dim];
const index_t cropped_input_size = input_tensor->dim(block_dim + 2) * block_shape_value
- *cropped_ptr
- *(cropped_ptr+1);
MACE_CHECK(cropped_input_size >= 0,
"cropped size must be non-negative");
block_shape_product *= block_shape_value;
output_shape[block_dim+2] = cropped_input_size;
cropped_ptr += 2;
}
output_shape[0] = input_tensor->dim(0) / block_shape_product;
output_shape[1] = input_tensor->dim(1);
output->Resize(output_shape);
}
template <DeviceType D, typename T>
class BatchToSpaceNDOp: public Operator<D, T> {
public: public:
BatchToSpaceNDOp(const OperatorDef &op_def, Workspace *ws) BatchToSpaceNDOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), functor_(true) {} : Operator<D, T>(op_def, ws),
functor_(
OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0}),
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}),
true) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT); const Tensor *batch_tensor = this->Input(INPUT);
const Tensor *block_shape_tensor = this->Input(BLOCK_SHAPE); Tensor *space_tensor = this->Output(OUTPUT);
const Tensor *cropped_tensor = this->Input(CROPS);
Tensor *output = this->Output(OUTPUT);
BatchToSpaceHelper(input_tensor, block_shape_tensor, cropped_tensor, output); std::vector<index_t> output_shape(4, 0);
functor_(output, block_shape_tensor, cropped_tensor, const_cast<Tensor*>(input_tensor), future); CalculateOutputShape(batch_tensor, space_tensor, output_shape.data());
functor_(space_tensor, output_shape, const_cast<Tensor *>(batch_tensor), future);
return true; return true;
} }
private:
inline void CalculateOutputShape(const Tensor *input_tensor,
Tensor *output,
index_t *output_shape) {
auto crops = OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0});
auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
MACE_CHECK(block_shape.size() == 2, "Block's shape should be 1D");
MACE_CHECK(crops.size() == 4, "Crops' shape should be 2D");
const index_t block_dims = block_shape.size();
index_t block_shape_product = 1;
for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
MACE_CHECK(block_shape[block_dim] > 1, "block_shape's value should be great to 1");
const index_t block_shape_value = block_shape[block_dim];
const index_t cropped_input_size = input_tensor->dim(block_dim + 1) * block_shape_value
- crops[block_dim * 2]
- crops[block_dim * 2 + 1];
MACE_CHECK(cropped_input_size >= 0,
"cropped size must be non-negative");
block_shape_product *= block_shape_value;
output_shape[block_dim + 1] = cropped_input_size;
}
output_shape[0] = input_tensor->dim(0) / block_shape_product;
output_shape[3] = input_tensor->dim(3);
}
private: private:
kernels::SpaceToBatchFunctor<D, T> functor_; kernels::SpaceToBatchFunctor<D, T> functor_;
protected: protected:
OP_INPUT_TAGS(INPUT, BLOCK_SHAPE, CROPS); OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT); OP_OUTPUT_TAGS(OUTPUT);
}; };
......
...@@ -9,23 +9,19 @@ ...@@ -9,23 +9,19 @@
namespace mace { namespace mace {
template <DeviceType D, typename T> template <DeviceType D, typename T>
static void BMBatchToSpace( static void BMBatchToSpace(
int iters, int batch, int channels, int height, int width) { int iters, int batch, int channels, int height, int width, int arg) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
BufferToImage<D, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("Input") .Input("InputImage")
.Input("BlockShape") .Output("OutputImage")
.Input("Crops") .AddIntsArg("crops", {0, 0, 0, 0})
.Output("Output") .AddIntsArg("block_shape", {arg, arg})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
net.AddInputFromArray<D, int>(
"BlockShape", {2}, {2, 2});
net.AddInputFromArray<D, int>("Crops", {2, 2}, {0,1,0,1});
// Warm-up // Warm-up
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
net.RunOp(D); net.RunOp(D);
...@@ -39,18 +35,20 @@ static void BMBatchToSpace( ...@@ -39,18 +35,20 @@ static void BMBatchToSpace(
net.Sync(); net.Sync();
} }
#define BM_BATCH_TO_SPACE_MACRO(N, C, H, W, TYPE, DEVICE) \ #define BM_BATCH_TO_SPACE_MACRO(N, H, W, C, ARG, TYPE, DEVICE) \
static void BM_BATCH_TO_SPACE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ static void BM_BATCH_TO_SPACE_##N##_##H##_##W##_##C##_##ARG##_##TYPE##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMBatchToSpace<DEVICE, TYPE>(iters, N, C, H, W); \ BMBatchToSpace<DEVICE, TYPE>(iters, N, C, H, W, ARG); \
} \ } \
BENCHMARK(BM_BATCH_TO_SPACE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) BENCHMARK(BM_BATCH_TO_SPACE_##N##_##H##_##W##_##C##_##ARG##_##TYPE##_##DEVICE)
#define BM_BATCH_TO_SPACE(N, C, H, W, TYPE) \ #define BM_BATCH_TO_SPACE(N, H, W, C, ARG, TYPE) \
BM_BATCH_TO_SPACE_MACRO(N, C, H, W, TYPE, OPENCL); BM_BATCH_TO_SPACE_MACRO(N, H, W, C, ARG, TYPE, OPENCL);
BM_BATCH_TO_SPACE(128, 128, 8, 8, float); BM_BATCH_TO_SPACE(128, 8, 8, 128, 2, float);
BM_BATCH_TO_SPACE(4, 128, 128, 32, 2, float);
BM_BATCH_TO_SPACE(16, 64, 64, 32, 4, float);
} // namespace mace } // namespace mace
\ No newline at end of file
...@@ -322,9 +322,18 @@ struct Expector<EXP_TYPE, RES_TYPE, true> { ...@@ -322,9 +322,18 @@ struct Expector<EXP_TYPE, RES_TYPE, true> {
Tensor::MappingGuard y_mapper(&y); Tensor::MappingGuard y_mapper(&y);
auto a = x.data<EXP_TYPE>(); auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>(); auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) { for (int n = 0; n < x.dim(0); ++n) {
EXPECT_NEAR(a[i], b[i], abs_err) << "a = " << a << " b = " << b for (int h = 0; h < x.dim(1); ++h) {
<< " index = " << i; for (int w = 0; w < x.dim(2); ++w) {
for (int c = 0; c < x.dim(3); ++c) {
EXPECT_NEAR(*a, *b, abs_err) << "with index = ["
<< n << ", " << h << ", "
<< w << ", " << c << "]";
a++;
b++;
}
}
}
} }
} }
......
...@@ -10,5 +10,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("SpaceToBatchND") ...@@ -10,5 +10,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("SpaceToBatchND")
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
SpaceToBatchNDOp<DeviceType::OPENCL, float>); SpaceToBatchNDOp<DeviceType::OPENCL, float>);
REGISTER_OPENCL_OPERATOR(OpKeyBuilder("SpaceToBatchND")
.TypeConstraint<half>("T")
.Build(),
SpaceToBatchNDOp<DeviceType::OPENCL, half>);
} // namespace mace } // namespace mace
...@@ -12,62 +12,59 @@ ...@@ -12,62 +12,59 @@
namespace mace { namespace mace {
static void SpaceToBatchHelper(const Tensor *input_tensor, template<DeviceType D, typename T>
const Tensor *block_shape_tensor,
const Tensor *paddings_tensor,
Tensor *output) {
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
MACE_CHECK(block_shape_tensor->dim_size() == 1, "Block's shape should be 1D");
MACE_CHECK(paddings_tensor->dim_size() == 2, "Paddings' shape should be 2D");
const index_t block_dims = block_shape_tensor->dim(0);
MACE_CHECK(block_dims == paddings_tensor->dim(0) && 2 == paddings_tensor->dim(1));
Tensor::MappingGuard block_shape_tensor_mapper(block_shape_tensor);
Tensor::MappingGuard padding_tensor_mapper(paddings_tensor);
const int *block_shape_ptr = block_shape_tensor->data<int>();
const int *paddings_ptr = paddings_tensor->data<int>();
std::vector<index_t> output_shape(4, 0);
index_t block_shape_product = 1;
for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
MACE_CHECK(block_shape_ptr[block_dim] > 1, "block_shape's value should be great to 1");
const index_t block_shape_value = block_shape_ptr[block_dim];
const index_t padded_input_size = input_tensor->dim(block_dim + 2)
+ *paddings_ptr
+ *(paddings_ptr+1);
MACE_CHECK(padded_input_size % block_shape_value == 0,
"padded input is not divisible by block_shape");
block_shape_product *= block_shape_value;
output_shape[block_dim+2] = padded_input_size / block_shape_value;
paddings_ptr += 2;
}
output_shape[0] = input_tensor->dim(0) * block_shape_product;
output_shape[1] = input_tensor->dim(1);
output->Resize(output_shape);
}
template <DeviceType D, typename T>
class SpaceToBatchNDOp : public Operator<D, T> { class SpaceToBatchNDOp : public Operator<D, T> {
public: public:
SpaceToBatchNDOp(const OperatorDef &op_def, Workspace *ws) SpaceToBatchNDOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws) {} : Operator<D, T>(op_def, ws),
functor_(
OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0}),
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}),
false) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT); const Tensor *space_tensor = this->Input(INPUT);
const Tensor *block_shape_tensor = this->Input(BLOCK_SHAPE); Tensor *batch_tensor = this->Output(OUTPUT);
const Tensor *paddings_tensor = this->Input(PADDINGS);
Tensor *output = this->Output(OUTPUT);
SpaceToBatchHelper(input_tensor, block_shape_tensor, paddings_tensor, output); std::vector<index_t> output_shape(4, 0);
functor_(const_cast<Tensor*>(input_tensor), block_shape_tensor, paddings_tensor, output, future); CalculateOutputShape(space_tensor, batch_tensor, output_shape.data());
functor_(const_cast<Tensor *>(space_tensor), output_shape, batch_tensor, future);
return true; return true;
} }
private:
inline void CalculateOutputShape(const Tensor *input_tensor,
Tensor *output,
index_t *output_shape) {
auto paddings = OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0});
auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
MACE_CHECK(block_shape.size() == 2, "Block's shape should be 1D");
MACE_CHECK(paddings.size() == 4, "Paddings' shape should be 2D");
const index_t block_dims = block_shape.size();
index_t block_shape_product = 1;
for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
MACE_CHECK(block_shape[block_dim] > 1, "block_shape's value should be great to 1");
const index_t block_shape_value = block_shape[block_dim];
const index_t padded_input_size = input_tensor->dim(block_dim + 1)
+ paddings[block_dim * 2]
+ paddings[block_dim * 2 + 1];
MACE_CHECK(padded_input_size % block_shape_value == 0,
"padded input ", padded_input_size, " is not divisible by block_shape");
block_shape_product *= block_shape_value;
output_shape[block_dim + 1] = padded_input_size / block_shape_value;
}
output_shape[0] = input_tensor->dim(0) * block_shape_product;
output_shape[3] = input_tensor->dim(3);
}
private: private:
kernels::SpaceToBatchFunctor<D, T> functor_; kernels::SpaceToBatchFunctor<D, T> functor_;
protected: protected:
OP_INPUT_TAGS(INPUT, BLOCK_SHAPE, PADDINGS); OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT); OP_OUTPUT_TAGS(OUTPUT);
}; };
......
...@@ -9,23 +9,20 @@ ...@@ -9,23 +9,20 @@
namespace mace { namespace mace {
template <DeviceType D, typename T> template <DeviceType D, typename T>
static void BMSpaceToBatch( static void BMSpaceToBatch(
int iters, int batch, int channels, int height, int width) { int iters, int batch, int height, int width, int channels, int shape) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
BufferToImage<D, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("Input") .Input("InputImage")
.Input("BlockShape") .Output("OutputImage")
.Input("Padding") .AddIntsArg("paddings", {shape, shape, shape, shape})
.Output("Output") .AddIntsArg("block_shape", {shape, shape})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
net.AddInputFromArray<D, int>(
"BlockShape", {2}, {2, 2});
net.AddInputFromArray<D, int>("Padding", {2, 2}, {2,3,2,3});
// Warm-up // Warm-up
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
net.RunOp(D); net.RunOp(D);
...@@ -39,18 +36,20 @@ static void BMSpaceToBatch( ...@@ -39,18 +36,20 @@ static void BMSpaceToBatch(
net.Sync(); net.Sync();
} }
#define BM_SPACE_TO_BATCH_MACRO(N, C, H, W, TYPE, DEVICE) \ #define BM_SPACE_TO_BATCH_MACRO(N, H, W, C, SHAPE, TYPE, DEVICE) \
static void BM_SPACE_TO_BATCH_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ static void BM_SPACE_TO_BATCH_##N##_##H##_##W##_##C##_##SHAPE##_##TYPE##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMSpaceToBatch<DEVICE, TYPE>(iters, N, C, H, W); \ BMSpaceToBatch<DEVICE, TYPE>(iters, N, H, W, C, SHAPE); \
} \ } \
BENCHMARK(BM_SPACE_TO_BATCH_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) BENCHMARK(BM_SPACE_TO_BATCH_##N##_##H##_##W##_##C##_##SHAPE##_##TYPE##_##DEVICE)
#define BM_SPACE_TO_BATCH(N, C, H, W, TYPE) \ #define BM_SPACE_TO_BATCH(N, H, W, C, SHAPE, TYPE) \
BM_SPACE_TO_BATCH_MACRO(N, C, H, W, TYPE, OPENCL); BM_SPACE_TO_BATCH_MACRO(N, H, W, C, SHAPE, TYPE, OPENCL);
BM_SPACE_TO_BATCH(128, 128, 15, 15, float); BM_SPACE_TO_BATCH(128, 16, 16, 128, 2, float);
BM_SPACE_TO_BATCH(1, 256, 256, 32, 2, float);
BM_SPACE_TO_BATCH(1, 256, 256, 32, 4, float);
} // namespace mace } // namespace mace
\ No newline at end of file
...@@ -4,79 +4,70 @@ ...@@ -4,79 +4,70 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
#include <fstream>
using namespace mace; using namespace mace;
template <DeviceType D> template<DeviceType D>
void RunSpaceToBatch(const std::vector<index_t> &input_shape, void RunSpaceToBatch(const std::vector<index_t> &input_shape,
const std::vector<float> &input_data, const std::vector<float> &input_data,
const std::vector<index_t> &block_shape_shape,
const std::vector<int> &block_shape_data, const std::vector<int> &block_shape_data,
const std::vector<index_t> &padding_shape,
const std::vector<int> &padding_data, const std::vector<int> &padding_data,
const Tensor *expected) { const Tensor *expected) {
OpsTestNet net; OpsTestNet net;
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("Input")
.Input("BlockShape")
.Input("Padding")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<D, float>( net.AddInputFromArray<D, float>(
"Input", input_shape, input_data); "Input", input_shape, input_data);
net.AddInputFromArray<D, int>(
"BlockShape", block_shape_shape, block_shape_data); BufferToImage<D, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
net.AddInputFromArray<D, int>("Padding", padding_shape, padding_data); OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntsArg("paddings", padding_data)
.AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D, float>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT);
// Check // Check
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-8); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-8);
} }
template <DeviceType D> template<DeviceType D>
void RunBatchToSpace(const std::vector<index_t> &input_shape, void RunBatchToSpace(const std::vector<index_t> &input_shape,
const std::vector<float> &input_data, const std::vector<float> &input_data,
const std::vector<index_t> &block_shape_shape,
const std::vector<int> &block_shape_data, const std::vector<int> &block_shape_data,
const std::vector<index_t> &crops_shape,
const std::vector<int> &crops_data, const std::vector<int> &crops_data,
const Tensor *expected) { const Tensor *expected) {
OpsTestNet net; OpsTestNet net;
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("Input")
.Input("BlockShape")
.Input("Crops")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddInputFromArray<D, float>( net.AddInputFromArray<D, float>(
"Input", input_shape, input_data); "Input", input_shape, input_data);
net.AddInputFromArray<D, int>(
"BlockShape", block_shape_shape, block_shape_data); BufferToImage<D, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
net.AddInputFromArray<D, int>("Crops", crops_shape, crops_data); OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntsArg("crops", crops_data)
.AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D, float>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT);
// Check // Check
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-8); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-8);
} }
template <typename T> template<typename T>
void TestBidirectionTransform(const std::vector<index_t> &space_shape, void TestBidirectionalTransform(const std::vector<index_t> &space_shape,
const std::vector<float> &space_data, const std::vector<float> &space_data,
const std::vector<index_t> &block_shape, const std::vector<int> &block_data,
const std::vector<int> &block_data, const std::vector<int> &padding_data,
const std::vector<index_t> &padding_shape, const std::vector<index_t> &batch_shape,
const std::vector<int> &padding_data, const std::vector<float> &batch_data) {
const std::vector<index_t> &batch_shape,
const std::vector<float> &batch_data) {
auto space_tensor = unique_ptr<Tensor>(new Tensor(GetDeviceAllocator(DeviceType::OPENCL), auto space_tensor = unique_ptr<Tensor>(new Tensor(GetDeviceAllocator(DeviceType::OPENCL),
DataTypeToEnum<T>::v())); DataTypeToEnum<T>::v()));
...@@ -101,99 +92,157 @@ void TestBidirectionTransform(const std::vector<index_t> &space_shape, ...@@ -101,99 +92,157 @@ void TestBidirectionTransform(const std::vector<index_t> &space_shape,
} }
RunSpaceToBatch<DeviceType::OPENCL>(space_shape, space_data, RunSpaceToBatch<DeviceType::OPENCL>(space_shape, space_data,
block_shape, block_data, block_data,
padding_shape, padding_data, padding_data,
batch_tensor.get()); batch_tensor.get());
RunBatchToSpace<DeviceType::OPENCL>(batch_shape, batch_data, RunBatchToSpace<DeviceType::OPENCL>(batch_shape, batch_data,
block_shape, block_data, block_data,
padding_shape, padding_data, padding_data,
space_tensor.get()); space_tensor.get());
} }
TEST(SpaceToBatchTest, SmallData) { TEST(SpaceToBatchTest, SmallData) {
TestBidirectionTransform<float>({1, 1, 2, 2}, TestBidirectionalTransform<float>({1, 2, 2, 1},
{1,2,3,4}, {1, 2, 3, 4},
{2}, {2, 2},
{2, 2}, {0, 0, 0, 0},
{2, 2}, {4, 1, 1, 1},
{0, 0, 0, 0}, {1, 2, 3, 4}
{4,1,1,1},
{1,2,3,4}
); );
} }
TEST(SpaceToBatchTest, SmallDataWithOnePadding) { TEST(SpaceToBatchTest, SmallDataWithOnePadding) {
TestBidirectionTransform<float>({1, 1, 2, 2}, TestBidirectionalTransform<float>({1, 2, 2, 1},
{1,2,3,4}, {1, 2, 3, 4},
{2}, {3, 3},
{3, 3}, {1, 0, 1, 0},
{2, 2}, {9, 1, 1, 1},
{1, 0, 1, 0}, {0, 0, 0, 0, 1, 2, 0, 3, 4}
{9,1,1,1},
{0,0,0,0,1,2,0,3,4}
); );
} }
TEST(SpaceToBatchTest, SmallDataWithTwoPadding) { TEST(SpaceToBatchTest, SmallDataWithTwoPadding) {
TestBidirectionTransform<float>({1, 1, 2, 2}, TestBidirectionalTransform<float>({1, 2, 2, 1},
{1,2,3,4}, {1, 2, 3, 4},
{2}, {2, 2},
{2, 2}, {1, 1, 1, 1},
{2, 2}, {4, 2, 2, 1},
{1, 1, 1, 1}, {0, 0, 0, 4, 0, 0, 3, 0, 0, 2, 0, 0, 1, 0, 0, 0}
{4,1,2,2}, );
{0,0,0,4,0,0,3,0,0,2,0,0,1,0,0,0} }
TEST(SpaceToBatchTest, SmallDataWithLargeImage) {
TestBidirectionalTransform<float>({1, 2, 10, 1},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
{2, 2},
{0, 0, 0, 0},
{4, 1, 5, 1},
{1, 3, 5, 7, 9,
2, 4, 6, 8, 10,
11, 13, 15, 17, 19,
12, 14, 16, 18, 20}
); );
} }
TEST(SpaceToBatchTest, MultiChannelData) { TEST(SpaceToBatchTest, MultiChannelData) {
TestBidirectionTransform<float>({1, 3, 2, 2}, TestBidirectionalTransform<float>({1, 2, 2, 3},
{1,2,3,4,5,6,7,8,9,10,11,12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{2}, {2, 2},
{2, 2}, {0, 0, 0, 0},
{2, 2}, {4, 1, 1, 3},
{0, 0, 0, 0}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
{4,3,1,1}, );
{1,5,9,2,6,10,3,7,11,4,8,12}
);
} }
TEST(SpaceToBatchTest, LargerMultiChannelData) { TEST(SpaceToBatchTest, LargerMultiChannelData) {
TestBidirectionTransform<float>({1, 1, 4, 4}, TestBidirectionalTransform<float>({1, 4, 4, 1},
{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{2}, {2, 2},
{2, 2}, {0, 0, 0, 0},
{2, 2}, {4, 2, 2, 1},
{0, 0, 0, 0}, {1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16}
{4,1,2,2},
{1,3,9,11,2,4,10,12,5,7,13,15,6,8,14,16}
); );
} }
TEST(SpaceToBatchTest, MultiBatchData) { TEST(SpaceToBatchTest, MultiBatchData) {
TestBidirectionTransform<float>({2, 1, 2, 4}, TestBidirectionalTransform<float>({2, 2, 4, 1},
{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
{2}, {2, 2},
{2, 2}, {0, 0, 0, 0},
{2, 2}, {8, 1, 2, 1},
{0, 0, 0, 0}, {1, 3, 2, 4, 5, 7, 6, 8, 9, 11, 10, 12, 13, 15, 14, 16}
{8,1,1,2},
{1,3,2,4,5,7,6,8,9,11,10,12,13,15,14,16}
); );
} }
TEST(SpaceToBatchTest, MultiBatchAndChannelData) { TEST(SpaceToBatchTest, MultiBatchAndChannelData) {
TestBidirectionTransform<float>({2, 2, 2, 4}, TestBidirectionalTransform<float>({2, 2, 4, 2},
{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32}, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
{2}, {2, 2},
{2, 2}, {0, 0, 0, 0},
{2, 2}, {8, 1, 2, 2},
{0, 0, 0, 0}, {1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16,
{8,2,1,2}, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32}
{1,3,9,11,2,4,10,12,5,7,13,15,6,8,14,16,
17,19,25,27,18,20,26,28,21,23,29,31,22,24,30,32}
); );
} }
//TEST(SpaceTobatchTest, CompareTF) {
//
// const std::string space_file = "/data/local/tmp/test/input";
// const std::string batch_file = "/data/local/tmp/test/output";
// const std::vector<index_t> space_shape = {1, 256, 256, 32};
// const int space_size = std::accumulate(space_shape.begin(), space_shape.end(), 1, std::multiplies<int>());
// const std::vector<index_t> batch_shape = {4, 130, 130, 32};
// const int batch_size = std::accumulate(batch_shape.begin(), batch_shape.end(), 1, std::multiplies<int>());
//
// auto space_tensor = unique_ptr<Tensor>(new Tensor(GetDeviceAllocator(DeviceType::OPENCL),
// DataTypeToEnum<float>::v()));
// space_tensor->Resize(space_shape);
// std::vector<float> space_data(space_size, 0.0);
// std::ifstream in_file(space_file, std::ios::in | std::ios::binary);
// if (in_file.is_open()) {
// in_file.read(reinterpret_cast<char *>(space_data.data()),
// space_size * sizeof(float));
// in_file.close();
// Tensor::MappingGuard space_mapper(space_tensor.get());
// float *space_ptr = space_tensor->mutable_data<float>();
// MACE_CHECK(static_cast<size_t>(space_tensor->size()) == space_data.size())
// << "Space tensor size:" << space_tensor->size()
// << ", space data size:" << space_data.size();
// memcpy(space_ptr, space_data.data(), space_data.size() * sizeof(float));
// } else {
// VLOG(0) << "open space file failed";
// }
//
// auto batch_tensor = unique_ptr<Tensor>(new Tensor(GetDeviceAllocator(DeviceType::OPENCL),
// DataTypeToEnum<float>::v()));
// std::vector<float> batch_data(batch_size, 0.0);
// batch_tensor->Resize(batch_shape);
// {
// std::ifstream in_file(batch_file, std::ios::in | std::ios::binary);
// if (in_file.is_open()) {
// in_file.read(reinterpret_cast<char *>(batch_data.data()),
// batch_size * sizeof(float));
// in_file.close();
// } else {
// VLOG(0) << "open batch file failed";
// }
// Tensor::MappingGuard batch_mapper(batch_tensor.get());
// float *batch_ptr = batch_tensor->mutable_data<float>();
// MACE_CHECK(static_cast<size_t>(batch_tensor->size()) == batch_data.size());
// memcpy(batch_ptr, batch_data.data(), batch_data.size() * sizeof(float));
// }
//
// RunSpaceToBatch<DeviceType::OPENCL>(space_shape, space_data,
// {2, 2},
// {2, 2, 2, 2},
// batch_tensor.get());
//
// RunBatchToSpace<DeviceType::OPENCL>(batch_shape, batch_data,
// {2, 2},
// {2, 2, 2, 2},
// space_tensor.get());
//}
...@@ -363,6 +363,27 @@ class TFConverter(object): ...@@ -363,6 +363,27 @@ class TFConverter(object):
self.net_def.op.extend([op_def]) self.net_def.op.extend([op_def])
self.resolved_ops[op.name] = 1 self.resolved_ops[op.name] = 1
def convert_space_to_batch(self, op, b2s):
op_def = self.net_def.op.add()
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
op_def.name = op.name
op_def.type = op.type
op_def.input.extend([op.inputs[0].name])
op_def.output.extend([output.name for output in op.outputs])
size_arg = op_def.arg.add()
size_arg.name = 'block_shape'
size_arg.ints.extend(get_input_tensor(op, 1).eval().astype(np.int32).flat)
size_arg = op_def.arg.add()
if b2s:
size_arg.name = 'crops'
else:
size_arg.name = 'paddings'
size_arg.ints.extend(get_input_tensor(op, 2).eval().astype(np.int32).flat)
self.add_output_shape(op.outputs, op_def)
self.resolved_ops[op.name] = 1
def convert_normal_op(self, op): def convert_normal_op(self, op):
op_def = self.net_def.op.add() op_def = self.net_def.op.add()
arg = op_def.arg.add() arg = op_def.arg.add()
...@@ -405,7 +426,11 @@ class TFConverter(object): ...@@ -405,7 +426,11 @@ class TFConverter(object):
self.convert_resize_bilinear(op) self.convert_resize_bilinear(op)
elif op.type == 'BiasAdd': elif op.type == 'BiasAdd':
self.convert_bias_add(op) self.convert_bias_add(op)
elif op.type in ['Relu', 'SpaceToBatchND', 'BatchToSpaceND']: elif op.type == 'SpaceToBatchND':
self.convert_space_to_batch(op, False)
elif op.type == 'BatchToSpaceND':
self.convert_space_to_batch(op, True)
elif op.type in ['Relu']:
self.convert_normal_op(op) self.convert_normal_op(op)
else: else:
raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type)) raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type))
......
TF_INPUT_NODE=input
TF_OUTPUT_NODE=GCN/br_result_2/fcn_br
\ No newline at end of file
TF_INPUT_NODE=input_node
TF_OUTPUT_NODE=GCN/br_result_x/fcn_br
\ No newline at end of file
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Must run at root dir of mace project. # Must run at root dir of mace project.
set +x set +x
Usage() { Usage() {
echo 'Usage: bash tools/validate_gcn.sh tf_model_path image_size [tuning]' echo 'Usage: bash tools/validate_gcn.sh tools/gcn.config tf_model_path image_size [tuning]'
} }
if [ $# -lt 2 ];then if [ $# -lt 2 ];then
...@@ -10,8 +10,10 @@ if [ $# -lt 2 ];then ...@@ -10,8 +10,10 @@ if [ $# -lt 2 ];then
exit -1 exit -1
fi fi
source $1
VLOG_LEVEL=0 VLOG_LEVEL=0
TF_MODEL_FILE_PATH=$1 TF_MODEL_FILE_PATH=$2
MODEL_DIR=$(dirname ${TF_MODEL_FILE_PATH}) MODEL_DIR=$(dirname ${TF_MODEL_FILE_PATH})
MACE_SOURCE_DIR=`/bin/pwd` MACE_SOURCE_DIR=`/bin/pwd`
MACE_MODEL_NAME='mace_model.pb' MACE_MODEL_NAME='mace_model.pb'
...@@ -20,14 +22,14 @@ OUTPUT_FILE_NAME='gcn.out' ...@@ -20,14 +22,14 @@ OUTPUT_FILE_NAME='gcn.out'
OUTPUT_LIST_FILE='gcn.list' OUTPUT_LIST_FILE='gcn.list'
PHONE_DATA_DIR="/data/local/tmp/${MACE_MODEL_NAME}" PHONE_DATA_DIR="/data/local/tmp/${MACE_MODEL_NAME}"
KERNEL_DIR="${PHONE_DATA_DIR}/cl/" KERNEL_DIR="${PHONE_DATA_DIR}/cl/"
IMAGE_SIZE=$2 IMAGE_SIZE=$3
MODEL_TAG=GCN${IMAGE_SIZE} MODEL_TAG=GCN${IMAGE_SIZE}
CODEGEN_DIR=${MACE_SOURCE_DIR}/mace/codegen CODEGEN_DIR=${MACE_SOURCE_DIR}/mace/codegen
MODEL_CODEGEN_DIR=${CODEGEN_DIR}/models/gcn-$IMAGE_SIZE MODEL_CODEGEN_DIR=${CODEGEN_DIR}/models/gcn-$IMAGE_SIZE
CL_CODEGEN_DIR=${CODEGEN_DIR}/opencl CL_CODEGEN_DIR=${CODEGEN_DIR}/opencl
CL_BIN_DIR=${CODEGEN_DIR}/opencl_bin CL_BIN_DIR=${CODEGEN_DIR}/opencl_bin
TUNING_CODEGEN_DIR=${CODEGEN_DIR}/tuning TUNING_CODEGEN_DIR=${CODEGEN_DIR}/tuning
TUNING_OR_NOT=${3:-0} TUNING_OR_NOT=${4:-0}
VERSION_SOURCE_PATH=${CODEGEN_DIR}/version VERSION_SOURCE_PATH=${CODEGEN_DIR}/version
build_and_run() build_and_run()
...@@ -87,8 +89,8 @@ rm -rf ${MODEL_CODEGEN_DIR} ...@@ -87,8 +89,8 @@ rm -rf ${MODEL_CODEGEN_DIR}
mkdir -p ${MODEL_CODEGEN_DIR} mkdir -p ${MODEL_CODEGEN_DIR}
bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \ bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \
--output=${MODEL_CODEGEN_DIR}/mace_gcn${IMAGE_SIZE}.cc \ --output=${MODEL_CODEGEN_DIR}/mace_gcn${IMAGE_SIZE}.cc \
--input_node=input \ --input_node=${TF_INPUT_NODE} \
--output_node=GCN/br_result_2/fcn_br \ --output_node=${TF_OUTPUT_NODE} \
--data_type=DT_HALF \ --data_type=DT_HALF \
--runtime=gpu \ --runtime=gpu \
--output_type=source \ --output_type=source \
...@@ -129,7 +131,7 @@ echo "Step 9: Validate the result" ...@@ -129,7 +131,7 @@ echo "Step 9: Validate the result"
python tools/validate.py --model_file ${TF_MODEL_FILE_PATH} \ python tools/validate.py --model_file ${TF_MODEL_FILE_PATH} \
--input_file ${MODEL_DIR}/${INPUT_FILE_NAME} \ --input_file ${MODEL_DIR}/${INPUT_FILE_NAME} \
--mace_out_file ${MODEL_DIR}/${OUTPUT_FILE_NAME} \ --mace_out_file ${MODEL_DIR}/${OUTPUT_FILE_NAME} \
--input_node input \ --input_node ${TF_INPUT_NODE} \
--output_node GCN/br_result_2/fcn_br\ --output_node ${TF_OUTPUT_NODE} \
--input_shape "${IMAGE_SIZE},${IMAGE_SIZE},3" \ --input_shape "${IMAGE_SIZE},${IMAGE_SIZE},3" \
--output_shape "1,${IMAGE_SIZE},${IMAGE_SIZE},2" --output_shape "1,${IMAGE_SIZE},${IMAGE_SIZE},2"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册