提交 084310f5 编写于 作者: N nhzlx

paddle-anakin: concat, split, pool2d converter#16003

上级 be523baa
cc_library(anakin_engine SRCS engine.cc) cc_library(anakin_engine SRCS engine.cc)
nv_library(anakin_op_teller SRCS op_teller.cc DEPS framework_proto)
target_link_libraries(anakin_engine anakin anakin_saber_common) target_link_libraries(anakin_engine anakin anakin_saber_common)
cc_test(test_anakin_engine SRCS test_anakin_engine.cc DEPS anakin_engine) cc_test(test_anakin_engine SRCS test_anakin_engine.cc DEPS anakin_engine)
add_subdirectory(convert) add_subdirectory(convert)
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc activation.cc DEPS anakin_engine framework_proto scope operator op_registry) cc_library(anakin_op_converter SRCS fc.cc conv2d.cc activation.cc pool2d.cc concat.cc split.cc DEPS anakin_engine framework_proto scope op_registry)
cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op) cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} anakin_op_converter mul_op)
cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL) cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL)
cc_test(test_anakin_activation SRCS test_activation_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} activation_op anakin_op_converter SERIAL) cc_test(test_anakin_activation SRCS test_activation_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} activation_op anakin_op_converter
SERIAL)
cc_test(test_anakin_pool2d SRCS test_pool2d_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} anakin_op_converter pool_op pooling)
cc_test(test_anakin_concat SRCS test_concat_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} anakin_op_converter concat_op concat_and_split)
cc_test(test_anakin_split SRCS test_split_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} anakin_op_converter split_op concat_and_split)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/concat.h"
#include <algorithm>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void ConcatOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
auto input_names = op_desc.Input("X");
int axis = boost::get<int>(op_desc.GetAttr("axis"));
PADDLE_ENFORCE(axis > 0,
"The axis attr of Concat op should be large than 0 for trt");
auto y_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Concat", input_names, {y_name});
engine_->AddOpAttr(op_name, "axis", axis);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(concat, ConcatOpConverter);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
class ConcatOpConverter : public AnakinOpConverter {
public:
ConcatOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override;
virtual ~ConcatOpConverter() {}
private:
};
} // namespace anakin
} // namespace inference
} // namespace paddle
...@@ -48,7 +48,6 @@ void Conv2dOpConverter::operator()(const framework::proto::OpDesc &op, ...@@ -48,7 +48,6 @@ void Conv2dOpConverter::operator()(const framework::proto::OpDesc &op,
weight_tensor->Resize(filter_t->dims()); weight_tensor->Resize(filter_t->dims());
TensorCopySync((*filter_t), platform::CPUPlace(), weight_tensor.get()); TensorCopySync((*filter_t), platform::CPUPlace(), weight_tensor.get());
auto *weight_data = weight_tensor->mutable_data<float>(platform::CPUPlace());
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
// const int n_output = weight_tensor->dims()[0]; // const int n_output = weight_tensor->dims()[0];
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/pool2d.h"
#include <algorithm>
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto x_name = op_desc.Input("X").front();
auto y_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
bool global_pooling = boost::get<bool>(op_desc.GetAttr("global_pooling"));
std::string pool_type =
boost::get<std::string>(op_desc.GetAttr("pooling_type"));
std::vector<int> ksize =
boost::get<std::vector<int>>(op_desc.GetAttr("ksize"));
std::vector<int> strides =
boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
bool ceil_mode = boost::get<bool>(op_desc.GetAttr("ceil_mode"));
std::string anakin_pool_type;
if (pool_type == "max") {
anakin_pool_type = "MAX";
} else if (pool_type == "avg") {
anakin_pool_type = "AVG";
} else {
PADDLE_THROW("TensorRT unsupported pooling type!");
}
engine_->AddOp(op_name, "Pooling", {x_name}, {y_name});
engine_->AddOpAttr<PTuple<int>>(op_name, "pool_size", ksize);
engine_->AddOpAttr<PTuple<int>>(op_name, "strides", strides);
engine_->AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
engine_->AddOpAttr(op_name, "method", anakin_pool_type);
engine_->AddOpAttr(op_name, "global_pooling", global_pooling);
engine_->AddOpAttr(op_name, "cmp_out_shape_floor_as_conv", !ceil_mode);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(pool2d, Pool2dOpConverter);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
class Pool2dOpConverter : public AnakinOpConverter {
public:
Pool2dOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override;
virtual ~Pool2dOpConverter() {}
private:
};
} // namespace anakin
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/split.h"
#include <algorithm>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void SplitOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
auto input_name = op_desc.Input("X").front();
auto y_names = op_desc.Output("Out");
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
int axis = boost::get<int>(op_desc.GetAttr("axis"));
std::vector<int> output_lengths =
boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
int split_num = output_lengths.size();
PADDLE_ENFORCE(split_num > 1,
"anakin split op converter: the split num should > 1");
int num_sum = 0;
std::vector<int> slice_point;
for (int i = 0; i < split_num - 1; i++) {
num_sum += output_lengths[i];
slice_point.push_back(num_sum);
}
engine_->AddOp(op_name, "Slice", {input_name}, y_names);
engine_->AddOpAttr(op_name, "axis", axis);
engine_->AddOpAttr<PTuple<int>>(op_name, "slice_point", slice_point);
// slice_dim is useless in anakin
engine_->AddOpAttr(op_name, "slice_dim", 4);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(split, SplitOpConverter);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
class SplitOpConverter : public AnakinOpConverter {
public:
SplitOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override;
virtual ~SplitOpConverter() {}
private:
};
} // namespace anakin
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/concat.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace anakin {
TEST(concat_op, test) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
AnakinConvertValidation validator(parameters, scope);
validator.DeclInputVar("concat_x1", {1, 10, 3, 1});
validator.DeclInputVar("concat_x2", {1, 3, 3, 1});
validator.DeclInputVar("concat_x3", {1, 7, 3, 1});
validator.DeclOutputVar("concat_out", {1, 20, 3, 1});
// Prepare Op description
framework::OpDesc desc;
desc.SetType("concat");
desc.SetInput("X", {"concat_x1", "concat_x2", "concat_x3"});
desc.SetOutput("Out", {"concat_out"});
int axis = 1;
desc.SetAttr("axis", axis);
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(concat);
USE_ANAKIN_CONVERTER(concat);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace anakin {
void test_pool2d(bool global_pooling, bool ceil_mode,
std::string pool_type = "max") {
auto* pool2d_converter =
Registry<AnakinOpConverter>::Global().Lookup("pool2d");
ASSERT_TRUE(pool2d_converter);
framework::Scope scope;
std::unordered_set<std::string> parameters;
AnakinConvertValidation validator(parameters, scope);
// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
validator.DeclInputVar("pool2d_x", {1, 3, 6, 7});
if (global_pooling)
validator.DeclOutputVar("pool2d_out", {1, 3, 1, 1});
else if (ceil_mode)
validator.DeclOutputVar("pool2d_out", {1, 3, 3, 4});
else
validator.DeclOutputVar("pool2d_out", {1, 3, 3, 3});
// Prepare Op description
framework::OpDesc desc;
desc.SetType("pool2d");
desc.SetInput("X", {"pool2d_x"});
desc.SetOutput("Out", {"pool2d_out"});
std::vector<int> ksize({2, 2});
std::vector<int> strides({2, 2});
std::vector<int> paddings({0, 0});
std::string pooling_t = pool_type;
desc.SetAttr("pooling_type", pooling_t);
desc.SetAttr("ksize", ksize);
desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings);
desc.SetAttr("global_pooling", global_pooling);
desc.SetAttr("ceil_mode", ceil_mode);
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
validator.Execute(1);
}
TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); }
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); }
TEST(Pool2dOpConverter, max_ceil_test) { test_pool2d(false, true); }
TEST(Pool2dOpConverter, avg_ceil_test) { test_pool2d(false, true, "avg"); }
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(pool2d);
USE_ANAKIN_CONVERTER(pool2d);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/split.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace anakin {
template <int Axis>
void AnakinSliceTest(const std::vector<int> &in_shape,
const std::vector<int> &sections) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
AnakinConvertValidation validator(parameters, scope);
validator.DeclInputVar("split_input", in_shape);
std::vector<std::string> output_vars;
for (size_t i = 0; i < sections.size(); ++i) {
auto out_shape = in_shape;
out_shape[Axis] = sections[i];
std::string output_name = "split_out" + std::to_string(i);
validator.DeclOutputVar(output_name, out_shape);
output_vars.push_back(output_name);
}
// Prepare Op description
framework::OpDesc desc;
desc.SetType("split");
desc.SetInput("X", {"split_input"});
desc.SetOutput("Out", output_vars);
desc.SetAttr("axis", Axis);
desc.SetAttr("num", 0);
desc.SetAttr("sections", sections);
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
// batch = 0, axis = 1, same shape
TEST(split_op, test_same_shape_axis1_batch1) {
AnakinSliceTest<1>({1, 4, 2, 2}, {2, 2});
}
// batch = 0, axis = 1, different shape
TEST(split_op, test_different_shape_axis1_batch1) {
AnakinSliceTest<1>({1, 3, 2, 2}, {2, 1});
}
// batch = 10, axis = 1, same shape
TEST(split_op, test_same_shape_axis1_batch10) {
AnakinSliceTest<1>({1, 4, 2, 2}, {2, 2});
}
// batch = 10, axis = 1, different shape
TEST(split_op, test_different_shape_axis1_batch10) {
AnakinSliceTest<1>({1, 3, 2, 2}, {2, 1});
}
// batch = 0, axis = 2, same shape
TEST(split_op, test_same_shape_axis2_batch1) {
AnakinSliceTest<2>({1, 3, 4, 2}, {2, 2});
}
// batch = 0, axis = 2, different shape
TEST(split_op, test_different_shape_axis2_batch1) {
AnakinSliceTest<2>({1, 3, 3, 2}, {2, 1});
}
// batch = 10, axis = 2, same shape
TEST(split_op, test_same_shape_axis2_batch10) {
AnakinSliceTest<2>({1, 3, 4, 2}, {2, 2});
}
// batch = 10, axis = 2, different shape
TEST(split_op, test_different_shape_axis2_batch10) {
AnakinSliceTest<2>({1, 3, 3, 2}, {2, 1});
}
// batch = 0, axis = 3, same shape
TEST(split_op, test_same_shape_axis3_batch1) {
AnakinSliceTest<3>({1, 3, 2, 4}, {2, 2});
}
// batch = 0, axis = 3, different shape
TEST(split_op, test_different_shape_axis3_batch1) {
AnakinSliceTest<3>({1, 3, 2, 3}, {2, 1});
}
// batch = 10, axis = 3, same shape
TEST(split_op, test_same_shape_axis3_batch10) {
AnakinSliceTest<3>({1, 3, 2, 4}, {2, 2});
}
// batch = 10, axis = 3, different shape
TEST(split_op, test_different_shape_axis3_batch10) {
AnakinSliceTest<3>({1, 3, 2, 3}, {2, 1});
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(split);
USE_ANAKIN_CONVERTER(split);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/op_teller.h"
namespace paddle {
namespace inference {
namespace anakin {
// Just tell by the op_types.
struct SimpleOpTypeSetTeller : public Teller {
SimpleOpTypeSetTeller() {}
bool operator()(const std::string& op_type,
const framework::OpDesc& desc) override {
return teller_set.count(op_type);
}
private:
std::unordered_set<std::string> teller_set{{"mul"}};
};
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) {
for (auto& teller : tellers_) {
if ((*teller)(op_type, desc)) return true;
}
return false;
}
OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); }
} // namespace anakin
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
namespace paddle {
namespace inference {
namespace anakin {
/*
* Single Op teller definition.
* One can override this and define a more complex tell logic, considerring more
* issues such as op_desc.
*/
struct Teller {
virtual bool operator()(const std::string& op_type,
const framework::OpDesc& desc) = 0;
virtual ~Teller() = default;
};
/*
* A real example:
*
* struct SomeTeller : public Teller {
* bool operator()(const std::string& op_type,
* const framework::OpDesc& desc) override {
* return op_type == "fc" && desc.Inputs().size() == 2;
* }
*};
*/
/*
* class OpTeller helps to tell whether a fluid
* operator can be transformed to a TensorRT layer.
*/
class OpTeller {
public:
static OpTeller& Global() {
static std::unique_ptr<OpTeller> x(new OpTeller);
return *x;
}
bool Tell(const std::string& op_type, const framework::OpDesc& desc);
private:
OpTeller();
private:
std::vector<std::unique_ptr<Teller>> tellers_;
};
} // namespace anakin
} // namespace inference
} // namespace paddle
...@@ -45,7 +45,7 @@ class EngineIOConverter { ...@@ -45,7 +45,7 @@ class EngineIOConverter {
static void ConvertInput(const std::string& op_type, const LoDTensor& in, static void ConvertInput(const std::string& op_type, const LoDTensor& in,
void* out, size_t max_size, cudaStream_t* stream) { void* out, size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr); PADDLE_ENFORCE(stream != nullptr);
auto* converter = Registry<EngineIOConverter>::Lookup( auto* converter = Registry<EngineIOConverter>::Global().Lookup(
op_type, "default" /* default_type */); op_type, "default" /* default_type */);
PADDLE_ENFORCE_NOT_NULL(converter); PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream); converter->SetStream(stream);
...@@ -56,7 +56,7 @@ class EngineIOConverter { ...@@ -56,7 +56,7 @@ class EngineIOConverter {
LoDTensor* out, size_t max_size, LoDTensor* out, size_t max_size,
cudaStream_t* stream) { cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr); PADDLE_ENFORCE(stream != nullptr);
auto* converter = Registry<EngineIOConverter>::Lookup( auto* converter = Registry<EngineIOConverter>::Global().Lookup(
op_type, "default" /* default_type */); op_type, "default" /* default_type */);
PADDLE_ENFORCE_NOT_NULL(converter); PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream); converter->SetStream(stream);
...@@ -69,12 +69,12 @@ class EngineIOConverter { ...@@ -69,12 +69,12 @@ class EngineIOConverter {
cudaStream_t* stream_{nullptr}; cudaStream_t* stream_{nullptr};
}; };
#define REGISTER_TENSORRT_IO_CONVERTER(op_type__, Converter__) \ #define REGISTER_TENSORRT_IO_CONVERTER(op_type__, Converter__) \
struct trt_io_##op_type__##_converter { \ struct trt_io_##op_type__##_converter { \
trt_io_##op_type__##_converter() { \ trt_io_##op_type__##_converter() { \
Registry<EngineIOConverter>::Register<Converter__>(#op_type__); \ Registry<EngineIOConverter>::Global().Register<Converter__>(#op_type__); \
} \ } \
}; \ }; \
trt_io_##op_type__##_converter trt_io_##op_type__##_converter__; trt_io_##op_type__##_converter trt_io_##op_type__##_converter__;
} // namespace tensorrt } // namespace tensorrt
......
...@@ -86,7 +86,7 @@ class OpConverter { ...@@ -86,7 +86,7 @@ class OpConverter {
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL); PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
std::string Y = op_desc.Input("Y")[0]; std::string Y = op_desc.Input("Y")[0];
if (parameters.count(Y)) { if (parameters.count(Y)) {
it = Registry<OpConverter>::Lookup("fc"); it = Registry<OpConverter>::Global().Lookup("fc");
} }
} }
if (op_desc.Type().find("elementwise") != std::string::npos) { if (op_desc.Type().find("elementwise") != std::string::npos) {
...@@ -103,28 +103,28 @@ class OpConverter { ...@@ -103,28 +103,28 @@ class OpConverter {
if (parameters.count(Y)) { if (parameters.count(Y)) {
PADDLE_ENFORCE(add_weight_op_set.count(op_type) > 0, PADDLE_ENFORCE(add_weight_op_set.count(op_type) > 0,
"Unsupported elementwise type" + op_type); "Unsupported elementwise type" + op_type);
it = it = Registry<OpConverter>::Global().Lookup("elementwise_" + op_type +
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_weight"); "_weight");
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type()); op_desc.Type());
} else { } else {
PADDLE_ENFORCE(add_tensor_op_set.count(op_type) > 0, PADDLE_ENFORCE(add_tensor_op_set.count(op_type) > 0,
"Unsupported elementwise type" + op_type); "Unsupported elementwise type" + op_type);
it = it = Registry<OpConverter>::Global().Lookup("elementwise_" + op_type +
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_tensor"); "_tensor");
} }
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type()); op_desc.Type());
} }
if (op_desc.Type() == "depthwise_conv2d") { if (op_desc.Type() == "depthwise_conv2d") {
it = Registry<OpConverter>::Lookup("conv2d"); it = Registry<OpConverter>::Global().Lookup("conv2d");
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type()); op_desc.Type());
} }
if (!it) { if (!it) {
it = Registry<OpConverter>::Lookup(op_desc.Type()); it = Registry<OpConverter>::Global().Lookup(op_desc.Type());
} }
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type()); op_desc.Type());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册