diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 7a62ad980560d46bf425dc41fd3b035ee693a7f8..76f223c3d591a0a03dddde021d6ca2a2dc977c68 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -111,6 +111,7 @@ add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposal add_operator(crf_decoding_op_lite extra SRCS crf_decoding_op.cc DEPS ${op_DEPS}) add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS}) add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${op_DEPS}) +add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/pixel_shuffle_op.cc b/lite/operators/pixel_shuffle_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..40f564bdd6d2699bafe497bdfded21ea4f3956a3 --- /dev/null +++ b/lite/operators/pixel_shuffle_op.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/pixel_shuffle_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool PixelShuffleOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + CHECK_OR_FALSE(param_.upscale_factor); + const auto x_dims = param_.x->dims(); + const auto upscale_factor = param_.upscale_factor; + CHECK_EQ_OR_FALSE(x_dims[1] % (upscale_factor * upscale_factor), 0); + return true; +} + +bool PixelShuffleOpLite::InferShapeImpl() const { + const auto x_dims = param_.x->dims(); + const auto upscale_factor = param_.upscale_factor; + auto output_dims = x_dims; + output_dims[0] = x_dims[0]; + output_dims[1] = x_dims[1] / (upscale_factor * upscale_factor); + output_dims[2] = x_dims[2] * upscale_factor; + output_dims[3] = x_dims[3] * upscale_factor; + param_.output->Resize(output_dims); + return true; +} + +bool PixelShuffleOpLite::AttachImpl(const cpp::OpDesc& opdesc, + lite::Scope* scope) { + auto input = opdesc.Input("X").front(); + auto out = opdesc.Output("Out").front(); + + param_.x = scope->FindVar(input)->GetMutable(); + param_.output = scope->FindVar(out)->GetMutable(); + + if (opdesc.HasAttr("upscale_factor")) { + param_.upscale_factor = opdesc.GetAttr("upscale_factor"); + } + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(pixel_shuffle, paddle::lite::operators::PixelShuffleOpLite); diff --git a/lite/operators/pixel_shuffle_op.h b/lite/operators/pixel_shuffle_op.h new file mode 100644 index 0000000000000000000000000000000000000000..63efd8df778c6d92bc448f795c19ff5bffba62c8 --- /dev/null +++ b/lite/operators/pixel_shuffle_op.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class PixelShuffleOpLite : public OpLite { + public: + PixelShuffleOpLite() {} + explicit PixelShuffleOpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "pixel_shuffle"; } + + private: + mutable PixelShuffleParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle