提交 d397458f 编写于 作者: D dingminghui 提交者: jackzhang235

fix(cast): fix precision error in mlu cast

caused by wrong data type in io_copy
上级 cb3f16ff
...@@ -40,6 +40,10 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, ...@@ -40,6 +40,10 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
cast_arg->AsArg().type = cast_type; cast_arg->AsArg().type = cast_type;
inst_node->AsStmt().op()->scope()->Var(cast_arg_name); inst_node->AsStmt().op()->scope()->Var(cast_arg_name);
VLOG(4) << "insert cast before subgraph";
VLOG(4) << "curent node type: " << cur_node->AsArg().type->name()
<< " cast to node type: " << cast_type->name();
// create the stmt node // create the stmt node
auto* cast_inst = graph->NewInstructNode(); auto* cast_inst = graph->NewInstructNode();
// create op // create op
...@@ -89,13 +93,16 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, ...@@ -89,13 +93,16 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TargetCompatibleTo(*in_arg_ty, *cur_node->AsArg().type) && if (TargetCompatibleTo(*in_arg_ty, *cur_node->AsArg().type) &&
TargetCompatibleTo(*out_arg_ty, *cast_type)) { TargetCompatibleTo(*out_arg_ty, *cast_type) &&
PrecisionCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
PrecisionCompatible(*out_arg_ty, *cast_type)) {
is_found = true; is_found = true;
} }
} else { } else {
CHECK(0) << "Unsupport cast type"; CHECK(0) << "Unsupport cast type";
} }
if (is_found) { if (is_found) {
VLOG(4) << "insert kernel: " << kernel->name();
selected_kernels.emplace_back(std::move(kernel)); selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel // we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op); cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
...@@ -125,6 +132,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, ...@@ -125,6 +132,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
auto* var = inst_node->AsStmt().op()->scope()->Var(cast_arg_name); auto* var = inst_node->AsStmt().op()->scope()->Var(cast_arg_name);
// for CastAfter manully set the tensor's type // for CastAfter manully set the tensor's type
var->GetMutable<paddle::lite::Tensor>(); var->GetMutable<paddle::lite::Tensor>();
VLOG(4) << "insert cast after subgraph";
VLOG(4) << "curent node type: " << cur_node->AsArg().type->name()
<< " cast to node type: " << cast_type->name();
// create the stmt node // create the stmt node
auto* cast_inst = graph->NewInstructNode(); auto* cast_inst = graph->NewInstructNode();
...@@ -174,7 +184,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, ...@@ -174,7 +184,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (TargetCompatibleTo(*in_arg_ty, *cast_type) && if (TargetCompatibleTo(*in_arg_ty, *cast_type) &&
TargetCompatibleTo(*out_arg_ty, *cur_node->AsArg().type)) { TargetCompatibleTo(*out_arg_ty, *cur_node->AsArg().type) &&
PrecisionCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
PrecisionCompatible(*out_arg_ty, *cast_type)) {
is_found = true; is_found = true;
} }
} else { } else {
...@@ -323,10 +335,9 @@ void MLUPostprocessPass::GetSubgraphOpArgType(Node* inst_node, ...@@ -323,10 +335,9 @@ void MLUPostprocessPass::GetSubgraphOpArgType(Node* inst_node,
CHECK(subgraph_precision == PRECISION(kFloat) || CHECK(subgraph_precision == PRECISION(kFloat) ||
subgraph_precision == PRECISION(kFP16)) subgraph_precision == PRECISION(kFP16))
<< "Mlu node has unsupport precision"; << "Mlu node has unsupport precision";
VLOG(4) << "picked kernel precision: "
<< PrecisionToStr(subgraph_precision);
*arg_type = LiteType::GetTensorTy( *arg_type = LiteType::GetTensorTy(
subgraph_target, subgraph_precision, subgraph_layout); subgraph_target, subgraph_precision, subgraph_layout);
VLOG(4) << "picked subgraph kernel type: " << (*arg_type)->name();
break; break;
} }
} }
...@@ -726,7 +737,7 @@ std::pair<bool, std::string> CheckOutputAndInsert( ...@@ -726,7 +737,7 @@ std::pair<bool, std::string> CheckOutputAndInsert(
return std::make_pair(do_insert, cur_node); return std::make_pair(do_insert, cur_node);
} }
// insert cast op on mlu, to avoid cast on cpu, invoke before first run // insert cast op on mlu, to avoid cast on cpu
void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
const Type* subgraph_type) { const Type* subgraph_type) {
auto subgraph_op = subgraph_node->AsStmt().op(); auto subgraph_op = subgraph_node->AsStmt().op();
...@@ -820,6 +831,42 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, ...@@ -820,6 +831,42 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
op->SetSubBlock(new_block_desc); op->SetSubBlock(new_block_desc);
} }
void ModifyValidPlaces(SSAGraph* graph, bool use_mlu_cast) {
// remove invalid places, since only support X86, host, MLU
auto v_places = graph->valid_places();
for (auto it = v_places.begin(); it != v_places.end();) {
if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) &&
it->target != TARGET(kX86)) {
it = v_places.erase(it);
} else {
++it;
}
}
if (use_mlu_cast) {
// insert mlu float place for float io copy, no effect to subgraph type
v_places.emplace_back(TARGET(kMLU), PRECISION(kFloat), DATALAYOUT(kNHWC));
} else {
// add x86 NHWC place for cpu cast
std::set<paddle::lite_api::PrecisionType> prec_set{};
for (auto& place : v_places) {
prec_set.insert(place.precision);
}
if (lite::TargetWrapperMlu::UseFirstConv()) {
prec_set.insert(PRECISION(kInt8));
}
for (auto& prec : prec_set) {
v_places.emplace_back(TARGET(kX86), prec, DATALAYOUT(kNHWC));
}
}
graph->SetValidPlaces(v_places);
VLOG(4) << "valid places after modified:";
for (auto& p : v_places) {
VLOG(4) << p.DebugString();
}
}
void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// currently for non-persistent input and output args, mlu subgraph op // currently for non-persistent input and output args, mlu subgraph op
// only support float16/float32 data type // only support float16/float32 data type
...@@ -842,6 +889,7 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -842,6 +889,7 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
g_stream_id = static_cast<int>(reinterpret_cast<int64_t>(graph.get())); g_stream_id = static_cast<int>(reinterpret_cast<int64_t>(graph.get()));
bool use_mlu_cast = GetBoolFromEnv("LITE_MLU_CAST"); bool use_mlu_cast = GetBoolFromEnv("LITE_MLU_CAST");
ModifyValidPlaces(graph.get(), use_mlu_cast);
// insert io_copy, layout and precision cast of subgraph's inputs and outputs // insert io_copy, layout and precision cast of subgraph's inputs and outputs
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") { if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") {
......
...@@ -84,39 +84,6 @@ void RKNPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -84,39 +84,6 @@ void RKNPUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} }
void MLUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void MLUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
#ifdef LITE_WITH_MLU
// remove invalid places, since only support X86, host, MLU
auto v_places = graph->valid_places();
for (auto it = v_places.begin(); it != v_places.end();) {
if (it->target != TARGET(kMLU) && it->target != TARGET(kHost) &&
it->target != TARGET(kX86)) {
it = v_places.erase(it);
} else {
++it;
}
}
// add x86 NHWC place
std::vector<paddle::lite_api::PrecisionType> precisions{PRECISION(kFloat),
PRECISION(kFP16)};
if (lite::TargetWrapperMlu::UseFirstConv())
precisions.emplace_back(PRECISION(kInt8));
for (auto& prec : precisions) {
auto is_x86_nhwc = [prec](const Place& it) {
return it.layout == DATALAYOUT(kNHWC) && it.target == TARGET(kX86) &&
it.precision == prec;
};
if (std::find_if(v_places.cbegin(), v_places.cend(), is_x86_nhwc) ==
v_places.end()) {
v_places.emplace_back(Place{TARGET(kX86), prec, DATALAYOUT(kNHWC)});
}
}
graph->SetValidPlaces(v_places);
VLOG(4) << "valid places after modified:";
for (auto& p : v_places) {
VLOG(4) << p.DebugString();
}
#endif
std::unordered_set<std::string> supported_lists; std::unordered_set<std::string> supported_lists;
#define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type); #define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type);
#include "lite/kernels/mlu/bridges/paddle_use_bridges.h" #include "lite/kernels/mlu/bridges/paddle_use_bridges.h"
......
...@@ -51,7 +51,7 @@ int LrnConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -51,7 +51,7 @@ int LrnConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto local_size = op_info->GetAttr<int>("n"); auto local_size = op_info->GetAttr<int>("n");
CHECK(op_info->HasAttr("input_scale")); CHECK(op_info->HasAttr("input_scale"));
auto input_scale = op_info->GetAttr<float>("input_scale"); auto input_scale = op_info->GetAttr<float>("input_scale");
std::cout << "input scale: " << input_scale << std::endl; VLOG(5) << "lrn input scale: " << input_scale;
cnmlLrnOpParam_t param; cnmlLrnOpParam_t param;
cnmlBaseOp_t lrn_op; cnmlBaseOp_t lrn_op;
......
...@@ -41,6 +41,8 @@ class IoCopyHostToMluCompute ...@@ -41,6 +41,8 @@ class IoCopyHostToMluCompute
auto mem_size = param.x->memory_size(); auto mem_size = param.x->memory_size();
// LOG(INFO) << "copy size " << mem_size; // LOG(INFO) << "copy size " << mem_size;
auto* data = param.y->mutable_data(TARGET(kMLU), mem_size); auto* data = param.y->mutable_data(TARGET(kMLU), mem_size);
VLOG(6) << "io_copy host to mlu] memory size: " << mem_size
<< " precision type: " << PrecisionToStr(Precision);
param.y->set_precision(param.x->precision()); param.y->set_precision(param.x->precision());
CopyFromHostSync(data, param.x->raw_data(), mem_size); CopyFromHostSync(data, param.x->raw_data(), mem_size);
} }
...@@ -80,6 +82,8 @@ class IoCopyMluToHostCompute ...@@ -80,6 +82,8 @@ class IoCopyMluToHostCompute
CHECK(param.x->target() == TARGET(kMLU)); CHECK(param.x->target() == TARGET(kMLU));
auto mem_size = param.x->memory_size(); auto mem_size = param.x->memory_size();
auto* data = param.y->mutable_data(TARGET(kHost), mem_size); auto* data = param.y->mutable_data(TARGET(kHost), mem_size);
VLOG(6) << "io_copy mlu to host] memory size: " << mem_size
<< " precision type: " << PrecisionToStr(Precision);
// sync queue to ensure process done // sync queue to ensure process done
auto& mlu_context = this->ctx_->template As<MLUContext>(); auto& mlu_context = this->ctx_->template As<MLUContext>();
...@@ -105,11 +109,11 @@ REGISTER_LITE_KERNEL( ...@@ -105,11 +109,11 @@ REGISTER_LITE_KERNEL(
host_to_device_kFloat) host_to_device_kFloat)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost), {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny), PRECISION(kFloat),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU), {LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny), PRECISION(kFloat),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -122,11 +126,11 @@ REGISTER_LITE_KERNEL( ...@@ -122,11 +126,11 @@ REGISTER_LITE_KERNEL(
host_to_device_kFP16) host_to_device_kFP16)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost), {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny), PRECISION(kFP16),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU), {LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny), PRECISION(kFP16),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -139,11 +143,11 @@ REGISTER_LITE_KERNEL( ...@@ -139,11 +143,11 @@ REGISTER_LITE_KERNEL(
device_to_host_kFloat) device_to_host_kFloat)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU), {LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny), PRECISION(kFloat),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny), PRECISION(kFloat),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -156,10 +160,27 @@ REGISTER_LITE_KERNEL( ...@@ -156,10 +160,27 @@ REGISTER_LITE_KERNEL(
device_to_host_kFP16) device_to_host_kFP16)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU), {LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny), PRECISION(kFP16),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny), PRECISION(kFP16),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(
io_copy,
kMLU,
kInt8,
kNHWC,
paddle::lite::kernels::mlu::IoCopyMluToHostCompute<PRECISION(kInt8)>,
device_to_host_kInt8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt8),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -314,6 +314,18 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -314,6 +314,18 @@ class SubgraphEngine : public subgraph::Engine {
} }
} }
inline void* GetOutputDataPtr(Tensor* tensor, bool use_mlu_cast) {
if (use_mlu_cast) {
// output is float, since cast fused in subgraph
return static_cast<void*>(tensor->mutable_data<float>(TARGET(kMLU)));
} else {
return static_cast<void*>(
tensor->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
}
}
int LaunchDeviceProgram() override { int LaunchDeviceProgram() override {
// prepare input and output memory // prepare input and output memory
auto& mlu_context = this->ctx_->template As<MLUContext>(); auto& mlu_context = this->ctx_->template As<MLUContext>();
...@@ -331,6 +343,8 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -331,6 +343,8 @@ class SubgraphEngine : public subgraph::Engine {
CHECK_EQ(graph_input->size(), origin_itensors_.size()); CHECK_EQ(graph_input->size(), origin_itensors_.size());
CHECK_EQ(graph_output->size(), origin_otensors_.size()); CHECK_EQ(graph_output->size(), origin_otensors_.size());
bool use_mlu_cast = GetBoolFromEnv("LITE_MLU_CAST");
if (!disable_batch_size_changeable_) { if (!disable_batch_size_changeable_) {
std::vector<std::shared_ptr<paddle::lite::subgraph::mlu::MLUTensor>> std::vector<std::shared_ptr<paddle::lite::subgraph::mlu::MLUTensor>>
graph_in; graph_in;
...@@ -371,26 +385,17 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -371,26 +385,17 @@ class SubgraphEngine : public subgraph::Engine {
graph_out = shape_tensor_map_out_[all_inputs_shape_]; graph_out = shape_tensor_map_out_[all_inputs_shape_];
for (size_t i = 0; i < origin_otensors_.size(); ++i) { for (size_t i = 0; i < origin_otensors_.size(); ++i) {
// origin_otensors_[i]->Resize(new_output_size.at(i)); // origin_otensors_[i]->Resize(new_output_size.at(i));
void* p_data = static_cast<void*>( graph_out[i]->set_mlu_ptr(
origin_otensors_[i] GetOutputDataPtr(origin_otensors_[i], use_mlu_cast));
->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
graph_out[i]->set_mlu_ptr(p_data);
} }
} else { } else {
graph_out.reserve(origin_otensors_.size()); graph_out.reserve(origin_otensors_.size());
for (size_t i = 0; i < origin_otensors_.size(); ++i) { for (size_t i = 0; i < origin_otensors_.size(); ++i) {
// origin_otensors_[i]->Resize(new_output_size.at(i)); // origin_otensors_[i]->Resize(new_output_size.at(i));
void* p_data = static_cast<void*>(
origin_otensors_[i]
->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
paddle::lite::subgraph::mlu::MLUTensor tmp( paddle::lite::subgraph::mlu::MLUTensor tmp(
origin_otensors_[i]->dims().Vectorize()); origin_otensors_[i]->dims().Vectorize());
tmp.set_mlu_dtype(graph_output->at(i)->dtype()); tmp.set_mlu_dtype(graph_output->at(i)->dtype());
tmp.set_mlu_ptr(p_data); tmp.set_mlu_ptr(GetOutputDataPtr(origin_otensors_[i], use_mlu_cast));
graph_out.push_back( graph_out.push_back(
std::make_shared<paddle::lite::subgraph::mlu::MLUTensor>(tmp)); std::make_shared<paddle::lite::subgraph::mlu::MLUTensor>(tmp));
} }
...@@ -404,12 +409,8 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -404,12 +409,8 @@ class SubgraphEngine : public subgraph::Engine {
} }
for (size_t i = 0; i < origin_otensors_.size(); ++i) { for (size_t i = 0; i < origin_otensors_.size(); ++i) {
origin_otensors_[i]->Resize(graph_output->at(i)->get_origin_shape()); origin_otensors_[i]->Resize(graph_output->at(i)->get_origin_shape());
void* p_data = static_cast<void*>( graph_output->at(i)->set_mlu_ptr(
origin_otensors_[i] GetOutputDataPtr(origin_otensors_[i], use_mlu_cast));
->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
graph_output->at(i)->set_mlu_ptr(p_data);
} }
graph->Compute(forward_param, exec_queue); graph->Compute(forward_param, exec_queue);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册