提交 786d0946 编写于 作者: J jiaopu

Add env param

上级 739b4ac1
...@@ -82,11 +82,13 @@ class Graph { ...@@ -82,11 +82,13 @@ class Graph {
void AddInput(std::shared_ptr<MLUTensor> tensor) { void AddInput(std::shared_ptr<MLUTensor> tensor) {
inputs_.push_back(tensor->mlu_tensor()); inputs_.push_back(tensor->mlu_tensor());
input_tensors_.push_back(tensor); input_tensors_.push_back(tensor);
if (GetBoolFromEnv("BATCH_SIZE_CHANGEBLE")) {
constexpr int input_dimNb = 4; constexpr int input_dimNb = 4;
bool input_dim_mutable[4] = {true, false, false, false}; bool input_dim_mutable[4] = {true, false, false, false};
cnmlSetTensorDimMutable( cnmlSetTensorDimMutable(
tensor->mlu_tensor(), input_dim_mutable, input_dimNb); tensor->mlu_tensor(), input_dim_mutable, input_dimNb);
} }
}
void AddOutput(std::shared_ptr<MLUTensor> tensor) { void AddOutput(std::shared_ptr<MLUTensor> tensor) {
outputs_.push_back(tensor->mlu_tensor()); outputs_.push_back(tensor->mlu_tensor());
......
...@@ -81,9 +81,13 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -81,9 +81,13 @@ class SubgraphEngine : public subgraph::Engine {
bool InputShapeChanged() { bool InputShapeChanged() {
std::vector<std::vector<int64_t>> new_shape; std::vector<std::vector<int64_t>> new_shape;
for (auto origin_itensor : origin_itensors_) { for (auto origin_itensor : origin_itensors_) {
if (GetBoolFromEnv("BATCH_SIZE_CHANGEBLE")) {
auto iv = origin_itensor->dims().Vectorize(); auto iv = origin_itensor->dims().Vectorize();
iv.erase(iv.begin()); iv.erase(iv.begin());
new_shape.push_back(iv); new_shape.push_back(iv);
} else {
new_shape.push_back(origin_itensor->dims().Vectorize());
}
} }
inputs_shape_ = new_shape; inputs_shape_ = new_shape;
if (shape_graph_map_.count(inputs_shape_) > 0) { if (shape_graph_map_.count(inputs_shape_) > 0) {
...@@ -106,9 +110,13 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -106,9 +110,13 @@ class SubgraphEngine : public subgraph::Engine {
for (auto& input_name : input_names_) { for (auto& input_name : input_names_) {
auto input_tensor = scope_->FindMutableTensor(input_name); auto input_tensor = scope_->FindMutableTensor(input_name);
origin_itensors_.push_back(input_tensor); origin_itensors_.push_back(input_tensor);
if (GetBoolFromEnv("BATCH_SIZE_CHANGEBLE")) {
auto iv = input_tensor->dims().Vectorize(); auto iv = input_tensor->dims().Vectorize();
iv.erase(iv.begin()); iv.erase(iv.begin());
new_shape.push_back(iv); new_shape.push_back(iv);
} else {
new_shape.push_back(input_tensor->dims().Vectorize());
}
CHECK(input_tensor); CHECK(input_tensor);
auto input_node = graph->AddNode(input_name, auto input_node = graph->AddNode(input_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册