提交 0d6c897b 编写于 作者: Z zhupengyang

fix elementwise_add unit test

上级 92b6051d
...@@ -41,10 +41,10 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout, ...@@ -41,10 +41,10 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
float32x4_t diny2 = vld1q_f32(diny_ptr + 8); float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12); float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
float32x4_t dinx0 = vaddq_f32(dinx0, diny0); dinx0 = vaddq_f32(dinx0, diny0);
float32x4_t dinx1 = vaddq_f32(dinx1, diny1); dinx1 = vaddq_f32(dinx1, diny1);
float32x4_t dinx2 = vaddq_f32(dinx2, diny2); dinx2 = vaddq_f32(dinx2, diny2);
float32x4_t dinx3 = vaddq_f32(dinx3, diny3); dinx3 = vaddq_f32(dinx3, diny3);
vst1q_f32(dout_ptr, dinx0); vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1); vst1q_f32(dout_ptr + 4, dinx1);
...@@ -100,10 +100,10 @@ void elementwise_add_axis<float>(const float* dinx, const float* diny, ...@@ -100,10 +100,10 @@ void elementwise_add_axis<float>(const float* dinx, const float* diny,
if (remain >= 8) { if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr); float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4); float32x4_t din1 = vld1q_f32(din_ptr + 4);
din0 = vaddq_f32(din0, diny_data); din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, diny_data); din1 = vaddq_f32(din1, rb);
vst1q_f32(dout_ptr, r0); vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, r1); vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8; din_ptr += 8;
dout_ptr += 8; dout_ptr += 8;
remain -= 8; remain -= 8;
...@@ -111,16 +111,16 @@ void elementwise_add_axis<float>(const float* dinx, const float* diny, ...@@ -111,16 +111,16 @@ void elementwise_add_axis<float>(const float* dinx, const float* diny,
if (remain >= 4) { if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr); float32x4_t din0 = vld1q_f32(din_ptr);
din0 = vaddq_f32(din0, rb); din0 = vaddq_f32(din0, rb);
vst1q_f32(dout_ptr, diny_data); vst1q_f32(dout_ptr, din0);
din_ptr += 4; din_ptr += 4;
dout_ptr += 4; dout_ptr += 4;
remain -= 4; remain -= 4;
} }
if (remain > 0) { if (remain > 0) {
for (p = 0; p < remain; p++) { for (int p = 0; p < remain; p++) {
*dout_ptr = *dinx_ptr + diny_data; *dout_ptr = *din_ptr + diny_data;
dout_ptr++; dout_ptr++;
dinx_ptr++; din_ptr++;
} }
} }
} }
......
...@@ -23,8 +23,8 @@ template <typename T> ...@@ -23,8 +23,8 @@ template <typename T>
void elementwise_add(const T* dinx, const T* diny, T* dout, int num); void elementwise_add(const T* dinx, const T* diny, T* dout, int num);
template <typename T> template <typename T>
void elementwise_add_axis<float>(const T* dinx, const T* diny, T* dout, void elementwise_add_axis(const T* dinx, const T* diny, T* dout, int batch,
int batch, int channels, int num); int channels, int num);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
......
...@@ -100,15 +100,15 @@ void ConvCompute::Run() { ...@@ -100,15 +100,15 @@ void ConvCompute::Run() {
REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def) paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) // .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def) paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) // .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -31,7 +31,7 @@ void ElementwiseAddCompute::Run() { ...@@ -31,7 +31,7 @@ void ElementwiseAddCompute::Run() {
if (axis < 0) { if (axis < 0) {
axis = x_dims.size() - y_dims.size(); axis = x_dims.size() - y_dims.size();
} }
if (axis == 0) { if (x_dims.size() == y_dims.size()) {
lite::arm::math::elementwise_add(x_data, y_data, out_data, lite::arm::math::elementwise_add(x_data, y_data, out_data,
x_dims.production()); x_dims.production());
} else { } else {
......
...@@ -44,6 +44,9 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) { ...@@ -44,6 +44,9 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int axis = param.axis; int axis = param.axis;
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
int batch = 1; int batch = 1;
int channels = 1; int channels = 1;
int num = 1; int num = 1;
...@@ -61,9 +64,11 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) { ...@@ -61,9 +64,11 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
int offset = (i * channels + j) * num; int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset; const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j]; const dtype diny_data = y_data[j];
dtype* dout_ptr = dout + offset; dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) { for (int k = 0; k < num; ++k) {
dout_ptr[k] = din_ptr[k] + diny_data; *dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
} }
} }
} }
...@@ -79,18 +84,15 @@ TEST(elementwise_add, compute) { ...@@ -79,18 +84,15 @@ TEST(elementwise_add, compute) {
for (auto h : {1, 3, 4, 11}) { for (auto h : {1, 3, 4, 11}) {
for (auto w : {1, 3, 4, 11}) { for (auto w : {1, 3, 4, 11}) {
for (auto axis : {-1, 0, 1, 2, 3}) { for (auto axis : {-1, 0, 1, 2, 3}) {
for (auto yd{{n}, for (auto yd :
{c}, {std::vector<int64_t>({n}), std::vector<int64_t>({c}),
{h}, std::vector<int64_t>({h}), std::vector<int64_t>({w}),
{w}, std::vector<int64_t>({n, c}), std::vector<int64_t>({c, h}),
{n, c}, std::vector<int64_t>({h, w}), std::vector<int64_t>({n, c, h}),
{c, h}, std::vector<int64_t>({c, h, w}),
{h, w}, std::vector<int64_t>({n, c, h, w})}) {
{n, c, h},
{c, h, w},
{n, c, h, w}}) {
auto x_dim = DDim(std::vector<int64_t>({n, c, h, w})); auto x_dim = DDim(std::vector<int64_t>({n, c, h, w}));
auto y_dim = DDim(std::vector<int64_t>(yd)); auto y_dim = DDim(yd);
int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis; int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis;
if (axis_t + y_dim.size() > 4) continue; if (axis_t + y_dim.size() > 4) continue;
...@@ -102,26 +104,27 @@ TEST(elementwise_add, compute) { ...@@ -102,26 +104,27 @@ TEST(elementwise_add, compute) {
x.Resize(x_dim); x.Resize(x_dim);
y.Resize(y_dim); y.Resize(y_dim);
output.Resize(DDim(std::vector<int64_t>({n, c, h, w}))); output.Resize(x_dim);
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w}))); output_ref.Resize(x_dim);
auto* x_data = x.mutable_data<float>(); auto* x_data = x.mutable_data<float>();
auto* y_data = y.mutable_data<float>();
auto* output_data = output.mutable_data<float>(); auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>(); auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) { for (int i = 0; i < x_dim.production(); i++) {
x_data[i] = i; x_data[i] = i;
} }
for (int i = 0; i < y.dims().production(); i++) { for (int i = 0; i < y_dim.production(); i++) {
y_data[i] = i; y_data[i] = i;
} }
param.X = &x; param.X = &x;
param.Y = &y; param.Y = &y;
param.axis = axis; param.axis = axis;
param.Out = &output; param.Out = &output;
softmax.SetParam(param); elementwise_add.SetParam(param);
softmax.Run(); elementwise_add.Run();
param.Out = &output_ref; param.Out = &output_ref;
elementwise_add_compute_ref<float>(param); elementwise_add_compute_ref<float>(param);
for (int i = 0; i < out.dims().production(); i++) { for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册