提交 07dcf285 编写于 作者: N nhzlx

git cherry-pick from feature/anakin-engine: update anakin subgraph #16278

上级 c407dfa3
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc
elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc
batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc
detection_out.cc scale.cc DEPS anakin_engine framework_proto scope op_registry)
elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc DEPS anakin_engine framework_proto scope op_registry)
cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op)
cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv)
......@@ -9,11 +7,13 @@ cc_test(test_anakin_activation SRCS test_activation_op.cc DEPS activation_op ana
cc_test(test_anakin_pool2d SRCS test_pool2d_op.cc DEPS anakin_op_converter pool_op pooling)
cc_test(test_anakin_concat SRCS test_concat_op.cc DEPS anakin_op_converter concat_op concat_and_split)
cc_test(test_anakin_split SRCS test_split_op.cc DEPS anakin_op_converter split_op concat_and_split)
cc_test(test_anakin_elementwise SRCS test_elementwise_op.cc DEPS anakin_op_converter elementwise_add_op)
cc_test(test_anakin_elementwise SRCS test_elementwise_op.cc DEPS anakin_op_converter elementwise_add_op elementwise_mul_op)
cc_test(test_anakin_relu SRCS test_relu_op.cc DEPS activation_op anakin_op_converter SERIAL)
cc_test(test_anakin_softmax SRCS test_softmax_op.cc DEPS anakin_op_converter softmax_op softmax)
cc_test(test_anakin_reshape SRCS test_reshape_op.cc DEPS anakin_op_converter reshape_op)
cc_test(test_anakin_flatten SRCS test_flatten_op.cc DEPS anakin_op_converter flatten_op reshape_op)
cc_test(test_anakin_transpose SRCS test_transpose_op.cc DEPS anakin_op_converter transpose_op)
cc_test(test_anakin_batch_norm SRCS test_batch_norm_op.cc DEPS anakin_op_converter batch_norm_op)
cc_test(test_anakin_scale SRCS test_scale_op.cc DEPS anakin_op_converter scale_op math_function)
cc_test(test_anakin_dropout SRCS test_dropout_op.cc DEPS anakin_op_converter dropout_op)
cc_test(test_anakin_im2sequence SRCS test_im2sequence_op.cc DEPS anakin_op_converter im2sequence_op im2col)
cc_test(test_anakin_sum SRCS test_sum_op.cc DEPS anakin_op_converter sum_op selected_rows_functor)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/dropout.h"
#include <algorithm>
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void DropoutOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Mask").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto x_name = op_desc.Input("X").front();
auto out_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Scale", {x_name}, {out_name});
auto dropout_prob = boost::get<float>(op_desc.GetAttr("dropout_prob"));
auto factor = 1 - dropout_prob;
Shape shape1(std::vector<int>({1, 1, 1, 1}));
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(shape1);
auto *factor_data = static_cast<float *>(weight1->h_tensor().mutable_data());
float weight1_data[] = {factor};
std::copy(std::begin(weight1_data), std::end(weight1_data), factor_data);
engine_->AddOpAttr(op_name, "weight_1", *weight1);
engine_->AddOpAttr(op_name, "axis", 0);
engine_->AddOpAttr(op_name, "num_axes", 0);
engine_->AddOpAttr(op_name, "bias_term", false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(dropout, DropoutOpConverter);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
class DropoutOpConverter : public AnakinOpConverter {
public:
DropoutOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override;
virtual ~DropoutOpConverter() {}
private:
};
} // namespace anakin
} // namespace inference
} // namespace paddle
......@@ -35,7 +35,7 @@ void ElementwiseAddOpConverter::operator()(const framework::proto::OpDesc &op,
bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); // Y is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto x_name = op_desc.Input("X").front();
......@@ -50,8 +50,39 @@ void ElementwiseAddOpConverter::operator()(const framework::proto::OpDesc &op,
engine_->AddOpAttr<PTuple<float>>(op_name, "coeff", coeff);
}
void ElementwiseMulOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto x_name = op_desc.Input("X").front();
auto y_name = op_desc.Input("Y").front();
auto out_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Scale", {x_name, y_name}, {out_name});
// Fill a number to weight_1 as a placeholder.
Shape shape1(std::vector<int>({1, 1, 1, 1}));
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(shape1);
auto *placeholder_data =
static_cast<float *>(weight1->h_tensor().mutable_data());
float weight1_data[] = {1};
std::copy(std::begin(weight1_data), std::end(weight1_data), placeholder_data);
engine_->AddOpAttr(op_name, "weight_1", *weight1);
auto axis = boost::get<int>(op_desc.GetAttr("axis"));
engine_->AddOpAttr(op_name, "axis", axis);
engine_->AddOpAttr(op_name, "num_axes", 1);
engine_->AddOpAttr(op_name, "bias_term", false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(elementwise_add, ElementwiseAddOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(elementwise_mul, ElementwiseMulOpConverter);
......@@ -32,6 +32,18 @@ class ElementwiseAddOpConverter : public AnakinOpConverter {
private:
};
class ElementwiseMulOpConverter : public AnakinOpConverter {
public:
ElementwiseMulOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override;
virtual ~ElementwiseMulOpConverter() {}
private:
};
} // namespace anakin
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/im2sequence.h"
#include <algorithm>
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void Im2SequenceConverter::operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 0);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto x_name = op_desc.Input("X").front();
auto out_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Im2Sequence", {x_name}, {out_name});
std::vector<int> dilations = {1, 1};
auto paddings = boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
auto strides = boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
auto kernels = boost::get<std::vector<int>>(op_desc.GetAttr("kernels"));
engine_->AddOpAttr<PTuple<int>>(op_name, "paddings", paddings);
engine_->AddOpAttr<PTuple<int>>(op_name, "strides", strides);
engine_->AddOpAttr<PTuple<int>>(op_name, "window_size", kernels);
engine_->AddOpAttr<PTuple<int>>(op_name, "dilations", dilations);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(im2sequence, Im2SequenceConverter);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
class Im2SequenceConverter : public AnakinOpConverter {
public:
Im2SequenceConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override;
virtual ~Im2SequenceConverter() {}
private:
};
} // namespace anakin
} // namespace inference
} // namespace paddle
......@@ -55,7 +55,11 @@ void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op,
if (pool_type == "max") {
anakin_pool_type = "MAX";
} else if (pool_type == "avg") {
anakin_pool_type = "AVGEXC";
if (paddings[0] || paddings[1]) {
anakin_pool_type = "AVGEXC";
} else {
anakin_pool_type = "AVG";
}
} else {
PADDLE_THROW("TensorRT unsupported pooling type!");
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/sum.h"
#include <algorithm>
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void SumOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 2);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto input_names = op_desc.Input("X");
auto out_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
std::vector<float> coeff = {1, 1};
std::string elementwise_type = "Add";
engine_->AddOp(op_name, "Eltwise", input_names, {out_name});
engine_->AddOpAttr<PTuple<float>>(op_name, "coeff", coeff);
engine_->AddOpAttr<std::string>(op_name, "type", elementwise_type);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(sum, SumOpConverter);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
class SumOpConverter : public AnakinOpConverter {
public:
SumOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override;
virtual ~SumOpConverter() {}
private:
};
} // namespace anakin
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/dropout.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace anakin {
TEST(dropout_op, native) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, scope);
validator.DeclInputVar("x", {1, 1, 2, 2});
validator.DeclOutputVar("out", {1, 1, 2, 2});
validator.DeclOutputVar("mask", {1, 1, 2, 2});
// Prepare Op description
framework::OpDesc desc;
desc.SetType("dropout");
desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"out"});
desc.SetOutput("Mask", {"mask"});
float dropout_prob = 0.5;
desc.SetAttr("dropout_prob", dropout_prob);
desc.SetAttr("is_test", true);
validator.SetOp(*desc.Proto());
std::unordered_set<std::string> neglected_output = {"mask"};
validator.Execute(1, neglected_output);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(dropout);
USE_ANAKIN_CONVERTER(dropout);
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/elementwise.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
......@@ -20,20 +21,20 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(elementwise_op, native) {
static void test_elementwise_op(const std::string &op_type) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, scope);
validator.DeclInputVar("elementwise_add_x", {1, 1, 2, 2});
validator.DeclInputVar("elementwise_y", {1, 1, 2, 2});
validator.DeclOutputVar("elementwise_out", {1, 1, 2, 2});
validator.DeclInputVar("x", {1, 1, 2, 2});
validator.DeclInputVar("y", {1, 1, 2, 2});
validator.DeclOutputVar("out", {1, 1, 2, 2});
// Prepare Op description
framework::OpDesc desc;
desc.SetType("elementwise_add");
desc.SetInput("X", {"elementwise_add_x"});
desc.SetInput("Y", {"elementwise_y"});
desc.SetOutput("Out", {"elementwise_out"});
desc.SetType(op_type);
desc.SetInput("X", {"x"});
desc.SetInput("Y", {"y"});
desc.SetOutput("Out", {"out"});
int axis = -1;
desc.SetAttr("axis", axis);
......@@ -42,9 +43,14 @@ TEST(elementwise_op, native) {
validator.Execute(1);
}
TEST(elementwise_op, native_add) { test_elementwise_op("elementwise_add"); }
TEST(elementwise_op, native_mul) { test_elementwise_op("elementwise_mul"); }
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(elementwise_add);
USE_ANAKIN_CONVERTER(elementwise_add);
USE_OP(elementwise_mul);
USE_ANAKIN_CONVERTER(elementwise_mul);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/im2sequence.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace anakin {
TEST(im2sequence_op, native) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, scope);
std::vector<int> kernels = {6, 1};
std::vector<int> strides = {1, 1};
std::vector<int> paddings = {0, 0, 0, 0};
validator.DeclInputVar("x", {1, 1, 2, 2});
validator.DeclOutputVar("out", {1, 1 * kernels[0] * kernels[1]});
// Prepare Op description
framework::OpDesc desc;
desc.SetType("im2sequence");
desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"out"});
desc.SetAttr("kernels", kernels);
desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings);
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(im2sequence);
USE_ANAKIN_CONVERTER(im2sequence);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/sum.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
#include "paddle/fluid/operators/sum_op.h"
namespace paddle {
namespace inference {
namespace anakin {
TEST(sum, native) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, scope);
validator.DeclInputVar("sum_x1", {1, 2, 1, 2});
validator.DeclInputVar("sum_x2", {1, 2, 1, 2});
validator.DeclOutputVar("sum_out", {1, 2, 1, 2});
// Prepare Op description
framework::OpDesc desc;
desc.SetType("sum");
desc.SetInput("X", {"sum_x1", "sum_x2"});
desc.SetOutput("Out", {"sum_out"});
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(sum);
USE_ANAKIN_CONVERTER(sum);
......@@ -28,6 +28,7 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("relu");
teller_set.insert("pool2d");
teller_set.insert("elementwise_add");
teller_set.insert("elementwise_mul");
teller_set.insert("concat");
teller_set.insert("tanh");
teller_set.insert("conv2d");
......@@ -38,7 +39,9 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("transpose2");
teller_set.insert("density_prior_box");
teller_set.insert("detection_out");
teller_set.insert("scale");
teller_set.insert("dropout");
teller_set.insert("sigmoid");
teller_set.insert("sum");
}
bool operator()(const std::string& op_type,
......
......@@ -47,7 +47,7 @@ std::unique_ptr<framework::ir::Graph> analysis::AnakinSubgraphPass::ApplyImpl(
return anakin::OpTeller::Global().Tell(node->Op()->Type(), *node->Op());
};
SubGraphFuser fuser(graph.get(), teller, 0 /* min_subgraph_size */);
SubGraphFuser fuser(graph.get(), teller, 6 /* min_subgraph_size */);
fuser();
std::vector<std::string> graph_param_names =
......
......@@ -210,13 +210,14 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "parameters", params);
auto enable_int8 = Get<bool>("enable_int8");
auto use_static_engine = Get<bool>("use_static_engine");
auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id,
std::to_string(0));
// Get "" when there is no cached calibration table data.
bool load_from_memory = Get<bool>("model_from_memory");
std::string calibration_data = "";
if (!load_from_memory) {
if (!load_from_memory && use_static_engine) {
calibration_data = GetTrtCalibTableData(
Get<std::string>("model_opt_cache_dir"), engine_key, enable_int8);
}
......@@ -240,7 +241,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
calibrator.reset(new tensorrt::TRTInt8Calibrator(calibration_data));
}
bool use_static_engine = Get<bool>("use_static_engine");
// When in int8 mode and calibration_mode, the program just produce the
// calibration table data.
bool calibration_mode = (enable_int8 && calibration_data.size() == 0);
......
......@@ -828,12 +828,13 @@ USE_ANAKIN_CONVERTER(sigmoid);
USE_ANAKIN_CONVERTER(tanh);
USE_ANAKIN_CONVERTER(pool2d);
USE_ANAKIN_CONVERTER(elementwise_add);
USE_ANAKIN_CONVERTER(elementwise_mul);
USE_ANAKIN_CONVERTER(batch_norm);
USE_ANAKIN_CONVERTER(flatten);
USE_ANAKIN_CONVERTER(reshape);
USE_ANAKIN_CONVERTER(transpose);
USE_ANAKIN_CONVERTER(softmax);
USE_ANAKIN_CONVERTER(detection_out);
USE_ANAKIN_CONVERTER(density_prior_box);
USE_ANAKIN_CONVERTER(scale);
USE_ANAKIN_CONVERTER(dropout);
USE_ANAKIN_CONVERTER(sum);
......@@ -138,7 +138,7 @@ struct AnalysisConfig {
void EnableTensorRtEngine(int workspace_size = 1 << 20,
int max_batch_size = 1, int min_subgraph_size = 3,
Precision precision = Precision::kFloat32,
bool use_static = true);
bool use_static = false);
/** A boolean state telling whether the TensorRT engine is used.
*/
bool tensorrt_engine_enabled() const { return use_tensorrt_; }
......
......@@ -80,6 +80,7 @@ const std::vector<std::string> kAnakinSubgraphPasses({
"conv_elementwise_add_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"fc_gru_fuse_pass", //
"anakin_subgraph_pass",
});
......
......@@ -233,6 +233,8 @@ void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
paddle::memory::Allocator::kScratchpad);
}
std::once_flag CUDADeviceContext::init_cudnn_;
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) {
CUDADeviceGuard guard(place_.device);
......@@ -252,10 +254,6 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
#endif
}
if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place));
}
driver_version_ = GetCUDADriverVersion(place_.device);
runtime_version_ = GetCUDARuntimeVersion(place_.device);
......@@ -348,12 +346,21 @@ bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr;
}
CudnnHolder* CUDADeviceContext::cudnn_holder() const {
std::call_once(init_cudnn_, [&]() {
if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place_));
}
});
return cudnn_holder_.get();
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->cudnn_handle();
return cudnn_holder()->cudnn_handle();
}
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(cudnn_holder_.get());
return CudnnWorkspaceHandle(cudnn_holder());
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
......
......@@ -290,9 +290,11 @@ class CUDADeviceContext : public DeviceContext {
private:
CUDAPlace place_;
static std::once_flag init_cudnn_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<CudnnHolder> cudnn_holder_;
mutable std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
......@@ -315,6 +317,7 @@ class CUDADeviceContext : public DeviceContext {
// StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_;
CudnnHolder* cudnn_holder() const;
DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册