提交 7fb93348 编写于 作者: C chenjiaoAngel

pull new code . test=develop

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle-Lite into conv_dw_5x5
......@@ -78,6 +78,7 @@ add_kernel(assign_value_compute_arm ARM basic SRCS assign_value_compute.cc DEPS
add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(clip_compute_arm ARM extra SRCS clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(pixel_shuffle_compute_arm ARM extra SRCS pixel_shuffle_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/pixel_shuffle_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void PixelShuffleCompute::Run() {
auto& param = Param<operators::PixelShuffleParam>();
const float* x_data = param.x->data<float>();
float* output_data = param.output->mutable_data<float>();
int upscale_factor = param.upscale_factor;
int batch_size = param.x->dims()[0];
int height = param.x->dims()[2];
int width = param.x->dims()[3];
int out_channels = param.output->dims()[1];
int out_height = param.output->dims()[2];
int out_width = param.output->dims()[3];
#pragma omp parallel for
for (int nc = 0; nc < batch_size * out_channels; nc++) {
const float* inptr = x_data + nc * out_height * out_width;
float* outptr_nc = output_data + nc * out_height * out_width;
for (int sh = 0; sh < upscale_factor; sh++) {
for (int sw = 0; sw < upscale_factor; sw++) {
float* outptr = outptr_nc + sh * out_width + sw;
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
outptr[0] = inptr[0];
inptr++;
outptr += upscale_factor;
}
outptr += (upscale_factor - 1) * out_width;
}
}
}
}
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "pixel_shuffle_func";
#endif
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(pixel_shuffle,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::PixelShuffleCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include "lite/core/kernel.h"
#include "lite/operators/pixel_shuffle_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class PixelShuffleCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::PixelShuffleParam;
void Run() override;
virtual ~PixelShuffleCompute() = default;
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForPixelShuffle"};
#endif
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -218,10 +218,10 @@ int DensityPriorBoxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
i_split_shape_data[1] /= 2;
shape[0] = &i_split_shape_data[0];
shape[1] = &i_split_shape_data[0];
name[0] = static_cast<const char*>(
lite::subgraph::bm::UniqueName("bm_boxes").c_str());
name[1] = static_cast<const char*>(
lite::subgraph::bm::UniqueName("bm_boxes_var").c_str());
auto boxes_name = lite::subgraph::bm::UniqueName("bm_boxes");
auto var_name = lite::subgraph::bm::UniqueName("bm_var");
name[0] = static_cast<const char*>(boxes_name.c_str());
name[1] = static_cast<const char*>(var_name.c_str());
int split_size[2];
split_size[0] = shape[0][1];
split_size[1] = shape[1][1];
......
......@@ -242,6 +242,7 @@ int PriorBoxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
if (op_info->HasAttr("prior_num")) {
param.prior_num = op_info->GetAttr<int32_t>("prior_num");
}
param.min_max_aspect_ratios_order = false;
if (op_info->HasAttr("min_max_aspect_ratios_order")) {
param.min_max_aspect_ratios_order =
op_info->GetAttr<bool>("min_max_aspect_ratios_order");
......@@ -289,10 +290,10 @@ int PriorBoxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
i_split_shape_data[1] /= 2;
shape[0] = &i_split_shape_data[0];
shape[1] = &i_split_shape_data[0];
name[0] = static_cast<const char*>(
lite::subgraph::bm::UniqueName("bm_boxes").c_str());
name[1] = static_cast<const char*>(
lite::subgraph::bm::UniqueName("bm_boxes_var").c_str());
auto boxes_name = lite::subgraph::bm::UniqueName("bm_boxes");
auto var_name = lite::subgraph::bm::UniqueName("bm_var");
name[0] = static_cast<const char*>(boxes_name.c_str());
name[1] = static_cast<const char*>(var_name.c_str());
int split_size[2];
split_size[0] = shape[0][1];
split_size[1] = shape[1][1];
......
......@@ -58,9 +58,7 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
for (size_t i = 0; i < axes.size(); i++) {
begin_index[axes[i]] = starts[i];
end_index[axes[i]] = ends[i] > static_cast<int32_t>(input_dims.size())
? static_cast<int32_t>(input_dims.size())
: ends[i];
end_index[axes[i]] = ends[i];
begin_mask &= ~(1 << axes[i]);
end_mask &= ~(1 << axes[i]);
}
......
......@@ -23,9 +23,11 @@ add_kernel(print_compute_host Host extra SRCS print_compute.cc DEPS ${lite_kerne
add_kernel(while_compute_host Host extra SRCS while_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(conditional_block_compute_host Host extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps})
add_kernel(pixel_shuffle_compute_host Host extra SRCS pixel_shuffle_compute.cc DEPS ${lite_kernel_deps})
add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps})
if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)
lite_cc_test(test_pixel_shuffle_compute_host SRCS pixel_shuffle_compute.cc DEPS pixel_shuffle_compute_host)
lite_cc_test(test_one_hot_compute_host SRCS one_hot_compute_test.cc DEPS one_hot_compute_host)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/pixel_shuffle_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
void PixelShuffleCompute::Run() {
auto& param = Param<operators::PixelShuffleParam>();
const float* x_data = param.x->data<float>();
float* output_data = param.output->mutable_data<float>();
int upscale_factor = param.upscale_factor;
int batch_size = param.x->dims()[0];
int height = param.x->dims()[2];
int width = param.x->dims()[3];
int out_channels = param.output->dims()[1];
int out_height = param.output->dims()[2];
int out_width = param.output->dims()[3];
for (int nc = 0; nc < batch_size * out_channels; nc++) {
const float* inptr = x_data + nc * out_height * out_width;
float* outptr_nc = output_data + nc * out_height * out_width;
for (int sh = 0; sh < upscale_factor; sh++) {
for (int sw = 0; sw < upscale_factor; sw++) {
float* outptr = outptr_nc + sh * out_width + sw;
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
outptr[0] = inptr[0];
inptr++;
outptr += upscale_factor;
}
outptr += (upscale_factor - 1) * out_width;
}
}
}
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(pixel_shuffle,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::PixelShuffleCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class PixelShuffleCompute
: public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::PixelShuffleParam;
void Run() override;
virtual ~PixelShuffleCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -25,6 +25,7 @@ bool PixelShuffleOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.upscale_factor);
const auto x_dims = param_.x->dims();
const auto upscale_factor = param_.upscale_factor;
CHECK_EQ_OR_FALSE(x_dims.size(), 4);
CHECK_EQ_OR_FALSE(x_dims[1] % (upscale_factor * upscale_factor), 0);
return true;
}
......
......@@ -36,6 +36,18 @@ class PixelShuffleOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "pixel_shuffle"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
auto input_dims = param_.x->dims();
auto output_dims = param_.output->dims();
ch->input_shape = ch->DimToStr(input_dims);
ch->output_shape = ch->DimToStr(output_dims);
ch->remark = "upscale_factor" + std::to_string(param_.upscale_factor);
ch->macs = 1;
}
#endif
private:
mutable PixelShuffleParam param_;
};
......
......@@ -65,6 +65,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_ctc_align_compute SRCS ctc_align_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_clip_compute SRCS clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_pixel_shuffle_compute SRCS pixel_shuffle_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
# for training kernel
if (LITE_WITH_TRAIN)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class PixelShuffleComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "X";
std::string output_ = "Out";
int upscale_factor_ = 3;
DDim dims_{{2, 27, 20, 30}};
public:
PixelShuffleComputeTester(const Place& place,
const std::string& alias,
int upscale_factor,
int n,
int c,
int h,
int w)
: TestCase(place, alias), upscale_factor_(upscale_factor) {
dims_ = DDim(std::vector<int64_t>({n, c, h, w}));
}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
const int64_t batch_size = dims_[0];
const int64_t out_channels = dims_[1] / (upscale_factor_ * upscale_factor_);
const int64_t out_height = dims_[2] * upscale_factor_;
const int64_t out_width = dims_[3] * upscale_factor_;
int64_t nchw[] = {batch_size, out_channels, out_height, out_width};
std::vector<int64_t> output_shape(nchw, nchw + 4);
DDim output_dims(output_shape);
out->Resize(output_dims);
auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
const auto* x_data = x->data<float>();
for (int nc = 0; nc < batch_size * out_channels; nc++) {
const float* inptr = x_data + nc * out_height * out_width;
float* outptr_nc = output_data + nc * out_height * out_width;
for (int sh = 0; sh < upscale_factor_; sh++) {
for (int sw = 0; sw < upscale_factor_; sw++) {
float* outptr = outptr_nc + sh * out_width + sw;
for (int h = 0; h < dims_[2]; h++) {
for (int w = 0; w < dims_[3]; w++) {
outptr[0] = inptr[0];
inptr++;
outptr += upscale_factor_;
}
outptr += (upscale_factor_ - 1) * out_width;
}
}
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("pixel_shuffle");
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("upscale_factor", upscale_factor_);
}
void PrepareData() override {
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(input_, dims_, din.data());
}
};
TEST(PixelShuffle, precision) {
LOG(INFO) << "test pixel_shuffle op";
#ifdef LITE_WITH_ARM
LOG(INFO) << "test pixel_shuffle arm";
Place place(TARGET(kARM));
for (int upscale_factor : {1, 2, 3, 4, 5}) {
for (int n : {1, 3}) {
for (int c : {3 * upscale_factor * upscale_factor,
6 * upscale_factor * upscale_factor}) {
for (int h : {9, 18}) {
for (int w : {9, 18}) {
std::unique_ptr<arena::TestCase> tester(
new PixelShuffleComputeTester(
place, "def", upscale_factor, n, c, h, w));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册