提交 273e6fae 编写于 作者: J jiaopu 提交者: jackzhang235

add interpolate in x86

上级 2bec4623
......@@ -64,6 +64,7 @@ add_kernel(search_fc_compute_x86 X86 basic SRCS search_fc_compute.cc DEPS ${lite
add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} blas)
add_kernel(yolo_box_compute_x86 X86 basic SRCS yolo_box_compute.cc DEPS ${lite_kernel_deps})
add_kernel(interpolate_compute_x86 X86 basic SRCS interpolate_compute.cc DEPS ${lite_kernel_deps})
lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
......@@ -104,3 +105,5 @@ lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compu
lite_cc_test(test_leaky_relu_compute_x86 SRCS leaky_relu_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_yolo_box_compute_x86 SRCS yolo_box_compute_test.cc DEPS
yolo_box_compute_x86)
lite_cc_test(test_nearest_interp_comute_x86 SRCS interpolate_compute_test.cc
DEPS interpolate_compute_x86)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/interpolate_compute.h"
REGISTER_LITE_KERNEL(nearest_interp,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::InterpolateCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("OutSize",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindInput("SizeTensor",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/interpolate_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
inline void nearest_interp(const float* src,
int w_in,
int h_in,
float* dst,
int w_out,
int h_out,
bool with_align) {
float scale_w_new = (with_align)
? (static_cast<float>(w_in - 1) / (w_out - 1))
: (static_cast<float>(w_in) / (w_out));
float scale_h_new = (with_align)
? (static_cast<float>(h_in - 1) / (h_out - 1))
: (static_cast<float>(h_in) / (h_out));
if (with_align) {
for (int h = 0; h < h_out; ++h) {
float* dst_p = dst + h * w_out;
int near_y = static_cast<int>(scale_h_new * h + 0.5);
for (int w = 0; w < w_out; ++w) {
int near_x = static_cast<int>(scale_w_new * w + 0.5);
*dst_p++ = src[near_y * w_in + near_x];
}
}
} else {
for (int h = 0; h < h_out; ++h) {
float* dst_p = dst + h * w_out;
int near_y = static_cast<int>(scale_h_new * h);
for (int w = 0; w < w_out; ++w) {
int near_x = static_cast<int>(scale_w_new * w);
*dst_p++ = src[near_y * w_in + near_x];
}
}
}
}
inline std::vector<int> get_new_shape(
std::vector<const lite::Tensor*> list_new_shape_tensor) {
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
return vec_new_shape;
}
class InterpolateCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::InterpolateParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
int in_h = param.X->dims()[2];
int in_w = param.X->dims()[3];
if (param.SizeTensor.size() > 0) {
auto new_size = get_new_shape(param.SizeTensor);
param.out_h = new_size[0];
param.out_w = new_size[1];
} else {
auto scale_tensor = param.Scale;
if (scale_tensor != nullptr) {
auto* scale_data = param.Scale->mutable_data<float>();
param.scale = scale_data[0];
}
if (param.scale > 0) {
param.out_h = static_cast<int>(in_h * param.scale);
param.out_w = static_cast<int>(in_w * param.scale);
}
if (param.OutSize != nullptr) {
auto* outsize_data = param.OutSize->mutable_data<float>();
param.out_h = outsize_data[0];
param.out_w = outsize_data[1];
}
}
int num_cout = param.X->dims()[0];
int c_cout = param.X->dims()[1];
param.Out->Resize({num_cout, c_cout, param.out_h, param.out_w});
float* dout = param.Out->mutable_data<float>();
const float* din = param.X->data<float>();
int out_num = param.Out->dims()[0];
int out_c = param.Out->dims()[1];
int count = out_num * out_c;
int out_h = param.Out->dims()[2];
int out_w = param.Out->dims()[3];
int spatial_in = in_h * in_w;
int spatial_out = out_h * out_w;
#pragma omp parallel for
for (int i = 0; i < count; ++i) {
nearest_interp(din + spatial_in * i,
in_w,
in_h,
dout + spatial_out * i,
out_w,
out_h,
param.align_corners);
}
}
virtual ~InterpolateCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/interpolate_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
void NearestInterpRef(lite::Tensor* input,
lite::Tensor* output,
bool with_align) {
int hin = input->dims()[2];
int win = input->dims()[3];
int channels = input->dims()[1];
int num = input->dims()[0];
int hout = output->dims()[2];
int wout = output->dims()[3];
float scale_w = (with_align) ? (static_cast<float>(win - 1) / (wout - 1))
: (static_cast<float>(win) / (wout));
float scale_h = (with_align) ? (static_cast<float>(hin - 1) / (hout - 1))
: (static_cast<float>(hin) / (hout));
const float* src = input->data<float>();
float* dst = output->mutable_data<float>();
int dst_stride_w = 1;
int dst_stride_h = wout;
int dst_stride_c = wout * hout;
int dst_stride_batch = wout * hout * channels;
int src_stride_w = 1;
int src_stride_h = win;
int src_stride_c = win * hin;
int src_stride_batch = win * hin * channels;
for (int n = 0; n < num; ++n) {
for (int c = 0; c < channels; ++c) {
int src_index = n * src_stride_batch + c * src_stride_c;
for (int h = 0; h < hout; ++h) {
for (int w = 0; w < wout; ++w) {
int fw = (with_align) ? static_cast<int>(scale_w * w + 0.5)
: static_cast<int>(scale_w * w);
fw = (fw < 0) ? 0 : fw;
int fh = (with_align) ? static_cast<int>(scale_h * h + 0.5)
: static_cast<int>(scale_h * h);
fh = (fh < 0) ? 0 : fh;
int w_start = static_cast<int>(fw);
int h_start = static_cast<int>(fh);
int dst_index = n * dst_stride_batch + c * dst_stride_c +
h * dst_stride_h + w * dst_stride_w;
dst[dst_index] =
src[src_index + w_start * src_stride_w + h_start * src_stride_h];
}
}
}
}
}
TEST(interpolate_x86, retrive_op) {
auto interpolate =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"nearest_interp");
ASSERT_FALSE(interpolate.empty());
ASSERT_TRUE(interpolate.front());
}
TEST(interpolate_x86, init) {
InterpolateCompute interpolate;
ASSERT_EQ(interpolate.precision(), PRECISION(kFloat));
ASSERT_EQ(interpolate.target(), TARGET(kX86));
}
TEST(interpolate_x86, run_test) {
lite::Tensor X, OutSize, Out, Out_base;
operators::InterpolateParam param;
InterpolateCompute interpolate;
int n = 1, c = 3, in_h = 40, in_w = 40;
int out_h = 80, out_w = 80;
float scale = 2.0;
param.out_h = out_h;
param.out_w = out_w;
param.scale = scale;
param.align_corners = false;
X.Resize({n, c, in_h, in_w});
OutSize.Resize({2});
Out.Resize({n, c, out_h, out_w});
Out_base.Resize({n, c, out_h, out_w});
auto* out_data = Out.mutable_data<float>();
auto* out_base_data = Out_base.mutable_data<float>();
auto* x_data = X.mutable_data<float>();
auto* outsize_data = OutSize.mutable_data<float>();
for (int i = 0; i < X.dims().production(); i++) {
x_data[i] = i + 5.0;
}
outsize_data[0] = out_h;
outsize_data[1] = out_w;
param.X = &X;
param.OutSize = &OutSize;
param.Out = &Out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
interpolate.SetContext(std::move(ctx));
interpolate.SetParam(std::move(param));
interpolate.Run();
NearestInterpRef(&X, &Out_base, false);
for (int i = 0; i < Out.dims().production(); i++) {
LOG(INFO) << out_data[i];
EXPECT_NEAR(out_data[i], out_base_data[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(nearest_interp, kX86, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册