提交 789112e8 编写于 作者: J jackzhang235 提交者: jackzhang235

support changable input dims

上级 63da1451
...@@ -61,12 +61,13 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -61,12 +61,13 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
int co = static_cast<int>(mean_dims[0]); int co = static_cast<int>(mean_dims[0]);
std::vector<float> variance_trans(co);
std::vector<float> mean_trans(co);
for (int i = 0; i < co; ++i) { for (int i = 0; i < co; ++i) {
variance->mutable_data<float>()[i] = variance_trans[i] =
scale->data<float>()[i] / sqrtf(variance->data<float>()[i] + epsilon); scale->data<float>()[i] / sqrtf(variance->data<float>()[i] + epsilon);
mean->mutable_data<float>()[i] = mean_trans[i] =
mean->data<float>()[i] - mean->data<float>()[i] - bias->data<float>()[i] / variance_trans[i];
bias->data<float>()[i] / variance->data<float>()[i];
} }
auto input_tensor = graph->GetNode(x_var_name); auto input_tensor = graph->GetNode(x_var_name);
...@@ -77,8 +78,10 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -77,8 +78,10 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
mean_tensor->mlu_tensor(), mean_tensor->mlu_tensor(),
variance_tensor->mlu_tensor())); variance_tensor->mlu_tensor()));
graph->BindConstData(variance_var_name, variance); graph->BindConstRawData(
graph->BindConstData(mean_var_name, mean); variance_var_name, variance_trans.data(), variance_trans.size(), true);
graph->BindConstRawData(
mean_var_name, mean_trans.data(), mean_trans.size(), true);
graph->FuseOp(bn_op); graph->FuseOp(bn_op);
CNML_CALL(cnmlDestroyBaseOp(&bn_op)); CNML_CALL(cnmlDestroyBaseOp(&bn_op));
......
...@@ -89,6 +89,14 @@ class Graph { ...@@ -89,6 +89,14 @@ class Graph {
output_tensors_.push_back(tensor); output_tensors_.push_back(tensor);
} }
std::vector<std::shared_ptr<MLUTensor>>* MutableInputs() {
return &input_tensors_;
}
std::vector<std::shared_ptr<MLUTensor>>* MutableOutputs() {
return &output_tensors_;
}
void FuseOp(cnmlBaseOp_t op) { CNML_CALL(cnmlFuseOp(op, fusion_op_)); } void FuseOp(cnmlBaseOp_t op) { CNML_CALL(cnmlFuseOp(op, fusion_op_)); }
void Compile(cnmlCoreVersion_t core_version, int core_number) { void Compile(cnmlCoreVersion_t core_version, int core_number) {
...@@ -100,15 +108,18 @@ class Graph { ...@@ -100,15 +108,18 @@ class Graph {
CNML_CALL(cnmlSetFusionOpCorenum(fusion_op_, core_number)); CNML_CALL(cnmlSetFusionOpCorenum(fusion_op_, core_number));
CNML_CALL(cnmlSetFusionOpCoreVersion(fusion_op_, core_version)); CNML_CALL(cnmlSetFusionOpCoreVersion(fusion_op_, core_version));
CNML_CALL(cnmlCompileFusionOp_V2(fusion_op_)); CNML_CALL(cnmlCompileFusionOp_V2(fusion_op_));
for (auto in : input_tensors_) {
input_addrs_.push_back(in->mlu_data());
}
for (auto out : output_tensors_) {
output_addrs_.push_back(out->mlu_data());
}
} }
void Compute(cnrtInvokeFuncParam_t forward_param, cnrtQueue_t que) { void Compute(cnrtInvokeFuncParam_t forward_param, cnrtQueue_t que) {
input_addrs_.resize(input_tensors_.size());
output_addrs_.resize(output_tensors_.size());
for (size_t i = 0; i < input_addrs_.size(); ++i) {
input_addrs_[i] = input_tensors_[i]->mlu_data();
}
for (size_t i = 0; i < output_addrs_.size(); ++i) {
output_addrs_[i] = output_tensors_[i]->mlu_data();
}
#if PRINT_HW_TIME #if PRINT_HW_TIME
thread_local float hw_time; thread_local float hw_time;
CNRT_CALL(cnrtPlaceNotifier(notifier_start_, que)); CNRT_CALL(cnrtPlaceNotifier(notifier_start_, que));
...@@ -159,7 +170,7 @@ class Graph { ...@@ -159,7 +170,7 @@ class Graph {
CNML_CALL(cnmlBindConstData_V2( CNML_CALL(cnmlBindConstData_V2(
nodes_[tensor_name]->mlu_tensor(), alloc_data, false)); nodes_[tensor_name]->mlu_tensor(), alloc_data, false));
} else if (fp_type_ == CNML_DATA_FLOAT16) { } else if (fp_type_ == CNML_DATA_FLOAT16) {
void* data_fp16 = RegisterConstData<::paddle::lite::fluid::float16>(len); void* data_fp16 = RegisterConstData<paddle::lite::fluid::float16>(len);
CNRT_CALL( CNRT_CALL(
cnrtCastDataType(const_cast<void*>(static_cast<const void*>(data)), cnrtCastDataType(const_cast<void*>(static_cast<const void*>(data)),
CNRT_FLOAT32, CNRT_FLOAT32,
...@@ -174,7 +185,7 @@ class Graph { ...@@ -174,7 +185,7 @@ class Graph {
} }
} }
void BindConstData(std::string tensor_name, ::paddle::lite::Tensor* tensor) { void BindConstData(std::string tensor_name, paddle::lite::Tensor* tensor) {
const float* data = tensor->data<float>(); const float* data = tensor->data<float>();
size_t len = tensor->data_size(); size_t len = tensor->data_size();
if (fp_type_ == CNML_DATA_FLOAT32) { if (fp_type_ == CNML_DATA_FLOAT32) {
...@@ -183,10 +194,14 @@ class Graph { ...@@ -183,10 +194,14 @@ class Graph {
const_cast<void*>(static_cast<const void*>(data)), const_cast<void*>(static_cast<const void*>(data)),
false)); false));
} else if (fp_type_ == CNML_DATA_FLOAT16) { } else if (fp_type_ == CNML_DATA_FLOAT16) {
auto* data_fp16 = tensor->mutable_data<::paddle::lite::fluid::float16>(); void* data_fp16 = RegisterConstData<paddle::lite::fluid::float16>(len);
for (size_t i = 0; i < len; ++i) { CNRT_CALL(
data_fp16[i] = static_cast<::paddle::lite::fluid::float16>(data[i]); cnrtCastDataType(const_cast<void*>(static_cast<const void*>(data)),
} CNRT_FLOAT32,
data_fp16,
CNRT_FLOAT16,
len,
nullptr));
CNML_CALL(cnmlBindConstData_V2(nodes_[tensor_name]->mlu_tensor(), CNML_CALL(cnmlBindConstData_V2(nodes_[tensor_name]->mlu_tensor(),
static_cast<void*>(data_fp16), static_cast<void*>(data_fp16),
false)); false));
...@@ -207,12 +222,13 @@ class Graph { ...@@ -207,12 +222,13 @@ class Graph {
CNML_CALL(cnmlDestroyQuantizedParam(&quant_param)); CNML_CALL(cnmlDestroyQuantizedParam(&quant_param));
} }
void SetFPType(::paddle::lite_api::PrecisionType type) { void SetFPType(paddle::lite_api::PrecisionType type) {
origin_fp_type_ = type;
switch (type) { switch (type) {
case ::paddle::lite_api::PrecisionType::kFP16: case paddle::lite_api::PrecisionType::kFP16:
fp_type_ = CNML_DATA_FLOAT16; fp_type_ = CNML_DATA_FLOAT16;
break; break;
case ::paddle::lite_api::PrecisionType::kFloat: case paddle::lite_api::PrecisionType::kFloat:
fp_type_ = CNML_DATA_FLOAT32; fp_type_ = CNML_DATA_FLOAT32;
break; break;
default: default:
...@@ -224,6 +240,7 @@ class Graph { ...@@ -224,6 +240,7 @@ class Graph {
private: private:
cnmlDataType_t fp_type_{CNML_DATA_FLOAT32}; cnmlDataType_t fp_type_{CNML_DATA_FLOAT32};
paddle::lite_api::PrecisionType origin_fp_type_{PRECISION(kFloat)};
std::unordered_map<std::string, std::shared_ptr<MLUTensor>> nodes_; std::unordered_map<std::string, std::shared_ptr<MLUTensor>> nodes_;
std::vector<cnmlTensor_t> inputs_; std::vector<cnmlTensor_t> inputs_;
std::vector<cnmlTensor_t> outputs_; std::vector<cnmlTensor_t> outputs_;
......
...@@ -46,6 +46,7 @@ void MLUTensor::remember(const std::vector<int>& shape, ...@@ -46,6 +46,7 @@ void MLUTensor::remember(const std::vector<int>& shape,
cnmlDataOrder_t shape_order) { cnmlDataOrder_t shape_order) {
tensor_type_ = tensor_type; tensor_type_ = tensor_type;
mlu_dtype_ = mlu_dtype; mlu_dtype_ = mlu_dtype;
origin_shape_.assign(shape.begin(), shape.end());
int size = 4; int size = 4;
if (shape.size() > 4 || shape_order == CNML_ARRAY) { if (shape.size() > 4 || shape_order == CNML_ARRAY) {
......
...@@ -51,6 +51,8 @@ class MLUTensor { ...@@ -51,6 +51,8 @@ class MLUTensor {
void set_mlu_dtype(cnmlDataType_t type) { mlu_dtype_ = type; } void set_mlu_dtype(cnmlDataType_t type) { mlu_dtype_ = type; }
const std::vector<int64_t>& get_origin_shape() const { return origin_shape_; }
~MLUTensor(); ~MLUTensor();
void ToFile(std::string file_name); void ToFile(std::string file_name);
...@@ -59,6 +61,7 @@ class MLUTensor { ...@@ -59,6 +61,7 @@ class MLUTensor {
cnmlTensor_t mlu_tensor_; cnmlTensor_t mlu_tensor_;
std::vector<int> shape_; std::vector<int> shape_;
std::vector<int64_t> origin_shape_;
cnmlTensorType_t tensor_type_; cnmlTensorType_t tensor_type_;
cnmlDataType_t mlu_dtype_; cnmlDataType_t mlu_dtype_;
int dim_{0}; int dim_{0};
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -40,11 +41,10 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -40,11 +41,10 @@ class SubgraphEngine : public subgraph::Engine {
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
Scope* scope, Scope* scope,
::paddle::lite_api::PrecisionType type) paddle::lite_api::PrecisionType type)
: subgraph::Engine( : subgraph::Engine(
ctx, block_idx, block_desc, input_names, output_names, scope) { ctx, block_idx, block_desc, input_names, output_names, scope),
graph_.SetFPType(type); fp_type_(type) {}
}
int Build() { int Build() {
// In order to attach all of the ops of the block desc, we need to build // In order to attach all of the ops of the block desc, we need to build
...@@ -72,24 +72,44 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -72,24 +72,44 @@ class SubgraphEngine : public subgraph::Engine {
return 0; return 0;
} }
bool InputShapeChanged() {
std::vector<std::vector<int64_t>> new_shape;
for (auto origin_itensor : origin_itensors_) {
new_shape.push_back(origin_itensor->dims().Vectorize());
}
inputs_shape_ = new_shape;
if (shape_graph_map_.count(inputs_shape_) > 0) {
return false;
}
return true;
}
protected: protected:
int BuildDeviceProgram() override { int BuildDeviceProgram() override {
int status = 0; int status = 0;
auto graph = std::make_shared<paddle::lite::subgraph::mlu::Graph>();
graph->SetFPType(fp_type_);
std::vector<std::vector<int64_t>> new_shape;
origin_itensors_.clear();
origin_otensors_.clear();
// Convert all of input data vars and added into the MLU IR graph // Convert all of input data vars and added into the MLU IR graph
status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED;
for (auto& input_name : input_names_) { for (auto& input_name : input_names_) {
auto input_tensor = scope_->FindMutableTensor(input_name); auto input_tensor = scope_->FindMutableTensor(input_name);
origin_itensors_.push_back(input_tensor);
new_shape.push_back(input_tensor->dims().Vectorize());
CHECK(input_tensor); CHECK(input_tensor);
auto input_node = auto input_node = graph->AddNode(input_name,
graph_.AddNode(input_name, input_tensor->dims().Vectorize(),
input_tensor->dims().Vectorize(), CNML_TENSOR,
CNML_TENSOR, CNML_NCHW,
CNML_NCHW, graph->FPType());
graph_.FPType(),
const_cast<void*>(input_tensor->raw_data()));
CHECK(input_node); CHECK(input_node);
// MLU doesn't support dynamic dimensions/shapes, so need to rebuild // MLU doesn't support dynamic dimensions/shapes, so need to rebuild
// the program when the shape of any input tensor is changed. // the program when the shape of any input tensor is changed.
status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED;
} }
LOG(INFO) << "START TO CONVERT "; LOG(INFO) << "START TO CONVERT ";
// Convert all of ops and its weights and added into the MLU IR graph // Convert all of ops and its weights and added into the MLU IR graph
...@@ -106,7 +126,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -106,7 +126,7 @@ class SubgraphEngine : public subgraph::Engine {
} }
auto kernel = inst.kernel(); auto kernel = inst.kernel();
status |= bridges.Select(op_type, TARGET(kMLU))( status |= bridges.Select(op_type, TARGET(kMLU))(
reinterpret_cast<void*>(&graph_), reinterpret_cast<void*>(graph.get()),
const_cast<OpLite*>(op), const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel)); const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) { if (subgraph::CHECK_FAILED(status)) {
...@@ -115,33 +135,51 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -115,33 +135,51 @@ class SubgraphEngine : public subgraph::Engine {
} }
// Obtain the output nodes of the MLU IR graph and build the graph to MLU // Obtain the output nodes of the MLU IR graph and build the graph to MLU
// runtime // runtime
std::vector<std::string> valid_output_names;
for (auto& output_name : output_names_) { for (auto& output_name : output_names_) {
if (graph_.HasNode(output_name)) { if (graph->HasNode(output_name)) {
graph_.AddOutput(graph_.GetNode(output_name)); graph->AddOutput(graph->GetNode(output_name));
auto output_tensor = scope_->FindMutableTensor(output_name); auto output_tensor = scope_->FindMutableTensor(output_name);
void* p_data = static_cast<void*>( origin_otensors_.push_back(output_tensor);
output_tensor->mutable_data<typename ::paddle::lite::subgraph::mlu::
FPTypeTraits<Precision>::T>( // auto node = graph->GetNode(output_name);
TARGET(kMLU))); // CHECK(p_data);
auto node = graph_.GetNode(output_name); // node->set_mlu_ptr(p_data);
CHECK(p_data);
node->set_mlu_ptr(p_data);
valid_output_names.push_back(output_name);
} }
} }
for (auto& input_name : input_names_) { for (auto& input_name : input_names_) {
graph_.AddInput(graph_.GetNode(input_name)); graph->AddInput(graph->GetNode(input_name));
} }
CHECK(!valid_output_names.empty()) << "[MLU] no valid output names";
CHECK(!origin_otensors_.empty()) << "[MLU] no valid output names";
auto& mlu_context = this->ctx_->template As<MLUContext>(); auto& mlu_context = this->ctx_->template As<MLUContext>();
auto core_version = mlu_context.MLUCoreVersion(); auto core_version = mlu_context.MLUCoreVersion();
auto core_number = mlu_context.MLUCoreNumber(); auto core_number = mlu_context.MLUCoreNumber();
graph_.Compile(core_version, core_number); graph->Compile(core_version, core_number);
shape_graph_map_[new_shape] = graph;
return status; return status;
} }
int LaunchDeviceProgram() override { int LaunchDeviceProgram() override {
// prepare input and output memory
auto graph = shape_graph_map_[inputs_shape_];
auto* graph_input = graph->MutableInputs();
auto* graph_output = graph->MutableOutputs();
CHECK_EQ(graph_input->size(), origin_itensors_.size());
CHECK_EQ(graph_output->size(), origin_otensors_.size());
for (size_t i = 0; i < origin_itensors_.size(); ++i) {
graph_input->at(i)->set_mlu_ptr(
const_cast<void*>(origin_itensors_[i]->raw_data()));
}
for (size_t i = 0; i < origin_otensors_.size(); ++i) {
origin_otensors_[i]->Resize(graph_output->at(i)->get_origin_shape());
void* p_data = static_cast<void*>(
origin_otensors_[i]
->mutable_data<typename paddle::lite::subgraph::mlu::FPTypeTraits<
Precision>::T>(TARGET(kMLU)));
graph_output->at(i)->set_mlu_ptr(p_data);
}
auto& mlu_context = this->ctx_->template As<MLUContext>(); auto& mlu_context = this->ctx_->template As<MLUContext>();
auto exec_queue = mlu_context.exec_queue(); auto exec_queue = mlu_context.exec_queue();
u32_t affinity = mlu_context.affinity(); u32_t affinity = mlu_context.affinity();
...@@ -150,11 +188,13 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -150,11 +188,13 @@ class SubgraphEngine : public subgraph::Engine {
forward_param.data_parallelism = &data_param; forward_param.data_parallelism = &data_param;
forward_param.affinity = &affinity; forward_param.affinity = &affinity;
forward_param.end = CNRT_PARAM_END; forward_param.end = CNRT_PARAM_END;
graph_.Compute(forward_param, exec_queue);
graph->Compute(forward_param, exec_queue);
// // =========== DUMP =================== // // =========== DUMP ===================
// for (auto input_name : input_names_) { // for (auto input_name : input_names_) {
// auto input_tensor = graph_.GetNode(input_name); // auto input_tensor =
// shape_graph_map_[inputs_shape_]->GetNode(input_name);
// auto dump_name = input_name; // auto dump_name = input_name;
// while (dump_name.find("/") != std::string::npos) { // while (dump_name.find("/") != std::string::npos) {
// dump_name = dump_name.replace(dump_name.find("/"), 1, "_"); // dump_name = dump_name.replace(dump_name.find("/"), 1, "_");
...@@ -163,8 +203,9 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -163,8 +203,9 @@ class SubgraphEngine : public subgraph::Engine {
// input_tensor->ToFile(dump_name); // input_tensor->ToFile(dump_name);
// } // }
// for (auto output_name : output_names_) { // for (auto output_name : output_names_) {
// if (graph_.HasNode(output_name)) { // if (shape_graph_map_[inputs_shape_]->HasNode(output_name)) {
// auto output_tensor = graph_.GetNode(output_name); // auto output_tensor =
// shape_graph_map_[inputs_shape_]->GetNode(output_name);
// auto dump_name = output_name; // auto dump_name = output_name;
// while (dump_name.find("/") != std::string::npos) { // while (dump_name.find("/") != std::string::npos) {
// dump_name = dump_name.replace(dump_name.find("/"), 1, "_"); // dump_name = dump_name.replace(dump_name.find("/"), 1, "_");
...@@ -180,7 +221,11 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -180,7 +221,11 @@ class SubgraphEngine : public subgraph::Engine {
return 0; return 0;
} }
paddle::lite::subgraph::mlu::Graph graph_; paddle::lite_api::PrecisionType fp_type_;
std::vector<std::vector<int64_t>> inputs_shape_{};
std::map<std::vector<std::vector<int64_t>>,
std::shared_ptr<paddle::lite::subgraph::mlu::Graph>>
shape_graph_map_{};
}; };
template <PrecisionType Precision> template <PrecisionType Precision>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册