提交 cdd63eb4 编写于 作者: H Hong Ming

fix ARM kernel of pool2d op for armv7/v8

因为 它太大了无法显示 source diff 。你可以改为 查看blob
...@@ -25,7 +25,7 @@ namespace arm { ...@@ -25,7 +25,7 @@ namespace arm {
namespace math { namespace math {
// !pooling fp32 Op // !pooling fp32 Op
void pooling_basic(const void* din, void* dout, int num, int chout, int hout, void pooling_basic(const float* din, float* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win, int wout, int chin, int hin, int win,
const std::vector<int>& ksize, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& strides,
...@@ -33,77 +33,39 @@ void pooling_basic(const void* din, void* dout, int num, int chout, int hout, ...@@ -33,77 +33,39 @@ void pooling_basic(const void* din, void* dout, int num, int chout, int hout,
bool exclusive, bool adaptive, bool ceil_mode, bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type); bool use_quantizer, const std::string& pooling_type);
void pooling_global(const void* din, void* dout, int num, int chout, int hout, void pooling_global_max(const float* din, float* dout, int num, int chout,
int wout, int chin, int hin, int win, int hout, int wout, int chin, int hin, int win);
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout, void pooling_global_avg(const float* din, float* dout, int num, int chout,
int wout, int chin, int hin, int win, int hout, int wout, int chin, int hin, int win);
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout, void pooling2x2s2_max(const float* din, float* dout, int num, int chout,
int wout, int chin, int hin, int win, int hout, int wout, int chin, int hin, int win);
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, void pooling2x2s2_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win, int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize, bool exclusive);
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, void pooling3x3s1p1_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win, int hout, int wout, int chin, int hin, int win);
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, void pooling3x3s1p1_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win, int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize, bool exclusive);
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, void pooling3x3s2p1_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win, int hout, int wout, int chin, int hin, int win);
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, void pooling3x3s2p1_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win, int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize, bool exclusive);
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout, void pooling3x3s2p0_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling3x3s2p0_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win, int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize, bool exclusive);
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
......
...@@ -48,120 +48,96 @@ void PoolCompute::Run() { ...@@ -48,120 +48,96 @@ void PoolCompute::Run() {
bool use_quantizer = param.use_quantizer; bool use_quantizer = param.use_quantizer;
std::string& data_format = param.data_format; std::string& data_format = param.data_format;
if (param.global_pooling) { bool kps_equal = (ksize[0] == ksize[1]) && (strides[0] == strides[1]) &&
(paddings[0] == paddings[1]);
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0; paddings[i] = 0;
ksize[i] = static_cast<int>(in_dims[i + 2]); ksize[i] = static_cast<int>(in_dims[i + 2]);
} }
if (pooling_type == "max") {
lite::arm::math::pooling_global_max(din, dout, out_dims[0], out_dims[1],
out_dims[2], out_dims[3], in_dims[1],
in_dims[2], in_dims[3]);
VLOG(3) << "invoking pooling_global_max";
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling_global_avg(din, dout, out_dims[0], out_dims[1],
out_dims[2], out_dims[3], in_dims[1],
in_dims[2], in_dims[3]);
VLOG(3) << "invoking pooling_global_ave";
return;
} }
} else {
#if 0 if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && kps_equal) {
for (int i = 0; i < in_dims.size(); ++i) {
LOG(INFO) << "in_dims[" << i << "]:" << in_dims[i];
}
for (int i = 0; i < out_dims.size(); ++i) {
LOG(INFO) << "out_dims[" << i << "]:" << out_dims[i];
}
for (int i = 0; i < ksize.size(); ++i) {
LOG(INFO) << "ksize[" << i << "]:" << ksize[i];
}
for (int i = 0; i < strides.size(); ++i) {
LOG(INFO) << "strides[" << i << "]:" << strides[i];
}
for (int i = 0; i < paddings.size(); ++i) {
LOG(INFO) << "paddings[" << i << "]:" << paddings[i];
}
LOG(INFO) << "global_pooling:" << global_pooling;
LOG(INFO) << "exclusive:" << exclusive;
LOG(INFO) << "adaptive:" << adaptive;
LOG(INFO) << "ceil_mode:" << ceil_mode;
LOG(INFO) << "use_quantizer:" << use_quantizer;
LOG(INFO) << "data_format:" << data_format;
LOG(INFO) << "din:" << din;
LOG(INFO) << "dout:" << dout;
#endif
// global
if (global_pooling == true) {
lite::arm::math::pooling_global(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1]) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max( lite::arm::math::pooling2x2s2_max(din, dout, out_dims[0], out_dims[1],
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], out_dims[2], out_dims[3], in_dims[1],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, in_dims[2], in_dims[3]);
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, VLOG(3) << "invoking pooling2x2s2_max";
pooling_type); return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_ave( lite::arm::math::pooling2x2s2_avg(din, dout, out_dims[0], out_dims[1],
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], out_dims[2], out_dims[3], in_dims[1],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, in_dims[2], in_dims[3], exclusive);
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, VLOG(3) << "invoking pooling2x2s2_avg";
pooling_type); return;
} }
} else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 1 && } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
strides[0] == strides[1] && paddings[0] == 1) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max( lite::arm::math::pooling3x3s1p1_max(din, dout, out_dims[0], out_dims[1],
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, in_dims[1], in_dims[2], in_dims[3]);
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, VLOG(3) << "invokingpooling3x3s1p1_max";
pooling_type); return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s1p1_ave( lite::arm::math::pooling3x3s1p1_avg(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, in_dims[1], in_dims[2], in_dims[3], exclusive);
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, VLOG(3) << "invoking pooling3x3s1p1_avg";
pooling_type); return;
} }
} else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
strides[0] == strides[1] && paddings[0] == 0) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max( lite::arm::math::pooling3x3s2p0_max(din, dout, out_dims[0], out_dims[1],
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, in_dims[1], in_dims[2], in_dims[3]);
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, VLOG(3) << "pooling3x3s2p0_max";
pooling_type); return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p0_ave( lite::arm::math::pooling3x3s2p0_avg(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, in_dims[1], in_dims[2], in_dims[3], exclusive);
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, VLOG(3) << "invoking pooling3x3s2p0_avg";
pooling_type); return;
} }
} else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
strides[0] == strides[1] && paddings[0] == 1) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max( lite::arm::math::pooling3x3s2p1_max(din, dout, out_dims[0], out_dims[1],
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, in_dims[1], in_dims[2], in_dims[3]);
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, VLOG(3) << "invoking pooling3x3s2p1_max";
pooling_type); return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p1_ave( lite::arm::math::pooling3x3s2p1_avg(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, in_dims[1], in_dims[2], in_dims[3], exclusive);
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, VLOG(3) << "invoking pooling3x3s2p1_avg";
pooling_type); return;
} }
} else {
lite::arm::math::pooling_basic(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} }
return; }
lite::arm::math::pooling_basic(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], in_dims[1],
in_dims[2], in_dims[3], ksize, strides, paddings, global_pooling,
exclusive, adaptive, ceil_mode, use_quantizer, pooling_type);
VLOG(3) << "invoking pooling_basic";
} }
TargetType PoolCompute::target() const { return TARGET(kARM); }
PrecisionType PoolCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -29,9 +29,6 @@ class PoolCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -29,9 +29,6 @@ class PoolCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
void PrepareForRun() override; void PrepareForRun() override;
void Run() override; void Run() override;
TargetType target() const override;
PrecisionType precision() const override;
virtual ~PoolCompute() = default; virtual ~PoolCompute() = default;
}; };
......
...@@ -101,94 +101,65 @@ void pool_compute_ref(const operators::PoolParam& param) { ...@@ -101,94 +101,65 @@ void pool_compute_ref(const operators::PoolParam& param) {
int pad_w = paddings[1]; int pad_w = paddings[1];
if (global_pooling == true) { if (global_pooling == true) {
ksize[0] = in_h;
ksize[1] = in_w;
for (int n = 0; n < in_n; ++n) { for (int n = 0; n < in_n; ++n) {
for (int c = 0; c < in_c; ++c) { for (int c = 0; c < in_c; ++c) {
const float* src = src_ptr + n * in_c * in_h * in_w + c * in_h * in_w; const float* src = src_ptr + n * size_in_n + c * size_in_c;
float res = src[0]; float res = src[0];
if (pooling_type == "max") { if (pooling_type == "max") {
for (int i = 1; i < in_h * in_w; ++i) { for (int i = 1; i < size_in_c; ++i) {
float cur_val = src[i]; float cur_val = src[i];
res = cur_val > res ? cur_val : res; res = cur_val > res ? cur_val : res;
} }
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
for (int i = 1; i < in_h * in_w; ++i) { for (int i = 1; i < size_in_c; ++i) {
float cur_val = src[i]; float cur_val = src[i];
res += cur_val; res += cur_val;
} }
res /= (in_h * in_w); res /= size_in_c;
}
dst_ptr[n * in_c * out_h * out_w + c] = res;
} }
dst_ptr[n * size_out_n + c] = res;
} }
return;
} }
} else {
for (int ind_n = 0; ind_n < in_n; ++ind_n) { for (int n = 0; n < in_n; ++n) {
for (int ind_c = 0; ind_c < in_c; ++ind_c) { for (int c = 0; c < in_c; ++c) {
for (int ind_h = 0; ind_h < out_h; ++ind_h) { for (int h = 0; h < out_h; ++h) {
int sh = ind_h * stride_h; int sh = h * stride_h;
int eh = sh + window_h; int eh = sh + window_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > in_h ? in_h : eh - pad_h; eh = (eh - pad_h) > in_h ? in_h : eh - pad_h;
for (int w = 0; w < out_w; ++w) {
for (int ind_w = 0; ind_w < out_w; ++ind_w) { int sw = w * stride_w;
int sw = ind_w * stride_w;
int ew = sw + window_w; int ew = sw + window_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > in_w ? in_w : ew - pad_w; ew = (ew - pad_w) > in_w ? in_w : ew - pad_w;
int pooling_size = (ew - sw) * (eh - sh);
float result = static_cast<float>(0); if (pooling_size == 0) continue;
float res = 0.f;
int dst_ind =
ind_n * size_out_n + ind_c * size_out_c + ind_h * out_w + ind_w;
for (int kh = sh; kh < eh; ++kh) { for (int kh = sh; kh < eh; ++kh) {
for (int kw = sw; kw < ew; ++kw) { for (int kw = sw; kw < ew; ++kw) {
int src_ind = int src_idx = n * size_in_n + c * size_in_c + kh * in_w + kw;
ind_n * size_in_n + ind_c * size_in_c + kh * in_w + kw;
if (kh == sh && kw == sw) { if (kh == sh && kw == sw) {
result = src_ptr[src_ind]; res = src_ptr[src_idx];
} else { } else {
if (pooling_type == "max") { if (pooling_type == "max") {
result = res = res >= src_ptr[src_idx] ? res : src_ptr[src_idx];
result >= src_ptr[src_ind] ? result : src_ptr[src_ind];
}
if (pooling_type == "avg" && exclusive == false) {
// Pooling_average_include_padding
result += src_ptr[src_ind];
} }
if (pooling_type == "avg" && exclusive == true) { if (pooling_type == "avg") {
// Pooling_average_include_padding res += src_ptr[src_idx];
result += src_ptr[src_ind];
} }
} }
} }
} }
if (pooling_type == "avg" && exclusive == false) { if (pooling_type == "avg") {
// Pooling_average_include_padding if (exclusive) {
// result /= param.window_h * param.window_w; res /= pooling_size;
// LOG(ERROR)<<"cpu"<<param.window_h * param.window_w; } else {
int bh = window_h; res /= window_h * window_w;
int bw = window_w;
if (ew == in_w) {
bw = sw + window_w >= in_w + pad_w ? in_w + pad_w : sw + window_w;
bw -= sw;
}
if (eh == in_h) {
bh = sh + window_h >= in_h + pad_h ? in_h + pad_h : sh + window_h;
bh -= sh;
} }
result /= bh * bw;
} }
if (pooling_type == "avg" && exclusive == true) { dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = res;
// Pooling_average_exclude_padding
result /= (ew - sw) * (eh - sh);
} }
dst_ptr[dst_ind] = result;
} }
} }
} }
...@@ -209,41 +180,54 @@ TEST(pool_arm, compute) { ...@@ -209,41 +180,54 @@ TEST(pool_arm, compute) {
lite::Tensor output; lite::Tensor output;
lite::Tensor output_ref; lite::Tensor output_ref;
for (auto pooling_type : {"avg", "max"}) { // speedup for ci
for (auto global_pooling : {true}) { for (auto pooling_type : {"max", "avg"}) {
// for (auto ksize: {3}) { // TODO(yuanshuai): ksize enable 2, 3 for (auto ceil_mode : {true, false}) {
for (auto global_pooling : {true, false}) {
for (auto exclusive : {true, false}) {
for (auto ksize : {2, 3}) {
for (auto stride : {1, 2}) { for (auto stride : {1, 2}) {
for (auto pad : {0, 1}) { for (auto pad : {0, 1}) {
for (auto n : {1, 3, 4, 11}) { for (auto n : {1, 2}) {
for (auto c : {1, 3, 11 /* ,1024 */}) { // speedup for ci for (auto c : {1, 3}) {
#if 1
for (auto h : {2, 3, 4, 11}) { for (auto h : {2, 3, 4, 11}) {
for (auto w : {2, 3, 4, 11}) { for (auto w : {2, 3, 4, 11}) {
LOG(INFO) << "n:" << n << " c:" << c << " h:" << h #else
<< " w:" << w // << " ksize:" << ksize for (int h = 2; h < 25; h++) {
for (int w = 2; w < 25; w++) {
#endif
VLOG(3) << "n:" << n << " c:" << c << " h:" << h
<< " w:" << w << " ksize:" << ksize
<< " stride:" << stride << " pad:" << pad << " stride:" << stride << " pad:" << pad
<< " pooling_type:" << pooling_type << " exclusive:" << exclusive
<< " global_pooling:" << global_pooling; << " global_pooling:" << global_pooling
<< " ceil_mode: " << ceil_mode
<< " pooling_type:" << pooling_type;
// init x, output // init x, output
x.Resize(DDim(std::vector<int64_t>({n, c, h, w}))); x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_data = x.mutable_data<float>(); auto* x_data = x.mutable_data<float>();
for (int i = 0; i < x.dims().production(); ++i) { for (int i = 0; i < x.dims().production(); ++i) {
x_data[i] = i; float sign = i % 3 == 0 ? -0.03 : 0.05f;
x_data[i] = sign * (i % 128);
} }
// fill param // fill param
param.x = &x; param.x = &x;
param.output = &output; param.output = &output;
param.pooling_type = pooling_type; param.pooling_type = pooling_type;
// param.ksize = {ksize, ksize}; //TODO(yuanshuai): ksize if (global_pooling) {
// enable
param.ksize = {h, w}; param.ksize = {h, w};
} else {
param.ksize = {ksize, ksize};
}
param.global_pooling = global_pooling; param.global_pooling = global_pooling;
param.strides = {stride, stride}; param.strides = {stride, stride};
param.paddings = {pad, pad}; param.paddings = {pad, pad};
param.exclusive = true; param.exclusive = exclusive;
param.ceil_mode = ceil_mode;
param.adaptive = false; param.adaptive = false;
param.ceil_mode = false;
param.use_quantizer = false; param.use_quantizer = false;
const std::vector<int64_t>& output_shape = const std::vector<int64_t>& output_shape =
...@@ -251,6 +235,14 @@ TEST(pool_arm, compute) { ...@@ -251,6 +235,14 @@ TEST(pool_arm, compute) {
output.Resize(DDim(output_shape)); output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape)); output_ref.Resize(DDim(output_shape));
auto* output_data = output.mutable_data<float>();
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); ++i) {
output_data[i] = -2;
output_ref_data[i] = -2;
}
// compute // compute
pool.SetParam(param); pool.SetParam(param);
pool.Run(); pool.Run();
...@@ -260,41 +252,24 @@ TEST(pool_arm, compute) { ...@@ -260,41 +252,24 @@ TEST(pool_arm, compute) {
pool_compute_ref(param); pool_compute_ref(param);
// compare // compare
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) { for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4);
float tmp = output_data[i] - output_ref_data[i];
tmp = tmp < 0 ? -tmp : tmp;
if (tmp > 1e-5) {
std::cout << "output_data[0]:" << output_data[0]
<< std::endl;
std::cout << "output_ref_data[0]:" << output_ref_data[0]
<< std::endl;
std::cout
<< "x.dims().production():" << x.dims().production()
<< std::endl;
for (int ii = 0; ii < x.dims().production(); ++ii) {
std::cout << x_data[ii] << " ";
}
std::cout;
exit(0);
}
} }
VLOG(3) << "compare pass"; VLOG(3) << "compare pass";
} }
} }
} }
} }
} // pad }
} // stride }
//} // ksize TODO(yuanshuai): ksize enable }
} // global_pooling }
} // pooling_type }
}
}
} }
TEST(pool, retrive_op) { TEST(pool_arm, retrive_op) {
auto pool = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>( auto pool = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"pool2d"); "pool2d");
ASSERT_FALSE(pool.empty()); ASSERT_FALSE(pool.empty());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册