提交 0d6c897b 编写于 作者: Z zhupengyang

fix elementwise_add unit test

上级 92b6051d
......@@ -41,10 +41,10 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
float32x4_t dinx0 = vaddq_f32(dinx0, diny0);
float32x4_t dinx1 = vaddq_f32(dinx1, diny1);
float32x4_t dinx2 = vaddq_f32(dinx2, diny2);
float32x4_t dinx3 = vaddq_f32(dinx3, diny3);
dinx0 = vaddq_f32(dinx0, diny0);
dinx1 = vaddq_f32(dinx1, diny1);
dinx2 = vaddq_f32(dinx2, diny2);
dinx3 = vaddq_f32(dinx3, diny3);
vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1);
......@@ -100,10 +100,10 @@ void elementwise_add_axis<float>(const float* dinx, const float* diny,
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
din0 = vaddq_f32(din0, diny_data);
din1 = vaddq_f32(din1, diny_data);
vst1q_f32(dout_ptr, r0);
vst1q_f32(dout_ptr + 4, r1);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
......@@ -111,16 +111,16 @@ void elementwise_add_axis<float>(const float* dinx, const float* diny,
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
din0 = vaddq_f32(din0, rb);
vst1q_f32(dout_ptr, diny_data);
vst1q_f32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (p = 0; p < remain; p++) {
*dout_ptr = *dinx_ptr + diny_data;
for (int p = 0; p < remain; p++) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
dinx_ptr++;
din_ptr++;
}
}
}
......
......@@ -23,8 +23,8 @@ template <typename T>
void elementwise_add(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_add_axis<float>(const T* dinx, const T* diny, T* dout,
int batch, int channels, int num);
void elementwise_add_axis(const T* dinx, const T* diny, T* dout, int batch,
int channels, int num);
} // namespace math
} // namespace arm
......
......@@ -100,15 +100,15 @@ void ConvCompute::Run() {
REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
// .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
// .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -31,7 +31,7 @@ void ElementwiseAddCompute::Run() {
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
if (axis == 0) {
if (x_dims.size() == y_dims.size()) {
lite::arm::math::elementwise_add(x_data, y_data, out_data,
x_dims.production());
} else {
......
......@@ -44,6 +44,9 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int axis = param.axis;
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
int batch = 1;
int channels = 1;
int num = 1;
......@@ -61,9 +64,11 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = dout + offset;
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
dout_ptr[k] = din_ptr[k] + diny_data;
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
}
}
}
......@@ -79,18 +84,15 @@ TEST(elementwise_add, compute) {
for (auto h : {1, 3, 4, 11}) {
for (auto w : {1, 3, 4, 11}) {
for (auto axis : {-1, 0, 1, 2, 3}) {
for (auto yd{{n},
{c},
{h},
{w},
{n, c},
{c, h},
{h, w},
{n, c, h},
{c, h, w},
{n, c, h, w}}) {
for (auto yd :
{std::vector<int64_t>({n}), std::vector<int64_t>({c}),
std::vector<int64_t>({h}), std::vector<int64_t>({w}),
std::vector<int64_t>({n, c}), std::vector<int64_t>({c, h}),
std::vector<int64_t>({h, w}), std::vector<int64_t>({n, c, h}),
std::vector<int64_t>({c, h, w}),
std::vector<int64_t>({n, c, h, w})}) {
auto x_dim = DDim(std::vector<int64_t>({n, c, h, w}));
auto y_dim = DDim(std::vector<int64_t>(yd));
auto y_dim = DDim(yd);
int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis;
if (axis_t + y_dim.size() > 4) continue;
......@@ -102,26 +104,27 @@ TEST(elementwise_add, compute) {
x.Resize(x_dim);
y.Resize(y_dim);
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(x_dim);
output_ref.Resize(x_dim);
auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
for (int i = 0; i < x_dim.production(); i++) {
x_data[i] = i;
}
for (int i = 0; i < y.dims().production(); i++) {
for (int i = 0; i < y_dim.production(); i++) {
y_data[i] = i;
}
param.X = &x;
param.Y = &y;
param.axis = axis;
param.Out = &output;
softmax.SetParam(param);
softmax.Run();
elementwise_add.SetParam(param);
elementwise_add.Run();
param.Out = &output_ref;
elementwise_add_compute_ref<float>(param);
for (int i = 0; i < out.dims().production(); i++) {
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册