未验证 提交 2f317a5e 编写于 作者: X xiebaiyuan 提交者: GitHub

[OPENCL] add pixel_shuffle operator, test=develop (#3806)

上级 80285e85
...@@ -111,6 +111,7 @@ add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposal ...@@ -111,6 +111,7 @@ add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposal
add_operator(crf_decoding_op_lite extra SRCS crf_decoding_op.cc DEPS ${op_DEPS}) add_operator(crf_decoding_op_lite extra SRCS crf_decoding_op.cc DEPS ${op_DEPS})
add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS}) add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS})
add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${op_DEPS}) add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${op_DEPS})
add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS})
# for OCR specific # for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/pixel_shuffle_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool PixelShuffleOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(param_.upscale_factor);
const auto x_dims = param_.x->dims();
const auto upscale_factor = param_.upscale_factor;
CHECK_EQ_OR_FALSE(x_dims[1] % (upscale_factor * upscale_factor), 0);
return true;
}
bool PixelShuffleOpLite::InferShapeImpl() const {
const auto x_dims = param_.x->dims();
const auto upscale_factor = param_.upscale_factor;
auto output_dims = x_dims;
output_dims[0] = x_dims[0];
output_dims[1] = x_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = x_dims[2] * upscale_factor;
output_dims[3] = x_dims[3] * upscale_factor;
param_.output->Resize(output_dims);
return true;
}
bool PixelShuffleOpLite::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
auto input = opdesc.Input("X").front();
auto out = opdesc.Output("Out").front();
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
if (opdesc.HasAttr("upscale_factor")) {
param_.upscale_factor = opdesc.GetAttr<int>("upscale_factor");
}
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(pixel_shuffle, paddle::lite::operators::PixelShuffleOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class PixelShuffleOpLite : public OpLite {
public:
PixelShuffleOpLite() {}
explicit PixelShuffleOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "pixel_shuffle"; }
private:
mutable PixelShuffleParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册