提交 6dbc2f29 编写于 作者: L liu zhengxi 提交者: Xiaoyang LI

Modify the slice op, stack op, reduce_mean op (#1921)

* add stack op and add reduce_mean op and their unit tests

* modify stack op output name and modify the for loop in reduce_mean op

* add HasAttr for slice op
上级 029971b4
......@@ -36,7 +36,7 @@ void reduce_mean_n<float>(const float* src,
for (int w = 0; w < width_in; ++w) {
data_index = c * hw_size + h * width_in + w;
dst[data_index] = 0.0;
for (int n = 1; n < num_in; ++n) {
for (int n = 0; n < num_in; ++n) {
src_index = n * chw_size + data_index;
dst[data_index] += static_cast<float>(src[src_index]) / num_in;
}
......@@ -61,7 +61,7 @@ void reduce_mean_c<float>(const float* src,
data_index = n * hw_size + h * width_in + w;
src_index0 = n * chw_size + h * width_in + w;
dst[data_index] = 0.0;
for (int c = 1; c < channel_in; ++c) {
for (int c = 0; c < channel_in; ++c) {
src_index = src_index0 + c * hw_size;
dst[data_index] += static_cast<float>(src[src_index]) / channel_in;
}
......@@ -87,7 +87,7 @@ void reduce_mean_h<float>(const float* src,
data_index = n * cw_size + c * width_in + w;
src_index0 = n * chw_size + c * hw_size + w;
dst[data_index] = 0.0;
for (int h = 1; h < height_in; ++h) {
for (int h = 0; h < height_in; ++h) {
src_index = src_index0 + h * width_in;
dst[data_index] += static_cast<float>(src[src_index]) / height_in;
}
......@@ -115,7 +115,7 @@ void reduce_mean_w<float>(const float* src,
data_index = n * ch_size + c * height_in + h;
src_index0 = n * chw_size + c * hw_size + h * width_in;
dst[data_index] = 0.0;
for (int w = 1; w < width_in; ++w) {
for (int w = 0; w < width_in; ++w) {
src_index = src_index0 + w;
dst[data_index] += static_cast<float>(src[src_index]) / width_in;
}
......
......@@ -38,5 +38,5 @@ void StackCompute::Run() {
REGISTER_LITE_KERNEL(
stack, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::StackCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -79,7 +79,9 @@ bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.axes = opdesc.GetAttr<std::vector<int>>("axes");
param_.starts = opdesc.GetAttr<std::vector<int>>("starts");
param_.ends = opdesc.GetAttr<std::vector<int>>("ends");
param_.decrease_axis = opdesc.GetAttr<std::vector<int>>("decrease_axis");
if (opdesc.HasAttr("decrease_axis")) {
param_.decrease_axis = opdesc.GetAttr<std::vector<int>>("decrease_axis");
}
return true;
}
......
......@@ -46,7 +46,7 @@ bool StackOp::InferShape() const {
bool StackOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto X = op_desc.Input("X");
auto Out = op_desc.Output("Out").front();
auto Out = op_desc.Output("Y").front();
for (auto var : X) {
param_.X.emplace_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
......
......@@ -34,7 +34,7 @@ void reduce_mean_n(const float* src,
for (int w = 0; w < width_in; ++w) {
data_index = c * hw_size + h * width_in + w;
dst[data_index] = 0.0;
for (int n = 1; n < num_in; ++n) {
for (int n = 0; n < num_in; ++n) {
src_index = n * chw_size + data_index;
dst[data_index] += static_cast<float>(src[src_index]) / num_in;
}
......@@ -58,7 +58,7 @@ void reduce_mean_c(const float* src,
data_index = n * hw_size + h * width_in + w;
src_index0 = n * chw_size + h * width_in + w;
dst[data_index] = 0.0;
for (int c = 1; c < channel_in; ++c) {
for (int c = 0; c < channel_in; ++c) {
src_index = src_index0 + c * hw_size;
dst[data_index] += static_cast<float>(src[src_index]) / channel_in;
}
......@@ -83,7 +83,7 @@ void reduce_mean_h(const float* src,
data_index = n * cw_size + c * width_in + w;
src_index0 = n * chw_size + c * hw_size + w;
dst[data_index] = 0.0;
for (int h = 1; h < height_in; ++h) {
for (int h = 0; h < height_in; ++h) {
src_index = src_index0 + h * width_in;
dst[data_index] += static_cast<float>(src[src_index]) / height_in;
}
......@@ -110,7 +110,7 @@ void reduce_mean_w(const float* src,
data_index = n * ch_size + c * height_in + h;
src_index0 = n * chw_size + c * hw_size + h * width_in;
dst[data_index] = 0.0;
for (int w = 1; w < width_in; ++w) {
for (int w = 0; w < width_in; ++w) {
src_index = src_index0 + w;
dst[data_index] += static_cast<float>(src[src_index]) / width_in;
}
......
......@@ -77,7 +77,7 @@ class StackComputeTester : public arena::TestCase {
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("stack");
op_desc->SetInput("X", {input1_, input2_});
op_desc->SetOutput("Out", {output_});
op_desc->SetOutput("Y", {output_});
op_desc->SetAttr("axis", axis_);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册