提交 67a2f0ea 编写于 作者: Z ZhenWang

add int8_t type pooling support.

上级 ca3c0309
......@@ -23,20 +23,22 @@ namespace paddle_mobile {
namespace operators {
using framework::Tensor;
inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings,
const Tensor *in_x, Tensor *out) {
template <typename T, typename S>
void PoolBasic(std::string pooling_type, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings,
const Tensor *in_x, Tensor *out) {
if (pooling_type == "max") {
math::PoolFunctor<CPU, math::MaxPool<float>, float> pool2d_forward;
math::MaxPool<float> pool_process;
math::PoolFunctor<CPU, math::MaxPool<T>, T> pool2d_forward;
math::MaxPool<T> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
} else if (pooling_type == "avg") {
math::PoolFunctor<CPU, math::AvgPool<float>, float> pool2d_forward;
math::AvgPool<float> pool_process;
math::PoolFunctor<CPU, math::AvgPool<T, S>, T> pool2d_forward;
math::AvgPool<T, S> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
}
}
template <typename P>
void PoolCompute(const PoolParam<CPU> &param) {
const Tensor *in_x = param.Input();
......@@ -52,50 +54,65 @@ void PoolCompute(const PoolParam<CPU> &param) {
LOG(paddle_mobile::LogLevel::kLOG_ERROR)
<< "Pool op only supports 2D and 3D input.";
}
if (param.isGlobalPooling()) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
}
if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Maxs1p1(in_x, out);
if (in_x->type() == typeid(int8_t)) {
if (pooling_type == "max" && ksize[0] == 3 && ksize[0] == ksize[1]) {
if (strides[0] == strides[1] && strides[0] == 1) {
math::Pool3x3Maxs1_int8(in_x, out, paddings[0], paddings[1]);
} else if (strides[0] == strides[1] && strides[0] == 2) {
math::Pool3x3Maxs2_int8(in_x, out, paddings[0], paddings[1]);
} else {
math::Pool3x3Max(strides, paddings, in_x, out);
}
} else if (pooling_type == "avg") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Avgs1p1(in_x, out);
} else {
math::Pool3x3Avg(strides, paddings, in_x, out);
math::Pool3x3Max_int8(strides, paddings, in_x, out);
}
} else {
PoolBasic<int8_t, int32_t>(pooling_type, ksize, strides, paddings, in_x,
out);
}
} else {
if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Maxs1p1(in_x, out);
} else {
math::Pool3x3Max(strides, paddings, in_x, out);
}
} else if (pooling_type == "avg") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Avgs1p1(in_x, out);
} else {
math::Pool3x3Avg(strides, paddings, in_x, out);
}
}
} else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == paddings[1] &&
paddings[1] == 0) {
} else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == paddings[1] &&
paddings[1] == 0) {
#if __ARM_NEON
#if __aarch64__
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x, out);
#else
/// todo: fix bug in Pool2x2
if (pooling_type == "max") {
math::Pool2x2Maxs2p0(strides, paddings, in_x, out);
} else if (pooling_type == "avg") {
math::Pool2x2Avgs2p0(strides, paddings, in_x, out);
}
/// todo: fix bug in Pool2x2
if (pooling_type == "max") {
math::Pool2x2Maxs2p0(strides, paddings, in_x, out);
} else if (pooling_type == "avg") {
math::Pool2x2Avgs2p0(strides, paddings, in_x, out);
}
#endif
#else
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x, out);
#endif // __ARM_NEON
} else {
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
} else {
PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
out);
}
}
}
......
......@@ -38,6 +38,7 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
output->mutable_data<float>();
const int hxw = input_height * input_width;
......@@ -472,7 +473,7 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
const int inputdata_channel_stride = h_in * w_in;
const int input_batch_stride = output_channels * inputdata_channel_stride;
const int output_batch_stride = output_channels * outputdata_channel_stride;
float *out_data = output->data<float>();
float *out_data = output->mutable_data<float>();
const float *input_data = input->data<float>();
for (int k = 0; k < batch_size; ++k) {
#pragma omp parallel for
......
......@@ -28,15 +28,21 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
using framework::Tensor;
using std::vector;
void Pool3x3Avgs1p1(const Tensor *input, Tensor *output);
void Pool3x3Maxs1p1(const Tensor *input, Tensor *output);
void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output);
void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *in_x,
Tensor *out);
void Pool3x3Avgs1p1(const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Maxs1p1(const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Max(std::vector<int> strides, std::vector<int> paddings,
const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Avg(std::vector<int> strides, std::vector<int> paddings,
const framework::Tensor *in_x, framework::Tensor *out);
void Pool3x3Maxs1_int8(const framework::Tensor *input,
framework::Tensor *output, int32_t pad_h, int32_t pad_w);
void Pool3x3Maxs2_int8(const framework::Tensor *input,
framework::Tensor *output, int32_t pad_h, int32_t pad_w);
void Pool3x3Max_int8(const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *input, framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......
此差异已折叠。
......@@ -70,15 +70,15 @@ class PoolFunctor<CPU, PoolProcess, T> {
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
T ele = pool_process.initial();
auto ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.compute(input_data[h * input_width + w], &ele);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[ph * output_width + pw] = ele;
pool_process.finalize(static_cast<float>(pool_size), &ele);
output_data[ph * output_width + pw] = static_cast<T>(ele);
}
}
input_data += input_stride;
......@@ -88,8 +88,10 @@ class PoolFunctor<CPU, PoolProcess, T> {
}
};
template class PoolFunctor<CPU, math::AvgPool<float>, float>;
template class PoolFunctor<CPU, math::AvgPool<float, float>, float>;
template class PoolFunctor<CPU, math::MaxPool<float>, float>;
template class PoolFunctor<CPU, math::AvgPool<int8_t, int32_t>, int8_t>;
template class PoolFunctor<CPU, math::MaxPool<int8_t>, int8_t>;
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......
......@@ -16,6 +16,8 @@ limitations under the License. */
#pragma once
#include <climits>
#include <cmath>
#include "common/log.h"
#include "framework/tensor.h"
#include "pool_2x2.h"
......@@ -37,24 +39,42 @@ namespace math {
* in pool pooling, and finally takes the average.
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
*/
template <class T>
template <typename T>
class MaxPool {
public:
inline T initial() { return static_cast<T>(-FLT_MAX); }
inline T initial() {
if (typeid(T) == typeid(int8_t)) {
return static_cast<T>(-SCHAR_MAX);
}
return static_cast<T>(-FLT_MAX);
}
inline void compute(const T &x, T *y) { *y = *y > x ? *y : x; }
inline void finalize(const T &pool_field, T *y) {}
};
template <class T>
template <typename Itype, typename Otype>
class AvgPool {
public:
inline T initial() { return static_cast<T>(0); }
inline void compute(const T &x, T *y) { *y += x; }
inline void finalize(const T &pool_field, T *y) { *y /= pool_field; }
inline Otype initial() { return static_cast<Otype>(0); }
inline void compute(const Itype &x, Otype *y) { *y += x; }
inline void finalize(const float &pool_field, Otype *y) {
if (typeid(Itype) == typeid(int8_t)) {
float tmp = *y / pool_field;
if (tmp > SCHAR_MAX) {
*y = SCHAR_MAX;
} else if (tmp < -SCHAR_MAX) {
*y = -SCHAR_MAX;
} else {
*y = static_cast<Otype>(std::round(tmp));
}
} else {
*y /= pool_field;
}
}
};
template <typename DeviceType, typename PoolProcess, typename T>
......
......@@ -269,8 +269,8 @@ if (NOT FOUND_MATCH)
#gen test
ADD_EXECUTABLE(test-pool operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-pool paddle-mobile)
ADD_EXECUTABLE(test-pool-op operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-pool-op paddle-mobile)
#gen test
ADD_EXECUTABLE(test-softmax operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h)
......
......@@ -251,7 +251,7 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) {
attrs["groups"].Set<int>(1);
attrs["axis"].Set<int>(0);
auto *op = new operators::FusionConvAddReluInt8Op<CPU, int8_t>(
auto *op = new operators::FusionConvAddReluInt8Op<CPU, T>(
"fusion_conv_add_relu_int8", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
......
......@@ -12,30 +12,279 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_include.h"
#include "operators/kernel/central-arm-func/pool_arm_func.h"
#include "operators/pool_op.h"
int main() {
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_googlenet));
if (program.originProgram == nullptr) {
DLOG << "program read file";
namespace paddle_mobile {
static int PoolOutputSize(int input_size, int filter_size, int padding,
int stride, bool ceil_mode) {
int output_size;
if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
}
return output_size;
}
template <typename T>
static void PoolAvgPad0(std::vector<int> ksize, std::vector<int> strides,
const framework::Tensor *input,
framework::Tensor *out) {
const int32_t batch_size = input->dims()[0];
const int32_t input_c = input->dims()[1];
const int32_t input_h = input->dims()[2];
const int32_t input_w = input->dims()[3];
const int32_t out_c = out->dims()[1];
const int32_t out_h = out->dims()[2];
const int32_t out_w = out->dims()[3];
const int32_t kernel_h = ksize[0];
const int32_t kernel_w = ksize[1];
const int32_t stride_h = strides[0];
const int32_t stride_w = strides[1];
const int32_t inputdata_channel_stride = input_h * input_w;
const int32_t input_batch_stride = input_c * inputdata_channel_stride;
const int32_t outputdata_channel_stride = out_h * out_w;
const int32_t output_batch_stride = out_c * outputdata_channel_stride;
T *out_data = out->mutable_data<T>();
const T *input_data = input->data<T>();
const T **rows = new const T *[kernel_h];
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < out_c; ++j) {
const T *img_in = input_data + j * inputdata_channel_stride;
T *img_out = out_data + j * outputdata_channel_stride;
for (int k = 0; k < out_h; ++k) {
for (int m = 0; m < kernel_h; ++m) {
rows[m] = img_in + (stride_h * k + m) * input_w;
}
int32_t left = out_w;
while (left > 0) {
float tmp = 0;
for (int m = 0; m < kernel_h; ++m) {
for (int l = 0; l < kernel_w; ++l) {
tmp += rows[m][l];
}
}
if (typeid(T) == typeid(int8_t)) {
tmp = tmp / (kernel_h * kernel_w);
if (tmp < -127) {
*img_out = -127;
} else if (tmp > 127) {
*img_out = 127;
} else {
*img_out = static_cast<T>(std::round(tmp));
}
} else {
*img_out = static_cast<T>(tmp / (kernel_h * kernel_w));
}
for (int m = 0; m < kernel_h; ++m) {
rows[m] += stride_w;
}
img_out++;
left--;
}
}
}
input_data += input_batch_stride;
out_data += output_batch_stride;
}
delete[] rows;
}
template <typename T, int CeilMode, int PoolType, int Kernel, int Pad,
int Stride>
int TestPoolOp(int in_channels, int in_height, int in_width) {
int kernel_h = Kernel;
int kernel_w = Kernel;
int pad_h = Pad;
int pad_w = Pad;
int stride_h = Stride;
int stride_w = Stride;
bool ceil_mode = CeilMode != 0;
std::string pooling_type = (PoolType == 0 ? "max" : "avg");
int batch_size = 1;
int input_c = in_channels;
int input_h = in_height;
int input_w = in_width;
framework::DDim input_shape =
framework::make_ddim({batch_size, input_c, input_h, input_w});
std::vector<int64_t> output_shape_v({batch_size, input_c});
output_shape_v.push_back(
PoolOutputSize(input_h, kernel_h, pad_h, stride_h, ceil_mode));
output_shape_v.push_back(
PoolOutputSize(input_w, kernel_w, pad_w, stride_w, ceil_mode));
framework::DDim output_shape = framework::make_ddim(output_shape_v);
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::PoolOp<paddle_mobile::CPU, float>>
executor(program, "pool2d");
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {1, 64, 112, 112}, static_cast<float>(0),
static_cast<float>(1));
auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 56, 56});
auto output =
executor.Predict(input, "conv2d_0.tmp_1", "pool2d_0.tmp_0", out_ddim);
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<T>(input, input_shape, -127, 127);
float *output_ptr = output->data<float>();
for (int j = 0; j < output->numel(); ++j) {
DLOG << " value of output: " << output_ptr[j];
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["pooling_type"].SetString(pooling_type);
attrs["ksize"].Set<vector<int>>(std::vector<int>({kernel_h, kernel_w}));
attrs["strides"].Set<vector<int>>(std::vector<int>({stride_h, stride_w}));
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad_h, pad_w}));
attrs["ceil_mode"].Set<bool>(false);
attrs["global_pooling"].Set<bool>(false);
auto *op = new operators::PoolOp<CPU, float>("pool2d", inputs, outputs, attrs,
scope);
op->InferShape();
op->Init();
op->Run();
framework::Tensor output_cmp;
output_cmp.mutable_data<T>(output_shape);
if (pooling_type == "avg" && pad_h == 0 && pad_h == pad_w) {
PoolAvgPad0<T>(std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, input, &output_cmp);
} else {
if (typeid(T) == typeid(int8_t)) {
operators::PoolBasic<int8_t, int32_t>(
pooling_type, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, std::vector<int>{pad_h, pad_w},
input, &output_cmp);
} else {
operators::PoolBasic<float, float>(
pooling_type, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, std::vector<int>{pad_h, pad_w},
input, &output_cmp);
}
}
// compare results
int eq = 0;
int neq = 0;
auto output = output_var->template Get<framework::LoDTensor>();
const T *output_data = output->data<T>();
T *output_cmp_data = output_cmp.data<T>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
"The execution of test_pool_op is failed!");
if (output_data[i] == output_cmp_data[i]) {
++eq;
} else {
++neq;
}
}
std::cout << "eq = " << eq << ", neq = " << neq << std::endl;
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
if (argc < 4) {
LOG(paddle_mobile::kLOG_INFO)
<< "Usage:\n"
<< " ./test-pool-op in_channels in_height in_width \n"
<< " params:\n"
<< " -in_channels: int, input image's channels\n"
<< " -in_height: int, input image's height\n"
<< " -in_width: int, input image's width\n";
return 1;
}
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1";
paddle_mobile::TestPoolOp<float, 0, 0, 3, 1, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<float, 0, 0, 3, 0, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 0, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 1, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 0, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 1, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 1, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 3, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=3, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 3, 3>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 1>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 2>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 3>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 3>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 4
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=4";
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height,
in_width);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
in_width);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册