提交 d397458f 编写于 作者: D dingminghui 提交者: jackzhang235

fix(cast): fix precision error in mlu cast

caused by wrong data type in io_copy
上级 cb3f16ff
......@@ -40,6 +40,10 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
cast_arg->AsArg().type = cast_type;
inst_node->AsStmt().op()->scope()->Var(cast_arg_name);
VLOG(4) << "insert cast before subgraph";
VLOG(4) << "curent node type: " << cur_node->AsArg().type->name()
<< " cast to node type: " << cast_type->name();
// create the stmt node
auto* cast_inst = graph->NewInstructNode();
// create op
......@@ -89,13 +93,16 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TargetCompatibleTo(*in_arg_ty, *cur_node->AsArg().type) &&
TargetCompatibleTo(*out_arg_ty, *cast_type)) {
TargetCompatibleTo(*out_arg_ty, *cast_type) &&
PrecisionCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
PrecisionCompatible(*out_arg_ty, *cast_type)) {
is_found = true;
}
} else {
CHECK(0) << "Unsupport cast type";
}
if (is_found) {
VLOG(4) << "insert kernel: " << kernel->name();
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
......@@ -125,6 +132,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
auto* var = inst_node->AsStmt().op()->scope()->Var(cast_arg_name);
// for CastAfter manully set the tensor's type
var->GetMutable<paddle::lite::Tensor>();
VLOG(4) << "insert cast after subgraph";
VLOG(4) << "curent node type: " << cur_node->AsArg().type->name()
<< " cast to node type: " << cast_type->name();
// create the stmt node
auto* cast_inst = graph->NewInstructNode();
......@@ -174,7 +184,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TargetCompatibleTo(*in_arg_ty, *cast_type) &&
TargetCompatibleTo(*out_arg_ty, *cur_node->AsArg().type)) {
TargetCompatibleTo(*out_arg_ty, *cur_node->AsArg().type) &&
PrecisionCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
PrecisionCompatible(*out_arg_ty, *cast_type)) {
is_found = true;
}
} else {
......@@ -323,10 +335,9 @@ void MLUPostprocessPass::GetSubgraphOpArgType(Node* inst_node,
CHECK(subgraph_precision == PRECISION(kFloat) ||
subgraph_precision == PRECISION(kFP16))
<< "Mlu node has unsupport precision";
VLOG(4) << "picked kernel precision: "
<< PrecisionToStr(subgraph_precision);
*arg_type = LiteType::GetTensorTy(
subgraph_target, subgraph_precision, subgraph_layout);
VLOG(4) << "picked subgraph kernel type: " << (*arg_type)->name();
break;
}
}
......@@ -726,7 +737,7 @@ std::pair<bool, std::string> CheckOutputAndInsert(
return std::make_pair(do_insert, cur_node);
}
// insert cast op on mlu, to avoid cast on cpu, invoke before first run
// insert cast op on mlu, to avoid cast on cpu
void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
const Type* subgraph_type) {
auto subgraph_op = subgraph_node->AsStmt().op();
......@@ -820,6 +831,42 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
op->SetSubBlock(new_block_desc);
}
void ModifyValidPlaces(SSAGraph* graph, bool use_mlu_cast) {
// remove invalid places, since only support X86, host, MLU
auto v_places = graph->valid_places();
for (auto it = v_places.begin(); it != v_places.end();) {
if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) &&
it->target != TARGET(kX86)) {
it = v_places.erase(it);
} else {
++it;
}
}
if (use_mlu_cast) {
// insert mlu float place for float io copy, no effect to subgraph type
v_places.emplace_back(TARGET(kMLU), PRECISION(kFloat), DATALAYOUT(kNHWC));
} else {
// add x86 NHWC place for cpu cast
std::set<paddle::lite_api::PrecisionType> prec_set{};
for (auto& place : v_places) {
prec_set.insert(place.precision);
}
if (lite::TargetWrapperMlu::UseFirstConv()) {
prec_set.insert(PRECISION(kInt8));
}
for (auto& prec : prec_set) {
v_places.emplace_back(TARGET(kX86), prec, DATALAYOUT(kNHWC));
}
}
graph->SetValidPlaces(v_places);
VLOG(4) << "valid places after modified:";
for (auto& p : v_places) {
VLOG(4) << p.DebugString();
}
}
void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// currently for non-persistent input and output args, mlu subgraph op
// only support float16/float32 data type
......@@ -842,6 +889,7 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
g_stream_id = static_cast<int>(reinterpret_cast<int64_t>(graph.get()));
bool use_mlu_cast = GetBoolFromEnv("LITE_MLU_CAST");
ModifyValidPlaces(graph.get(), use_mlu_cast);
// insert io_copy, layout and precision cast of subgraph's inputs and outputs
for (auto& node : graph->mutable_nodes()) {
if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") {
......
......@@ -84,39 +84,6 @@ void RKNPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
void MLUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
#ifdef LITE_WITH_MLU
// remove invalid places, since only support X86, host, MLU
auto v_places = graph->valid_places();
for (auto it = v_places.begin(); it != v_places.end();) {
if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) &&
it->target != TARGET(kX86)) {
it = v_places.erase(it);
} else {
++it;
}
}
// add x86 NHWC place
std::vector<paddle::lite_api::PrecisionType> precisions{PRECISION(kFloat),
PRECISION(kFP16)};
if (lite::TargetWrapperMlu::UseFirstConv())
precisions.emplace_back(PRECISION(kInt8));
for (auto& prec : precisions) {
auto is_x86_nhwc = [prec](const Place& it) {
return it.layout == DATALAYOUT(kNHWC) && it.target == TARGET(kX86) &&
it.precision == prec;
};
if (std::find_if(v_places.cbegin(), v_places.cend(), is_x86_nhwc) ==
v_places.end()) {
v_places.emplace_back(Place{TARGET(kX86), prec, DATALAYOUT(kNHWC)});
}
}
graph->SetValidPlaces(v_places);
VLOG(4) << "valid places after modified:";
for (auto& p : v_places) {
VLOG(4) << p.DebugString();
}
#endif
std::unordered_set<std::string> supported_lists;
#define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type);
#include "lite/kernels/mlu/bridges/paddle_use_bridges.h"
......
......@@ -51,7 +51,7 @@ int LrnConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto local_size = op_info->GetAttr<int>("n");
CHECK(op_info->HasAttr("input_scale"));
auto input_scale = op_info->GetAttr<float>("input_scale");
std::cout << "input scale: " << input_scale << std::endl;
VLOG(5) << "lrn input scale: " << input_scale;
cnmlLrnOpParam_t param;
cnmlBaseOp_t lrn_op;
......
......@@ -41,6 +41,8 @@ class IoCopyHostToMluCompute
auto mem_size = param.x->memory_size();
// LOG(INFO) << "copy size " << mem_size;
auto* data = param.y->mutable_data(TARGET(kMLU), mem_size);
VLOG(6) << "io_copy host to mlu] memory size: " << mem_size
<< " precision type: " << PrecisionToStr(Precision);
param.y->set_precision(param.x->precision());
CopyFromHostSync(data, param.x->raw_data(), mem_size);
}
......@@ -80,6 +82,8 @@ class IoCopyMluToHostCompute
CHECK(param.x->target() == TARGET(kMLU));
auto mem_size = param.x->memory_size();
auto* data = param.y->mutable_data(TARGET(kHost), mem_size);
VLOG(6) << "io_copy mlu to host] memory size: " << mem_size
<< " precision type: " << PrecisionToStr(Precision);
// sync queue to ensure process done
auto& mlu_context = this->ctx_->template As<MLUContext>();
......@@ -105,11 +109,11 @@ REGISTER_LITE_KERNEL(
host_to_device_kFloat)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.Finalize();
......@@ -122,11 +126,11 @@ REGISTER_LITE_KERNEL(
host_to_device_kFP16)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
PRECISION(kFP16),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
PRECISION(kFP16),
DATALAYOUT(kAny))})
.Finalize();
......@@ -139,11 +143,11 @@ REGISTER_LITE_KERNEL(
device_to_host_kFloat)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.Finalize();
......@@ -156,10 +160,27 @@ REGISTER_LITE_KERNEL(
device_to_host_kFP16)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
PRECISION(kFP16),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
PRECISION(kFP16),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(
io_copy,
kMLU,
kInt8,
kNHWC,
paddle::lite::kernels::mlu::IoCopyMluToHostCompute<PRECISION(kInt8)>,
device_to_host_kInt8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt8),
DATALAYOUT(kAny))})
.Finalize();
......@@ -314,6 +314,18 @@ class SubgraphEngine : public subgraph::Engine {
}
}
inline void* GetOutputDataPtr(Tensor* tensor, bool use_mlu_cast) {
if (use_mlu_cast) {
// output is float, since cast fused in subgraph
return static_cast<void*>(tensor->mutable_data<float>(TARGET(kMLU)));
} else {
return static_cast<void*>(
tensor->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
}
}
int LaunchDeviceProgram() override {
// prepare input and output memory
auto& mlu_context = this->ctx_->template As<MLUContext>();
......@@ -331,6 +343,8 @@ class SubgraphEngine : public subgraph::Engine {
CHECK_EQ(graph_input->size(), origin_itensors_.size());
CHECK_EQ(graph_output->size(), origin_otensors_.size());
bool use_mlu_cast = GetBoolFromEnv("LITE_MLU_CAST");
if (!disable_batch_size_changeable_) {
std::vector<std::shared_ptr<paddle::lite::subgraph::mlu::MLUTensor>>
graph_in;
......@@ -371,26 +385,17 @@ class SubgraphEngine : public subgraph::Engine {
graph_out = shape_tensor_map_out_[all_inputs_shape_];
for (size_t i = 0; i < origin_otensors_.size(); ++i) {
// origin_otensors_[i]->Resize(new_output_size.at(i));
void* p_data = static_cast<void*>(
origin_otensors_[i]
->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
graph_out[i]->set_mlu_ptr(p_data);
graph_out[i]->set_mlu_ptr(
GetOutputDataPtr(origin_otensors_[i], use_mlu_cast));
}
} else {
graph_out.reserve(origin_otensors_.size());
for (size_t i = 0; i < origin_otensors_.size(); ++i) {
// origin_otensors_[i]->Resize(new_output_size.at(i));
void* p_data = static_cast<void*>(
origin_otensors_[i]
->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
paddle::lite::subgraph::mlu::MLUTensor tmp(
origin_otensors_[i]->dims().Vectorize());
tmp.set_mlu_dtype(graph_output->at(i)->dtype());
tmp.set_mlu_ptr(p_data);
tmp.set_mlu_ptr(GetOutputDataPtr(origin_otensors_[i], use_mlu_cast));
graph_out.push_back(
std::make_shared<paddle::lite::subgraph::mlu::MLUTensor>(tmp));
}
......@@ -404,12 +409,8 @@ class SubgraphEngine : public subgraph::Engine {
}
for (size_t i = 0; i < origin_otensors_.size(); ++i) {
origin_otensors_[i]->Resize(graph_output->at(i)->get_origin_shape());
void* p_data = static_cast<void*>(
origin_otensors_[i]
->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
graph_output->at(i)->set_mlu_ptr(p_data);
graph_output->at(i)->set_mlu_ptr(
GetOutputDataPtr(origin_otensors_[i], use_mlu_cast));
}
graph->Compute(forward_param, exec_queue);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册