未验证 提交 1c591c39 编写于 作者: J jerrywgz 提交者: GitHub

Merge branch 'develop' into fix_rpn_target_assign_op

......@@ -261,6 +261,13 @@ function(cc_library TARGET_NAME)
add_dependencies(${TARGET_NAME} mklml)
target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed")
endif()
# remove link to python, see notes at:
# https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually
if("${cc_library_DEPS};" MATCHES "python;")
list(REMOVE_ITEM cc_library_DEPS python)
add_dependencies(${TARGET_NAME} python)
target_link_libraries(${TARGET_NAME} "-Wl,-undefined,dynamic_lookup")
endif()
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
endif()
......
......@@ -49,6 +49,8 @@ struct VarHandleBase {
void AddOutput(OpHandleBase* out, ir::Node* node) {
if (pending_ops_.find(out) == pending_ops_.end()) {
PADDLE_ENFORCE(out != nullptr, "The output of %s should not be nullptr",
this->Node()->Name());
pending_ops_.insert(out);
node_->outputs.push_back(node);
}
......
......@@ -37,6 +37,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base)
pass_library(conv_bias_mkldnn_fuse_pass inference)
......@@ -44,6 +45,9 @@ if(WITH_MKLDNN)
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
if(WITH_MKLDNN)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
endif()
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
......@@ -57,4 +61,5 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
if (WITH_MKLDNN)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
endif ()
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include <utility>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
// The function keeps the graph consistent by replacing
// a node 'from' in the set of inputs nodes
// of the visited node by a node 'to'.
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) {
auto from_in_inputs =
std::find(std::begin(node.inputs), std::end(node.inputs), from);
if (from_in_inputs != std::end(node.inputs)) {
IR_NODE_LINK_TO(to, (&node));
auto inputs = node.Op()->Inputs();
using input_type = VariableNameMap::value_type;
std::for_each(std::begin(inputs), std::end(inputs),
[from, to, &node](const input_type& i) -> void {
auto param_names = i.second;
auto pi = std::find(std::begin(param_names),
std::end(param_names), from->Name());
if (pi != std::end(param_names)) {
node.Op()->SetInput(i.first, {to->Name()});
}
});
}
}
}
} // namespace
using graph_ptr = std::unique_ptr<ir::Graph>;
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get());
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope_};
auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
elementwise_add_pattern(conv_output);
conv_output->AsIntermediate();
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
auto bias_input_names = conv_op.Op()->Inputs();
auto bias_it = bias_input_names.find("Bias");
if (bias_it != std::end(bias_input_names)) {
bool has_bias = !bias_it->second.empty();
if (has_bias) {
auto conv_bias_names = bias_it->second;
auto conv_bias_names_it =
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs),
[&conv_bias_names](Node* n) -> bool {
return n->Name() == conv_bias_names[0];
});
return std::make_pair(has_bias, *conv_bias_names_it);
}
}
return std::make_pair(false, nullptr);
};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
bool has_bias;
Node* conv_bias;
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});
}
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = g->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(conv_input, fused_conv_op);
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (has_bias) {
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
}
CorrectGraphEdges(g, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase {
public:
virtual ~ConvElementwiseAddMKLDNNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"residual_connections_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
constexpr int nodes_removed = 3;
constexpr int nodes_added = 1;
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::pair<std::string, std::string>>& inputs,
const std::pair<std::string, std::string>& output) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", true);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
op->SetOutput(output.first, {output.second});
}
struct IsReachable {
using func = std::function<bool(const std::string&, const std::string&)>;
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
auto find_node = [](const std::unique_ptr<ir::Graph>& graph,
const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(*graph)) {
if (name == node.Name()) {
return &node;
}
}
return nullptr;
};
return [&](std::string from, const std::string to) -> bool {
if (from == to) return true;
std::map<std::string, bool> visited;
for (auto& node : GraphTraits::DFS(*graph)) {
visited[node.Name()] = false;
}
visited[from] = true;
std::list<std::string> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr) return false;
for (auto n : cur->outputs) {
if (n->Name() == to) return true;
if (!visited[n->Name()]) {
visited[n->Name()] = true;
queue.push_back(n->Name());
}
}
}
return false;
};
}
};
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
}
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog;
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR);
return var;
};
for (const auto& v : transient_vars) {
add_var_to_prog(v);
}
for (const auto& v : persistent_vars) {
auto var = add_var_to_prog(v);
var->SetPersistable(true);
}
return prog;
}
} // namespace
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass,
ConvolutionWithElementwiseAddReluNoBias) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "d"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_FALSE(is_reachable(graph)("a", "d"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "f"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "f"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_elementwise_add_mkldnn_fuse_pass);
......@@ -761,6 +761,51 @@ PDNode *patterns::ConvReLU::operator()(
return relu_out_var;
}
PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
seqconv_input->assert_is_op_input("sequence_conv", "X");
auto *seqconv_op = pattern->NewNode(seqconv_repr())
->assert_is_op("sequence_conv")
->assert_op_attr<bool>("paddingTrainable", false)
->assert_op_attr<int>("contextStride", 1);
auto *eltadd_op =
pattern->NewNode(eltadd_repr())->assert_is_op("elementwise_add");
auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu");
// Create variables
// Filter
auto *seqconv_weight_var =
pattern->NewNode(seqconv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("sequence_conv", "Filter");
// Bias
auto *eltadd_bias_var = pattern->NewNode(eltadd_bias_repr())
->AsInput()
->assert_is_op_input("elementwise_add");
// intermediate variable, will be removed in the IR after fuse.
auto *seqconv_out_var = pattern->NewNode(seqconv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("sequence_conv")
->assert_is_op_input("elementwise_add");
auto *eltadd_out_var = pattern->NewNode(eltadd_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("elementwise_add")
->assert_is_only_input_of_op("relu");
// output
auto *relu_out_var = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu");
seqconv_op->LinksFrom({seqconv_input, seqconv_weight_var})
.LinksTo({seqconv_out_var});
eltadd_op->LinksFrom({seqconv_out_var, eltadd_bias_var})
.LinksTo({eltadd_out_var});
relu_op->LinksFrom({eltadd_out_var}).LinksTo({relu_out_var});
return relu_out_var;
}
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
// Create shared nodes.
......@@ -999,6 +1044,46 @@ PDNode *patterns::ConvBias::operator()(
return eltwise_out_var;
}
PDNode *patterns::Conv::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto input_var = pattern->NewNode(conv_input_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input");
auto filter_var = pattern->NewNode(conv_filter_repr())
->AsInput()
->assert_is_op_input("conv2d", "Filter");
auto output_var = pattern->NewNode(conv_output_repr())
->AsOutput()
->assert_is_op_output("conv2d", "Output");
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
return output_var;
}
PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var) {
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
x_var->assert_is_op_input("elementwise_add", "X");
auto y_var = pattern->NewNode(elementwise_add_x_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto out_var = pattern->NewNode(elementwise_add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");
elementwise_add_op->LinksFrom({x_var, y_var});
elementwise_add_op->LinksTo({out_var});
return out_var;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -128,6 +128,15 @@ struct PDNode {
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
template <typename T>
PDNode* assert_op_attr(const std::string& attr_name, const T& attr) {
asserts_.emplace_back([=](Node* x) {
return x && x->IsOp() && x->Op()->HasAttr(attr_name) &&
boost::get<T>(x->Op()->GetAttr(attr_name)) == attr;
});
return this;
}
private:
PDNode(PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
......@@ -434,6 +443,31 @@ struct ConvReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out);
};
// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
// seqconv_input, seqconv_weight,
// seqconv_out, seqconv,
// elementwise_add_bias, elementwise_add_out, elementwise_add
// relu_out, relu
struct SeqConvEltAddRelu : public PatternBase {
SeqConvEltAddRelu(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "seqconv_eltadd_relu") {}
PDNode* operator()(PDNode* seqconv_input);
// declare operator node's name
PATTERN_DECL_NODE(seqconv);
PATTERN_DECL_NODE(eltadd);
PATTERN_DECL_NODE(relu);
// declare variable node's name
PATTERN_DECL_NODE(seqconv_weight);
PATTERN_DECL_NODE(seqconv_out);
PATTERN_DECL_NODE(eltadd_bias);
PATTERN_DECL_NODE(eltadd_out);
PATTERN_DECL_NODE(relu_out);
};
// FC with bias
// op: mul + elementwise_add
// named nodes:
......@@ -599,6 +633,44 @@ struct ConvBias : public PatternBase {
PATTERN_DECL_NODE(eltwise_bias);
PATTERN_DECL_NODE(eltwise_out);
};
// Convolution op
// Forward pass for convolution.
// conv_input, conv_bias and conv_filter are inputs.
// conv_output is a result of the operator.
// residual_data is data used by skip connection.
// If residual connection fusion is on, the formula is:
// conv_output = conv_op(conv_filter, conv_input, conv_bias)
// + conv_residual_data
// If the fusion is off, conv_residual_data is not added.
struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "convolution") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_residual_data);
PATTERN_DECL_NODE(conv_output);
};
// ElementwiseAdd used in residual connections.
// y_var is used and convolution output.
// The operator is removed, when residual
// connection fusion is on.
struct ElementwiseAdd : public PatternBase {
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* x_var);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_x);
PATTERN_DECL_NODE(elementwise_add_y);
PATTERN_DECL_NODE(elementwise_add_out);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
namespace ir {
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X"))
->assert_is_op_input("sequence_conv")
->assert_var_not_persistable();
patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope);
fuse_pattern(x);
// Create New OpDesc
auto fuse_creator = [&](Node* seqconv, Node* input, Node* seqconv_weight,
Node* eltadd_bias, Node* relu_out) {
OpDesc op_desc;
op_desc.SetType("fusion_seqconv_eltadd_relu");
op_desc.SetInput("X", {input->Name()});
op_desc.SetInput("Filter", {seqconv_weight->Name()});
op_desc.SetInput("Bias", {eltadd_bias->Name()});
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
op_desc.SetOutput("ColMat", {ColMat});
op_desc.SetOutput("Out", {relu_out->Name()});
scope->Var(ColMat)->GetMutable<LoDTensor>();
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op);
IR_NODE_LINK_TO(seqconv_weight, op);
IR_NODE_LINK_TO(eltadd_bias, op);
IR_NODE_LINK_TO(op, relu_out);
return op;
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle SeqConv EltAdd Relu fuse";
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd, eltadd, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_bias, eltadd_bias, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_out, eltadd_out, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, fuse_pattern);
fuse_creator(seqconv, subgraph.at(x), seqconv_weight, eltadd_bias,
relu_out);
std::unordered_set<const Node*> marked_nodes(
{seqconv, seqconv_out, eltadd, eltadd_out, relu});
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
std::unique_ptr<ir::Graph> SeqConvEltAddReluFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope());
AddStatis(fusion_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(seqconv_eltadd_relu_fuse_pass,
paddle::framework::ir::SeqConvEltAddReluFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class SeqConvEltAddReluFusePass : public FusePassBase {
public:
virtual ~SeqConvEltAddReluFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"seqconv_eltadd_relu_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -515,20 +515,14 @@ void OpDesc::InferShape(const BlockDesc &block) const {
}
void OpDesc::InferVarType(BlockDesc *block) const {
// There are a few places that var type can be set.
// When VarDesc is created, default set to LOD_TENSOR.
// When output variable is created, default is defaut set to LOD_TENSOR.
// We limit here to be the only place that operator defines its customized
// var type inference. Hence, we don't do any "default" setting here.
auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) {
info.infer_var_type_(*this, block);
} else {
// all output type is LoDTensor by default
VLOG(10) << this->Type()
<< " has not registered InferVarType. Set output variables to "
"LOD_TENSOR";
for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(proto::VarType::LOD_TENSOR);
}
}
}
}
......
......@@ -299,6 +299,12 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
ParallelExecutor::~ParallelExecutor() {
const auto dev_ctxs =
platform::DeviceContextPool::Instance().GetAllDeviceContexts();
for (auto &dev_ctx : dev_ctxs) {
dev_ctx->Wait();
}
if (member_->own_local_scope_) {
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
Scope *local_scope = member_->local_scopes_[i];
......
......@@ -67,20 +67,22 @@ class Analyzer : public OrderedRegistry<PassManager> {
// larger fusion.
const std::vector<std::string> all_ir_passes_{{
// Manual update the passes here.
"infer_clean_graph_pass", //
"attention_lstm_fuse_pass", //
"embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"infer_clean_graph_pass", //
"attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", //
"embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
#ifdef PADDLE_WITH_MKLDNN
"conv_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", //
#endif
}};
......
......@@ -183,7 +183,13 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
SetConfig(&cfg);
int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
GetFuseStatis(predictor.get(), &num_ops);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 2);
EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6);
EXPECT_EQ(num_ops, 32);
}
// Compare result of NativeConfig and AnalysisConfig
......
......@@ -86,7 +86,7 @@ function(op_library TARGET)
# remove windows unsupported op, because windows has no nccl, no warpctc such ops.
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op"
"crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op"
"channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
"fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return()
endif()
......
......@@ -300,10 +300,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_eltwise = ctx.Attr<bool>("fuse_eltwise");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups");
// TODO: add support for dilation
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
......@@ -369,11 +369,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_eltwise);
fuse_relu, fuse_residual_conn);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_eltwise);
mkldnn_engine, fuse_relu, fuse_residual_conn);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
......@@ -386,8 +386,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data));
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
T* output_data = nullptr;
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the "
"same dimension sizes");
output->ShareDataWith(*residual_param);
output_data = output->mutable_data<T>(ctx.GetPlace());
} else {
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
}
// create reorder primitive if the input format is not the preferred one
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
......@@ -424,14 +442,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private:
mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
bool fuse_eltwise) const {
bool fuse_residual_conn) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_eltwise is true, the
// Output tensor contains the data coming from residual connection. The
// result of this post_op is: Output = scale * Output + Conv_Out.
if (fuse_eltwise) {
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(1.0f);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
......@@ -452,7 +471,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_eltwise) const {
const bool fuse_residual_conn) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
......@@ -461,7 +480,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
......@@ -476,7 +496,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_eltwise) const {
const bool fuse_residual_conn) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
......@@ -485,7 +505,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
......
......@@ -132,6 +132,11 @@ void Conv2DOpMaker::Make() {
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.")
.Reuse("Input");
AddInput("ResidualData",
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"Used with fuse_residual_connection fusion.")
.AsDispensable();
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
......@@ -164,10 +169,10 @@ void Conv2DOpMaker::Make() {
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_eltwise",
AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is connected via skip connection "
"to a previous layer.")
"whenever convolution output is as an input to residual "
"connection.")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
......
......@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -25,21 +27,17 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
struct AppendProposalsFunctor {
LoDTensor *out_;
int64_t offset_;
Tensor *to_add_;
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
AppendProposalsFunctor(LoDTensor *out, int64_t offset, Tensor *to_add)
: out_(out), offset_(offset), to_add_(to_add) {}
template <typename T>
void apply() const {
auto *out_data = out_->data<T>();
auto *to_add_data = to_add_->data<T>();
memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T));
}
};
static void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) {
auto *out_data = dst->data<void>();
auto *to_add_data = src.data<void>();
size_t size_of_t = framework::SizeOfType(src.type());
offset *= size_of_t;
std::memcpy(
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(out_data) + offset),
to_add_data, src.numel() * size_of_t);
}
class GenerateProposalsOp : public framework::OperatorWithKernel {
public:
......@@ -75,8 +73,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
};
template <class T>
void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
Tensor *bbox_deltas, Tensor *variances, Tensor *proposals) {
static inline void BoxCoder(const platform::DeviceContext &ctx,
Tensor *all_anchors, Tensor *bbox_deltas,
Tensor *variances, Tensor *proposals) {
T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace());
int64_t row = all_anchors->dims()[0];
......@@ -108,11 +107,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
anchor_center_y;
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) *
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
kBBoxClipDefault)) *
anchor_height;
} else {
bbox_center_x =
......@@ -120,10 +119,10 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) *
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
kBBoxClipDefault)) *
anchor_height;
}
......@@ -136,30 +135,32 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
}
template <class T>
void ClipTiledBoxes(const platform::DeviceContext &ctx, const Tensor &im_info,
Tensor *boxes) {
static inline void ClipTiledBoxes(const platform::DeviceContext &ctx,
const Tensor &im_info, Tensor *boxes) {
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
const T *im_info_data = im_info.data<T>();
T zero(0);
for (int64_t i = 0; i < boxes->numel(); ++i) {
if (i % 4 == 0) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f);
std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else if (i % 4 == 1) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f);
std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
} else if (i % 4 == 2) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f);
std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f);
std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
}
}
}
template <class T>
void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
float min_size, const Tensor &im_info, Tensor *keep) {
static inline void FilterBoxes(const platform::DeviceContext &ctx,
Tensor *boxes, float min_size,
const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
T im_scale = im_info_data[2];
......@@ -185,24 +186,24 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
keep->Resize({keep_len});
}
bool SortScorePairDescend(const std::pair<float, int> &pair1,
const std::pair<float, int> &pair2) {
return pair1.first > pair2.first;
}
template <class T>
void GetMaxScoreIndex(const std::vector<T> &scores,
std::vector<std::pair<T, int>> *sorted_indices) {
static inline std::vector<std::pair<T, int>> GetSortedScoreIndex(
const std::vector<T> &scores) {
std::vector<std::pair<T, int>> sorted_indices;
sorted_indices.reserve(scores.size());
for (size_t i = 0; i < scores.size(); ++i) {
sorted_indices->push_back(std::make_pair(scores[i], i));
sorted_indices.emplace_back(scores[i], i);
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
SortScorePairDescend);
std::stable_sort(sorted_indices.begin(), sorted_indices.end(),
[](const std::pair<T, int> &a, const std::pair<T, int> &b) {
return a.first < b.first;
});
return sorted_indices;
}
template <class T>
T BBoxArea(const T *box, const bool normalized) {
static inline T BBoxArea(const T *box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
......@@ -220,7 +221,7 @@ T BBoxArea(const T *box, const bool normalized) {
}
template <class T>
T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
static inline T JaccardOverlap(const T *box1, const T *box2, bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
......@@ -229,8 +230,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = std::max(0.0f, inter_xmax - inter_xmin + 1);
const T inter_h = std::max(0.0f, inter_ymax - inter_ymin + 1);
const T inter_w = std::max(T(0), inter_xmax - inter_xmin + 1);
const T inter_h = std::max(T(0), inter_ymax - inter_ymin + 1);
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
......@@ -238,9 +239,21 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
}
}
template <typename T>
static inline Tensor VectorToTensor(const std::vector<T> &selected_indices,
int selected_num) {
Tensor keep_nms;
keep_nms.Resize({selected_num});
auto *keep_data = keep_nms.mutable_data<T>(platform::CPUPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
}
template <class T>
Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
const T nms_threshold, const float eta) {
static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
Tensor *scores, T nms_threshold, float eta) {
PADDLE_ENFORCE_NOT_NULL(bbox);
int64_t num_boxes = bbox->dims()[0];
// 4: [xmin ymin xmax ymax]
......@@ -248,20 +261,18 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
std::vector<T> scores_data(num_boxes);
std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
GetMaxScoreIndex<T>(scores_data, &sorted_indices);
std::vector<std::pair<T, int>> sorted_indices =
GetSortedScoreIndex<T>(scores_data);
std::vector<int> selected_indices;
int selected_num = 0;
T adaptive_threshold = nms_threshold;
const T *bbox_data = bbox->data<T>();
bool flag;
while (sorted_indices.size() != 0) {
int idx = sorted_indices.front().second;
flag = true;
for (size_t k = 0; k < selected_indices.size(); ++k) {
int idx = sorted_indices.back().second;
bool flag = true;
for (int kept_idx : selected_indices) {
if (flag) {
const int kept_idx = selected_indices[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, false);
flag = (overlap <= adaptive_threshold);
......@@ -271,32 +282,29 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
}
if (flag) {
selected_indices.push_back(idx);
selected_num++;
++selected_num;
}
sorted_indices.erase(sorted_indices.begin());
sorted_indices.erase(sorted_indices.end());
if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
Tensor keep_nms;
keep_nms.Resize({selected_num});
int *keep_data = keep_nms.mutable_data<int>(ctx.GetPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
return VectorToTensor(selected_indices, selected_num);
}
template <typename DeviceContext, typename T>
template <typename T>
class GenerateProposalsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo");
auto *anchors = context.Input<Tensor>("Anchors");
auto *variances = context.Input<Tensor>("Variances");
auto anchors = detail::Ref(context.Input<Tensor>("Anchors"),
"Cannot find input Anchors(%s) in scope",
context.Inputs("Anchors")[0]);
auto variances = detail::Ref(context.Input<Tensor>("Variances"),
"Cannot find input Variances(%s) in scope",
context.Inputs("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
......@@ -307,15 +315,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta");
auto &dev_ctx = context.template device_context<DeviceContext>();
auto &dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto scores_dim = scores->dims();
auto &scores_dim = scores->dims();
int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3];
auto bbox_dim = bbox_deltas->dims();
auto &bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3];
......@@ -330,17 +339,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace());
math::Transpose<DeviceContext, T, 4> trans;
math::Transpose<platform::CPUDeviceContext, T, 4> trans;
std::vector<int> axis = {0, 2, 3, 1};
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis);
framework::LoD lod;
std::vector<size_t> lod0(1, 0);
Tensor *anchor = const_cast<framework::Tensor *>(anchors);
anchor->Resize({anchors->numel() / 4, 4});
Tensor *var = const_cast<framework::Tensor *>(variances);
var->Resize({var->numel() / 4, 4});
lod.resize(1);
auto &lod0 = lod[0];
lod0.push_back(0);
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) {
......@@ -352,24 +361,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair =
ProposalForOneImage(dev_ctx, im_info_slice, *anchor, *var,
ProposalForOneImage(dev_ctx, im_info_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta);
Tensor proposals = tensor_pair.first;
Tensor scores = tensor_pair.second;
framework::VisitDataType(
framework::ToDataType(rpn_rois->type()),
AppendProposalsFunctor(rpn_rois, 4 * num_proposals, &proposals));
framework::VisitDataType(
framework::ToDataType(rpn_roi_probs->type()),
AppendProposalsFunctor(rpn_roi_probs, num_proposals, &scores));
Tensor &proposals = tensor_pair.first;
Tensor &scores = tensor_pair.second;
AppendProposals(rpn_rois, 4 * num_proposals, proposals);
AppendProposals(rpn_roi_probs, num_proposals, scores);
num_proposals += proposals.dims()[0];
lod0.emplace_back(num_proposals);
lod0.push_back(num_proposals);
}
lod.emplace_back(lod0);
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
......@@ -377,7 +379,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
}
std::pair<Tensor, Tensor> ProposalForOneImage(
const DeviceContext &ctx, const Tensor &im_info_slice,
const platform::CPUDeviceContext &ctx, const Tensor &im_info_slice,
const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1]
......@@ -392,10 +394,9 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i;
}
std::function<bool(const int64_t &, const int64_t &)> compare =
[scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
};
auto compare = [scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
};
if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) {
std::sort(index, index + scores_slice.numel(), compare);
......@@ -452,33 +453,45 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Scores", "The scores of anchors should be foreground.");
AddInput("BboxDeltas", "bbox_deltas.");
AddInput("ImInfo", "Information for image reshape.");
AddInput("Anchors", "All anchors.");
AddInput("Variances", " variances");
AddOutput("RpnRois", "Anchors.");
AddOutput("RpnRoiProbs", "Anchors.");
AddAttr<int>("pre_nms_topN", "pre_nms_topN");
AddAttr<int>("post_nms_topN", "post_nms_topN");
AddAttr<float>("nms_thresh", "nms_thres");
AddAttr<float>("min_size", "min size");
AddInput("Scores",
"(Tensor) The scores from conv is in shape (N, A, H, W), "
"N is batch size, A is number of anchors, "
"H and W are height and width of the feature map");
AddInput("BboxDeltas",
"(Tensor) Bounding box deltas from conv is in "
"shape (N, 4*A, H, W).");
AddInput("ImInfo",
"(Tensor) Information for image reshape is in shape (N, 3), "
"in format (height, width, scale)");
AddInput("Anchors",
"(Tensor) Bounding box anchors from anchor_generator_op "
"is in shape (A, H, W, 4).");
AddInput("Variances",
"(Tensor) Bounding box variances with same shape as `Anchors`.");
AddOutput("RpnRois",
"(LoDTensor), Output proposals with shape (rois_num, 4).");
AddOutput("RpnRoiProbs",
"(LoDTensor) Scores of proposals with shape (rois_num, 1).");
AddAttr<int>("pre_nms_topN",
"Number of top scoring RPN proposals to keep before "
"applying NMS.");
AddAttr<int>("post_nms_topN",
"Number of top scoring RPN proposals to keep after "
"applying NMS");
AddAttr<float>("nms_thresh", "NMS threshold used on RPN proposals.");
AddAttr<float>("min_size",
"Proposal height and width both need to be greater "
"than this min_size.");
AddAttr<float>("eta", "The parameter for adaptive NMS.");
AddComment(R"DOC(
Generate Proposals OP
This operator proposes rois according to each box with their probability to be a foreground object and
the box can be calculated by anchors. Bbox_deltais and scores are the output of RPN. Final proposals
could be used to train detection net.
Scores is the probability for each box to be an object. In format of (N, A, H, W) where N is batch size, A is number
of anchors, H and W are height and width of the feature map.
BboxDeltas is the differece between predicted box locatoin and anchor location. In format of (N, 4*A, H, W)
This operator Generate bounding box proposals for Faster RCNN.
The propoasls are generated for a list of images based on image
score 'Scores', bounding box regression result 'BboxDeltas' as
well as predefined bounding box shapes 'anchors'. Greedy
non-maximum suppression is applied to generate the final bounding
boxes.
For generating proposals, this operator transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4) and
calculate box locations as proposals candidates. Then clip boxes to image and remove predicted boxes with small area.
Finally, apply nms to get final proposals as output.
)DOC");
}
};
......@@ -490,6 +503,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp,
ops::GenerateProposalsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
generate_proposals,
ops::GenerateProposalsKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel<float>,
ops::GenerateProposalsKernel<double>);
......@@ -16,10 +16,13 @@ limitations under the License. */
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
......@@ -36,36 +39,38 @@ namespace {
int const kThreadsPerBlock = sizeof(uint64_t) * 8;
template <typename T>
__global__ void RangeInitKernel(const T start, const T delta, const int size,
T *out) {
CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
}
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
struct RangeInitFunctor {
int start_;
int delta_;
int *out_;
__device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
};
template <typename T>
void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value,
Tensor *value_out, Tensor *index_out) {
int num = value.numel();
static void SortDescending(const platform::CUDADeviceContext &ctx,
const Tensor &value, Tensor *value_out,
Tensor *index_out) {
int num = static_cast<int>(value.numel());
Tensor index_in_t;
int *idx_in = index_in_t.mutable_data<int>({num}, ctx.GetPlace());
int block = 512;
auto stream = ctx.stream();
RangeInitKernel<<<DIVUP(num, block), block, 0, stream>>>(0, 1, num, idx_in);
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, num);
for_range(RangeInitFunctor{0, 1, idx_in});
int *idx_out = index_out->mutable_data<int>({num}, ctx.GetPlace());
const T *keys_in = value.data<T>();
T *keys_out = value_out->mutable_data<T>({num}, ctx.GetPlace());
// Determine temporary device storage requirements
void *d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out,
num);
nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num);
// Allocate temporary storage
auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
void *d_temp_storage = memory::Alloc(place, temp_storage_bytes);
// Run sorting operation
cub::DeviceRadixSort::SortPairsDescending<T, int>(
......@@ -76,22 +81,27 @@ void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value,
}
template <typename T>
__device__ __forceinline__ T Min(T x, T y) {
return x < y ? x : y;
}
template <typename T>
__device__ __forceinline__ T Max(T x, T y) {
return x > y ? x : y;
}
template <typename T>
__global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
const T *var, const int *index,
const T *im_info, const int num,
T *proposals) {
T kBBoxClipDefault = log(1000.0 / 16.0);
CUDA_1D_KERNEL_LOOP(i, num) {
struct BoxDecodeAndClipFunctor {
const T *anchor;
const T *deltas;
const T *var;
const int *index;
const T *im_info;
T *proposals;
BoxDecodeAndClipFunctor(const T *anchor, const T *deltas, const T *var,
const int *index, const T *im_info, T *proposals)
: anchor(anchor),
deltas(deltas),
var(var),
index(index),
im_info(im_info),
proposals(proposals) {}
T bbox_clip_default{static_cast<T>(kBBoxClipDefault)};
__device__ void operator()(size_t i) {
int k = index[i] * 4;
T axmin = anchor[k];
T aymin = anchor[k + 1];
......@@ -108,17 +118,17 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
T dxmax = deltas[k + 2];
T dymax = deltas[k + 3];
T d_cx = 0., d_cy = 0., d_w = 0., d_h = 0.;
T d_cx, d_cy, d_w, d_h;
if (var) {
d_cx = cx + dxmin * w * var[k];
d_cy = cy + dymin * h * var[k + 1];
d_w = exp(Min<T>(dxmax * var[k + 2], kBBoxClipDefault)) * w;
d_h = exp(Min<T>(dymax * var[k + 3], kBBoxClipDefault)) * h;
d_w = exp(Min(dxmax * var[k + 2], bbox_clip_default)) * w;
d_h = exp(Min(dymax * var[k + 3], bbox_clip_default)) * h;
} else {
d_cx = cx + dxmin * w;
d_cy = cy + dymin * h;
d_w = exp(Min<T>(dxmax, kBBoxClipDefault)) * w;
d_h = exp(Min<T>(dymax, kBBoxClipDefault)) * h;
d_w = exp(Min(dxmax, bbox_clip_default)) * w;
d_h = exp(Min(dymax, bbox_clip_default)) * h;
}
T oxmin = d_cx - d_w * 0.5;
......@@ -126,17 +136,21 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
T oxmax = d_cx + d_w * 0.5 - 1.;
T oymax = d_cy + d_h * 0.5 - 1.;
proposals[i * 4] = Max<T>(Min<T>(oxmin, im_info[1] - 1.), 0.);
proposals[i * 4 + 1] = Max<T>(Min<T>(oymin, im_info[0] - 1.), 0.);
proposals[i * 4 + 2] = Max<T>(Min<T>(oxmax, im_info[1] - 1.), 0.);
proposals[i * 4 + 3] = Max<T>(Min<T>(oymax, im_info[0] - 1.), 0.);
proposals[i * 4] = Max(Min(oxmin, im_info[1] - 1.), 0.);
proposals[i * 4 + 1] = Max(Min(oymin, im_info[0] - 1.), 0.);
proposals[i * 4 + 2] = Max(Min(oxmax, im_info[1] - 1.), 0.);
proposals[i * 4 + 3] = Max(Min(oymax, im_info[0] - 1.), 0.);
}
}
__device__ __forceinline__ T Min(T a, T b) const { return a > b ? b : a; }
__device__ __forceinline__ T Max(T a, T b) const { return a > b ? a : b; }
};
template <typename T, int BlockSize>
__global__ void FilterBBoxes(const T *bboxes, const T *im_info,
const T min_size, const int num, int *keep_num,
int *keep) {
static __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
const T min_size, const int num,
int *keep_num, int *keep) {
T im_h = im_info[0];
T im_w = im_info[1];
T im_scale = im_info[2];
......@@ -181,7 +195,7 @@ __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
}
}
__device__ inline float IoU(const float *a, const float *b) {
static __device__ inline float IoU(const float *a, const float *b) {
float left = max(a[0], b[0]), right = min(a[2], b[2]);
float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
......@@ -191,8 +205,9 @@ __device__ inline float IoU(const float *a, const float *b) {
return inter_s / (s_a + s_b - inter_s);
}
__global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh,
const float *dev_boxes, uint64_t *dev_mask) {
static __global__ void NMSKernel(const int n_boxes,
const float nms_overlap_thresh,
const float *dev_boxes, uint64_t *dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
......@@ -234,9 +249,9 @@ __global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh,
}
template <typename T>
void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const Tensor &sorted_indices, const T nms_threshold,
Tensor *keep_out) {
static void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const Tensor &sorted_indices, const T nms_threshold,
Tensor *keep_out) {
int boxes_num = proposals.dims()[0];
PADDLE_ENFORCE_EQ(boxes_num, sorted_indices.dims()[0]);
......@@ -247,13 +262,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const T *boxes = proposals.data<T>();
auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
int size_bytes = boxes_num * col_blocks * sizeof(uint64_t);
uint64_t *d_mask =
reinterpret_cast<uint64_t *>(memory::Alloc(place, size_bytes));
NMSKernel<<<blocks, threads>>>(boxes_num, nms_threshold, boxes, d_mask);
uint64_t *h_mask = reinterpret_cast<uint64_t *>(
memory::Alloc(platform::CPUPlace(), size_bytes));
memory::Copy(platform::CPUPlace(), h_mask, place, d_mask, size_bytes, 0);
framework::Vector<uint64_t> mask(boxes_num * col_blocks);
NMSKernel<<<blocks, threads>>>(
boxes_num, nms_threshold, boxes,
mask.CUDAMutableData(boost::get<platform::CUDAPlace>(ctx.GetPlace())));
std::vector<uint64_t> remv(col_blocks);
memset(&remv[0], 0, sizeof(uint64_t) * col_blocks);
......@@ -267,7 +279,7 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
if (!(remv[nblock] & (1ULL << inblock))) {
++num_to_keep;
keep_vec.push_back(i);
uint64_t *p = &h_mask[0] + i * col_blocks;
uint64_t *p = &mask[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
......@@ -276,12 +288,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace());
memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(),
sizeof(int) * num_to_keep, 0);
memory::Free(place, d_mask);
memory::Free(platform::CPUPlace(), h_mask);
}
template <typename T>
std::pair<Tensor, Tensor> ProposalForOneImage(
static std::pair<Tensor, Tensor> ProposalForOneImage(
const platform::CUDADeviceContext &ctx, const Tensor &im_info,
const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas, // [M, 4]
......@@ -300,18 +310,20 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
// 2. box decode and clipping
Tensor proposals;
proposals.mutable_data<T>({pre_nms_num, 4}, ctx.GetPlace());
int block = 512;
auto stream = ctx.stream();
BoxDecodeAndClipKernel<T><<<DIVUP(pre_nms_num, block), block, 0, stream>>>(
anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(),
index_sort.data<int>(), im_info.data<T>(), pre_nms_num,
proposals.data<T>());
{
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, pre_nms_num);
for_range(BoxDecodeAndClipFunctor<T>{
anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(),
index_sort.data<int>(), im_info.data<T>(), proposals.data<T>()});
}
// 3. filter
Tensor keep_index, keep_num_t;
keep_index.mutable_data<int>({pre_nms_num}, ctx.GetPlace());
keep_num_t.mutable_data<int>({1}, ctx.GetPlace());
min_size = std::max(min_size, 1.0f);
auto stream = ctx.stream();
FilterBBoxes<T, 512><<<1, 512, 0, stream>>>(
proposals.data<T>(), im_info.data<T>(), min_size, pre_nms_num,
keep_num_t.data<int>(), keep_index.data<int>());
......@@ -355,8 +367,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo");
auto *anchors = context.Input<Tensor>("Anchors");
auto *variances = context.Input<Tensor>("Variances");
auto anchors = detail::Ref(context.Input<Tensor>("Anchors"),
"Cannot find input Anchors(%s) in scope",
context.Inputs("Anchors")[0]);
auto variances = detail::Ref(context.Input<Tensor>("Variances"),
"Cannot find input Variances(%s) in scope",
context.Inputs("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
......@@ -392,10 +408,8 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis);
Tensor *anchor = const_cast<framework::Tensor *>(anchors);
anchor->Resize({anchors->numel() / 4, 4});
Tensor *var = const_cast<framework::Tensor *>(variances);
var->Resize({var->numel() / 4, 4});
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
......@@ -417,12 +431,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> box_score_pair =
ProposalForOneImage<T>(dev_ctx, im_info_slice, *anchor, *var,
ProposalForOneImage<T>(dev_ctx, im_info_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta);
Tensor proposals = box_score_pair.first;
Tensor scores = box_score_pair.second;
Tensor &proposals = box_score_pair.first;
Tensor &scores = box_score_pair.second;
memory::Copy(place, rpn_rois_data + num_proposals * 4, place,
proposals.data<T>(), sizeof(T) * proposals.numel(), 0);
......
......@@ -86,7 +86,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
// stub context
s->response_call_back_ = nullptr;
platform::RecordEvent record_event(method, p_ctx);
platform::RecordRPCEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
......@@ -143,7 +143,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordEvent record_event(method, p_ctx);
platform::RecordRPCEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
......@@ -191,7 +191,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordEvent record_event(method, p_ctx);
platform::RecordRPCEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
......@@ -221,7 +221,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
platform::RecordRPCEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
......@@ -246,7 +246,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
platform::RecordRPCEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
......@@ -271,7 +271,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
platform::RecordRPCEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
......@@ -301,7 +301,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
platform::RecordEvent record_event(method, nullptr);
platform::RecordRPCEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
......
......@@ -36,7 +36,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
platform::RecordEvent record_event("serial", &ctx);
platform::RecordRPCEvent record_event("serial", &ctx);
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {};
......@@ -148,7 +148,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
platform::RecordEvent record_event("deserial", &ctx);
platform::RecordRPCEvent record_event("deserial", &ctx);
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h"
#include <algorithm> // for min, max
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
namespace paddle {
namespace operators {
void FusionSeqConvEltAddReluOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("Filter"),
"Input(Filter) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("Bias"),
"Input(Bias) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("ColMat"),
"Output(ColMat) of FusionSeqConvEltAddReluOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto w_dims = ctx->GetInputDim("Filter");
int context_length = ctx->Attrs().Get<int>("contextLength");
PADDLE_ENFORCE(
ctx->Attrs().Get<int>("contextStride") == 1,
"Currently, FusionSeqConvEltAddReluOp only supports contextStride=1.");
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE(w_dims[0] == context_length * x_dims[1],
"Filter's height should be context_length * "
"input_hidden_size .");
PADDLE_ENFORCE_GT(context_length + ctx->Attrs().Get<int>("contextStart"), 0,
"contextStart size should be smaller than contextLength.");
ctx->SetOutputDim("Out", {x_dims[0], w_dims[1]});
ctx->SetOutputDim("ColMat", {x_dims[0], w_dims[0]});
ctx->ShareLoD("X", "Out");
}
framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void FusionSeqConvEltAddReluOpMaker::Make() {
AddInput("X",
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
// PaddingData only support false yet, should be ensured at pass.
AddInput("Filter",
"(Tensor) same as the input(Filter) of sequence conv op is an "
"learnable parameter."
"This is a tensor with shape (K, N), where K is the "
"context_length * dim size of x, N is the output feature size.");
AddInput("Bias",
"(Tensor) the learnable weights. shape (1, N), where N is the "
"output feature size");
AddOutput(
"Out",
"(LoDTensor) the output(Out) is a LodTensor, which support "
"variable-time length output sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T, N), where, T is the "
"total time steps in this mini-batch, N is the output feature size.");
AddOutput("ColMat",
"(Tensor) (T, K), where T is where T is the "
"total time steps in this mini-batch, K is height of Filter")
.AsIntermediate();
AddAttr<int>("contextLength",
"(int) the contextLength of FusionSeqConvEltAddReluOp is the "
"height of the convolution kernel.")
.GreaterThan(0);
AddAttr<int>("contextStart",
"(int, default:0) the contextStart of FusionSeqConvEltAddReluOp "
"represents the beginning of the convolution of the number of "
"rows of sequence, which can be negative. The negative number "
"means to pad contextStart time-steps of zeros or learnable "
"parameters at the beginning of each instance. The positive "
"number means to skip contextStart time-steps of each "
"instance.")
.SetDefault(0);
AddAttr<int>(
"contextStride",
"(int, default:1) the contextStride of FusionSeqConvEltAddReluOp "
"represents the stride length of convolution kernel. "
"Currently, FusionSeqConvEltAddReluOp only supports"
"contextStride=1.")
.SetDefault(1)
.GreaterThan(0);
AddComment(R"DOC(
Fusion Sequence Conv and ElementwiseAdd Operator.
)DOC");
}
template <typename T>
class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
auto* w = ctx.Input<Tensor>("Filter");
auto* b = ctx.Input<Tensor>("Bias");
auto* y = ctx.Output<LoDTensor>("Out");
auto* col = ctx.Output<Tensor>("ColMat");
auto x_lod = x->lod();
auto x_dims = x->dims();
auto w_dims = w->dims();
PADDLE_ENFORCE_EQ(b->numel(), w_dims[1],
"bias size should be equal to output feature size.");
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL,
"Only support one level sequence now.");
const T* x_data = x->data<T>();
const T* w_data = w->data<T>();
const T* b_data = b->data<T>();
T* y_data = y->mutable_data<T>(ctx.GetPlace());
T* col_data = col->mutable_data<T>(ctx.GetPlace());
int context_start = ctx.Attr<int>("contextStart");
int context_length = ctx.Attr<int>("contextLength");
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
// im2col
int src_mat_w = static_cast<int>(x_dims[1]);
int src_mat_w_sz = src_mat_w * sizeof(T);
int col_mat_w = static_cast<int>(w_dims[0]);
int col_mat_w_sz = col_mat_w * sizeof(T);
for (int i = 0; i < static_cast<int>(x_lod[0].size()) - 1; ++i) {
int st = x_lod[0][i];
int ed = x_lod[0][i + 1];
const T* src_data = x_data + st * src_mat_w;
T* dst_data = col_data + st * col_mat_w;
int seq_len = ed - st;
if (seq_len > up_pad + down_pad) {
// zero all up_pad and fill data
std::memset(dst_data, 0, up_pad * col_mat_w_sz);
dst_data = dst_data + up_pad * src_mat_w;
int copy_size = col_mat_w_sz - up_pad * src_mat_w_sz;
for (int j = 0; j < up_pad; ++j) {
// blas.VCOPY?
std::memcpy(dst_data, src_data, copy_size);
dst_data += (col_mat_w - src_mat_w);
copy_size += src_mat_w_sz;
}
// fill data
for (int j = 0; j < seq_len - up_pad - down_pad; ++j) {
std::memcpy(dst_data, src_data, copy_size);
dst_data += col_mat_w;
src_data += src_mat_w;
}
// zero all down_pad and fill data
std::memset(dst_data, 0, down_pad * col_mat_w_sz);
copy_size -= src_mat_w_sz;
for (int j = 0; j < down_pad; ++j) {
std::memcpy(dst_data, src_data, copy_size);
dst_data += col_mat_w;
src_data += src_mat_w;
copy_size -= src_mat_w_sz;
}
} else {
PADDLE_ENFORCE_GE(context_length, up_pad + down_pad + 1);
std::memset(dst_data, 0, seq_len * col_mat_w_sz);
dst_data = dst_data + up_pad * src_mat_w;
int zero_sz = up_pad * src_mat_w_sz;
int cur_src_sz = seq_len * src_mat_w_sz;
for (int j = 0; j < std::min(up_pad, seq_len); ++j) {
int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz);
std::memcpy(dst_data, src_data, copy_size);
dst_data += (col_mat_w - src_mat_w);
zero_sz -= src_mat_w_sz;
}
// from bottom
dst_data = col_data + ed * col_mat_w;
src_data = x_data + st * src_mat_w;
zero_sz = down_pad * src_mat_w_sz;
for (int j = 1; j <= std::min(down_pad, seq_len); ++j) {
int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz);
std::memcpy(dst_data - (zero_sz + copy_size) / sizeof(T),
src_data + std::max(seq_len - j - up_pad, 0) * src_mat_w,
copy_size);
dst_data -= col_mat_w;
zero_sz -= src_mat_w_sz;
}
}
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::FCCompute<DeviceContext, T>(blas, x_dims[0], w_dims[1], w_dims[0],
col_data, w_data, y_data, b_data, true);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seqconv_eltadd_relu, ops::FusionSeqConvEltAddReluOp,
ops::FusionSeqConvEltAddReluOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_seqconv_eltadd_relu,
ops::FusionSeqConvEltAddReluKernel<float>,
ops::FusionSeqConvEltAddReluKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionSeqConvEltAddReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionSeqConvEltAddReluOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
......@@ -39,11 +39,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0];
int64_t index_size = index.dims()[0];
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
......@@ -55,7 +53,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
for (int64_t i = 0; i < index_size; ++i) {
int index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
}
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
DECLARE_int32(paddle_num_threads);
......@@ -30,20 +31,25 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
if (B == NULL) {
return;
}
if (relu) {
const auto& vaddrelu = jitkernel::KernelPool::Instance()
.template Get<jitkernel::VAddReluKernel<T>>(N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
vaddrelu->Compute(B, dst, dst);
}
} else {
const auto& vadd = jitkernel::KernelPool::Instance()
.template Get<jitkernel::VAddKernel<T>>(N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), B, Y + i * N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
vadd->Compute(B, dst, dst);
}
}
if (!relu) {
return;
}
// TODO(TJ): fuse relu
LOG(FATAL) << "Not implemented!";
}
} // namespace math
......
......@@ -86,6 +86,12 @@ class VAddBiasKernel : public Kernel {
virtual void Compute(const T a, const T *x, T *y) const = 0;
};
template <typename T>
class VAddReluKernel : public Kernel {
public:
virtual void Compute(const T *x, const T *y, T *z) const = 0;
};
template <typename T>
class VActKernel : public Kernel {
public:
......
......@@ -378,11 +378,99 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void Compute(const T* x, T* y) const override {}
};
/* VAddRelu JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddReluKernelImpl : public VAddReluKernel<T> {
public:
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() { this->num_ = d; }
void Compute(const T* x, const T* y, T* z) const override {
for (int i = 0; i < this->num_; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VAddReluKernelImpl<float, isa, kEQ8>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 tmpx = _mm256_loadu_ps(x); \
__m256 tmpy = _mm256_loadu_ps(y); \
tmpy = _mm256_add_ps(tmpx, tmpy); \
tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \
_mm256_storeu_ps(z, tmpy); \
}
#define INTRI16_FLOAT(isa) \
template <> \
void VAddReluKernelImpl<float, isa, kEQ16>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(y); \
tmp0 = _mm256_add_ps(tmp0, tmp1); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp2 = _mm256_loadu_ps(y + 8); \
tmp1 = _mm256_add_ps(tmp1, tmp2); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(z, tmp0); \
_mm256_storeu_ps(z + 8, tmp1); \
}
#define INTRI_COMMON_FLOAT(isa, block) \
template <> \
VAddReluKernelImpl<float, isa, block>::VAddReluKernelImpl(int d) \
: VAddReluKernel<float>() { \
this->num_ = d; \
this->end_ = d - d % AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
} \
template <> \
void VAddReluKernelImpl<float, isa, block>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 zeros = _mm256_setzero_ps(); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmpx = _mm256_loadu_ps(x + i); \
__m256 tmpy = _mm256_loadu_ps(y + i); \
tmpy = _mm256_add_ps(tmpx, tmpy); \
tmpy = _mm256_max_ps(tmpy, zeros); \
_mm256_storeu_ps(z + i, tmpy); \
} \
for (int i = this->end_; i < this->num_; ++i) { \
z[i] = x[i] + y[i]; \
z[i] = z[i] > 0 ? z[i] : 0; \
} \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
INTRI_COMMON_FLOAT(jit::avx, kGT16);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
INTRI_COMMON_FLOAT(jit::avx2, kGT16);
#endif
#ifdef __AVX512F__
// TODO(TJ): refine avx512
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_COMMON_FLOAT
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(videntity, VIdentityKernel);
} // namespace jitkernel
......
......@@ -712,6 +712,63 @@ TEST(JitKernel, vadd) {
}
}
void vaddrelu_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
void vaddrelu_better(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z) {
vadd->Compute(x, y, z);
vrelu->Compute(z, z);
}
TEST(JitKernel, vaddrelu) {
namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data());
RandomVec<float>(d, y.data());
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VAddReluKernel<float>>(d);
const auto& vadd =
jit::KernelPool::Instance().template Get<jit::VAddKernel<float>>(d);
const auto& vrelu =
jit::KernelPool::Instance().template Get<jit::VReluKernel<float>>(d);
const float* x_data = x.data();
const float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vadd_ref(d, x_data, y_data, zref_data);
}
auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data);
}
auto tmkle = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better takes: " << (tmkle - tmkls) / repeat << " us, "
<< "tgt takes: " << (ttgte - ttgts) / repeat;
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4;
......
......@@ -35,6 +35,16 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
return it->second.get();
}
const std::vector<const DeviceContext*>
DeviceContextPool::GetAllDeviceContexts() const {
std::vector<const DeviceContext*> all_device_ctx;
all_device_ctx.reserve(device_contexts_.size());
for (auto& dev_ctx : device_contexts_) {
all_device_ctx.emplace_back(dev_ctx.second.get());
}
return all_device_ctx;
}
DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
......
......@@ -217,6 +217,9 @@ class DeviceContextPool {
/*! \brief Return handle of single device context. */
platform::DeviceContext* Get(const platform::Place& place);
/*! \brief Return all the device contexts. */
const std::vector<const DeviceContext*> GetAllDeviceContexts() const;
template <typename Place>
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
const Place& place) {
......
......@@ -30,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/string/printf.h"
DEFINE_bool(enable_rpc_profiler, false, "Enable rpc profiler or not.");
namespace paddle {
namespace platform {
......@@ -193,6 +195,13 @@ RecordEvent::~RecordEvent() {
PopEvent(name_, dev_ctx_);
}
RecordRPCEvent::RecordRPCEvent(const std::string& name,
const DeviceContext* dev_ctx) {
if (FLAGS_enable_rpc_profiler) {
event_.reset(new platform::RecordEvent(name, dev_ctx));
}
}
RecordBlock::RecordBlock(int block_id)
: is_enabled_(false), start_ns_(PosixInNsec()) {
std::lock_guard<std::mutex> l(profiler_mu);
......
......@@ -87,6 +87,16 @@ struct RecordEvent {
std::string full_name_;
};
class RecordRPCEvent {
public:
// dev_ctx can be set to nullptr if device is cpu.
RecordRPCEvent(const std::string& name, const DeviceContext* dev_ctx);
~RecordRPCEvent() {}
private:
std::unique_ptr<RecordEvent> event_;
};
struct RecordBlock {
explicit RecordBlock(int block_id);
~RecordBlock();
......
......@@ -120,6 +120,7 @@ def __bootstrap__():
read_env_flags.append('rpc_deadline')
read_env_flags.append('rpc_server_profile_period')
read_env_flags.append('rpc_server_profile_path')
read_env_flags.append('enable_rpc_profiler')
if core.is_compiled_with_cuda():
read_env_flags += [
......
......@@ -324,10 +324,19 @@ class LayerHelper(object):
raise ValueError("no Parameter name %s found" % name)
return param
def create_tmp_variable(self, dtype, stop_gradient=False):
def create_variable_for_type_inference(self, dtype, stop_gradient=False):
"""Create a temporary variable that should be type inferred layer.
Note:
The default type will be set to LOD_TENSOR. However, when
the var is used as operator output, its type will be updated
based on operator's `VarTypeInference` implementation in
infer_var_type.
"""
return self.main_program.current_block().create_var(
name=unique_name.generate(".".join([self.name, 'tmp'])),
dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=stop_gradient)
......@@ -388,7 +397,7 @@ class LayerHelper(object):
b = self.create_parameter(
attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True)
tmp = self.create_tmp_variable(dtype=input_var.dtype)
tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
self.append_op(
type='elementwise_add',
inputs={'X': [input_var],
......@@ -414,7 +423,7 @@ class LayerHelper(object):
tmp = input_var
# NOTE(dzhwinter): some activation support inplace compution.
if not core.IsInplace(act_type):
tmp = self.create_tmp_variable(dtype=input_var.dtype)
tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
self.append_op(
type=act_type,
inputs={"X": [input_var]},
......
......@@ -80,8 +80,8 @@ def split_lod_tensor(input, mask, level=0):
"""
helper = LayerHelper('split_lod_tensor', **locals())
out_true = helper.create_tmp_variable(dtype=input.dtype)
out_false = helper.create_tmp_variable(dtype=input.dtype)
out_true = helper.create_variable_for_type_inference(dtype=input.dtype)
out_false = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='split_lod_tensor',
inputs={
......@@ -131,7 +131,7 @@ def merge_lod_tensor(in_true, in_false, x, mask, level=0):
in_true=out_true, in_false=out_false, mask=y, x=x, level=level)
"""
helper = LayerHelper('merge_lod_tensor', **locals())
out = helper.create_tmp_variable(dtype=in_true.dtype)
out = helper.create_variable_for_type_inference(dtype=in_true.dtype)
helper.append_op(
type='merge_lod_tensor',
inputs={'X': x,
......@@ -524,7 +524,7 @@ class StaticRNN(object):
if not isinstance(o, Variable):
raise TypeError("step output takes a Variable")
tmp_o = self.helper.create_tmp_variable(dtype=o.dtype)
tmp_o = self.helper.create_variable_for_type_inference(dtype=o.dtype)
self.helper.append_op(
type='rnn_memory_helper',
inputs={'X': [o]},
......@@ -606,7 +606,8 @@ class StaticRNN(object):
pre_memories.append(mem.pre_mem.name)
mem_var = rnn_block.var(mem.mem.name)
assert isinstance(mem_var, Variable)
new_mem = self.helper.create_tmp_variable(dtype=mem_var.dtype)
new_mem = self.helper.create_variable_for_type_inference(
dtype=mem_var.dtype)
rnn_block.append_op(
type='rnn_memory_helper',
......@@ -813,7 +814,7 @@ def max_sequence_len(rank_table):
${out_comment}.
"""
helper = LayerHelper("max_seqence_len", **locals())
res = helper.create_tmp_variable(dtype="int64")
res = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="max_sequence_len",
inputs={"RankTable": rank_table},
......@@ -884,7 +885,7 @@ def array_to_lod_tensor(x, table):
lod_tensor = fluid.layers.array_to_lod_tensor(array, table)
"""
helper = LayerHelper("array_to_lod_tensor", **locals())
tmp = helper.create_tmp_variable(dtype=x.dtype)
tmp = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="array_to_lod_tensor",
inputs={'X': x,
......@@ -915,7 +916,7 @@ def increment(x, value=1.0, in_place=True):
"""
helper = LayerHelper("increment", **locals())
if not in_place:
out = helper.create_tmp_variable(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = x
helper.append_op(
......@@ -1012,7 +1013,7 @@ def less_than(x, y, force_cpu=None, cond=None, **ignored):
"""
helper = LayerHelper("less_than", **locals())
if cond is None:
cond = helper.create_tmp_variable(dtype='bool')
cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True
attrs = dict()
......@@ -1051,7 +1052,7 @@ def equal(x, y, cond=None, **ignored):
"""
helper = LayerHelper("equal", **locals())
if cond is None:
cond = helper.create_tmp_variable(dtype='bool')
cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True
helper.append_op(
......@@ -1098,7 +1099,7 @@ def array_read(array, i):
array,
Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError("array should be tensor array vairable")
out = helper.create_tmp_variable(dtype=array.dtype)
out = helper.create_variable_for_type_inference(dtype=array.dtype)
helper.append_op(
type='read_from_array',
inputs={'X': [array],
......@@ -1133,7 +1134,7 @@ def shrink_memory(x, i, table):
usage.
"""
helper = LayerHelper('shrink_memory', **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='shrink_rnn_memory',
inputs={'X': [x],
......@@ -1170,7 +1171,7 @@ def array_length(array):
"""
helper = LayerHelper('array_length', **locals())
tmp = helper.create_tmp_variable(dtype='int64')
tmp = helper.create_variable_for_type_inference(dtype='int64')
tmp.stop_gradient = True
helper.append_op(
type='lod_array_length', inputs={'X': [array]}, outputs={'Out': [tmp]})
......@@ -1590,7 +1591,7 @@ class DynamicRNN(object):
self.mem_dict = dict()
self.output_array = []
self.outputs = []
self.cond = self.helper.create_tmp_variable(dtype='bool')
self.cond = self.helper.create_variable_for_type_inference(dtype='bool')
self.cond.stop_gradient = False
self.while_op = While(self.cond)
self.input_array = []
......@@ -1924,7 +1925,7 @@ def reorder_lod_tensor_by_rank(x, rank_table):
helper.is_instance('x', Variable)
helper.is_instance('rank_table', Variable)
out = helper.create_tmp_variable(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='reorder_lod_tensor_by_rank',
inputs={'X': [x],
......@@ -1958,7 +1959,7 @@ def is_empty(x, cond=None, **ignored):
"""
helper = LayerHelper("is_empty", **locals())
if cond is None:
cond = helper.create_tmp_variable(dtype='bool')
cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True
elif not isinstance(cond, Variable):
raise TypeError("cond takes a variable")
......
......@@ -287,7 +287,8 @@ def detection_output(loc,
scores = nn.reshape(x=scores, shape=compile_shape, actual_shape=run_shape)
scores = nn.transpose(scores, perm=[0, 2, 1])
scores.stop_gradient = True
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
nmsed_outs = helper.create_variable_for_type_inference(
dtype=decoded_box.dtype)
helper.append_op(
type="multiclass_nms",
inputs={'Scores': scores,
......@@ -319,7 +320,7 @@ def iou_similarity(x, y, name=None):
"""
helper = LayerHelper("iou_similarity", **locals())
if name is None:
out = helper.create_tmp_variable(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
......@@ -356,7 +357,8 @@ def box_coder(prior_box,
helper = LayerHelper("box_coder", **locals())
if name is None:
output_box = helper.create_tmp_variable(dtype=prior_box.dtype)
output_box = helper.create_variable_for_type_inference(
dtype=prior_box.dtype)
else:
output_box = helper.create_variable(
name=name, dtype=prior_box.dtype, persistable=False)
......@@ -387,7 +389,7 @@ def polygon_box_transform(input, name=None):
"""
helper = LayerHelper("polygon_box_transform", **locals())
if name is None:
output = helper.create_tmp_variable(dtype=input.dtype)
output = helper.create_variable_for_type_inference(dtype=input.dtype)
else:
output = helper.create_variable(
name=name, dtype=prior_box.input, persistable=False)
......@@ -455,7 +457,7 @@ def detection_map(detect_res,
helper = LayerHelper("detection_map", **locals())
def __create_var(type):
return helper.create_tmp_variable(dtype=type)
return helper.create_variable_for_type_inference(dtype=type)
map_out = __create_var('float32')
accum_pos_count_out = out_states[0] if out_states else __create_var('int32')
......@@ -562,8 +564,9 @@ def bipartite_match(dist_matrix,
>>> matched_indices, matched_dist = fluid.layers.bipartite_match(iou)
"""
helper = LayerHelper('bipartite_match', **locals())
match_indices = helper.create_tmp_variable(dtype='int32')
match_distance = helper.create_tmp_variable(dtype=dist_matrix.dtype)
match_indices = helper.create_variable_for_type_inference(dtype='int32')
match_distance = helper.create_variable_for_type_inference(
dtype=dist_matrix.dtype)
helper.append_op(
type='bipartite_match',
inputs={'DistMat': dist_matrix},
......@@ -649,8 +652,8 @@ def target_assign(input,
gt, matched_indices, mismatch_value=0)
"""
helper = LayerHelper('target_assign', **locals())
out = helper.create_tmp_variable(dtype=input.dtype)
out_weight = helper.create_tmp_variable(dtype='float32')
out = helper.create_variable_for_type_inference(dtype=input.dtype)
out_weight = helper.create_variable_for_type_inference(dtype='float32')
helper.append_op(
type='target_assign',
inputs={
......@@ -821,9 +824,10 @@ def ssd_loss(location,
conf_loss = nn.reshape(
x=conf_loss, shape=(num, num_prior), actual_shape=actual_shape)
conf_loss.stop_gradient = True
neg_indices = helper.create_tmp_variable(dtype='int32')
neg_indices = helper.create_variable_for_type_inference(dtype='int32')
dtype = matched_indices.dtype
updated_matched_indices = helper.create_tmp_variable(dtype=dtype)
updated_matched_indices = helper.create_variable_for_type_inference(
dtype=dtype)
helper.append_op(
type='mine_hard_examples',
inputs={
......@@ -1003,8 +1007,8 @@ def prior_box(input,
max_sizes = [max_sizes]
attrs['max_sizes'] = max_sizes
box = helper.create_tmp_variable(dtype)
var = helper.create_tmp_variable(dtype)
box = helper.create_variable_for_type_inference(dtype)
var = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="prior_box",
inputs={"Input": input,
......@@ -1342,8 +1346,8 @@ def anchor_generator(input,
'offset': offset
}
anchor = helper.create_tmp_variable(dtype)
var = helper.create_tmp_variable(dtype)
anchor = helper.create_variable_for_type_inference(dtype)
var = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="anchor_generator",
inputs={"Input": input},
......@@ -1389,7 +1393,7 @@ def roi_perspective_transform(input,
"""
helper = LayerHelper('roi_perspective_transform', **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="roi_perspective_transform",
inputs={"X": input,
......@@ -1423,11 +1427,15 @@ def generate_proposal_labels(rpn_rois,
helper = LayerHelper('generate_proposal_labels', **locals())
rois = helper.create_tmp_variable(dtype=rpn_rois.dtype)
labels_int32 = helper.create_tmp_variable(dtype=gt_classes.dtype)
bbox_targets = helper.create_tmp_variable(dtype=rpn_rois.dtype)
bbox_inside_weights = helper.create_tmp_variable(dtype=rpn_rois.dtype)
bbox_outside_weights = helper.create_tmp_variable(dtype=rpn_rois.dtype)
rois = helper.create_variable_for_type_inference(dtype=rpn_rois.dtype)
labels_int32 = helper.create_variable_for_type_inference(
dtype=gt_classes.dtype)
bbox_targets = helper.create_variable_for_type_inference(
dtype=rpn_rois.dtype)
bbox_inside_weights = helper.create_variable_for_type_inference(
dtype=rpn_rois.dtype)
bbox_outside_weights = helper.create_variable_for_type_inference(
dtype=rpn_rois.dtype)
helper.append_op(
type="generate_proposal_labels",
......@@ -1509,8 +1517,10 @@ def generate_proposals(scores,
"""
helper = LayerHelper('generate_proposals', **locals())
rpn_rois = helper.create_tmp_variable(dtype=bbox_deltas.dtype)
rpn_roi_probs = helper.create_tmp_variable(dtype=scores.dtype)
rpn_rois = helper.create_variable_for_type_inference(
dtype=bbox_deltas.dtype)
rpn_roi_probs = helper.create_variable_for_type_inference(
dtype=scores.dtype)
helper.append_op(
type="generate_proposals",
inputs={
......
......@@ -954,7 +954,7 @@ def read_file(reader):
"""
helper = LayerHelper('read_file')
out = [
helper.create_tmp_variable(
helper.create_variable_for_type_inference(
stop_gradient=True, dtype='float32')
for _ in range(len(reader.desc.shapes()))
]
......
......@@ -202,10 +202,12 @@ def generate_layer_fn(op_type):
out_var = out[0] if (isinstance(out, list) or
isinstance(out, tuple)) else out
else:
out_var = helper.create_tmp_variable(dtype=dtype)
out_var = helper.create_variable_for_type_inference(dtype=dtype)
outputs[o_name] = [out_var]
for name in intermediate_output_names:
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
outputs[name] = [
helper.create_variable_for_type_inference(dtype=dtype)
]
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return helper.append_activation(out_var)
......@@ -229,7 +231,7 @@ def generate_layer_fn_noattr(op_type):
def func(x, name=None):
helper = LayerHelper(op_type, **locals())
output = helper.create_tmp_variable(dtype=x.dtype)
output = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output})
return output
......
......@@ -58,11 +58,11 @@ def accuracy(input, label, k=1, correct=None, total=None):
"""
helper = LayerHelper("accuracy", **locals())
topk_out, topk_indices = nn.topk(input, k=k)
acc_out = helper.create_tmp_variable(dtype="float32")
acc_out = helper.create_variable_for_type_inference(dtype="float32")
if correct is None:
correct = helper.create_tmp_variable(dtype="int64")
correct = helper.create_variable_for_type_inference(dtype="int64")
if total is None:
total = helper.create_tmp_variable(dtype="int64")
total = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="accuracy",
inputs={
......@@ -124,8 +124,8 @@ def auc(input,
auc_out=fluid.layers.auc(input=prediction, label=label)
"""
helper = LayerHelper("auc", **locals())
auc_out = helper.create_tmp_variable(dtype="float64")
batch_auc_out = helper.create_tmp_variable(dtype="float64")
auc_out = helper.create_variable_for_type_inference(dtype="float64")
batch_auc_out = helper.create_variable_for_type_inference(dtype="float64")
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
# for batch auc
......
此差异已折叠。
......@@ -152,7 +152,7 @@ def cast(x, dtype):
result = fluid.layers.cast(x=data, dtype='float64')
"""
helper = LayerHelper('cast', **locals())
out = helper.create_tmp_variable(dtype=dtype)
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='cast',
inputs={'X': [x]},
......@@ -184,7 +184,7 @@ def concat(input, axis=0, name=None):
out = fluid.layers.concat(input=[Efirst, Esecond, Ethird, Efourth])
"""
helper = LayerHelper('concat', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
type='concat',
inputs={'X': input},
......@@ -221,7 +221,8 @@ def sums(input, out=None):
"""
helper = LayerHelper('sum', **locals())
if out is None:
out = helper.create_tmp_variable(dtype=helper.input_dtype())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
helper.append_op(
type='sum',
inputs={'X': input},
......@@ -252,7 +253,7 @@ def assign(input, output=None):
"""
helper = LayerHelper('assign', **locals())
if output is None:
output = helper.create_tmp_variable(dtype=input.dtype)
output = helper.create_variable_for_type_inference(dtype=input.dtype)
if isinstance(input, Variable):
helper.append_op(
type='assign', inputs={'X': [input]}, outputs={'Out': [output]})
......@@ -311,7 +312,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
helper = LayerHelper("fill_constant", **locals())
if out is None:
out = helper.create_tmp_variable(dtype=dtype)
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='fill_constant',
inputs={},
......@@ -358,7 +359,7 @@ def fill_constant_batch_size_like(input,
${out_comment}.
"""
helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_tmp_variable(dtype=dtype)
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='fill_constant_batch_size_like',
inputs={'Input': input},
......@@ -396,7 +397,7 @@ def argmin(x, axis=0):
out = fluid.layers.argmin(x=in, axis=-1)
"""
helper = LayerHelper("arg_min", **locals())
out = helper.create_tmp_variable(VarDesc.VarType.INT64)
out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
helper.append_op(
type='arg_min',
inputs={'X': x},
......@@ -427,7 +428,7 @@ def argmax(x, axis=0):
out = fluid.layers.argmax(x=in, axis=-1)
"""
helper = LayerHelper("arg_max", **locals())
out = helper.create_tmp_variable(VarDesc.VarType.INT64)
out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
helper.append_op(
type='arg_max',
inputs={'X': x},
......@@ -477,8 +478,10 @@ def argsort(input, axis=-1, name=None):
out, indices = fluid.layers.argsort(input, axis=0)
"""
helper = LayerHelper("argsort", **locals())
out = helper.create_tmp_variable(dtype=input.dtype, stop_gradient=True)
ids = helper.create_tmp_variable(VarDesc.VarType.INT64, stop_gradient=True)
out = helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True)
ids = helper.create_variable_for_type_inference(
VarDesc.VarType.INT64, stop_gradient=True)
helper.append_op(
type='argsort',
inputs={'X': input},
......@@ -562,7 +565,7 @@ def reverse(x, axis):
if isinstance(axis, int):
axis = [axis]
helper = LayerHelper("reverse", **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='reverse',
inputs={'Input': x},
......@@ -654,7 +657,7 @@ def has_inf(x):
Variable: The tensor variable storing the output, only a bool value.
"""
helper = LayerHelper("isinf", **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type="isinf", inputs={"X": x}, outputs={"Out": out})
return out
......@@ -670,7 +673,7 @@ def has_nan(x):
Variable: The tensor variable storing the output, only a bool value.
"""
helper = LayerHelper("isnan", **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type="isnan", inputs={"X": x}, outputs={"Out": out})
return out
......@@ -687,6 +690,6 @@ def isfinite(x):
Variable: The tensor variable storing the output, contains a bool value.
"""
helper = LayerHelper("isfinite", **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out})
return out
......@@ -151,7 +151,7 @@ class L2DecayRegularizer(WeightDecayRegularizer):
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS)
type=core.VarDesc.VarType.LOD_TENSOR)
block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op(
......@@ -228,7 +228,7 @@ class L1DecayRegularizer(WeightDecayRegularizer):
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS)
type=core.VarDesc.VarType.LOD_TENSOR)
block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op(
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import random
from op_test import OpTest
from test_seq_conv import seqconv
class TestSeqConvEltAddRelu(OpTest):
def set_conf(self):
pass
def setUp(self):
self.op_type = 'fusion_seqconv_eltadd_relu'
self.lod = [[6, 4]]
self.in_fea_size = 16
self.out_fea_size = 8
self.context_length = 4
self.context_stride = 1
self.context_start = 0
self.set_conf()
assert self.context_stride == 1
T = sum(self.lod[0])
x = np.random.uniform(-1, 1, [T, self.in_fea_size]).astype('float32')
w = np.random.uniform(
-1, 1, [self.in_fea_size * self.context_length,
self.out_fea_size]).astype('float32')
b = np.random.uniform(-2, 1, [1, self.out_fea_size]).astype('float32')
out = seqconv(x, self.lod, w, self.context_length, self.context_start)
out = np.maximum(out + b, 0)
self.inputs = {'X': (x, self.lod), 'Filter': w, 'Bias': b}
self.attrs = {
'contextStart': self.context_start,
'contextLength': self.context_length,
'contextStride': self.context_stride
}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
class TestSeqConvEltAddReluBS1(TestSeqConvEltAddRelu):
def set_conf(self):
self.lod = [[10]]
class TestSeqConvEltAddReluBS1Case2(TestSeqConvEltAddRelu):
def set_conf(self):
self.lod = [[2]]
class TestSeqConvEltAddReluCase1(TestSeqConvEltAddRelu):
def set_conf(self):
self.lod = [[3, 5, 1, 6]]
self.context_length = 3
self.context_start = -2
class TestSeqConvEltAddReluCase2(TestSeqConvEltAddRelu):
def set_conf(self):
self.lod = [[10, 1, 2, 4, 1, 5, 6]]
self.in_fea_size = 2
self.context_length = 4
self.context_start = -1
class TestSeqConvEltAddReluCase3(TestSeqConvEltAddRelu):
def set_conf(self):
self.lod = [[10, 1, 2, 4, 1, 5, 6]]
self.context_length = 5
self.context_start = -4
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,53 @@ import random
from op_test import OpTest
def seqconv(x,
lod,
filter,
context_length,
context_start,
padding_trainable=False,
padding_data=None):
[T, M] = x.shape
col = np.zeros((T, context_length * M)).astype('float32')
offset = [0]
for seq_len in lod[0]:
offset.append(offset[-1] + seq_len)
begin_pad = np.max([0, -context_start])
for i in range(len(offset) - 1):
for j in range(context_length):
in_begin = offset[i] + context_start + j
in_end = offset[i + 1] + context_start + j
out_begin = offset[i]
out_end = offset[i + 1]
if in_begin < offset[i]:
pad_size = np.min(
[offset[i] - in_begin, offset[i + 1] - offset[i]])
if padding_trainable:
sub_w = padding_data[j:j + pad_size, :]
col[offset[i]:offset[i] + pad_size, j * M:(j + 1) *
M] = sub_w
out_begin = offset[i] + pad_size
in_begin = offset[i]
if in_end > offset[i + 1]:
pad_size = np.min(
[in_end - offset[i + 1], offset[i + 1] - offset[i]])
if padding_trainable:
sub_w = padding_data[begin_pad + context_start + j -
pad_size:begin_pad + context_start +
j, :]
col[offset[i + 1] - pad_size:offset[i + 1], j * M:(j + 1) *
M] = sub_w
in_end = offset[i + 1]
out_end = offset[i + 1] - pad_size
if in_end <= in_begin:
continue
in_sub = x[in_begin:in_end, :]
col[out_begin:out_end, j * M:(j + 1) * M] += in_sub
return np.dot(col, filter)
class TestSeqProject(OpTest):
def setUp(self):
self.init_test_case()
......@@ -66,57 +113,9 @@ class TestSeqProject(OpTest):
'paddingTrainable': self.padding_trainable,
'contextStride': self.context_stride
}
out = np.zeros(
(self.input_size[0], self.output_represention)).astype('float32')
out = seqconv(x, self.lod, w, self.context_length, self.context_start,
self.padding_trainable, self.pad_data)
self.outputs = {'Out': out}
self.compute()
def compute(self):
x, lod = self.inputs['X']
filter = self.inputs['Filter']
pading_data = self.pad_data
out = np.zeros((self.input_size[0], self.context_length *
self.input_size[1])).astype('float32')
offset = [0]
for seq_len in lod[0]:
offset.append(offset[-1] + seq_len)
begin_pad = np.max([0, -self.context_start])
for i in range(len(offset) - 1):
for j in range(self.context_length):
in_begin = offset[i] + self.context_start + j
in_end = offset[i + 1] + self.context_start + j
out_begin = offset[i]
out_end = offset[i + 1]
if in_begin < offset[i]:
pad_size = np.min(
[offset[i] - in_begin, offset[i + 1] - offset[i]])
if self.padding_trainable:
sub_w = pading_data[j:j + pad_size, :]
out[offset[i]:offset[i] + pad_size, j * self.input_size[
1]:(j + 1) * self.input_size[1]] = sub_w
out_begin = offset[i] + pad_size
in_begin = offset[i]
if in_end > offset[i + 1]:
pad_size = np.min(
[in_end - offset[i + 1], offset[i + 1] - offset[i]])
if self.padding_trainable:
sub_w = pading_data[begin_pad + self.context_start + j -
pad_size:begin_pad +
self.context_start + j, :]
out[offset[i + 1] - pad_size:offset[i + 1], j * self.
input_size[1]:(j + 1) * self.input_size[1]] = sub_w
in_end = offset[i + 1]
out_end = offset[i + 1] - pad_size
if in_end <= in_begin:
continue
in_sub = x[in_begin:in_end, :]
out[out_begin:out_end, j * self.input_size[1]:(j + 1) *
self.input_size[1]] += in_sub
np.dot(out, filter, out=self.outputs['Out'])
def test_check_output(self):
self.check_output()
......
......@@ -30,7 +30,6 @@ class TestSliceVar(unittest.TestCase):
var = program.global_block().create_var(
name=str(random.randint(10000, 99999)),
persistable=True,
# dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape)
var_list.append(var)
blocks = slice_variable(var_list, 10, min_size)
......
......@@ -74,7 +74,7 @@ class InferenceTranspiler(object):
'''
Transpile the program fusing elementwise_add into conv for MKLDNN
program. Elementwise add following convolution OP can be fused by adding
'fuse_eltwise' attribute to convolution OP and replacing its output
'fuse_residual_connection' attribute to convolution OP and replacing its output
Tensor with second parameter of elementwise_add.
The result of fuse is:
- before:
......@@ -92,7 +92,8 @@ class InferenceTranspiler(object):
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'elementwise_add':
self._fuse_conv_eltwise(current_op, next_op)
self._fuse_conv_eltwise(i, current_op, next_op)
self.block._remove_op(i + 1) # Remove old conv
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
self._adjust_input()
......@@ -444,7 +445,7 @@ class InferenceTranspiler(object):
outputs={"Output": out_var},
attrs=attrs)
def _fuse_conv_eltwise(self, conv_op, eltwise_op):
def _fuse_conv_eltwise(self, index, conv_op, eltwise_op):
'''
fuse the conv op with elementwise_add
......@@ -454,9 +455,30 @@ class InferenceTranspiler(object):
:type eltwise_op: Operator
'''
conv_op._set_attr("fuse_eltwise", True)
self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0]
self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0]
eltwise_input = "X"
if eltwise_op.input("X")[0] == conv_op.output("Output")[0]:
eltwise_input = "Y"
residual_var = self.block.vars[eltwise_op.input(eltwise_input)[0]]
out_var = self.block.vars[eltwise_op.output("Out")[0]]
filter_var = self.block.vars[conv_op.input("Filter")[0]]
in_var = self.block.vars[conv_op.input("Input")[0]]
bias_var = self.block.vars[conv_op.input("Bias")[0]]
conv_op._set_attr("fuse_residual_connection", True)
attrs = {name: conv_op.attr(name) for name in conv_op.attr_names}
self.block._insert_op(
index,
type="conv2d",
inputs={
"Input": in_var,
"Filter": filter_var,
"Bias": bias_var,
"ResidualData": residual_var
},
outputs={"Output": out_var},
attrs=attrs)
def _adjust_input(self):
for i in range(len(self.block.ops)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册