未验证 提交 1e21e8b5 编写于 作者: 石晓伟 提交者: GitHub

Merge pull request #16611 from Shixiaowei02/release/1.4

Cherry-pick from 16498 : Deal with softmax layer in anakin subgraph
...@@ -201,7 +201,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -201,7 +201,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64") SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64")
ENDIF() ENDIF()
SET(PROTOBUF_REPO "https://github.com/google/protobuf.git") SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git")
SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546") SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546")
ExternalProject_Add( ExternalProject_Add(
......
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc DEPS anakin_engine framework_proto scope op_registry)
elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc DEPS anakin_engine framework_proto scope op_registry)
cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL) cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL)
cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL) cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL)
......
...@@ -34,6 +34,7 @@ ActivationOpConverter::ActivationOpConverter(const std::string &op_type) ...@@ -34,6 +34,7 @@ ActivationOpConverter::ActivationOpConverter(const std::string &op_type)
} }
void ActivationOpConverter::operator()(const framework::proto::OpDesc &op, void ActivationOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -27,6 +27,7 @@ class ActivationOpConverter : public AnakinOpConverter { ...@@ -27,6 +27,7 @@ class ActivationOpConverter : public AnakinOpConverter {
explicit ActivationOpConverter(const std::string &op_type); explicit ActivationOpConverter(const std::string &op_type);
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~ActivationOpConverter() {} virtual ~ActivationOpConverter() {}
......
...@@ -29,6 +29,7 @@ namespace inference { ...@@ -29,6 +29,7 @@ namespace inference {
namespace anakin { namespace anakin {
void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op, void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class BatchNormOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class BatchNormOpConverter : public AnakinOpConverter {
BatchNormOpConverter() = default; BatchNormOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~BatchNormOpConverter() {} virtual ~BatchNormOpConverter() {}
......
...@@ -29,6 +29,7 @@ namespace inference { ...@@ -29,6 +29,7 @@ namespace inference {
namespace anakin { namespace anakin {
void ConcatOpConverter::operator()(const framework::proto::OpDesc &op, void ConcatOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class ConcatOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class ConcatOpConverter : public AnakinOpConverter {
ConcatOpConverter() = default; ConcatOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~ConcatOpConverter() {} virtual ~ConcatOpConverter() {}
......
...@@ -28,6 +28,7 @@ namespace inference { ...@@ -28,6 +28,7 @@ namespace inference {
namespace anakin { namespace anakin {
void Conv2dOpConverter::operator()(const framework::proto::OpDesc &op, void Conv2dOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class Conv2dOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class Conv2dOpConverter : public AnakinOpConverter {
Conv2dOpConverter() = default; Conv2dOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~Conv2dOpConverter() {} virtual ~Conv2dOpConverter() {}
......
...@@ -28,6 +28,7 @@ namespace inference { ...@@ -28,6 +28,7 @@ namespace inference {
namespace anakin { namespace anakin {
void Conv2dFusionOpConverter::operator()(const framework::proto::OpDesc &op, void Conv2dFusionOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class Conv2dFusionOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class Conv2dFusionOpConverter : public AnakinOpConverter {
Conv2dFusionOpConverter() = default; Conv2dFusionOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~Conv2dFusionOpConverter() {} virtual ~Conv2dFusionOpConverter() {}
......
...@@ -27,9 +27,9 @@ namespace paddle { ...@@ -27,9 +27,9 @@ namespace paddle {
namespace inference { namespace inference {
namespace anakin { namespace anakin {
void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op, void DensityPriorBoxOpConverter::operator()(
const framework::Scope& scope, const framework::proto::OpDesc& op, const framework::BlockDesc& block_desc,
bool test_mode) { const framework::Scope& scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto input_name = op_desc.Input("Input").front(); auto input_name = op_desc.Input("Input").front();
auto image_name = op_desc.Input("Image").front(); auto image_name = op_desc.Input("Image").front();
......
...@@ -27,6 +27,7 @@ class DensityPriorBoxOpConverter : public AnakinOpConverter { ...@@ -27,6 +27,7 @@ class DensityPriorBoxOpConverter : public AnakinOpConverter {
DensityPriorBoxOpConverter() = default; DensityPriorBoxOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~DensityPriorBoxOpConverter() {} virtual ~DensityPriorBoxOpConverter() {}
......
...@@ -26,6 +26,7 @@ namespace inference { ...@@ -26,6 +26,7 @@ namespace inference {
namespace anakin { namespace anakin {
void DetectionOutOpConverter::operator()(const framework::proto::OpDesc &op, void DetectionOutOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -27,6 +27,7 @@ class DetectionOutOpConverter : public AnakinOpConverter { ...@@ -27,6 +27,7 @@ class DetectionOutOpConverter : public AnakinOpConverter {
DetectionOutOpConverter() = default; DetectionOutOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~DetectionOutOpConverter() {} virtual ~DetectionOutOpConverter() {}
......
...@@ -31,6 +31,7 @@ namespace inference { ...@@ -31,6 +31,7 @@ namespace inference {
namespace anakin { namespace anakin {
void DropoutOpConverter::operator()(const framework::proto::OpDesc &op, void DropoutOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class DropoutOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class DropoutOpConverter : public AnakinOpConverter {
DropoutOpConverter() = default; DropoutOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~DropoutOpConverter() {} virtual ~DropoutOpConverter() {}
......
...@@ -30,9 +30,9 @@ namespace paddle { ...@@ -30,9 +30,9 @@ namespace paddle {
namespace inference { namespace inference {
namespace anakin { namespace anakin {
void ElementwiseAddOpConverter::operator()(const framework::proto::OpDesc &op, void ElementwiseAddOpConverter::operator()(
const framework::Scope &scope, const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
bool test_mode) { const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1);
...@@ -50,9 +50,9 @@ void ElementwiseAddOpConverter::operator()(const framework::proto::OpDesc &op, ...@@ -50,9 +50,9 @@ void ElementwiseAddOpConverter::operator()(const framework::proto::OpDesc &op,
engine_->AddOpAttr<PTuple<float>>(op_name, "coeff", coeff); engine_->AddOpAttr<PTuple<float>>(op_name, "coeff", coeff);
} }
void ElementwiseMulOpConverter::operator()(const framework::proto::OpDesc &op, void ElementwiseMulOpConverter::operator()(
const framework::Scope &scope, const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
bool test_mode) { const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1);
......
...@@ -25,6 +25,7 @@ class ElementwiseAddOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class ElementwiseAddOpConverter : public AnakinOpConverter {
ElementwiseAddOpConverter() = default; ElementwiseAddOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~ElementwiseAddOpConverter() {} virtual ~ElementwiseAddOpConverter() {}
...@@ -37,6 +38,7 @@ class ElementwiseMulOpConverter : public AnakinOpConverter { ...@@ -37,6 +38,7 @@ class ElementwiseMulOpConverter : public AnakinOpConverter {
ElementwiseMulOpConverter() = default; ElementwiseMulOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~ElementwiseMulOpConverter() {} virtual ~ElementwiseMulOpConverter() {}
......
...@@ -27,6 +27,7 @@ namespace inference { ...@@ -27,6 +27,7 @@ namespace inference {
namespace anakin { namespace anakin {
void FcBaseOpConverter::operator()(const framework::proto::OpDesc &op, void FcBaseOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class FcBaseOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class FcBaseOpConverter : public AnakinOpConverter {
FcBaseOpConverter() = default; FcBaseOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~FcBaseOpConverter() {} virtual ~FcBaseOpConverter() {}
......
...@@ -26,6 +26,7 @@ namespace inference { ...@@ -26,6 +26,7 @@ namespace inference {
namespace anakin { namespace anakin {
void FlattenOpConverter::operator()(const framework::proto::OpDesc &op, void FlattenOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class FlattenOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class FlattenOpConverter : public AnakinOpConverter {
FlattenOpConverter() = default; FlattenOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~FlattenOpConverter() {} virtual ~FlattenOpConverter() {}
......
...@@ -31,6 +31,7 @@ namespace inference { ...@@ -31,6 +31,7 @@ namespace inference {
namespace anakin { namespace anakin {
void Im2SequenceConverter::operator()(const framework::proto::OpDesc &op, void Im2SequenceConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class Im2SequenceConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class Im2SequenceConverter : public AnakinOpConverter {
Im2SequenceConverter() = default; Im2SequenceConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~Im2SequenceConverter() {} virtual ~Im2SequenceConverter() {}
......
...@@ -40,8 +40,10 @@ class AnakinOpConverter { ...@@ -40,8 +40,10 @@ class AnakinOpConverter {
AnakinOpConverter() = default; AnakinOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {} const framework::Scope &scope, bool test_mode) {}
void ConvertOp(const framework::proto::OpDesc &op, void ConvertOp(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const std::unordered_set<std::string> &parameters, const std::unordered_set<std::string> &parameters,
const framework::Scope &scope, AnakinNvEngine *engine, const framework::Scope &scope, AnakinNvEngine *engine,
bool test_mode = false) { bool test_mode = false) {
...@@ -58,16 +60,17 @@ class AnakinOpConverter { ...@@ -58,16 +60,17 @@ class AnakinOpConverter {
} }
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", op_type); PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", op_type);
it->SetEngine(engine); it->SetEngine(engine);
(*it)(op, scope, test_mode); (*it)(op, block_desc, scope, test_mode);
} }
void ConvertBlock(const framework::proto::BlockDesc &block, void ConvertBlock(framework::BlockDesc *block_desc,
const std::unordered_set<std::string> &parameters, const std::unordered_set<std::string> &parameters,
const framework::Scope &scope, AnakinNvEngine *engine) { const framework::Scope &scope, AnakinNvEngine *engine) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
for (auto i = 0; i < block.ops_size(); i++) { framework::proto::BlockDesc *block = block_desc->Proto();
auto &op = block.ops(i); for (auto i = 0; i < block->ops_size(); i++) {
ConvertOp(op, parameters, scope, engine); auto &op = block->ops(i);
ConvertOp(op, *block_desc, parameters, scope, engine);
} }
} }
...@@ -77,9 +80,7 @@ class AnakinOpConverter { ...@@ -77,9 +80,7 @@ class AnakinOpConverter {
const std::vector<std::string> &inputs, const std::vector<std::string> &inputs,
const std::unordered_set<std::string> &parameters, const std::unordered_set<std::string> &parameters,
const std::vector<std::string> &outputs, AnakinNvEngine *engine) { const std::vector<std::string> &outputs, AnakinNvEngine *engine) {
framework::proto::BlockDesc *block_proto = block_desc->Proto(); ConvertBlock(block_desc, parameters, *scope, engine);
ConvertBlock(*block_proto, parameters, *scope, engine);
engine->Freeze(); engine->Freeze();
// if the max_batch size // if the max_batch size
int max_batch_size = engine->GetMaxBatchSize(); int max_batch_size = engine->GetMaxBatchSize();
......
...@@ -31,6 +31,7 @@ namespace inference { ...@@ -31,6 +31,7 @@ namespace inference {
namespace anakin { namespace anakin {
void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op, void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class Pool2dOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class Pool2dOpConverter : public AnakinOpConverter {
Pool2dOpConverter() = default; Pool2dOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~Pool2dOpConverter() {} virtual ~Pool2dOpConverter() {}
......
...@@ -26,6 +26,7 @@ namespace inference { ...@@ -26,6 +26,7 @@ namespace inference {
namespace anakin { namespace anakin {
void ReluOpConverter::operator()(const framework::proto::OpDesc &op, void ReluOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -27,6 +27,7 @@ class ReluOpConverter : public AnakinOpConverter { ...@@ -27,6 +27,7 @@ class ReluOpConverter : public AnakinOpConverter {
ReluOpConverter() = default; ReluOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~ReluOpConverter() {} virtual ~ReluOpConverter() {}
......
...@@ -26,6 +26,7 @@ namespace inference { ...@@ -26,6 +26,7 @@ namespace inference {
namespace anakin { namespace anakin {
void ReshapeOpConverter::operator()(const framework::proto::OpDesc &op, void ReshapeOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class ReshapeOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class ReshapeOpConverter : public AnakinOpConverter {
ReshapeOpConverter() = default; ReshapeOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~ReshapeOpConverter() {} virtual ~ReshapeOpConverter() {}
......
...@@ -26,6 +26,7 @@ namespace inference { ...@@ -26,6 +26,7 @@ namespace inference {
namespace anakin { namespace anakin {
void ScaleOpConverter::operator()(const framework::proto::OpDesc &op, void ScaleOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -27,6 +27,7 @@ class ScaleOpConverter : public AnakinOpConverter { ...@@ -27,6 +27,7 @@ class ScaleOpConverter : public AnakinOpConverter {
ScaleOpConverter() = default; ScaleOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~ScaleOpConverter() {} virtual ~ScaleOpConverter() {}
......
...@@ -24,6 +24,7 @@ namespace inference { ...@@ -24,6 +24,7 @@ namespace inference {
namespace anakin { namespace anakin {
void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op, void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
...@@ -32,8 +33,16 @@ void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op, ...@@ -32,8 +33,16 @@ void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op,
auto input = op_desc.Input("X").front(); auto input = op_desc.Input("X").front();
auto output = op_desc.Output("Out").front(); auto output = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front(); auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
auto input_var_desc = block_desc.FindVar(input);
PADDLE_ENFORCE(input_var_desc,
"Cant find %s variable When runing Anakin Softmax converter.",
input);
auto input_shape_in_fluid = input_var_desc->GetShape();
size_t input_dims = input_shape_in_fluid.size();
engine_->AddOp(op_name, "Softmax", {input}, {output}); engine_->AddOp(op_name, "Softmax", {input}, {output});
engine_->AddOpAttr(op_name, "axis", 2); engine_->AddOpAttr(op_name, "axis", static_cast<int>(input_dims - 1));
} }
} // namespace anakin } // namespace anakin
......
...@@ -25,6 +25,7 @@ class SoftMaxOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class SoftMaxOpConverter : public AnakinOpConverter {
SoftMaxOpConverter() = default; SoftMaxOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~SoftMaxOpConverter() {} virtual ~SoftMaxOpConverter() {}
......
...@@ -30,6 +30,7 @@ namespace inference { ...@@ -30,6 +30,7 @@ namespace inference {
namespace anakin { namespace anakin {
void SplitOpConverter::operator()(const framework::proto::OpDesc &op, void SplitOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class SplitOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class SplitOpConverter : public AnakinOpConverter {
SplitOpConverter() = default; SplitOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~SplitOpConverter() {} virtual ~SplitOpConverter() {}
......
...@@ -31,6 +31,7 @@ namespace inference { ...@@ -31,6 +31,7 @@ namespace inference {
namespace anakin { namespace anakin {
void SumOpConverter::operator()(const framework::proto::OpDesc &op, void SumOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) { const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 2); PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 2);
......
...@@ -25,6 +25,7 @@ class SumOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class SumOpConverter : public AnakinOpConverter {
SumOpConverter() = default; SumOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~SumOpConverter() {} virtual ~SumOpConverter() {}
......
...@@ -28,6 +28,7 @@ namespace inference { ...@@ -28,6 +28,7 @@ namespace inference {
namespace anakin { namespace anakin {
void TransposeOpConverter::operator()(const framework::proto::OpDesc &op, void TransposeOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) { bool test_mode) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
......
...@@ -25,6 +25,7 @@ class TransposeOpConverter : public AnakinOpConverter { ...@@ -25,6 +25,7 @@ class TransposeOpConverter : public AnakinOpConverter {
TransposeOpConverter() = default; TransposeOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op, virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, const framework::Scope &scope,
bool test_mode) override; bool test_mode) override;
virtual ~TransposeOpConverter() {} virtual ~TransposeOpConverter() {}
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
...@@ -112,6 +113,17 @@ class AnakinConvertValidation { ...@@ -112,6 +113,17 @@ class AnakinConvertValidation {
auto* x_tensor = x->GetMutable<framework::LoDTensor>(); auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec)); x_tensor->Resize(framework::make_ddim(dim_vec));
RandomizeTensor(x_tensor, place_, ctx); RandomizeTensor(x_tensor, place_, ctx);
std::vector<int64_t> dim_vec_int64;
for (auto& ele : dim_vec) {
dim_vec_int64.push_back(static_cast<int64_t>(ele));
}
// Add var_desc to block_desc
auto* block_desc = program_desc_.MutableBlock(framework::kRootBlockIndex);
auto* var_desc = block_desc->Var(name);
var_desc->SetShape(dim_vec_int64);
} }
void SetOp(const framework::proto::OpDesc& desc) { void SetOp(const framework::proto::OpDesc& desc) {
...@@ -119,8 +131,10 @@ class AnakinConvertValidation { ...@@ -119,8 +131,10 @@ class AnakinConvertValidation {
op_desc_.reset(new framework::OpDesc(desc, nullptr)); op_desc_.reset(new framework::OpDesc(desc, nullptr));
// should init anakin engine here. // should init anakin engine here.
auto& block_desc = program_desc_.Block(framework::kRootBlockIndex);
Singleton<AnakinOpConverter>::Global().ConvertOp( Singleton<AnakinOpConverter>::Global().ConvertOp(
desc, parameters_, *scope_, engine_.get(), true /*test_mode*/); desc, block_desc, parameters_, *scope_, engine_.get(),
true /*test_mode*/);
engine_->Freeze(); engine_->Freeze();
std::map<std::string, std::vector<int>> temp_max_input_shape; std::map<std::string, std::vector<int>> temp_max_input_shape;
...@@ -194,6 +208,7 @@ class AnakinConvertValidation { ...@@ -194,6 +208,7 @@ class AnakinConvertValidation {
cudaStream_t stream_; cudaStream_t stream_;
std::unique_ptr<framework::OperatorBase> op_; std::unique_ptr<framework::OperatorBase> op_;
std::unique_ptr<framework::OpDesc> op_desc_; std::unique_ptr<framework::OpDesc> op_desc_;
framework::ProgramDesc program_desc_;
const std::unordered_set<std::string>& parameters_; const std::unordered_set<std::string>& parameters_;
framework::Scope* scope_; framework::Scope* scope_;
platform::CUDAPlace place_; platform::CUDAPlace place_;
......
...@@ -91,7 +91,6 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::Execute( ...@@ -91,7 +91,6 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
" or equal to the real input shape, Please set the max " " or equal to the real input shape, Please set the max "
"input shape using EnableAnakinEngine"); "input shape using EnableAnakinEngine");
anakin_input->reshape(fluid_input_shape); anakin_input->reshape(fluid_input_shape);
::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0, ::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0,
fluid_input_shape); fluid_input_shape);
anakin_input->copy_from(tmp_anakin_tensor); anakin_input->copy_from(tmp_anakin_tensor);
......
...@@ -168,6 +168,7 @@ struct Argument { ...@@ -168,6 +168,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(anakin_max_input_shape, AnakinMaxInputShape, DECL_ARGUMENT_FIELD(anakin_max_input_shape, AnakinMaxInputShape,
anakin_max_shape_t); anakin_max_shape_t);
DECL_ARGUMENT_FIELD(anakin_max_batch_size, AnakinMaxBatchSize, int); DECL_ARGUMENT_FIELD(anakin_max_batch_size, AnakinMaxBatchSize, int);
DECL_ARGUMENT_FIELD(anakin_min_subgraph_size, AnakinMinSubgraphSize, int);
DECL_ARGUMENT_FIELD(use_anakin, UseAnakin, bool); DECL_ARGUMENT_FIELD(use_anakin, UseAnakin, bool);
// Memory optimized related. // Memory optimized related.
......
...@@ -151,13 +151,20 @@ void AnakinSubgraphPass::CreateAnakinOp( ...@@ -151,13 +151,20 @@ void AnakinSubgraphPass::CreateAnakinOp(
op_desc->SetType("anakin_engine"); op_desc->SetType("anakin_engine");
std::unordered_map<std::string, std::string> output_name_map; std::unordered_map<std::string, std::string> output_name_map;
std::unordered_map<std::string, framework::ir::Node *> graph_var_map;
for (framework::ir::Node *node : graph->Nodes()) {
if (node->IsVar() && node->Var()) {
graph_var_map[node->Name()] = node;
}
}
auto &subgraph_nodes = *Agent(node).subgraph(); auto &subgraph_nodes = *Agent(node).subgraph();
// The following procedure is used to rename all the intermediate // The following procedure is used to rename all the intermediate
// variables and the output variables of the subgraph. // variables and the output variables of the subgraph.
RenameAndGetOutputs(subgraph_nodes, &block_desc, input_names_with_id, RenameAndGetOutputs(subgraph_nodes, &block_desc, input_names_with_id,
&output_names_with_id, &output_names, &output_name_map, &output_names_with_id, &output_names, &output_name_map,
false); graph_var_map, false);
// When anakin engine runs at the end of the operation, // When anakin engine runs at the end of the operation,
// output_mapping help us copy the data from the renamed ITensor // output_mapping help us copy the data from the renamed ITensor
...@@ -168,13 +175,6 @@ void AnakinSubgraphPass::CreateAnakinOp( ...@@ -168,13 +175,6 @@ void AnakinSubgraphPass::CreateAnakinOp(
output_mapping.push_back(output_name_map[name]); output_mapping.push_back(output_name_map[name]);
} }
auto *vars = block_desc.Proto()->mutable_vars();
for (framework::ir::Node *node : graph->Nodes()) {
if (node->IsVar() && node->Var()) {
*vars->Add() = *node->Var()->Proto();
}
}
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(), PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(),
"the block has no var-desc"); "the block has no var-desc");
PADDLE_ENFORCE(!output_mapping.empty()); PADDLE_ENFORCE(!output_mapping.empty());
......
...@@ -60,6 +60,7 @@ void RenameAndGetOutputs( ...@@ -60,6 +60,7 @@ void RenameAndGetOutputs(
std::set<std::string> *output_names_with_id, std::set<std::string> *output_names_with_id,
std::set<std::string> *output_names, std::set<std::string> *output_names,
std::unordered_map<std::string, std::string> *output_name_map, std::unordered_map<std::string, std::string> *output_name_map,
const std::unordered_map<std::string, framework::ir::Node *> &graph_var_map,
bool is_trt) { bool is_trt) {
//// In the normal case, the paddle-trt exists bug when runing the googlenet. //// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the // When there are more than two convolutions of 1 * 1 with the same input, the
...@@ -69,6 +70,15 @@ void RenameAndGetOutputs( ...@@ -69,6 +70,15 @@ void RenameAndGetOutputs(
std::unordered_map<std::string /*name*/, int /*ITensor_quote_num*/> std::unordered_map<std::string /*name*/, int /*ITensor_quote_num*/>
same_hierarchy_conv2d_num_map; same_hierarchy_conv2d_num_map;
auto add_block_var = [&](const std::string &graph_arg,
const std::string &block_arg) {
auto arg_var_node = graph_var_map.find(graph_arg);
PADDLE_ENFORCE(arg_var_node != graph_var_map.end());
auto *var_t = block_desc->Var(block_arg);
var_t->SetShape(arg_var_node->second->Var()->GetShape());
var_t->SetDataType(arg_var_node->second->Var()->GetDataType());
};
for (size_t index = 0; index < block_desc->OpSize(); ++index) { for (size_t index = 0; index < block_desc->OpSize(); ++index) {
framework::proto::OpDesc *op = block_desc->Op(index)->Proto(); framework::proto::OpDesc *op = block_desc->Op(index)->Proto();
framework::OpDesc op_desc(*op, nullptr); framework::OpDesc op_desc(*op, nullptr);
...@@ -87,13 +97,20 @@ void RenameAndGetOutputs( ...@@ -87,13 +97,20 @@ void RenameAndGetOutputs(
auto *in_var = op->mutable_inputs(i); auto *in_var = op->mutable_inputs(i);
std::vector<std::string> replaced_names; std::vector<std::string> replaced_names;
for (int k = 0; k < in_var->arguments_size(); k++) { // all the arguments for (int k = 0; k < in_var->arguments_size(); k++) { // all the arguments
std::string arg_value = in_var->arguments(k); const std::string arg_value = in_var->arguments(k);
std::string arg_value_with_id = const std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]); arg_value + std::to_string(var2id[arg_value]);
if (input_names_with_id.count(arg_value_with_id)) { if (input_names_with_id.count(arg_value_with_id)) {
replaced_names.push_back(arg_value); replaced_names.push_back(arg_value);
if (graph_var_map.count(arg_value)) {
add_block_var(arg_value, arg_value);
}
} else { } else {
replaced_names.push_back(arg_value_with_id); replaced_names.push_back(arg_value_with_id);
if (graph_var_map.count(arg_value)) {
add_block_var(arg_value, arg_value_with_id);
}
} }
} }
in_var->clear_arguments(); in_var->clear_arguments();
...@@ -105,7 +122,6 @@ void RenameAndGetOutputs( ...@@ -105,7 +122,6 @@ void RenameAndGetOutputs(
for (auto out_var : correspond_node->outputs) { for (auto out_var : correspond_node->outputs) {
var2id[out_var->Name()] = out_var->id(); var2id[out_var->Name()] = out_var->id();
} }
if (op_desc.Type() == "conv2d" && is_trt) { if (op_desc.Type() == "conv2d" && is_trt) {
auto input_var_name = op_desc.Input("Input").front(); auto input_var_name = op_desc.Input("Input").front();
auto filter_var_name = op_desc.Input("Filter").front(); auto filter_var_name = op_desc.Input("Filter").front();
...@@ -125,15 +141,18 @@ void RenameAndGetOutputs( ...@@ -125,15 +141,18 @@ void RenameAndGetOutputs(
same_hierarchy_conv2d_num_map[input_var_name] += 1; same_hierarchy_conv2d_num_map[input_var_name] += 1;
} }
} }
// rename for the output variables of op inside subgraph // rename for the output variables of op inside subgraph
for (int i = 0; i < op->outputs_size(); i++) { for (int i = 0; i < op->outputs_size(); i++) {
framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i); framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i);
std::vector<std::string> replaced_names; std::vector<std::string> replaced_names;
for (int k = 0; k < out_var->arguments_size(); k++) { for (int k = 0; k < out_var->arguments_size(); k++) {
std::string arg_value = out_var->arguments(k); const std::string arg_value = out_var->arguments(k);
std::string arg_value_with_id = const std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]); arg_value + std::to_string(var2id[arg_value]);
if (graph_var_map.count(arg_value)) {
add_block_var(arg_value, arg_value_with_id);
}
if (output_names_with_id->count(arg_value_with_id)) { if (output_names_with_id->count(arg_value_with_id)) {
(*output_name_map)[arg_value] = arg_value_with_id; (*output_name_map)[arg_value] = arg_value_with_id;
} }
......
...@@ -42,6 +42,7 @@ void RenameAndGetOutputs( ...@@ -42,6 +42,7 @@ void RenameAndGetOutputs(
std::set<std::string> *output_names_with_id, std::set<std::string> *output_names_with_id,
std::set<std::string> *output_names, std::set<std::string> *output_names,
std::unordered_map<std::string, std::string> *output_name_map, std::unordered_map<std::string, std::string> *output_name_map,
const std::unordered_map<std::string, framework::ir::Node *> &graph_var_map,
bool is_trt = true); bool is_trt = true);
} // namespace analysis } // namespace analysis
......
...@@ -142,6 +142,13 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -142,6 +142,13 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
} }
std::unordered_map<std::string, std::string> output_name_map; std::unordered_map<std::string, std::string> output_name_map;
std::unordered_map<std::string, framework::ir::Node *> graph_var_map;
for (framework::ir::Node *node : graph->Nodes()) {
if (node->IsVar() && node->Var()) {
graph_var_map[node->Name()] = node;
}
}
auto &subgraph_nodes = *Agent(node).subgraph(); auto &subgraph_nodes = *Agent(node).subgraph();
// The following procedure is used to rename all the intermediate // The following procedure is used to rename all the intermediate
...@@ -157,7 +164,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -157,7 +164,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// So we have to rename the variable in the subgraph to make sure // So we have to rename the variable in the subgraph to make sure
// it is either an OP's input or an OP's output. // it is either an OP's input or an OP's output.
RenameAndGetOutputs(subgraph_nodes, &block_desc, input_names_with_id, RenameAndGetOutputs(subgraph_nodes, &block_desc, input_names_with_id,
&output_names_with_id, &output_names, &output_name_map); &output_names_with_id, &output_names, &output_name_map,
graph_var_map);
// When tensorrt engine runs at the end of the operation, // When tensorrt engine runs at the end of the operation,
// output_mapping help us copy the data from the renamed ITensor // output_mapping help us copy the data from the renamed ITensor
...@@ -168,14 +176,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -168,14 +176,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
output_mapping.push_back(output_name_map[name]); output_mapping.push_back(output_name_map[name]);
} }
PADDLE_ENFORCE(!output_mapping.empty()); PADDLE_ENFORCE(!output_mapping.empty());
auto *vars = block_desc.Proto()->mutable_vars();
for (framework::ir::Node *node : graph->Nodes()) {
if (node->IsVar() && node->Var()) {
*vars->Add() = *node->Var()->Proto();
}
}
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(), PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(),
"the block has no var-desc"); "the block has no var-desc");
...@@ -213,7 +213,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -213,7 +213,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "enable_int8", enable_int8); SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "engine_key", engine_key); SetAttr(op_desc->Proto(), "engine_key", engine_key);
std::string trt_engine_serialized_data = ""; std::string trt_engine_serialized_data = "";
SetAttr(op_desc->Proto(), "engine_serialized_data", SetAttr(op_desc->Proto(), "engine_serialized_data",
trt_engine_serialized_data); trt_engine_serialized_data);
......
...@@ -115,6 +115,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -115,6 +115,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(use_anakin_); CP_MEMBER(use_anakin_);
CP_MEMBER(anakin_max_batchsize_); CP_MEMBER(anakin_max_batchsize_);
CP_MEMBER(anakin_max_input_shape_); CP_MEMBER(anakin_max_input_shape_);
CP_MEMBER(anakin_min_subgraph_size_);
// Ir related. // Ir related.
CP_MEMBER(enable_ir_optim_); CP_MEMBER(enable_ir_optim_);
...@@ -315,6 +316,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -315,6 +316,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << specify_input_name_; ss << specify_input_name_;
ss << cpu_math_library_num_threads_; ss << cpu_math_library_num_threads_;
ss << use_anakin_; ss << use_anakin_;
ss << anakin_min_subgraph_size_;
return ss.str(); return ss.str();
} }
...@@ -386,10 +388,11 @@ void AnalysisConfig::SwitchIrDebug(int x) { ...@@ -386,10 +388,11 @@ void AnalysisConfig::SwitchIrDebug(int x) {
Update(); Update();
} }
void AnalysisConfig::EnableAnakinEngine( void AnalysisConfig::EnableAnakinEngine(
int max_batch_size, int max_batch_size, std::map<std::string, std::vector<int>> max_input_shape,
std::map<std::string, std::vector<int>> max_input_shape) { int min_subgraph_size) {
anakin_max_batchsize_ = max_batch_size; anakin_max_batchsize_ = max_batch_size;
anakin_max_input_shape_ = max_input_shape; anakin_max_input_shape_ = max_input_shape;
anakin_min_subgraph_size_ = min_subgraph_size;
use_anakin_ = true; use_anakin_ = true;
Update(); Update();
} }
......
...@@ -385,6 +385,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -385,6 +385,7 @@ void AnalysisPredictor::PrepareArgument() {
if (config_.use_gpu() && config_.anakin_engine_enabled()) { if (config_.use_gpu() && config_.anakin_engine_enabled()) {
argument_.SetAnakinMaxBatchSize(config_.anakin_max_batchsize_); argument_.SetAnakinMaxBatchSize(config_.anakin_max_batchsize_);
argument_.SetAnakinMaxInputShape(config_.anakin_max_input_shape_); argument_.SetAnakinMaxInputShape(config_.anakin_max_input_shape_);
argument_.SetAnakinMinSubgraphSize(config_.anakin_min_subgraph_size_);
LOG(INFO) << "Anakin subgraph engine is enabled"; LOG(INFO) << "Anakin subgraph engine is enabled";
} }
......
...@@ -151,7 +151,8 @@ struct AnalysisConfig { ...@@ -151,7 +151,8 @@ struct AnalysisConfig {
*/ */
void EnableAnakinEngine( void EnableAnakinEngine(
int max_batch_size = 1, int max_batch_size = 1,
std::map<std::string, std::vector<int>> max_input_shape = {}); std::map<std::string, std::vector<int>> max_input_shape = {},
int min_subgraph_size = 6);
/** A boolean state indicating whether the Anakin sub-graph engine is used. /** A boolean state indicating whether the Anakin sub-graph engine is used.
*/ */
...@@ -288,6 +289,7 @@ struct AnalysisConfig { ...@@ -288,6 +289,7 @@ struct AnalysisConfig {
bool use_anakin_{false}; bool use_anakin_{false};
int anakin_max_batchsize_; int anakin_max_batchsize_;
int anakin_min_subgraph_size_{6};
std::map<std::string, std::vector<int>> anakin_max_input_shape_; std::map<std::string, std::vector<int>> anakin_max_input_shape_;
std::map<std::string, std::string> engine_opt_info_; std::map<std::string, std::string> engine_opt_info_;
......
...@@ -120,40 +120,8 @@ class AnakinEngineOp : public framework::OperatorBase { ...@@ -120,40 +120,8 @@ class AnakinEngineOp : public framework::OperatorBase {
inference::Singleton<inference::anakin::AnakinEngineManager>::Global() inference::Singleton<inference::anakin::AnakinEngineManager>::Global()
.Get(engine_key_); .Get(engine_key_);
} }
return anakin_engine_; return anakin_engine_;
} }
void Prepare(const framework::Scope &scope, const platform::Place &dev_place,
AnakinNvEngineT *engine) const {
LOG(INFO) << "Prepare Anakin engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time.";
framework::proto::BlockDesc block_desc;
block_desc.ParseFromString(Attr<std::string>("subgraph"));
std::vector<std::string> output_maps =
Attr<std::vector<std::string>>("output_name_mapping");
inference::Singleton<inference::anakin::AnakinOpConverter>::Global()
.ConvertBlock(block_desc, param_names_, scope, engine);
engine->Freeze();
for (const auto &x : Inputs("Xs")) {
if (param_names_.count(x)) continue;
auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
auto t_shape = framework::vectorize2int(t.dims());
// all input shape should be 4 dims
if (t_shape.size() == 2) {
t_shape.push_back(1);
t_shape.push_back(1);
}
engine->SetInputShape(x, t_shape);
}
engine->Optimize();
engine->InitGraph();
}
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册