未验证 提交 a72dbe9a 编写于 作者: 石晓伟 提交者: GitHub

Cherry-pick benchmark related changes from release/1.4 (#17156)

* cherry-pick commit from 88770542

* cherry-pick commit from 3f0b97df

* cherry-pick from 16691:Anakin subgraph support yolo_v3 and faster-rcnn

(cherry picked from commit 8643dbc2)

* Cherry-Pick from 16662 : Anakin subgraph cpu support

(cherry picked from commit 7ad182e1)

* Cherry-pick from 1662, 16797.. : add anakin int8 support

(cherry picked from commit e14ab180)

* Cherry-pick from 16813 : change singleton to graph RegistBlock
test=release/1.4

(cherry picked from commit 4b9fa423)

* Cherry Pick : 16837 Support ShuffleNet and MobileNet-v2

Support ShuffleNet and MobileNet-v2, test=release/1.4

(cherry picked from commit a6fb066f)

* Cherry-pick : anakin subgraph add opt config layout argument #16846
test=release/1.4

(cherry picked from commit 8121b3ec)

* 1. add shuffle_channel_detect

(cherry picked from commit 6efdea89)

* update shuffle_channel op convert, test=release/1.4

(cherry picked from commit e4726a06)

* Modify symbol export rules

test=develop
上级 16922e00
......@@ -25,8 +25,9 @@ endif()
if(ANAKIN_FOUND)
message(STATUS "Current ANAKIN header is ${ANAKIN_INCLUDE_DIR}/anakin_config.h. ")
include_directories(${ANAKIN_ROOT})
include_directories(${ANAKIN_ROOT}/include)
include_directories(${ANAKIN_ROOT}/include/saber)
include_directories(${ANAKIN_ROOT}/saber)
link_directories(${ANAKIN_ROOT})
add_definitions(-DPADDLE_WITH_ANAKIN)
endif()
......@@ -71,6 +71,7 @@ pass_library(runtime_context_cache_pass base)
pass_library(expected_kernel_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference)
if(ANAKIN_FOUND)
pass_library(simplify_anakin_priorbox_detection_out_pass inference)
......
......@@ -48,17 +48,37 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
auto base_op_desc = mul->Op();
// Create an FC Node.
// OpDesc desc(base_op_desc, nullptr);
OpDesc desc;
std::string fc_x_in = subgraph.at(x)->Name();
std::string fc_Y_in = w->Name();
std::string fc_bias_in = fc_bias->Name();
std::string fc_out_out = fc_out->Name();
desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
// For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant"
// can be detected by the quant_dequant_fuse_pass. This pass will add
// "input_scale",
// "weight_scale" which are extracted from fake_quant op and fake_dequant op
// to mul op,
// and then delete the fake_quant op and fake_dequant op in the graph. If
// the mul op
// has the scale info, we should add those to the fused fc.
if (base_op_desc->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8"));
desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale"));
desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale"));
}
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
......
......@@ -1640,7 +1640,8 @@ PDNode *patterns::FillConstantElementWiseMulFuse::operator()(
void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type,
const std::string &weight_name,
int times) {
int times,
const std::string &quant_type) {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
......@@ -1648,24 +1649,22 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
// the quant op always be one.
auto quant_op_in_scale =
pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input("fake_quantize_range_abs_max", "InScale")
->AsInput();
auto quant_op = pattern->NewNode(GetNodeName("quant_op"))
->assert_is_op("fake_quantize_range_abs_max");
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input(quant_type, "InScale")
->AsInput();
auto quant_op =
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type);
auto quant_op_out_scale =
pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output("fake_quantize_range_abs_max", "OutScale")
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
auto quant_op_out =
pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output("fake_quantize_range_abs_max", "Out")
->assert_is_op_input(op_type)
->AsIntermediate();
auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output(quant_type, "Out")
->assert_is_op_input(op_type)
->AsIntermediate();
// there are 'times' quantized and dequant op
std::vector<PDNode *> nodes;
......@@ -1707,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
}
}
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
auto reshape1_op =
pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2");
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("transpose2")
->AsIntermediate();
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto transpose_out = pattern->NewNode(transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("reshape2")
->AsIntermediate();
auto reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto reshape2_out = pattern->NewNode(reshape2_out_repr())
->assert_is_op_output("reshape2", "Out")
->AsOutput();
reshape1_op->LinksFrom({reshape1_in});
reshape1_out->LinksFrom({reshape1_op});
transpose_op->LinksFrom({reshape1_out});
transpose_out->LinksFrom({transpose_op});
reshape2_op->LinksFrom({transpose_out});
reshape2_out->LinksFrom({reshape2_op});
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -880,7 +880,8 @@ struct QuantDequantOpFuse : public PatternBase {
: PatternBase(pattern, name_scope, "quant_dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& op_name,
const std::string& weight_name, int times = 1);
const std::string& weight_name, int times,
const std::string& quant_type);
std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
......@@ -891,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase {
}
};
struct ShuffleChannelPattern : public PatternBase {
ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "shufflechannel_pattern") {}
void operator()(PDNode* reshape1_in);
PATTERN_DECL_NODE(reshape1_op);
PATTERN_DECL_NODE(reshape1_out);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......
......@@ -25,7 +25,8 @@ namespace framework {
namespace ir {
void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
std::string op_type) {
const std::string& op_type,
const std::string& quant_type) {
const std::string pattern_name = "quant_dequant_fuse";
// FusePassBase::Init(pattern_name, graph);
const int kNumFields = 5;
......@@ -38,7 +39,7 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("fake_quantize_range_abs_max", "X")
->assert_is_op_input(quant_type, "X")
->AsInput();
std::string quantized_op_type = "";
......@@ -46,6 +47,9 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
if (op_type == "conv2d") {
quantized_op_type = "conv2d";
weight_name = "Filter";
} else if (op_type == "depthwise_conv2d") {
quantized_op_type = "depthwise_conv2d";
weight_name = "Filter";
} else if (op_type == "conv2d_fusion") {
quantized_op_type = "conv2d_fusion";
weight_name = "Filter";
......@@ -62,7 +66,7 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
}
patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, quantized_op_type, weight_name, times);
pattern(x, quantized_op_type, weight_name, times, quant_type);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -103,7 +107,6 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
std::unordered_set<const Node*> delete_nodes;
for (int i = 0; i < times; i++) {
// max_range = (range * range) / weight_scale
float max_range = boost::get<float>(
nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr("max_range"));
float weight_scale = (range * range) / max_range;
......@@ -118,7 +121,8 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
new_op_desc.SetType(quantized_op_type);
if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion") {
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Output", {new_output});
} else if (quantized_op_type == "fc") {
......@@ -156,11 +160,17 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "quant_dequant_fuse";
FusePassBase::Init(pattern_name, graph);
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul"};
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul",
"depthwise_conv2d"};
auto* scope = param_scope();
for (auto& op_type : quantized_op_types) {
for (int i = 1; i <= 6; i++) {
RunQuantDequant(graph, scope, i, op_type);
for (auto& quant_type : quant_types) {
for (auto& op_type : quantized_op_types) {
for (int i = 6; i >= 1; i--) {
RunQuantDequant(graph, scope, i, op_type, quant_type);
}
}
}
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(reshape1_op); \
GET_IR_NODE(reshape1_out); \
GET_IR_NODE(transpose_op); \
GET_IR_NODE(transpose_out); \
GET_IR_NODE(reshape2_op); \
GET_IR_NODE(reshape2_out);
void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "shufflechannel_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("reshape2", "X")
->AsInput();
patterns::ShuffleChannelPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
PADDLE_ENFORCE(subgraph.count(x));
auto* input_node = subgraph.at(x);
auto reshape1_desc = reshape1_op->Op();
auto reshape2_desc = reshape2_op->Op();
std::string input_name = input_node->Name();
std::string output_name = reshape2_out->Name();
auto reshape1_shape =
boost::get<std::vector<int>>(reshape1_desc->GetAttr("shape"));
auto reshape2_shape =
boost::get<std::vector<int>>(reshape2_desc->GetAttr("shape"));
int i_c = reshape1_shape[2];
int o_c = reshape2_shape[1];
int group = o_c / i_c;
framework::OpDesc new_op_desc;
new_op_desc.SetType("shuffle_channel");
new_op_desc.SetInput("X", {input_name});
new_op_desc.SetOutput("Out", {output_name});
new_op_desc.SetAttr("group", group);
new_op_desc.Flush();
// Create a new node for the fused op.
auto* new_op = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(input_node, new_op);
IR_NODE_LINK_TO(new_op, reshape2_out);
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op,
transpose_out, reshape2_op});
};
gpd(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(shuffle_channel_detect_pass,
paddle::framework::ir::ShuffleChannelDetectPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ShuffleChannelDetectPass : public FusePassBase {
public:
virtual ~ShuffleChannelDetectPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc DEPS anakin_engine framework_proto scope op_registry)
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc
elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc
batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc
detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc
roi_align.cc shuffle_channel.cc helper.cc DEPS anakin_engine framework_proto
scope op_registry gtest)
cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL)
cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL)
......@@ -14,5 +19,5 @@ cc_test(test_anakin_flatten SRCS test_flatten_op.cc DEPS anakin_op_converter fla
cc_test(test_anakin_transpose SRCS test_transpose_op.cc DEPS anakin_op_converter transpose_op SERIAL)
cc_test(test_anakin_batch_norm SRCS test_batch_norm_op.cc DEPS anakin_op_converter batch_norm_op SERIAL)
cc_test(test_anakin_dropout SRCS test_dropout_op.cc DEPS anakin_op_converter dropout_op SERIAL)
#cc_test(test_anakin_im2sequence SRCS test_im2sequence_op.cc DEPS anakin_op_converter im2sequence_op im2col)
cc_test(test_anakin_sum SRCS test_sum_op.cc DEPS anakin_op_converter sum_op selected_rows_functor SERIAL)
cc_test(test_anakin_affine_channel SRCS test_affine_channel_op.cc DEPS anakin_op_converter affine_channel_op SERIAL)
......@@ -16,16 +16,13 @@
#include <algorithm>
#include <map>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
namespace paddle {
namespace inference {
namespace anakin {
ActivationOpConverter::ActivationOpConverter(const std::string &op_type)
template <typename TargetT, ::anakin::Precision PrecisionT>
ActivationOpConverter<TargetT, PrecisionT>::ActivationOpConverter(
const std::string &op_type)
: op_type_(op_type) {
auto it = anakin_op_types_.find(op_type_);
PADDLE_ENFORCE(it != anakin_op_types_.end(),
......@@ -33,10 +30,10 @@ ActivationOpConverter::ActivationOpConverter(const std::string &op_type)
anakin_op_type_ = it->second;
}
void ActivationOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void ActivationOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
......@@ -44,8 +41,17 @@ void ActivationOpConverter::operator()(const framework::proto::OpDesc &op,
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
auto input_name = op_desc.Input("X").front();
auto output_name = op_desc.Output("Out").front();
engine_->AddOp(op_name, "Activation", {input_name}, {output_name});
engine_->AddOpAttr(op_name, "type", anakin_op_type_);
this->engine_->AddOp(op_name, "Activation", {input_name}, {output_name});
this->engine_->AddOpAttr(op_name, "type", anakin_op_type_);
if (op_type_ == "swish") {
float beta = boost::get<float>(op_desc.GetAttr("beta"));
this->engine_->AddOpAttr(op_name, "clip_relu_num", beta);
}
if (op_type_ == "relu6") {
float threshold = boost::get<float>(op_desc.GetAttr("threshold"));
this->engine_->AddOpAttr(op_name, "clip_relu_num", threshold);
}
}
} // namespace anakin
......@@ -54,3 +60,5 @@ void ActivationOpConverter::operator()(const framework::proto::OpDesc &op,
REGISTER_ANAKIN_OP_CONVERTER(sigmoid, SigmoidOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(tanh, TanhOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(swish, SwishOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(relu6, Relu6OpConverter);
......@@ -22,7 +22,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class ActivationOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ActivationOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
explicit ActivationOpConverter(const std::string &op_type);
......@@ -36,18 +37,36 @@ class ActivationOpConverter : public AnakinOpConverter {
std::string op_type_;
std::string anakin_op_type_;
std::map<std::string, std::string> anakin_op_types_{{"tanh", "TanH"},
{"sigmoid", "Sigmoid"}};
{"sigmoid", "Sigmoid"},
{"relu6", "ClippedRelu"},
{"swish", "Swish"}};
};
class TanhOpConverter : public ActivationOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class TanhOpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
TanhOpConverter() : ActivationOpConverter("tanh") {}
TanhOpConverter() : ActivationOpConverter<TargetT, PrecisionT>("tanh") {}
};
class SigmoidOpConverter : public ActivationOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class SigmoidOpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
SigmoidOpConverter() : ActivationOpConverter("sigmoid") {}
SigmoidOpConverter()
: ActivationOpConverter<TargetT, PrecisionT>("sigmoid") {}
};
template <typename TargetT, ::anakin::Precision PrecisionT>
class Relu6OpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
Relu6OpConverter() : ActivationOpConverter<TargetT, PrecisionT>("relu6") {}
};
template <typename TargetT, ::anakin::Precision PrecisionT>
class SwishOpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
SwishOpConverter() : ActivationOpConverter<TargetT, PrecisionT>("swish") {}
};
} // namespace anakin
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/affine_channel.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/anakin/convert/helper.h"
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, ::anakin::Precision PrecisionT>
void AffineChannelOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
auto input_name = op_desc.Input("X").front();
auto output_name = op_desc.Output("Out").front();
this->engine_->AddOp(op_name, "AffineChannel", {input_name}, {output_name});
// Copy the Scale to CPUPlace and get the pointer.
auto *scale_v = scope.FindVar(op_desc.Input("Scale").front());
PADDLE_ENFORCE_NOT_NULL(scale_v);
auto weight1 = pblock_from_var<TargetT, PrecisionT>(*scale_v, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
// Copy the Bias to CPUPlace and get the pointer.
auto *bias_v = scope.FindVar(op_desc.Input("Bias").front());
PADDLE_ENFORCE_NOT_NULL(bias_v);
auto weight2 = pblock_from_var<TargetT, PrecisionT>(*bias_v, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_2", *weight2);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(affine_channel, AffineChannelOpConverter);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, ::anakin::Precision PrecisionT>
class AffineChannelOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
AffineChannelOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) override;
virtual ~AffineChannelOpConverter() {}
private:
};
} // namespace anakin
} // namespace inference
} // namespace paddle
......@@ -18,107 +18,64 @@
#include <map>
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
#include "paddle/fluid/inference/anakin/convert/helper.h"
namespace paddle {
namespace inference {
namespace anakin {
void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void BatchNormOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1);
std::map<std::string, std::string> inputs;
for (auto k : {"X", "Scale", "Bias", "Mean", "Variance"}) {
PADDLE_ENFORCE_EQ(op_desc.Input(k).size(), 1UL);
auto v = op_desc.Input(k).front();
inputs.insert({k, v});
}
auto input = op_desc.Input("X").front();
auto output = op_desc.Output("Y").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Y").front();
auto epsilon = boost::get<float>(op_desc.GetAttr("epsilon"));
// auto momentum = boost::get<float>(op_desc.GetAttr("momentum"));
auto bn_op_name = op_name + ":bn";
auto bn_output = bn_op_name + "_output";
engine_->AddOp(bn_op_name, "BatchNorm", {inputs["X"]}, {bn_output});
engine_->AddOpAttr(bn_op_name, "epsilon", epsilon);
engine_->AddOpAttr(bn_op_name, "momentum", static_cast<float>(1.0));
this->engine_->AddOp(bn_op_name, "BatchNorm", {input}, {bn_output});
this->engine_->AddOpAttr(bn_op_name, "epsilon", epsilon);
this->engine_->AddOpAttr(bn_op_name, "momentum", static_cast<float>(1.0));
auto scale_op_name = op_name + ":scale";
auto get_lod_tensor = [this, &scope, &op_name](const std::string &var_name,
framework::LoDTensor *tensor) {
auto *v = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(v);
auto *t = v->GetMutable<framework::LoDTensor>();
tensor->Resize(t->dims());
TensorCopySync(*t, platform::CPUPlace(), tensor);
};
framework::LoDTensor bias_t;
framework::LoDTensor mean_t;
framework::LoDTensor scale_t;
framework::LoDTensor variance_t;
get_lod_tensor(inputs["Bias"], &bias_t);
get_lod_tensor(inputs["Mean"], &mean_t);
get_lod_tensor(inputs["Scale"], &scale_t);
get_lod_tensor(inputs["Variance"], &variance_t);
auto fill_shape = [](size_t n, std::vector<int> shape) {
shape.insert(shape.begin(), 1);
if (shape.size() < n) {
shape.insert(shape.end(), n - shape.size(), 1);
}
return shape;
};
Shape shape1(fill_shape(4, framework::vectorize2int(mean_t.dims())));
Shape shape2(fill_shape(4, framework::vectorize2int(variance_t.dims())));
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(shape1);
auto *mean_data = static_cast<float *>(weight1->h_tensor().mutable_data());
std::copy_n(mean_t.data<float>(), mean_t.numel(), mean_data);
engine_->AddOpAttr(bn_op_name, "weight_1", *weight1);
auto *weight2 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(shape2);
auto *variance_data =
static_cast<float *>(weight2->h_tensor().mutable_data());
std::copy_n(variance_t.data<float>(), variance_t.numel(), variance_data);
engine_->AddOpAttr(bn_op_name, "weight_2", *weight2);
Shape shape3(std::vector<int>({1, 1, 1, 1}));
auto *weight3 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(shape3);
auto *alpha_data = static_cast<float *>(weight3->h_tensor().mutable_data());
float weight3_data[] = {1};
std::copy(std::begin(weight3_data), std::end(weight3_data), alpha_data);
engine_->AddOpAttr(bn_op_name, "weight_3", *weight3);
Shape scale_shape(fill_shape(4, framework::vectorize2int(scale_t.dims())));
auto *scale =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(scale_shape);
auto *scale_data = static_cast<float *>(scale->h_tensor().mutable_data());
std::copy_n(scale_t.data<float>(), scale_t.numel(), scale_data);
Shape bias_shape(fill_shape(4, framework::vectorize2int(bias_t.dims())));
auto *bias =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(bias_shape);
auto *bias_data = static_cast<float *>(bias->h_tensor().mutable_data());
std::copy_n(bias_t.data<float>(), bias_t.numel(), bias_data);
engine_->AddOp(scale_op_name, "Scale", {bn_output}, {output});
engine_->AddOpAttr(scale_op_name, "axis", 1);
engine_->AddOpAttr(scale_op_name, "num_axes", 1);
engine_->AddOpAttr(scale_op_name, "bias_term", true);
engine_->AddOpAttr(scale_op_name, "weight_1", *scale);
engine_->AddOpAttr(scale_op_name, "weight_2", *bias);
this->engine_->AddOp(scale_op_name, "Scale", {bn_output}, {output});
this->engine_->AddOpAttr(scale_op_name, "axis", 1);
this->engine_->AddOpAttr(scale_op_name, "num_axes", 1);
this->engine_->AddOpAttr(scale_op_name, "bias_term", true);
auto *mean_v = scope.FindVar(op_desc.Input("Mean").front());
PADDLE_ENFORCE_NOT_NULL(mean_v);
auto weight1 = pblock_from_var<TargetT, PrecisionT>(*mean_v, this->engine_);
this->engine_->AddOpAttr(bn_op_name, "weight_1", *weight1);
auto *variance_v = scope.FindVar(op_desc.Input("Variance").front());
PADDLE_ENFORCE_NOT_NULL(variance_v);
auto weight2 =
pblock_from_var<TargetT, PrecisionT>(*variance_v, this->engine_);
this->engine_->AddOpAttr(bn_op_name, "weight_2", *weight2);
auto *weight3 = pblock_from_vector<TargetT, PrecisionT>(
std::vector<float>({1}), this->engine_);
this->engine_->AddOpAttr(bn_op_name, "weight_3", *weight3);
auto *scale_v = scope.FindVar(op_desc.Input("Scale").front());
PADDLE_ENFORCE_NOT_NULL(scale_v);
auto scale = pblock_from_var<TargetT, PrecisionT>(*scale_v, this->engine_);
this->engine_->AddOpAttr(scale_op_name, "weight_1", *scale);
auto *bias_v = scope.FindVar(op_desc.Input("Bias").front());
PADDLE_ENFORCE_NOT_NULL(bias_v);
auto bias = pblock_from_var<TargetT, PrecisionT>(*bias_v, this->engine_);
this->engine_->AddOpAttr(scale_op_name, "weight_2", *bias);
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class BatchNormOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class BatchNormOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
BatchNormOpConverter() = default;
......
......@@ -15,34 +15,23 @@
#include "paddle/fluid/inference/anakin/convert/concat.h"
#include <algorithm>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void ConcatOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void ConcatOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
int axis = boost::get<int>(op_desc.GetAttr("axis"));
auto input_names = op_desc.Input("X");
// PADDLE_ENFORCE(axis > 0,
// "The axis attr of Concat op should be large than 0 for trt");
auto y_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Concat", input_names, {y_name});
engine_->AddOpAttr(op_name, "axis", axis);
this->engine_->AddOp(op_name, "Concat", input_names, {y_name});
this->engine_->AddOpAttr(op_name, "axis", axis);
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class ConcatOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ConcatOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
ConcatOpConverter() = default;
......
......@@ -16,21 +16,18 @@
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/inference/anakin/convert/helper.h"
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void Conv2dOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void Conv2dOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1UL);
......@@ -39,46 +36,69 @@ void Conv2dOpConverter::operator()(const framework::proto::OpDesc &op,
auto input_name = op_desc.Input("Input").front();
auto output_name = op_desc.Output("Output").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Output").front();
engine_->AddOp(op_name, "Convolution", {input_name}, {output_name});
this->engine_->AddOp(op_name, "Convolution", {input_name}, {output_name});
auto *filter_v = scope.FindVar(op_desc.Input("Filter").front());
PADDLE_ENFORCE_NOT_NULL(filter_v);
auto *filter_t = filter_v->GetMutable<framework::LoDTensor>();
std::unique_ptr<framework::LoDTensor> weight_tensor(
new framework::LoDTensor());
weight_tensor->Resize(filter_t->dims());
TensorCopySync((*filter_t), platform::CPUPlace(), weight_tensor.get());
auto weight_tensor = tensor_from_var(*filter_v, platform::CPUPlace());
auto weight_shape = framework::vectorize2int(weight_tensor->dims());
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
// const int n_output = weight_tensor->dims()[0];
// const int n_input = weight_tensor->dims()[1];
const int filter_h = weight_tensor->dims()[2];
const int filter_w = weight_tensor->dims()[3];
// auto filter_num = n_input * filter_h * filter_w ;
auto filter_num = weight_tensor->dims()[0];
engine_->AddOpAttr<int>(op_name, "filter_num", filter_num);
engine_->AddOpAttr<PTuple<int>>(op_name, "kernel_size", {filter_h, filter_w});
this->engine_->template AddOpAttr<int>(op_name, "filter_num", filter_num);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "kernel_size",
{filter_h, filter_w});
auto strides = boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
engine_->AddOpAttr<PTuple<int>>(op_name, "strides", strides);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "strides", strides);
auto paddings = boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
engine_->AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
auto dilations = boost::get<std::vector<int>>(op_desc.GetAttr("dilations"));
engine_->AddOpAttr<PTuple<int>>(op_name, "dilation_rate", dilations);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "dilation_rate",
dilations);
const int groups = boost::get<int>(op_desc.GetAttr("groups"));
engine_->AddOpAttr(op_name, "group", groups);
engine_->AddOpAttr(op_name, "axis", 1);
engine_->AddOpAttr(op_name, "bias_term", false);
this->engine_->AddOpAttr(op_name, "group", groups);
this->engine_->AddOpAttr(op_name, "axis", 1);
this->engine_->AddOpAttr(op_name, "bias_term", false);
::anakin::saber::Shape anakin_shape(weight_shape);
bool enable_int8 = boost::get<bool>(op_desc.HasAttr("enable_int8"));
auto weight_shape = framework::vectorize2int(filter_t->dims());
Shape anakin_shape(weight_shape);
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(anakin_shape);
float *cpu_data = static_cast<float *>(weight1->h_tensor().mutable_data());
std::copy_n(weight_tensor->data<float>(), weight_tensor->numel(), cpu_data);
weight1->d_tensor().set_shape(anakin_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
engine_->AddOpAttr(op_name, "weight_1", *weight1);
if (enable_int8) {
const float int8_range = 127.;
float in_scale = boost::get<float>(op_desc.GetAttr("input_scale"));
float weight_scale = boost::get<float>(op_desc.GetAttr("weight_scale"));
PBlock<TargetT> *weight1 =
new PBlock<TargetT>(anakin_shape, ::anakin::AK_INT8);
this->engine_->RegistBlock(weight1);
float *weight_data = weight_tensor->data<float>();
std::vector<char> weight_int8;
int weight_num = weight_tensor->numel();
for (int i = 0; i < weight_tensor->numel(); i++) {
bool is_valid_int8 =
((weight_data[i] >= -128) && (weight_data[i] <= 127));
PADDLE_ENFORCE(is_valid_int8,
"We are in anakin subgraph int8 mode, the weight of conv "
"should be in range [-128, 127]");
weight_int8.push_back(static_cast<char>(weight_data[i]));
}
memcpy(static_cast<void *>(weight1->h_tensor().mutable_data()),
static_cast<void *>(weight_int8.data()), sizeof(char) * weight_num);
weight1->d_tensor().set_shape(anakin_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
this->engine_->Graph()->SetOpPrec(op_name, ::anakin::AK_INT8);
this->engine_->Graph()->SetWeightsScale(op_name,
{weight_scale / int8_range}, false);
this->engine_->AddTensorScale(input_name, in_scale / int8_range);
} else {
auto *weight1 = pblock_from_tensor<TargetT, PrecisionT>(
*weight_tensor, weight_shape, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
}
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class Conv2dOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class Conv2dOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
Conv2dOpConverter() = default;
......
......@@ -16,21 +16,18 @@
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/inference/anakin/convert/helper.h"
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void Conv2dFusionOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void Conv2dFusionOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1UL);
......@@ -40,71 +37,74 @@ void Conv2dFusionOpConverter::operator()(const framework::proto::OpDesc &op,
auto input_name = op_desc.Input("Input").front();
auto output_name = op_desc.Output("Output").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Output").front();
engine_->AddOp(op_name, "Convolution", {input_name}, {output_name});
this->engine_->AddOp(op_name, "Convolution", {input_name}, {output_name});
auto *filter_v = scope.FindVar(op_desc.Input("Filter").front());
PADDLE_ENFORCE_NOT_NULL(filter_v);
auto *filter_t = filter_v->GetMutable<framework::LoDTensor>();
auto weight_tensor = tensor_from_var(*filter_v, platform::CPUPlace());
auto weight_shape = framework::vectorize2int(weight_tensor->dims());
auto *b_v = scope.FindVar(op_desc.Input("Bias").front());
PADDLE_ENFORCE_NOT_NULL(b_v);
auto *b_t = b_v->GetMutable<framework::LoDTensor>();
std::unique_ptr<framework::LoDTensor> weight_tensor(
new framework::LoDTensor());
weight_tensor->Resize(filter_t->dims());
TensorCopySync((*filter_t), platform::CPUPlace(), weight_tensor.get());
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
// const int n_output = weight_tensor->dims()[0];
// const int n_input = weight_tensor->dims()[1];
const int filter_h = weight_tensor->dims()[2];
const int filter_w = weight_tensor->dims()[3];
// auto filter_num = n_input * filter_h * filter_w ;
auto filter_num = weight_tensor->dims()[0];
engine_->AddOpAttr<int>(op_name, "filter_num", filter_num);
engine_->AddOpAttr<PTuple<int>>(op_name, "kernel_size", {filter_h, filter_w});
this->engine_->template AddOpAttr<int>(op_name, "filter_num", filter_num);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "kernel_size",
{filter_h, filter_w});
auto strides = boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
engine_->AddOpAttr<PTuple<int>>(op_name, "strides", strides);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "strides", strides);
auto paddings = boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
engine_->AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
auto dilations = boost::get<std::vector<int>>(op_desc.GetAttr("dilations"));
engine_->AddOpAttr<PTuple<int>>(op_name, "dilation_rate", dilations);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "dilation_rate",
dilations);
const int groups = boost::get<int>(op_desc.GetAttr("groups"));
engine_->AddOpAttr(op_name, "group", groups);
engine_->AddOpAttr(op_name, "axis", 1);
engine_->AddOpAttr(op_name, "bias_term", true);
auto weight_shape = framework::vectorize2int(filter_t->dims());
Shape anakin_shape(weight_shape);
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(anakin_shape);
float *cpu_data = static_cast<float *>(weight1->h_tensor().mutable_data());
std::copy_n(weight_tensor->data<float>(), weight_tensor->numel(), cpu_data);
weight1->d_tensor().set_shape(anakin_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
engine_->AddOpAttr(op_name, "weight_1", *weight1);
auto bias_shape = framework::vectorize2int(b_t->dims());
framework::LoDTensor bias_tensor;
bias_tensor.Resize(b_t->dims());
TensorCopySync((*b_t), platform::CPUPlace(), &bias_tensor);
auto *bias_data = bias_tensor.data<float>();
bias_shape.insert(bias_shape.begin(), 1);
bias_shape.insert(bias_shape.begin(), 1);
bias_shape.insert(bias_shape.begin(), 1);
// bias_shape.push_back(1);
// bias_shape.push_back(1);
Shape anakin_bias_shape(bias_shape);
this->engine_->AddOpAttr(op_name, "group", groups);
this->engine_->AddOpAttr(op_name, "axis", 1);
this->engine_->AddOpAttr(op_name, "bias_term", true);
auto *weight2 = GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(
anakin_bias_shape);
float *cpu_data2 = static_cast<float *>(weight2->h_tensor().mutable_data());
std::copy_n(bias_data, bias_tensor.numel(), cpu_data2);
weight2->d_tensor().set_shape(anakin_bias_shape);
weight2->d_tensor().copy_from(weight2->h_tensor());
engine_->AddOpAttr(op_name, "weight_2", *weight2);
::anakin::saber::Shape anakin_shape(weight_shape);
bool enable_int8 = boost::get<bool>(op_desc.HasAttr("enable_int8"));
if (enable_int8) {
const float int8_range = 127.;
float in_scale = boost::get<float>(op_desc.GetAttr("input_scale"));
float weight_scale = boost::get<float>(op_desc.GetAttr("weight_scale"));
PBlock<TargetT> *weight1 =
new PBlock<TargetT>(anakin_shape, ::anakin::AK_INT8);
this->engine_->RegistBlock(weight1);
float *weight_data = weight_tensor->data<float>();
std::vector<char> weight_int8;
int weight_num = weight_tensor->numel();
for (int i = 0; i < weight_tensor->numel(); i++) {
bool is_valid_int8 =
((weight_data[i] >= -128) && (weight_data[i] <= 127));
PADDLE_ENFORCE(is_valid_int8,
"We are in anakin subgraph int8 mode, the weight of conv "
"should be in range [-128, 127]");
weight_int8.push_back(static_cast<char>(weight_data[i]));
}
memcpy(static_cast<void *>(weight1->h_tensor().mutable_data()),
static_cast<void *>(weight_int8.data()), sizeof(char) * weight_num);
weight1->d_tensor().set_shape(anakin_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
this->engine_->Graph()->SetOpPrec(op_name, ::anakin::AK_INT8);
this->engine_->Graph()->SetWeightsScale(op_name,
{weight_scale / int8_range}, false);
this->engine_->AddTensorScale(input_name, in_scale / int8_range);
} else {
auto weight_tensor = tensor_from_var(*filter_v, platform::CPUPlace());
auto weight_shape = framework::vectorize2int(weight_tensor->dims());
auto *weight1 = pblock_from_tensor<TargetT, PrecisionT>(
*weight_tensor, weight_shape, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
auto weight2 = pblock_from_var<TargetT, PrecisionT>(*b_v, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_2", *weight2);
}
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class Conv2dFusionOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class Conv2dFusionOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
Conv2dFusionOpConverter() = default;
......
......@@ -17,17 +17,14 @@
#include <map>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void DensityPriorBoxOpConverter::operator()(
template <typename TargetT, ::anakin::Precision PrecisionT>
void DensityPriorBoxOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc& op, const framework::BlockDesc& block_desc,
const framework::Scope& scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
......@@ -81,22 +78,30 @@ void DensityPriorBoxOpConverter::operator()(
std::vector<float> temp_v = {};
engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, {output_name});
engine_->AddOpAttr<PTuple<float>>(op_name, "min_size", min_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "max_size", max_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "aspect_ratio", aspect_ratios);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_size", fixed_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_ratio", fixed_ratios);
engine_->AddOpAttr<PTuple<float>>(op_name, "density", dens);
engine_->AddOpAttr(op_name, "is_flip", is_flip);
engine_->AddOpAttr(op_name, "is_clip", is_clip);
engine_->AddOpAttr<PTuple<float>>(op_name, "variance", variances);
engine_->AddOpAttr(op_name, "img_h", static_cast<int>(0));
engine_->AddOpAttr(op_name, "img_w", static_cast<int>(0));
engine_->AddOpAttr(op_name, "step_h", step_h);
engine_->AddOpAttr(op_name, "step_w", step_w);
engine_->AddOpAttr(op_name, "offset", offset);
engine_->AddOpAttr<PTuple<std::string>>(op_name, "order", t_order);
this->engine_->AddOp(op_name, "PriorBox", {input_name, image_name},
{output_name});
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "min_size",
min_sizes);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "max_size",
max_sizes);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "aspect_ratio",
aspect_ratios);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "fixed_size",
fixed_sizes);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "fixed_ratio",
fixed_ratios);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "density", dens);
this->engine_->AddOpAttr(op_name, "is_flip", is_flip);
this->engine_->AddOpAttr(op_name, "is_clip", is_clip);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "variance",
variances);
this->engine_->AddOpAttr(op_name, "img_h", static_cast<int>(0));
this->engine_->AddOpAttr(op_name, "img_w", static_cast<int>(0));
this->engine_->AddOpAttr(op_name, "step_h", step_h);
this->engine_->AddOpAttr(op_name, "step_w", step_w);
this->engine_->AddOpAttr(op_name, "offset", offset);
this->engine_->template AddOpAttr<PTuple<std::string>>(op_name, "order",
t_order);
}
} // namespace anakin
......
......@@ -22,7 +22,9 @@ namespace paddle {
namespace inference {
namespace anakin {
class DensityPriorBoxOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class DensityPriorBoxOpConverter
: public AnakinOpConverter<TargetT, PrecisionT> {
public:
DensityPriorBoxOpConverter() = default;
......
......@@ -16,19 +16,14 @@
#include <algorithm>
#include <map>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
namespace paddle {
namespace inference {
namespace anakin {
void DetectionOutOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void DetectionOutOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
auto target_name = op_desc.Input("TargetBox").front();
auto prior_box_name = op_desc.Input("PriorBox").front();
......@@ -52,18 +47,19 @@ void DetectionOutOpConverter::operator()(const framework::proto::OpDesc &op,
"Not support encode_center_size code_type in DetectionOut of anakin");
}
engine_->AddOp(op_name, "DetectionOutput",
{target_name, scores_name, prior_box_name}, {output_name});
engine_->AddOpAttr(op_name, "share_location", true);
engine_->AddOpAttr(op_name, "variance_encode_in_target", false);
engine_->AddOpAttr(op_name, "class_num", static_cast<int>(0));
engine_->AddOpAttr(op_name, "background_id", background_label);
engine_->AddOpAttr(op_name, "keep_top_k", keep_top_k);
engine_->AddOpAttr(op_name, "code_type", anakin_code_type);
engine_->AddOpAttr(op_name, "conf_thresh", score_threshold);
engine_->AddOpAttr(op_name, "nms_top_k", nms_top_k);
engine_->AddOpAttr(op_name, "nms_thresh", nms_threshold);
engine_->AddOpAttr(op_name, "nms_eta", nms_eta);
this->engine_->AddOp(op_name, "DetectionOutput",
{target_name, scores_name, prior_box_name},
{output_name});
this->engine_->AddOpAttr(op_name, "share_location", true);
this->engine_->AddOpAttr(op_name, "variance_encode_in_target", false);
this->engine_->AddOpAttr(op_name, "class_num", static_cast<int>(0));
this->engine_->AddOpAttr(op_name, "background_id", background_label);
this->engine_->AddOpAttr(op_name, "keep_top_k", keep_top_k);
this->engine_->AddOpAttr(op_name, "code_type", anakin_code_type);
this->engine_->AddOpAttr(op_name, "conf_thresh", score_threshold);
this->engine_->AddOpAttr(op_name, "nms_top_k", nms_top_k);
this->engine_->AddOpAttr(op_name, "nms_thresh", nms_threshold);
this->engine_->AddOpAttr(op_name, "nms_eta", nms_eta);
}
} // namespace anakin
......
......@@ -22,7 +22,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class DetectionOutOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class DetectionOutOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
DetectionOutOpConverter() = default;
......
......@@ -16,24 +16,16 @@
#include <algorithm>
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
#include "paddle/fluid/inference/anakin/convert/helper.h"
namespace paddle {
namespace inference {
namespace anakin {
void DropoutOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void DropoutOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Mask").size(), 1);
......@@ -43,21 +35,17 @@ void DropoutOpConverter::operator()(const framework::proto::OpDesc &op,
auto out_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Scale", {x_name}, {out_name});
this->engine_->AddOp(op_name, "Scale", {x_name}, {out_name});
auto dropout_prob = boost::get<float>(op_desc.GetAttr("dropout_prob"));
auto factor = 1 - dropout_prob;
Shape shape1(std::vector<int>({1, 1, 1, 1}));
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(shape1);
auto *factor_data = static_cast<float *>(weight1->h_tensor().mutable_data());
float weight1_data[] = {factor};
std::copy(std::begin(weight1_data), std::end(weight1_data), factor_data);
auto *weight1 = pblock_from_vector<TargetT, PrecisionT>(
std::vector<float>({factor}), this->engine_);
engine_->AddOpAttr(op_name, "weight_1", *weight1);
engine_->AddOpAttr(op_name, "axis", 0);
engine_->AddOpAttr(op_name, "num_axes", 0);
engine_->AddOpAttr(op_name, "bias_term", false);
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
this->engine_->AddOpAttr(op_name, "axis", 0);
this->engine_->AddOpAttr(op_name, "num_axes", 0);
this->engine_->AddOpAttr(op_name, "bias_term", false);
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class DropoutOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class DropoutOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
DropoutOpConverter() = default;
......
......@@ -17,20 +17,14 @@
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void ElementwiseAddOpConverter::operator()(
template <typename TargetT, ::anakin::Precision PrecisionT>
void ElementwiseAddOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
......@@ -43,14 +37,16 @@ void ElementwiseAddOpConverter::operator()(
auto out_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Eltwise", {x_name, y_name}, {out_name});
this->engine_->AddOp(op_name, "Eltwise", {x_name, y_name}, {out_name});
std::string elementwise_type = "Add";
engine_->AddOpAttr<std::string>(op_name, "type", elementwise_type);
this->engine_->template AddOpAttr<std::string>(op_name, "type",
elementwise_type);
std::vector<float> coeff = {1.0, 1.0};
engine_->AddOpAttr<PTuple<float>>(op_name, "coeff", coeff);
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "coeff", coeff);
}
void ElementwiseMulOpConverter::operator()(
template <typename TargetT, ::anakin::Precision PrecisionT>
void ElementwiseMulOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
......@@ -63,21 +59,12 @@ void ElementwiseMulOpConverter::operator()(
auto out_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Scale", {x_name, y_name}, {out_name});
// Fill a number to weight_1 as a placeholder.
Shape shape1(std::vector<int>({1, 1, 1, 1}));
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(shape1);
auto *placeholder_data =
static_cast<float *>(weight1->h_tensor().mutable_data());
float weight1_data[] = {1};
std::copy(std::begin(weight1_data), std::end(weight1_data), placeholder_data);
engine_->AddOpAttr(op_name, "weight_1", *weight1);
auto axis = boost::get<int>(op_desc.GetAttr("axis"));
engine_->AddOpAttr(op_name, "axis", axis);
engine_->AddOpAttr(op_name, "num_axes", 1);
engine_->AddOpAttr(op_name, "bias_term", false);
this->engine_->AddOp(op_name, "Eltwise", {x_name, y_name}, {out_name});
std::string elementwise_type = "Prod";
this->engine_->template AddOpAttr<std::string>(op_name, "type",
elementwise_type);
std::vector<float> coeff = {1.0, 1.0};
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "coeff", coeff);
}
} // namespace anakin
......
......@@ -20,7 +20,9 @@ namespace paddle {
namespace inference {
namespace anakin {
class ElementwiseAddOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ElementwiseAddOpConverter
: public AnakinOpConverter<TargetT, PrecisionT> {
public:
ElementwiseAddOpConverter() = default;
......@@ -33,7 +35,9 @@ class ElementwiseAddOpConverter : public AnakinOpConverter {
private:
};
class ElementwiseMulOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ElementwiseMulOpConverter
: public AnakinOpConverter<TargetT, PrecisionT> {
public:
ElementwiseMulOpConverter() = default;
......
......@@ -16,23 +16,19 @@
#include <algorithm>
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
#include "paddle/fluid/inference/anakin/convert/helper.h"
namespace paddle {
namespace inference {
namespace anakin {
void FcBaseOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void FcBaseOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
auto input_names = op_desc.InputNames();
bool with_bias = input_names.size() == 3;
bool with_bias = input_names.size() >= 3;
std::string w_name = "Y";
std::string i_name = "X";
......@@ -46,71 +42,74 @@ void FcBaseOpConverter::operator()(const framework::proto::OpDesc &op,
// get weights
auto *y_v = scope.FindVar(op_desc.Input(w_name).front());
PADDLE_ENFORCE_NOT_NULL(y_v);
auto *y_t = y_v->GetMutable<framework::LoDTensor>();
auto input_name = op_desc.Input(i_name).front();
auto output_name = op_desc.Output("Out").front();
auto weight_tensor = tensor_from_var(*y_v, platform::CPUPlace());
auto weight_shape = framework::vectorize2int(weight_tensor->dims());
engine_->AddOp(op_name, "Dense", {input_name}, {output_name});
engine_->AddOpAttr(op_name, "bias_term", with_bias);
engine_->AddOpAttr(op_name, "axis", 1);
auto weight_shape = framework::vectorize2int(y_t->dims());
int out_dim = weight_shape[1];
engine_->AddOpAttr(op_name, "out_dim", out_dim);
const int w_m = weight_shape[0];
const int w_k = weight_shape[1];
if (weight_shape.size() < 4UL) {
weight_shape.insert(weight_shape.begin(), 4UL - weight_shape.size(), 1);
}
Shape anakin_shape(weight_shape);
auto input_name = op_desc.Input(i_name).front();
auto output_name = op_desc.Output("Out").front();
framework::LoDTensor weight_tensor;
weight_tensor.Resize(y_t->dims());
TensorCopySync((*y_t), platform::CPUPlace(), &weight_tensor);
auto *weight_data = weight_tensor.data<float>();
PADDLE_ENFORCE(w_m * w_k == weight_tensor.numel());
this->engine_->AddOp(op_name, "Dense", {input_name}, {output_name});
this->engine_->AddOpAttr(op_name, "bias_term", with_bias);
this->engine_->AddOpAttr(op_name, "axis", 1);
this->engine_->AddOpAttr(op_name, "out_dim", out_dim);
std::vector<float> trans_weight_data(weight_tensor.numel());
auto *weight_data = weight_tensor->data<float>();
PADDLE_ENFORCE(w_m * w_k == weight_tensor->numel());
std::vector<float> trans_weight_data(weight_tensor->numel());
for (int i = 0; i < w_m; i++) {
for (int j = 0; j < w_k; j++) {
trans_weight_data[i + j * w_m] = weight_data[i * w_k + j];
}
}
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(anakin_shape);
float *cpu_data = static_cast<float *>(weight1->h_tensor().mutable_data());
std::copy_n(trans_weight_data.data(), weight_tensor.numel(), cpu_data);
weight1->d_tensor().set_shape(anakin_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
engine_->AddOpAttr(op_name, "weight_1", *weight1);
int weight_num = weight_tensor->numel();
bool enable_int8 = boost::get<bool>(op_desc.HasAttr("enable_int8"));
if (enable_int8) {
if (weight_shape.size() < 4UL) {
weight_shape.insert(weight_shape.begin(), 4UL - weight_shape.size(), 1);
}
::anakin::saber::Shape anakin_shape(weight_shape);
const float int8_range = 127.;
float in_scale = boost::get<float>(op_desc.GetAttr("input_scale"));
float weight_scale = boost::get<float>(op_desc.GetAttr("weight_scale"));
PBlock<TargetT> *weight1 =
new PBlock<TargetT>(anakin_shape, ::anakin::AK_INT8);
this->engine_->RegistBlock(weight1);
std::vector<char> weight_int8;
for (int i = 0; i < weight_num; i++) {
bool is_valid_int8 =
((trans_weight_data[i] >= -128) && (trans_weight_data[i] <= 127));
PADDLE_ENFORCE(is_valid_int8,
"We are in anakin subgraph int8 mode, the weight of fc "
"should be in range [-128, 127]");
weight_int8.push_back(static_cast<char>(trans_weight_data[i]));
}
memcpy(static_cast<void *>(weight1->h_tensor().mutable_data()),
static_cast<void *>(weight_int8.data()), sizeof(char) * weight_num);
weight1->d_tensor().set_shape(anakin_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
this->engine_->Graph()->SetOpPrec(op_name, ::anakin::AK_INT8);
this->engine_->Graph()->SetWeightsScale(op_name,
{weight_scale / int8_range}, false);
this->engine_->AddTensorScale(input_name, in_scale / int8_range);
} else {
auto *weight1 = pblock_from_vector<TargetT, PrecisionT>(trans_weight_data,
this->engine_);
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
}
// get bias
if (with_bias) {
auto *b_v = scope.FindVar(op_desc.Input("Bias").front());
PADDLE_ENFORCE_NOT_NULL(b_v);
auto *b_t = b_v->GetMutable<framework::LoDTensor>();
auto bias_shape = framework::vectorize2int(b_t->dims());
framework::LoDTensor bias_tensor;
bias_tensor.Resize(b_t->dims());
TensorCopySync((*b_t), platform::CPUPlace(), &bias_tensor);
auto *bias_data = bias_tensor.data<float>();
bias_shape.insert(bias_shape.begin(), 1);
bias_shape.insert(bias_shape.begin(), 1);
bias_shape.insert(bias_shape.begin(), 1);
// bias_shape.push_back(1);
// bias_shape.push_back(1);
Shape anakin_bias_shape(bias_shape);
auto *weight2 = GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(
anakin_bias_shape);
float *cpu_data2 = static_cast<float *>(weight2->h_tensor().mutable_data());
std::copy_n(bias_data, bias_tensor.numel(), cpu_data2);
weight2->d_tensor().set_shape(anakin_bias_shape);
weight2->d_tensor().copy_from(weight2->h_tensor());
engine_->AddOpAttr(op_name, "weight_2", *weight2);
auto weight2 = pblock_from_var<TargetT, PrecisionT>(*b_v, this->engine_);
this->engine_->AddOpAttr(op_name, "weight_2", *weight2);
}
}
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class FcBaseOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class FcBaseOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
FcBaseOpConverter() = default;
......@@ -32,13 +33,15 @@ class FcBaseOpConverter : public AnakinOpConverter {
};
// with bias
class FcOpConverter : public FcBaseOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class FcOpConverter : public FcBaseOpConverter<TargetT, PrecisionT> {
public:
FcOpConverter() = default;
};
// without bias
class MulOpConverter : public FcBaseOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class MulOpConverter : public FcBaseOpConverter<TargetT, PrecisionT> {
public:
MulOpConverter() = default;
};
......
......@@ -15,20 +15,16 @@
#include "paddle/fluid/inference/anakin/convert/flatten.h"
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void FlattenOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void FlattenOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1UL);
......@@ -41,8 +37,8 @@ void FlattenOpConverter::operator()(const framework::proto::OpDesc &op,
std::vector<int> out_dims = {0, -1, 1, 1};
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Reshape", {input}, {output});
engine_->AddOpAttr<PTuple<int>>(op_name, "dims", out_dims);
this->engine_->AddOp(op_name, "Reshape", {input}, {output});
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "dims", out_dims);
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class FlattenOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class FlattenOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
FlattenOpConverter() = default;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/helper.h"
namespace paddle {
namespace inference {
namespace anakin {
std::unique_ptr<framework::LoDTensor> tensor_from_var(
const framework::Variable& var, const platform::Place& place) {
auto& src = var.Get<framework::LoDTensor>();
std::unique_ptr<framework::LoDTensor> dst(new framework::LoDTensor());
dst->Resize(src.dims());
TensorCopySync((src), place, dst.get());
return dst;
}
} // namespace anakin
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <map>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/inference/anakin/engine.h"
#include "framework/core/net/net.h"
#include "framework/core/types.h"
#include "framework/graph/graph.h"
#include "framework/graph/graph_global_mem.h"
#include "saber/saber_types.h"
using anakin::saber::Shape;
using anakin::AK_FLOAT;
using anakin::AK_INT8;
using anakin::PBlock;
namespace paddle {
namespace inference {
namespace anakin {
std::unique_ptr<framework::LoDTensor> tensor_from_var(
const framework::Variable& var, const platform::Place& place);
template <typename TargetT, ::anakin::Precision PrecisionT>
PBlock<TargetT>* pblock_from_tensor(const framework::LoDTensor& tensor,
std::vector<int> shape_vec,
AnakinEngine<TargetT, PrecisionT>* engine) {
while (shape_vec.size() < 4) {
shape_vec.insert(shape_vec.begin(), 1);
}
Shape shape(shape_vec);
PBlock<TargetT>* weight = new PBlock<TargetT>(shape, AK_FLOAT);
engine->RegistBlock(weight);
float* cpu_data = static_cast<float*>(weight->h_tensor().mutable_data());
std::copy_n(tensor.data<float>(), tensor.numel(), cpu_data);
weight->d_tensor().set_shape(shape);
weight->d_tensor().copy_from(weight->h_tensor());
return weight;
}
template <typename TargetT, ::anakin::Precision PrecisionT>
PBlock<TargetT>* pblock_from_vector(const std::vector<float>& vec,
std::vector<int> shape_vec,
AnakinEngine<TargetT, PrecisionT>* engine) {
while (shape_vec.size() < 4) {
shape_vec.insert(shape_vec.begin(), 1);
}
Shape shape(shape_vec);
PBlock<TargetT>* weight = new PBlock<TargetT>(shape, AK_FLOAT);
engine->RegistBlock(weight);
auto* weight_data = static_cast<float*>(weight->h_tensor().mutable_data());
std::copy(std::begin(vec), std::end(vec), weight_data);
weight->d_tensor().set_shape(shape);
weight->d_tensor().copy_from(weight->h_tensor());
return weight;
}
template <typename TargetT, ::anakin::Precision PrecisionT>
PBlock<TargetT>* pblock_from_vector(const std::vector<float>& vec,
AnakinEngine<TargetT, PrecisionT>* engine) {
int size = vec.size();
return pblock_from_vector<TargetT, PrecisionT>(
vec, std::vector<int>({1, 1, 1, size}), engine);
}
template <typename TargetT, ::anakin::Precision PrecisionT>
PBlock<TargetT>* pblock_from_var(const framework::Variable& var,
AnakinEngine<TargetT, PrecisionT>* engine) {
auto tensor = tensor_from_var(var, platform::CPUPlace());
auto shape = framework::vectorize2int(tensor->dims());
return pblock_from_tensor<TargetT, PrecisionT>(*tensor, shape, engine);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
......@@ -17,23 +17,16 @@
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void Im2SequenceConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void Im2SequenceConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 0);
......@@ -43,17 +36,19 @@ void Im2SequenceConverter::operator()(const framework::proto::OpDesc &op,
auto out_name = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Im2Sequence", {x_name}, {out_name});
this->engine_->AddOp(op_name, "Im2Sequence", {x_name}, {out_name});
std::vector<int> dilations = {1, 1};
auto paddings = boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
auto strides = boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
auto kernels = boost::get<std::vector<int>>(op_desc.GetAttr("kernels"));
engine_->AddOpAttr<PTuple<int>>(op_name, "paddings", paddings);
engine_->AddOpAttr<PTuple<int>>(op_name, "strides", strides);
engine_->AddOpAttr<PTuple<int>>(op_name, "window_size", kernels);
engine_->AddOpAttr<PTuple<int>>(op_name, "dilations", dilations);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "paddings", paddings);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "strides", strides);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "window_size",
kernels);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "dilations",
dilations);
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class Im2SequenceConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class Im2SequenceConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
Im2SequenceConverter() = default;
......
......@@ -32,10 +32,10 @@ namespace paddle {
namespace inference {
namespace anakin {
using AnakinNvEngine =
AnakinEngine<::anakin::saber::NV, ::anakin::Precision::FP32>;
template <typename TargetT, ::anakin::Precision PrecisionT>
class AnakinOpConverter {
using AnakinEngineT = AnakinEngine<TargetT, PrecisionT>;
public:
AnakinOpConverter() = default;
......@@ -45,7 +45,7 @@ class AnakinOpConverter {
void ConvertOp(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const std::unordered_set<std::string> &parameters,
const framework::Scope &scope, AnakinNvEngine *engine,
const framework::Scope &scope, AnakinEngineT *engine,
bool test_mode = false) {
framework::OpDesc op_desc(op, nullptr);
std::string op_type = op_desc.Type();
......@@ -65,7 +65,7 @@ class AnakinOpConverter {
void ConvertBlock(framework::BlockDesc *block_desc,
const std::unordered_set<std::string> &parameters,
const framework::Scope &scope, AnakinNvEngine *engine) {
const framework::Scope &scope, AnakinEngineT *engine) {
std::unique_lock<std::mutex> lock(mutex_);
framework::proto::BlockDesc *block = block_desc->Proto();
for (auto i = 0; i < block->ops_size(); i++) {
......@@ -79,9 +79,8 @@ class AnakinOpConverter {
framework::BlockDesc *block_desc, framework::Scope *scope,
const std::vector<std::string> &inputs,
const std::unordered_set<std::string> &parameters,
const std::vector<std::string> &outputs, AnakinNvEngine *engine) {
const std::vector<std::string> &outputs, AnakinEngineT *engine) {
ConvertBlock(block_desc, parameters, *scope, engine);
engine->Freeze();
// if the max_batch size
int max_batch_size = engine->GetMaxBatchSize();
PADDLE_ENFORCE(max_batch_size > 0,
......@@ -91,6 +90,18 @@ class AnakinOpConverter {
// the block_desc.
auto max_input_shape = engine->GetMaxInputShape();
std::map<std::string, std::vector<int>> temp_max_input_shape;
// Register outputs with anakin using the RegistVar interface before Freeze.
// Note that RegistVar's parameters can only be outputs, not inputs.
for (auto &output : outputs) {
engine->Graph()->RegistVar(output);
}
engine->Freeze();
// Add scale for tensor in int8 mode.
auto tensor_scales = engine->GetTensorScales();
for (auto &item : tensor_scales) {
engine->Graph()->SetVarScale(item.first, item.second);
}
for (auto &input : inputs) {
if (parameters.count(input)) continue;
......@@ -99,7 +110,7 @@ class AnakinOpConverter {
input_shape[0] = max_batch_size;
if (max_input_shape.count(input)) {
PADDLE_ENFORCE(max_input_shape[input].size() == 4,
"the dimensions of max_input_shape setted from "
"the dimensions of max_input_shape setted from "
"config->EnableAnakinEngine must be 4");
for (int i = 1; i < 4; i++) {
input_shape[i] = max_input_shape[input][i];
......@@ -118,50 +129,104 @@ class AnakinOpConverter {
}
temp_max_input_shape[input] = input_shape;
engine->SetInputShape(input, input_shape);
engine->Graph()->RegistVar(input); // For share from data.
}
engine->SetMaxInputShape(temp_max_input_shape);
engine->Optimize();
// For anakin share with fluid tensor.
engine->AllocTmpMem();
engine->InitGraph();
engine->InitNet();
}
void SetEngine(AnakinNvEngine *engine) { engine_ = engine; }
void SetEngine(AnakinEngineT *engine) { engine_ = engine; }
virtual ~AnakinOpConverter() {}
protected:
bool test_mode_;
AnakinNvEngine *engine_{nullptr};
AnakinEngineT *engine_{nullptr};
private:
std::unordered_map<std::string, AnakinOpConverter *> converters_;
std::unordered_map<std::string, AnakinOpConverter<TargetT, PrecisionT> *>
converters_;
framework::Scope *scope_{nullptr};
std::mutex mutex_;
};
template class AnakinOpConverter<::anakin::saber::NV,
::anakin::Precision::FP32>;
template class AnakinOpConverter<::anakin::saber::NV,
::anakin::Precision::INT8>;
template class AnakinOpConverter<::anakin::saber::X86,
::anakin::Precision::FP32>;
template class AnakinOpConverter<::anakin::saber::X86,
::anakin::Precision::INT8>;
} // namespace anakin
} // namespace inference
} // namespace paddle
#define REGISTER_ANAKIN_OP_CONVERTER(op_type__, Converter__) \
struct anakin_##op_type__##_converter \
: public ::paddle::framework::Registrar { \
anakin_##op_type__##_converter() { \
LOG(INFO) << "register convert " << #op_type__; \
::paddle::inference::Registry< \
::paddle::inference::anakin::AnakinOpConverter>::Global() \
.Register<::paddle::inference::anakin::Converter__>(#op_type__); \
} \
}; \
anakin_##op_type__##_converter anakin_##op_type__##_converter__; \
int TouchConverterRegister_anakin_##op_type__() { \
anakin_##op_type__##_converter__.Touch(); \
return 0; \
#define REGISTER_ANAKIN_OP_CONVERTER_BASE(op_type__, Converter__, \
place_type__, place_class__, \
precision_type__, precision_class__) \
struct anakin_##op_type__##_##place_type__##_##precision_type__##_converter \
: public ::paddle::framework::Registrar { \
anakin_##op_type__##_##place_type__##_##precision_type__##_converter() { \
LOG(INFO) << "register convert " << #op_type__ << " "; \
::paddle::inference::Registry< \
::paddle::inference::anakin::AnakinOpConverter< \
place_class__, precision_class__>>::Global() \
.Register<Converter__>(#op_type__); \
} \
}; \
anakin_##op_type__##_##place_type__##_##precision_type__##_converter \
anakin_##op_type__##_##place_type__##_##precision_type__##_converter__; \
int Touch_anakin_##op_type__##_##place_type__##_##precision_type__() { \
anakin_##op_type__##_##place_type__##_##precision_type__##_converter__ \
.Touch(); \
return 0; \
}
#define USE_ANAKIN_CONVERTER(op_type__) \
extern int TouchConverterRegister_anakin_##op_type__(); \
int use_op_converter_anakin_##op_type__ __attribute__((unused)) = \
TouchConverterRegister_anakin_##op_type__();
#define WRAP(...) __VA_ARGS__
#define REGISTER_CUDA_ANAKIN_OP_CONVERTER(op_type__, Converter__, \
precision_type__) \
REGISTER_ANAKIN_OP_CONVERTER_BASE( \
op_type__, \
::paddle::inference::anakin::Converter__<WRAP( \
::anakin::saber::NV, ::anakin::Precision::precision_type__)>, \
CUDA, ::anakin::saber::NV, precision_type__, \
::anakin::Precision::precision_type__)
#define REGISTER_CPU_ANAKIN_OP_CONVERTER(op_type__, Converter__, \
precision_type__) \
REGISTER_ANAKIN_OP_CONVERTER_BASE( \
op_type__, \
::paddle::inference::anakin::Converter__<WRAP( \
::anakin::saber::X86, ::anakin::Precision::precision_type__)>, \
CPU, ::anakin::saber::X86, precision_type__, \
::anakin::Precision::precision_type__)
#ifdef PADDLE_WITH_CUDA
#define REGISTER_ANAKIN_OP_CONVERTER(op_type__, Converter__) \
REGISTER_CUDA_ANAKIN_OP_CONVERTER(op_type__, Converter__, FP32); \
REGISTER_CUDA_ANAKIN_OP_CONVERTER(op_type__, Converter__, INT8); \
REGISTER_CPU_ANAKIN_OP_CONVERTER(op_type__, Converter__, FP32); \
REGISTER_CPU_ANAKIN_OP_CONVERTER(op_type__, Converter__, INT8)
#else
#define REGISTER_ANAKIN_OP_CONVERTER(op_type__, Converter__) \
REGISTER_CPU_ANAKIN_OP_CONVERTER(op_type__, Converter__, FP32); \
REGISTER_CPU_ANAKIN_OP_CONVERTER(op_type__, Converter__, INT8)
#endif
#define USE_ANAKIN_CONVERTER_BASE(op_type__, place_type__, precision_type__) \
extern int Touch_anakin_##op_type__##_##place_type__##_##precision_type__(); \
int use_converter_anakin_##op_type__##_##place_type__##_##precision_type__ \
__attribute__((unused)) = \
Touch_anakin_##op_type__##_##place_type__##_##precision_type__();
#define USE_ANAKIN_CONVERTER(op_type__) \
USE_ANAKIN_CONVERTER_BASE(op_type__, CUDA, FP32)
#define USE_INT8_ANAKIN_CONVERTER(op_type__) \
USE_ANAKIN_CONVERTER_BASE(op_type__, CUDA, INT8)
#define USE_CPU_ANAKIN_CONVERTER(op_type__) \
USE_ANAKIN_CONVERTER_BASE(op_type__, CPU, FP32)
#define USE_CPU_INT8_ANAKIN_CONVERTER(op_type__) \
USE_ANAKIN_CONVERTER_BASE(op_type__, CPU, INT8)
......@@ -17,23 +17,16 @@
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void Pool2dOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
......@@ -65,13 +58,13 @@ void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op,
PADDLE_THROW("TensorRT unsupported pooling type!");
}
engine_->AddOp(op_name, "Pooling", {x_name}, {y_name});
engine_->AddOpAttr<PTuple<int>>(op_name, "pool_size", ksize);
engine_->AddOpAttr<PTuple<int>>(op_name, "strides", strides);
engine_->AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
engine_->AddOpAttr(op_name, "method", anakin_pool_type);
engine_->AddOpAttr(op_name, "global_pooling", global_pooling);
engine_->AddOpAttr(op_name, "cmp_out_shape_floor_as_conv", !ceil_mode);
this->engine_->AddOp(op_name, "Pooling", {x_name}, {y_name});
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "pool_size", ksize);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "strides", strides);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "padding", paddings);
this->engine_->AddOpAttr(op_name, "method", anakin_pool_type);
this->engine_->AddOpAttr(op_name, "global_pooling", global_pooling);
this->engine_->AddOpAttr(op_name, "cmp_out_shape_floor_as_conv", !ceil_mode);
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class Pool2dOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class Pool2dOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
Pool2dOpConverter() = default;
......
......@@ -16,19 +16,30 @@
#include <algorithm>
#include <map>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
namespace paddle {
namespace inference {
namespace anakin {
void ReluOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void ReluOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
auto input_name = op_desc.Input("X").front();
auto output_name = op_desc.Output("Out").front();
this->engine_->AddOp(op_name, "ReLU", {input_name}, {output_name});
this->engine_->AddOpAttr(op_name, "alpha", 0);
}
template <typename TargetT, ::anakin::Precision PrecisionT>
void LeakyReluOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
......@@ -37,8 +48,9 @@ void ReluOpConverter::operator()(const framework::proto::OpDesc &op,
auto input_name = op_desc.Input("X").front();
auto output_name = op_desc.Output("Out").front();
engine_->AddOp(op_name, "ReLU", {input_name}, {output_name});
engine_->AddOpAttr(op_name, "alpha", 0);
float alpha = boost::get<float>(op_desc.GetAttr("alpha"));
this->engine_->AddOp(op_name, "ReLU", {input_name}, {output_name});
this->engine_->AddOpAttr(op_name, "alpha", alpha);
}
} // namespace anakin
......@@ -46,3 +58,4 @@ void ReluOpConverter::operator()(const framework::proto::OpDesc &op,
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(relu, ReluOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(leaky_relu, LeakyReluOpConverter);
......@@ -22,7 +22,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class ReluOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ReluOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
ReluOpConverter() = default;
......@@ -33,6 +34,18 @@ class ReluOpConverter : public AnakinOpConverter {
virtual ~ReluOpConverter() {}
};
template <typename TargetT, ::anakin::Precision PrecisionT>
class LeakyReluOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
LeakyReluOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) override;
virtual ~LeakyReluOpConverter() {}
};
} // namespace anakin
} // namespace inference
} // namespace paddle
......@@ -15,20 +15,16 @@
#include "paddle/fluid/inference/anakin/convert/reshape.h"
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void ReshapeOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void ReshapeOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1UL);
......@@ -37,13 +33,13 @@ void ReshapeOpConverter::operator()(const framework::proto::OpDesc &op,
auto output = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Reshape", {input}, {output});
this->engine_->AddOp(op_name, "Reshape", {input}, {output});
auto shape = boost::get<std::vector<int>>(op_desc.GetAttr("shape"));
if (shape.size() < 4) {
shape.insert(shape.end(), 4 - shape.size(), 1);
}
engine_->AddOpAttr<PTuple<int>>(op_name, "dims", shape);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "dims", shape);
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class ReshapeOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ReshapeOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
ReshapeOpConverter() = default;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/roi_align.h"
#include <algorithm>
#include <map>
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, ::anakin::Precision PrecisionT>
void RoiAlignOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("ROIs").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
auto input_x_name = op_desc.Input("X").front();
auto input_rois_name = op_desc.Input("ROIs").front();
auto output_name = op_desc.Output("Out").front();
auto spatial_scale = boost::get<float>(op_desc.GetAttr("spatial_scale"));
auto pooled_height = boost::get<int>(op_desc.GetAttr("pooled_height"));
auto pooled_width = boost::get<int>(op_desc.GetAttr("pooled_width"));
auto sampling_ratio = boost::get<int>(op_desc.GetAttr("sampling_ratio"));
this->engine_->AddOp(op_name, "RoiAlign", {input_x_name, input_rois_name},
{output_name});
this->engine_->AddOpAttr(op_name, "spatial_scale", spatial_scale);
this->engine_->AddOpAttr(op_name, "pooled_height", pooled_height);
this->engine_->AddOpAttr(op_name, "pooled_width", pooled_width);
this->engine_->AddOpAttr(op_name, "sampling_ratio", sampling_ratio);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(roi_align, RoiAlignOpConverter);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, ::anakin::Precision PrecisionT>
class RoiAlignOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
RoiAlignOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) override;
virtual ~RoiAlignOpConverter() {}
};
} // namespace anakin
} // namespace inference
} // namespace paddle
......@@ -16,19 +16,14 @@
#include <algorithm>
#include <map>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
namespace paddle {
namespace inference {
namespace anakin {
void ScaleOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void ScaleOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
......@@ -44,10 +39,10 @@ void ScaleOpConverter::operator()(const framework::proto::OpDesc &op,
PADDLE_ENFORCE(bias_after_scale,
"The anakin scale layer only support bias after scale now.");
engine_->AddOp(op_name, "Power", {input_name}, {output_name});
engine_->AddOpAttr(op_name, "shift", bias);
engine_->AddOpAttr(op_name, "scale", scale);
engine_->AddOpAttr(op_name, "power", static_cast<float>(1.0));
this->engine_->AddOp(op_name, "Power", {input_name}, {output_name});
this->engine_->AddOpAttr(op_name, "shift", bias);
this->engine_->AddOpAttr(op_name, "scale", scale);
this->engine_->AddOpAttr(op_name, "power", static_cast<float>(1.0));
}
} // namespace anakin
......
......@@ -22,7 +22,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class ScaleOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ScaleOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
ScaleOpConverter() = default;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/shuffle_channel.h"
#include <algorithm>
#include <string>
#include <vector>
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, ::anakin::Precision PrecisionT>
void ShuffleChannelOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto input = op_desc.Input("X").front();
auto output = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
this->engine_->AddOp(op_name, "ShuffleChannel", {input}, {output});
auto group = boost::get<int>(op_desc.GetAttr("group"));
this->engine_->AddOpAttr(op_name, "group", group);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(shuffle_channel, ShuffleChannelOpConverter);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ShuffleChannelOpConverter
: public AnakinOpConverter<TargetT, PrecisionT> {
public:
ShuffleChannelOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) override;
virtual ~ShuffleChannelOpConverter() {}
};
} // namespace anakin
} // namespace inference
} // namespace paddle
......@@ -14,19 +14,14 @@
#include "paddle/fluid/inference/anakin/convert/softmax.h"
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
namespace paddle {
namespace inference {
namespace anakin {
void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void SoftMaxOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1UL);
......@@ -41,8 +36,8 @@ void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op,
auto input_shape_in_fluid = input_var_desc->GetShape();
size_t input_dims = input_shape_in_fluid.size();
engine_->AddOp(op_name, "Softmax", {input}, {output});
engine_->AddOpAttr(op_name, "axis", static_cast<int>(input_dims - 1));
this->engine_->AddOp(op_name, "Softmax", {input}, {output});
this->engine_->AddOpAttr(op_name, "axis", static_cast<int>(input_dims - 1));
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class SoftMaxOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class SoftMaxOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
SoftMaxOpConverter() = default;
......
......@@ -16,23 +16,16 @@
#include <algorithm>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void SplitOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void SplitOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
auto input_name = op_desc.Input("X").front();
auto y_names = op_desc.Output("Out");
......@@ -51,14 +44,16 @@ void SplitOpConverter::operator()(const framework::proto::OpDesc &op,
num_sum += output_lengths[i];
slice_point.push_back(num_sum);
}
engine_->AddOp(op_name, "Slice", {input_name}, y_names);
engine_->AddOpAttr(op_name, "axis", axis);
engine_->AddOpAttr<PTuple<int>>(op_name, "slice_point", slice_point);
this->engine_->AddOp(op_name, "Slice", {input_name}, y_names);
this->engine_->AddOpAttr(op_name, "axis", axis);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "slice_point",
slice_point);
// slice_dim is useless in anakin
engine_->AddOpAttr(op_name, "slice_dim", 4);
this->engine_->AddOpAttr(op_name, "slice_dim", 4);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
REGISTER_ANAKIN_OP_CONVERTER(split, SplitOpConverter);
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class SplitOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class SplitOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
SplitOpConverter() = default;
......
......@@ -17,22 +17,16 @@
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void SumOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void SumOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 2);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
......@@ -43,9 +37,10 @@ void SumOpConverter::operator()(const framework::proto::OpDesc &op,
std::vector<float> coeff = {1, 1};
std::string elementwise_type = "Add";
engine_->AddOp(op_name, "Eltwise", input_names, {out_name});
engine_->AddOpAttr<PTuple<float>>(op_name, "coeff", coeff);
engine_->AddOpAttr<std::string>(op_name, "type", elementwise_type);
this->engine_->AddOp(op_name, "Eltwise", input_names, {out_name});
this->engine_->template AddOpAttr<PTuple<float>>(op_name, "coeff", coeff);
this->engine_->template AddOpAttr<std::string>(op_name, "type",
elementwise_type);
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class SumOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class SumOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
SumOpConverter() = default;
......
......@@ -21,12 +21,14 @@ namespace paddle {
namespace inference {
namespace anakin {
static void test_activation_op(const std::string &op_type) {
auto *converter = Registry<AnakinOpConverter>::Global().Lookup(op_type);
PADDLE_ENFORCE(converter != nullptr);
template <typename TargetT>
static void test_activation_op(const std::string& op_type,
const platform::DeviceContext& context,
bool use_gpu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("act-X", {10, 6, 1, 1});
validator.DeclOutputVar("act-Out", {10, 6, 1, 1});
framework::OpDesc desc;
......@@ -34,6 +36,14 @@ static void test_activation_op(const std::string &op_type) {
desc.SetInput("X", {"act-X"});
desc.SetOutput("Out", {"act-Out"});
if (op_type == "swish") {
desc.SetAttr("beta", 1.0f);
}
if (op_type == "relu6") {
desc.SetAttr("threshold", 6.0f);
}
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
......@@ -41,13 +51,74 @@ static void test_activation_op(const std::string &op_type) {
validator.Execute(5);
}
TEST(sigm_op, test) { test_activation_op("sigmoid"); }
TEST(tanh_op, test) { test_activation_op("tanh"); }
#ifdef PADDLE_WITH_CUDA
TEST(sigm_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_activation_op<::anakin::saber::NV>("sigmoid", ctx, true);
}
TEST(tanh_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_activation_op<::anakin::saber::NV>("tanh", ctx, true);
}
TEST(relu6_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_activation_op<::anakin::saber::NV>("relu6", ctx, true);
}
TEST(swish_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_activation_op<::anakin::saber::NV>("swish", ctx, true);
}
#endif
/*
TEST(sigm_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_activation_op<::anakin::saber::X86>("sigmoid", ctx, false);
}
TEST(tanh_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_activation_op<::anakin::saber::X86>("tanh", ctx, false);
}
TEST(relu6_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_activation_op<::anakin::saber::X86>("relu6", ctx, false);
}
TEST(swish_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_activation_op<::anakin::saber::X86>("swish", ctx, false);
}
*/
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(sigmoid);
USE_OP(tanh);
USE_OP(relu6);
USE_OP(swish);
USE_CPU_ANAKIN_CONVERTER(sigmoid);
USE_CPU_ANAKIN_CONVERTER(tanh);
USE_CPU_ANAKIN_CONVERTER(relu6);
USE_CPU_ANAKIN_CONVERTER(swish);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(sigmoid);
USE_ANAKIN_CONVERTER(tanh);
USE_ANAKIN_CONVERTER(relu6);
USE_ANAKIN_CONVERTER(swish);
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/affine_channel.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT>
void test_affine_channel_op(const platform::DeviceContext& context,
bool use_gpu) {
// Declare the difference between the inputs.
std::unordered_set<std::string> parameters({"scale", "bias"});
framework::Scope scope;
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("x", {1, 3, 5, 2});
validator.DeclOutputVar("out", {1, 3, 5, 2});
validator.DeclParamVar("scale", {3});
validator.DeclParamVar("bias", {3});
// Prepare Op descriptions.
framework::OpDesc desc;
desc.SetType("affine_channel");
desc.SetInput("X", {"x"});
desc.SetInput("Bias", {"bias"});
desc.SetInput("Scale", {"scale"});
desc.SetOutput("Out", {"out"});
// Layout must be explicitly specified here as NCHW.
desc.SetAttr("data_layout", std::string("NCHW"));
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
#ifdef PADDLE_WITH_CUDA
TEST(affine_channel_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_affine_channel_op<::anakin::saber::NV>(ctx, true);
}
#endif
TEST(affine_channel_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_affine_channel_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(affine_channel);
USE_CPU_ANAKIN_CONVERTER(affine_channel);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(affine_channel);
#endif
......@@ -19,12 +19,14 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(batch_norm_op, test) {
template <typename TargetT>
void test_batchnorm_op(const platform::DeviceContext& context, bool use_gpu) {
std::unordered_set<std::string> parameters(
{"batch_norm_scale", "batch_norm_bias", "batch_norm_mean",
"batch_norm_variance"});
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
std::vector<int> param_shape{2};
validator.DeclInputVar("batch_norm_X", {1, 2, 5, 5});
......@@ -64,8 +66,26 @@ TEST(batch_norm_op, test) {
validator.Execute(1, neglected_output);
}
#ifdef PADDLE_WITH_CUDA
TEST(batch_norm_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_batchnorm_op<::anakin::saber::NV>(ctx, true);
}
#endif
TEST(batch_norm_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_batchnorm_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(batch_norm);
USE_CPU_ANAKIN_CONVERTER(batch_norm);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(batch_norm);
#endif
......@@ -21,10 +21,12 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(concat_op, test) {
template <typename TargetT>
void test_concat_op(const platform::DeviceContext& context, bool use_gpu) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("concat_x1", {1, 2, 1, 1});
validator.DeclInputVar("concat_x2", {1, 3, 1, 1});
validator.DeclInputVar("concat_x3", {1, 1, 1, 1});
......@@ -44,31 +46,26 @@ TEST(concat_op, test) {
validator.Execute(1);
}
TEST(concat_op, test2) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
validator.DeclInputVar("concat_x1", {1, 4});
validator.DeclInputVar("concat_x2", {3, 4});
validator.DeclInputVar("concat_x3", {2, 4});
validator.DeclOutputVar("concat_out", {6, 4});
// Prepare Op description
framework::OpDesc desc;
desc.SetType("concat");
desc.SetInput("X", {"concat_x1", "concat_x2", "concat_x3"});
desc.SetOutput("Out", {"concat_out"});
int axis = 0;
desc.SetAttr("axis", axis);
validator.SetOp(*desc.Proto());
#ifdef PADDLE_WITH_CUDA
TEST(concat_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_concat_op<::anakin::saber::NV>(ctx, true);
}
#endif
validator.Execute(1);
TEST(concat_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_concat_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(concat);
USE_CPU_ANAKIN_CONVERTER(concat);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(concat);
#endif
......@@ -21,13 +21,12 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(conv2d_op, test) {
auto* conv2d_converter =
Registry<AnakinOpConverter>::Global().Lookup("conv2d");
ASSERT_TRUE(conv2d_converter != nullptr);
template <typename TargetT>
void test_conv2d_op(const platform::DeviceContext& context, bool use_gpu) {
std::unordered_set<std::string> parameters({"conv2d-Y"});
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("conv2d-X", {1, 3, 3, 3});
validator.DeclParamVar("conv2d-Y", {4, 3, 1, 1});
validator.DeclOutputVar("conv2d-Out", {1, 4, 3, 3});
......@@ -54,9 +53,27 @@ TEST(conv2d_op, test) {
validator.Execute(3);
}
#ifdef PADDLE_WITH_CUDA
TEST(conv2d_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_conv2d_op<::anakin::saber::NV>(ctx, true);
}
#endif
TEST(conv2d_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_conv2d_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(conv2d);
USE_CPU_ANAKIN_CONVERTER(conv2d);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(conv2d);
#endif
......@@ -21,10 +21,12 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(dropout_op, native) {
template <typename TargetT>
void test_dropout_op(const platform::DeviceContext& context, bool use_gpu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("x", {1, 1, 2, 2});
validator.DeclOutputVar("out", {1, 1, 2, 2});
validator.DeclOutputVar("mask", {1, 1, 2, 2});
......@@ -45,9 +47,26 @@ TEST(dropout_op, native) {
validator.Execute(1, neglected_output);
}
#ifdef PADDLE_WITH_CUDA
TEST(dropout_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_dropout_op<::anakin::saber::NV>(ctx, true);
}
#endif
TEST(dropout_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_dropout_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(dropout);
USE_CPU_ANAKIN_CONVERTER(dropout);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(dropout);
#endif
......@@ -21,10 +21,14 @@ namespace paddle {
namespace inference {
namespace anakin {
static void test_elementwise_op(const std::string &op_type) {
template <typename TargetT>
static void test_elementwise_op(const std::string& op_type,
const platform::DeviceContext& context,
bool use_gpu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("x", {1, 1, 2, 2});
validator.DeclInputVar("y", {1, 1, 2, 2});
validator.DeclOutputVar("out", {1, 1, 2, 2});
......@@ -43,14 +47,41 @@ static void test_elementwise_op(const std::string &op_type) {
validator.Execute(1);
}
TEST(elementwise_op, native_add) { test_elementwise_op("elementwise_add"); }
TEST(elementwise_op, native_mul) { test_elementwise_op("elementwise_mul"); }
#ifdef PADDLE_WITH_CUDA
TEST(elementwise_op, native_add_gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_elementwise_op<::anakin::saber::NV>("elementwise_add", ctx, true);
}
TEST(elementwise_op, native_mul_gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_elementwise_op<::anakin::saber::NV>("elementwise_mul", ctx, true);
}
#endif
TEST(elementwise_op, native_add_cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_elementwise_op<::anakin::saber::X86>("elementwise_add", ctx, false);
}
TEST(elementwise_op, native_mul_cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_elementwise_op<::anakin::saber::X86>("elementwise_mul", ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(elementwise_add);
USE_ANAKIN_CONVERTER(elementwise_add);
USE_OP(elementwise_mul);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(elementwise_add);
USE_ANAKIN_CONVERTER(elementwise_mul);
#endif
USE_CPU_ANAKIN_CONVERTER(elementwise_add);
USE_CPU_ANAKIN_CONVERTER(elementwise_mul);
......@@ -20,13 +20,13 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(fc_op, test) {
auto* fc_converter = Registry<AnakinOpConverter>::Global().Lookup("fc");
ASSERT_TRUE(fc_converter);
template <typename TargetT>
void test_mul_op(const platform::DeviceContext& context, bool use_gpu) {
std::unordered_set<std::string> parameters({"mul_y"});
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("mul_x", {1, 1, 2, 2});
validator.DeclParamVar("mul_y", {4, 2});
validator.DeclOutputVar("mul_out", {1, 2});
......@@ -42,9 +42,26 @@ TEST(fc_op, test) {
validator.Execute(10);
}
#ifdef PADDLE_WITH_CUDA
TEST(mul_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_mul_op<::anakin::saber::NV>(ctx, true);
}
#endif
TEST(mul_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_mul_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(mul);
USE_CPU_ANAKIN_CONVERTER(fc);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(fc);
#endif
......@@ -20,13 +20,12 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(flatten_op, test) {
auto *converter = Registry<AnakinOpConverter>::Global().Lookup("flatten");
ASSERT_TRUE(converter);
template <typename TargetT>
void test_flatten_op(const platform::DeviceContext& context, bool use_gpu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("flatten-X", {3, 10, 10, 4});
validator.DeclOutputVar("flatten-Out", {3, 400, 1, 1});
framework::OpDesc desc;
......@@ -42,10 +41,27 @@ TEST(flatten_op, test) {
validator.Execute(5);
}
#ifdef PADDLE_WITH_CUDA
TEST(flatten_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_flatten_op<::anakin::saber::NV>(ctx, true);
}
#endif
TEST(flatten_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_flatten_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(reshape);
USE_OP_ITSELF(flatten);
USE_CPU_ANAKIN_CONVERTER(flatten);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(flatten);
#endif
......@@ -19,15 +19,14 @@ namespace paddle {
namespace inference {
namespace anakin {
void test_pool2d(bool global_pooling, bool ceil_mode,
template <typename TargetT>
void test_pool2d(const platform::DeviceContext& context, bool use_gpu,
bool global_pooling, bool ceil_mode,
std::string pool_type = "max") {
auto* pool2d_converter =
Registry<AnakinOpConverter>::Global().Lookup("pool2d");
ASSERT_TRUE(pool2d_converter);
framework::Scope scope;
std::unordered_set<std::string> parameters;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
......@@ -64,56 +63,61 @@ void test_pool2d(bool global_pooling, bool ceil_mode,
validator.Execute(1);
}
void test_pool2d2(bool global_pooling, bool ceil_mode,
std::string pool_type = "max") {
auto* pool2d_converter =
Registry<AnakinOpConverter>::Global().Lookup("pool2d");
ASSERT_TRUE(pool2d_converter);
framework::Scope scope;
std::unordered_set<std::string> parameters;
AnakinConvertValidation validator(parameters, &scope);
// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
validator.DeclInputVar("pool2d_x", {1, 1, 17, 17});
validator.DeclOutputVar("pool2d_out", {1, 1, 17, 17});
// Prepare Op description
framework::OpDesc desc;
desc.SetType("pool2d");
desc.SetInput("X", {"pool2d_x"});
desc.SetOutput("Out", {"pool2d_out"});
std::vector<int> ksize({3, 3});
std::vector<int> strides({1, 1});
std::vector<int> paddings({1, 1});
std::string pooling_t = pool_type;
#ifdef PADDLE_WITH_CUDA
TEST(Pool2dOpConverter, normal) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_pool2d<::anakin::saber::NV>(ctx, true, false, false);
}
TEST(Pool2dOpConverter, test_global_pooling) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_pool2d<::anakin::saber::NV>(ctx, true, true, false);
}
desc.SetAttr("pooling_type", pooling_t);
desc.SetAttr("ksize", ksize);
desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings);
desc.SetAttr("global_pooling", global_pooling);
desc.SetAttr("ceil_mode", true);
TEST(Pool2dOpConverter, max_ceil_test) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_pool2d<::anakin::saber::NV>(ctx, true, false, true);
}
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
TEST(Pool2dOpConverter, avg_ceil_test) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_pool2d<::anakin::saber::NV>(ctx, true, false, true, "avg");
}
#endif
validator.Execute(1);
TEST(Pool2dOpConverter, normal_cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_pool2d<::anakin::saber::X86>(ctx, false, false, false);
}
TEST(Pool2dOpConverter, test_global_pooling_cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_pool2d<::anakin::saber::X86>(ctx, false, true, false);
}
TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); }
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); }
TEST(Pool2dOpConverter, max_ceil_test_cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_pool2d<::anakin::saber::X86>(ctx, false, false, true);
}
TEST(Pool2dOpConverter, max_ceil_test) { test_pool2d(false, true); }
TEST(Pool2dOpConverter, avg_ceil_test) { test_pool2d(false, true, "avg"); }
TEST(Pool2dOpConverter, avg_ceil_test2) { test_pool2d2(false, true, "avg"); }
TEST(Pool2dOpConverter, avg_ceil_test_cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_pool2d<::anakin::saber::X86>(ctx, false, false, true, "avg");
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(pool2d);
USE_CPU_ANAKIN_CONVERTER(pool2d);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(pool2d);
#endif
......@@ -21,18 +21,23 @@ namespace paddle {
namespace inference {
namespace anakin {
static void test_activation_op(const std::string &op_type) {
auto *converter = Registry<AnakinOpConverter>::Global().Lookup(op_type);
PADDLE_ENFORCE(converter != nullptr);
template <typename TargetT>
static void test_activation_op(const std::string& op_type,
const platform::DeviceContext& context,
bool use_gpu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("act-X", {10, 6, 1, 1});
validator.DeclOutputVar("act-Out", {10, 6, 1, 1});
framework::OpDesc desc;
desc.SetType(op_type);
desc.SetInput("X", {"act-X"});
desc.SetOutput("Out", {"act-Out"});
if (op_type == "leaky_relu") {
desc.SetAttr("alpha", 0.1f);
}
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
......@@ -41,10 +46,30 @@ static void test_activation_op(const std::string &op_type) {
validator.Execute(5);
}
TEST(sigm_op, test) { test_activation_op("relu"); }
#ifdef PADDLE_WITH_CUDA
TEST(relu_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_activation_op<::anakin::saber::NV>("relu", ctx, true);
}
TEST(leaky_relu_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_activation_op<::anakin::saber::NV>("leaky_relu", ctx, true);
}
#endif
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(relu);
USE_OP(leaky_relu);
USE_CPU_ANAKIN_CONVERTER(relu);
USE_CPU_ANAKIN_CONVERTER(leaky_relu);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(relu);
USE_ANAKIN_CONVERTER(leaky_relu);
#endif
......@@ -20,12 +20,12 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(reshape, test) {
auto* converter = Registry<AnakinOpConverter>::Global().Lookup("reshape");
ASSERT_TRUE(converter);
template <typename TargetT>
void test_reshape1_op(const platform::DeviceContext& context, bool use_gpu) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
// validator.DeclInputVar("reshape-X", {2, 3, 3, 1});
// validator.DeclOutputVar("reshape-Out", {3, 2, 1, 3});
......@@ -45,10 +45,12 @@ TEST(reshape, test) {
validator.Execute(1);
}
TEST(reshape, test2) {
template <typename TargetT>
void test_reshape2_op(const platform::DeviceContext& context, bool use_gpu) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("reshape-X", {1, 2, 4});
validator.DeclOutputVar("reshape-Out", {1, 4, 2});
......@@ -66,9 +68,39 @@ TEST(reshape, test2) {
validator.Execute(1);
}
#ifdef PADDLE_WITH_CUDA
TEST(reshape1_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_reshape1_op<::anakin::saber::NV>(ctx, true);
}
TEST(reshape2_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_reshape2_op<::anakin::saber::NV>(ctx, true);
}
#endif
TEST(reshape1_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_reshape2_op<::anakin::saber::X86>(ctx, false);
}
TEST(reshape2_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_reshape2_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(reshape);
USE_CPU_ANAKIN_CONVERTER(reshape);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(reshape);
#endif
......@@ -20,12 +20,12 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(softmax, test) {
auto* converter = Registry<AnakinOpConverter>::Global().Lookup("softmax");
ASSERT_TRUE(converter);
template <typename TargetT>
void test_softmax_op(const platform::DeviceContext& context, bool use_gpu) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("softmax-X", {1, 10, 2});
validator.DeclOutputVar("softmax-Out", {1, 10, 2});
......@@ -41,9 +41,27 @@ TEST(softmax, test) {
validator.Execute(1);
}
#ifdef PADDLE_WITH_CUDA
TEST(softmax_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_softmax_op<::anakin::saber::NV>(ctx, true);
}
#endif
TEST(relu_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_softmax_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(softmax);
USE_CPU_ANAKIN_CONVERTER(softmax);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(softmax);
#endif
......@@ -21,12 +21,14 @@ namespace paddle {
namespace inference {
namespace anakin {
template <int Axis>
void AnakinSliceTest(const std::vector<int> &in_shape,
template <typename TargetT, int Axis>
void AnakinSliceTest(const platform::DeviceContext &context, bool use_gpu,
const std::vector<int> &in_shape,
const std::vector<int> &sections) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("split_input", in_shape);
std::vector<std::string> output_vars;
......@@ -55,51 +57,58 @@ void AnakinSliceTest(const std::vector<int> &in_shape,
// batch = 0, axis = 1, same shape
TEST(split_op, test_same_shape_axis1_batch1) {
AnakinSliceTest<1>({1, 4, 2, 2}, {2, 2});
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
AnakinSliceTest<::anakin::saber::NV, 1>(ctx, true, {1, 4, 2, 2}, {2, 2});
}
// batch = 0, axis = 1, different shape
TEST(split_op, test_different_shape_axis1_batch1) {
AnakinSliceTest<1>({1, 3, 2, 2}, {2, 1});
}
// batch = 10, axis = 1, same shape
TEST(split_op, test_same_shape_axis1_batch10) {
AnakinSliceTest<1>({1, 4, 2, 2}, {2, 2});
}
// batch = 10, axis = 1, different shape
TEST(split_op, test_different_shape_axis1_batch10) {
AnakinSliceTest<1>({1, 3, 2, 2}, {2, 1});
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
AnakinSliceTest<::anakin::saber::NV, 1>(ctx, true, {1, 3, 2, 2}, {2, 1});
}
// batch = 0, axis = 2, same shape
TEST(split_op, test_same_shape_axis2_batch1) {
AnakinSliceTest<2>({1, 3, 4, 2}, {2, 2});
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
AnakinSliceTest<::anakin::saber::NV, 2>(ctx, true, {1, 3, 4, 2}, {2, 2});
}
// batch = 0, axis = 2, different shape
TEST(split_op, test_different_shape_axis2_batch1) {
AnakinSliceTest<2>({1, 3, 3, 2}, {2, 1});
}
// batch = 10, axis = 2, same shape
TEST(split_op, test_same_shape_axis2_batch10) {
AnakinSliceTest<2>({1, 3, 4, 2}, {2, 2});
}
// batch = 10, axis = 2, different shape
TEST(split_op, test_different_shape_axis2_batch10) {
AnakinSliceTest<2>({1, 3, 3, 2}, {2, 1});
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
AnakinSliceTest<::anakin::saber::NV, 2>(ctx, true, {1, 3, 3, 2}, {2, 1});
}
// batch = 0, axis = 3, same shape
TEST(split_op, test_same_shape_axis3_batch1) {
AnakinSliceTest<3>({1, 3, 2, 4}, {2, 2});
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
AnakinSliceTest<::anakin::saber::NV, 3>(ctx, true, {1, 3, 2, 4}, {2, 2});
}
// batch = 0, axis = 3, different shape
TEST(split_op, test_different_shape_axis3_batch1) {
AnakinSliceTest<3>({1, 3, 2, 3}, {2, 1});
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
AnakinSliceTest<::anakin::saber::NV, 3>(ctx, true, {1, 3, 2, 3}, {2, 1});
}
// batch = 10, axis = 3, same shape
TEST(split_op, test_same_shape_axis3_batch10) {
AnakinSliceTest<3>({1, 3, 2, 4}, {2, 2});
TEST(split_op, test_different_shape_axis1_batch1_cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
AnakinSliceTest<::anakin::saber::X86, 1>(ctx, false, {1, 3, 2, 3}, {2, 1});
}
TEST(split_op, test_different_shape_axis2_batch1_cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
AnakinSliceTest<::anakin::saber::X86, 2>(ctx, false, {1, 3, 4, 2}, {2, 2});
}
// batch = 10, axis = 3, different shape
TEST(split_op, test_different_shape_axis3_batch10) {
AnakinSliceTest<3>({1, 3, 2, 3}, {2, 1});
TEST(split_op, test_different_shape_axis3_batch1_cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
AnakinSliceTest<::anakin::saber::X86, 3>(ctx, false, {1, 3, 2, 4}, {2, 2});
}
} // namespace anakin
......@@ -107,4 +116,7 @@ TEST(split_op, test_different_shape_axis3_batch10) {
} // namespace paddle
USE_OP(split);
USE_CPU_ANAKIN_CONVERTER(split);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(split);
#endif
......@@ -22,10 +22,12 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(sum, native) {
template <typename TargetT>
static void test_sum_op(const platform::DeviceContext& context, bool use_gpu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("sum_x1", {1, 2, 1, 2});
validator.DeclInputVar("sum_x2", {1, 2, 1, 2});
validator.DeclOutputVar("sum_out", {1, 2, 1, 2});
......@@ -40,9 +42,26 @@ TEST(sum, native) {
validator.Execute(1);
}
#ifdef PADDLE_WITH_CUDA
TEST(sum_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_sum_op<::anakin::saber::NV>(ctx, true);
}
#endif
TEST(sum_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_sum_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(sum);
USE_CPU_ANAKIN_CONVERTER(sum);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(sum);
#endif
......@@ -20,12 +20,12 @@ namespace paddle {
namespace inference {
namespace anakin {
TEST(transpose_op, test) {
auto* converter = Registry<AnakinOpConverter>::Global().Lookup("transpose");
ASSERT_TRUE(converter != nullptr);
template <typename TargetT>
void test_transpose1_op(const platform::DeviceContext& context, bool use_gpu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("transpose-X", {2, 3, 4, 5});
validator.DeclOutputVar("transpose-Out", {4, 2, 5, 3});
......@@ -43,11 +43,12 @@ TEST(transpose_op, test) {
validator.Execute(3);
}
// test input shape's dims < 4
TEST(transpose_op, test2) {
template <typename TargetT>
void test_transpose2_op(const platform::DeviceContext& context, bool use_gpu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
AnakinConvertValidation validator(parameters, &scope);
AnakinConvertValidation<TargetT, ::anakin::Precision::FP32> validator(
parameters, &scope, context, use_gpu);
validator.DeclInputVar("transpose-X", {3, 4, 5});
validator.DeclOutputVar("transpose-Out", {3, 5, 4});
......@@ -65,9 +66,38 @@ TEST(transpose_op, test2) {
validator.Execute(1);
}
#ifdef PADDLE_WITH_CUDA
TEST(transpose1_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_transpose1_op<::anakin::saber::NV>(ctx, true);
}
TEST(transpose2_op, gpu) {
platform::CUDAPlace gpu_place(0);
platform::CUDADeviceContext ctx(gpu_place);
test_transpose2_op<::anakin::saber::NV>(ctx, true);
}
#endif
TEST(transpose1_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_transpose2_op<::anakin::saber::X86>(ctx, false);
}
TEST(transpose2_op, cpu) {
platform::CPUPlace cpu_place;
platform::CPUDeviceContext ctx(cpu_place);
test_transpose2_op<::anakin::saber::X86>(ctx, false);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(transpose);
USE_CPU_ANAKIN_CONVERTER(transpose);
#ifdef PADDLE_WITH_CUDA
USE_ANAKIN_CONVERTER(transpose);
#endif
......@@ -17,20 +17,16 @@
#include <string>
#include <vector>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
void TransposeOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void TransposeOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
......@@ -38,7 +34,7 @@ void TransposeOpConverter::operator()(const framework::proto::OpDesc &op,
auto input = op_desc.Input("X").front();
auto output = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Permute", {input}, {output});
this->engine_->AddOp(op_name, "Permute", {input}, {output});
auto axis = boost::get<std::vector<int>>(op_desc.GetAttr("axis"));
size_t axis_size = axis.size();
......@@ -46,7 +42,7 @@ void TransposeOpConverter::operator()(const framework::proto::OpDesc &op,
axis.push_back(axis_size);
axis_size += 1;
}
engine_->AddOpAttr<PTuple<int>>(op_name, "dims", axis);
this->engine_->template AddOpAttr<PTuple<int>>(op_name, "dims", axis);
}
} // namespace anakin
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class TransposeOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class TransposeOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
TransposeOpConverter() = default;
......
......@@ -32,14 +32,8 @@ limitations under the License. */
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/enforce.h"
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
......@@ -55,8 +49,8 @@ float random(float low, float high) {
return dist(mt);
}
void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
const platform::DeviceContext& ctx) {
void RandomizeTensor(framework::LoDTensor* tensor,
const platform::Place& place) {
auto dims = tensor->dims();
size_t num_elements = analysis::AccuDims(dims, dims.size());
PADDLE_ENFORCE_GT(num_elements, 0);
......@@ -78,17 +72,19 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
* anakin
* layer.
*/
template <typename TargetT, ::anakin::Precision PrecisionT>
class AnakinConvertValidation {
using AnakinNvEngineT = AnakinEngine<NV, Precision::FP32>;
using AnakinNvEngineT = AnakinEngine<TargetT, PrecisionT>;
public:
AnakinConvertValidation() = delete;
AnakinConvertValidation(const std::unordered_set<std::string>& parameters,
framework::Scope* scope)
: parameters_(parameters), scope_(scope), place_(0) {
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
engine_.reset(new AnakinEngine<NV, Precision::FP32>(true));
framework::Scope* scope,
const platform::DeviceContext& ctx,
bool use_gpu = true)
: parameters_(parameters), scope_(scope), ctx_(ctx), use_gpu_(use_gpu) {
engine_.reset(new AnakinEngine<TargetT, PrecisionT>(true));
}
// Declare a Variable as input with random initialization.
......@@ -108,11 +104,10 @@ class AnakinConvertValidation {
}
void DeclVar(const std::string& name, const std::vector<int> dim_vec) {
platform::CUDADeviceContext ctx(place_);
auto* x = scope_->Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec));
RandomizeTensor(x_tensor, place_, ctx);
RandomizeTensor(x_tensor, ctx_.GetPlace());
std::vector<int64_t> dim_vec_int64;
for (auto& ele : dim_vec) {
......@@ -132,7 +127,7 @@ class AnakinConvertValidation {
// should init anakin engine here.
auto& block_desc = program_desc_.Block(framework::kRootBlockIndex);
Singleton<AnakinOpConverter>::Global().ConvertOp(
Singleton<AnakinOpConverter<TargetT, PrecisionT>>::Global().ConvertOp(
desc, block_desc, parameters_, *scope_, engine_.get(),
true /*test_mode*/);
engine_->Freeze();
......@@ -151,7 +146,7 @@ class AnakinConvertValidation {
}
engine_->SetMaxInputShape(temp_max_input_shape);
engine_->Optimize();
engine_->InitGraph();
engine_->InitNet();
}
// We use the set 'neglected_output' here, because some Ops like batch norm,
......@@ -160,11 +155,8 @@ class AnakinConvertValidation {
void Execute(int batch_size,
std::unordered_set<std::string> neglected_output = {}) {
// Execute Fluid Op
platform::CUDADeviceContext ctx(place_);
op_->Run(*scope_, place_);
op_->Run(*scope_, ctx_.GetPlace());
// std::vector<framework::LoDTensor> input_vector;
// std::vector<framework::LoDTensor> output_vector;
std::map<std::string, framework::LoDTensor*> inputs;
for (const auto& input : op_desc_->InputArgumentNames()) {
if (parameters_.count(input)) continue;
......@@ -180,20 +172,27 @@ class AnakinConvertValidation {
std::vector<float> fluid_out;
auto* var = scope_->FindVar(output);
auto tensor = var->GetMutable<framework::LoDTensor>();
framework::TensorToVector(*tensor, ctx, &fluid_out);
framework::TensorToVector(*tensor, ctx_, &fluid_out);
fluid_outputs.push_back(fluid_out);
outputs.insert({output, tensor});
}
engine_->Execute(inputs, outputs, stream_);
if (!use_gpu_) {
engine_->Execute(inputs, outputs);
} else {
cudaStream_t stream;
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream), 0);
engine_->Execute(inputs, outputs, stream);
}
int i_output = 0;
for (const auto& output : op_desc_->OutputArgumentNames()) {
if (neglected_output.count(output)) continue;
std::vector<float> anakin_out;
auto* var = scope_->FindVar(output);
auto tensor = var->GetMutable<framework::LoDTensor>();
framework::TensorToVector(*tensor, ctx, &anakin_out);
framework::TensorToVector(*tensor, ctx_, &anakin_out);
size_t anakin_out_size = anakin_out.size();
auto fluid_out = fluid_outputs[i_output++];
......@@ -205,15 +204,24 @@ class AnakinConvertValidation {
private:
std::unique_ptr<AnakinNvEngineT> engine_{nullptr};
cudaStream_t stream_;
std::unique_ptr<framework::OperatorBase> op_;
std::unique_ptr<framework::OpDesc> op_desc_;
framework::ProgramDesc program_desc_;
const std::unordered_set<std::string>& parameters_;
framework::Scope* scope_;
platform::CUDAPlace place_;
const platform::DeviceContext& ctx_;
bool use_gpu_{true};
};
template class AnakinConvertValidation<::anakin::saber::NV,
::anakin::Precision::FP32>;
template class AnakinConvertValidation<::anakin::saber::X86,
::anakin::Precision::FP32>;
template class AnakinConvertValidation<::anakin::saber::NV,
::anakin::Precision::INT8>;
template class AnakinConvertValidation<::anakin::saber::X86,
::anakin::Precision::INT8>;
} // namespace anakin
} // namespace inference
} // namespace paddle
......@@ -35,12 +35,15 @@ namespace anakin {
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
AnakinEngine<TargetT, PrecisionType, RunType>::AnakinEngine(
bool need_summary, int device, int max_batch_size,
std::map<std::string, std::vector<int>> max_input_shape)
std::map<std::string, std::vector<int>> max_input_shape,
std::vector<std::string> program_inputs, bool auto_config_layout)
: graph_(new AnakinGraphT<TargetT, PrecisionType>()),
net_(new AnakinNetT<TargetT, PrecisionType, RunType>(need_summary)) {
device_ = device;
max_batch_size_ = max_batch_size;
max_input_shape_ = max_input_shape;
program_inputs_ = program_inputs;
auto_config_layout_ = auto_config_layout;
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
......@@ -54,8 +57,8 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::SetInputShape(
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::InitGraph() {
net_->init(*graph_);
void AnakinEngine<TargetT, PrecisionType, RunType>::InitNet() {
net_->init(*graph_, auto_config_layout_);
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
......@@ -67,11 +70,11 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::AddOp(
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs,
cudaStream_t stream) {
void AnakinEngine<TargetT, PrecisionType, RunType>::BindInput(
const std::map<std::string, framework::LoDTensor *> &inputs) {
#ifdef PADDLE_WITH_CUDA
cudaDeviceSynchronize();
#endif
for (const auto &input : inputs) {
auto *tensor = input.second;
auto *data = tensor->data<float>();
......@@ -85,16 +88,53 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
int max_shape_sum =
std::accumulate(max_input_shape.begin(), max_input_shape.end(), 1,
std::multiplies<int>());
PADDLE_ENFORCE(max_shape_sum >= tensor->numel(),
"The anakin input max shape should be greater than"
" or equal to the real input shape, Please set the max "
"input shape using EnableAnakinEngine");
if (tensor->numel() > max_shape_sum) {
PADDLE_ENFORCE(std::find(program_inputs_.begin(), program_inputs_.end(),
input.first) == program_inputs_.end(),
"The anakin input max shape should be greater than"
" or equal to the real input shape, Please set the max "
"input shape using EnableAnakinEngine");
VLOG(3) << "Anakin Net will be reset because of the inputs out of range: "
<< input.first;
graph_->Reshape(input.first, fluid_input_shape);
net_.reset(new AnakinNetT<TargetT, PrecisionType, RunType>(true));
net_->init(*graph_);
anakin_input = net_->get_in(input.first);
}
anakin_input->reshape(fluid_input_shape);
::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0,
fluid_input_shape);
anakin_input->copy_from(tmp_anakin_tensor);
}
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs) {
BindInput(inputs);
net_->prediction();
for (const auto &output : outputs) {
platform::CPUPlace cpu_place;
auto *tensor = output.second;
auto *anakin_output = net_->get_out(output.first);
auto *anakin_data = anakin_output->data();
auto anakin_output_shape = anakin_output->valid_shape();
tensor->Resize(framework::make_ddim(anakin_output_shape));
auto *fluid_data = tensor->mutable_data<float>(cpu_place);
memory::Copy(cpu_place, static_cast<void *>(fluid_data), cpu_place,
static_cast<void *>(anakin_data),
tensor->numel() * sizeof(float));
}
}
#ifdef PADDLE_WITH_CUDA
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs,
cudaStream_t stream) {
BindInput(inputs);
net_->prediction();
cudaDeviceSynchronize();
for (const auto &output : outputs) {
......@@ -111,10 +151,11 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
}
cudaDeviceSynchronize();
}
#endif
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::Freeze() {
PADDLE_ENFORCE(graph_->Freeze_v3(), "Freeze anakin subgraph.");
PADDLE_ENFORCE(graph_->Freeze(), "Freeze anakin subgraph.");
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
......@@ -122,6 +163,12 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::Optimize() {
PADDLE_ENFORCE(graph_->Optimize(), "Graph optimization.");
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::RegistBlock(
::anakin::PBlock<TargetT> *block_p) {
PADDLE_ENFORCE(graph_->RegistBlock(block_p), "Block register.");
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
std::unique_ptr<AnakinEngine<TargetT, PrecisionType, RunType>>
AnakinEngine<TargetT, PrecisionType, RunType>::Clone() {
......@@ -130,7 +177,24 @@ AnakinEngine<TargetT, PrecisionType, RunType>::Clone() {
return std::unique_ptr<AnakinEngine>(engine);
}
#ifdef PADDLE_WITH_CUDA
template class AnakinEngine<::anakin::saber::NV, ::anakin::Precision::FP32>;
template class AnakinEngineManager<::anakin::saber::NV,
::anakin::Precision::FP32>;
template class AnakinEngine<::anakin::saber::NV, ::anakin::Precision::INT8>;
template class AnakinEngineManager<::anakin::saber::NV,
::anakin::Precision::INT8>;
#endif
template class AnakinEngine<::anakin::saber::X86, ::anakin::Precision::FP32>;
template class AnakinEngineManager<::anakin::saber::X86,
::anakin::Precision::FP32>;
template class AnakinEngine<::anakin::saber::X86, ::anakin::Precision::INT8>;
template class AnakinEngineManager<::anakin::saber::X86,
::anakin::Precision::INT8>;
// template class AnakinEngine<::anakin::saber::X86, ::anakin::Precision::FP32>;
} // namespace anakin
} // namespace inference
} // namespace paddle
......@@ -32,7 +32,6 @@
#include "saber/saber_types.h"
using anakin::Precision;
using anakin::saber::NV;
namespace anakin {
......@@ -58,9 +57,11 @@ class AnakinEngine {
public:
explicit AnakinEngine(
bool need_summary = false, int device = 0, int max_batch_size = 1,
std::map<std::string, std::vector<int>> max_input_shape = {});
std::map<std::string, std::vector<int>> max_input_shape = {},
std::vector<std::string> program_inputs = {},
bool auto_config_layout = false);
~AnakinEngine();
void InitGraph();
void InitNet();
void SetInputShape(const std::string &name, std::vector<int> shape);
void AddOp(const std::string &name, const std::string &type,
const std::vector<std::string> &inputs,
......@@ -81,20 +82,35 @@ class AnakinEngine {
void SetMaxInputShape(std::map<std::string, std::vector<int>> shape) {
max_input_shape_ = shape;
}
const std::vector<std::string> &GetScalableInputs() {
return program_inputs_;
}
void SetScalableInputs(std::vector<std::string> program_inputs) {
program_inputs_ = program_inputs;
}
int GetMaxBatchSize() { return max_batch_size_; }
void Freeze();
void Optimize();
void AllocTmpMem() {
PADDLE_ENFORCE(net_->alloc_memory_first(*graph_),
"anakin alloc temp memory first failed");
}
void RegistBlock(::anakin::PBlock<TargetT> *block_p);
void Save(std::string path) { graph_->save(path); }
bool IsInit() { return initialized_; }
int GetDevice() { return device_; }
void AddTensorScale(const std::string &tensor_name, float scale) {
tensor_scales_[tensor_name] = scale;
}
std::unordered_map<std::string, float> GetTensorScales() {
return tensor_scales_;
}
void Execute(const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs);
#ifdef PADDLE_WITH_CUDA
void Execute(const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs,
cudaStream_t stream);
#endif
private:
void BindInput(const std::map<std::string, framework::LoDTensor *> &inputs);
private:
bool initialized_{false};
......@@ -103,27 +119,33 @@ class AnakinEngine {
int device_;
std::unique_ptr<GraphT> graph_;
std::unique_ptr<NetT> net_;
std::vector<std::string> program_inputs_;
std::unordered_map<std::string, float> tensor_scales_;
// Always be false in gpu mode but true in most cpu cases.
bool auto_config_layout_;
};
template <typename TargetT, ::anakin::Precision PrecisionType>
class AnakinEngineManager {
using AnakinNvEngineT = AnakinEngine<NV, Precision::FP32>;
using AnakinEngineT = AnakinEngine<TargetT, PrecisionType>;
public:
bool HasEngine(const std::string &name) const {
if (engines_.count(name) == 0) return false;
return engines_.at(name).get() != nullptr;
}
AnakinNvEngineT *Get(const std::string &name) const {
AnakinEngineT *Get(const std::string &name) const {
return engines_.at(name).get();
}
AnakinNvEngineT *Create(
bool need_summary, int device, int max_batch_size,
std::map<std::string, std::vector<int>> max_input_shape,
std::string engine_name) {
AnakinEngineT *Create(bool need_summary, int device, int max_batch_size,
std::map<std::string, std::vector<int>> max_input_shape,
std::vector<std::string> program_inputs,
bool auto_config_layout, std::string engine_name) {
std::unique_lock<std::mutex> lk(mut_);
auto *p = new AnakinEngine<NV, Precision::FP32>(
need_summary, device, max_batch_size, max_input_shape);
auto *p = new AnakinEngine<TargetT, PrecisionType>(
need_summary, device, max_batch_size, max_input_shape, program_inputs,
auto_config_layout);
engines_[engine_name].reset(p);
return p;
}
......@@ -135,7 +157,7 @@ class AnakinEngineManager {
}
private:
std::unordered_map<std::string, std::unique_ptr<AnakinNvEngineT>> engines_;
std::unordered_map<std::string, std::unique_ptr<AnakinEngineT>> engines_;
std::mutex mut_;
};
} // namespace anakin
......
......@@ -44,6 +44,11 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("sum");
teller_set.insert("depthwise_conv2d");
teller_set.insert("prior_box");
teller_set.insert("leaky_relu");
teller_set.insert("affine_channel");
teller_set.insert("relu6");
teller_set.insert("swish");
teller_set.insert("shuffle_channel");
}
bool operator()(const std::string& op_type,
......
......@@ -19,7 +19,6 @@ limitations under the License. */
#include "paddle/fluid/inference/anakin/engine.h"
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
......@@ -52,11 +51,9 @@ TEST_F(TestAnakinEngine, Execute) {
engine_->AddOpAttr("op1", "axis", 1);
std::vector<int> shape = {1, 1, 1, 2};
Shape tmp_shape(shape);
// PBlock<NV> weight1(tmp_shape);
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(tmp_shape);
// auto *weight1 = new PBlock<NV>(tmp_shape, AK_FLOAT);
PBlock<NV> *weight1 = new PBlock<NV>(tmp_shape, AK_FLOAT);
engine_->RegistBlock(weight1);
float *cpu_data = static_cast<float *>(weight1->h_tensor().mutable_data());
cpu_data[0] = 2.;
weight1->d_tensor().set_shape(tmp_shape);
......@@ -68,7 +65,7 @@ TEST_F(TestAnakinEngine, Execute) {
// engine_->AddOpAttr("x", "input_shape", input_shape);
engine_->SetInputShape("x", {1, 1, 1, 1});
engine_->Optimize();
engine_->InitGraph();
engine_->InitNet();
framework::LoDTensor x;
framework::LoDTensor y;
x.Resize({1, 1, 1, 1});
......
......@@ -64,20 +64,20 @@ struct Argument {
bool Has(const std::string& key) const { return valid_fields_.count(key); }
#define DECL_ARGUMENT_FIELD(field__, Field, type__) \
public: \
type__& field__() { \
PADDLE_ENFORCE(Has(#field__)); \
return field__##_; \
} \
void Set##Field(const type__& x) { \
field__##_ = x; \
valid_fields_.insert(#field__); \
} \
DECL_ARGUMENT_FIELD_VALID(field__); \
type__* field__##_ptr() { return &field__##_; } \
\
private: \
#define DECL_ARGUMENT_FIELD(field__, Field, type__) \
public: \
type__& field__() { \
PADDLE_ENFORCE(Has(#field__), "There is no such field"); \
return field__##_; \
} \
void Set##Field(const type__& x) { \
field__##_ = x; \
valid_fields_.insert(#field__); \
} \
DECL_ARGUMENT_FIELD_VALID(field__); \
type__* field__##_ptr() { return &field__##_; } \
\
private: \
type__ field__##_;
#define DECL_ARGUMENT_FIELD_VALID(field__) \
......@@ -169,7 +169,14 @@ struct Argument {
anakin_max_shape_t);
DECL_ARGUMENT_FIELD(anakin_max_batch_size, AnakinMaxBatchSize, int);
DECL_ARGUMENT_FIELD(anakin_min_subgraph_size, AnakinMinSubgraphSize, int);
DECL_ARGUMENT_FIELD(anakin_precision_mode, AnakinPrecisionMode,
AnalysisConfig::Precision);
DECL_ARGUMENT_FIELD(anakin_auto_config_layout, AnakinAutoConfigLayout, bool);
DECL_ARGUMENT_FIELD(use_anakin, UseAnakin, bool);
DECL_ARGUMENT_FIELD(anakin_passes_filter, AnakinPassesFilter,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(anakin_ops_filter, AnakinOpsFilter,
std::vector<std::string>);
// Memory optimized related.
DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
......
......@@ -114,6 +114,7 @@ void IRPassManager::CreatePasses(Argument *argument,
if (pass_name == "anakin_subgraph_pass") {
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
pass->Set("use_gpu", new bool(argument->use_gpu()));
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
pass->Set("model_from_memory", new bool(argument->model_from_memory()));
pass->Set("engine_opt_info", new std::map<std::string, std::string>(
......@@ -122,6 +123,13 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("max_input_shape", new std::map<std::string, std::vector<int>>(
argument->anakin_max_input_shape()));
pass->Set("max_batch_size", new int(argument->anakin_max_batch_size()));
bool enable_int8 =
argument->anakin_precision_mode() == AnalysisConfig::Precision::kInt8;
pass->Set("enable_int8", new bool(enable_int8));
pass->Set("anakin_ops_filter",
new std::vector<std::string>(argument->anakin_ops_filter()));
pass->Set("auto_config_layout",
new bool(argument->anakin_auto_config_layout()));
}
pre_pass = pass_name;
......
......@@ -39,8 +39,14 @@ void analysis::AnakinSubgraphPass::ApplyImpl(
framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("anakin_subgraph_pass", graph);
auto teller = [](const framework::ir::Node *node) {
if (!node->IsOp() || !node->Op()) return false;
auto &anakin_ops_filter = Get<std::vector<std::string>>("anakin_ops_filter");
auto teller = [&anakin_ops_filter](const framework::ir::Node *node) {
if (!node->IsOp() || !node->Op())
return false;
else if (std::find(anakin_ops_filter.begin(), anakin_ops_filter.end(),
node->Op()->Type()) != anakin_ops_filter.end())
return false;
return anakin::OpTeller::Global().Tell(node->Op()->Type(), *node->Op());
};
......@@ -191,22 +197,78 @@ void AnakinSubgraphPass::CreateAnakinOp(
SetAttr(op_desc->Proto(), "engine_key", engine_key);
auto max_input_shape =
Get<std::map<std::string, std::vector<int>>>("max_input_shape");
auto max_batch_size = Get<int>("max_batch_size");
auto program_inputs = program_desc->GetFeedTargetNames();
auto *anakin_engine =
inference::Singleton<anakin::AnakinEngineManager>::Global().Create(
true, Get<int>("gpu_device_id"), max_batch_size, max_input_shape,
engine_key);
bool use_gpu = Get<bool>("use_gpu");
SetAttr(op_desc->Proto(), "use_gpu", use_gpu);
bool enable_int8 = Get<bool>("enable_int8");
SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
if (enable_int8) {
CreateAnakinEngine<::anakin::Precision::INT8>(&block_desc, params,
input_names, output_mapping,
program_inputs, engine_key);
} else {
CreateAnakinEngine<::anakin::Precision::FP32>(&block_desc, params,
input_names, output_mapping,
program_inputs, engine_key);
}
}
template <::anakin::Precision PrecisionT>
void AnakinSubgraphPass::CreateAnakinEngine(
framework::BlockDesc *block_desc, const std::vector<std::string> &params,
const std::set<std::string> &input_names,
const std::vector<std::string> &output_mapping,
const std::vector<std::string> &program_inputs,
const std::string &engine_key) const {
framework::BlockDesc block_desc_temp(nullptr, block_desc->Proto());
bool use_gpu = Get<bool>("use_gpu");
auto max_batch_size = Get<int>("max_batch_size");
auto max_input_shape =
Get<std::map<std::string, std::vector<int>>>("max_input_shape");
bool auto_config_layout = Get<bool>("auto_config_layout");
if (use_gpu) {
#ifdef PADDLE_WITH_CUDA
inference::Singleton<
anakin::AnakinEngineManager<::anakin::saber::NV, PrecisionT>>::Global()
.Create(true, Get<int>("gpu_device_id"), max_batch_size,
max_input_shape, program_inputs, false, engine_key);
#endif
} else {
inference::Singleton<
anakin::AnakinEngineManager<::anakin::saber::X86, PrecisionT>>::Global()
.Create(true, Get<int>("gpu_device_id"), max_batch_size,
max_input_shape, program_inputs, auto_config_layout,
engine_key);
}
auto *scope = param_scope();
std::unordered_set<std::string> param_set(params.begin(), params.end());
framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
inference::Singleton<inference::anakin::AnakinOpConverter>::Global()
.ConvertBlockToAnakinEngine(
&block_desc_temp, scope,
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, anakin_engine);
if (use_gpu) {
#ifdef PADDLE_WITH_CUDA
auto *anakin_engine =
inference::Singleton<inference::anakin::AnakinEngineManager<
::anakin::saber::NV, PrecisionT>>::Global()
.Get(engine_key);
inference::Singleton<inference::anakin::AnakinOpConverter<
::anakin::saber::NV, PrecisionT>>::Global()
.ConvertBlockToAnakinEngine(
&block_desc_temp, scope,
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, anakin_engine);
#endif
} else {
auto *anakin_engine =
inference::Singleton<inference::anakin::AnakinEngineManager<
::anakin::saber::X86, PrecisionT>>::Global()
.Get(engine_key);
inference::Singleton<inference::anakin::AnakinOpConverter<
::anakin::saber::X86, PrecisionT>>::Global()
.ConvertBlockToAnakinEngine(
&block_desc_temp, scope,
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, anakin_engine);
}
}
} // namespace analysis
......
......@@ -15,6 +15,7 @@
#pragma once
#include <paddle/fluid/framework/ir/fuse_pass_base.h>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/pass.h"
......@@ -36,6 +37,13 @@ class AnakinSubgraphPass : public framework::ir::FusePassBase {
const std::vector<std::string> &graph_params,
std::vector<std::string> *repetitive_params) const;
void CleanIntermediateOutputs(framework::ir::Node *node);
template <::anakin::Precision PrecisionT>
void CreateAnakinEngine(framework::BlockDesc *block_desc,
const std::vector<std::string> &params,
const std::set<std::string> &input_names,
const std::vector<std::string> &output_mapping,
const std::vector<std::string> &program_inputs,
const std::string &engine_key) const;
};
} // namespace analysis
......
......@@ -100,7 +100,6 @@ void RenameAndGetOutputs(
const std::string arg_value = in_var->arguments(k);
const std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]);
if (input_names_with_id.count(arg_value_with_id)) {
replaced_names.push_back(arg_value);
if (graph_var_map.count(arg_value)) {
......@@ -149,7 +148,6 @@ void RenameAndGetOutputs(
const std::string arg_value = out_var->arguments(k);
const std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]);
if (graph_var_map.count(arg_value)) {
add_block_var(arg_value, arg_value_with_id);
}
......
......@@ -116,6 +116,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(anakin_max_batchsize_);
CP_MEMBER(anakin_max_input_shape_);
CP_MEMBER(anakin_min_subgraph_size_);
CP_MEMBER(anakin_precision_mode_);
CP_MEMBER(anakin_auto_config_layout_);
CP_MEMBER(anakin_passes_filter_);
CP_MEMBER(anakin_ops_filter_);
// Ir related.
CP_MEMBER(enable_ir_optim_);
......@@ -269,13 +273,18 @@ void AnalysisConfig::Update() {
PADDLE_ENFORCE(!use_tensorrt_,
"Anakin sub-graph and TensorRT sub-graph are not allowed to "
"run at the same time!");
PADDLE_ENFORCE(
use_gpu_,
"Anakin sub-graph engine need gpu, please use the EnableGpu API.");
if (use_gpu_) {
LOG(INFO) << "Run Anakin GPU mode";
} else {
LOG(INFO) << "Run Anakin CPU mode";
}
pass_builder()->ClearPasses();
for (const auto &pass : kAnakinSubgraphPasses) {
pass_builder()->AppendPass(pass);
if (std::find(anakin_passes_filter_.begin(), anakin_passes_filter_.end(),
pass) == anakin_passes_filter_.end()) {
pass_builder()->AppendPass(pass);
}
}
}
......@@ -390,11 +399,17 @@ void AnalysisConfig::SwitchIrDebug(int x) {
}
void AnalysisConfig::EnableAnakinEngine(
int max_batch_size, std::map<std::string, std::vector<int>> max_input_shape,
int min_subgraph_size) {
int min_subgraph_size, AnalysisConfig::Precision precision_mode,
bool auto_config_layout, std::vector<std::string> passes_filter,
std::vector<std::string> ops_filter) {
anakin_max_batchsize_ = max_batch_size;
anakin_max_input_shape_ = max_input_shape;
anakin_min_subgraph_size_ = min_subgraph_size;
anakin_passes_filter_ = passes_filter;
anakin_ops_filter_ = ops_filter;
use_anakin_ = true;
anakin_precision_mode_ = precision_mode;
anakin_auto_config_layout_ = auto_config_layout;
Update();
}
} // namespace paddle
......@@ -383,10 +383,14 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
}
if (config_.use_gpu() && config_.anakin_engine_enabled()) {
if (config_.anakin_engine_enabled()) {
argument_.SetAnakinMaxBatchSize(config_.anakin_max_batchsize_);
argument_.SetAnakinMaxInputShape(config_.anakin_max_input_shape_);
argument_.SetAnakinMinSubgraphSize(config_.anakin_min_subgraph_size_);
argument_.SetAnakinPrecisionMode(config_.anakin_precision_mode_);
argument_.SetAnakinAutoConfigLayout(config_.anakin_auto_config_layout_);
argument_.SetAnakinPassesFilter(config_.anakin_passes_filter_);
argument_.SetAnakinOpsFilter(config_.anakin_ops_filter_);
LOG(INFO) << "Anakin subgraph engine is enabled";
}
......@@ -929,4 +933,9 @@ USE_ANAKIN_CONVERTER(density_prior_box);
USE_ANAKIN_CONVERTER(dropout);
USE_ANAKIN_CONVERTER(sum);
USE_ANAKIN_CONVERTER(prior_box);
USE_ANAKIN_CONVERTER(leaky_relu);
USE_ANAKIN_CONVERTER(affine_channel);
USE_ANAKIN_CONVERTER(relu6);
USE_ANAKIN_CONVERTER(swish);
USE_ANAKIN_CONVERTER(shuffle_channel);
#endif
......@@ -152,7 +152,10 @@ struct AnalysisConfig {
void EnableAnakinEngine(
int max_batch_size = 1,
std::map<std::string, std::vector<int>> max_input_shape = {},
int min_subgraph_size = 6);
int min_subgraph_size = 6, Precision precision = Precision::kFloat32,
bool auto_config_layout = false,
std::vector<std::string> passes_filter = {},
std::vector<std::string> ops_filter = {});
/** A boolean state indicating whether the Anakin sub-graph engine is used.
*/
......@@ -291,6 +294,10 @@ struct AnalysisConfig {
int anakin_max_batchsize_;
int anakin_min_subgraph_size_{6};
std::map<std::string, std::vector<int>> anakin_max_input_shape_;
Precision anakin_precision_mode_;
bool anakin_auto_config_layout_{false};
std::vector<std::string> anakin_passes_filter_;
std::vector<std::string> anakin_ops_filter_;
std::map<std::string, std::string> engine_opt_info_;
bool use_mkldnn_quantizer_{false};
......
......@@ -73,15 +73,15 @@ void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
// The following passes works for Anakin sub-graph engine.
const std::vector<std::string> kAnakinSubgraphPasses({
"infer_clean_graph_pass", //
"quant_conv2d_dequant_fuse_pass", //
"simplify_anakin_priorbox_detection_out_pass", //
"fillconstant_elementwisemul_fuse", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"fc_gru_fuse_pass", //
"quant_conv2d_dequant_fuse_pass", //
"anakin_subgraph_pass",
"shuffle_channel_detect_pass", //
"anakin_subgraph_pass", //
"fc_gru_fuse_pass", //
});
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
......
......@@ -4,7 +4,7 @@ lib=$1
if [ $# -ne 1 ]; then echo "No input library"; exit -1 ; fi
num_paddle_syms=$(nm -D ${lib} | grep paddle | wc -l)
num_google_syms=$(nm -D ${lib} | grep google | grep -v paddle | grep T | wc -l)
num_google_syms=$(nm -D ${lib} | grep google | grep -v paddle | grep "T " | wc -l)
if [ $num_paddle_syms -le 0 ]; then echo "Have no paddle symbols"; exit -1 ; fi
if [ $num_google_syms -ge 1 ]; then echo "Have some google symbols"; exit -1 ; fi
......
......@@ -34,28 +34,17 @@ limitations under the License. */
namespace paddle {
namespace operators {
using FluidDT = framework::proto::VarType_Type;
using inference::Singleton;
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
using inference::anakin::AnakinEngine;
class AnakinEngineOp : public framework::OperatorBase {
using AnakinNvEngineT = AnakinEngine<NV, Precision::FP32>;
private:
std::vector<std::string> input_names_;
std::unordered_set<std::string> param_names_;
mutable AnakinNvEngineT *anakin_engine_;
std::string engine_key_;
std::string engine_serialized_data_;
bool use_gpu_;
bool enable_int8_;
public:
AnakinEngineOp(const std::string &type,
......@@ -66,10 +55,11 @@ class AnakinEngineOp : public framework::OperatorBase {
input_names_ = Inputs("Xs");
engine_key_ = Attr<std::string>("engine_key");
auto params = Attr<std::vector<std::string>>("parameters");
use_gpu_ = Attr<bool>("use_gpu");
enable_int8_ = Attr<bool>("enable_int8");
for (const auto &param : params) {
param_names_.insert(param);
}
anakin_engine_ = nullptr;
}
protected:
......@@ -80,19 +70,12 @@ class AnakinEngineOp : public framework::OperatorBase {
void RunAnakin(const framework::Scope &scope,
const platform::Place &dev_place) const {
auto *engine = GetEngine(scope, dev_place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs");
std::vector<std::string> output_maps =
Attr<std::vector<std::string>>("output_name_mapping");
std::map<std::string, framework::LoDTensor *> inputs;
// Convert input tensor from fluid to engine.
for (const auto &x : Inputs("Xs")) {
if (param_names_.count(x)) continue;
auto &t =
......@@ -110,17 +93,38 @@ class AnakinEngineOp : public framework::OperatorBase {
outputs.insert({output_maps[output_index], fluid_t});
output_index += 1;
}
engine->Execute(inputs, outputs, stream);
if (enable_int8_) {
Execute<::anakin::Precision::INT8>(inputs, outputs, dev_place);
} else {
Execute<::anakin::Precision::FP32>(inputs, outputs, dev_place);
}
}
AnakinNvEngineT *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const {
if (anakin_engine_ == nullptr) {
anakin_engine_ =
inference::Singleton<inference::anakin::AnakinEngineManager>::Global()
template <::anakin::Precision PrecisionT>
void Execute(const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs,
const platform::Place &dev_place) const {
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx)
.stream();
auto *engine =
inference::Singleton<inference::anakin::AnakinEngineManager<
::anakin::saber::NV, PrecisionT>>::Global()
.Get(engine_key_);
engine->Execute(inputs, outputs, stream);
#endif
} else {
auto *engine =
inference::Singleton<inference::anakin::AnakinEngineManager<
::anakin::saber::X86, PrecisionT>>::Global()
.Get(engine_key_);
engine->Execute(inputs, outputs);
}
return anakin_engine_;
}
};
......
......@@ -16,6 +16,7 @@
#include <pybind11/stl.h>
#include <cstring>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/inference/api/analysis_predictor.h"
......@@ -229,6 +230,15 @@ void BindAnalysisConfig(py::module *m) {
py::arg("min_subgraph_size") = 3,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
py::arg("use_static") = true)
.def("enable_anakin_engine", &AnalysisConfig::EnableAnakinEngine,
py::arg("max_batch_size") = 1,
py::arg("max_input_shape") =
std::map<std::string, std::vector<int>>(),
py::arg("min_subgraph_size") = 6,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
py::arg("auto_config_layout") = false,
py::arg("passes_filter") = std::vector<std::string>(),
py::arg("ops_filter") = std::vector<std::string>())
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug,
py::arg("x") = true)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册