提交 786d0946 编写于 作者: J jiaopu

Add env param

上级 739b4ac1
......@@ -82,10 +82,12 @@ class Graph {
void AddInput(std::shared_ptr<MLUTensor> tensor) {
inputs_.push_back(tensor->mlu_tensor());
input_tensors_.push_back(tensor);
constexpr int input_dimNb = 4;
bool input_dim_mutable[4] = {true, false, false, false};
cnmlSetTensorDimMutable(
tensor->mlu_tensor(), input_dim_mutable, input_dimNb);
if (GetBoolFromEnv("BATCH_SIZE_CHANGEBLE")) {
constexpr int input_dimNb = 4;
bool input_dim_mutable[4] = {true, false, false, false};
cnmlSetTensorDimMutable(
tensor->mlu_tensor(), input_dim_mutable, input_dimNb);
}
}
void AddOutput(std::shared_ptr<MLUTensor> tensor) {
......
......@@ -81,9 +81,13 @@ class SubgraphEngine : public subgraph::Engine {
bool InputShapeChanged() {
std::vector<std::vector<int64_t>> new_shape;
for (auto origin_itensor : origin_itensors_) {
auto iv = origin_itensor->dims().Vectorize();
iv.erase(iv.begin());
new_shape.push_back(iv);
if (GetBoolFromEnv("BATCH_SIZE_CHANGEBLE")) {
auto iv = origin_itensor->dims().Vectorize();
iv.erase(iv.begin());
new_shape.push_back(iv);
} else {
new_shape.push_back(origin_itensor->dims().Vectorize());
}
}
inputs_shape_ = new_shape;
if (shape_graph_map_.count(inputs_shape_) > 0) {
......@@ -106,9 +110,13 @@ class SubgraphEngine : public subgraph::Engine {
for (auto& input_name : input_names_) {
auto input_tensor = scope_->FindMutableTensor(input_name);
origin_itensors_.push_back(input_tensor);
auto iv = input_tensor->dims().Vectorize();
iv.erase(iv.begin());
new_shape.push_back(iv);
if (GetBoolFromEnv("BATCH_SIZE_CHANGEBLE")) {
auto iv = input_tensor->dims().Vectorize();
iv.erase(iv.begin());
new_shape.push_back(iv);
} else {
new_shape.push_back(input_tensor->dims().Vectorize());
}
CHECK(input_tensor);
auto input_node = graph->AddNode(input_name,
......@@ -239,7 +247,7 @@ class SubgraphEngine : public subgraph::Engine {
for (size_t i = 0; i < origin_itensors_.size(); ++i) {
paddle::lite::subgraph::mlu::MLUTensor tmp(
origin_itensors_[i]->dims().Vectorize());
// graph_input->at(i)->get_origin_shape());
// graph_input->at(i)->get_origin_shape());
tmp.set_mlu_dtype(graph_input->at(i)->dtype());
tmp.set_mlu_ptr(const_cast<void*>(origin_itensors_[i]->raw_data()));
graph_in.push_back(
......@@ -253,7 +261,7 @@ class SubgraphEngine : public subgraph::Engine {
Precision>::T>(TARGET(kMLU)));
paddle::lite::subgraph::mlu::MLUTensor tmp(
origin_otensors_[i]->dims().Vectorize());
// graph_output->at(i)->get_origin_shape());
// graph_output->at(i)->get_origin_shape());
tmp.set_mlu_dtype(graph_output->at(i)->dtype());
tmp.set_mlu_ptr(p_data);
graph_out.push_back(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册