提交 54f421c6 编写于 作者: H hong19860320

fix ARM kernel for pool2d op

test=develop
上级 ee0e12fe
因为 它太大了无法显示 source diff 。你可以改为 查看blob
......@@ -25,7 +25,7 @@ namespace arm {
namespace math {
// !pooling fp32 Op
void pooling_basic(const void* din, void* dout, int num, int chout, int hout,
void pooling_basic(const float* din, float* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
......@@ -33,77 +33,39 @@ void pooling_basic(const void* din, void* dout, int num, int chout, int hout,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling_global(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling_global_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling_global_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling2x2s2_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling2x2s2_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
bool exclusive);
void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s1p1_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout,
void pooling3x3s1p1_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
bool exclusive);
void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s2p1_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout,
void pooling3x3s2p1_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
bool exclusive);
void pooling3x3s2p0_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout,
void pooling3x3s2p0_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
bool exclusive);
} // namespace math
} // namespace arm
......
......@@ -48,120 +48,96 @@ void PoolCompute::Run() {
bool use_quantizer = param.use_quantizer;
std::string& data_format = param.data_format;
if (param.global_pooling) {
bool kps_equal = (ksize[0] == ksize[1]) && (strides[0] == strides[1]) &&
(paddings[0] == paddings[1]);
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_dims[i + 2]);
}
}
#if 0
for (int i = 0; i < in_dims.size(); ++i) {
LOG(INFO) << "in_dims[" << i << "]:" << in_dims[i];
}
for (int i = 0; i < out_dims.size(); ++i) {
LOG(INFO) << "out_dims[" << i << "]:" << out_dims[i];
}
for (int i = 0; i < ksize.size(); ++i) {
LOG(INFO) << "ksize[" << i << "]:" << ksize[i];
}
for (int i = 0; i < strides.size(); ++i) {
LOG(INFO) << "strides[" << i << "]:" << strides[i];
}
for (int i = 0; i < paddings.size(); ++i) {
LOG(INFO) << "paddings[" << i << "]:" << paddings[i];
}
LOG(INFO) << "global_pooling:" << global_pooling;
LOG(INFO) << "exclusive:" << exclusive;
LOG(INFO) << "adaptive:" << adaptive;
LOG(INFO) << "ceil_mode:" << ceil_mode;
LOG(INFO) << "use_quantizer:" << use_quantizer;
LOG(INFO) << "data_format:" << data_format;
LOG(INFO) << "din:" << din;
LOG(INFO) << "dout:" << dout;
#endif
// global
if (global_pooling == true) {
lite::arm::math::pooling_global(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1]) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
lite::arm::math::pooling_global_max(din, dout, out_dims[0], out_dims[1],
out_dims[2], out_dims[3], in_dims[1],
in_dims[2], in_dims[3]);
VLOG(3) << "invoking pooling_global_max";
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_ave(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
}
} else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 1 &&
strides[0] == strides[1] && paddings[0] == 1) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s1p1_ave(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
}
} else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == 0) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p0_ave(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
}
} else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == 1) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p1_ave(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
lite::arm::math::pooling_global_avg(din, dout, out_dims[0], out_dims[1],
out_dims[2], out_dims[3], in_dims[1],
in_dims[2], in_dims[3]);
VLOG(3) << "invoking pooling_global_ave";
return;
}
} else {
lite::arm::math::pooling_basic(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din, dout, out_dims[0], out_dims[1],
out_dims[2], out_dims[3], in_dims[1],
in_dims[2], in_dims[3]);
VLOG(3) << "invoking pooling2x2s2_max";
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din, dout, out_dims[0], out_dims[1],
out_dims[2], out_dims[3], in_dims[1],
in_dims[2], in_dims[3], exclusive);
VLOG(3) << "invoking pooling2x2s2_avg";
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din, dout, out_dims[0], out_dims[1],
out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3]);
VLOG(3) << "invokingpooling3x3s1p1_max";
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s1p1_avg(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], exclusive);
VLOG(3) << "invoking pooling3x3s1p1_avg";
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din, dout, out_dims[0], out_dims[1],
out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3]);
VLOG(3) << "pooling3x3s2p0_max";
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p0_avg(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], exclusive);
VLOG(3) << "invoking pooling3x3s2p0_avg";
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din, dout, out_dims[0], out_dims[1],
out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3]);
VLOG(3) << "invoking pooling3x3s2p1_max";
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p1_avg(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], exclusive);
VLOG(3) << "invoking pooling3x3s2p1_avg";
return;
}
}
}
return;
lite::arm::math::pooling_basic(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], in_dims[1],
in_dims[2], in_dims[3], ksize, strides, paddings, global_pooling,
exclusive, adaptive, ceil_mode, use_quantizer, pooling_type);
VLOG(3) << "invoking pooling_basic";
}
TargetType PoolCompute::target() const { return TARGET(kARM); }
PrecisionType PoolCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -29,9 +29,6 @@ class PoolCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
void PrepareForRun() override;
void Run() override;
TargetType target() const override;
PrecisionType precision() const override;
virtual ~PoolCompute() = default;
};
......
......@@ -101,94 +101,65 @@ void pool_compute_ref(const operators::PoolParam& param) {
int pad_w = paddings[1];
if (global_pooling == true) {
ksize[0] = in_h;
ksize[1] = in_w;
for (int n = 0; n < in_n; ++n) {
for (int c = 0; c < in_c; ++c) {
const float* src = src_ptr + n * in_c * in_h * in_w + c * in_h * in_w;
const float* src = src_ptr + n * size_in_n + c * size_in_c;
float res = src[0];
if (pooling_type == "max") {
for (int i = 1; i < in_h * in_w; ++i) {
for (int i = 1; i < size_in_c; ++i) {
float cur_val = src[i];
res = cur_val > res ? cur_val : res;
}
} else if (pooling_type == "avg") {
for (int i = 1; i < in_h * in_w; ++i) {
for (int i = 1; i < size_in_c; ++i) {
float cur_val = src[i];
res += cur_val;
}
res /= (in_h * in_w);
res /= size_in_c;
}
dst_ptr[n * in_c * out_h * out_w + c] = res;
dst_ptr[n * size_out_n + c] = res;
}
}
return;
}
for (int ind_n = 0; ind_n < in_n; ++ind_n) {
for (int ind_c = 0; ind_c < in_c; ++ind_c) {
for (int ind_h = 0; ind_h < out_h; ++ind_h) {
int sh = ind_h * stride_h;
int eh = sh + window_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > in_h ? in_h : eh - pad_h;
for (int ind_w = 0; ind_w < out_w; ++ind_w) {
int sw = ind_w * stride_w;
int ew = sw + window_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > in_w ? in_w : ew - pad_w;
float result = static_cast<float>(0);
int dst_ind =
ind_n * size_out_n + ind_c * size_out_c + ind_h * out_w + ind_w;
for (int kh = sh; kh < eh; ++kh) {
for (int kw = sw; kw < ew; ++kw) {
int src_ind =
ind_n * size_in_n + ind_c * size_in_c + kh * in_w + kw;
if (kh == sh && kw == sw) {
result = src_ptr[src_ind];
} else {
if (pooling_type == "max") {
result =
result >= src_ptr[src_ind] ? result : src_ptr[src_ind];
}
if (pooling_type == "avg" && exclusive == false) {
// Pooling_average_include_padding
result += src_ptr[src_ind];
}
if (pooling_type == "avg" && exclusive == true) {
// Pooling_average_include_padding
result += src_ptr[src_ind];
} else {
for (int n = 0; n < in_n; ++n) {
for (int c = 0; c < in_c; ++c) {
for (int h = 0; h < out_h; ++h) {
int sh = h * stride_h;
int eh = sh + window_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > in_h ? in_h : eh - pad_h;
for (int w = 0; w < out_w; ++w) {
int sw = w * stride_w;
int ew = sw + window_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > in_w ? in_w : ew - pad_w;
int pooling_size = (ew - sw) * (eh - sh);
if (pooling_size == 0) continue;
float res = 0.f;
for (int kh = sh; kh < eh; ++kh) {
for (int kw = sw; kw < ew; ++kw) {
int src_idx = n * size_in_n + c * size_in_c + kh * in_w + kw;
if (kh == sh && kw == sw) {
res = src_ptr[src_idx];
} else {
if (pooling_type == "max") {
res = res >= src_ptr[src_idx] ? res : src_ptr[src_idx];
}
if (pooling_type == "avg") {
res += src_ptr[src_idx];
}
}
}
}
}
if (pooling_type == "avg" && exclusive == false) {
// Pooling_average_include_padding
// result /= param.window_h * param.window_w;
// LOG(ERROR)<<"cpu"<<param.window_h * param.window_w;
int bh = window_h;
int bw = window_w;
if (ew == in_w) {
bw = sw + window_w >= in_w + pad_w ? in_w + pad_w : sw + window_w;
bw -= sw;
}
if (eh == in_h) {
bh = sh + window_h >= in_h + pad_h ? in_h + pad_h : sh + window_h;
bh -= sh;
if (pooling_type == "avg") {
if (exclusive) {
res /= pooling_size;
} else {
res /= window_h * window_w;
}
}
result /= bh * bw;
}
if (pooling_type == "avg" && exclusive == true) {
// Pooling_average_exclude_padding
result /= (ew - sw) * (eh - sh);
dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = res;
}
dst_ptr[dst_ind] = result;
}
}
}
......@@ -209,92 +180,96 @@ TEST(pool_arm, compute) {
lite::Tensor output;
lite::Tensor output_ref;
for (auto pooling_type : {"avg", "max"}) {
for (auto global_pooling : {true}) {
// for (auto ksize: {3}) { // TODO(yuanshuai): ksize enable 2, 3
for (auto stride : {1, 2}) {
for (auto pad : {0, 1}) {
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11 /* ,1024 */}) { // speedup for ci
for (auto h : {2, 3, 4, 11}) {
for (auto w : {2, 3, 4, 11}) {
LOG(INFO) << "n:" << n << " c:" << c << " h:" << h
<< " w:" << w // << " ksize:" << ksize
<< " stride:" << stride << " pad:" << pad
<< " pooling_type:" << pooling_type
<< " global_pooling:" << global_pooling;
// init x, output
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_data = x.mutable_data<float>();
for (int i = 0; i < x.dims().production(); ++i) {
x_data[i] = i;
}
// fill param
param.x = &x;
param.output = &output;
param.pooling_type = pooling_type;
// param.ksize = {ksize, ksize}; //TODO(yuanshuai): ksize
// enable
param.ksize = {h, w};
param.global_pooling = global_pooling;
param.strides = {stride, stride};
param.paddings = {pad, pad};
param.exclusive = true;
param.adaptive = false;
param.ceil_mode = false;
param.use_quantizer = false;
const std::vector<int64_t>& output_shape =
compute_output_shape(&param);
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
// compute
pool.SetParam(param);
pool.Run();
// compute ref
param.output = &output_ref;
pool_compute_ref(param);
// compare
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
float tmp = output_data[i] - output_ref_data[i];
tmp = tmp < 0 ? -tmp : tmp;
if (tmp > 1e-5) {
std::cout << "output_data[0]:" << output_data[0]
<< std::endl;
std::cout << "output_ref_data[0]:" << output_ref_data[0]
<< std::endl;
std::cout
<< "x.dims().production():" << x.dims().production()
<< std::endl;
for (int ii = 0; ii < x.dims().production(); ++ii) {
std::cout << x_data[ii] << " ";
// speedup for ci
for (auto pooling_type : {"max", "avg"}) {
for (auto ceil_mode : {true, false}) {
for (auto global_pooling : {true, false}) {
for (auto exclusive : {true, false}) {
for (auto ksize : {2, 3}) {
for (auto stride : {1, 2}) {
for (auto pad : {0, 1}) {
for (auto n : {1, 2}) {
for (auto c : {1, 3}) {
#if 1
for (auto h : {2, 3, 4, 11}) {
for (auto w : {2, 3, 4, 11}) {
#else
for (int h = 2; h < 25; h++) {
for (int w = 2; w < 25; w++) {
#endif
VLOG(3) << "n:" << n << " c:" << c << " h:" << h
<< " w:" << w << " ksize:" << ksize
<< " stride:" << stride << " pad:" << pad
<< " exclusive:" << exclusive
<< " global_pooling:" << global_pooling
<< " ceil_mode: " << ceil_mode
<< " pooling_type:" << pooling_type;
// init x, output
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_data = x.mutable_data<float>();
for (int i = 0; i < x.dims().production(); ++i) {
float sign = i % 3 == 0 ? -0.03 : 0.05f;
x_data[i] = sign * (i % 128);
}
// fill param
param.x = &x;
param.output = &output;
param.pooling_type = pooling_type;
if (global_pooling) {
param.ksize = {h, w};
} else {
param.ksize = {ksize, ksize};
}
param.global_pooling = global_pooling;
param.strides = {stride, stride};
param.paddings = {pad, pad};
param.exclusive = exclusive;
param.ceil_mode = ceil_mode;
param.adaptive = false;
param.use_quantizer = false;
const std::vector<int64_t>& output_shape =
compute_output_shape(&param);
output.Resize(DDim(output_shape));
output_ref.Resize(DDim(output_shape));
auto* output_data = output.mutable_data<float>();
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); ++i) {
output_data[i] = -2;
output_ref_data[i] = -2;
}
// compute
pool.SetParam(param);
pool.Run();
// compute ref
param.output = &output_ref;
pool_compute_ref(param);
// compare
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4);
}
VLOG(3) << "compare pass";
}
std::cout;
exit(0);
}
}
VLOG(3) << "compare pass";
}
}
}
}
} // pad
} // stride
//} // ksize TODO(yuanshuai): ksize enable
} // global_pooling
} // pooling_type
}
}
}
}
}
TEST(pool, retrive_op) {
TEST(pool_arm, retrive_op) {
auto pool = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"pool2d");
ASSERT_FALSE(pool.empty());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册