提交 128cdc5c 编写于 作者: J jackzhang235 提交者: jackzhang235

add batch size unchangeable op's black list

上级 4af82042
...@@ -103,10 +103,11 @@ class Graph { ...@@ -103,10 +103,11 @@ class Graph {
// return outputs_shape; // return outputs_shape;
// } // }
void AddInput(std::shared_ptr<MLUTensor> tensor) { void AddInput(std::shared_ptr<MLUTensor> tensor,
bool disable_batch_size_changeable = true) {
inputs_.push_back(tensor->mlu_tensor()); inputs_.push_back(tensor->mlu_tensor());
input_tensors_.push_back(tensor); input_tensors_.push_back(tensor);
if (GetBoolFromEnv("BATCH_SIZE_CHANGEABLE")) { if (!disable_batch_size_changeable) {
constexpr int input_dimNb = 4; constexpr int input_dimNb = 4;
bool input_dim_mutable[4] = {true, false, false, false}; bool input_dim_mutable[4] = {true, false, false, false};
cnmlSetTensorDimMutable( cnmlSetTensorDimMutable(
......
...@@ -50,7 +50,15 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -50,7 +50,15 @@ class SubgraphEngine : public subgraph::Engine {
paddle::lite_api::PrecisionType type) paddle::lite_api::PrecisionType type)
: subgraph::Engine( : subgraph::Engine(
ctx, block_idx, block_desc, input_names, output_names, scope), ctx, block_idx, block_desc, input_names, output_names, scope),
fp_type_(type) {} fp_type_(type) {
VLOG(4) << "[MLU] PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL is "
<< GetBoolFromEnv("PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL");
VLOG(4) << "[MLU] PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE is "
<< GetBoolFromEnv("PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE");
if (GetBoolFromEnv("PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE")) {
disable_batch_size_changeable_ = true;
}
}
int Build() { int Build() {
// In order to attach all of the ops of the block desc, we need to build // In order to attach all of the ops of the block desc, we need to build
...@@ -83,7 +91,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -83,7 +91,7 @@ class SubgraphEngine : public subgraph::Engine {
// used in batch changable situation // used in batch changable situation
std::vector<std::vector<int64_t>> all_shape; std::vector<std::vector<int64_t>> all_shape;
for (auto origin_itensor : origin_itensors_) { for (auto origin_itensor : origin_itensors_) {
if (GetBoolFromEnv("BATCH_SIZE_CHANGEABLE")) { if (!disable_batch_size_changeable_) {
auto iv = origin_itensor->dims().Vectorize(); auto iv = origin_itensor->dims().Vectorize();
all_shape.push_back(iv); all_shape.push_back(iv);
iv.erase(iv.begin()); iv.erase(iv.begin());
...@@ -117,6 +125,21 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -117,6 +125,21 @@ class SubgraphEngine : public subgraph::Engine {
protected: protected:
int BuildDeviceProgram() override { int BuildDeviceProgram() override {
if (!error_compile_batch_size_changeable_ &&
!disable_batch_size_changeable_) {
int status = BuildDeviceProgramImpl();
if (subgraph::CHECK_SUCCESS(status)) {
return status;
}
VLOG(4) << "[MLU] build batch_size changeable subgraph op failed, "
"changed to input_shape changeable";
}
error_compile_batch_size_changeable_ = true;
disable_batch_size_changeable_ = true;
return BuildDeviceProgramImpl();
}
int BuildDeviceProgramImpl() {
int status = 0; int status = 0;
auto graph = std::make_shared<paddle::lite::subgraph::mlu::Graph>(); auto graph = std::make_shared<paddle::lite::subgraph::mlu::Graph>();
graph->SetFPType(fp_type_); graph->SetFPType(fp_type_);
...@@ -131,7 +154,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -131,7 +154,7 @@ class SubgraphEngine : public subgraph::Engine {
auto data_type = input_tensor->precision(); auto data_type = input_tensor->precision();
cnmlDataType_t fp_type = PrecisionToDatatype(data_type); cnmlDataType_t fp_type = PrecisionToDatatype(data_type);
origin_itensors_.push_back(input_tensor); origin_itensors_.push_back(input_tensor);
if (GetBoolFromEnv("BATCH_SIZE_CHANGEABLE")) { if (!disable_batch_size_changeable_) {
auto iv = input_tensor->dims().Vectorize(); auto iv = input_tensor->dims().Vectorize();
iv.erase(iv.begin()); iv.erase(iv.begin());
new_shape.push_back(iv); new_shape.push_back(iv);
...@@ -156,6 +179,16 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -156,6 +179,16 @@ class SubgraphEngine : public subgraph::Engine {
auto op = inst.op(); auto op = inst.op();
CHECK(op); CHECK(op);
std::string op_type = op->op_info()->Type(); std::string op_type = op->op_info()->Type();
if (!disable_batch_size_changeable_ &&
std::find(unsupport_batch_size_changeable_op_type_.begin(),
unsupport_batch_size_changeable_op_type_.end(),
op_type) !=
unsupport_batch_size_changeable_op_type_.end()) {
status |= subgraph::FAILED;
VLOG(4) << "[MLU] found unsupported batch_size changeable op type: "
<< op_type;
return status;
}
op->CheckShape(); op->CheckShape();
const_cast<OpLite*>(op)->InferShape(); const_cast<OpLite*>(op)->InferShape();
if (!bridges.Exists(op_type, TARGET(kMLU))) { if (!bridges.Exists(op_type, TARGET(kMLU))) {
...@@ -185,7 +218,8 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -185,7 +218,8 @@ class SubgraphEngine : public subgraph::Engine {
} }
} }
for (auto& input_name : input_names_) { for (auto& input_name : input_names_) {
graph->AddInput(graph->GetNode(input_name)); graph->AddInput(graph->GetNode(input_name),
disable_batch_size_changeable_);
} }
CHECK(!origin_otensors_.empty()) << "[MLU] no valid output names"; CHECK(!origin_otensors_.empty()) << "[MLU] no valid output names";
...@@ -194,7 +228,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -194,7 +228,7 @@ class SubgraphEngine : public subgraph::Engine {
auto core_number = mlu_context.MLUCoreNumber(); auto core_number = mlu_context.MLUCoreNumber();
graph->Compile(core_version, core_number); graph->Compile(core_version, core_number);
shape_graph_map_[new_shape] = graph; shape_graph_map_[new_shape] = graph;
if (GetBoolFromEnv("SAVE_MLU_OFFLINE_MODEL")) { if (GetBoolFromEnv("PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL")) {
graph->GenOfflineModel(GetOfflineModName()); graph->GenOfflineModel(GetOfflineModName());
} }
return status; return status;
...@@ -289,7 +323,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -289,7 +323,7 @@ class SubgraphEngine : public subgraph::Engine {
CHECK_EQ(graph_input->size(), origin_itensors_.size()); CHECK_EQ(graph_input->size(), origin_itensors_.size());
CHECK_EQ(graph_output->size(), origin_otensors_.size()); CHECK_EQ(graph_output->size(), origin_otensors_.size());
if (GetBoolFromEnv("BATCH_SIZE_CHANGEABLE")) { if (!disable_batch_size_changeable_) {
std::vector<std::shared_ptr<paddle::lite::subgraph::mlu::MLUTensor>> std::vector<std::shared_ptr<paddle::lite::subgraph::mlu::MLUTensor>>
graph_in; graph_in;
if (shape_tensor_map_in_.find(all_inputs_shape_) != if (shape_tensor_map_in_.find(all_inputs_shape_) !=
...@@ -405,6 +439,12 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -405,6 +439,12 @@ class SubgraphEngine : public subgraph::Engine {
std::map<std::vector<std::vector<int64_t>>, std::map<std::vector<std::vector<int64_t>>,
std::shared_ptr<paddle::lite::subgraph::mlu::Graph>> std::shared_ptr<paddle::lite::subgraph::mlu::Graph>>
shape_graph_map_{}; shape_graph_map_{};
// enable batch size changeable by default, this cound be changed by
// environment variable PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE and
// whether the op can be compiled with batch size changeable way
bool disable_batch_size_changeable_{false};
bool error_compile_batch_size_changeable_{false};
std::vector<std::string> unsupport_batch_size_changeable_op_type_{"concat"};
// search output runtime MLUTensor for certain output shape when enable // search output runtime MLUTensor for certain output shape when enable
// BATCH_SIZE_CHANGEABLE // BATCH_SIZE_CHANGEABLE
std::map<std::vector<std::vector<int64_t>>, std::map<std::vector<std::vector<int64_t>>,
...@@ -419,7 +459,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -419,7 +459,7 @@ class SubgraphEngine : public subgraph::Engine {
// BATCH_SIZE_CHANGEABLE // BATCH_SIZE_CHANGEABLE
std::map<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>> std::map<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
in_out_shape_map_{}; in_out_shape_map_{};
}; // namespace mlu };
template <PrecisionType Precision> template <PrecisionType Precision>
class SubgraphCompute class SubgraphCompute
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册