diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 36024d4a7d52d78ebca45c577a0fc81d97d16d91..49fa323fc66bcc9461c6c04300aab9c5120c368f 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -72,6 +72,7 @@ pass_library(identity_scale_op_clean_pass base) pass_library(sync_batch_norm_pass base) pass_library(runtime_context_cache_pass base) pass_library(simplify_anakin_detection_pattern_pass inference) +pass_library(anakin_fillconstant_elementwisemul_fuse inference) # There may be many transpose-flatten structures in a model, and the output of # these structures will be used as inputs to the concat Op. This pattern will @@ -82,7 +83,7 @@ foreach (index RANGE 3 6) file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n") endforeach() -foreach (index RANGE 3 6) +foreach (index RANGE 2 6) file(APPEND ${pass_file} "USE_PASS(simplify_anakin_detection_pattern_pass${index});\n") endforeach() diff --git a/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.cc b/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.cc new file mode 100644 index 0000000000000000000000000000000000000000..83b0da0c0118a856e54d744607cee8b421f330a3 --- /dev/null +++ b/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h" +#include "paddle/fluid/framework/ir/graph_viz_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(fill_constant); \ + GET_IR_NODE(fill_constant_out); \ + GET_IR_NODE(elementwise_mul); \ + GET_IR_NODE(elementwise_mul_out); + +std::unique_ptr AnakinFillconstantElementwisemulFuse::ApplyImpl( + std::unique_ptr graph) const { + const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse"; + FusePassBase::Init(pattern_name, graph.get()); + + GraphPatternDetector gpd; + auto* x = gpd.mutable_pattern() + ->NewNode("x") + ->assert_is_op_input("elementwise_mul", "X") + ->AsInput(); + + patterns::AnakinFillConstantElementWiseMulFuse pattern(gpd.mutable_pattern(), + pattern_name); + pattern(x); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + + PADDLE_ENFORCE(subgraph.count(x)); + auto* elementwise_in = subgraph.at(x); + float constant_value = + boost::get(fill_constant->Op()->GetAttr("value")); + + framework::OpDesc new_op_desc; + new_op_desc.SetType("scale"); + new_op_desc.SetInput("X", {elementwise_in->Name()}); + new_op_desc.SetAttr("scale", constant_value); + new_op_desc.SetAttr("bias", static_cast(0.0)); + new_op_desc.SetAttr("bias_after_scale", true); + new_op_desc.SetOutput("Out", {elementwise_mul_out->Name()}); + new_op_desc.Flush(); + + // Create a new node for the fused op. + auto* scale_op = graph->CreateOpNode(&new_op_desc); + + IR_NODE_LINK_TO(elementwise_in, scale_op); // Input + IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output + + // Delete the unneeded nodes. + GraphSafeRemoveNodes(graph.get(), + {fill_constant, fill_constant_out, elementwise_mul}); + }; + + gpd(graph.get(), handler); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(anakin_fillconstant_elementwisemul_fuse, + paddle::framework::ir::AnakinFillconstantElementwisemulFuse); diff --git a/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h b/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h new file mode 100644 index 0000000000000000000000000000000000000000..fa95143d3adae3e3eeb913af09986fb4a401bd73 --- /dev/null +++ b/paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h @@ -0,0 +1,35 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class AnakinFillconstantElementwisemulFuse : public FusePassBase { + public: + virtual ~AnakinFillconstantElementwisemulFuse() {} + + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 77c9a94df2f2c003c743276187a8b34979491c61..31e259c51d1996bbed33f978013ce5f591436704 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1596,6 +1596,29 @@ PDNode *patterns::AnakinDetectionPattern::operator()( return multiclass_nms_out; } +PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()( + PDNode *elementwise_op_input) { + auto fill_constant = + pattern->NewNode(fill_constant_repr())->assert_is_op("fill_constant"); + + auto fill_constant_out = pattern->NewNode(fill_constant_out_repr()) + ->assert_is_op_output("fill_constant") + ->assert_is_op_input("elementwise_mul", "Y") + ->AsIntermediate(); + + auto elementwise_mul_op = + pattern->NewNode(elementwise_mul_repr())->assert_is_op("elementwise_mul"); + + auto elementwise_mul_out = pattern->NewNode(elementwise_mul_out_repr()) + ->assert_is_op_output("elementwise_mul") + ->AsOutput(); + + fill_constant_out->LinksFrom({fill_constant}); + elementwise_mul_op->LinksFrom({elementwise_op_input, fill_constant_out}); + elementwise_mul_out->LinksFrom({elementwise_mul_op}); + return elementwise_mul_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 080b2f9644456370947e6a0a66be2c75ce11531f..16cb6fb7aee2c8ac2c59a758aa63001106c816d2 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -856,6 +856,21 @@ struct AnakinDetectionPattern : public PatternBase { } }; +struct AnakinFillConstantElementWiseMulFuse : public PatternBase { + AnakinFillConstantElementWiseMulFuse(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, + "anakin_fillconstant_elementwisemul_fuse") {} + + PDNode* operator()(PDNode* elementwise_op_input); + + // declare operator node's name + PATTERN_DECL_NODE(fill_constant); + PATTERN_DECL_NODE(fill_constant_out); + PATTERN_DECL_NODE(elementwise_mul); + PATTERN_DECL_NODE(elementwise_mul_out); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc b/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc index 5ab10ba39fa86a9a65c247177cd437ba48d3797d..84fb8063e6f020d5ada2c6af7a0307360aa1c92c 100644 --- a/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc +++ b/paddle/fluid/framework/ir/simplify_anakin_detection_pattern_pass.cc @@ -215,6 +215,7 @@ std::unique_ptr SimplifyAnakinDetectionPatternPass::ApplyImpl( } template class SimplifyAnakinDetectionPatternPass<1>; +template class SimplifyAnakinDetectionPatternPass<2>; template class SimplifyAnakinDetectionPatternPass<3>; template class SimplifyAnakinDetectionPatternPass<4>; template class SimplifyAnakinDetectionPatternPass<5>; @@ -227,6 +228,9 @@ template class SimplifyAnakinDetectionPatternPass<6>; REGISTER_PASS(simplify_anakin_detection_pattern_pass, paddle::framework::ir::SimplifyAnakinDetectionPatternPass<1>); +REGISTER_PASS(simplify_anakin_detection_pattern_pass2, + paddle::framework::ir::SimplifyAnakinDetectionPatternPass<2>); + REGISTER_PASS(simplify_anakin_detection_pattern_pass3, paddle::framework::ir::SimplifyAnakinDetectionPatternPass<3>); diff --git a/paddle/fluid/inference/anakin/convert/CMakeLists.txt b/paddle/fluid/inference/anakin/convert/CMakeLists.txt index 7b08375a7a3f4da9a1b499e1ccf3101835138f87..da9ffa5bbf6bd3bbfc679ed540d07739ba04b12f 100644 --- a/paddle/fluid/inference/anakin/convert/CMakeLists.txt +++ b/paddle/fluid/inference/anakin/convert/CMakeLists.txt @@ -1,5 +1,8 @@ cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc -elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc DEPS anakin_engine framework_proto scope op_registry) + elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc +batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc +detection_out.cc scale.cc DEPS anakin_engine framework_proto scope op_registry) + cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op) cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv) cc_test(test_anakin_activation SRCS test_activation_op.cc DEPS activation_op anakin_op_converter) @@ -13,3 +16,4 @@ cc_test(test_anakin_reshape SRCS test_reshape_op.cc DEPS anakin_op_converter res cc_test(test_anakin_flatten SRCS test_flatten_op.cc DEPS anakin_op_converter flatten_op reshape_op) cc_test(test_anakin_transpose SRCS test_transpose_op.cc DEPS anakin_op_converter transpose_op) cc_test(test_anakin_batch_norm SRCS test_batch_norm_op.cc DEPS anakin_op_converter batch_norm_op) +cc_test(test_anakin_scale SRCS test_scale_op.cc DEPS anakin_op_converter scale_op math_function) diff --git a/paddle/fluid/inference/anakin/convert/op_converter.h b/paddle/fluid/inference/anakin/convert/op_converter.h index 6ce37c39e6c02acfcfcf6b6566e8ecca121041fd..2eb7f24ce544c61aab1221b1518dd1fbcb9a7ca3 100644 --- a/paddle/fluid/inference/anakin/convert/op_converter.h +++ b/paddle/fluid/inference/anakin/convert/op_converter.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -72,32 +73,71 @@ class AnakinOpConverter { // The scope here should be inited with the parameter vars. void ConvertBlockToAnakinEngine( - framework::BlockDesc *block_desc, const framework::Scope &scope, + framework::BlockDesc *block_desc, framework::Scope *scope, const std::vector &inputs, const std::unordered_set ¶meters, const std::vector &outputs, AnakinNvEngine *engine) { framework::proto::BlockDesc *block_proto = block_desc->Proto(); - ConvertBlock(*block_proto, parameters, scope, engine); + ConvertBlock(*block_proto, parameters, *scope, engine); + engine->Freeze(); + // if the max_batch size + int max_batch_size = engine->GetMaxBatchSize(); + PADDLE_ENFORCE(max_batch_size > 0, + "the max_batch_size setted from config->EnableAnakinEngine " + "must largger than 0"); + // If the user does not specify this variable, we use the input shape from + // the block_desc. + auto max_input_shape = engine->GetMaxInputShape(); + std::map> temp_max_input_shape; + for (auto &input : inputs) { if (parameters.count(input)) continue; - auto *var = block_desc->FindVar(input); - PADDLE_ENFORCE(var, "no variable called %s", input); - - auto var_shape = var->GetShape(); - PADDLE_ENFORCE(var_shape.size() == 4); std::vector input_shape; - for (int i = 0; i < var_shape.size(); i++) { - input_shape.push_back(var_shape[i]); + input_shape.resize(4); + input_shape[0] = max_batch_size; + if (max_input_shape.count(input)) { + PADDLE_ENFORCE(max_input_shape[input].size() == 4, + "the dimensions of max_input_shape setted from " + "config->EnableAnakinEngine must be 4"); + for (int i = 1; i < 4; i++) { + input_shape[i] = max_input_shape[input][i]; + } + } else { + auto *var = block_desc->FindVar(input); + PADDLE_ENFORCE(var, "no variable called %s", input); + + auto var_shape = var->GetShape(); + std::cout << "input :" << input << std::endl; + PADDLE_ENFORCE(var_shape.size() == 4); + + for (size_t i = 1; i < var_shape.size(); i++) { + input_shape[i] = var_shape[i]; + } } - input_shape[0] = engine->GetMaxBatch(); - + temp_max_input_shape[input] = input_shape; engine->SetInputShape(input, input_shape); + // engine->Graph()->RegistVar(input); // For share from data. } + engine->SetMaxInputShape(temp_max_input_shape); - // engine->Graph()->RegistAllOut(); engine->Optimize(); engine->InitGraph(); + /* + for(auto& input : inputs) { + platform::CUDAPlace gpu_place(engine->GetDevice()); + auto input_var = scope->Var(); + auto input_tensor = input_var->GetMutable(); + auto input_max_shape = temp_max_input_shape[input]; + input_tensor->Resize(framework::make_ddim(input_max_shape)); + auto input_data = input_tensor->mutable_data(gpu_place); + auto* anakin_input = engine->Net()->get_in(input); + + ::anakin::saber::Tensor<::anakin::saber::NV> tmp_anakin_tensor(input_data, + ::anakin::saber::NV(), 0, input_max_shape); + anakin_input->share_from(tmp_anakin_tensor); + } + */ } void SetEngine(AnakinNvEngine *engine) { engine_ = engine; } diff --git a/paddle/fluid/inference/anakin/convert/scale.cc b/paddle/fluid/inference/anakin/convert/scale.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f3aa8c5d1111dc2829e241c9331eeb521003c03 --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/scale.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/anakin/convert/scale.h" +#include +#include + +using anakin::graph::GraphGlobalMem; +using anakin::AK_FLOAT; +using anakin::saber::NV; +using anakin::saber::Shape; + +namespace paddle { +namespace inference { +namespace anakin { + +void ScaleOpConverter::operator()(const framework::proto::OpDesc &op, + const framework::Scope &scope, + bool test_mode) { + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); + + auto input_name = op_desc.Input("X").front(); + auto output_name = op_desc.Output("Out").front(); + float scale = boost::get(op_desc.GetAttr("scale")); + float bias = boost::get(op_desc.GetAttr("bias")); + float bias_after_scale = + boost::get(op_desc.GetAttr("bias_after_scale")); + PADDLE_ENFORCE(bias_after_scale, + "The anakin scale layer only support bias after scale now."); + + engine_->AddOp(op_name, "Power", {input_name}, {output_name}); + engine_->AddOpAttr(op_name, "shift", bias); + engine_->AddOpAttr(op_name, "scale", scale); + engine_->AddOpAttr(op_name, "power", static_cast(1.0)); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle + +REGISTER_ANAKIN_OP_CONVERTER(scale, ScaleOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/scale.h b/paddle/fluid/inference/anakin/convert/scale.h new file mode 100644 index 0000000000000000000000000000000000000000..b858e3c512494f80c7c3818a570e43d90d65251b --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/scale.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/inference/anakin/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace anakin { + +class ScaleOpConverter : public AnakinOpConverter { + public: + ScaleOpConverter() = default; + + virtual void operator()(const framework::proto::OpDesc &op, + const framework::Scope &scope, + bool test_mode) override; + virtual ~ScaleOpConverter() {} +}; + +} // namespace anakin +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/ut_helper.h b/paddle/fluid/inference/anakin/convert/ut_helper.h index 1b0ef8c7dbe5fd8d39c42cc39972159126ec1214..d62d11d25bba37821099492c5c292e44fc566052 100644 --- a/paddle/fluid/inference/anakin/convert/ut_helper.h +++ b/paddle/fluid/inference/anakin/convert/ut_helper.h @@ -122,6 +122,8 @@ class AnakinConvertValidation { Singleton::Global().ConvertOp( desc, parameters_, scope_, engine_.get(), true /*test_mode*/); engine_->Freeze(); + + std::map> temp_max_input_shape; for (const auto& input : op_desc_->InputArgumentNames()) { if (parameters_.count(input)) continue; auto& t = inference::analysis::GetFromScope(scope_, @@ -131,7 +133,9 @@ class AnakinConvertValidation { t_shape.push_back(1); } engine_->SetInputShape(input, t_shape); + temp_max_input_shape[input] = t_shape; } + engine_->SetMaxInputShape(temp_max_input_shape); engine_->Optimize(); engine_->InitGraph(); } diff --git a/paddle/fluid/inference/anakin/engine.cc b/paddle/fluid/inference/anakin/engine.cc index b8b0d06d2106010772b0b9d4d307fd2744ce00a2..176bc1254b5517603c2d8d4c8279cf0e5d4c4578 100644 --- a/paddle/fluid/inference/anakin/engine.cc +++ b/paddle/fluid/inference/anakin/engine.cc @@ -33,13 +33,14 @@ namespace inference { namespace anakin { template -AnakinEngine::AnakinEngine(bool need_summary, - int device, - int max_batch_size) +AnakinEngine::AnakinEngine( + bool need_summary, int device, int max_batch_size, + std::map> max_input_shape) : graph_(new AnakinGraphT()), net_(new AnakinNetT(need_summary)) { device_ = device; max_batch_size_ = max_batch_size; + max_input_shape_ = max_input_shape; } template @@ -75,20 +76,31 @@ void AnakinEngine::Execute( auto *data = tensor->data(); auto fluid_input_shape = framework::vectorize2int(tensor->dims()); + while (fluid_input_shape.size() < 4) { + fluid_input_shape.push_back(1); + } auto *anakin_input = net_->get_in(input.first); - auto net_shape = anakin_input->shape(); + std::vector max_input_shape = max_input_shape_[input.first]; + int max_shape_sum = + std::accumulate(max_input_shape.begin(), max_input_shape.end(), 1, + std::multiplies()); + + PADDLE_ENFORCE(max_shape_sum >= tensor->numel(), + "The anakin input max shape should be greater than" + " or equal to the real input shape, Please set the max " + "input shape using EnableAnakinEngine"); + /* if (tensor->numel() > net_shape.count()) { graph_->Reshape(input.first, fluid_input_shape); net_.reset(new AnakinNetT(true)); net_->init(*graph_); anakin_input = net_->get_in(input.first); } + */ anakin_input->reshape(fluid_input_shape); - net_shape = anakin_input->shape(); ::anakin::saber::Tensor tmp_anakin_tensor(data, TargetT(), 0, - // net_shape); fluid_input_shape); anakin_input->copy_from(tmp_anakin_tensor); } diff --git a/paddle/fluid/inference/anakin/engine.h b/paddle/fluid/inference/anakin/engine.h index 101ca491678a54ce09fd9a5aa81d63eaede46304..3835ead1946823f7dfb2b9afb746ea78cb88832b 100644 --- a/paddle/fluid/inference/anakin/engine.h +++ b/paddle/fluid/inference/anakin/engine.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -55,8 +56,9 @@ class AnakinEngine { using GraphT = ::anakin::graph::Graph; public: - explicit AnakinEngine(bool need_summary = false, int device = 0, - int max_batch_size = 1); + explicit AnakinEngine( + bool need_summary = false, int device = 0, int max_batch_size = 1, + std::map> max_input_shape = {}); ~AnakinEngine(); void InitGraph(); void SetInputShape(const std::string &name, std::vector shape); @@ -73,10 +75,17 @@ class AnakinEngine { NetT *Net() { return net_.get(); } GraphT *Graph() { return graph_.get(); } std::unique_ptr Clone(); + const std::map> &GetMaxInputShape() { + return max_input_shape_; + } + void SetMaxInputShape(std::map> shape) { + max_input_shape_ = shape; + } + int GetMaxBatchSize() { return max_batch_size_; } void Freeze(); void Optimize(); void Save(std::string path) { graph_->save(path); } - int GetMaxBatch() { return max_batch_size_; } + int GetDevice() { return device_; } // void SaveSerializedData(std::string& data) { graph_->save_to_string(data); // } // void LoadSerializedData(const std::string& data) { @@ -87,6 +96,7 @@ class AnakinEngine { private: int max_batch_size_; + std::map> max_input_shape_; int device_; std::unique_ptr graph_; std::unique_ptr net_; @@ -104,11 +114,13 @@ class AnakinEngineManager { return engines_.at(name).get(); } - AnakinNvEngineT *Create(bool need_summary, int device, int max_batch_size, - std::string engine_name) { + AnakinNvEngineT *Create( + bool need_summary, int device, int max_batch_size, + std::map> max_input_shape, + std::string engine_name) { std::unique_lock lk(mut_); - auto *p = new AnakinEngine(need_summary, device, - max_batch_size); + auto *p = new AnakinEngine( + need_summary, device, max_batch_size, max_input_shape); engines_[engine_name].reset(p); return p; } diff --git a/paddle/fluid/inference/anakin/op_teller.cc b/paddle/fluid/inference/anakin/op_teller.cc index 3166f68b67af224367afc36984154671cf94a25d..3270f5b57a1fd4b97e4c2ae097620dbcad9447c7 100644 --- a/paddle/fluid/inference/anakin/op_teller.cc +++ b/paddle/fluid/inference/anakin/op_teller.cc @@ -38,6 +38,7 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("transpose2"); teller_set.insert("density_prior_box"); teller_set.insert("detection_out"); + teller_set.insert("scale"); } bool operator()(const std::string& op_type, diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 87aceba4793265189b5b35e76443b5ca1a6809aa..992c779711a9b331534b43ac3b6af2df75e88c75 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -57,6 +57,7 @@ struct Argument { using unique_ptr_t = std::unique_ptr>; using fusion_statis_t = std::unordered_map; using engine_opt_info_t = std::map; + using anakin_max_shape_t = std::map>; bool Has(const std::string& key) const { return valid_fields_.count(key); } @@ -150,6 +151,8 @@ struct Argument { DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine, bool); + DECL_ARGUMENT_FIELD(anakin_max_input_shape, AnakinMaxInputShape, + anakin_max_shape_t); DECL_ARGUMENT_FIELD(anakin_max_batch_size, AnakinMaxBatchSize, int); DECL_ARGUMENT_FIELD(use_anakin, UseAnakin, bool); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 3dc9c347b5f0bf4340f056f350e1ab38f5160a28..b0e07fdf132f31087c73342e0b239c50ef93abbd 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -77,6 +77,8 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("engine_opt_info", new std::map( argument->engine_opt_info())); pass->Set("predictor_id", new int(argument->predictor_id())); + pass->Set("max_input_shape", new std::map>( + argument->anakin_max_input_shape())); pass->Set("max_batch_size", new int(argument->anakin_max_batch_size())); } diff --git a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc index b2bd1ec0ea1143c971d4e9ace784e645c96926bf..de41e05f1a690912ee5c91643c528cb7109f3120 100644 --- a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -256,11 +257,14 @@ void AnakinSubgraphPass::CreateAnakinOp( input_names_with_id, output_names_with_id, std::to_string(predictor_id)); SetAttr(op_desc->Proto(), "engine_key", engine_key); - int max_batch_size = Get("max_batch_size"); + auto max_input_shape = + Get>>("max_input_shape"); + auto max_batch_size = Get("max_batch_size"); auto *anakin_engine = inference::Singleton::Global().Create( - true, Get("gpu_device_id"), max_batch_size, engine_key); + true, Get("gpu_device_id"), max_batch_size, max_input_shape, + engine_key); auto *scope = param_scope(); std::unordered_set param_set(params.begin(), params.end()); @@ -268,7 +272,7 @@ void AnakinSubgraphPass::CreateAnakinOp( inference::Singleton::Global() .ConvertBlockToAnakinEngine( - &block_desc_temp, *scope, + &block_desc_temp, scope, std::vector(input_names.begin(), input_names.end()), param_set, output_mapping, anakin_engine); } diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 1800f06f2de2ac7f8bd6b10b4c079ec75f13b67a..69d6ab1022e66d31a06883681d8132a99b359ce0 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -214,13 +214,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp( std::to_string(0)); // Get "" when there is no cached calibration table data. - std::string calibration_data = GetTrtCalibTableData( - Get("model_opt_cache_dir"), engine_key, enable_int8); + bool load_from_memory = Get("model_from_memory"); + std::string calibration_data = ""; + if (!load_from_memory) { + calibration_data = GetTrtCalibTableData( + Get("model_opt_cache_dir"), engine_key, enable_int8); + } SetAttr(op_desc->Proto(), "calibration_data", calibration_data); SetAttr(op_desc->Proto(), "enable_int8", enable_int8); SetAttr(op_desc->Proto(), "engine_key", engine_key); - bool load_from_memory = Get("model_from_memory"); std::string trt_engine_serialized_data = ""; if (load_from_memory) { std::map engine_opt_info = diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index 8360963f7366e7cd192e55d307669d4915b065de..d13ec7608c3e8075c1ef62fd4d47fbeee06e9005 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -30,7 +30,6 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { // The parameters are on the cpu, therefore, synchronization is not necessary. if (!argument->use_gpu()) return; - return; auto &graph = argument->main_graph(); std::vector repetitive_params; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 3c17f49fa350929e4c92c470c62a2dab6b6a92da..7bfdada49664544c829b1f4fc886292b29717c32 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -111,6 +111,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(use_anakin_); CP_MEMBER(anakin_max_batchsize_); + CP_MEMBER(anakin_max_input_shape_); // Ir related. CP_MEMBER(enable_ir_optim_); @@ -355,8 +356,11 @@ void AnalysisConfig::SwitchIrDebug(int x) { ir_debug_ = x; Update(); } -void AnalysisConfig::EnableAnakinEngine(int max_batch_size) { +void AnalysisConfig::EnableAnakinEngine( + int max_batch_size, + std::map> max_input_shape) { anakin_max_batchsize_ = max_batch_size; + anakin_max_input_shape_ = max_input_shape; use_anakin_ = true; Update(); } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 9c992602e0a82d816a69b369f4de6d4370896a33..bcae080bc9abdcdbeab9b9b3852a6a750a2efb6f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -380,6 +380,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { if (config_.use_gpu() && config_.anakin_engine_enabled()) { argument_.SetAnakinMaxBatchSize(config_.anakin_max_batchsize_); + argument_.SetAnakinMaxInputShape(config_.anakin_max_input_shape_); LOG(INFO) << "Anakin subgraph engine is enabled"; } @@ -835,3 +836,4 @@ USE_ANAKIN_CONVERTER(softmax); USE_ANAKIN_CONVERTER(detection_out); USE_ANAKIN_CONVERTER(density_prior_box); +USE_ANAKIN_CONVERTER(scale); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 65dd669c95fc50d08af04a1a48fcf44f111373d3..9a29f8f77ed7081494073c76a501949e2c346c5a 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -145,7 +145,9 @@ struct AnalysisConfig { /** * \brief Turn on the usage of Anakin sub-graph engine. */ - void EnableAnakinEngine(int max_batch_size = 1); + void EnableAnakinEngine( + int max_batch_size = 1, + std::map> max_input_shape = {}); /** A boolean state indicating whether the Anakin sub-graph engine is used. */ @@ -271,6 +273,7 @@ struct AnalysisConfig { mutable std::unique_ptr pass_builder_; bool use_anakin_{false}; int anakin_max_batchsize_; + std::map> anakin_max_input_shape_; std::map engine_opt_info_; }; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index f6d82a57d2939ddd37fad682d0190668ff15e3d5..8db636274fb0ebaa5765a60864b37863da8e5d44 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -71,7 +71,11 @@ void GpuPassStrategy::EnableMKLDNN() { // The following passes works for Anakin sub-graph engine. const std::vector kAnakinSubgraphPasses({ "infer_clean_graph_pass", // + "simplify_anakin_detection_pattern_pass5", // + "simplify_anakin_detection_pattern_pass4", // "simplify_anakin_detection_pattern_pass3", // + "simplify_anakin_detection_pattern_pass2", // + "anakin_fillconstant_elementwisemul_fuse", // "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // "conv_bn_fuse_pass", // diff --git a/paddle/fluid/operators/anakin/anakin_engine_op.h b/paddle/fluid/operators/anakin/anakin_engine_op.h index 7a70836652db2ce9774660d3853bde37666eed71..bbe9a221b2cae71a78b4d269b2aeb160c2c57055 100644 --- a/paddle/fluid/operators/anakin/anakin_engine_op.h +++ b/paddle/fluid/operators/anakin/anakin_engine_op.h @@ -97,6 +97,7 @@ class AnakinEngineOp : public framework::OperatorBase { if (param_names_.count(x)) continue; auto &t = inference::analysis::GetFromScope(scope, x); + /* auto t_shape = framework::vectorize(t.dims()); auto *anakin_input = engine->Net()->get_in(x); auto net_shape = anakin_input->shape(); @@ -112,20 +113,16 @@ class AnakinEngineOp : public framework::OperatorBase { t.mutable_data(dev_place); TensorCopySync(temp_t, dev_place, &t); } + */ inputs.insert({x, &t}); } std::map outputs; int output_index = 0; for (const auto &y : Outputs("Ys")) { - // std::vector ddim = - // engine->Net()->get_out(output_maps[output_index])->valid_shape(); - // we need get the output anakin output shape. auto *fluid_v = scope.FindVar(y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); auto *fluid_t = fluid_v->GetMutable(); - // fluid_t->Resize(framework::make_ddim(ddim)); - // fluid_t->mutable_data(boost::get(dev_place)); outputs.insert({output_maps[output_index], fluid_t}); output_index += 1; }