From 8643dbc233f12f829b64cc0ee6926e41fb891ddf Mon Sep 17 00:00:00 2001 From: nhzlx Date: Thu, 11 Apr 2019 09:02:52 +0000 Subject: [PATCH] cherry-pick from 16691:Anakin subgraph support yolo_v3 and faster-rcnn --- .../inference/anakin/convert/CMakeLists.txt | 4 +- .../anakin/convert/affine_channel.cc | 100 ++++++++++++++++++ .../inference/anakin/convert/affine_channel.h | 39 +++++++ .../inference/anakin/convert/op_converter.h | 16 +-- paddle/fluid/inference/anakin/convert/relu.cc | 18 ++++ paddle/fluid/inference/anakin/convert/relu.h | 11 ++ .../inference/anakin/convert/roi_align.cc | 59 +++++++++++ .../inference/anakin/convert/roi_align.h | 38 +++++++ .../anakin/convert/test_affine_channel_op.cc | 55 ++++++++++ .../inference/anakin/convert/test_relu_op.cc | 11 +- .../inference/anakin/convert/ut_helper.h | 4 +- paddle/fluid/inference/anakin/engine.cc | 26 +++-- paddle/fluid/inference/anakin/engine.h | 21 ++-- paddle/fluid/inference/anakin/op_teller.cc | 2 + .../inference/anakin/test_anakin_engine.cc | 2 +- .../ir_passes/anakin_subgraph_pass.cc | 3 +- .../analysis/ir_passes/subgraph_util.cc | 2 - paddle/fluid/inference/api/CMakeLists.txt | 1 + .../fluid/inference/api/analysis_predictor.cc | 2 + paddle/fluid/pybind/inference_api.cc | 3 + 20 files changed, 382 insertions(+), 35 deletions(-) create mode 100644 paddle/fluid/inference/anakin/convert/affine_channel.cc create mode 100644 paddle/fluid/inference/anakin/convert/affine_channel.h create mode 100644 paddle/fluid/inference/anakin/convert/roi_align.cc create mode 100644 paddle/fluid/inference/anakin/convert/roi_align.h create mode 100644 paddle/fluid/inference/anakin/convert/test_affine_channel_op.cc diff --git a/paddle/fluid/inference/anakin/convert/CMakeLists.txt b/paddle/fluid/inference/anakin/convert/CMakeLists.txt index d3d1522dcc..7cc75de8ee 100644 --- a/paddle/fluid/inference/anakin/convert/CMakeLists.txt +++ b/paddle/fluid/inference/anakin/convert/CMakeLists.txt @@ -1,4 +1,4 @@ -cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc DEPS anakin_engine framework_proto scope op_registry) +cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc roi_align.cc DEPS anakin_engine framework_proto scope op_registry) cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL) cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL) @@ -14,5 +14,5 @@ cc_test(test_anakin_flatten SRCS test_flatten_op.cc DEPS anakin_op_converter fla cc_test(test_anakin_transpose SRCS test_transpose_op.cc DEPS anakin_op_converter transpose_op SERIAL) cc_test(test_anakin_batch_norm SRCS test_batch_norm_op.cc DEPS anakin_op_converter batch_norm_op SERIAL) cc_test(test_anakin_dropout SRCS test_dropout_op.cc DEPS anakin_op_converter dropout_op SERIAL) -#cc_test(test_anakin_im2sequence SRCS test_im2sequence_op.cc DEPS anakin_op_converter im2sequence_op im2col) cc_test(test_anakin_sum SRCS test_sum_op.cc DEPS anakin_op_converter sum_op selected_rows_functor SERIAL) +cc_test(test_anakin_affine_channel SRCS test_affine_channel_op.cc DEPS anakin_op_converter affine_channel_op SERIAL) diff --git a/paddle/fluid/inference/anakin/convert/affine_channel.cc b/paddle/fluid/inference/anakin/convert/affine_channel.cc new file mode 100644 index 0000000000..7c886df082 --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/affine_channel.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/anakin/convert/affine_channel.h" +#include +#include +#include + +using anakin::graph::GraphGlobalMem; +using anakin::AK_FLOAT; +using anakin::Precision; +using anakin::saber::NV; +using anakin::saber::X86; +using anakin::saber::Shape; +using anakin::PBlock; +using anakin::PTuple; + +namespace paddle { +namespace inference { +namespace anakin { + +void AffineChannelOpConverter::operator()( + const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc, + const framework::Scope &scope, bool test_mode) { + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); + + auto input_name = op_desc.Input("X").front(); + auto output_name = op_desc.Output("Out").front(); + + // Copy the Scale to CPUPlace and get the pointer. + auto *scale_v = scope.FindVar(op_desc.Input("Scale").front()); + PADDLE_ENFORCE_NOT_NULL(scale_v); + auto *scale_t = scale_v->GetMutable(); + std::unique_ptr scale_tensor( + new framework::LoDTensor()); + scale_tensor->Resize(scale_t->dims()); + TensorCopySync((*scale_t), platform::CPUPlace(), scale_tensor.get()); + + // Copy the Bias to CPUPlace and get the pointer. + auto *bias_v = scope.FindVar(op_desc.Input("Bias").front()); + PADDLE_ENFORCE_NOT_NULL(bias_v); + auto *bias_t = bias_v->GetMutable(); + std::unique_ptr bias_tensor(new framework::LoDTensor()); + bias_tensor->Resize(bias_t->dims()); + TensorCopySync((*bias_t), platform::CPUPlace(), bias_tensor.get()); + + engine_->AddOp(op_name, "AffineChannel", {input_name}, {output_name}); + + // Generate the Scale parameter of Anakin. + auto scale_shape = framework::vectorize2int(scale_t->dims()); + while (scale_shape.size() < 4) { + scale_shape.insert(scale_shape.begin(), 1); + } + Shape anakin_scale_shape(scale_shape); + auto *weight1 = GraphGlobalMem::Global().template new_block( + anakin_scale_shape); + float *scale_cpu_data = + static_cast(weight1->h_tensor().mutable_data()); + std::copy_n(scale_tensor->data(), scale_tensor->numel(), + scale_cpu_data); + weight1->d_tensor().set_shape(anakin_scale_shape); + weight1->d_tensor().copy_from(weight1->h_tensor()); + engine_->AddOpAttr(op_name, "weight_1", *weight1); + + // Generate the Bias parameter of Anakin. + auto bias_shape = framework::vectorize2int(bias_t->dims()); + while (bias_shape.size() < 4) { + bias_shape.insert(bias_shape.begin(), 1); + } + Shape anakin_bias_shape(bias_shape); + auto *weight2 = GraphGlobalMem::Global().template new_block( + anakin_bias_shape); + float *bias_cpu_data = + static_cast(weight2->h_tensor().mutable_data()); + std::copy_n(bias_tensor->data(), bias_tensor->numel(), bias_cpu_data); + weight2->d_tensor().set_shape(anakin_bias_shape); + weight2->d_tensor().copy_from(weight2->h_tensor()); + engine_->AddOpAttr(op_name, "weight_2", *weight2); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle + +REGISTER_ANAKIN_OP_CONVERTER(affine_channel, AffineChannelOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/affine_channel.h b/paddle/fluid/inference/anakin/convert/affine_channel.h new file mode 100644 index 0000000000..ea0043670c --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/affine_channel.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/inference/anakin/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace anakin { + +class AffineChannelOpConverter : public AnakinOpConverter { + public: + AffineChannelOpConverter() = default; + + virtual void operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) override; + virtual ~AffineChannelOpConverter() {} + + private: +}; + +} // namespace anakin +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/op_converter.h b/paddle/fluid/inference/anakin/convert/op_converter.h index 1ca62658ef..bffab229ed 100644 --- a/paddle/fluid/inference/anakin/convert/op_converter.h +++ b/paddle/fluid/inference/anakin/convert/op_converter.h @@ -81,7 +81,6 @@ class AnakinOpConverter { const std::unordered_set ¶meters, const std::vector &outputs, AnakinNvEngine *engine) { ConvertBlock(block_desc, parameters, *scope, engine); - engine->Freeze(); // if the max_batch size int max_batch_size = engine->GetMaxBatchSize(); PADDLE_ENFORCE(max_batch_size > 0, @@ -91,7 +90,12 @@ class AnakinOpConverter { // the block_desc. auto max_input_shape = engine->GetMaxInputShape(); std::map> temp_max_input_shape; - + // Register outputs with anakin using the RegistVar interface before Freeze. + // Note that RegistVar's parameters can only be outputs, not inputs. + for (auto &output : outputs) { + engine->Graph()->RegistVar(output); + } + engine->Freeze(); for (auto &input : inputs) { if (parameters.count(input)) continue; std::vector input_shape; @@ -99,7 +103,7 @@ class AnakinOpConverter { input_shape[0] = max_batch_size; if (max_input_shape.count(input)) { PADDLE_ENFORCE(max_input_shape[input].size() == 4, - "the dimensions of max_input_shape setted from " + "the dimensions of max_input_shape setted from " "config->EnableAnakinEngine must be 4"); for (int i = 1; i < 4; i++) { input_shape[i] = max_input_shape[input][i]; @@ -118,14 +122,10 @@ class AnakinOpConverter { } temp_max_input_shape[input] = input_shape; engine->SetInputShape(input, input_shape); - engine->Graph()->RegistVar(input); // For share from data. } engine->SetMaxInputShape(temp_max_input_shape); engine->Optimize(); - - // For anakin share with fluid tensor. - engine->AllocTmpMem(); - engine->InitGraph(); + engine->InitNet(); } void SetEngine(AnakinNvEngine *engine) { engine_ = engine; } diff --git a/paddle/fluid/inference/anakin/convert/relu.cc b/paddle/fluid/inference/anakin/convert/relu.cc index 993437d014..744066e88a 100644 --- a/paddle/fluid/inference/anakin/convert/relu.cc +++ b/paddle/fluid/inference/anakin/convert/relu.cc @@ -41,8 +41,26 @@ void ReluOpConverter::operator()(const framework::proto::OpDesc &op, engine_->AddOpAttr(op_name, "alpha", 0); } +void LeakyReluOpConverter::operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) { + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); + auto input_name = op_desc.Input("X").front(); + auto output_name = op_desc.Output("Out").front(); + + float alpha = boost::get(op_desc.GetAttr("alpha")); + engine_->AddOp(op_name, "ReLU", {input_name}, {output_name}); + engine_->AddOpAttr(op_name, "alpha", alpha); +} + } // namespace anakin } // namespace inference } // namespace paddle REGISTER_ANAKIN_OP_CONVERTER(relu, ReluOpConverter); +REGISTER_ANAKIN_OP_CONVERTER(leaky_relu, LeakyReluOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/relu.h b/paddle/fluid/inference/anakin/convert/relu.h index 6ede506511..d7b6b6934d 100644 --- a/paddle/fluid/inference/anakin/convert/relu.h +++ b/paddle/fluid/inference/anakin/convert/relu.h @@ -33,6 +33,17 @@ class ReluOpConverter : public AnakinOpConverter { virtual ~ReluOpConverter() {} }; +class LeakyReluOpConverter : public AnakinOpConverter { + public: + LeakyReluOpConverter() = default; + + virtual void operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) override; + virtual ~LeakyReluOpConverter() {} +}; + } // namespace anakin } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/roi_align.cc b/paddle/fluid/inference/anakin/convert/roi_align.cc new file mode 100644 index 0000000000..0f2b08df08 --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/roi_align.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/anakin/convert/roi_align.h" +#include +#include + +using anakin::graph::GraphGlobalMem; +using anakin::AK_FLOAT; +using anakin::saber::NV; +using anakin::saber::Shape; + +namespace paddle { +namespace inference { +namespace anakin { + +void RoiAlignOpConverter::operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) { + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Input("ROIs").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); + + auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); + auto input_x_name = op_desc.Input("X").front(); + auto input_rois_name = op_desc.Input("ROIs").front(); + auto output_name = op_desc.Output("Out").front(); + + auto spatial_scale = boost::get(op_desc.GetAttr("spatial_scale")); + auto pooled_height = boost::get(op_desc.GetAttr("pooled_height")); + auto pooled_width = boost::get(op_desc.GetAttr("pooled_width")); + auto sampling_ratio = boost::get(op_desc.GetAttr("sampling_ratio")); + + engine_->AddOp(op_name, "RoiAlign", {input_x_name, input_rois_name}, + {output_name}); + engine_->AddOpAttr(op_name, "spatial_scale", spatial_scale); + engine_->AddOpAttr(op_name, "pooled_height", pooled_height); + engine_->AddOpAttr(op_name, "pooled_width", pooled_width); + engine_->AddOpAttr(op_name, "sampling_ratio", sampling_ratio); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle + +REGISTER_ANAKIN_OP_CONVERTER(roi_align, RoiAlignOpConverter); diff --git a/paddle/fluid/inference/anakin/convert/roi_align.h b/paddle/fluid/inference/anakin/convert/roi_align.h new file mode 100644 index 0000000000..c6df4754ba --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/roi_align.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/inference/anakin/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace anakin { + +class RoiAlignOpConverter : public AnakinOpConverter { + public: + RoiAlignOpConverter() = default; + + virtual void operator()(const framework::proto::OpDesc &op, + const framework::BlockDesc &block_desc, + const framework::Scope &scope, + bool test_mode) override; + virtual ~RoiAlignOpConverter() {} +}; + +} // namespace anakin +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/anakin/convert/test_affine_channel_op.cc b/paddle/fluid/inference/anakin/convert/test_affine_channel_op.cc new file mode 100644 index 0000000000..eb4f4e12ee --- /dev/null +++ b/paddle/fluid/inference/anakin/convert/test_affine_channel_op.cc @@ -0,0 +1,55 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/anakin/convert/affine_channel.h" +#include "paddle/fluid/inference/anakin/convert/op_converter.h" +#include "paddle/fluid/inference/anakin/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace anakin { + +TEST(affine_channel, native) { + // Declare the difference between the inputs. + std::unordered_set parameters({"scale", "bias"}); + + framework::Scope scope; + AnakinConvertValidation validator(parameters, &scope); + validator.DeclInputVar("x", {1, 3, 5, 2}); + validator.DeclOutputVar("out", {1, 3, 5, 2}); + validator.DeclParamVar("scale", {1, 3, 1, 1}); + validator.DeclParamVar("bias", {1, 3, 1, 1}); + + // Prepare Op descriptions. + framework::OpDesc desc; + desc.SetType("affine_channel"); + desc.SetInput("X", {"x"}); + desc.SetInput("Bias", {"bias"}); + desc.SetInput("Scale", {"scale"}); + desc.SetOutput("Out", {"out"}); + + // Layout must be explicitly specified here as NCHW. + desc.SetAttr("data_layout", std::string("NCHW")); + + validator.SetOp(*desc.Proto()); + validator.Execute(1); +} + +} // namespace anakin +} // namespace inference +} // namespace paddle + +USE_OP(affine_channel); +USE_ANAKIN_CONVERTER(affine_channel); diff --git a/paddle/fluid/inference/anakin/convert/test_relu_op.cc b/paddle/fluid/inference/anakin/convert/test_relu_op.cc index 04e624518a..cba19a5585 100644 --- a/paddle/fluid/inference/anakin/convert/test_relu_op.cc +++ b/paddle/fluid/inference/anakin/convert/test_relu_op.cc @@ -21,7 +21,7 @@ namespace paddle { namespace inference { namespace anakin { -static void test_activation_op(const std::string &op_type) { +static void test_relu_op(const std::string &op_type) { auto *converter = Registry::Global().Lookup(op_type); PADDLE_ENFORCE(converter != nullptr); std::unordered_set parameters; @@ -33,6 +33,9 @@ static void test_activation_op(const std::string &op_type) { desc.SetType(op_type); desc.SetInput("X", {"act-X"}); desc.SetOutput("Out", {"act-Out"}); + if (op_type == "leaky_relu") { + desc.SetAttr("alpha", 0.1f); + } LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto()); @@ -41,10 +44,14 @@ static void test_activation_op(const std::string &op_type) { validator.Execute(5); } -TEST(sigm_op, test) { test_activation_op("relu"); } +TEST(activation, relu) { test_relu_op("relu"); } +TEST(activation, leaky_relu) { test_relu_op("leaky_relu"); } + } // namespace anakin } // namespace inference } // namespace paddle USE_OP(relu); USE_ANAKIN_CONVERTER(relu); +USE_OP(leaky_relu); +USE_ANAKIN_CONVERTER(leaky_relu); diff --git a/paddle/fluid/inference/anakin/convert/ut_helper.h b/paddle/fluid/inference/anakin/convert/ut_helper.h index 029aff6704..a931efbcf4 100644 --- a/paddle/fluid/inference/anakin/convert/ut_helper.h +++ b/paddle/fluid/inference/anakin/convert/ut_helper.h @@ -67,7 +67,7 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place, auto* temp_data = temp_tensor.mutable_data(cpu_place); for (size_t i = 0; i < num_elements; i++) { - *(temp_data + i) = random(0., 1.); + *(temp_data + i) = random(-128., 128.); } TensorCopySync(temp_tensor, place, tensor); @@ -151,7 +151,7 @@ class AnakinConvertValidation { } engine_->SetMaxInputShape(temp_max_input_shape); engine_->Optimize(); - engine_->InitGraph(); + engine_->InitNet(); } // We use the set 'neglected_output' here, because some Ops like batch norm, diff --git a/paddle/fluid/inference/anakin/engine.cc b/paddle/fluid/inference/anakin/engine.cc index ba044c9401..2b85d266cf 100644 --- a/paddle/fluid/inference/anakin/engine.cc +++ b/paddle/fluid/inference/anakin/engine.cc @@ -35,12 +35,14 @@ namespace anakin { template AnakinEngine::AnakinEngine( bool need_summary, int device, int max_batch_size, - std::map> max_input_shape) + std::map> max_input_shape, + std::vector program_inputs) : graph_(new AnakinGraphT()), net_(new AnakinNetT(need_summary)) { device_ = device; max_batch_size_ = max_batch_size; max_input_shape_ = max_input_shape; + program_inputs_ = program_inputs; } template @@ -54,7 +56,7 @@ void AnakinEngine::SetInputShape( } template -void AnakinEngine::InitGraph() { +void AnakinEngine::InitNet() { net_->init(*graph_); } @@ -85,11 +87,19 @@ void AnakinEngine::Execute( int max_shape_sum = std::accumulate(max_input_shape.begin(), max_input_shape.end(), 1, std::multiplies()); - - PADDLE_ENFORCE(max_shape_sum >= tensor->numel(), - "The anakin input max shape should be greater than" - " or equal to the real input shape, Please set the max " - "input shape using EnableAnakinEngine"); + if (tensor->numel() > max_shape_sum) { + PADDLE_ENFORCE(std::find(program_inputs_.begin(), program_inputs_.end(), + input.first) == program_inputs_.end(), + "The anakin input max shape should be greater than" + " or equal to the real input shape, Please set the max " + "input shape using EnableAnakinEngine"); + VLOG(3) << "Anakin Net will be reset because of the inputs out of range: " + << input.first; + graph_->Reshape(input.first, fluid_input_shape); + net_.reset(new AnakinNetT(true)); + net_->init(*graph_); + anakin_input = net_->get_in(input.first); + } anakin_input->reshape(fluid_input_shape); ::anakin::saber::Tensor tmp_anakin_tensor(data, TargetT(), 0, fluid_input_shape); @@ -114,7 +124,7 @@ void AnakinEngine::Execute( template void AnakinEngine::Freeze() { - PADDLE_ENFORCE(graph_->Freeze_v3(), "Freeze anakin subgraph."); + PADDLE_ENFORCE(graph_->Freeze(), "Freeze anakin subgraph."); } template diff --git a/paddle/fluid/inference/anakin/engine.h b/paddle/fluid/inference/anakin/engine.h index 4845ffdf5b..1325306557 100644 --- a/paddle/fluid/inference/anakin/engine.h +++ b/paddle/fluid/inference/anakin/engine.h @@ -58,9 +58,10 @@ class AnakinEngine { public: explicit AnakinEngine( bool need_summary = false, int device = 0, int max_batch_size = 1, - std::map> max_input_shape = {}); + std::map> max_input_shape = {}, + std::vector program_inputs = {}); ~AnakinEngine(); - void InitGraph(); + void InitNet(); void SetInputShape(const std::string &name, std::vector shape); void AddOp(const std::string &name, const std::string &type, const std::vector &inputs, @@ -81,15 +82,16 @@ class AnakinEngine { void SetMaxInputShape(std::map> shape) { max_input_shape_ = shape; } + const std::vector &GetScalableInputs() { + return program_inputs_; + } + void SetScalableInputs(std::vector program_inputs) { + program_inputs_ = program_inputs; + } int GetMaxBatchSize() { return max_batch_size_; } void Freeze(); void Optimize(); - void AllocTmpMem() { - PADDLE_ENFORCE(net_->alloc_memory_first(*graph_), - "anakin alloc temp memory first failed"); - } void Save(std::string path) { graph_->save(path); } - bool IsInit() { return initialized_; } int GetDevice() { return device_; } void Execute(const std::map &inputs, @@ -103,6 +105,7 @@ class AnakinEngine { int device_; std::unique_ptr graph_; std::unique_ptr net_; + std::vector program_inputs_; }; class AnakinEngineManager { @@ -120,10 +123,10 @@ class AnakinEngineManager { AnakinNvEngineT *Create( bool need_summary, int device, int max_batch_size, std::map> max_input_shape, - std::string engine_name) { + std::vector program_inputs, std::string engine_name) { std::unique_lock lk(mut_); auto *p = new AnakinEngine( - need_summary, device, max_batch_size, max_input_shape); + need_summary, device, max_batch_size, max_input_shape, program_inputs); engines_[engine_name].reset(p); return p; } diff --git a/paddle/fluid/inference/anakin/op_teller.cc b/paddle/fluid/inference/anakin/op_teller.cc index 2042fb18ea..72064c1790 100644 --- a/paddle/fluid/inference/anakin/op_teller.cc +++ b/paddle/fluid/inference/anakin/op_teller.cc @@ -44,6 +44,8 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("sum"); teller_set.insert("depthwise_conv2d"); teller_set.insert("prior_box"); + teller_set.insert("leaky_relu"); + teller_set.insert("affine_channel"); } bool operator()(const std::string& op_type, diff --git a/paddle/fluid/inference/anakin/test_anakin_engine.cc b/paddle/fluid/inference/anakin/test_anakin_engine.cc index 8fd6b8bec9..613481a555 100644 --- a/paddle/fluid/inference/anakin/test_anakin_engine.cc +++ b/paddle/fluid/inference/anakin/test_anakin_engine.cc @@ -68,7 +68,7 @@ TEST_F(TestAnakinEngine, Execute) { // engine_->AddOpAttr("x", "input_shape", input_shape); engine_->SetInputShape("x", {1, 1, 1, 1}); engine_->Optimize(); - engine_->InitGraph(); + engine_->InitNet(); framework::LoDTensor x; framework::LoDTensor y; x.Resize({1, 1, 1, 1}); diff --git a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc index b8d8b6fed8..cbf883a8a5 100644 --- a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc @@ -192,11 +192,12 @@ void AnakinSubgraphPass::CreateAnakinOp( auto max_input_shape = Get>>("max_input_shape"); auto max_batch_size = Get("max_batch_size"); + auto program_inputs = program_desc->GetFeedTargetNames(); auto *anakin_engine = inference::Singleton::Global().Create( true, Get("gpu_device_id"), max_batch_size, max_input_shape, - engine_key); + program_inputs, engine_key); auto *scope = param_scope(); std::unordered_set param_set(params.begin(), params.end()); diff --git a/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc b/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc index 7c4aab06a1..8f7c6ac755 100644 --- a/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc +++ b/paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc @@ -100,7 +100,6 @@ void RenameAndGetOutputs( const std::string arg_value = in_var->arguments(k); const std::string arg_value_with_id = arg_value + std::to_string(var2id[arg_value]); - if (input_names_with_id.count(arg_value_with_id)) { replaced_names.push_back(arg_value); if (graph_var_map.count(arg_value)) { @@ -149,7 +148,6 @@ void RenameAndGetOutputs( const std::string arg_value = out_var->arguments(k); const std::string arg_value_with_id = arg_value + std::to_string(var2id[arg_value]); - if (graph_var_map.count(arg_value)) { add_block_var(arg_value, arg_value_with_id); } diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 882bb34683..9c80b7a839 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -70,3 +70,4 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI anakin_target(inference_anakin_api) anakin_target(inference_anakin_api_shared) endif() +inference_analysis_test(faster_rcnn_test SRCS faster_rcnn_test.cc EXTRA_DEPS paddle_fluid) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 6942604b07..e5991af4f7 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -888,4 +888,6 @@ USE_ANAKIN_CONVERTER(density_prior_box); USE_ANAKIN_CONVERTER(dropout); USE_ANAKIN_CONVERTER(sum); USE_ANAKIN_CONVERTER(prior_box); +USE_ANAKIN_CONVERTER(leaky_relu); +USE_ANAKIN_CONVERTER(affine_channel); #endif diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 236afc77f7..ace385ec60 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -229,6 +229,9 @@ void BindAnalysisConfig(py::module *m) { py::arg("min_subgraph_size") = 3, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32, py::arg("use_static") = true) + .def("enable_anakin_engine", &AnalysisConfig::EnableAnakinEngine, + py::arg("max_batch_size") = 1, py::arg("max_input_shape") = {}, + py::arg("min_subgraph_size") = 6) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug, py::arg("x") = true) -- GitLab