提交 11a2a2a1 编写于 作者: N nhzlx

Merge branch 'hongming/arm-fix' of http://10.87.145.36/inference/paddlelite into xzl/incubate/lite

......@@ -9,6 +9,8 @@ stages:
- build_mobile
check:prebuilt:
tags:
- lite
stage: ci
script:
#- pip3 install pre-commit
......@@ -24,17 +26,21 @@ check:prebuilt:
- /root/.cache
build:server:
tags:
- lite
image: $SERVER_LITE_DOCKER_IMAGE
stage: build_server
cache:
key: server_thirdparty
paths:
- build/third_party
- /root/.ccache
script:
#- export http_proxy=http://172.19.57.45:3128
#- export https_proxy=http://172.19.57.45:3128
- export http_proxy=http://agent.baidu.com:8118
- export https_proxy=http://agent.baidu.com:8118
- apt install ccache
- export http_proxy=http://172.19.57.45:3128
- export https_proxy=http://172.19.57.45:3128
#- export http_proxy=http://agent.baidu.com:8118
#- export https_proxy=http://agent.baidu.com:8118
- mkdir -p build
- cd build
- ../paddle/fluid/lite/tools/build.sh cmake_x86
......@@ -49,6 +55,8 @@ build:server:
- check:prebuilt
build:mobile:
tags:
- lite
stage: build_mobile
image: $MOBILE_LITE_DOCKER_IMAGE
cache:
......@@ -56,7 +64,9 @@ build:mobile:
paths:
- $MOBILE_LITE_CACHE0
- $MOBILE_LITE_CACHE1
- /root/.ccache
script:
- apt install ccache
- export http_proxy=http://172.19.57.45:3128
- export https_proxy=http://172.19.57.45:3128
- ./paddle/fluid/lite/tools/build.sh build_test_arm
......
......@@ -166,6 +166,7 @@ if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
#include(external/zlib) # download, build, install gtest
include(external/protobuf) # download, build, install protobuf
include(external/eigen) # download eigen3
include(ccache) # set ccache for compilation
include(generic) # simplify cmake module
include(configure) # add paddle env configuration
......
......@@ -13,17 +13,25 @@
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include <chrono>
#include "paddle/fluid/lite/core/mir/passes.h"
#endif
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
void Run(const char* model_dir) {
using Time = decltype(std::chrono::high_resolution_clock::now());
Time time() { return std::chrono::high_resolution_clock::now(); }
double time_diff(Time t1, Time t2) {
typedef std::chrono::microseconds ms;
auto diff = t2 - t1;
ms counter = std::chrono::duration_cast<ms>(diff);
return counter.count() / 1000.0;
}
void Run(const char* model_dir, int repeat) {
#ifdef LITE_WITH_ARM
DeviceInfo::Init();
#endif
lite::ExecutorLite predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}});
......@@ -32,13 +40,19 @@ void Run(const char* model_dir) {
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({3, 224, 224})));
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 3 * 224 * 224; i++) {
data[i] = i;
for (int i = 0; i < input_tensor->dims().production(); i++) {
data[i] = 1;
}
predictor.Run();
for (int i = 0; i < 10; i++) predictor.Run();
auto time1 = time();
for (int i = 0; i < repeat; i++) predictor.Run();
auto time2 = time();
std::cout << " predict cost: " << time_diff(time1, time2) / repeat << "ms"
<< std::endl;
auto* out = predictor.GetOutput(0);
LOG(INFO) << out << " memory size " << out->data_size();
......@@ -53,7 +67,7 @@ void Run(const char* model_dir) {
int main(int argc, char** argv) {
CHECK_EQ(argc, 2) << "usage: ./cmd <model_dir>";
paddle::lite::Run(argv[1]);
paddle::lite::Run(argv[1], 1);
return 0;
}
......@@ -66,7 +80,7 @@ USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
USE_LITE_OP(conv2d);
// USE_LITE_OP(batch_norm);
USE_LITE_OP(batch_norm);
USE_LITE_OP(relu);
USE_LITE_OP(depthwise_conv2d);
USE_LITE_OP(pool2d);
......@@ -85,7 +99,7 @@ USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def);
// USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
......
......@@ -41,15 +41,15 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
float32x4_t vsum0 = vaddq_f32(dinx0, diny0);
float32x4_t vsum1 = vaddq_f32(dinx1, diny1);
float32x4_t vsum2 = vaddq_f32(dinx2, diny2);
float32x4_t vsum3 = vaddq_f32(dinx3, diny3);
dinx0 = vaddq_f32(dinx0, diny0);
dinx1 = vaddq_f32(dinx1, diny1);
dinx2 = vaddq_f32(dinx2, diny2);
dinx3 = vaddq_f32(dinx3, diny3);
vst1q_f32(dout_ptr, vsum0);
vst1q_f32(dout_ptr + 4, vsum1);
vst1q_f32(dout_ptr + 8, vsum2);
vst1q_f32(dout_ptr + 12, vsum3);
vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, dinx3);
}
if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4);
......@@ -64,6 +64,69 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
}
}
template <>
void elementwise_add_axis<float>(const float* dinx, const float* diny,
float* dout, int batch, int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const float* din_ptr = dinx + offset;
const float diny_data = diny[j];
float* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t rb = vdupq_n_f32(diny_data);
for (int k = 0; k < cnt; ++k) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
din2 = vaddq_f32(din2, rb);
din3 = vaddq_f32(din3, rb);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
din0 = vaddq_f32(din0, rb);
vst1q_f32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (int p = 0; p < remain; p++) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
}
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -22,6 +22,10 @@ namespace math {
template <typename T>
void elementwise_add(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_add_axis(const T* dinx, const T* diny, T* dout, int batch,
int channels, int num);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -21,6 +21,7 @@ namespace mir {} // namespace mir
} // namespace lite
} // namespace paddle
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
USE_MIR_PASS(demo);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass);
......@@ -30,6 +31,7 @@ USE_MIR_PASS(type_target_transform_pass);
USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
#endif
USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze);
......@@ -18,10 +18,10 @@ import numpy as np
import paddle.fluid as fluid
from paddle.fluid.backward import append_backward
a = fluid.layers.data(name="a", shape=[100], dtype='float32')
label = fluid.layers.data(name="label", shape=[100], dtype='float32')
a = fluid.layers.data(name="a", shape=[2], dtype='float32')
label = fluid.layers.data(name="label", shape=[10], dtype='float32')
a1 = fluid.layers.fc(input=a, size=500, act=None, bias_attr=False)
a1 = fluid.layers.fc(input=a, size=3, act=None, bias_attr=False)
cost = fluid.layers.square_error_cost(a1, label)
avg_cost = fluid.layers.mean(cost)
......@@ -36,7 +36,7 @@ exe.run(fluid.default_startup_program())
with open('startup_program.pb', 'wb') as f:
f.write(fluid.default_startup_program().desc.serialize_to_string())
data_1 = np.array(numpy.random.random([100, 100]), dtype='float32')
#data_1 = np.array(numpy.random.random([100, 100]), dtype='float32')
#fluid.default_main_program().desc.
......@@ -50,7 +50,7 @@ with open('main_program.pb', 'wb') as f:
#outs = exe.run(program=prog, feed={'a':data_1, }, fetch_list=[cost])
sys.exit(0)
#sys.exit(0)
fluid.io.save_inference_model("./model2", [a.name], [a1], exe)
print(numpy.array(outs))
#print(numpy.array(outs))
......@@ -51,8 +51,8 @@ class Optimizer {
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_add_act_fuse_pass", //
"lite_fc_fuse_pass", //
"static_kernel_pick_pass", //
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_target_transform_pass", //
......
......@@ -100,15 +100,15 @@ void ConvCompute::Run() {
REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
// .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
// .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -45,7 +45,7 @@ void conv_compute_ref(const operators::ConvParam& param) {
bias_data = param.bias->mutable_data<float>();
}
bool flag_bias = bias_data != nullptr;
bool flag_relu = false; // TODO(hong19860320) param.relu
bool flag_relu = param.fuse_relu;
int num = input_dims[0];
int chout = output_dims[1];
......@@ -183,7 +183,8 @@ TEST(conv_arm, compute) {
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128);
float sign = i % 3 == 0 ? -1.0f : 1.0f;
input_data[i] = sign * static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] =
......@@ -208,7 +209,7 @@ TEST(conv_arm, compute) {
}
param.bias = &bias;
}
// TODO(hong19860320) param.relu = flag_relu;
param.fuse_relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations =
......
......@@ -25,8 +25,31 @@ void ElementwiseAddCompute::Run() {
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
int n = param.X->dims().production();
lite::arm::math::elementwise_add(x_data, y_data, out_data, n);
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
if (x_dims.size() == y_dims.size()) {
lite::arm::math::elementwise_add(x_data, y_data, out_data,
x_dims.production());
} else {
int batch = 1;
int channels = 1;
int num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
}
lite::arm::math::elementwise_add_axis(x_data, y_data, out_data, batch,
channels, num);
}
}
} // namespace arm
......
......@@ -41,40 +41,97 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
const dtype* x_data = param.X->data<const dtype>();
const dtype* y_data = param.Y->data<const dtype>();
dtype* out_data = param.Out->mutable_data<dtype>();
DDim dim = param.X->dims();
ASSERT_EQ(dim.data(), param.Out->dims().data());
for (int i = 0; i < dim.production(); i++) {
out_data[i] = x_data[i] + y_data[i];
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int axis = param.axis;
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
int batch = 1;
int channels = 1;
int num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
}
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
}
}
}
}
TEST(elementwise_add, compute) {
ElementwiseAddCompute elementwise_add;
operators::ElementwiseParam param;
lite::Tensor x, y, output, output_ref;
lite::Tensor x, y, out, out_ref;
x.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5})));
y.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5})));
out.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5})));
out_ref.Resize(DDim(std::vector<int64_t>({2, 3, 4, 5})));
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
auto* out_data = out.mutable_data<float>();
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = y_data[i] = i;
}
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 4, 11}) {
for (auto h : {1, 3, 4, 11}) {
for (auto w : {1, 3, 4, 11}) {
for (auto axis : {-1, 0, 1, 2, 3}) {
for (auto yd :
{std::vector<int64_t>({n}), std::vector<int64_t>({c}),
std::vector<int64_t>({h}), std::vector<int64_t>({w}),
std::vector<int64_t>({n, c}), std::vector<int64_t>({c, h}),
std::vector<int64_t>({h, w}), std::vector<int64_t>({n, c, h}),
std::vector<int64_t>({c, h, w}),
std::vector<int64_t>({n, c, h, w})}) {
auto x_dim = DDim(std::vector<int64_t>({n, c, h, w}));
auto y_dim = DDim(yd);
int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis;
param.X = &x;
param.Y = &y;
param.Out = &out;
elementwise_add.SetParam(param);
elementwise_add.Run();
if (axis_t + y_dim.size() > 4) continue;
bool flag = false;
for (int i = 0; i < y_dim.size(); i++) {
if (x_dim[i + axis_t] != y_dim[i]) flag = true;
}
if (flag) continue;
param.Out = &out_ref;
elementwise_add_compute_ref<float>(param);
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
x.Resize(x_dim);
y.Resize(y_dim);
output.Resize(x_dim);
output_ref.Resize(x_dim);
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x_dim.production(); i++) {
x_data[i] = i;
}
for (int i = 0; i < y_dim.production(); i++) {
y_data[i] = i;
}
param.X = &x;
param.Y = &y;
param.axis = axis;
param.Out = &output;
elementwise_add.SetParam(param);
elementwise_add.Run();
param.Out = &output_ref;
elementwise_add_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
}
}
......
......@@ -163,7 +163,7 @@ PrecisionType PoolCompute::precision() const { return PRECISION(kFloat); }
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(pool, kARM, kFloat, kNCHW,
REGISTER_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::PoolCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
......
......@@ -272,4 +272,4 @@ TEST(pool, retrive_op) {
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def);
......@@ -45,4 +45,6 @@ class ReluCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(relu, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ReluCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) {
desc.SetInput("Mean", {"mean"});
desc.SetInput("Variance", {"variance"});
desc.SetOutput("Y", {"y"});
desc.SetAttr("is_test", true);
desc.SetAttr("is_test", static_cast<int>(1));
desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f);
......@@ -101,7 +101,7 @@ TEST(batch_norm_op_lite, test_enable_is_test) {
desc.SetOutput("VarianceOut", {"variance_out"});
desc.SetOutput("SavedMean", {"saved_mean"});
desc.SetOutput("SavedVariance", {"saved_variance"});
desc.SetAttr("is_test", false);
desc.SetAttr("is_test", static_cast<int>(0));
desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f);
......
......@@ -56,23 +56,26 @@ class ConvOpLite : public OpLite {
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_arguments = op_desc.Input("Bias");
if (bias_arguments.size() != 0) {
if (bias_arguments.size() > 0) {
auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) {
param_.bias = bias_var->GetMutable<lite::Tensor>();
param_.bias =
const_cast<lite::Tensor*>(&(bias_var->Get<lite::Tensor>()));
}
}
}
if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
auto res_argument = op_desc.Input("ResidualData");
if (res_argument.size() != 0) {
auto residual_data_var = scope->FindVar(res_argument.front());
auto res_data_arguments = op_desc.Input("ResidualData");
if (res_data_arguments.size() > 0) {
auto residual_data_var = scope->FindVar(res_data_arguments.front());
if (residual_data_var != nullptr) {
param_.residualData = residual_data_var->GetMutable<lite::Tensor>();
param_.residualData = const_cast<lite::Tensor*>(
&(residual_data_var->Get<lite::Tensor>()));
}
}
}
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
return true;
}
......
......@@ -53,17 +53,25 @@ class PoolOpLite : public OpLite {
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.exclusive = op_desc.GetAttr<bool>("exclusive");
param_.adaptive = op_desc.GetAttr<bool>("adaptive");
param_.ceil_mode = op_desc.GetAttr<bool>("ceil_mode");
param_.use_quantizer = op_desc.GetAttr<bool>("use_quantizer");
if (op_desc.HasAttr("exclusive")) {
param_.exclusive = op_desc.GetAttr<bool>("exclusive");
}
if (op_desc.HasAttr("adaptive")) {
param_.adaptive = op_desc.GetAttr<bool>("adaptive");
}
if (op_desc.HasAttr("ceil_mode")) {
param_.ceil_mode = op_desc.GetAttr<bool>("ceil_mode");
}
if (op_desc.HasAttr("use_quantizer")) {
param_.use_quantizer = op_desc.GetAttr<bool>("use_quantizer");
}
// param_.data_format = op_desc.GetAttr<bool>("data_format");
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "pool"; }
std::string DebugString() const override { return "pool2d"; }
private:
mutable PoolParam param_;
......
......@@ -38,7 +38,7 @@ TEST(pool_op_lite, test) {
// prepare op desc
cpp::OpDesc desc;
desc.SetType("pool");
desc.SetType("pool2d");
desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"output"});
......@@ -69,7 +69,7 @@ TEST(pool_op_lite, test) {
bool use_quantizer{false};
desc.SetAttr("use_quantizer", use_quantizer);
PoolOpLite pool("pool");
PoolOpLite pool("pool2d");
pool.SetValidPlaces({Place{TARGET(kARM), PRECISION(kFloat)}});
pool.Attach(desc, &scope);
auto kernels = pool.CreateKernels({Place{TARGET(kARM), PRECISION(kFloat)}});
......@@ -86,5 +86,5 @@ TEST(pool_op_lite, test) {
} // namespace paddle
#ifdef LITE_WITH_ARM
USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def);
#endif
......@@ -37,7 +37,7 @@ bool SplitOp::InferShape() const {
const auto &sections = param_.sections;
const int outs_number = outs.size();
std::vector<lite::DDimHvy> outs_dims;
std::vector<lite::DDim> outs_dims;
outs_dims.reserve(outs_number);
if (num > 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册