提交 20001636 编写于 作者: L liu zhengxi 提交者: Xiaoyang LI

add stack op and add reduce_mean op and their unit tests (#1888)

上级 ecce1eff
......@@ -106,6 +106,8 @@ USE_LITE_KERNEL(generate_proposals, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(squeeze, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(squeeze2, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(expand, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(reduce_mean, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(stack, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(roi_align, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
......
......@@ -69,6 +69,8 @@ USE_LITE_OP(shuffle_channel)
USE_LITE_OP(yolo_box)
USE_LITE_OP(bilinear_interp)
USE_LITE_OP(nearest_interp)
USE_LITE_OP(reduce_mean)
USE_LITE_OP(stack)
USE_LITE_OP(assign);
USE_LITE_OP(crop)
......
......@@ -103,6 +103,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
sequence_pool.cc
sequence_expand.cc
slice.cc
reduce_mean.cc
stack.cc
affine_channel.cc
anchor_generator.cc
DEPS ${lite_kernel_deps})
......
......@@ -48,6 +48,7 @@
#include "lite/arm/math/power.h"
#include "lite/arm/math/prior_box.h"
#include "lite/arm/math/reduce_max.h"
#include "lite/arm/math/reduce_mean.h"
#include "lite/arm/math/scale.h"
#include "lite/arm/math/sequence_expand.h"
#include "lite/arm/math/sequence_pool.h"
......@@ -58,6 +59,7 @@
#include "lite/arm/math/slice.h"
#include "lite/arm/math/softmax.h"
#include "lite/arm/math/split.h"
#include "lite/arm/math/stack.h"
#include "lite/arm/math/topk.h"
#include "lite/arm/math/yolo_box.h"
namespace paddle {
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/arm/math/reduce_mean.h"
#include "lite/arm/math/funcs.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void reduce_mean_n<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int hw_size = height_in * width_in;
int chw_size = channel_in * hw_size;
int data_index, src_index, src_index0;
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
data_index = c * hw_size + h * width_in + w;
dst[data_index] = 0.0;
for (int n = 1; n < num_in; ++n) {
src_index = n * chw_size + data_index;
dst[data_index] += static_cast<float>(src[src_index]) / num_in;
}
}
}
}
}
template <>
void reduce_mean_c<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int hw_size = height_in * width_in;
int chw_size = hw_size * channel_in;
int data_index, src_index0, src_index;
for (int n = 0; n < num_in; ++n) {
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
data_index = n * hw_size + h * width_in + w;
src_index0 = n * chw_size + h * width_in + w;
dst[data_index] = 0.0;
for (int c = 1; c < channel_in; ++c) {
src_index = src_index0 + c * hw_size;
dst[data_index] += static_cast<float>(src[src_index]) / channel_in;
}
}
}
}
}
template <>
void reduce_mean_h<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int cw_size = channel_in * width_in;
int chw_size = cw_size * height_in;
int hw_size = height_in * width_in;
int data_index, src_index, src_index0;
for (int n = 0; n < num_in; ++n) {
for (int c = 0; c < channel_in; ++c) {
for (int w = 0; w < width_in; ++w) {
data_index = n * cw_size + c * width_in + w;
src_index0 = n * chw_size + c * hw_size + w;
dst[data_index] = 0.0;
for (int h = 1; h < height_in; ++h) {
src_index = src_index0 + h * width_in;
dst[data_index] += static_cast<float>(src[src_index]) / height_in;
}
}
}
}
}
template <>
void reduce_mean_w<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int ch_size = channel_in * height_in;
int hw_size = height_in * width_in;
int chw_size = ch_size * width_in;
int data_index = 0;
int src_index0 = 0;
int src_index = 0;
for (int n = 0; n < num_in; ++n) {
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
data_index = n * ch_size + c * height_in + h;
src_index0 = n * chw_size + c * hw_size + h * width_in;
dst[data_index] = 0.0;
for (int w = 1; w < width_in; ++w) {
src_index = src_index0 + w;
dst[data_index] += static_cast<float>(src[src_index]) / width_in;
}
}
}
}
}
template <>
void reduce_mean_all<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
float mean = 0.0;
int src_index;
int n_id, c_id;
int all = num_in * channel_in * height_in * width_in;
for (int n = 0; n < num_in; ++n) {
n_id = n * channel_in * height_in * width_in;
for (int c = 0; c < channel_in; ++c) {
c_id = c * height_in * width_in;
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
src_index = n_id + c_id + h * width_in + w;
mean = src[src_index] / all;
}
}
}
}
dst[0] = mean;
}
template <>
void reduce_mean_nc<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce n first.
DDimLite ddimA({1, channel_in, height_in, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
float* tmp_out = tensor_tmp.mutable_data<float>();
reduce_mean_n(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_mean_c(tmp_out, dst, 1, channel_in, height_in, width_in);
}
template <>
void reduce_mean_ch<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce c first
DDimLite ddimA({num_in, 1, height_in, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
float* tmp_out = tensor_tmp.mutable_data<float>();
reduce_mean_c(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_mean_h(tmp_out, dst, num_in, 1, height_in, width_in);
}
template <>
void reduce_mean_hw<float>(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce h first
DDimLite ddimA({num_in, channel_in, 1, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
float* tmp_out = tensor_tmp.mutable_data<float>();
reduce_mean_h(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_mean_w(tmp_out, dst, num_in, channel_in, 1, width_in);
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void reduce_mean_n(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_mean_c(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_mean_h(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_mean_w(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_mean_nc(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_mean_ch(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_mean_hw(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
template <typename T>
void reduce_mean_all(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/arm/math/stack.h"
#include <cstddef>
#include <utility>
#include <vector>
#include "lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void stack(std::vector<lite::Tensor *> x, lite::Tensor *y, int axis) {
if (axis < 0) axis += (x[0]->dims().size() + 1);
int n = x.size();
auto *y_data = y->mutable_data<float>();
std::vector<const float *> x_datas(n);
for (int i = 0; i < n; i++) x_datas[i] = x[i]->data<float>();
int pre = 1, post = 1;
auto &dim = x[0]->dims();
for (auto i = 0; i < axis; ++i) pre *= dim[i];
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
auto x_data_arr = x_datas.data();
size_t x_offset = 0;
size_t y_offset = 0;
for (int i = 0; i < pre; i++) {
for (int j = 0; j < n; j++) {
std::memcpy(
y_data + y_offset, x_data_arr[j] + x_offset, post * sizeof(float));
y_offset += post;
}
x_offset += post;
}
}
} /* namespace math */
} /* namespace arm */
} /* namespace lite */
} /* namespace paddle */
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <vector>
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void stack(std::vector<lite::Tensor*> x, lite::Tensor* out, int axis);
} /* namespace math */
} /* namespace arm */
} /* namespace lite */
} /* namespace paddle */
......@@ -45,6 +45,8 @@ add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${li
add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_compute_arm ARM basic SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(anchor_generator_compute_arm ARM basic SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/reduce_mean_compute.h"
#include <string>
#include "lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void ReduceMeanCompute::Run() {
auto& param = Param<operators::ReduceMeanParam>();
const float* input = param.X->data<float>();
auto x_dims = param.X->dims();
int x_rank = x_dims.size();
float* output = param.Out->mutable_data<float>();
bool keep_dim = param.keep_dim;
auto dim = param.dim;
if (!dim.empty()) {
for (int i = 0; i < dim.size(); i++) {
if (dim[i] < 0) {
dim[i] += x_rank;
}
}
}
int n_in = x_dims[0];
int c_in = x_dims[1];
int h_in = x_dims[2];
int w_in = x_dims[3];
if (dim.size() == 0) {
lite::arm::math::reduce_mean_all(input, output, n_in, c_in, h_in, w_in);
} else if (dim.size() == 1) {
switch (dim[0]) {
case 0:
lite::arm::math::reduce_mean_n(input, output, n_in, c_in, h_in, w_in);
break;
case 1:
lite::arm::math::reduce_mean_c(input, output, n_in, c_in, h_in, w_in);
break;
case 2:
lite::arm::math::reduce_mean_h(input, output, n_in, c_in, h_in, w_in);
break;
case 3:
lite::arm::math::reduce_mean_w(input, output, n_in, c_in, h_in, w_in);
break;
default:
LOG(FATAL) << "error!!!";
}
} else if (dim.size() == 2) {
if (dim[0] == 0 && dim[1] == 1) {
lite::arm::math::reduce_mean_nc(input, output, n_in, c_in, h_in, w_in);
} else if (dim[0] == 1 && dim[1] == 2) {
lite::arm::math::reduce_mean_ch(input, output, n_in, c_in, h_in, w_in);
} else if (dim[0] == 2 && dim[1] == 3) {
lite::arm::math::reduce_mean_hw(input, output, n_in, c_in, h_in, w_in);
} else {
LOG(FATAL) << "invalid dim!!";
}
} else {
LOG(FATAL) << "dim's size over than 2, which is not supported now!!";
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(reduce_mean,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ReduceMeanCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ReduceMeanCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ReduceMeanCompute() = default;
private:
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/stack_compute.h"
#include <vector>
#include "lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void StackCompute::Run() {
auto& param = Param<operators::StackParam>();
std::vector<lite::Tensor*> x = param.X;
lite::Tensor* out = param.Out;
int axis = param.axis;
lite::arm::math::stack(x, out, axis);
}
} /* namespace arm */
} /* namespace kernels */
} /* namespace lite */
} /* namespace paddle */
REGISTER_LITE_KERNEL(
stack, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::StackCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class StackCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~StackCompute() = default;
};
} /* namespace arm */
} /* namespace kernels */
} /* namespace lite */
} /* namespace paddle */
......@@ -59,6 +59,8 @@ add_operator(shape_op_lite basic SRCS shape_op.cc DEPS ${op_DEPS})
add_operator(sequence_expand_op_lite basic SRCS sequence_expand_op.cc DEPS ${op_DEPS})
add_operator(squeeze_op_lite basic SRCS squeeze_op.cc DEPS ${op_DEPS})
add_operator(im2sequence_op basic SRCS im2sequence_op.cc DEPS ${op_DEPS})
add_operator(reduce_mean_op basic SRCS reduce_mean_op.cc DEPS ${op_DEPS})
add_operator(stack_op basic SRCS stack_op.cc DEPS ${op_DEPS})
add_operator(cast_op_lite basic SRCS cast_op.cc DEPS ${op_DEPS})
add_operator(assign_op basic SRCS assign_op.cc DEPS ${op_DEPS})
......
......@@ -121,6 +121,23 @@ struct MulGradParam {
int y_num_col_dims{1};
};
// For ReduceMean Op
struct ReduceMeanParam {
lite::Tensor* X{};
lite::Tensor* Out{};
std::vector<int> dim;
bool keep_dim{false};
};
// For Stack Op
struct StackParam {
std::vector<lite::Tensor*> X;
lite::Tensor* Out{};
int axis{0};
};
// For Power Op
struct PowerParam {
const lite::Tensor* X{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/reduce_mean_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ReduceMeanOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
auto dims = param_.dim;
auto x_dims = param_.X->dims();
int x_rank = x_dims.size();
if (dims.size() != 0) {
for (int i = 0; i < dims.size(); i++) {
if (dims[i] < 0) {
dims[i] = x_rank + dims[i];
}
CHECK_OR_FALSE(dims[i] <= x_rank && dims[i] >= -x_rank);
}
}
return true;
}
bool ReduceMeanOp::InferShape() const {
auto dims = param_.dim;
auto x_dims = param_.X->dims();
bool reduce_all = false;
bool keep_dim = param_.keep_dim;
auto x_rank = x_dims.size();
if (dims.size() != 0) {
for (int i = 0; i < dims.size(); i++) {
if (dims[i] < 0) {
dims[i] = x_rank + dims[i];
}
}
}
sort(dims.begin(), dims.end());
if (dims.size() == 0) {
reduce_all = true;
}
std::vector<int64_t> out_dims;
if (reduce_all) {
if (keep_dim) {
out_dims.push_back(x_rank);
out_dims.push_back(1);
} else {
out_dims.push_back(1);
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
out_dims.push_back(x_dims[i]);
}
if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) {
out_dims[dims[i]] = 1;
}
} else {
const int64_t kDelFlag = -2;
for (size_t i = 0; i < dims.size(); ++i) {
out_dims[dims[i]] = kDelFlag;
}
out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag),
out_dims.end());
}
param_.Out->Resize(DDim(out_dims));
if (dims[0] != 0) {
// Only pass LoD when not reducing on the first dim.
*param_.Out->mutable_lod() = param_.X->lod();
}
}
return true;
}
bool ReduceMeanOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.dim = opdesc.GetAttr<std::vector<int>>("dim");
if (opdesc.HasAttr("keep_dim")) {
param_.keep_dim = opdesc.GetAttr<bool>("keep_dim");
} else {
param_.keep_dim = false;
}
CHECK(param_.X);
CHECK(param_.Out);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(reduce_mean, paddle::lite::operators::ReduceMeanOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class ReduceMeanOp : public OpLite {
public:
ReduceMeanOp() {}
explicit ReduceMeanOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reduce_mean"; }
private:
mutable ReduceMeanParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/stack_op.h"
#include <cstddef>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace operators {
bool StackOp::CheckShape() const {
auto input = param_.X;
for (auto x : input) {
CHECK_OR_FALSE(x);
}
CHECK_OR_FALSE(param_.Out);
return true;
}
bool StackOp::InferShape() const {
auto input = param_.X;
auto input_dims = input[0]->dims();
int axis = param_.axis;
int rank = input_dims.size();
if (axis < 0) axis += (rank + 1);
auto vec = input_dims.Vectorize();
vec.insert(vec.begin() + axis, input.size());
param_.Out->Resize(vec);
return true;
}
bool StackOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto X = op_desc.Input("X");
auto Out = op_desc.Output("Out").front();
for (auto var : X) {
param_.X.emplace_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
param_.Out = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.axis = op_desc.GetAttr<int>("axis");
return true;
}
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
REGISTER_LITE_OP(stack, paddle::lite::operators::StackOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class StackOp : public OpLite {
public:
StackOp() {}
explicit StackOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "stack"; }
private:
mutable StackParam param_;
};
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
......@@ -43,6 +43,8 @@ endif()
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
void reduce_mean_n(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int hw_size = height_in * width_in;
int chw_size = channel_in * hw_size;
int data_index, src_index, src_index0;
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
data_index = c * hw_size + h * width_in + w;
dst[data_index] = 0.0;
for (int n = 1; n < num_in; ++n) {
src_index = n * chw_size + data_index;
dst[data_index] += static_cast<float>(src[src_index]) / num_in;
}
}
}
}
}
void reduce_mean_c(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int hw_size = height_in * width_in;
int chw_size = hw_size * channel_in;
int data_index, src_index0, src_index;
for (int n = 0; n < num_in; ++n) {
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
data_index = n * hw_size + h * width_in + w;
src_index0 = n * chw_size + h * width_in + w;
dst[data_index] = 0.0;
for (int c = 1; c < channel_in; ++c) {
src_index = src_index0 + c * hw_size;
dst[data_index] += static_cast<float>(src[src_index]) / channel_in;
}
}
}
}
}
void reduce_mean_h(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int cw_size = channel_in * width_in;
int chw_size = cw_size * height_in;
int hw_size = height_in * width_in;
int data_index, src_index, src_index0;
for (int n = 0; n < num_in; ++n) {
for (int c = 0; c < channel_in; ++c) {
for (int w = 0; w < width_in; ++w) {
data_index = n * cw_size + c * width_in + w;
src_index0 = n * chw_size + c * hw_size + w;
dst[data_index] = 0.0;
for (int h = 1; h < height_in; ++h) {
src_index = src_index0 + h * width_in;
dst[data_index] += static_cast<float>(src[src_index]) / height_in;
}
}
}
}
}
void reduce_mean_w(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int ch_size = channel_in * height_in;
int hw_size = height_in * width_in;
int chw_size = ch_size * width_in;
int data_index = 0;
int src_index0 = 0;
int src_index = 0;
for (int n = 0; n < num_in; ++n) {
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
data_index = n * ch_size + c * height_in + h;
src_index0 = n * chw_size + c * hw_size + h * width_in;
dst[data_index] = 0.0;
for (int w = 1; w < width_in; ++w) {
src_index = src_index0 + w;
dst[data_index] += static_cast<float>(src[src_index]) / width_in;
}
}
}
}
}
void reduce_mean_all(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
float mean = 0.0;
int src_index;
int n_id, c_id;
int all = num_in * channel_in * height_in * width_in;
for (int n = 0; n < num_in; ++n) {
n_id = n * channel_in * height_in * width_in;
for (int c = 0; c < channel_in; ++c) {
c_id = c * height_in * width_in;
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
src_index = n_id + c_id + h * width_in + w;
mean = src[src_index] / all;
}
}
}
}
dst[0] = mean;
}
void reduce_mean_nc(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce n first.
DDimLite ddimA({1, channel_in, height_in, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
float* tmp_out = tensor_tmp.mutable_data<float>();
reduce_mean_n(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_mean_c(tmp_out, dst, 1, channel_in, height_in, width_in);
}
void reduce_mean_ch(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce c first
DDimLite ddimA({num_in, 1, height_in, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
float* tmp_out = tensor_tmp.mutable_data<float>();
reduce_mean_c(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_mean_h(tmp_out, dst, num_in, 1, height_in, width_in);
}
void reduce_mean_hw(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce h first
DDimLite ddimA({num_in, channel_in, 1, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
float* tmp_out = tensor_tmp.mutable_data<float>();
reduce_mean_h(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_mean_w(tmp_out, dst, num_in, channel_in, 1, width_in);
}
class ReduceMeanComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "x";
std::string output_ = "out";
std::vector<int> dim_{0};
DDim x_dims_{{3, 2, 3, 4}};
bool keep_dim_ = false;
bool reduce_all_ = false;
public:
ReduceMeanComputeTester(const Place& place,
const std::string& alias,
std::vector<int> dim,
bool keep_dim,
DDim x_dims)
: TestCase(place, alias),
dim_(dim),
keep_dim_(keep_dim),
x_dims_(x_dims) {}
void RunBaseline(Scope* scope) override {
auto* x = scope->FindMutableTensor(input_);
const auto* x_data = x->data<float>();
auto* out = scope->NewTensor(output_);
auto x_rank = x_dims_.size();
if (!dim_.empty()) {
for (int i = 0; i < dim_.size(); i++) {
if (dim_[i] < 0) {
dim_[i] += x_rank;
}
}
}
sort(dim_.begin(), dim_.end());
if (dim_.size() == 0) {
reduce_all_ = true;
}
std::vector<int64_t> out_dims;
if (reduce_all_) {
if (keep_dim_) {
out_dims.push_back(x_rank);
out_dims.push_back(1);
} else {
out_dims.push_back(1);
}
} else {
for (int i = 0; i < x_dims_.size(); i++) {
out_dims.push_back(x_dims_[i]);
}
if (keep_dim_) {
for (size_t i = 0; i < dim_.size(); ++i) {
out_dims[dim_[i]] = 1L;
}
} else {
int64_t kDelFlag = -2;
for (size_t i = 0; i < dim_.size(); ++i) {
out_dims[dim_[i]] = kDelFlag;
}
out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag),
out_dims.end());
}
out->Resize(DDim(out_dims));
}
auto* out_data = out->mutable_data<float>();
int in_n = x_dims_[0];
int in_c = x_dims_[1];
int in_h = x_dims_[2];
int in_w = x_dims_[3];
if (dim_.size() == 0) {
reduce_mean_all(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_.size() == 1) {
switch (dim_[0]) {
case 0:
reduce_mean_n(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 1:
reduce_mean_c(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 2:
reduce_mean_h(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 3:
reduce_mean_w(x_data, out_data, in_n, in_c, in_h, in_w);
break;
default:
LOG(FATAL) << "error!!!";
}
} else if (dim_.size() == 2) {
if (dim_[0] == 0 && dim_[1] == 1) {
reduce_mean_nc(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 1 && dim_[1] == 2) {
reduce_mean_ch(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 2 && dim_[1] == 3) {
reduce_mean_hw(x_data, out_data, in_n, in_c, in_h, in_w);
} else {
LOG(FATAL) << "invalid dims_!!";
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("reduce_mean");
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("dim", dim_);
op_desc->SetAttr("keep_dim", keep_dim_);
}
void PrepareData() override {
std::vector<float> data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
data[i] = i * 1.0;
}
SetCommonTensor(input_, x_dims_, data.data());
}
};
void test_reduce_mean(Place place) {
std::vector<std::vector<int>> reduce_dim{
{0}, {1}, {2}, {3}, {0, 1}, {1, 2}, {2, 3}, {-2, -1}};
for (auto n : {1, 3}) {
for (auto c : {1, 2}) {
for (auto h : {1, 3}) {
for (auto w : {1, 3}) {
for (bool keep_dim : {false, true}) {
for (auto dim : reduce_dim) {
auto x_dims = DDim(std::vector<int64_t>({n, c, h, w}));
std::unique_ptr<arena::TestCase> tester(
new ReduceMeanComputeTester(
place, "def", dim, keep_dim, x_dims));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
}
}
TEST(ReduceMean, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_reduce_mean(place);
#endif
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
void stack(std::vector<const lite::Tensor*> x, lite::Tensor* y, int axis) {
if (axis < 0) axis += (x[0]->dims().size() + 1);
int n = x.size();
auto* y_data = y->mutable_data<float>();
std::vector<const float*> x_datas(n);
for (int i = 0; i < n; i++) x_datas[i] = x[i]->data<float>();
int pre = 1, post = 1;
auto& dim = x[0]->dims();
for (auto i = 0; i < axis; ++i) pre *= dim[i];
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
auto x_data_arr = x_datas.data();
size_t x_offset = 0;
size_t y_offset = 0;
for (int i = 0; i < pre; i++) {
for (int j = 0; j < n; j++) {
std::memcpy(
y_data + y_offset, x_data_arr[j] + x_offset, post * sizeof(float));
y_offset += post;
}
x_offset += post;
}
}
class StackComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input1_ = "X1";
std::string input2_ = "X2";
std::string output_ = "Out";
int axis_ = 0;
DDim dims_{{1, 5, 6, 7}};
public:
StackComputeTester(const Place& place, const std::string& alias, float axis)
: TestCase(place, alias), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
std::vector<const lite::Tensor*> x;
x.emplace_back(scope->FindTensor(input1_));
x.emplace_back(scope->FindTensor(input2_));
auto input_dims = x[0]->dims();
int rank = input_dims.size();
if (axis_ < 0) axis_ += (rank + 1);
auto vec = input_dims.Vectorize();
vec.insert(vec.begin() + axis_, x.size());
out->Resize(vec);
stack(x, out, axis_);
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("stack");
op_desc->SetInput("X", {input1_, input2_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.01;
}
SetCommonTensor(input1_, dims_, data.data());
SetCommonTensor(input2_, dims_, data.data());
}
};
void test_stack(Place place) {
for (float axis : {0, 1, 3}) {
std::unique_ptr<arena::TestCase> tester(
new StackComputeTester(place, "def", axis));
arena::Arena arena(std::move(tester), place, 2e-4);
arena.TestPrecision();
}
}
TEST(Stack, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_stack(place);
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册