From d765b4b0e2a5e47cdc05ab6265321760de2406f3 Mon Sep 17 00:00:00 2001 From: chenjianping Date: Wed, 26 Aug 2020 13:01:50 +0800 Subject: [PATCH] fix space_to_batch_nd bug --- mindspore/lite/nnacl/fp32/space_to_batch.c | 181 +++++------- mindspore/lite/nnacl/fp32/space_to_batch.h | 21 +- mindspore/lite/src/lite_kernel.cc | 24 ++ mindspore/lite/src/lite_kernel.h | 14 +- mindspore/lite/src/ops/concat.cc | 6 +- mindspore/lite/src/ops/primitive_c.cc | 5 + mindspore/lite/src/ops/space_to_batch.cc | 4 +- mindspore/lite/src/ops/space_to_batch.h | 3 +- mindspore/lite/src/ops/space_to_batch_nd.cc | 52 +++- mindspore/lite/src/ops/space_to_batch_nd.h | 3 +- mindspore/lite/src/populate_parameter.cc | 22 +- .../runtime/kernel/arm/fp32/space_to_batch.cc | 136 ++++----- .../runtime/kernel/arm/fp32/space_to_batch.h | 16 +- .../arm/fp32/space_to_batch_fp32_tests.cc | 271 ++++++++++-------- 14 files changed, 404 insertions(+), 354 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/space_to_batch.c b/mindspore/lite/nnacl/fp32/space_to_batch.c index bc64d0663..0cd665de0 100644 --- a/mindspore/lite/nnacl/fp32/space_to_batch.c +++ b/mindspore/lite/nnacl/fp32/space_to_batch.c @@ -16,132 +16,79 @@ #include "nnacl/fp32/space_to_batch.h" #include "nnacl/arithmetic_common.h" #include "nnacl/errorcode.h" -#include "nnacl/fp32/concat.h" #include "nnacl/op_base.h" -int EnumElement(int *shape, int n_dims) { - int total = 1; - for (int i = 0; i < n_dims; i++) { - total *= shape[i]; - } - return total; -} - -void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm, - int *output_shape, int h_start, int h_end) { - const int stride0 = strides[perm[0]]; - const int stride1 = strides[perm[1]]; - const int stride2 = strides[perm[2]]; - const int stride3 = strides[perm[3]]; - const int stride4 = strides[perm[4]]; - const int out_stride0 = out_strides[0]; - const int out_stride1 = out_strides[1]; - const int out_stride2 = out_strides[2]; - const int out_stride3 = out_strides[3]; - const int out_stride4 = out_strides[4]; - const int output0 = output_shape[0]; - const int output2 = output_shape[2]; - const int output3 = output_shape[3]; - const int output4 = output_shape[4]; - - for (int i = 0; i < output0; ++i) { - int out_stride0_i = i * out_stride0; - int stride0_i = i * stride0; - for (int j = h_start; j < h_end; ++j) { - int out_stride1_j = j * out_stride1; - int stride1_j = j * stride1; - for (int k = 0; k < output2; ++k) { - int out_stride2_k = k * out_stride2; - int stride2_k = k * stride2; - for (int m = 0; m < output3; ++m) { - int out_stride3_m = m * out_stride3; - int stride3_m = m * stride3; - for (int n = 0; n < output4; ++n) { - int out_stride4_n = n * out_stride4; - int stride4_n = n * stride4; - memcpy(out_data + out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n, - in_data + stride0_i + stride1_j + stride2_k + stride3_m + stride4_n, stride4 * sizeof(float)); - } - } +void DoSpaceToBatchNHWC(const float *input, float *output, SpaceToBatchParameter *param, int *in_shape, + int *out_shape) { + int out_dim0 = out_shape[0]; + int out_dim1 = out_shape[1]; + int out_dim2 = out_shape[2]; + int copy_num = out_shape[3]; + int block_w = param->block_sizes_[1]; + int block_h = param->block_sizes_[0]; + int in_strides[4]; + ComputeStrides(in_shape, in_strides, 4); + int out_strides[4]; + ComputeStrides(out_shape, out_strides, 4); + size_t copy_size = copy_num * sizeof(float); + size_t out_offset = 0; + for (int n = 0; n < out_dim0; ++n) { + int in_n = n % in_shape[0]; + int32_t stride_w = (n / in_shape[0]) % block_w; + int32_t stride_h = (n / in_shape[0]) / block_w; + size_t in_offset0 = in_n * in_strides[0]; + for (int h = 0; h < out_dim1; ++h) { + size_t in_offset1 = in_offset0 + (h * block_h + stride_h) * in_strides[1]; + for (int w = 0; w < out_dim2; ++w) { + size_t in_offset2 = in_offset1 + (w * block_w + stride_w) * in_strides[2]; + memcpy(output + out_offset, input + in_offset2, copy_size); + out_offset += copy_num; } } } } -int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_sizes, int h_start, - int h_end) { - int trans_in_shape[6] = {in_shape[0], in_shape[1] / block_sizes[0], - block_sizes[0], in_shape[2] / block_sizes[1], - block_sizes[1], in_shape[3]}; - int trans_out_shape[6] = { - in_shape[0], block_sizes[0], block_sizes[1], in_shape[1] / block_sizes[0], in_shape[2] / block_sizes[1], - in_shape[3]}; - int in_strides[C4NUM + 2]; - ComputeStrides(trans_in_shape, in_strides, shape_size + 2); - int out_strides[C4NUM + 2]; - ComputeStrides(trans_out_shape, out_strides, shape_size + 2); - - int perm[6] = {0, 2, 4, 1, 3, 5}; - TransposeForNHWC(input, output, in_strides, out_strides, perm, trans_out_shape, h_start, h_end); - return NNACL_OK; -} - -void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter param, float *tmp_space[]) { - float *tmp = padded_input; - (void)memcpy(tmp, input, param.num_elements_ * sizeof(float)); - float *target = tmp_space[0]; - float *tmp_zeros = tmp_space[1]; - float *tmp2 = NULL; - int cur_shape[param.n_dims_], cur_start_shape[param.n_dims_], cur_end_shape[param.n_dims_], - cur_target_shape[param.n_dims_]; - float *concat_inputs[3]; - int *concat_shapes[4]; - - for (int i = 0; i < param.n_dims_; i++) { - cur_shape[i] = param.in_shape_[i]; - cur_start_shape[i] = param.in_shape_[i]; - cur_end_shape[i] = param.in_shape_[i]; - cur_target_shape[i] = param.in_shape_[i]; - } - for (int i = 0; i < param.n_space_dims_; ++i) { - if (param.padded_in_shape_[i + 1] > param.in_shape_[i + 1]) { - int concat_idx = 0; - cur_target_shape[i + 1] = 0; - if (param.paddings_[2 * i] != 0) { - cur_start_shape[i + 1] = param.paddings_[2 * i]; - concat_inputs[concat_idx] = tmp_zeros; - concat_shapes[concat_idx++] = cur_start_shape; - cur_target_shape[i + 1] += cur_start_shape[i + 1]; +void DoSpaceToBatchPaddingNHWC(const float *input, float *output, int *in_shape, int *padding, int *out_shape, + const float *pedding_h_data, const float *pedding_w_data) { + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int out_w = out_shape[2]; + int out_c = out_shape[3]; + size_t ped_h_num = out_w * out_c; + size_t ped_h_size = ped_h_num * sizeof(float); + size_t ped_w_size = out_c * sizeof(float); + size_t out_offset = 0; + int in_strides[4]; + ComputeStrides(in_shape, in_strides, 4); + int out_strides[4]; + ComputeStrides(out_shape, out_strides, 4); + size_t copy_size = in_c * sizeof(float); + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_offset0 = i * in_strides[0]; + for (int pad_h_top = 0; pad_h_top < padding[0]; ++pad_h_top) { + memcpy(output + out_offset, pedding_h_data, ped_h_size); + out_offset += ped_h_num; + } + for (int j = 0; j < in_h; ++j) { + size_t in_offset1 = in_offset0 + j * in_strides[1]; + for (int pad_w_left = 0; pad_w_left < padding[2]; ++pad_w_left) { + memcpy(output + out_offset, pedding_w_data, ped_w_size); + out_offset += out_c; } - - concat_inputs[concat_idx] = tmp; - concat_shapes[concat_idx++] = cur_shape; - cur_target_shape[i + 1] += cur_shape[i + 1]; - if (param.paddings_[2 * i + 1] != 0) { - cur_end_shape[i + 1] = param.paddings_[2 * i + 1]; - concat_inputs[concat_idx] = tmp_zeros; - concat_shapes[concat_idx++] = cur_end_shape; - cur_target_shape[i + 1] += cur_end_shape[i + 1]; + for (int k = 0; k < in_w; ++k) { + size_t in_offset2 = in_offset1 + k * in_strides[2]; + memcpy(output + out_offset, input + in_offset2, copy_size); + out_offset += in_c; + } + for (int pad_w_right = 0; pad_w_right < padding[3]; ++pad_w_right) { + memcpy(output + out_offset, pedding_w_data, ped_w_size); + out_offset += out_c; } - concat_shapes[concat_idx] = cur_target_shape; - Concat((void **)concat_inputs, concat_idx, i + 1, concat_shapes, param.n_dims_, target); - - tmp2 = tmp; - tmp = target; - target = tmp2; - cur_start_shape[i + 1] = cur_end_shape[i + 1] = cur_shape[i + 1] = concat_shapes[concat_idx][i + 1]; + } + for (int pad_h_bottom = 0; pad_h_bottom < padding[1]; ++pad_h_bottom) { + memcpy(output + out_offset, pedding_h_data, ped_h_size); + out_offset += ped_h_num; } } - if (padded_input != tmp) { - memcpy(padded_input, tmp, param.num_elements_padded_ * sizeof(float)); - } -} - -int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, int h_start, int h_end) { - if (input == NULL || output == NULL) { - return NNACL_NULL_PTR; - } - int ret = - SpaceToBatchForNHWC(input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_, h_start, h_end); - return ret; } diff --git a/mindspore/lite/nnacl/fp32/space_to_batch.h b/mindspore/lite/nnacl/fp32/space_to_batch.h index 5b19e7dfc..540640802 100644 --- a/mindspore/lite/nnacl/fp32/space_to_batch.h +++ b/mindspore/lite/nnacl/fp32/space_to_batch.h @@ -22,26 +22,17 @@ typedef struct SpaceToBatchParameter { OpParameter op_parameter_; - int block_sizes_[8]; - int paddings_[8]; - int n_dims_; - int num_elements_; - int num_elements_padded_; - int n_space_dims_; - int in_shape_[8]; - int padded_in_shape_[8]; bool need_paddings_; + int block_sizes_[4]; + int paddings_[4]; } SpaceToBatchParameter; #ifdef __cplusplus extern "C" { #endif -int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, int h_start, int h_end); -int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_size, int h_start, - int h_end); -void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm, - int *output_shape, int h_start, int h_end); -void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter param, float *tmp_space[]); -int EnumElement(int *shape, int n_dims); +void DoSpaceToBatchNHWC(const float *input, float *output, SpaceToBatchParameter *param, int *in_shape, + int *out_shape); +void DoSpaceToBatchPaddingNHWC(const float *input, float *output, int *in_shape, int *padding, int *out_shape, + const float *pedding_h_data, const float *pedding_w_data); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index 5d0cda8eb..7d3edf0df 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -39,6 +39,30 @@ int LiteKernel::DecOutTensorRefCount() { return 0; } +int LiteKernel::Prepare() { + if (!InferShapeDone()) { + (const_cast(primitive_))->SetInferFlag(true); + auto ret = (const_cast(primitive_))->InferShape(in_tensors_, out_tensors_); + if (ret != 0) { + (const_cast(primitive_))->SetInferFlag(false); + MS_LOG(ERROR) << "InferShape fail!"; + return ret; + } + ret = ReSize(); + if (ret != 0) { + MS_LOG(ERROR) << "ReSize fail!ret: " << ret; + return ret; + } + } + + auto &outputs = this->out_tensors(); + for (auto *output : outputs) { + MS_ASSERT(output != nullptr); + output->MallocData(); + } + return RET_OK; +} + std::vector LiteKernelUtil::SubgraphInputKernels( const std::vector &kernels) { std::vector input_kernels; diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 64682c3e7..260073286 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -75,19 +75,7 @@ class LiteKernel { virtual ~LiteKernel() = default; - virtual int Prepare() { - if (!InferShapeDone()) { - (const_cast(primitive_))->InferShape(in_tensors_, out_tensors_); - ReSize(); - } - - auto &outputs = this->out_tensors(); - for (auto *output : outputs) { - MS_ASSERT(output != nullptr); - output->MallocData(); - } - return RET_OK; - } + virtual int Prepare(); virtual int Init() { return -1; } diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index d2068dabd..f89a72141 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -96,17 +96,13 @@ int Concat::InferShape(std::vector inputs_, std::vectordata_type(); - schema::Format input0_format = inputs_[0]->GetFormat(); int output_axis_dim = input0_shape.at(axis); for (size_t i = 1; i < inputs_.size(); ++i) { if (inputs_.at(i)->data_type() != input0_data_type) { MS_LOG(ERROR) << "All inputs should have the same data type!"; return RET_PARAM_INVALID; } - if (inputs_.at(i)->GetFormat() != input0_format) { - MS_LOG(ERROR) << "All input format should be the same!"; - return RET_PARAM_INVALID; - } + auto shape_tmp = inputs_.at(i)->shape(); if (shape_tmp.size() != input0_shape.size()) { MS_LOG(ERROR) << "All inputs should have the same dim num!"; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 92079026f..db88eeee7 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -17,6 +17,7 @@ #include "src/ops/primitive_c.h" #include #include "src/ops/space_to_batch.h" +#include "src/ops/space_to_batch_nd.h" #include "src/ops/conv2d.h" #include "src/ops/roi_pooling.h" #include "src/ops/topk.h" @@ -414,6 +415,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT return new BatchToSpace(primitive); case schema::PrimitiveType_SpaceToBatch: return new SpaceToBatch(primitive); + case schema::PrimitiveType_SpaceToBatchND: + return new SpaceToBatchND(primitive); case schema::PrimitiveType_BroadcastTo: return new BroadcastTo(primitive); case schema::PrimitiveType_DepthToSpace: @@ -620,6 +623,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive * return new BatchToSpace(const_cast(primitive)); case schema::PrimitiveType_SpaceToBatch: return new SpaceToBatch(const_cast(primitive)); + case schema::PrimitiveType_SpaceToBatchND: + return new SpaceToBatchND(const_cast(primitive)); case schema::PrimitiveType_BroadcastTo: return new BroadcastTo(const_cast(primitive)); case schema::PrimitiveType_DepthToSpace: diff --git a/mindspore/lite/src/ops/space_to_batch.cc b/mindspore/lite/src/ops/space_to_batch.cc index ea1712fb8..194245538 100644 --- a/mindspore/lite/src/ops/space_to_batch.cc +++ b/mindspore/lite/src/ops/space_to_batch.cc @@ -105,8 +105,8 @@ int SpaceToBatch::InferShape(std::vector inputs, std::ve std::vector output_shape(input_shape.size()); output_shape[NHWC_N] = input_shape[NHWC_N] * (block_sizes_[NHWC_N] * block_sizes_[NHWC_H]); - output_shape[NHWC_H] = input_shape[NHWC_H] / block_sizes_[NHWC_N]; - output_shape[NHWC_W] = input_shape[NHWC_W] / block_sizes_[NHWC_H]; + output_shape[NHWC_H] = (input_shape[NHWC_H] + paddings_[0] + paddings_[1]) / block_sizes_[NHWC_N]; + output_shape[NHWC_W] = (input_shape[NHWC_W] + paddings_[2] + paddings_[3]) / block_sizes_[NHWC_H]; output_shape[NHWC_C] = input_shape[NHWC_C]; outputs[0]->set_shape(output_shape); return RET_OK; diff --git a/mindspore/lite/src/ops/space_to_batch.h b/mindspore/lite/src/ops/space_to_batch.h index 865d5c357..5257a6e4d 100644 --- a/mindspore/lite/src/ops/space_to_batch.h +++ b/mindspore/lite/src/ops/space_to_batch.h @@ -36,7 +36,8 @@ class SpaceToBatch : public PrimitiveC { #else explicit SpaceToBatch(schema::Primitive *primitive) : PrimitiveC(primitive) {} #endif - int InferShape(std::vector inputs_, std::vector outputs_) override; + int InferShape(std::vector inputs, std::vector outputs) override; + std::vector GetBlockShape() const; std::vector GetPaddings() const; diff --git a/mindspore/lite/src/ops/space_to_batch_nd.cc b/mindspore/lite/src/ops/space_to_batch_nd.cc index 2606f6401..b716a6cfe 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.cc +++ b/mindspore/lite/src/ops/space_to_batch_nd.cc @@ -15,9 +15,17 @@ */ #include "src/ops/space_to_batch_nd.h" +#include "src/common/common.h" namespace mindspore { namespace lite { +namespace { +constexpr int kSpaceToBatchNDOutputNum = 1; +constexpr int kSpaceToBatchNDInputNum = 1; +constexpr int kBlockSizesSize = 2; +constexpr int kPaddingsSize = 4; +} // namespace + #ifdef PRIMITIVE_WRITEABLE std::vector SpaceToBatchND::GetBlockShape() const { return this->primitive_->value.AsSpaceToBatchND()->blockShape; @@ -42,6 +50,48 @@ std::vector SpaceToBatchND::GetPaddings() const { return std::vector(fb_vector->begin(), fb_vector->end()); } -#endif +#endif // PRIMITIVE_WRITEABLE + +int SpaceToBatchND::InferShape(std::vector inputs, + std::vector outputs) { + if (outputs.size() != kSpaceToBatchNDOutputNum || inputs.size() != kSpaceToBatchNDInputNum) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return 1; + } + + auto input = inputs.at(0); + if (input->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "space_to_batch_nd only support NHWC now!"; + return RET_ERROR; + } + outputs[0]->set_data_type(input->data_type()); + outputs[0]->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + auto input_shape = input->shape(); + if (input_shape.size() != kDimension_4d) { + MS_LOG(ERROR) << "input shape dimension size only support " << kDimension_4d << " now!"; + return RET_ERROR; + } + auto block_shape = GetBlockShape(); + if (block_shape.size() != kBlockSizesSize) { + MS_LOG(ERROR) << "blockShape size != " << kBlockSizesSize; + return RET_ERROR; + } + auto pedding = GetPaddings(); + if (pedding.size() != kPaddingsSize) { + MS_LOG(ERROR) << "pedding size should be " << kPaddingsSize; + return RET_ERROR; + } + + std::vector output_shape(input_shape.size()); + output_shape[NHWC_N] = input_shape[NHWC_N] * block_shape[0] * block_shape[1]; + output_shape[NHWC_H] = (input_shape[NHWC_H] + pedding[0] + pedding[1]) / block_shape[0]; + output_shape[NHWC_W] = (input_shape[NHWC_W] + pedding[2] + pedding[3]) / block_shape[1]; + output_shape[NHWC_C] = input_shape[NHWC_C]; + outputs[0]->set_shape(output_shape); + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/space_to_batch_nd.h b/mindspore/lite/src/ops/space_to_batch_nd.h index b0f5429e1..4b53a7715 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.h +++ b/mindspore/lite/src/ops/space_to_batch_nd.h @@ -18,8 +18,6 @@ #define LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_BATCH_N_D_H_ #include -#include -#include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" @@ -38,6 +36,7 @@ class SpaceToBatchND : public PrimitiveC { #endif std::vector GetBlockShape() const; std::vector GetPaddings() const; + int InferShape(std::vector inputs, std::vector outputs) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index bf524d7a2..07d014d6b 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -20,6 +20,7 @@ #include "schema/ops_generated.h" #include "src/ops/constant_of_shape.h" #include "src/ops/space_to_batch.h" +#include "src/ops/space_to_batch_nd.h" #include "src/ops/conv2d.h" #include "src/ops/roi_pooling.h" #include "src/ops/topk.h" @@ -1189,10 +1190,22 @@ OpParameter *PopulateSpaceToBatchParameter(const mindspore::lite::PrimitiveC *pr (void)memcpy(space_batch_param->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int)); auto paddings = ((mindspore::lite::SpaceToBatch *)primitive)->Paddings(); (void)memcpy(space_batch_param->paddings_, (paddings.data()), paddings.size() * sizeof(int)); - auto in_shape = ((mindspore::lite::SpaceToBatch *)primitive)->InShape(); - (void)memcpy(space_batch_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); - auto padded_in_shape = ((mindspore::lite::SpaceToBatch *)primitive)->PaddedInShape(); - (void)memcpy(space_batch_param->padded_in_shape_, (padded_in_shape.data()), padded_in_shape.size() * sizeof(int)); + return reinterpret_cast(space_batch_param); +} + +OpParameter *PopulateSpaceToBatchParameterND(const mindspore::lite::PrimitiveC *primitivec) { + auto *space_batch_param = new (std::nothrow) SpaceToBatchParameter(); + if (space_batch_param == nullptr) { + MS_LOG(ERROR) << "new SpaceToBatchParameter failed."; + return nullptr; + } + + mindspore::lite::SpaceToBatchND *primitive = (mindspore::lite::SpaceToBatchND *)primitivec; + space_batch_param->op_parameter_.type_ = primitive->Type(); + auto block_sizes = primitive->GetBlockShape(); + (void)memcpy(space_batch_param->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int)); + auto paddings = primitive->GetPaddings(); + (void)memcpy(space_batch_param->paddings_, (paddings.data()), paddings.size() * sizeof(int)); return reinterpret_cast(space_batch_param); } @@ -1525,6 +1538,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_BatchToSpace] = PopulateBatchToSpaceParameter; populate_parameter_funcs_[schema::PrimitiveType_SpaceToDepth] = PopulateSpaceToDepthParameter; populate_parameter_funcs_[schema::PrimitiveType_SpaceToBatch] = PopulateSpaceToBatchParameter; + populate_parameter_funcs_[schema::PrimitiveType_SpaceToBatchND] = PopulateSpaceToBatchParameterND; populate_parameter_funcs_[schema::PrimitiveType_Crop] = PopulateCropParameter; populate_parameter_funcs_[schema::PrimitiveType_Unsqueeze] = PopulateUnsqueezeParameter; populate_parameter_funcs_[schema::PrimitiveType_Flatten] = PopulateFlattenParameter; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc index 81d6e9dee..647372e6e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc @@ -29,8 +29,18 @@ using mindspore::lite::RET_FORMAT_ERR; using mindspore::lite::RET_OK; using mindspore::lite::RET_OP_EXECUTE_FAILURE; using mindspore::schema::PrimitiveType_SpaceToBatch; +using mindspore::schema::PrimitiveType_SpaceToBatchND; namespace mindspore::kernel { +namespace { +size_t EnumElement(int *shape, int n_dims) { + size_t total = 1; + for (int i = 0; i < n_dims; i++) { + total *= shape[i]; + } + return total; +} +} int SpaceToBatchCPUKernel::Init() { SpaceToBatchParameter *param = reinterpret_cast(this->op_parameter_); @@ -40,37 +50,26 @@ int SpaceToBatchCPUKernel::Init() { break; } } - param->n_dims_ = DIMENSION_4D; - param->n_space_dims_ = SPACE_TO_BATCH_BLOCK_SIZES_SIZE; + if (!InferShapeDone()) { return RET_OK; } return ReSize(); } -int SpaceToBatchCPUKernel::SpaceToBatchParallel(int task_id) { - int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - task_id * thread_h_stride_); - if (num_unit_thread <= 0) { - return RET_OK; +void SpaceToBatchCPUKernel::FreeTmpBuffer() { + if (pedding_h_data_ != nullptr) { + context_->allocator->Free(pedding_h_data_); + pedding_h_data_ = nullptr; } - int thread_offset = task_id * thread_h_stride_; - SpaceToBatchParameter *param = reinterpret_cast(this->op_parameter_); - auto ret = SpaceToBatch(input_ptr_, output_ptr_, *param, thread_offset, thread_offset + num_unit_thread); - if (ret != RET_OK) { - MS_LOG(ERROR) << "SpaceToDepth error task_id[" << task_id << "] error_code[" << ret << "]"; - return RET_ERROR; + if (pedding_w_data_ != nullptr) { + context_->allocator->Free(pedding_w_data_); + pedding_w_data_ = nullptr; } - return RET_OK; -} - -int SpaceToBatchRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { - auto g_kernel = reinterpret_cast(cdata); - auto ret = g_kernel->SpaceToBatchParallel(task_id); - if (ret != RET_OK) { - MS_LOG(ERROR) << "SpaceToBatchRun error task_id[" << task_id << "] error_code[" << ret << "]"; - return RET_OP_EXECUTE_FAILURE; + if (pedding_input_ != nullptr) { + context_->allocator->Free(pedding_input_); + pedding_input_ = nullptr; } - return RET_OK; } int SpaceToBatchCPUKernel::ReSize() { @@ -78,13 +77,39 @@ int SpaceToBatchCPUKernel::ReSize() { MS_LOG(ERROR) << "space_to_batch only support NHWC now!"; return RET_FORMAT_ERR; } + FreeTmpBuffer(); SpaceToBatchParameter *param = reinterpret_cast(this->op_parameter_); - param->num_elements_ = EnumElement(param->in_shape_, param->n_dims_); - param->num_elements_padded_ = EnumElement(param->padded_in_shape_, param->n_dims_); - num_unit_ = static_cast(in_tensors_[kInputIndex]->shape().at(kNHWC_H)); - num_unit_ /= param->block_sizes_[0]; - thread_h_num_ = MSMIN(thread_num_, num_unit_); - thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); + if (!param->need_paddings_) { + return RET_OK; + } + auto input = in_tensors_[0]; + auto in_shape = input->shape(); + padded_in_shape_ = in_shape; + padded_in_shape_[1] = in_shape[1] + param->paddings_[0] + param->paddings_[1]; + padded_in_shape_[2] = in_shape[2] + param->paddings_[2] + param->paddings_[3]; + auto num_elements_padded = EnumElement(padded_in_shape_.data(), in_shape.size()); + auto output_shape = out_tensors_[0]->shape(); + auto pedding_h_size = output_shape[2] * output_shape[3] * sizeof(float); + pedding_h_data_ = reinterpret_cast(context_->allocator->Malloc(pedding_h_size)); + if (pedding_h_data_ == nullptr) { + MS_LOG(ERROR) << "malloc pedding h data fail!"; + return RET_ERROR; + } + auto pedding_w_size = output_shape[3] * sizeof(float); + pedding_w_data_ = reinterpret_cast(context_->allocator->Malloc(pedding_w_size)); + if (pedding_w_data_ == nullptr) { + MS_LOG(ERROR) << "malloc pedding w data fail!"; + FreeTmpBuffer(); + return RET_ERROR; + } + pedding_input_ = + reinterpret_cast(context_->allocator->Malloc(num_elements_padded * sizeof(float))); + if (pedding_input_ == nullptr) { + MS_LOG(ERROR) << "malloc pedding buffer fail!"; + return RET_ERROR; + } + memset(pedding_h_data_, 0, pedding_h_size); + memset(pedding_w_data_, 0, pedding_w_size); return RET_OK; } @@ -96,54 +121,32 @@ int SpaceToBatchCPUKernel::Run() { } auto input = in_tensors_[0]; auto output = out_tensors_[0]; - input_ptr_ = reinterpret_cast(input->Data()); - output_ptr_ = reinterpret_cast(output->Data()); + const float *input_ptr_ = reinterpret_cast(input->Data()); + float *output_ptr_ = reinterpret_cast(output->Data()); SpaceToBatchParameter *param = reinterpret_cast(this->op_parameter_); - - float *tmp_space[3] = {nullptr, nullptr, nullptr}; + auto in_shape = input->shape(); + auto out_shape = output->shape(); if (param->need_paddings_) { - for (int i = 0; i < 3; ++i) { - tmp_space[i] = - reinterpret_cast(context_->allocator->Malloc(param->num_elements_padded_ * sizeof(float))); - (void)memset(tmp_space[i], 0, param->num_elements_padded_ * sizeof(float)); - if (tmp_space[i] == nullptr) { - MS_LOG(ERROR) << "malloc tmp buffer fail!"; - return RET_ERROR; - } - } - auto padded_input = tmp_space[0]; - DoPadding(input_ptr_, padded_input, *param, tmp_space + 1); - input_ptr_ = padded_input; - } - - if (input->GetFormat() == schema::Format_NHWC) { - ret = LiteBackendParallelLaunch(SpaceToBatchRun, this, thread_h_num_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "SpaceToBatch error error_code[" << ret << "]"; - } + DoSpaceToBatchPaddingNHWC(input_ptr_, pedding_input_, in_shape.data(), param->paddings_, + padded_in_shape_.data(), pedding_h_data_, pedding_w_data_); + DoSpaceToBatchNHWC(pedding_input_, output_ptr_, param, padded_in_shape_.data(), out_shape.data()); + return RET_OK; } else { - MS_LOG(ERROR) << "Only support NHWC now!"; - ret = RET_FORMAT_ERR; - } - if (param->need_paddings_) { - for (int i = 0; i < 3; ++i) { - context_->allocator->Free(tmp_space[i]); - } + DoSpaceToBatchNHWC(input_ptr_, output_ptr_, param, in_shape.data(), out_shape.data()); + return RET_OK; } - - return ret; } // namespace mindspore::kernel kernel::LiteKernel *CpuSpaceToBatchFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, + OpParameter *param, const lite::Context *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { - if (opParameter == nullptr) { - MS_LOG(ERROR) << "Input opParameter is nullptr!"; + if (param == nullptr) { + MS_LOG(ERROR) << "Input param is nullptr!"; return nullptr; } - auto *kernel = new (std::nothrow) SpaceToBatchCPUKernel(opParameter, inputs, outputs, ctx, primitive); + auto *kernel = new (std::nothrow) SpaceToBatchCPUKernel(param, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SpaceToBatchCPUKernel fail!"; return nullptr; @@ -152,12 +155,13 @@ kernel::LiteKernel *CpuSpaceToBatchFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + MS_LOG(ERROR) << "Init kernel failed, name: " << param->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(param->type_)); return nullptr; } return kernel; } REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SpaceToBatch, CpuSpaceToBatchFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SpaceToBatchND, CpuSpaceToBatchFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h index f93de6cc8..2135d27c7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h @@ -25,22 +25,20 @@ class SpaceToBatchCPUKernel : public LiteKernel { SpaceToBatchCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~SpaceToBatchCPUKernel() = default; + ~SpaceToBatchCPUKernel() { FreeTmpBuffer(); } int Init() override; int ReSize() override; int Run() override; - int SpaceToBatchParallel(int task_id); private: - int thread_num_; - int thread_h_stride_; - int thread_h_num_; - int num_unit_; - const float *input_ptr_; - float *output_ptr_; + void FreeTmpBuffer(); + float *pedding_input_ = nullptr; + float *pedding_h_data_ = nullptr; + float *pedding_w_data_ = nullptr; + std::vector padded_in_shape_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc index 69f944131..59eed62a1 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc @@ -28,142 +28,175 @@ class SpaceToBatchTestFp32 : public mindspore::CommonTest { SpaceToBatchTestFp32() {} }; -void InitSpaceToBatchParameter(SpaceToBatchParameter *param) { - param->n_dims_ = 4; - param->n_space_dims_ = 2; - - param->block_sizes_[0] = 2; - param->block_sizes_[1] = 2; - - param->paddings_[0] = 2; - param->paddings_[1] = 0; - param->paddings_[2] = 2; - param->paddings_[3] = 2; - - param->in_shape_[0] = 1; - param->in_shape_[1] = 4; - param->in_shape_[2] = 4; - param->in_shape_[3] = 1; - - param->padded_in_shape_[0] = 1; - param->padded_in_shape_[1] = 6; - param->padded_in_shape_[2] = 8; - param->padded_in_shape_[3] = 1; - - param->num_elements_ = 16; - param->num_elements_padded_ = 48; - - param->need_paddings_ = true; +TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest4) { + std::vector input = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + const size_t kOutSize = 16; + std::vector expect_out = {1, 2, 3, 4, 9, 10, 11, 12, + 5, 6, 7, 8, 13, 14, 15, 16}; + float out[kOutSize]; + std::vector in_shape = {1, 4, 4, 1}; + std::vector out_shape = {2, 2, 4, 1}; + SpaceToBatchParameter param; + param.block_sizes_[0] = 2; + param.block_sizes_[1] = 1; + DoSpaceToBatchNHWC(input.data(), out, ¶m, in_shape.data(), out_shape.data()); + for (int i = 0; i < kOutSize; ++i) { + std::cout << out[i] << " "; + } + std::cout << "\n"; + CompareOutputData(out, expect_out.data(), kOutSize, 0.000001); } -void InitSpaceToBatchParameter2(SpaceToBatchParameter *param) { - param->block_sizes_[0] = 2; - param->block_sizes_[1] = 2; - - param->paddings_[0] = 2; - param->paddings_[1] = 0; - param->paddings_[2] = 2; - param->paddings_[3] = 2; - - param->in_shape_[0] = 1; - param->in_shape_[1] = 4; - param->in_shape_[2] = 4; - param->in_shape_[3] = 1; - - param->padded_in_shape_[0] = 1; - param->padded_in_shape_[1] = 6; - param->padded_in_shape_[2] = 8; - param->padded_in_shape_[3] = 1; +TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest5) { + std::vector input = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + size_t kOutSize = 16; + std::vector expect_out = {1, 3, 5, 7, 9, 11, 13, 15, + 2, 4, 6, 8, 10, 12, 14, 16}; + float out[kOutSize]; + std::vector in_shape = {1, 4, 4, 1}; + std::vector out_shape = {2, 4, 2, 1}; + SpaceToBatchParameter param; + param.block_sizes_[0] = 1; + param.block_sizes_[1] = 2; + DoSpaceToBatchNHWC(input.data(), out, ¶m, in_shape.data(), out_shape.data()); + for (int i = 0; i < kOutSize; ++i) { + std::cout << out[i] << " "; + } + std::cout << "\n"; + CompareOutputData(out, expect_out.data(), kOutSize, 0.000001); } -TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest1) { - float input[16] = {1, 2, 5, 6, 10, 20, 3, 8, 18, 10, 3, 4, 11, 55, 15, 25}; - const int out_size = 16; - float expect_out[16] = {1, 5, 18, 3, 2, 6, 10, 4, 10, 3, 11, 15, 20, 8, 55, 25}; - - float output[16]; - int in_shape[4] = {1, 4, 4, 1}; - int out_shape[4] = {4, 2, 2, 1}; - int block_sizes[2] = {2, 2}; - SpaceToBatchForNHWC((const float *)input, output, in_shape, 4, block_sizes, 0, 4 / 2); - for (int i = 0; i < out_size; ++i) { - std::cout << output[i] << " "; +TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest6) { + std::vector input = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + size_t kOutSize = 16; + std::vector expect_out = {1, 3, 9, 11, 2, 4, 10, 12, + 5, 7, 13, 15, 6, 8, 14, 16}; + float out[kOutSize]; + std::vector in_shape = {1, 4, 4, 1}; + std::vector out_shape = {4, 2, 2, 1}; + SpaceToBatchParameter param; + param.block_sizes_[0] = 2; + param.block_sizes_[1] = 2; + DoSpaceToBatchNHWC(input.data(), out, ¶m, in_shape.data(), out_shape.data()); + for (int i = 0; i < kOutSize; ++i) { + std::cout << out[i] << " "; } std::cout << "\n"; - CompareOutputData(output, expect_out, out_size, 0.000001); + CompareOutputData(out, expect_out.data(), kOutSize, 0.000001); } -TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest2) { +TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest7) { + std::vector input = {1, 11, 2, 12, 3, 13, 4, 14, + 5, 15, 6, 16, 7, 17, 8, 18, + 9, 19, 10, 110, 11, 111, 12, 112, + 10, 11, 20, 12, 30, 13, 40, 14, + 50, 15, 60, 16, 70, 17, 80, 18, + 13, 113, 14, 114, 15, 115, 16, 116}; + size_t kOutSize = 48; + std::vector expect_out = {1, 11, 3, 13, 9, 19, 11, 111, + 50, 15, 70, 17, 2, 12, 4, 14, + 10, 110, 12, 112, 60, 16, 80, 18, + 5, 15, 7, 17, 10, 11, 30, 13, + 13, 113, 15, 115, 6, 16, 8, 18, + 20, 12, 40, 14, 14, 114, 16, 116}; + float out[kOutSize]; + std::vector in_shape = {1, 6, 4, 2}; + std::vector out_shape = {4, 3, 2, 2}; SpaceToBatchParameter param; - InitSpaceToBatchParameter(¶m); - float input[16] = {1, 2, 5, 6, 10, 20, 3, 8, 18, 10, 3, 4, 11, 55, 15, 25}; - const int out_size = 48; - float expect_out[48] = {0, 0, 0, 0, 0, 1, 5, 0, 0, 18, 3, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0, 10, 4, 0, - 0, 0, 0, 0, 0, 10, 3, 0, 0, 11, 15, 0, 0, 0, 0, 0, 0, 20, 8, 0, 0, 55, 25, 0}; - float output[48]; - int in_shape[4] = {1, 4, 4, 1}; - int out_shape[4] = {4, 3, 4, 1}; - int block_sizes[2] = {2, 2}; + param.block_sizes_[0] = 2; + param.block_sizes_[1] = 2; + DoSpaceToBatchNHWC(input.data(), out, ¶m, in_shape.data(), out_shape.data()); + for (int i = 0; i < kOutSize; ++i) { + std::cout << out[i] << " "; + } + std::cout << "\n"; + CompareOutputData(out, expect_out.data(), kOutSize, 0.000001); +} - float padded_input[48]{}, tmp[48]{}, tmp_zero[48]{}; - float *tmp_space[3] = {padded_input, tmp, tmp_zero}; - // DoPadding - DoPadding(input, padded_input, param, tmp_space + 1); +TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest8) { + std::vector input = {1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8, + 9, -9, 10, -10, 11, -11, 12, -12, 13, -13, 14, -14, 15, -15, 16, -16}; + std::vector expect_out = {1, -1, 2, -2, 3, -3, 4, -4, 0, 0, 5, -5, 6, -6, 7, -7, 8, -8, 0, 0, + 9, -9, 10, -10, 11, -11, 12, -12, 0, 0, 13, -13, 14, -14, 15, -15, 16, -16, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + size_t kOutSize = 50; + float out[kOutSize]; + std::vector in_shape = {1, 4, 4, 2}; + std::vector out_shape = {1, 5, 5, 2}; + std::vector padding = {0, 1, 0, 1}; + std::vector pedding_h(10, 0); + std::vector pedding_w(2, 0); + DoSpaceToBatchPaddingNHWC(input.data(), out, in_shape.data(), padding.data(), out_shape.data(), pedding_h.data(), + pedding_w.data()); + for (int i = 0; i < kOutSize; ++i) { + std::cout << out[i] << " "; + } + std::cout << "\n"; + CompareOutputData(out, expect_out.data(), kOutSize, 0.000001); +} - auto ret = SpaceToBatch((const float *)padded_input, output, param, 0, 4 / 2); - std::cout << "return " << ret << std::endl; - for (int i = 0; i < out_size; ++i) { - std::cout << output[i] << " "; +TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest9) { + std::vector input = {1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8, + 9, -9, 10, -10, 11, -11, 12, -12, 13, -13, 14, -14, 15, -15, 16, -16}; + std::vector expect_out = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, -1, 2, -2, 3, -3, 4, -4, 0, 0, + 0, 0, 5, -5, 6, -6, 7, -7, 8, -8, 0, 0, + 0, 0, 9, -9, 10, -10, 11, -11, 12, -12, 0, 0, + 0, 0, 13, -13, 14, -14, 15, -15, 16, -16, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + size_t kOutSize = 72; + float out[kOutSize]; + std::vector in_shape = {1, 4, 4, 2}; + std::vector out_shape = {1, 6, 6, 2}; + std::vector padding = {1, 1, 1, 1}; + std::vector pedding_h(12, 0); + std::vector pedding_w(2, 0); + DoSpaceToBatchPaddingNHWC(input.data(), out, in_shape.data(), padding.data(), out_shape.data(), pedding_h.data(), + pedding_w.data()); + for (int i = 0; i < kOutSize; ++i) { + std::cout << out[i] << " "; } std::cout << "\n"; - CompareOutputData(output, expect_out, out_size, 0.000001); + CompareOutputData(out, expect_out.data(), kOutSize, 0.000001); } -TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest3) { +TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest10) { + std::vector input = {1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8, + 9, -9, 10, -10, 11, -11, 12, -12, 13, -13, 14, -14, 15, -15, 16, -16}; + std::vector expect_out = {0, 0, 0, 0, 0, 0, + 0, 0, 6, -6, 8, -8, + 0, 0, 14, -14, 16, -16, + 0, 0, 0, 0, 0, 0, + 5, -5, 7, -7, 0, 0, + 13, -13, 15, -15, 0, 0, + 0, 0, 2, -2, 4, -4, + 0, 0, 10, -10, 12, -12, + 0, 0, 0, 0, 0, 0, + 1, -1, 3, -3, 0, 0, + 9, -9, 11, -11, 0, 0, + 0, 0, 0, 0, 0, 0}; + size_t kOutSize = 72; + float out[kOutSize]; + float pedding_out[kOutSize]; + std::vector in_shape = {1, 4, 4, 2}; + std::vector pedding_out_shape = {1, 6, 6, 2};; + std::vector out_shape = {4, 3, 3, 2}; + std::vector padding = {1, 1, 1, 1}; + std::vector pedding_h(12, 0); + std::vector pedding_w(2, 0); + DoSpaceToBatchPaddingNHWC(input.data(), pedding_out, in_shape.data(), padding.data(), pedding_out_shape.data(), + pedding_h.data(), pedding_w.data()); SpaceToBatchParameter param; - InitSpaceToBatchParameter2(¶m); - param.op_parameter_.type_ = schema::PrimitiveType_SpaceToBatch; - - std::vector input = {1, 2, 5, 6, 10, 20, 3, 8, 18, 10, 3, 4, 11, 55, 15, 25}; - std::vector in_shape = {1, 4, 4, 1}; - lite::tensor::Tensor input_tensor; - input_tensor.SetData(input.data()); - input_tensor.set_shape(in_shape); - input_tensor.SetFormat(schema::Format_NHWC); - input_tensor.set_data_type(kNumberTypeFloat32); - std::vector inputs_tensor; - inputs_tensor.emplace_back(&input_tensor); - - const int out_size = 48; - float expect_out[48] = {0, 0, 0, 0, 0, 1, 5, 0, 0, 18, 3, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0, 10, 4, 0, - 0, 0, 0, 0, 0, 10, 3, 0, 0, 11, 15, 0, 0, 0, 0, 0, 0, 20, 8, 0, 0, 55, 25, 0}; - std::vector output(48); - std::vector out_shape = {4, 3, 4, 1}; - lite::tensor::Tensor output_tensor; - output_tensor.SetData(output.data()); - output_tensor.set_shape(out_shape); - output_tensor.SetFormat(schema::Format_NHWC); - output_tensor.set_data_type(kNumberTypeFloat32); - std::vector outputs_tensor; - outputs_tensor.emplace_back(&output_tensor); - - lite::Context ctx; - ctx.thread_num_ = 2; - kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SpaceToBatch}; - auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - ASSERT_NE(creator, nullptr); - kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(¶m), &ctx, desc, nullptr); - ASSERT_NE(kernel, nullptr); - kernel->Run(); - - for (int i = 0; i < out_size; ++i) { - std::cout << output[i] << " "; + param.block_sizes_[0] = 2; + param.block_sizes_[1] = 2; + DoSpaceToBatchNHWC(pedding_out, out, ¶m, pedding_out_shape.data(), out_shape.data()); + for (int i = 0; i < kOutSize; ++i) { + std::cout << out[i] << " "; } std::cout << "\n"; - CompareOutputData(output.data(), expect_out, out_size, 0.000001); - input_tensor.SetData(nullptr); - output_tensor.SetData(nullptr); + CompareOutputData(out, expect_out.data(), kOutSize, 0.000001); } - } // namespace mindspore -- GitLab