提交 07329c60 编写于 作者: Z ZhenWang

add int8_t type pooling support.

上级 475d7e91
...@@ -23,20 +23,22 @@ namespace paddle_mobile { ...@@ -23,20 +23,22 @@ namespace paddle_mobile {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
inline void PoolBasic(std::string pooling_type, std::vector<int> ksize, template <typename T, typename S>
std::vector<int> strides, std::vector<int> paddings, void PoolBasic(std::string pooling_type, std::vector<int> ksize,
const Tensor *in_x, Tensor *out) { std::vector<int> strides, std::vector<int> paddings,
const Tensor *in_x, Tensor *out) {
if (pooling_type == "max") { if (pooling_type == "max") {
math::PoolFunctor<CPU, math::MaxPool<float>, float> pool2d_forward; math::PoolFunctor<CPU, math::MaxPool<T>, T> pool2d_forward;
math::MaxPool<float> pool_process; math::MaxPool<T> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out); pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
math::PoolFunctor<CPU, math::AvgPool<float>, float> pool2d_forward; math::PoolFunctor<CPU, math::AvgPool<T, S>, T> pool2d_forward;
math::AvgPool<float> pool_process; math::AvgPool<T, S> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out); pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
} }
} }
template <typename P> template <typename P>
void PoolCompute(const PoolParam<CPU> &param) { void PoolCompute(const PoolParam<CPU> &param) {
const Tensor *in_x = param.Input(); const Tensor *in_x = param.Input();
...@@ -52,50 +54,65 @@ void PoolCompute(const PoolParam<CPU> &param) { ...@@ -52,50 +54,65 @@ void PoolCompute(const PoolParam<CPU> &param) {
LOG(paddle_mobile::LogLevel::kLOG_ERROR) LOG(paddle_mobile::LogLevel::kLOG_ERROR)
<< "Pool op only supports 2D and 3D input."; << "Pool op only supports 2D and 3D input.";
} }
if (param.isGlobalPooling()) { if (param.isGlobalPooling()) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0; paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
} }
if (ksize[0] == 3 && ksize[0] == ksize[1]) { if (in_x->type() == typeid(int8_t)) {
if (pooling_type == "max") { if (pooling_type == "max" && ksize[0] == 3 && ksize[0] == ksize[1]) {
if (strides[0] == strides[1] && strides[0] == 1 && if (strides[0] == strides[1] && strides[0] == 1) {
paddings[0] == paddings[1] && paddings[1] == 1) { math::Pool3x3Maxs1_int8(in_x, out, paddings[0], paddings[1]);
math::Pool3x3Maxs1p1(in_x, out); } else if (strides[0] == strides[1] && strides[0] == 2) {
math::Pool3x3Maxs2_int8(in_x, out, paddings[0], paddings[1]);
} else { } else {
math::Pool3x3Max(strides, paddings, in_x, out); math::Pool3x3Max_int8(strides, paddings, in_x, out);
}
} else if (pooling_type == "avg") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Avgs1p1(in_x, out);
} else {
math::Pool3x3Avg(strides, paddings, in_x, out);
} }
} else {
PoolBasic<int8_t, int32_t>(pooling_type, ksize, strides, paddings, in_x,
out);
} }
} else {
if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Maxs1p1(in_x, out);
} else {
math::Pool3x3Max(strides, paddings, in_x, out);
}
} else if (pooling_type == "avg") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Avgs1p1(in_x, out);
} else {
math::Pool3x3Avg(strides, paddings, in_x, out);
}
}
} else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 && } else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == paddings[1] && strides[0] == strides[1] && paddings[0] == paddings[1] &&
paddings[1] == 0) { paddings[1] == 0) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x, out);
#else #else
/// todo: fix bug in Pool2x2 /// todo: fix bug in Pool2x2
if (pooling_type == "max") { if (pooling_type == "max") {
math::Pool2x2Maxs2p0(strides, paddings, in_x, out); math::Pool2x2Maxs2p0(strides, paddings, in_x, out);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
math::Pool2x2Avgs2p0(strides, paddings, in_x, out); math::Pool2x2Avgs2p0(strides, paddings, in_x, out);
} }
#endif #endif
#else #else
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x, out);
#endif // __ARM_NEON #endif // __ARM_NEON
} else { } else {
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
out);
}
} }
} }
......
...@@ -38,6 +38,7 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { ...@@ -38,6 +38,7 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
const int input_width = static_cast<int>(input->dims()[3]); const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]); const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]); const int output_width = static_cast<int>(output->dims()[3]);
output->mutable_data<float>();
const int hxw = input_height * input_width; const int hxw = input_height * input_width;
...@@ -472,7 +473,7 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { ...@@ -472,7 +473,7 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
const int inputdata_channel_stride = h_in * w_in; const int inputdata_channel_stride = h_in * w_in;
const int input_batch_stride = output_channels * inputdata_channel_stride; const int input_batch_stride = output_channels * inputdata_channel_stride;
const int output_batch_stride = output_channels * outputdata_channel_stride; const int output_batch_stride = output_channels * outputdata_channel_stride;
float *out_data = output->data<float>(); float *out_data = output->mutable_data<float>();
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
for (int k = 0; k < batch_size; ++k) { for (int k = 0; k < batch_size; ++k) {
#pragma omp parallel for #pragma omp parallel for
......
...@@ -28,15 +28,21 @@ limitations under the License. */ ...@@ -28,15 +28,21 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
using framework::Tensor; void Pool3x3Avgs1p1(const framework::Tensor *input, framework::Tensor *output);
using std::vector; void Pool3x3Maxs1p1(const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Avgs1p1(const Tensor *input, Tensor *output); void Pool3x3Max(std::vector<int> strides, std::vector<int> paddings,
void Pool3x3Maxs1p1(const Tensor *input, Tensor *output); const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output); void Pool3x3Avg(std::vector<int> strides, std::vector<int> paddings,
const framework::Tensor *in_x, framework::Tensor *out);
void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *in_x,
Tensor *out); void Pool3x3Maxs1_int8(const framework::Tensor *input,
framework::Tensor *output, int32_t pad_h, int32_t pad_w);
void Pool3x3Maxs2_int8(const framework::Tensor *input,
framework::Tensor *output, int32_t pad_h, int32_t pad_w);
void Pool3x3Max_int8(const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *input, framework::Tensor *output);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
此差异已折叠。
...@@ -70,15 +70,15 @@ class PoolFunctor<CPU, PoolProcess, T> { ...@@ -70,15 +70,15 @@ class PoolFunctor<CPU, PoolProcess, T> {
int wend = std::min(wstart + ksize_width, input_width); int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
T ele = pool_process.initial(); auto ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
pool_process.compute(input_data[h * input_width + w], &ele); pool_process.compute(input_data[h * input_width + w], &ele);
} }
} }
int pool_size = (hend - hstart) * (wend - wstart); int pool_size = (hend - hstart) * (wend - wstart);
pool_process.finalize(static_cast<T>(pool_size), &ele); pool_process.finalize(static_cast<float>(pool_size), &ele);
output_data[ph * output_width + pw] = ele; output_data[ph * output_width + pw] = static_cast<T>(ele);
} }
} }
input_data += input_stride; input_data += input_stride;
...@@ -88,8 +88,10 @@ class PoolFunctor<CPU, PoolProcess, T> { ...@@ -88,8 +88,10 @@ class PoolFunctor<CPU, PoolProcess, T> {
} }
}; };
template class PoolFunctor<CPU, math::AvgPool<float>, float>; template class PoolFunctor<CPU, math::AvgPool<float, float>, float>;
template class PoolFunctor<CPU, math::MaxPool<float>, float>; template class PoolFunctor<CPU, math::MaxPool<float>, float>;
template class PoolFunctor<CPU, math::AvgPool<int8_t, int32_t>, int8_t>;
template class PoolFunctor<CPU, math::MaxPool<int8_t>, int8_t>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#pragma once #pragma once
#include <climits>
#include <cmath>
#include "common/log.h" #include "common/log.h"
#include "framework/tensor.h" #include "framework/tensor.h"
#include "pool_2x2.h" #include "pool_2x2.h"
...@@ -37,24 +39,42 @@ namespace math { ...@@ -37,24 +39,42 @@ namespace math {
* in pool pooling, and finally takes the average. * in pool pooling, and finally takes the average.
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively. * MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
*/ */
template <class T> template <typename T>
class MaxPool { class MaxPool {
public: public:
inline T initial() { return static_cast<T>(-FLT_MAX); } inline T initial() {
if (typeid(T) == typeid(int8_t)) {
return static_cast<T>(-SCHAR_MAX);
}
return static_cast<T>(-FLT_MAX);
}
inline void compute(const T &x, T *y) { *y = *y > x ? *y : x; } inline void compute(const T &x, T *y) { *y = *y > x ? *y : x; }
inline void finalize(const T &pool_field, T *y) {} inline void finalize(const T &pool_field, T *y) {}
}; };
template <class T> template <typename Itype, typename Otype>
class AvgPool { class AvgPool {
public: public:
inline T initial() { return static_cast<T>(0); } inline Otype initial() { return static_cast<Otype>(0); }
inline void compute(const T &x, T *y) { *y += x; } inline void compute(const Itype &x, Otype *y) { *y += x; }
inline void finalize(const T &pool_field, T *y) { *y /= pool_field; } inline void finalize(const float &pool_field, Otype *y) {
if (typeid(Itype) == typeid(int8_t)) {
float tmp = *y / pool_field;
if (tmp > SCHAR_MAX) {
*y = SCHAR_MAX;
} else if (tmp < -SCHAR_MAX) {
*y = -SCHAR_MAX;
} else {
*y = static_cast<Otype>(std::round(tmp));
}
} else {
*y /= pool_field;
}
}
}; };
template <typename DeviceType, typename PoolProcess, typename T> template <typename DeviceType, typename PoolProcess, typename T>
......
...@@ -269,8 +269,8 @@ if (NOT FOUND_MATCH) ...@@ -269,8 +269,8 @@ if (NOT FOUND_MATCH)
#gen test #gen test
ADD_EXECUTABLE(test-pool operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test-pool-op operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-pool paddle-mobile) target_link_libraries(test-pool-op paddle-mobile)
#gen test #gen test
ADD_EXECUTABLE(test-softmax operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test-softmax operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h)
......
...@@ -251,7 +251,7 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) { ...@@ -251,7 +251,7 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) {
attrs["groups"].Set<int>(1); attrs["groups"].Set<int>(1);
attrs["axis"].Set<int>(0); attrs["axis"].Set<int>(0);
auto *op = new operators::FusionConvAddReluInt8Op<CPU, int8_t>( auto *op = new operators::FusionConvAddReluInt8Op<CPU, T>(
"fusion_conv_add_relu_int8", inputs, outputs, attrs, scope); "fusion_conv_add_relu_int8", inputs, outputs, attrs, scope);
op->InferShape(); op->InferShape();
op->Init(); op->Init();
......
...@@ -12,30 +12,279 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,30 +12,279 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <iostream>
#include "../test_include.h" #include "../test_include.h"
#include "operators/kernel/central-arm-func/pool_arm_func.h"
#include "operators/pool_op.h" #include "operators/pool_op.h"
int main() { namespace paddle_mobile {
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader; static int PoolOutputSize(int input_size, int filter_size, int padding,
auto program = loader.Load(std::string(g_googlenet)); int stride, bool ceil_mode) {
if (program.originProgram == nullptr) { int output_size;
DLOG << "program read file"; if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
} }
return output_size;
}
template <typename T>
static void PoolAvgPad0(std::vector<int> ksize, std::vector<int> strides,
const framework::Tensor *input,
framework::Tensor *out) {
const int32_t batch_size = input->dims()[0];
const int32_t input_c = input->dims()[1];
const int32_t input_h = input->dims()[2];
const int32_t input_w = input->dims()[3];
const int32_t out_c = out->dims()[1];
const int32_t out_h = out->dims()[2];
const int32_t out_w = out->dims()[3];
const int32_t kernel_h = ksize[0];
const int32_t kernel_w = ksize[1];
const int32_t stride_h = strides[0];
const int32_t stride_w = strides[1];
const int32_t inputdata_channel_stride = input_h * input_w;
const int32_t input_batch_stride = input_c * inputdata_channel_stride;
const int32_t outputdata_channel_stride = out_h * out_w;
const int32_t output_batch_stride = out_c * outputdata_channel_stride;
T *out_data = out->mutable_data<T>();
const T *input_data = input->data<T>();
const T **rows = new const T *[kernel_h];
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < out_c; ++j) {
const T *img_in = input_data + j * inputdata_channel_stride;
T *img_out = out_data + j * outputdata_channel_stride;
for (int k = 0; k < out_h; ++k) {
for (int m = 0; m < kernel_h; ++m) {
rows[m] = img_in + (stride_h * k + m) * input_w;
}
int32_t left = out_w;
while (left > 0) {
float tmp = 0;
for (int m = 0; m < kernel_h; ++m) {
for (int l = 0; l < kernel_w; ++l) {
tmp += rows[m][l];
}
}
if (typeid(T) == typeid(int8_t)) {
tmp = tmp / (kernel_h * kernel_w);
if (tmp < -127) {
*img_out = -127;
} else if (tmp > 127) {
*img_out = 127;
} else {
*img_out = static_cast<T>(std::round(tmp));
}
} else {
*img_out = static_cast<T>(tmp / (kernel_h * kernel_w));
}
for (int m = 0; m < kernel_h; ++m) {
rows[m] += stride_w;
}
img_out++;
left--;
}
}
}
input_data += input_batch_stride;
out_data += output_batch_stride;
}
delete[] rows;
}
template <typename T, int CeilMode, int PoolType, int Kernel, int Pad,
int Stride>
int TestPoolOp(int in_channels, int in_height, int in_width) {
int kernel_h = Kernel;
int kernel_w = Kernel;
int pad_h = Pad;
int pad_w = Pad;
int stride_h = Stride;
int stride_w = Stride;
bool ceil_mode = CeilMode != 0;
std::string pooling_type = (PoolType == 0 ? "max" : "avg");
int batch_size = 1;
int input_c = in_channels;
int input_h = in_height;
int input_w = in_width;
framework::DDim input_shape =
framework::make_ddim({batch_size, input_c, input_h, input_w});
std::vector<int64_t> output_shape_v({batch_size, input_c});
output_shape_v.push_back(
PoolOutputSize(input_h, kernel_h, pad_h, stride_h, ceil_mode));
output_shape_v.push_back(
PoolOutputSize(input_w, kernel_w, pad_w, stride_w, ceil_mode));
framework::DDim output_shape = framework::make_ddim(output_shape_v);
Executor4Test<paddle_mobile::CPU, VariableNameMap inputs;
paddle_mobile::operators::PoolOp<paddle_mobile::CPU, float>> VariableNameMap outputs;
executor(program, "pool2d"); auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
paddle_mobile::framework::Tensor input; auto input_var = scope.get()->Var("input");
SetupTensor<float>(&input, {1, 64, 112, 112}, static_cast<float>(0), auto input = input_var->template GetMutable<framework::LoDTensor>();
static_cast<float>(1)); SetupTensor<T>(input, input_shape, -127, 127);
auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 56, 56});
auto output =
executor.Predict(input, "conv2d_0.tmp_1", "pool2d_0.tmp_0", out_ddim);
float *output_ptr = output->data<float>(); auto output_var = scope.get()->Var("output");
for (int j = 0; j < output->numel(); ++j) { framework::AttributeMap attrs;
DLOG << " value of output: " << output_ptr[j]; attrs["pooling_type"].SetString(pooling_type);
attrs["ksize"].Set<vector<int>>(std::vector<int>({kernel_h, kernel_w}));
attrs["strides"].Set<vector<int>>(std::vector<int>({stride_h, stride_w}));
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad_h, pad_w}));
attrs["ceil_mode"].Set<bool>(false);
attrs["global_pooling"].Set<bool>(false);
auto *op = new operators::PoolOp<CPU, float>("pool2d", inputs, outputs, attrs,
scope);
op->InferShape();
op->Init();
op->Run();
framework::Tensor output_cmp;
output_cmp.mutable_data<T>(output_shape);
if (pooling_type == "avg" && pad_h == 0 && pad_h == pad_w) {
PoolAvgPad0<T>(std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, input, &output_cmp);
} else {
if (typeid(T) == typeid(int8_t)) {
operators::PoolBasic<int8_t, int32_t>(
pooling_type, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, std::vector<int>{pad_h, pad_w},
input, &output_cmp);
} else {
operators::PoolBasic<float, float>(
pooling_type, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, std::vector<int>{pad_h, pad_w},
input, &output_cmp);
}
}
// compare results
int eq = 0;
int neq = 0;
auto output = output_var->template Get<framework::LoDTensor>();
const T *output_data = output->data<T>();
T *output_cmp_data = output_cmp.data<T>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
"The execution of test_pool_op is failed!");
if (output_data[i] == output_cmp_data[i]) {
++eq;
} else {
++neq;
}
} }
std::cout << "eq = " << eq << ", neq = " << neq << std::endl;
delete op;
return 0; return 0;
} }
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
if (argc < 4) {
LOG(paddle_mobile::kLOG_INFO)
<< "Usage:\n"
<< " ./test-pool-op in_channels in_height in_width \n"
<< " params:\n"
<< " -in_channels: int, input image's channels\n"
<< " -in_height: int, input image's height\n"
<< " -in_width: int, input image's width\n";
return 1;
}
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1";
paddle_mobile::TestPoolOp<float, 0, 0, 3, 1, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<float, 0, 0, 3, 0, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 0, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 1, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 0, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 1, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 1, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 3, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=3, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 3, 3>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 1>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 2>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 3>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 3>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 4
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=4";
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height,
in_width);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
in_width);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册