提交 53ef685e 编写于 作者: J jackzhang235

add mlu layout transform file

上级 c5e83404
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.ddNod
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/layout_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace mlu {} // namespace mlu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFloat,
kNHWC,
paddle::lite::kernels::mlu::LayoutNhwcToNchwCompute<PRECISION(kFloat)>,
def_layout_nhwc2nchw_fp32)
.BindInput("Inputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Outputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFP16,
kNHWC,
paddle::lite::kernels::mlu::LayoutNhwcToNchwCompute<PRECISION(kFP16)>,
def_layout_nhwc2nchw_fp16)
.BindInput("Inputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Outputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFloat,
kNHWC,
paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute<PRECISION(kFloat)>,
def_layout_nchw2nhwc_fp32)
.BindInput("Inputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Outputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFP16,
kNHWC,
paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute<PRECISION(kFP16)>,
def_layout_nchw2nhwc_fp16)
.BindInput("Inputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Outputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <string>
#include <vector>
#include "lite/backends/x86/math/math_function.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/layout_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace mlu {
template <lite::TargetType Target, typename T>
inline void LayoutTransCompute(const int dim,
const lite::Context<Target>& context,
const lite::Tensor& in,
lite::Tensor* out,
const std::vector<int>& axis) {
switch (dim) {
case 2:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 2> trans2;
trans2(context, in, out, axis);
break;
case 3:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 3> trans3;
trans3(context, in, out, axis);
break;
case 4:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 4> trans4;
trans4(context, in, out, axis);
break;
default:
CHECK(0) << ("Unsupport dim in mlu layout");
}
}
template <PrecisionType Precision>
class LayoutNchwToNhwcCompute
: public KernelLite<TARGET(kMLU), Precision, DATALAYOUT(kNHWC)> {
public:
using param_t = operators::LayoutParam;
void Run() override {
auto& param = this->template Param<param_t>();
auto* x = param.x;
auto* out = param.y;
out->template mutable_data<float>();
auto x_dims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>();
std::vector<int> axis;
switch (x_dims) {
case 2:
axis = {0, 1};
break;
case 3:
axis = {0, 2, 1};
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[1]});
break;
case 4:
axis = {0, 2, 3, 1};
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[3], out->dims()[1]});
break;
default:
CHECK(0) << "Unsupport dim in mlu layout nchw to nhwc";
}
LayoutTransCompute<lite::TargetType::X86, float>(
x_dims, context, *x, out, axis);
)
}
std::string doc() const override {
return "Mlu layout transform nchw to nhwc";
}
};
template <PrecisionType Precision>
class LayoutNhwcToNchwCompute
: public KernelLite<TARGET(kMLU), Precision, DATALAYOUT(kNHWC)> {
public:
using param_t = operators::LayoutParam;
void Run() override {
auto& param = this->template Param<param_t>();
auto* x = param.x;
auto* out = param.y;
out->template mutable_data<float>();
auto x_dims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>();
std::vector<int> axis;
switch (x_dims) {
case 2:
axis = {0, 1};
break;
case 3:
axis = {0, 2, 1};
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[1]});
break;
case 4:
axis = {0, 3, 1, 2};
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[3], out->dims()[1], out->dims()[2]});
break;
default:
CHECK(0) << "Unsupport dim in mlu layout nhwc to nchw";
}
LayoutTransCompute<lite::TargetType::X86, float>(
x_dims, context, *x, out, axis);
)
}
std::string doc() const override {
return "Mlu layout transform nhwc to nchw";
}
};
} // namespace mlu
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册