diff --git a/lite/kernels/mlu/layout_compute.cc b/lite/kernels/mlu/layout_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..85d89c258cd21dfd1b847d325308d5653ffa3be7 --- /dev/null +++ b/lite/kernels/mlu/layout_compute.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.ddNod +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/layout_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace mlu {} // namespace mlu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + layout, + kMLU, + kFloat, + kNHWC, + paddle::lite::kernels::mlu::LayoutNhwcToNchwCompute, + def_layout_nhwc2nchw_fp32) + .BindInput("Inputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Outputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + layout, + kMLU, + kFP16, + kNHWC, + paddle::lite::kernels::mlu::LayoutNhwcToNchwCompute, + def_layout_nhwc2nchw_fp16) + .BindInput("Inputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Outputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + layout, + kMLU, + kFloat, + kNHWC, + paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute, + def_layout_nchw2nhwc_fp32) + .BindInput("Inputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Outputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + layout, + kMLU, + kFP16, + kNHWC, + paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute, + def_layout_nchw2nhwc_fp16) + .BindInput("Inputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Outputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); diff --git a/lite/kernels/mlu/layout_compute.h b/lite/kernels/mlu/layout_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..48fa485163139b5c382b770adc4025c3bbe99822 --- /dev/null +++ b/lite/kernels/mlu/layout_compute.h @@ -0,0 +1,145 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/backends/x86/math/math_function.h" +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" +#include "lite/operators/layout_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace mlu { + +template +inline void LayoutTransCompute(const int dim, + const lite::Context& context, + const lite::Tensor& in, + lite::Tensor* out, + const std::vector& axis) { + switch (dim) { + case 2: + paddle::lite::x86::math::Transpose trans2; + trans2(context, in, out, axis); + break; + case 3: + paddle::lite::x86::math::Transpose trans3; + trans3(context, in, out, axis); + break; + case 4: + paddle::lite::x86::math::Transpose trans4; + trans4(context, in, out, axis); + break; + default: + CHECK(0) << ("Unsupport dim in mlu layout"); + } +} + +template +class LayoutNchwToNhwcCompute + : public KernelLite { + public: + using param_t = operators::LayoutParam; + + void Run() override { + auto& param = this->template Param(); + auto* x = param.x; + auto* out = param.y; + out->template mutable_data(); + auto x_dims = param.x->dims().size(); + auto& context = this->ctx_->template As(); + + std::vector axis; + switch (x_dims) { + case 2: + axis = {0, 1}; + break; + case 3: + axis = {0, 2, 1}; + out->Resize(std::vector{ + out->dims()[0], out->dims()[2], out->dims()[1]}); + break; + case 4: + axis = {0, 2, 3, 1}; + out->Resize(std::vector{ + out->dims()[0], out->dims()[2], out->dims()[3], out->dims()[1]}); + break; + default: + CHECK(0) << "Unsupport dim in mlu layout nchw to nhwc"; + } + + LayoutTransCompute( + x_dims, context, *x, out, axis); + ) + } + + std::string doc() const override { + return "Mlu layout transform nchw to nhwc"; + } +}; + +template +class LayoutNhwcToNchwCompute + : public KernelLite { + public: + using param_t = operators::LayoutParam; + + void Run() override { + auto& param = this->template Param(); + auto* x = param.x; + auto* out = param.y; + out->template mutable_data(); + auto x_dims = param.x->dims().size(); + auto& context = this->ctx_->template As(); + + std::vector axis; + switch (x_dims) { + case 2: + axis = {0, 1}; + break; + case 3: + axis = {0, 2, 1}; + out->Resize(std::vector{ + out->dims()[0], out->dims()[2], out->dims()[1]}); + break; + case 4: + axis = {0, 3, 1, 2}; + out->Resize(std::vector{ + out->dims()[0], out->dims()[3], out->dims()[1], out->dims()[2]}); + break; + default: + CHECK(0) << "Unsupport dim in mlu layout nhwc to nchw"; + } + + LayoutTransCompute( + x_dims, context, *x, out, axis); + ) + } + + std::string doc() const override { + return "Mlu layout transform nhwc to nchw"; + } +}; + +} // namespace mlu +} // namespace kernels +} // namespace lite +} // namespace paddle