未验证 提交 a2547758 编写于 作者: M MaxwellDing 提交者: GitHub

feat: Mlu cast kernel (#111)

上级 0e1f6cb0
......@@ -30,6 +30,8 @@ namespace mir {
static thread_local int g_stream_id = 0;
#define ENABLE_HOST_CAST false
Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
const std::string& cast_arg_name,
SSAGraph* graph,
......@@ -77,8 +79,17 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
for (auto& kernel : kernels) {
if (op_type == "cast") {
const Type* in_arg_ty = kernel->GetInputDeclType("X");
#if !ENABLE_HOST_CAST
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
#endif
if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type) &&
DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type)) {
#if !ENABLE_HOST_CAST
PrecisionCompatibleTo(*out_arg_ty, *cast_type) &&
TargetCompatibleTo(*out_arg_ty, *cast_type)
#else
DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type)
#endif
) {
is_found = true;
}
} else if (op_type == "layout") {
......@@ -86,6 +97,11 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
DataLayoutCompatible(*out_arg_ty, *cast_type) &&
#if !ENABLE_HOST_CAST
TargetCompatibleTo(*out_arg_ty, *cast_type) &&
#else
TargetCompatibleTo(*in_arg_ty, *cur_node->AsArg().type) &&
#endif
// for first conv
PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type)) {
is_found = true;
......@@ -95,8 +111,10 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TargetCompatibleTo(*in_arg_ty, *cur_node->AsArg().type) &&
TargetCompatibleTo(*out_arg_ty, *cast_type) &&
PrecisionCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
PrecisionCompatible(*out_arg_ty, *cast_type)) {
#if ENABLE_HOST_CAST
PrecisionCompatible(*out_arg_ty, *cast_type) &&
#endif
PrecisionCompatible(*in_arg_ty, *cur_node->AsArg().type)) {
is_found = true;
}
} else {
......@@ -170,7 +188,15 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
for (auto& kernel : kernels) {
if (op_type == "cast") {
const Type* in_arg_ty = kernel->GetInputDeclType("X");
if (PrecisionCompatibleTo(*in_arg_ty, *cast_type)) {
#if !ENABLE_HOST_CAST
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
#endif
if (
#if !ENABLE_HOST_CAST
PrecisionCompatibleTo(*out_arg_ty, *cur_node->AsArg().type) &&
TargetCompatibleTo(*in_arg_ty, *cast_type) &&
#endif
PrecisionCompatibleTo(*in_arg_ty, *cast_type)) {
is_found = true;
}
} else if (op_type == "layout") {
......@@ -178,6 +204,11 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (DataLayoutCompatible(*in_arg_ty, *cast_type) &&
DataLayoutCompatible(*out_arg_ty, *cur_node->AsArg().type) &&
#if !ENABLE_HOST_CAST
TargetCompatibleTo(*in_arg_ty, *cast_type) &&
#else
TargetCompatibleTo(*out_arg_ty, *cur_node->AsArg().type) &&
#endif
PrecisionCompatibleTo(*in_arg_ty, *cast_type)) {
is_found = true;
}
......@@ -186,8 +217,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TargetCompatibleTo(*in_arg_ty, *cast_type) &&
TargetCompatibleTo(*out_arg_ty, *cur_node->AsArg().type) &&
#if !ENABLE_HOST_CAST
PrecisionCompatible(*out_arg_ty, *cur_node->AsArg().type)
#else
PrecisionCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
PrecisionCompatible(*out_arg_ty, *cast_type)) {
PrecisionCompatible(*out_arg_ty, *cast_type)
#endif
) {
is_found = true;
}
} else {
......@@ -214,8 +250,7 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
Node* head_node,
Node* inst_node,
const Type* inst_type,
bool use_mlu_cast) {
const Type* inst_type) {
const auto* head_type = head_node->AsArg().type;
// break original link
......@@ -230,31 +265,46 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
head_node->AsArg().name) != first_conv_nodes_.end();
// precision cast node
if (!use_mlu_cast) {
if (!fuse_cast_) {
#if !ENABLE_HOST_CAST
// io copy
cur_node = InsertCastBefore(
"io_copy",
name_prefix + "io_copy",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(
inst_type->target(), head_type->precision(), head_type->layout()));
#endif
if (head_type->precision() != inst_type->precision() &&
!is_first_conv_head) {
cur_node = InsertCastBefore("cast",
name_prefix + "cast",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(head_type->target(),
inst_type->precision(),
head_type->layout()));
#if ENABLE_HOST_CAST
auto type = LiteType::GetTensorTy(
head_type->target(), inst_type->precision(), head_type->layout());
#else
auto type = LiteType::GetTensorTy(
inst_type->target(), inst_type->precision(), head_type->layout());
#endif
cur_node = InsertCastBefore(
"cast", name_prefix + "cast", graph, cur_node, inst_node, type);
}
// layout cast node
if (head_type->layout() != inst_type->layout()) {
cur_node = InsertCastBefore("layout",
name_prefix + "layout",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(head_type->target(),
inst_type->precision(),
inst_type->layout()));
#if ENABLE_HOST_CAST
auto type = LiteType::GetTensorTy(
head_type->target(), inst_type->precision(), inst_type->layout());
#else
auto type = LiteType::GetTensorTy(
inst_type->target(), inst_type->precision(), inst_type->layout());
#endif
cur_node = InsertCastBefore(
"layout", name_prefix + "layout", graph, cur_node, inst_node, type);
}
#if ENABLE_HOST_CAST
// io copy
cur_node = InsertCastBefore(
"io_copy",
......@@ -264,6 +314,7 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
inst_node,
LiteType::GetTensorTy(
inst_type->target(), inst_type->precision(), inst_type->layout()));
#endif
} else {
// io copy
cur_node = InsertCastBefore(
......@@ -380,8 +431,7 @@ bool MLUPostprocessPass::NeedInsert(Node* node, const Type* inst_type) {
void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
Node* tail_node,
Node* inst_node,
const Type* inst_type,
bool use_mlu_cast) {
const Type* inst_type) {
const auto* tail_type = tail_node->AsArg().type;
// break original link
......@@ -392,30 +442,45 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
tail_node->AsArg().name + string_format("_%p", inst_node) + "/trans_";
// precision cast node
if (!use_mlu_cast) {
if (!fuse_cast_) {
#if !ENABLE_HOST_CAST
// io copy
cur_node = InsertCastAfter(
"io_copy",
name_prefix + "io_copy",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(
inst_type->target(), tail_type->precision(), tail_type->layout()));
#endif
if (tail_type->precision() != inst_type->precision()) {
cur_node = InsertCastAfter("cast",
name_prefix + "cast",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(tail_type->target(),
inst_type->precision(),
tail_type->layout()));
#if ENABLE_HOST_CAST
auto type = LiteType::GetTensorTy(
tail_type->target(), inst_type->precision(), tail_type->layout());
#else
auto type = LiteType::GetTensorTy(
inst_type->target(), inst_type->precision(), tail_type->layout());
#endif
cur_node = InsertCastAfter(
"cast", name_prefix + "cast", graph, cur_node, inst_node, type);
}
// layout cast node
if (tail_type->layout() != inst_type->layout()) {
cur_node = InsertCastAfter("layout",
name_prefix + "layout",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(tail_type->target(),
inst_type->precision(),
inst_type->layout()));
#if ENABLE_HOST_CAST
auto type = LiteType::GetTensorTy(
tail_type->target(), inst_type->precision(), inst_type->layout());
#else
auto type = LiteType::GetTensorTy(
inst_type->target(), inst_type->precision(), inst_type->layout());
#endif
cur_node = InsertCastAfter(
"layout", name_prefix + "layout", graph, cur_node, inst_node, type);
}
#if ENABLE_HOST_CAST
// io copy
cur_node = InsertCastAfter(
"io_copy",
......@@ -425,6 +490,7 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
inst_node,
LiteType::GetTensorTy(
inst_type->target(), inst_type->precision(), inst_type->layout()));
#endif
} else {
cur_node = InsertCastAfter(
"io_copy",
......@@ -549,6 +615,14 @@ void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) {
in_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kMLU), PRECISION(kAny), DATALAYOUT(kNHWC));
} else {
if (!in_node->inlinks.empty()) {
auto& upkernel = in_node->inlinks.front()->AsStmt().picked_kernel();
if (upkernel.target() == TARGET(kMLU)) {
in_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kMLU), upkernel.precision(), upkernel.layout());
continue;
}
}
CHECK((in_node_type->target() == TARGET(kHost) ||
in_node_type->target() == TARGET(kX86)) &&
(in_node_type->precision() == PRECISION(kFloat) ||
......@@ -585,6 +659,13 @@ void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) {
VLOG(4) << "unused output node type: " << out_arg.name
<< out_node_type->name();
} else {
auto& downkernel =
out_node->outlinks.front()->AsStmt().picked_kernel();
if (downkernel.target() == TARGET(kMLU)) {
out_arg.type = LiteType::GetTensorTy(
TARGET(kMLU), downkernel.precision(), downkernel.layout());
continue;
}
out_arg.type = LiteType::GetTensorTy(
TARGET(kHost), out_node_type->precision(), DATALAYOUT(kNCHW));
VLOG(4) << "output node type: " << out_arg.name
......@@ -840,7 +921,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
op->SetSubBlock(new_block_desc);
}
void ModifyValidPlaces(SSAGraph* graph, bool use_mlu_cast) {
void ModifyValidPlaces(SSAGraph* graph, bool fuse_cast) {
// remove invalid places, since only support X86, host, MLU
auto v_places = graph->valid_places();
for (auto it = v_places.begin(); it != v_places.end();) {
......@@ -852,23 +933,27 @@ void ModifyValidPlaces(SSAGraph* graph, bool use_mlu_cast) {
}
}
if (use_mlu_cast) {
if (fuse_cast) {
// insert mlu float place for float io copy, no effect to subgraph type
v_places.emplace_back(TARGET(kMLU), PRECISION(kFloat), DATALAYOUT(kNHWC));
} else {
// add x86 NHWC place for cpu cast
#if !USE_HOST_CAST
// add MLU NCHW place for cast kernel
std::set<paddle::lite_api::PrecisionType> prec_set{};
for (auto& place : v_places) {
prec_set.insert(place.precision);
}
prec_set.insert(PRECISION(kFloat));
#ifdef LITE_WITH_MLU
if (lite::TargetWrapperMlu::UseFirstConv()) {
prec_set.insert(PRECISION(kInt8));
}
#endif
for (auto& prec : prec_set) {
v_places.emplace_back(TARGET(kX86), prec, DATALAYOUT(kNHWC));
v_places.emplace_back(TARGET(kMLU), prec, DATALAYOUT(kNCHW));
}
v_places.emplace_back(TARGET(kMLU), PRECISION(kFloat), DATALAYOUT(kNHWC));
#endif
}
graph->SetValidPlaces(v_places);
......@@ -899,29 +984,27 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
#endif
g_stream_id = static_cast<int>(reinterpret_cast<int64_t>(graph.get()));
bool disable_mlu_cast = GetBoolFromEnv("LITE_DISABLE_MLU_CAST");
ModifyValidPlaces(graph.get(), !disable_mlu_cast);
fuse_cast_ = GetBoolFromEnv("LITE_MLU_FUSE_CAST");
ModifyValidPlaces(graph.get(), fuse_cast_);
// insert io_copy, layout and precision cast of subgraph's inputs and outputs
for (auto& node : graph->mutable_nodes()) {
if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") {
const Type* subgraph_arg_type = nullptr;
GetSubgraphOpArgType(&node, &subgraph_arg_type, graph.get());
if (!disable_mlu_cast) {
if (fuse_cast_) {
AdjustSubgraph(&node, subgraph_arg_type);
}
auto links_tmp = node.inlinks;
for (auto p_in : links_tmp) {
if (NeedInsert(p_in, subgraph_arg_type)) {
InsertBefore(
graph.get(), p_in, &node, subgraph_arg_type, !disable_mlu_cast);
InsertBefore(graph.get(), p_in, &node, subgraph_arg_type);
}
}
links_tmp.assign(node.outlinks.begin(), node.outlinks.end());
for (auto p_out : links_tmp) {
if (NeedInsert(p_out, subgraph_arg_type)) {
InsertAfter(
graph.get(), p_out, &node, subgraph_arg_type, !disable_mlu_cast);
InsertAfter(graph.get(), p_out, &node, subgraph_arg_type);
}
}
}
......
......@@ -88,14 +88,12 @@ class MLUPostprocessPass : public ProgramPass {
void InsertBefore(SSAGraph* graph,
Node* head_node,
Node* inst_node,
const Type* type,
bool use_mlu_cast);
const Type* type);
void InsertAfter(SSAGraph* graph,
Node* tail_node,
Node* inst_node,
const Type* type,
bool use_mlu_cast);
const Type* type);
Node* InsertCastBefore(const std::string& op_type,
const std::string& cast_arg_name,
......@@ -123,6 +121,7 @@ class MLUPostprocessPass : public ProgramPass {
private:
std::set<std::string> first_conv_nodes_;
bool fuse_cast_{false};
};
} // namespace mir
......
......@@ -7,4 +7,5 @@ add_kernel(subgraph_compute_mlu MLU basic SRCS subgraph_compute.cc DEPS ${lite_k
add_kernel(io_copy_compute_mlu MLU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} ${target_wrapper_mlu})
add_kernel(calib_compute_mlu MLU basic SRCS calib_compute.cc DEPS ${lite_kernel_deps})
# depend on transpose function in backend/x86/math/math_function
add_kernel(layout_compute_mlu MLU basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} ${math_function})
add_kernel(layout_compute_mlu MLU basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} ${math_function} ${target_wrapper_mlu})
add_kernel(cast_compute_mlu MLU basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} ${target_wrapper_mlu})
......@@ -180,30 +180,35 @@ template <paddle::lite_api::PrecisionType>
struct MLUTypeTraits {
/* using type = void; */
/* static constexpr cnmlDataType_t cnml_type = CNML_DATA_INVALID; */
/* static constexpr int proto_type = 17; */
};
template <>
struct MLUTypeTraits<paddle::lite_api::PrecisionType::kFloat> {
using type = float;
static constexpr cnmlDataType_t cnml_type = CNML_DATA_FLOAT32;
static constexpr int proto_type = 5;
};
template <>
struct MLUTypeTraits<paddle::lite_api::PrecisionType::kFP16> {
using type = paddle::lite::fluid::float16;
static constexpr cnmlDataType_t cnml_type = CNML_DATA_FLOAT16;
static constexpr int proto_type = 4;
};
template <>
struct MLUTypeTraits<paddle::lite_api::PrecisionType::kInt8> {
using type = int8_t;
static constexpr cnmlDataType_t cnml_type = CNML_DATA_INT8;
static constexpr int proto_type = 21;
};
template <>
struct MLUTypeTraits<paddle::lite_api::PrecisionType::kInt32> {
using type = int32_t;
static constexpr cnmlDataType_t cnml_type = CNML_DATA_INT32;
static constexpr int proto_type = 2;
};
} // namespace mlu
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/cast_compute.h"
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
REGISTER_LITE_KERNEL(cast,
kMLU,
kFloat,
kNHWC,
paddle::lite::kernels::mlu::CastFp32toFp16,
fp32_to_fp16)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(cast,
kMLU,
kFloat,
kNHWC,
paddle::lite::kernels::mlu::CastFp16toFp32,
fp16_to_fp32)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <map>
#include <memory>
#include <vector>
#include "lite/backends/mlu/mlu_utils.h"
#include "lite/core/kernel.h"
#include "lite/kernels/mlu/bridges/tensor.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/mlu/mlu_operator.h"
#include "lite/operators/cast_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace mlu {
template <lite_api::PrecisionType in_dtype, lite_api::PrecisionType out_dtype>
class CastCompute
: public KernelLite<TARGET(kMLU), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::CastParam;
void Run() override {
auto param = param_.get_mutable<param_t>();
auto& mlu_context = this->ctx_->template As<MLUContext>();
auto in_dims = param->X->dims().Vectorize();
// key to map op
std::vector<int> ishape;
std::transform(in_dims.cbegin(),
in_dims.cend(),
std::back_inserter(ishape),
[](DDim::value_type in) { return static_cast<int>(in); });
// find compiled instruction at ishape
auto op_iter = inst_map_.find(ishape);
if (op_iter == inst_map_.end()) {
auto res = inst_map_.insert(
{ishape, CompileOperator(param, &mlu_context, ishape)});
CHECK(res.second);
op_iter = res.first;
}
// prepare param
auto exec_queue = mlu_context.exec_queue();
cnrtInvokeFuncParam_t forward_param = mlu_context.forward_param();
int data_param = 1;
forward_param.data_parallelism = &data_param;
u32_t affinity = mlu_context.affinity();
forward_param.affinity = &affinity;
forward_param.end = CNRT_PARAM_END;
// get input and output
param->Out->set_precision(out_dtype);
const void* input = param->X->template data<
typename subgraph::mlu::MLUTypeTraits<in_dtype>::type>();
/* void* output = param->Out->mutable_data(TARGET(kMLU), out_size); */
void* output = param->Out->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<out_dtype>::type>(TARGET(kMLU));
// compute op
CNML_CALL(cnmlComputeCastOpForward_V3(op_iter->second->cnml_op,
const_cast<void*>(input),
output,
&forward_param,
exec_queue));
}
~CastCompute() override{};
private:
inline cnmlCastType_t GetCastType(param_t* param) {
CHECK_EQ(subgraph::mlu::MLUTypeTraits<in_dtype>::proto_type,
param->in_dtype);
CHECK_EQ(subgraph::mlu::MLUTypeTraits<out_dtype>::proto_type,
param->out_dtype);
if (in_dtype == PRECISION(kFP16) && out_dtype == PRECISION(kFloat)) {
VLOG(4) << "choose float16 to float32";
return CNML_CAST_FLOAT16_TO_FLOAT32;
} else if (in_dtype == PRECISION(kFloat) && out_dtype == PRECISION(kFP16)) {
VLOG(4) << "choose float32 to float16";
return CNML_CAST_FLOAT32_TO_FLOAT16;
} else {
CHECK(0) << "Unsupported cast type";
}
return CNML_CAST_FLOAT32_TO_FLOAT16;
}
std::shared_ptr<MLUOperator> CompileOperator(param_t* param,
MLUContext* ctx,
std::vector<int> dims) {
VLOG(4) << "compile cast operator";
// get cast type
auto cast_type = GetCastType(param);
// prepare op and io tensor
auto op = std::make_shared<MLUOperator>();
op->input_tensors.emplace_back();
op->output_tensors.emplace_back();
int* dim_strides = nullptr;
CNML_CALL(cnmlCreateTensor_V2(&op->input_tensors[0], CNML_TENSOR));
CNML_CALL(cnmlSetTensorShape_V2(
op->input_tensors[0], dims.size(), dims.data(), dim_strides));
CNML_CALL(cnmlSetTensorDataType(
op->input_tensors[0],
subgraph::mlu::MLUTypeTraits<in_dtype>::cnml_type));
CNML_CALL(cnmlCreateTensor_V2(&op->output_tensors[0], CNML_TENSOR));
CNML_CALL(cnmlSetTensorShape_V2(
op->output_tensors[0], dims.size(), dims.data(), dim_strides));
CNML_CALL(cnmlSetTensorDataType(
op->output_tensors[0],
subgraph::mlu::MLUTypeTraits<out_dtype>::cnml_type));
CNML_CALL(cnmlCreateCastOp(
&op->cnml_op, cast_type, op->input_tensors[0], op->output_tensors[0]));
CNML_CALL(cnmlSetBaseOpCorenum(op->cnml_op, ctx->MLUCoreNumber()));
CNML_CALL(cnmlSetBaseOpCoreVersion(op->cnml_op, ctx->MLUCoreVersion()));
CNML_CALL(cnmlCompileBaseOp_V2(op->cnml_op));
return op;
}
private:
std::map<std::vector<int>, std::shared_ptr<MLUOperator>> inst_map_;
};
using CastFp32toFp16 =
paddle::lite::kernels::mlu::CastCompute<PRECISION(kFloat),
PRECISION(kFP16)>;
using CastFp16toFp32 =
paddle::lite::kernels::mlu::CastCompute<PRECISION(kFP16),
PRECISION(kFloat)>;
} // namespace mlu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -14,14 +14,7 @@
#include "lite/kernels/mlu/layout_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace mlu {} // namespace mlu
} // namespace kernels
} // namespace lite
} // namespace paddle
// X86 layout kernel
REGISTER_LITE_KERNEL(
layout,
kX86,
......@@ -106,3 +99,89 @@ REGISTER_LITE_KERNEL(
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
// MLU layout kernel
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFloat,
kNHWC,
paddle::lite::kernels::mlu::LayoutNHWC2NCHW<PRECISION(kFloat)>,
def_layout_nhwc2nchw_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFP16,
kNHWC,
paddle::lite::kernels::mlu::LayoutNHWC2NCHW<PRECISION(kFP16)>,
def_layout_nhwc2nchw_fp16)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFloat,
kNHWC,
paddle::lite::kernels::mlu::LayoutNCHW2NHWC<PRECISION(kFloat)>,
def_layout_nchw2nhwc_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFP16,
kNHWC,
paddle::lite::kernels::mlu::LayoutNCHW2NHWC<PRECISION(kFP16)>,
def_layout_nchw2nhwc_fp16)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kInt8,
kNHWC,
paddle::lite::kernels::mlu::LayoutNCHW2NHWC<PRECISION(kInt8)>,
def_layout_nchw2nhwc_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
......@@ -15,6 +15,9 @@
#pragma once
#include <Eigen/Core>
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "lite/backends/x86/math/math_function.h"
......@@ -23,6 +26,7 @@
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/mlu/mlu_operator.h"
#include "lite/operators/layout_op.h"
namespace paddle {
......@@ -151,6 +155,125 @@ class LayoutNhwcToNchwCompute
}
};
template <PrecisionType Precision, DataLayoutType in_layout>
class LayoutComputeMlu
: public KernelLite<TARGET(kMLU), Precision, DATALAYOUT(kNHWC)> {
public:
using param_t = operators::LayoutParam;
void Run() override {
auto& param = this->template Param<param_t>();
auto* x = param.x;
auto* y = param.y;
auto in_dims = x->dims().Vectorize();
y->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>();
auto& context = this->ctx_->template As<MLUContext>();
// key to map op
std::vector<int> ishape;
std::transform(in_dims.cbegin(),
in_dims.cend(),
std::back_inserter(ishape),
[](DDim::value_type in) { return static_cast<int>(in); });
// find compiled instruction at ishape
auto op_iter = inst_map_.find(ishape);
if (op_iter == inst_map_.end()) {
auto res =
inst_map_.insert({ishape, CompileOperator(&param, &context, ishape)});
CHECK(res.second);
op_iter = res.first;
}
// prepare param
auto exec_queue = context.exec_queue();
cnrtInvokeFuncParam_t forward_param = context.forward_param();
int data_param = 1;
forward_param.data_parallelism = &data_param;
u32_t affinity = context.affinity();
forward_param.affinity = &affinity;
forward_param.end = CNRT_PARAM_END;
// get input and output
auto mem_size = x->memory_size();
y->set_precision(Precision);
const void* input = x->template data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>();
void* output = y->mutable_data(TARGET(kMLU), mem_size);
// compute op
CNML_CALL(cnmlComputeNdTransposeProOpForward(op_iter->second->cnml_op,
const_cast<void*>(input),
output,
&forward_param,
exec_queue));
}
std::string doc() const override { return "Mlu layout transform"; }
private:
std::shared_ptr<MLUOperator> CompileOperator(param_t* param,
MLUContext* ctx,
std::vector<int> dims) {
VLOG(4) << "compile layout operator";
// get transpose axis
std::vector<int> axis;
std::vector<int> in_dims, out_dims;
if (in_layout == DATALAYOUT(kNCHW)) {
VLOG(4) << "trans layout from NCHW to NHWC";
axis = subgraph::mlu::GetAxisNCHW2NHWC<int>(dims.size());
in_dims = dims;
out_dims = subgraph::mlu::DimNCHW2NHWC(dims);
} else {
VLOG(4) << "trans layout from NHWC to NCHW";
axis = subgraph::mlu::GetAxisNHWC2NCHW<int>(dims.size());
in_dims = subgraph::mlu::DimNCHW2NHWC(dims);
out_dims = dims;
}
// prepare op and io tensor
auto op = std::make_shared<MLUOperator>();
op->input_tensors.emplace_back();
op->output_tensors.emplace_back();
int* dim_strides = nullptr;
CNML_CALL(cnmlCreateTensor_V2(&op->input_tensors[0], CNML_TENSOR));
CNML_CALL(cnmlSetTensorShape_V2(
op->input_tensors[0], in_dims.size(), in_dims.data(), dim_strides));
CNML_CALL(cnmlSetTensorDataType(
op->input_tensors[0],
subgraph::mlu::MLUTypeTraits<Precision>::cnml_type));
CNML_CALL(cnmlCreateTensor_V2(&op->output_tensors[0], CNML_TENSOR));
CNML_CALL(cnmlSetTensorShape_V2(
op->output_tensors[0], out_dims.size(), out_dims.data(), dim_strides));
CNML_CALL(cnmlSetTensorDataType(
op->output_tensors[0],
subgraph::mlu::MLUTypeTraits<Precision>::cnml_type));
cnmlNdTransposeOpParam_t transpose_param;
CNML_CALL(cnmlCreateNdTransposeOpParam(
&transpose_param, axis.data(), axis.size()));
CNML_CALL(cnmlCreateNdTransposeProOp(&op->cnml_op,
op->input_tensors[0],
op->output_tensors[0],
transpose_param));
CNML_CALL(cnmlDestroyNdTransposeOpParam(&transpose_param));
CNML_CALL(cnmlSetBaseOpCorenum(op->cnml_op, ctx->MLUCoreNumber()));
CNML_CALL(cnmlSetBaseOpCoreVersion(op->cnml_op, ctx->MLUCoreVersion()));
CNML_CALL(cnmlCompileBaseOp_V2(op->cnml_op));
return op;
}
std::map<std::vector<int>, std::shared_ptr<MLUOperator>> inst_map_;
};
template <PrecisionType precision>
using LayoutNHWC2NCHW = LayoutComputeMlu<precision, DATALAYOUT(kNHWC)>;
template <PrecisionType precision>
using LayoutNCHW2NHWC = LayoutComputeMlu<precision, DATALAYOUT(kNCHW)>;
} // namespace mlu
} // namespace kernels
} // namespace lite
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/backends/mlu/mlu_utils.h"
#include "lite/kernels/mlu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace mlu {
struct MLUOperator {
cnmlBaseOp_t cnml_op = nullptr;
// compile time tensor
std::vector<cnmlTensor_t> input_tensors{};
std::vector<cnmlTensor_t> output_tensors{};
~MLUOperator() {
if (cnml_op != nullptr) {
CNML_CALL(cnmlDestroyBaseOp(&cnml_op));
cnml_op = nullptr;
}
if (!input_tensors.empty()) {
std::for_each(input_tensors.begin(),
input_tensors.end(),
[](cnmlTensor_t t) { CNML_CALL(cnmlDestroyTensor(&t)); });
input_tensors.clear();
}
if (!output_tensors.empty()) {
std::for_each(output_tensors.begin(),
output_tensors.end(),
[](cnmlTensor_t t) { CNML_CALL(cnmlDestroyTensor(&t)); });
output_tensors.clear();
}
}
};
} // namespace mlu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -142,6 +142,18 @@ class SubgraphEngine : public subgraph::Engine {
return BuildDeviceProgramImpl();
}
cpp::OpDesc* GetFirstNode(const std::string& input_name) {
for (size_t i = 0; i < block_desc_->OpsSize(); ++i) {
auto desc = block_desc_->GetOp<cpp::OpDesc>(i);
auto inputs = desc->input_vars();
if (std::find(inputs.cbegin(), inputs.cend(), input_name) !=
inputs.cend()) {
return desc;
}
}
return nullptr;
}
int BuildDeviceProgramImpl() {
int status = 0;
auto graph = std::make_shared<paddle::lite::subgraph::mlu::Graph>();
......@@ -150,12 +162,12 @@ class SubgraphEngine : public subgraph::Engine {
origin_itensors_.clear();
origin_otensors_.clear();
auto data_order = block_desc_->GetOp<cpp::OpDesc>(0)->Type() == "layout"
? CNML_NCHW
: CNML_NHWC;
// Convert all of input data vars and added into the MLU IR graph
status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED;
for (auto& input_name : input_names_) {
auto first_node = GetFirstNode(input_name);
CHECK(first_node);
auto data_order = first_node->Type() == "layout" ? CNML_NCHW : CNML_NHWC;
auto input_tensor = scope_->FindMutableTensor(input_name);
auto data_type = input_tensor->precision();
cnmlDataType_t fp_type = PrecisionToDatatype(data_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册