提交 cb752a69 编写于 作者: S songhonglei413

modify op_roi_pooling

上级 e6e67c4b
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/fp32/roi_pooling.h" #include "src/runtime/kernel/arm/fp32/roi_pooling.h"
#include "src/runtime/kernel/arm/nnacl/fp32/roi_pooling.h"
#include <vector> #include <vector>
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
...@@ -35,10 +36,35 @@ int ROIPoolingCPUKernel::Init() { ...@@ -35,10 +36,35 @@ int ROIPoolingCPUKernel::Init() {
return ReSize(); return ReSize();
} }
int ROIPoolingCPUKernel::ReSize() { return RET_OK; } int ROIPoolingCPUKernel::ReSize() {
auto in_shape = in_tensors_.front()->shape();
auto out_shape = out_tensors_.front()->shape();
int ndims = in_shape.size();
if (ndims > 4) {
MS_LOG(ERROR) << "ROIPooling ReSzie error ,shape dim greater than 4!";
return RET_ERROR;
}
param_->ndim_ = ndims;
param_->input_n_ = in_shape[0];
param_->input_h_ = in_shape[1];
param_->input_w_ = in_shape[2];
param_->input_c_ = in_shape[3];
param_->output_n_ = out_shape[0];
param_->output_h_ = out_shape[1];
param_->output_w_ = out_shape[2];
param_->output_c_ = out_shape[3];
param_->in_strides_[ndims - 1] = 1;
param_->out_strides_[ndims - 1] = 1;
for (int i = ndims - 2; i >= 0; --i) {
param_->in_strides_[i] = in_shape[i + 1] * param_->in_strides_[i + 1];
param_->out_strides_[i] = out_shape[i + 1] * param_->out_strides_[i + 1];
}
param_->thread_num_ = MSMIN(param_->op_parameter_.thread_num_, out_shape[0]);
return RET_OK;
}
int ROIPoolingCPUKernel::DoExecute(int task_id) { int ROIPoolingCPUKernel::DoExecute(int task_id) {
auto ret = ROIPooling(in_ptr_, out_ptr_, roi_ptr_, in_shape_, out_shape_, dim_, task_id, param_); auto ret = ROIPooling(in_ptr_, out_ptr_, roi_ptr_, task_id, param_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ROIPooling Execute error task_id[" << task_id << "] error_code[" << ret << "]"; MS_LOG(ERROR) << "ROIPooling Execute error task_id[" << task_id << "] error_code[" << ret << "]";
return ret; return ret;
...@@ -65,11 +91,7 @@ int ROIPoolingCPUKernel::Run() { ...@@ -65,11 +91,7 @@ int ROIPoolingCPUKernel::Run() {
in_ptr_ = reinterpret_cast<float *>(in_tensors_.front()->Data()); in_ptr_ = reinterpret_cast<float *>(in_tensors_.front()->Data());
out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->Data()); out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->Data());
roi_ptr_ = reinterpret_cast<float *>(in_tensors_.at(1)->Data()); roi_ptr_ = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
in_shape_ = reinterpret_cast<const int *>(in_tensors_.front()->shape().data()); ret = LiteBackendParallelLaunch(ROIPoolingRun, this, param_->thread_num_);
out_shape_ = reinterpret_cast<const int *>(out_tensors_.front()->shape().data());
dim_ = in_tensors_.front()->shape().size();
thread_count_ = 1;
ret = LiteBackendParallelLaunch(ROIPoolingRun, this, thread_count_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ROIPooling error: error_code[" << ret << "]"; MS_LOG(ERROR) << "ROIPooling error: error_code[" << ret << "]";
return ret; return ret;
......
...@@ -40,11 +40,7 @@ class ROIPoolingCPUKernel : public LiteKernel { ...@@ -40,11 +40,7 @@ class ROIPoolingCPUKernel : public LiteKernel {
float *in_ptr_; float *in_ptr_;
float *out_ptr_; float *out_ptr_;
float *roi_ptr_; float *roi_ptr_;
const int *in_shape_;
const int *out_shape_;
ROIPoolingParameter *param_; ROIPoolingParameter *param_;
int dim_;
int thread_count_;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
......
...@@ -16,29 +16,31 @@ ...@@ -16,29 +16,31 @@
#include "nnacl/fp32/roi_pooling.h" #include "nnacl/fp32/roi_pooling.h"
#include <math.h> #include <math.h>
#include <string.h>
#include "nnacl/errorcode.h" #include "nnacl/errorcode.h"
#include "nnacl/op_base.h"
int ROIPooling(float *in_ptr, float *out_ptr, float *roi, const int *in_shape, const int *out_shape, int dim, int tid, int ROIPooling(float *in_ptr, float *out_ptr, float *roi, int tid, ROIPoolingParameter *param) {
ROIPoolingParameter *param) { int num_rois = param->output_n_;
int num_rois = out_shape[kNHWC_N]; int units = UP_DIV(num_rois, param->thread_num_);
int batch_size = in_shape[kNHWC_N]; int roi_st = tid * units;
int height_ = in_shape[kNHWC_H]; int roi_end = MSMIN(num_rois, roi_st + units);
int width_ = in_shape[kNHWC_W]; if (roi_st >= num_rois) {
int channels_ = in_shape[kNHWC_C]; return NNACL_OK;
}
int batch_size = param->input_n_;
int height_ = param->input_h_;
int width_ = param->input_w_;
int channels_ = param->input_c_;
int scale = param->scale_; int scale = param->scale_;
int pooled_height = param->pooledH_; int pooled_height = param->pooledH_;
int pooled_width = param->pooledW_; int pooled_width = param->pooledW_;
int in_stride[DIMENSION_4D]; int *in_strides = &(param->in_strides_);
int out_stride[DIMENSION_4D]; int *out_strides = &(param->out_strides_);
const int roi_stride = 5; int roi_stride = 5;
in_stride[DIMENSION_4D - 1] = 1; int roi_ind_st = roi_st * roi_stride;
out_stride[DIMENSION_4D - 1] = 1; float *max_c = malloc(channels_ * sizeof(float));
for (int i = dim - 2; i >= 0; --i) { for (int i = roi_st; i < roi_end; ++i) {
in_stride[i] = in_stride[i + 1] * in_shape[i + 1];
out_stride[i] = out_stride[i + 1] * out_shape[i + 1];
}
int roi_ind_st = 0;
for (int i = 0; i < num_rois; ++i) {
int roi_batch_ind = (int)roi[roi_ind_st]; // batch_index int roi_batch_ind = (int)roi[roi_ind_st]; // batch_index
if (roi_batch_ind >= batch_size) { if (roi_batch_ind >= batch_size) {
return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; return NNACL_ERRCODE_INDEX_OUT_OF_RANGE;
...@@ -53,44 +55,46 @@ int ROIPooling(float *in_ptr, float *out_ptr, float *roi, const int *in_shape, c ...@@ -53,44 +55,46 @@ int ROIPooling(float *in_ptr, float *out_ptr, float *roi, const int *in_shape, c
float bin_size_h = (float)roi_height / (float)pooled_height; float bin_size_h = (float)roi_height / (float)pooled_height;
float bin_size_w = (float)roi_width / (float)pooled_width; float bin_size_w = (float)roi_width / (float)pooled_width;
float *batch_data = in_ptr + in_stride[kNHWC_N] * roi_batch_ind; float *batch_data = in_ptr + in_strides[kNHWC_N] * roi_batch_ind;
int out_ind = i * out_stride[0];
for (int c = kNHWC_N; c < channels_; ++c) {
float max_v = -__FLT_MAX__;
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int pooled_index =
i * out_stride[kNHWC_N] + ph * out_stride[kNHWC_H] + pw * out_stride[kNHWC_W] + c * out_stride[kNHWC_C];
int hstart = (int)floorf(ph * bin_size_h); // block xi_1
int wstart = (int)floorf(pw * bin_size_w); // block yi_1
int hend = (int)ceilf((ph + 1) * bin_size_h); // block xi_2
int wend = (int)ceilf((pw + 1) * bin_size_w); // block yi_2
hstart = MSMIN(MSMAX(hstart + roi_start_h, 0), height_); int out_ind = i * out_strides[0];
hend = MSMIN(MSMAX(hend + roi_start_h, 0), height_); for (int ph = 0; ph < pooled_height; ++ph) {
wstart = MSMIN(MSMAX(wstart + roi_start_w, 0), width_); for (int pw = 0; pw < pooled_width; ++pw) {
wend = MSMIN(MSMAX(wend + roi_start_w, 0), width_); int hstart = (int)floorf(ph * bin_size_h); // block xi_1
int wstart = (int)floorf(pw * bin_size_w); // block yi_1
int hend = (int)ceilf((ph + 1) * bin_size_h); // block xi_2
int wend = (int)ceilf((pw + 1) * bin_size_w); // block yi_2
hstart = MSMIN(MSMAX(hstart + roi_start_h, 0), height_);
hend = MSMIN(MSMAX(hend + roi_start_h, 0), height_);
wstart = MSMIN(MSMAX(wstart + roi_start_w, 0), width_);
wend = MSMIN(MSMAX(wend + roi_start_w, 0), width_);
for (int j = 0; j < channels_; ++j) {
max_c[j] = -__FLT_MAX__;
bool is_empty = (hend <= hstart) || (wend <= wstart); bool is_empty = (hend <= hstart) || (wend <= wstart);
if (is_empty) { if (is_empty) {
max_v = 0; max_c[j] = 0;
} }
int bd_index = c * in_stride[kNHWC_C] + hstart * in_stride[kNHWC_H]; }
for (int h = hstart; h < hend; ++h) { int pooled_index = i * out_strides[0] + ph * out_strides[1] + pw * out_strides[2];
int wi = bd_index + wstart * in_stride[kNHWC_W]; int bd_index = hstart * in_strides[1];
for (int w = wstart; w < wend; ++w) { for (int h = hstart; h < hend; ++h) {
max_v = MSMAX(batch_data[wi], max_v); int wi = bd_index + wstart * in_strides[2];
// printf("bd:index: %d, data: %f, max_v: %f\n",wi,batch_data[wi],max_v); for (int w = wstart; w < wend; ++w) {
wi += in_stride[kNHWC_W]; for (int c = 0; c < channels_; ++c) {
max_c[c] = MSMAX(batch_data[wi + c], max_c[c]);
} }
bd_index += in_stride[kNHWC_H]; wi += in_strides[2];
} } // in_w end;
out_ptr[pooled_index] = max_v; bd_index += in_strides[1];
} // in_h end
for (int j = 0; j < channels_; ++j) {
out_ptr[pooled_index + j] = max_c[j];
} }
} }
} }
roi_ind_st += roi_stride; roi_ind_st += roi_stride;
} }
free(max_c);
return NNACL_OK; return NNACL_OK;
} }
...@@ -20,16 +20,27 @@ ...@@ -20,16 +20,27 @@
typedef struct ROIPoolingParameter { typedef struct ROIPoolingParameter {
OpParameter op_parameter_; OpParameter op_parameter_;
int in_strides_[DIMENSION_4D];
int out_strides_[DIMENSION_4D];
float scale_;
int ndim_;
int input_w_;
int input_h_;
int input_n_;
int input_c_;
int output_w_;
int output_h_;
int output_n_;
int output_c_;
int thread_num_;
int pooledW_; int pooledW_;
int pooledH_; int pooledH_;
float scale_;
} ROIPoolingParameter; } ROIPoolingParameter;
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int ROIPooling(float *in_ptr, float *out_ptr, float *roi, const int *in_shape, const int *out_shape, int dim, int tid, int ROIPooling(float *in_ptr, float *out_ptr, float *roi, int tid, ROIPoolingParameter *param);
ROIPoolingParameter *param);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -57,10 +57,10 @@ TEST_F(TestROIPoolingFp32, Simple) { ...@@ -57,10 +57,10 @@ TEST_F(TestROIPoolingFp32, Simple) {
param->pooledH_ = 2; param->pooledH_ = 2;
float a[] = {1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33, 34, 35, float a[] = {1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33, 34, 35,
1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33, 34, 35}; 1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33, 34, 35};
float b[] = {0, 1, 1, 3, 4, 1, 1, 1, 3, 4}; float b[] = {0, 1, 1, 3, 4};
std::vector<int> a_shape = {2, 4, 5, 1}; std::vector<int> a_shape = {1, 4, 5, 2};
std::vector<int> b_shape = {2, 5}; std::vector<int> b_shape = {2, 5};
std::vector<int> c_shape = {2, 2, 2, 1}; std::vector<int> c_shape = {1, 2, 2, 2};
int total_size = ROIPoolingTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); int total_size = ROIPoolingTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
auto ctx = new lite::Context; auto ctx = new lite::Context;
ctx->thread_num_ = 3; ctx->thread_num_ = 3;
...@@ -68,7 +68,7 @@ TEST_F(TestROIPoolingFp32, Simple) { ...@@ -68,7 +68,7 @@ TEST_F(TestROIPoolingFp32, Simple) {
new kernel::ROIPoolingCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx, nullptr); new kernel::ROIPoolingCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx, nullptr);
op->Init(); op->Init();
op->Run(); op->Run();
float correct[] = {23, 25, 33, 35, 23, 25, 33, 35}; float correct[] = {25, 31, 34, 35, 25, 31, 34, 35};
float *output = reinterpret_cast<float *>(outputs_[0]->Data()); float *output = reinterpret_cast<float *>(outputs_[0]->Data());
for (int i = 0; i < 8; ++i) printf("%f ", output[i]); for (int i = 0; i < 8; ++i) printf("%f ", output[i]);
printf("\n"); printf("\n");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册