提交 786d0946 编写于 作者: J jiaopu

Add env param

上级 739b4ac1
...@@ -82,10 +82,12 @@ class Graph { ...@@ -82,10 +82,12 @@ class Graph {
void AddInput(std::shared_ptr<MLUTensor> tensor) { void AddInput(std::shared_ptr<MLUTensor> tensor) {
inputs_.push_back(tensor->mlu_tensor()); inputs_.push_back(tensor->mlu_tensor());
input_tensors_.push_back(tensor); input_tensors_.push_back(tensor);
constexpr int input_dimNb = 4; if (GetBoolFromEnv("BATCH_SIZE_CHANGEBLE")) {
bool input_dim_mutable[4] = {true, false, false, false}; constexpr int input_dimNb = 4;
cnmlSetTensorDimMutable( bool input_dim_mutable[4] = {true, false, false, false};
tensor->mlu_tensor(), input_dim_mutable, input_dimNb); cnmlSetTensorDimMutable(
tensor->mlu_tensor(), input_dim_mutable, input_dimNb);
}
} }
void AddOutput(std::shared_ptr<MLUTensor> tensor) { void AddOutput(std::shared_ptr<MLUTensor> tensor) {
......
...@@ -81,9 +81,13 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -81,9 +81,13 @@ class SubgraphEngine : public subgraph::Engine {
bool InputShapeChanged() { bool InputShapeChanged() {
std::vector<std::vector<int64_t>> new_shape; std::vector<std::vector<int64_t>> new_shape;
for (auto origin_itensor : origin_itensors_) { for (auto origin_itensor : origin_itensors_) {
auto iv = origin_itensor->dims().Vectorize(); if (GetBoolFromEnv("BATCH_SIZE_CHANGEBLE")) {
iv.erase(iv.begin()); auto iv = origin_itensor->dims().Vectorize();
new_shape.push_back(iv); iv.erase(iv.begin());
new_shape.push_back(iv);
} else {
new_shape.push_back(origin_itensor->dims().Vectorize());
}
} }
inputs_shape_ = new_shape; inputs_shape_ = new_shape;
if (shape_graph_map_.count(inputs_shape_) > 0) { if (shape_graph_map_.count(inputs_shape_) > 0) {
...@@ -106,9 +110,13 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -106,9 +110,13 @@ class SubgraphEngine : public subgraph::Engine {
for (auto& input_name : input_names_) { for (auto& input_name : input_names_) {
auto input_tensor = scope_->FindMutableTensor(input_name); auto input_tensor = scope_->FindMutableTensor(input_name);
origin_itensors_.push_back(input_tensor); origin_itensors_.push_back(input_tensor);
auto iv = input_tensor->dims().Vectorize(); if (GetBoolFromEnv("BATCH_SIZE_CHANGEBLE")) {
iv.erase(iv.begin()); auto iv = input_tensor->dims().Vectorize();
new_shape.push_back(iv); iv.erase(iv.begin());
new_shape.push_back(iv);
} else {
new_shape.push_back(input_tensor->dims().Vectorize());
}
CHECK(input_tensor); CHECK(input_tensor);
auto input_node = graph->AddNode(input_name, auto input_node = graph->AddNode(input_name,
...@@ -239,7 +247,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -239,7 +247,7 @@ class SubgraphEngine : public subgraph::Engine {
for (size_t i = 0; i < origin_itensors_.size(); ++i) { for (size_t i = 0; i < origin_itensors_.size(); ++i) {
paddle::lite::subgraph::mlu::MLUTensor tmp( paddle::lite::subgraph::mlu::MLUTensor tmp(
origin_itensors_[i]->dims().Vectorize()); origin_itensors_[i]->dims().Vectorize());
// graph_input->at(i)->get_origin_shape()); // graph_input->at(i)->get_origin_shape());
tmp.set_mlu_dtype(graph_input->at(i)->dtype()); tmp.set_mlu_dtype(graph_input->at(i)->dtype());
tmp.set_mlu_ptr(const_cast<void*>(origin_itensors_[i]->raw_data())); tmp.set_mlu_ptr(const_cast<void*>(origin_itensors_[i]->raw_data()));
graph_in.push_back( graph_in.push_back(
...@@ -253,7 +261,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -253,7 +261,7 @@ class SubgraphEngine : public subgraph::Engine {
Precision>::T>(TARGET(kMLU))); Precision>::T>(TARGET(kMLU)));
paddle::lite::subgraph::mlu::MLUTensor tmp( paddle::lite::subgraph::mlu::MLUTensor tmp(
origin_otensors_[i]->dims().Vectorize()); origin_otensors_[i]->dims().Vectorize());
// graph_output->at(i)->get_origin_shape()); // graph_output->at(i)->get_origin_shape());
tmp.set_mlu_dtype(graph_output->at(i)->dtype()); tmp.set_mlu_dtype(graph_output->at(i)->dtype());
tmp.set_mlu_ptr(p_data); tmp.set_mlu_ptr(p_data);
graph_out.push_back( graph_out.push_back(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册