提交 6447b69a 编写于 作者: G guosheng

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into add-reshape-reuse-input

test=develop
...@@ -261,6 +261,13 @@ function(cc_library TARGET_NAME) ...@@ -261,6 +261,13 @@ function(cc_library TARGET_NAME)
add_dependencies(${TARGET_NAME} mklml) add_dependencies(${TARGET_NAME} mklml)
target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed") target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed")
endif() endif()
# remove link to python, see notes at:
# https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually
if("${cc_library_DEPS};" MATCHES "python;")
list(REMOVE_ITEM cc_library_DEPS python)
add_dependencies(${TARGET_NAME} python)
target_link_libraries(${TARGET_NAME} "-Wl,-undefined,dynamic_lookup")
endif()
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
endif() endif()
......
...@@ -116,6 +116,7 @@ paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], var ...@@ -116,6 +116,7 @@ paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], var
paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None)) paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None))
paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None)) paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None))
paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)) paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0))
paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None))
paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)) paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,))
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR')) paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR'))
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)) paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
......
...@@ -49,6 +49,8 @@ struct VarHandleBase { ...@@ -49,6 +49,8 @@ struct VarHandleBase {
void AddOutput(OpHandleBase* out, ir::Node* node) { void AddOutput(OpHandleBase* out, ir::Node* node) {
if (pending_ops_.find(out) == pending_ops_.end()) { if (pending_ops_.find(out) == pending_ops_.end()) {
PADDLE_ENFORCE(out != nullptr, "The output of %s should not be nullptr",
this->Node()->Name());
pending_ops_.insert(out); pending_ops_.insert(out);
node_->outputs.push_back(node); node_->outputs.push_back(node);
} }
......
...@@ -37,12 +37,17 @@ pass_library(embedding_fc_lstm_fuse_pass inference) ...@@ -37,12 +37,17 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference) pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference)
pass_library(conv_bn_fuse_pass inference) pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
if(WITH_MKLDNN) if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base) pass_library(mkldnn_placement_pass base)
pass_library(conv_bias_mkldnn_fuse_pass inference)
pass_library(conv_relu_mkldnn_fuse_pass inference) pass_library(conv_relu_mkldnn_fuse_pass inference)
endif() endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
if(WITH_MKLDNN)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
endif()
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
...@@ -56,4 +61,5 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g ...@@ -56,4 +61,5 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
if (WITH_MKLDNN) if (WITH_MKLDNN)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
endif () endif ()
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
#include <functional>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
template <typename BinaryOperation>
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
BinaryOperation f) {
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
LoDTensor vec_y;
vec_y.Resize(vec_a.dims());
const float* a = vec_a.data<float>();
const float* b = vec_b.data<float>();
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < vec_a.numel(); i++) {
y[i] = f(a[i], b[i]);
}
return vec_y;
}
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
conv_bias_pattern(conv_input);
int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvBias fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_bias_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_bias_pattern); // CONV op
// bias
GET_IR_NODE_FROM_SUBGRAPH(eltwise_bias, eltwise_bias, conv_bias_pattern);
// output
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
// elementwise_add op
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
PADDLE_ENFORCE(subgraph.count(conv_input));
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
VLOG(3) << "do not perform conv+bias fuse";
return;
}
auto* eltwise_bias_tensor =
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
auto input_names = conv->Op()->InputNames();
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
input_names.end();
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
auto conv_bias_names = conv->Op()->Input("Bias");
// add eltwise bias to existing conv bias
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
*conv_bias_tensor = tensor_apply_eltwise(
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
conv->Op()->SetOutput("Output",
std::vector<std::string>({eltwise_out->Name()}));
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
IR_NODE_LINK_TO(conv, eltwise_out);
} else {
// take eltwise bias as conv bias
OpDesc desc;
desc.SetInput(
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
desc.SetType("conv2d");
for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second);
}
auto conv_bias_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
}
found_conv_bias_count++;
};
gpd(graph.get(), handler);
AddStatis(found_conv_bias_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
paddle::framework::ir::ConvBiasFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the Conv and Elementwise_add to a ConvBiasOp.
*/
class ConvBiasFusePass : public FusePassBase {
public:
virtual ~ConvBiasFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include <utility>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
// The function keeps the graph consistent by replacing
// a node 'from' in the set of inputs nodes
// of the visited node by a node 'to'.
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) {
auto from_in_inputs =
std::find(std::begin(node.inputs), std::end(node.inputs), from);
if (from_in_inputs != std::end(node.inputs)) {
IR_NODE_LINK_TO(to, (&node));
auto inputs = node.Op()->Inputs();
using input_type = VariableNameMap::value_type;
std::for_each(std::begin(inputs), std::end(inputs),
[from, to, &node](const input_type& i) -> void {
auto param_names = i.second;
auto pi = std::find(std::begin(param_names),
std::end(param_names), from->Name());
if (pi != std::end(param_names)) {
node.Op()->SetInput(i.first, {to->Name()});
}
});
}
}
}
} // namespace
using graph_ptr = std::unique_ptr<ir::Graph>;
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get());
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope_};
auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
elementwise_add_pattern(conv_output);
conv_output->AsIntermediate();
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
auto bias_input_names = conv_op.Op()->Inputs();
auto bias_it = bias_input_names.find("Bias");
if (bias_it != std::end(bias_input_names)) {
bool has_bias = !bias_it->second.empty();
if (has_bias) {
auto conv_bias_names = bias_it->second;
auto conv_bias_names_it =
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs),
[&conv_bias_names](Node* n) -> bool {
return n->Name() == conv_bias_names[0];
});
return std::make_pair(has_bias, *conv_bias_names_it);
}
}
return std::make_pair(false, nullptr);
};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
bool has_bias;
Node* conv_bias;
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});
}
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = g->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(conv_input, fused_conv_op);
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (has_bias) {
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
}
CorrectGraphEdges(g, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase {
public:
virtual ~ConvElementwiseAddMKLDNNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"residual_connections_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
constexpr int nodes_removed = 3;
constexpr int nodes_added = 1;
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::pair<std::string, std::string>>& inputs,
const std::pair<std::string, std::string>& output) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", true);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
op->SetOutput(output.first, {output.second});
}
struct IsReachable {
using func = std::function<bool(const std::string&, const std::string&)>;
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
auto find_node = [](const std::unique_ptr<ir::Graph>& graph,
const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(*graph)) {
if (name == node.Name()) {
return &node;
}
}
return nullptr;
};
return [&](std::string from, const std::string to) -> bool {
if (from == to) return true;
std::map<std::string, bool> visited;
for (auto& node : GraphTraits::DFS(*graph)) {
visited[node.Name()] = false;
}
visited[from] = true;
std::list<std::string> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr) return false;
for (auto n : cur->outputs) {
if (n->Name() == to) return true;
if (!visited[n->Name()]) {
visited[n->Name()] = true;
queue.push_back(n->Name());
}
}
}
return false;
};
}
};
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
}
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog;
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR);
return var;
};
for (const auto& v : transient_vars) {
add_var_to_prog(v);
}
for (const auto& v : persistent_vars) {
auto var = add_var_to_prog(v);
var->SetPersistable(true);
}
return prog;
}
} // namespace
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass,
ConvolutionWithElementwiseAddReluNoBias) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "d"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_FALSE(is_reachable(graph)("a", "d"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "f"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "f"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_elementwise_add_mkldnn_fuse_pass);
...@@ -761,6 +761,51 @@ PDNode *patterns::ConvReLU::operator()( ...@@ -761,6 +761,51 @@ PDNode *patterns::ConvReLU::operator()(
return relu_out_var; return relu_out_var;
} }
PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
seqconv_input->assert_is_op_input("sequence_conv", "X");
auto *seqconv_op = pattern->NewNode(seqconv_repr())
->assert_is_op("sequence_conv")
->assert_op_attr<bool>("paddingTrainable", false)
->assert_op_attr<int>("contextStride", 1);
auto *eltadd_op =
pattern->NewNode(eltadd_repr())->assert_is_op("elementwise_add");
auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu");
// Create variables
// Filter
auto *seqconv_weight_var =
pattern->NewNode(seqconv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("sequence_conv", "Filter");
// Bias
auto *eltadd_bias_var = pattern->NewNode(eltadd_bias_repr())
->AsInput()
->assert_is_op_input("elementwise_add");
// intermediate variable, will be removed in the IR after fuse.
auto *seqconv_out_var = pattern->NewNode(seqconv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("sequence_conv")
->assert_is_op_input("elementwise_add");
auto *eltadd_out_var = pattern->NewNode(eltadd_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("elementwise_add")
->assert_is_only_input_of_op("relu");
// output
auto *relu_out_var = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu");
seqconv_op->LinksFrom({seqconv_input, seqconv_weight_var})
.LinksTo({seqconv_out_var});
eltadd_op->LinksFrom({seqconv_out_var, eltadd_bias_var})
.LinksTo({eltadd_out_var});
relu_op->LinksFrom({eltadd_out_var}).LinksTo({relu_out_var});
return relu_out_var;
}
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) { bool with_bias) {
// Create shared nodes. // Create shared nodes.
...@@ -966,6 +1011,79 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( ...@@ -966,6 +1011,79 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
return ele_add_grad; return ele_add_grad;
} }
PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto *eltiwse_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("elementwise_add");
// Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("elementwise_add", "Y");
// output
auto *eltwise_out_var = pattern->NewNode(eltwise_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
eltiwse_op->LinksFrom({conv_out_var, eltwise_bias_var})
.LinksTo({eltwise_out_var});
return eltwise_out_var;
}
PDNode *patterns::Conv::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto input_var = pattern->NewNode(conv_input_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input");
auto filter_var = pattern->NewNode(conv_filter_repr())
->AsInput()
->assert_is_op_input("conv2d", "Filter");
auto output_var = pattern->NewNode(conv_output_repr())
->AsOutput()
->assert_is_op_output("conv2d", "Output");
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
return output_var;
}
PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var) {
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
x_var->assert_is_op_input("elementwise_add", "X");
auto y_var = pattern->NewNode(elementwise_add_x_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto out_var = pattern->NewNode(elementwise_add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");
elementwise_add_op->LinksFrom({x_var, y_var});
elementwise_add_op->LinksTo({out_var});
return out_var;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -128,6 +128,15 @@ struct PDNode { ...@@ -128,6 +128,15 @@ struct PDNode {
const std::unordered_set<std::string>& op_types, const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth); const std::string& argument, int nth);
template <typename T>
PDNode* assert_op_attr(const std::string& attr_name, const T& attr) {
asserts_.emplace_back([=](Node* x) {
return x && x->IsOp() && x->Op()->HasAttr(attr_name) &&
boost::get<T>(x->Op()->GetAttr(attr_name)) == attr;
});
return this;
}
private: private:
PDNode(PDPattern* pattern, const std::string& name = "", PDNode(PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar) Type type = Type::kVar)
...@@ -434,6 +443,31 @@ struct ConvReLU : public PatternBase { ...@@ -434,6 +443,31 @@ struct ConvReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out); PATTERN_DECL_NODE(relu_out);
}; };
// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
// seqconv_input, seqconv_weight,
// seqconv_out, seqconv,
// elementwise_add_bias, elementwise_add_out, elementwise_add
// relu_out, relu
struct SeqConvEltAddRelu : public PatternBase {
SeqConvEltAddRelu(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "seqconv_eltadd_relu") {}
PDNode* operator()(PDNode* seqconv_input);
// declare operator node's name
PATTERN_DECL_NODE(seqconv);
PATTERN_DECL_NODE(eltadd);
PATTERN_DECL_NODE(relu);
// declare variable node's name
PATTERN_DECL_NODE(seqconv_weight);
PATTERN_DECL_NODE(seqconv_out);
PATTERN_DECL_NODE(eltadd_bias);
PATTERN_DECL_NODE(eltadd_out);
PATTERN_DECL_NODE(relu_out);
};
// FC with bias // FC with bias
// op: mul + elementwise_add // op: mul + elementwise_add
// named nodes: // named nodes:
...@@ -578,6 +612,65 @@ struct ElewiseAddActInplaceGrad : public PatternBase { ...@@ -578,6 +612,65 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
PATTERN_DECL_NODE(d_ele_y); PATTERN_DECL_NODE(d_ele_y);
PATTERN_DECL_NODE(ele_y); PATTERN_DECL_NODE(ele_y);
}; };
// Conv with Elementwise_add as bias
// op: conv + elementwise_add
// named nodes:
// conv_input, conv_weight,
// conv_out, conv,
// eltwise_bias, eltwise_out,
// elementwise_add
struct ConvBias : public PatternBase {
ConvBias(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bias") {}
PDNode* operator()(PDNode* conv_input);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(eltwise);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(eltwise_bias);
PATTERN_DECL_NODE(eltwise_out);
};
// Convolution op
// Forward pass for convolution.
// conv_input, conv_bias and conv_filter are inputs.
// conv_output is a result of the operator.
// residual_data is data used by skip connection.
// If residual connection fusion is on, the formula is:
// conv_output = conv_op(conv_filter, conv_input, conv_bias)
// + conv_residual_data
// If the fusion is off, conv_residual_data is not added.
struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "convolution") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_residual_data);
PATTERN_DECL_NODE(conv_output);
};
// ElementwiseAdd used in residual connections.
// y_var is used and convolution output.
// The operator is removed, when residual
// connection fusion is on.
struct ElementwiseAdd : public PatternBase {
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* x_var);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_x);
PATTERN_DECL_NODE(elementwise_add_y);
PATTERN_DECL_NODE(elementwise_add_out);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
namespace ir {
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X"))
->assert_is_op_input("sequence_conv")
->assert_var_not_persistable();
patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope);
fuse_pattern(x);
// Create New OpDesc
auto fuse_creator = [&](Node* seqconv, Node* input, Node* seqconv_weight,
Node* eltadd_bias, Node* relu_out) {
OpDesc op_desc;
op_desc.SetType("fusion_seqconv_eltadd_relu");
op_desc.SetInput("X", {input->Name()});
op_desc.SetInput("Filter", {seqconv_weight->Name()});
op_desc.SetInput("Bias", {eltadd_bias->Name()});
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
op_desc.SetOutput("ColMat", {ColMat});
op_desc.SetOutput("Out", {relu_out->Name()});
scope->Var(ColMat)->GetMutable<LoDTensor>();
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op);
IR_NODE_LINK_TO(seqconv_weight, op);
IR_NODE_LINK_TO(eltadd_bias, op);
IR_NODE_LINK_TO(op, relu_out);
return op;
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle SeqConv EltAdd Relu fuse";
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd, eltadd, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_bias, eltadd_bias, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_out, eltadd_out, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, fuse_pattern);
fuse_creator(seqconv, subgraph.at(x), seqconv_weight, eltadd_bias,
relu_out);
std::unordered_set<const Node*> marked_nodes(
{seqconv, seqconv_out, eltadd, eltadd_out, relu});
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
std::unique_ptr<ir::Graph> SeqConvEltAddReluFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope());
AddStatis(fusion_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(seqconv_eltadd_relu_fuse_pass,
paddle::framework::ir::SeqConvEltAddReluFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class SeqConvEltAddReluFusePass : public FusePassBase {
public:
virtual ~SeqConvEltAddReluFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"seqconv_eltadd_relu_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -299,6 +299,12 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -299,6 +299,12 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
} }
ParallelExecutor::~ParallelExecutor() { ParallelExecutor::~ParallelExecutor() {
const auto dev_ctxs =
platform::DeviceContextPool::Instance().GetAllDeviceContexts();
for (auto &dev_ctx : dev_ctxs) {
dev_ctx->Wait();
}
if (member_->own_local_scope_) { if (member_->own_local_scope_) {
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) { for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
Scope *local_scope = member_->local_scopes_[i]; Scope *local_scope = member_->local_scopes_[i];
......
...@@ -101,10 +101,12 @@ Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); } ...@@ -101,10 +101,12 @@ Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) { void Analyzer::Run(Argument* argument) {
std::vector<std::string> passes; std::vector<std::string> passes;
#ifdef PADDLE_WITH_MKLDNN
if (use_mkldnn_) { if (use_mkldnn_) {
VLOG(3) << "Adding MKL-DNN placement pass"; VLOG(3) << "Adding MKL-DNN placement pass";
passes.push_back("mkldnn_placement_pass"); passes.push_back("mkldnn_placement_pass");
} }
#endif
for (auto& pass : ir_passes_) { for (auto& pass : ir_passes_) {
if (!disabled_ir_passes_.count(pass)) { if (!disabled_ir_passes_.count(pass)) {
passes.push_back(pass); passes.push_back(pass);
......
...@@ -69,6 +69,7 @@ class Analyzer : public OrderedRegistry<PassManager> { ...@@ -69,6 +69,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
// Manual update the passes here. // Manual update the passes here.
"infer_clean_graph_pass", // "infer_clean_graph_pass", //
"attention_lstm_fuse_pass", // "attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", //
"embedding_fc_lstm_fuse_pass", // "embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", // "fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", // "mul_lstm_fuse_pass", //
...@@ -79,7 +80,9 @@ class Analyzer : public OrderedRegistry<PassManager> { ...@@ -79,7 +80,9 @@ class Analyzer : public OrderedRegistry<PassManager> {
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
"conv_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", // "conv_relu_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", //
#endif #endif
}}; }};
......
...@@ -77,10 +77,6 @@ bool AnalysisPredictor::Init( ...@@ -77,10 +77,6 @@ bool AnalysisPredictor::Init(
inference_program_ = program; inference_program_ = program;
} }
if (config_._use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
executor_->Prepare(scope_.get(), *inference_program_, 0, executor_->Prepare(scope_.get(), *inference_program_, 0,
config_.use_feed_fetch_ops); config_.use_feed_fetch_ops);
......
...@@ -18,12 +18,12 @@ namespace paddle { ...@@ -18,12 +18,12 @@ namespace paddle {
namespace inference { namespace inference {
using namespace framework; // NOLINT using namespace framework; // NOLINT
static std::vector<float> result_data;
struct DataRecord { struct DataRecord {
std::vector<std::vector<std::vector<float>>> link_step_data_all; std::vector<std::vector<std::vector<float>>> link_step_data_all;
std::vector<size_t> lod; std::vector<size_t> lod;
std::vector<std::vector<float>> rnn_link_data; std::vector<std::vector<float>> rnn_link_data;
std::vector<float> result_data;
size_t num_samples; // total number of samples size_t num_samples; // total number of samples
size_t batch_iter{0}; size_t batch_iter{0};
size_t batch_size{1}; size_t batch_size{1};
...@@ -57,6 +57,7 @@ struct DataRecord { ...@@ -57,6 +57,7 @@ struct DataRecord {
std::ifstream file(path); std::ifstream file(path);
std::string line; std::string line;
int num_lines = 0; int num_lines = 0;
result_data.clear();
while (std::getline(file, line)) { while (std::getline(file, line)) {
num_lines++; num_lines++;
std::vector<std::string> data; std::vector<std::string> data;
...@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) { ...@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) {
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result // the first inference result
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
PADDLE_ENFORCE_GT(outputs.size(), 0); PADDLE_ENFORCE_GT(outputs.size(), 0);
size_t size = GetSize(outputs[0]); size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(size, 0); PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(outputs[0].data.data()); float *result = static_cast<float *>(outputs[0].data.data());
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(result[i], data.result_data[i], 1e-3); EXPECT_NEAR(result[i], result_data[i], 1e-3);
} }
} }
} }
......
...@@ -183,7 +183,13 @@ TEST(Analyzer_seq_conv1, fuse_statis) { ...@@ -183,7 +183,13 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
SetConfig(&cfg); SetConfig(&cfg);
int num_ops; int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg); auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
GetFuseStatis(predictor.get(), &num_ops);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 2);
EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6);
EXPECT_EQ(num_ops, 32);
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
......
...@@ -86,7 +86,7 @@ function(op_library TARGET) ...@@ -86,7 +86,7 @@ function(op_library TARGET)
# remove windows unsupported op, because windows has no nccl, no warpctc such ops. # remove windows unsupported op, because windows has no nccl, no warpctc such ops.
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op" foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op"
"crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op" "crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op"
"channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op") "fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
if ("${TARGET}" STREQUAL "${windows_unsupport_op}") if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return() return()
endif() endif()
......
...@@ -300,10 +300,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -300,10 +300,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_eltwise = ctx.Attr<bool>("fuse_eltwise"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
// TODO: add support for dilation // TODO(tpatejko): add support for dilation
PADDLE_ENFORCE( PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet"); "dilation in convolution is not implemented yet");
...@@ -369,11 +369,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -369,11 +369,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
fuse_relu, fuse_eltwise); fuse_relu, fuse_residual_conn);
} else { } else {
conv_pd = conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_eltwise); mkldnn_engine, fuse_relu, fuse_residual_conn);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -386,8 +386,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -386,8 +386,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_weights_memory_p = handler.AcquireWeightsMemory( auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data)); user_weights_md, to_void_cast<T>(filter_data));
T* output_data = T* output_data = nullptr;
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the "
"same dimension sizes");
output->ShareDataWith(*residual_param);
output_data = output->mutable_data<T>(ctx.GetPlace());
} else {
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
}
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto src_memory_p = auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
...@@ -424,14 +442,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -424,14 +442,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private: private:
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
bool fuse_eltwise) const { bool fuse_residual_conn) const {
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with // Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_eltwise is true, the // the scale parameter. It is assumed that when fuse_residual_connection is
// Output tensor contains the data coming from residual connection. The // true, the output tensor contains the data coming from residual
// result of this post_op is: Output = scale * Output + Conv_Out. // connection. The result of this post_op is:
if (fuse_eltwise) { // Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(1.0f); post_operations.append_sum(1.0f);
} }
// Fusion with ReLU layer is executed through the PostOps feature. Create a // Fusion with ReLU layer is executed through the PostOps feature. Create a
...@@ -452,7 +471,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -452,7 +471,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_eltwise) const { const bool fuse_residual_conn) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -461,7 +480,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -461,7 +480,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise); mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
...@@ -476,7 +496,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -476,7 +496,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_eltwise) const { const bool fuse_residual_conn) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -485,7 +505,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -485,7 +505,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims, bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise); mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
......
...@@ -132,6 +132,11 @@ void Conv2DOpMaker::Make() { ...@@ -132,6 +132,11 @@ void Conv2DOpMaker::Make() {
"(Tensor) The output tensor of convolution operator. " "(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.") "The format of output tensor is also NCHW.")
.Reuse("Input"); .Reuse("Input");
AddInput("ResidualData",
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"Used with fuse_residual_connection fusion.")
.AsDispensable();
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the " "(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of " "strides(h_stride, w_stride) of "
...@@ -164,10 +169,10 @@ void Conv2DOpMaker::Make() { ...@@ -164,10 +169,10 @@ void Conv2DOpMaker::Make() {
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel") AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_eltwise", AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used " "(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is connected via skip connection " "whenever convolution output is as an input to residual "
"to a previous layer.") "connection.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
......
...@@ -20,7 +20,7 @@ detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu) ...@@ -20,7 +20,7 @@ detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op.cu) iou_similarity_op.cu)
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc) detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc) detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc poly_util.cc gpc.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
detection_library(anchor_generator_op SRCS anchor_generator_op.cc detection_library(anchor_generator_op SRCS anchor_generator_op.cc
anchor_generator_op.cu) anchor_generator_op.cu)
......
...@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cmath>
#include <cstring>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -25,21 +27,17 @@ namespace operators { ...@@ -25,21 +27,17 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
struct AppendProposalsFunctor { static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
LoDTensor *out_;
int64_t offset_;
Tensor *to_add_;
AppendProposalsFunctor(LoDTensor *out, int64_t offset, Tensor *to_add) static void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) {
: out_(out), offset_(offset), to_add_(to_add) {} auto *out_data = dst->data<void>();
auto *to_add_data = src.data<void>();
template <typename T> size_t size_of_t = framework::SizeOfType(src.type());
void apply() const { offset *= size_of_t;
auto *out_data = out_->data<T>(); std::memcpy(
auto *to_add_data = to_add_->data<T>(); reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(out_data) + offset),
memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T)); to_add_data, src.numel() * size_of_t);
} }
};
class GenerateProposalsOp : public framework::OperatorWithKernel { class GenerateProposalsOp : public framework::OperatorWithKernel {
public: public:
...@@ -75,8 +73,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { ...@@ -75,8 +73,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
}; };
template <class T> template <class T>
void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, static inline void BoxCoder(const platform::DeviceContext &ctx,
Tensor *bbox_deltas, Tensor *variances, Tensor *proposals) { Tensor *all_anchors, Tensor *bbox_deltas,
Tensor *variances, Tensor *proposals) {
T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace()); T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace());
int64_t row = all_anchors->dims()[0]; int64_t row = all_anchors->dims()[0];
...@@ -108,11 +107,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, ...@@ -108,11 +107,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
anchor_center_y; anchor_center_y;
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] * bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2], bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) * kBBoxClipDefault)) *
anchor_width; anchor_width;
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] * bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3], bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) * kBBoxClipDefault)) *
anchor_height; anchor_height;
} else { } else {
bbox_center_x = bbox_center_x =
...@@ -120,10 +119,10 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, ...@@ -120,10 +119,10 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
bbox_center_y = bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y; bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2], bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) * kBBoxClipDefault)) *
anchor_width; anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3], bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) * kBBoxClipDefault)) *
anchor_height; anchor_height;
} }
...@@ -136,30 +135,32 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, ...@@ -136,30 +135,32 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
} }
template <class T> template <class T>
void ClipTiledBoxes(const platform::DeviceContext &ctx, const Tensor &im_info, static inline void ClipTiledBoxes(const platform::DeviceContext &ctx,
Tensor *boxes) { const Tensor &im_info, Tensor *boxes) {
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace()); T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
const T *im_info_data = im_info.data<T>(); const T *im_info_data = im_info.data<T>();
T zero(0);
for (int64_t i = 0; i < boxes->numel(); ++i) { for (int64_t i = 0; i < boxes->numel(); ++i) {
if (i % 4 == 0) { if (i % 4 == 0) {
boxes_data[i] = boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f); std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else if (i % 4 == 1) { } else if (i % 4 == 1) {
boxes_data[i] = boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f); std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
} else if (i % 4 == 2) { } else if (i % 4 == 2) {
boxes_data[i] = boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f); std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else { } else {
boxes_data[i] = boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f); std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
} }
} }
} }
template <class T> template <class T>
void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes, static inline void FilterBoxes(const platform::DeviceContext &ctx,
float min_size, const Tensor &im_info, Tensor *keep) { Tensor *boxes, float min_size,
const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>(); const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace()); T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
T im_scale = im_info_data[2]; T im_scale = im_info_data[2];
...@@ -185,24 +186,24 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes, ...@@ -185,24 +186,24 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
keep->Resize({keep_len}); keep->Resize({keep_len});
} }
bool SortScorePairDescend(const std::pair<float, int> &pair1,
const std::pair<float, int> &pair2) {
return pair1.first > pair2.first;
}
template <class T> template <class T>
void GetMaxScoreIndex(const std::vector<T> &scores, static inline std::vector<std::pair<T, int>> GetSortedScoreIndex(
std::vector<std::pair<T, int>> *sorted_indices) { const std::vector<T> &scores) {
std::vector<std::pair<T, int>> sorted_indices;
sorted_indices.reserve(scores.size());
for (size_t i = 0; i < scores.size(); ++i) { for (size_t i = 0; i < scores.size(); ++i) {
sorted_indices->push_back(std::make_pair(scores[i], i)); sorted_indices.emplace_back(scores[i], i);
} }
// Sort the score pair according to the scores in descending order // Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(), std::stable_sort(sorted_indices.begin(), sorted_indices.end(),
SortScorePairDescend); [](const std::pair<T, int> &a, const std::pair<T, int> &b) {
return a.first < b.first;
});
return sorted_indices;
} }
template <class T> template <class T>
T BBoxArea(const T *box, const bool normalized) { static inline T BBoxArea(const T *box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) { if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid // If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0. // (e.g. xmax < xmin or ymax < ymin), return 0.
...@@ -220,7 +221,7 @@ T BBoxArea(const T *box, const bool normalized) { ...@@ -220,7 +221,7 @@ T BBoxArea(const T *box, const bool normalized) {
} }
template <class T> template <class T>
T JaccardOverlap(const T *box1, const T *box2, const bool normalized) { static inline T JaccardOverlap(const T *box1, const T *box2, bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) { box2[3] < box1[1]) {
return static_cast<T>(0.); return static_cast<T>(0.);
...@@ -229,8 +230,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) { ...@@ -229,8 +230,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
const T inter_ymin = std::max(box1[1], box2[1]); const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]); const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]); const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = std::max(0.0f, inter_xmax - inter_xmin + 1); const T inter_w = std::max(T(0), inter_xmax - inter_xmin + 1);
const T inter_h = std::max(0.0f, inter_ymax - inter_ymin + 1); const T inter_h = std::max(T(0), inter_ymax - inter_ymin + 1);
const T inter_area = inter_w * inter_h; const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized); const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized); const T bbox2_area = BBoxArea<T>(box2, normalized);
...@@ -238,9 +239,21 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) { ...@@ -238,9 +239,21 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
} }
} }
template <typename T>
static inline Tensor VectorToTensor(const std::vector<T> &selected_indices,
int selected_num) {
Tensor keep_nms;
keep_nms.Resize({selected_num});
auto *keep_data = keep_nms.mutable_data<T>(platform::CPUPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
}
template <class T> template <class T>
Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores, static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
const T nms_threshold, const float eta) { Tensor *scores, T nms_threshold, float eta) {
PADDLE_ENFORCE_NOT_NULL(bbox); PADDLE_ENFORCE_NOT_NULL(bbox);
int64_t num_boxes = bbox->dims()[0]; int64_t num_boxes = bbox->dims()[0];
// 4: [xmin ymin xmax ymax] // 4: [xmin ymin xmax ymax]
...@@ -248,20 +261,18 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores, ...@@ -248,20 +261,18 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
std::vector<T> scores_data(num_boxes); std::vector<T> scores_data(num_boxes);
std::copy_n(scores->data<T>(), num_boxes, scores_data.begin()); std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices; std::vector<std::pair<T, int>> sorted_indices =
GetMaxScoreIndex<T>(scores_data, &sorted_indices); GetSortedScoreIndex<T>(scores_data);
std::vector<int> selected_indices; std::vector<int> selected_indices;
int selected_num = 0; int selected_num = 0;
T adaptive_threshold = nms_threshold; T adaptive_threshold = nms_threshold;
const T *bbox_data = bbox->data<T>(); const T *bbox_data = bbox->data<T>();
bool flag;
while (sorted_indices.size() != 0) { while (sorted_indices.size() != 0) {
int idx = sorted_indices.front().second; int idx = sorted_indices.back().second;
flag = true; bool flag = true;
for (size_t k = 0; k < selected_indices.size(); ++k) { for (int kept_idx : selected_indices) {
if (flag) { if (flag) {
const int kept_idx = selected_indices[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size, T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, false); bbox_data + kept_idx * box_size, false);
flag = (overlap <= adaptive_threshold); flag = (overlap <= adaptive_threshold);
...@@ -271,32 +282,29 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores, ...@@ -271,32 +282,29 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
} }
if (flag) { if (flag) {
selected_indices.push_back(idx); selected_indices.push_back(idx);
selected_num++; ++selected_num;
} }
sorted_indices.erase(sorted_indices.begin()); sorted_indices.erase(sorted_indices.end());
if (flag && eta < 1 && adaptive_threshold > 0.5) { if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta; adaptive_threshold *= eta;
} }
} }
Tensor keep_nms; return VectorToTensor(selected_indices, selected_num);
keep_nms.Resize({selected_num});
int *keep_data = keep_nms.mutable_data<int>(ctx.GetPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
} }
template <typename DeviceContext, typename T> template <typename T>
class GenerateProposalsKernel : public framework::OpKernel<T> { class GenerateProposalsKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores"); auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas"); auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo"); auto *im_info = context.Input<Tensor>("ImInfo");
auto *anchors = context.Input<Tensor>("Anchors"); auto anchors = detail::Ref(context.Input<Tensor>("Anchors"),
auto *variances = context.Input<Tensor>("Variances"); "Cannot find input Anchors(%s) in scope",
context.Inputs("Anchors")[0]);
auto variances = detail::Ref(context.Input<Tensor>("Variances"),
"Cannot find input Variances(%s) in scope",
context.Inputs("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois"); auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs"); auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
...@@ -307,15 +315,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -307,15 +315,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
float min_size = context.Attr<float>("min_size"); float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta"); float eta = context.Attr<float>("eta");
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto scores_dim = scores->dims(); auto &scores_dim = scores->dims();
int64_t num = scores_dim[0]; int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1]; int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2]; int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3]; int64_t w_score = scores_dim[3];
auto bbox_dim = bbox_deltas->dims(); auto &bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1]; int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2]; int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3]; int64_t w_bbox = bbox_dim[3];
...@@ -330,17 +339,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -330,17 +339,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
scores_swap.mutable_data<T>({num, h_score, w_score, c_score}, scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace()); dev_ctx.GetPlace());
math::Transpose<DeviceContext, T, 4> trans; math::Transpose<platform::CPUDeviceContext, T, 4> trans;
std::vector<int> axis = {0, 2, 3, 1}; std::vector<int> axis = {0, 2, 3, 1};
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis); trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis); trans(dev_ctx, *scores, &scores_swap, axis);
framework::LoD lod; framework::LoD lod;
std::vector<size_t> lod0(1, 0); lod.resize(1);
Tensor *anchor = const_cast<framework::Tensor *>(anchors); auto &lod0 = lod[0];
anchor->Resize({anchors->numel() / 4, 4}); lod0.push_back(0);
Tensor *var = const_cast<framework::Tensor *>(variances); anchors.Resize({anchors.numel() / 4, 4});
var->Resize({var->numel() / 4, 4}); variances.Resize({variances.numel() / 4, 4});
int64_t num_proposals = 0; int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) { for (int64_t i = 0; i < num; ++i) {
...@@ -352,24 +361,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -352,24 +361,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
scores_slice.Resize({h_score * w_score * c_score, 1}); scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair = std::pair<Tensor, Tensor> tensor_pair =
ProposalForOneImage(dev_ctx, im_info_slice, *anchor, *var, ProposalForOneImage(dev_ctx, im_info_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n, bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta); post_nms_top_n, nms_thresh, min_size, eta);
Tensor proposals = tensor_pair.first; Tensor &proposals = tensor_pair.first;
Tensor scores = tensor_pair.second; Tensor &scores = tensor_pair.second;
framework::VisitDataType(
framework::ToDataType(rpn_rois->type()),
AppendProposalsFunctor(rpn_rois, 4 * num_proposals, &proposals));
framework::VisitDataType(
framework::ToDataType(rpn_roi_probs->type()),
AppendProposalsFunctor(rpn_roi_probs, num_proposals, &scores));
AppendProposals(rpn_rois, 4 * num_proposals, proposals);
AppendProposals(rpn_roi_probs, num_proposals, scores);
num_proposals += proposals.dims()[0]; num_proposals += proposals.dims()[0];
lod0.emplace_back(num_proposals); lod0.push_back(num_proposals);
} }
lod.emplace_back(lod0);
rpn_rois->set_lod(lod); rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod); rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4}); rpn_rois->Resize({num_proposals, 4});
...@@ -377,7 +379,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -377,7 +379,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
} }
std::pair<Tensor, Tensor> ProposalForOneImage( std::pair<Tensor, Tensor> ProposalForOneImage(
const DeviceContext &ctx, const Tensor &im_info_slice, const platform::CPUDeviceContext &ctx, const Tensor &im_info_slice,
const Tensor &anchors, const Tensor &variances, const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4] const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1] const Tensor &scores_slice, // [N, 1]
...@@ -392,8 +394,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -392,8 +394,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
for (int i = 0; i < scores_slice.numel(); ++i) { for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i; index[i] = i;
} }
std::function<bool(const int64_t &, const int64_t &)> compare = auto compare = [scores_data](const int64_t &i, const int64_t &j) {
[scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j]; return scores_data[i] > scores_data[j];
}; };
...@@ -452,33 +453,45 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -452,33 +453,45 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker { class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Scores", "The scores of anchors should be foreground."); AddInput("Scores",
AddInput("BboxDeltas", "bbox_deltas."); "(Tensor) The scores from conv is in shape (N, A, H, W), "
AddInput("ImInfo", "Information for image reshape."); "N is batch size, A is number of anchors, "
AddInput("Anchors", "All anchors."); "H and W are height and width of the feature map");
AddInput("Variances", " variances"); AddInput("BboxDeltas",
"(Tensor) Bounding box deltas from conv is in "
AddOutput("RpnRois", "Anchors."); "shape (N, 4*A, H, W).");
AddOutput("RpnRoiProbs", "Anchors."); AddInput("ImInfo",
AddAttr<int>("pre_nms_topN", "pre_nms_topN"); "(Tensor) Information for image reshape is in shape (N, 3), "
AddAttr<int>("post_nms_topN", "post_nms_topN"); "in format (height, width, scale)");
AddAttr<float>("nms_thresh", "nms_thres"); AddInput("Anchors",
AddAttr<float>("min_size", "min size"); "(Tensor) Bounding box anchors from anchor_generator_op "
"is in shape (A, H, W, 4).");
AddInput("Variances",
"(Tensor) Bounding box variances with same shape as `Anchors`.");
AddOutput("RpnRois",
"(LoDTensor), Output proposals with shape (rois_num, 4).");
AddOutput("RpnRoiProbs",
"(LoDTensor) Scores of proposals with shape (rois_num, 1).");
AddAttr<int>("pre_nms_topN",
"Number of top scoring RPN proposals to keep before "
"applying NMS.");
AddAttr<int>("post_nms_topN",
"Number of top scoring RPN proposals to keep after "
"applying NMS");
AddAttr<float>("nms_thresh", "NMS threshold used on RPN proposals.");
AddAttr<float>("min_size",
"Proposal height and width both need to be greater "
"than this min_size.");
AddAttr<float>("eta", "The parameter for adaptive NMS."); AddAttr<float>("eta", "The parameter for adaptive NMS.");
AddComment(R"DOC( AddComment(R"DOC(
Generate Proposals OP This operator Generate bounding box proposals for Faster RCNN.
The propoasls are generated for a list of images based on image
This operator proposes rois according to each box with their probability to be a foreground object and score 'Scores', bounding box regression result 'BboxDeltas' as
the box can be calculated by anchors. Bbox_deltais and scores are the output of RPN. Final proposals well as predefined bounding box shapes 'anchors'. Greedy
could be used to train detection net. non-maximum suppression is applied to generate the final bounding
boxes.
Scores is the probability for each box to be an object. In format of (N, A, H, W) where N is batch size, A is number
of anchors, H and W are height and width of the feature map.
BboxDeltas is the differece between predicted box locatoin and anchor location. In format of (N, 4*A, H, W)
For generating proposals, this operator transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4) and
calculate box locations as proposals candidates. Then clip boxes to image and remove predicted boxes with small area.
Finally, apply nms to get final proposals as output.
)DOC"); )DOC");
} }
}; };
...@@ -490,6 +503,5 @@ namespace ops = paddle::operators; ...@@ -490,6 +503,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp, REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp,
ops::GenerateProposalsOpMaker, ops::GenerateProposalsOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel<float>,
generate_proposals, ops::GenerateProposalsKernel<double>);
ops::GenerateProposalsKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -16,10 +16,13 @@ limitations under the License. */ ...@@ -16,10 +16,13 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,36 +39,38 @@ namespace { ...@@ -36,36 +39,38 @@ namespace {
int const kThreadsPerBlock = sizeof(uint64_t) * 8; int const kThreadsPerBlock = sizeof(uint64_t) * 8;
template <typename T> static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
__global__ void RangeInitKernel(const T start, const T delta, const int size,
T *out) { struct RangeInitFunctor {
CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; } int start_;
} int delta_;
int *out_;
__device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
};
template <typename T> template <typename T>
void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value, static void SortDescending(const platform::CUDADeviceContext &ctx,
Tensor *value_out, Tensor *index_out) { const Tensor &value, Tensor *value_out,
int num = value.numel(); Tensor *index_out) {
int num = static_cast<int>(value.numel());
Tensor index_in_t; Tensor index_in_t;
int *idx_in = index_in_t.mutable_data<int>({num}, ctx.GetPlace()); int *idx_in = index_in_t.mutable_data<int>({num}, ctx.GetPlace());
int block = 512; platform::ForRange<platform::CUDADeviceContext> for_range(ctx, num);
auto stream = ctx.stream(); for_range(RangeInitFunctor{0, 1, idx_in});
RangeInitKernel<<<DIVUP(num, block), block, 0, stream>>>(0, 1, num, idx_in);
int *idx_out = index_out->mutable_data<int>({num}, ctx.GetPlace()); int *idx_out = index_out->mutable_data<int>({num}, ctx.GetPlace());
const T *keys_in = value.data<T>(); const T *keys_in = value.data<T>();
T *keys_out = value_out->mutable_data<T>({num}, ctx.GetPlace()); T *keys_out = value_out->mutable_data<T>({num}, ctx.GetPlace());
// Determine temporary device storage requirements // Determine temporary device storage requirements
void *d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>( cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num);
num);
// Allocate temporary storage // Allocate temporary storage
auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace()); auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
d_temp_storage = memory::Alloc(place, temp_storage_bytes); void *d_temp_storage = memory::Alloc(place, temp_storage_bytes);
// Run sorting operation // Run sorting operation
cub::DeviceRadixSort::SortPairsDescending<T, int>( cub::DeviceRadixSort::SortPairsDescending<T, int>(
...@@ -76,22 +81,27 @@ void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value, ...@@ -76,22 +81,27 @@ void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value,
} }
template <typename T> template <typename T>
__device__ __forceinline__ T Min(T x, T y) { struct BoxDecodeAndClipFunctor {
return x < y ? x : y; const T *anchor;
} const T *deltas;
const T *var;
template <typename T> const int *index;
__device__ __forceinline__ T Max(T x, T y) { const T *im_info;
return x > y ? x : y;
} T *proposals;
template <typename T> BoxDecodeAndClipFunctor(const T *anchor, const T *deltas, const T *var,
__global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas, const int *index, const T *im_info, T *proposals)
const T *var, const int *index, : anchor(anchor),
const T *im_info, const int num, deltas(deltas),
T *proposals) { var(var),
T kBBoxClipDefault = log(1000.0 / 16.0); index(index),
CUDA_1D_KERNEL_LOOP(i, num) { im_info(im_info),
proposals(proposals) {}
T bbox_clip_default{static_cast<T>(kBBoxClipDefault)};
__device__ void operator()(size_t i) {
int k = index[i] * 4; int k = index[i] * 4;
T axmin = anchor[k]; T axmin = anchor[k];
T aymin = anchor[k + 1]; T aymin = anchor[k + 1];
...@@ -108,17 +118,17 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas, ...@@ -108,17 +118,17 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
T dxmax = deltas[k + 2]; T dxmax = deltas[k + 2];
T dymax = deltas[k + 3]; T dymax = deltas[k + 3];
T d_cx = 0., d_cy = 0., d_w = 0., d_h = 0.; T d_cx, d_cy, d_w, d_h;
if (var) { if (var) {
d_cx = cx + dxmin * w * var[k]; d_cx = cx + dxmin * w * var[k];
d_cy = cy + dymin * h * var[k + 1]; d_cy = cy + dymin * h * var[k + 1];
d_w = exp(Min<T>(dxmax * var[k + 2], kBBoxClipDefault)) * w; d_w = exp(Min(dxmax * var[k + 2], bbox_clip_default)) * w;
d_h = exp(Min<T>(dymax * var[k + 3], kBBoxClipDefault)) * h; d_h = exp(Min(dymax * var[k + 3], bbox_clip_default)) * h;
} else { } else {
d_cx = cx + dxmin * w; d_cx = cx + dxmin * w;
d_cy = cy + dymin * h; d_cy = cy + dymin * h;
d_w = exp(Min<T>(dxmax, kBBoxClipDefault)) * w; d_w = exp(Min(dxmax, bbox_clip_default)) * w;
d_h = exp(Min<T>(dymax, kBBoxClipDefault)) * h; d_h = exp(Min(dymax, bbox_clip_default)) * h;
} }
T oxmin = d_cx - d_w * 0.5; T oxmin = d_cx - d_w * 0.5;
...@@ -126,17 +136,21 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas, ...@@ -126,17 +136,21 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
T oxmax = d_cx + d_w * 0.5 - 1.; T oxmax = d_cx + d_w * 0.5 - 1.;
T oymax = d_cy + d_h * 0.5 - 1.; T oymax = d_cy + d_h * 0.5 - 1.;
proposals[i * 4] = Max<T>(Min<T>(oxmin, im_info[1] - 1.), 0.); proposals[i * 4] = Max(Min(oxmin, im_info[1] - 1.), 0.);
proposals[i * 4 + 1] = Max<T>(Min<T>(oymin, im_info[0] - 1.), 0.); proposals[i * 4 + 1] = Max(Min(oymin, im_info[0] - 1.), 0.);
proposals[i * 4 + 2] = Max<T>(Min<T>(oxmax, im_info[1] - 1.), 0.); proposals[i * 4 + 2] = Max(Min(oxmax, im_info[1] - 1.), 0.);
proposals[i * 4 + 3] = Max<T>(Min<T>(oymax, im_info[0] - 1.), 0.); proposals[i * 4 + 3] = Max(Min(oymax, im_info[0] - 1.), 0.);
} }
}
__device__ __forceinline__ T Min(T a, T b) const { return a > b ? b : a; }
__device__ __forceinline__ T Max(T a, T b) const { return a > b ? a : b; }
};
template <typename T, int BlockSize> template <typename T, int BlockSize>
__global__ void FilterBBoxes(const T *bboxes, const T *im_info, static __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
const T min_size, const int num, int *keep_num, const T min_size, const int num,
int *keep) { int *keep_num, int *keep) {
T im_h = im_info[0]; T im_h = im_info[0];
T im_w = im_info[1]; T im_w = im_info[1];
T im_scale = im_info[2]; T im_scale = im_info[2];
...@@ -181,7 +195,7 @@ __global__ void FilterBBoxes(const T *bboxes, const T *im_info, ...@@ -181,7 +195,7 @@ __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
} }
} }
__device__ inline float IoU(const float *a, const float *b) { static __device__ inline float IoU(const float *a, const float *b) {
float left = max(a[0], b[0]), right = min(a[2], b[2]); float left = max(a[0], b[0]), right = min(a[2], b[2]);
float top = max(a[1], b[1]), bottom = min(a[3], b[3]); float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
...@@ -191,7 +205,8 @@ __device__ inline float IoU(const float *a, const float *b) { ...@@ -191,7 +205,8 @@ __device__ inline float IoU(const float *a, const float *b) {
return inter_s / (s_a + s_b - inter_s); return inter_s / (s_a + s_b - inter_s);
} }
__global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh, static __global__ void NMSKernel(const int n_boxes,
const float nms_overlap_thresh,
const float *dev_boxes, uint64_t *dev_mask) { const float *dev_boxes, uint64_t *dev_mask) {
const int row_start = blockIdx.y; const int row_start = blockIdx.y;
const int col_start = blockIdx.x; const int col_start = blockIdx.x;
...@@ -234,7 +249,7 @@ __global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh, ...@@ -234,7 +249,7 @@ __global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh,
} }
template <typename T> template <typename T>
void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals, static void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const Tensor &sorted_indices, const T nms_threshold, const Tensor &sorted_indices, const T nms_threshold,
Tensor *keep_out) { Tensor *keep_out) {
int boxes_num = proposals.dims()[0]; int boxes_num = proposals.dims()[0];
...@@ -247,13 +262,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals, ...@@ -247,13 +262,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const T *boxes = proposals.data<T>(); const T *boxes = proposals.data<T>();
auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace()); auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
int size_bytes = boxes_num * col_blocks * sizeof(uint64_t); framework::Vector<uint64_t> mask(boxes_num * col_blocks);
uint64_t *d_mask = NMSKernel<<<blocks, threads>>>(
reinterpret_cast<uint64_t *>(memory::Alloc(place, size_bytes)); boxes_num, nms_threshold, boxes,
NMSKernel<<<blocks, threads>>>(boxes_num, nms_threshold, boxes, d_mask); mask.CUDAMutableData(boost::get<platform::CUDAPlace>(ctx.GetPlace())));
uint64_t *h_mask = reinterpret_cast<uint64_t *>(
memory::Alloc(platform::CPUPlace(), size_bytes));
memory::Copy(platform::CPUPlace(), h_mask, place, d_mask, size_bytes, 0);
std::vector<uint64_t> remv(col_blocks); std::vector<uint64_t> remv(col_blocks);
memset(&remv[0], 0, sizeof(uint64_t) * col_blocks); memset(&remv[0], 0, sizeof(uint64_t) * col_blocks);
...@@ -267,7 +279,7 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals, ...@@ -267,7 +279,7 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
if (!(remv[nblock] & (1ULL << inblock))) { if (!(remv[nblock] & (1ULL << inblock))) {
++num_to_keep; ++num_to_keep;
keep_vec.push_back(i); keep_vec.push_back(i);
uint64_t *p = &h_mask[0] + i * col_blocks; uint64_t *p = &mask[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) { for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j]; remv[j] |= p[j];
} }
...@@ -276,12 +288,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals, ...@@ -276,12 +288,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace()); int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace());
memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(), memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(),
sizeof(int) * num_to_keep, 0); sizeof(int) * num_to_keep, 0);
memory::Free(place, d_mask);
memory::Free(platform::CPUPlace(), h_mask);
} }
template <typename T> template <typename T>
std::pair<Tensor, Tensor> ProposalForOneImage( static std::pair<Tensor, Tensor> ProposalForOneImage(
const platform::CUDADeviceContext &ctx, const Tensor &im_info, const platform::CUDADeviceContext &ctx, const Tensor &im_info,
const Tensor &anchors, const Tensor &variances, const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas, // [M, 4] const Tensor &bbox_deltas, // [M, 4]
...@@ -300,18 +310,20 @@ std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -300,18 +310,20 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
// 2. box decode and clipping // 2. box decode and clipping
Tensor proposals; Tensor proposals;
proposals.mutable_data<T>({pre_nms_num, 4}, ctx.GetPlace()); proposals.mutable_data<T>({pre_nms_num, 4}, ctx.GetPlace());
int block = 512;
auto stream = ctx.stream(); {
BoxDecodeAndClipKernel<T><<<DIVUP(pre_nms_num, block), block, 0, stream>>>( platform::ForRange<platform::CUDADeviceContext> for_range(ctx, pre_nms_num);
for_range(BoxDecodeAndClipFunctor<T>{
anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(), anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(),
index_sort.data<int>(), im_info.data<T>(), pre_nms_num, index_sort.data<int>(), im_info.data<T>(), proposals.data<T>()});
proposals.data<T>()); }
// 3. filter // 3. filter
Tensor keep_index, keep_num_t; Tensor keep_index, keep_num_t;
keep_index.mutable_data<int>({pre_nms_num}, ctx.GetPlace()); keep_index.mutable_data<int>({pre_nms_num}, ctx.GetPlace());
keep_num_t.mutable_data<int>({1}, ctx.GetPlace()); keep_num_t.mutable_data<int>({1}, ctx.GetPlace());
min_size = std::max(min_size, 1.0f); min_size = std::max(min_size, 1.0f);
auto stream = ctx.stream();
FilterBBoxes<T, 512><<<1, 512, 0, stream>>>( FilterBBoxes<T, 512><<<1, 512, 0, stream>>>(
proposals.data<T>(), im_info.data<T>(), min_size, pre_nms_num, proposals.data<T>(), im_info.data<T>(), min_size, pre_nms_num,
keep_num_t.data<int>(), keep_index.data<int>()); keep_num_t.data<int>(), keep_index.data<int>());
...@@ -355,8 +367,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -355,8 +367,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
auto *scores = context.Input<Tensor>("Scores"); auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas"); auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo"); auto *im_info = context.Input<Tensor>("ImInfo");
auto *anchors = context.Input<Tensor>("Anchors"); auto anchors = detail::Ref(context.Input<Tensor>("Anchors"),
auto *variances = context.Input<Tensor>("Variances"); "Cannot find input Anchors(%s) in scope",
context.Inputs("Anchors")[0]);
auto variances = detail::Ref(context.Input<Tensor>("Variances"),
"Cannot find input Variances(%s) in scope",
context.Inputs("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois"); auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs"); auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
...@@ -392,10 +408,8 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -392,10 +408,8 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis); trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis); trans(dev_ctx, *scores, &scores_swap, axis);
Tensor *anchor = const_cast<framework::Tensor *>(anchors); anchors.Resize({anchors.numel() / 4, 4});
anchor->Resize({anchors->numel() / 4, 4}); variances.Resize({variances.numel() / 4, 4});
Tensor *var = const_cast<framework::Tensor *>(variances);
var->Resize({var->numel() / 4, 4});
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4}, rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace()); context.GetPlace());
...@@ -417,12 +431,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -417,12 +431,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
scores_slice.Resize({h_score * w_score * c_score, 1}); scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> box_score_pair = std::pair<Tensor, Tensor> box_score_pair =
ProposalForOneImage<T>(dev_ctx, im_info_slice, *anchor, *var, ProposalForOneImage<T>(dev_ctx, im_info_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n, bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta); post_nms_top_n, nms_thresh, min_size, eta);
Tensor proposals = box_score_pair.first; Tensor &proposals = box_score_pair.first;
Tensor scores = box_score_pair.second; Tensor &scores = box_score_pair.second;
memory::Copy(place, rpn_rois_data + num_proposals * 4, place, memory::Copy(place, rpn_rois_data + num_proposals * 4, place,
proposals.data<T>(), sizeof(T) * proposals.numel(), 0); proposals.data<T>(), sizeof(T) * proposals.numel(), 0);
......
此差异已折叠。
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/***************************************************************************
*
* Copyright (c) 2015 Baidu.com, Inc. All Rights Reserved
*
**************************************************************************/
/**
* @file include/gpc.h
* @author huhan02(com@baidu.com)
* @date 2015/12/18 13:52:10
* @brief
*
* @modified by sunyipeng
* @email sunyipeng@baidu.com
* @date 2018/6/12
**/
#ifndef PADDLE_FLUID_OPERATORS_DETECTION_GPC_H_ // GPC_H_
#define PADDLE_FLUID_OPERATORS_DETECTION_GPC_H_ // GPC_H_
#include <float.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
namespace gpc {
typedef enum { // Set operation type
GPC_DIFF, // Difference
GPC_INT, // Intersection
GPC_XOR, // Exclusive or
GPC_UNION // Union
} gpc_op;
typedef struct { // Polygon vertex structure
double x; // Vertex x component
double y; // vertex y component
} gpc_vertex;
typedef struct { // Vertex list structure
int num_vertices; // Number of vertices in list
gpc_vertex *vertex; // Vertex array pointer
} gpc_vertex_list;
typedef struct { // Polygon set structure
int num_contours; // Number of contours in polygon
int *hole; // Hole external contour flags
gpc_vertex_list *contour; // Contour array pointer
} gpc_polygon;
typedef struct { // Tristrip set structure
int num_strips; // Number of tristrips
gpc_vertex_list *strip; // Tristrip array pointer
} gpc_tristrip;
typedef enum { LEFT, RIGHT } gpc_left_right;
typedef enum { ABOVE, BELOW } gpc_above_below;
typedef enum { CLIP, SUBJ } gpc_clip_subj;
typedef enum { /* Edge intersection classes */
NUL, /* Empty non-intersection */
EMX, /* External maximum */
ELI, /* External left intermediate */
TED, /* Top edge */
ERI, /* External right intermediate */
RED, /* Right edge */
IMM, /* Internal maximum and minimum */
IMN, /* Internal minimum */
EMN, /* External minimum */
EMM, /* External maximum and minimum */
LED, /* Left edge */
ILI, /* Internal left intermediate */
BED, /* Bottom edge */
IRI, /* Internal right intermediate */
IMX, /* Internal maximum */
FUL /* Full non-intersection */
} vertex_type;
typedef enum { /* Horizontal edge states */
NH, /* No horizontal edge */
BH, /* Bottom horizontal edge */
TH /* Top horizontal edge */
} h_state;
typedef enum { /* Edge bundle state */
UNBUNDLED, /* Isolated edge not within a bundle */
BUNDLE_HEAD, /* Bundle head node */
BUNDLE_TAIL /* Passive bundle tail node */
} bundle_state;
typedef struct v_shape { /* Internal vertex list datatype */
double x; /* X coordinate component */
double y; /* Y coordinate component */
struct v_shape *next; /* Pointer to next vertex in list */
} vertex_node;
typedef struct p_shape { /* Internal contour / tristrip type */
int active; /* Active flag / vertex count */
int hole; /* Hole / external contour flag */
vertex_node *v[2]; /* Left and right vertex list ptrs */
struct p_shape *next; /* Pointer to next polygon contour */
struct p_shape *proxy; /* Pointer to actual structure used */
} polygon_node;
typedef struct edge_shape {
gpc_vertex vertex; /* Piggy-backed contour vertex data */
gpc_vertex bot; /* Edge lower (x, y) coordinate */
gpc_vertex top; /* Edge upper (x, y) coordinate */
double xb; /* Scanbeam bottom x coordinate */
double xt; /* Scanbeam top x coordinate */
double dx; /* Change in x for a unit y increase */
int type; /* Clip / subject edge flag */
int bundle[2][2]; /* Bundle edge flags */
int bside[2]; /* Bundle left / right indicators */
bundle_state bstate[2]; /* Edge bundle state */
polygon_node *outp[2]; /* Output polygon / tristrip pointer */
struct edge_shape *prev; /* Previous edge in the AET */
struct edge_shape *next; /* Next edge in the AET */
struct edge_shape *pred; /* Edge connected at the lower end */
struct edge_shape *succ; /* Edge connected at the upper end */
struct edge_shape *next_bound; /* Pointer to next bound in LMT */
} edge_node;
inline bool gpc_eq(float a, float b) { return (fabs(a - b) <= 1e-6); }
inline bool gpc_prev_index(float a, float b) { return (fabs(a - b) <= 1e-6); }
inline int gpc_prev_index(int i, int n) { return ((i - 1 + n) % n); }
inline int gpc_next_index(int i, int n) { return ((i + 1) % n); }
inline int gpc_optimal(gpc_vertex *v, int i, int n) {
return (v[(i + 1) % n].y != v[i].y || v[(i - 1 + n) % n].y != v[i].y);
}
inline int gpc_fwd_min(edge_node *v, int i, int n) {
return (v[(i + 1) % n].vertex.y > v[i].vertex.y &&
v[(i - 1 + n) % n].vertex.y >= v[i].vertex.y);
}
inline int gpc_not_fmax(edge_node *v, int i, int n) {
return (v[(i + 1) % n].vertex.y > v[i].vertex.y);
}
inline int gpc_rev_min(edge_node *v, int i, int n) {
return (v[(i + 1) % n].vertex.y >= v[i].vertex.y &&
v[(i - 1 + n) % n].vertex.y > v[i].vertex.y);
}
inline int gpc_not_rmax(edge_node *v, int i, int n) {
return (v[(i - 1 + n) % n].vertex.y > v[i].vertex.y);
}
// inline void gpc_p_edge(edge_node *d, edge_node *e, int p, double i, double j)
// {
inline void gpc_p_edge(edge_node *d, edge_node *e, int p) {
d = e;
do {
d = d->prev;
} while (!d->outp[p]);
// i = d->bot.x + d->dx * (j - d->bot.y);
}
// inline void gpc_n_edge(edge_node *d, edge_node *e, int p, double i, double j)
// {
inline void gpc_n_edge(edge_node *d, edge_node *e, int p) {
d = e;
do {
d = d->next;
} while (!d->outp[p]);
// i = d->bot.x + d->dx * (j - d->bot.y);
}
template <typename T>
void gpc_malloc(T *&p, int b, char *s) {
if (b > 0) {
p = (T *)malloc(b);
if (!p) {
fprintf(stderr, "gpc malloc failure: %s\n", s);
exit(0);
}
} else {
p = NULL;
}
}
template <typename T>
void gpc_free(T *&p) {
if (p) {
free(p);
p = NULL;
}
}
/*
===========================================================================
Public Function Prototypes
===========================================================================
*/
void add_vertex(vertex_node **t, double x, double y);
void gpc_vertex_create(edge_node *e, int p, int s, double x, double y);
/*
void gpc_read_polygon(FILE *infile_ptr, int read_hole_flags,
gpc_polygon *polygon);
void gpc_write_polygon(FILE *outfile_ptr, int write_hole_flags,
gpc_polygon *polygon);
*/
void gpc_add_contour(gpc_polygon *polygon, gpc_vertex_list *contour, int hole);
void gpc_polygon_clip(gpc_op set_operation, gpc_polygon *subject_polygon,
gpc_polygon *clip_polygon, gpc_polygon *result_polygon);
void gpc_tristrip_clip(gpc_op set_operation, gpc_polygon *subject_polygon,
gpc_polygon *clip_polygon,
gpc_tristrip *result_tristrip);
void gpc_polygon_to_tristrip(gpc_polygon *polygon, gpc_tristrip *tristrip);
void gpc_free_polygon(gpc_polygon *polygon);
void gpc_free_tristrip(gpc_tristrip *tristrip);
} // namespace gpc
#endif // PADDLE_FLUID_OPERATORS_DETECTION_GPC_H_
/* vim: set expandtab ts=4 sw=4 sts=4 tw=100: */
...@@ -9,10 +9,11 @@ http://www.apache.org/licenses/LICENSE-2.0 ...@@ -9,10 +9,11 @@ http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/poly_util.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -20,9 +21,6 @@ namespace operators { ...@@ -20,9 +21,6 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
constexpr int64_t kOutputDim = 6;
constexpr int64_t kBBoxSize = 4;
class MultiClassNMSOp : public framework::OperatorWithKernel { class MultiClassNMSOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -42,10 +40,15 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { ...@@ -42,10 +40,15 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
"The rank of Input(BBoxes) must be 3."); "The rank of Input(BBoxes) must be 3.");
PADDLE_ENFORCE_EQ(score_dims.size(), 3, PADDLE_ENFORCE_EQ(score_dims.size(), 3,
"The rank of Input(Scores) must be 3."); "The rank of Input(Scores) must be 3.");
PADDLE_ENFORCE_EQ(box_dims[2], 4, PADDLE_ENFORCE(box_dims[2] == 4 || box_dims[2] == 8 || box_dims[2] == 16 ||
"The 2nd dimension of Input(BBoxes) must be 4, " box_dims[2] == 24 || box_dims[2] == 32,
"The 2nd dimension of Input(BBoxes) must be 4 or 8, "
"represents the layout of coordinate " "represents the layout of coordinate "
"[xmin, ymin, xmax, ymax]"); "[xmin, ymin, xmax, ymax] or "
"4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or "
"8 points: [xi, yi] i= 1,2,...,8 or "
"12 points: [xi, yi] i= 1,2,...,12 or "
"16 points: [xi, yi] i= 1,2,...,16");
PADDLE_ENFORCE_EQ(box_dims[1], score_dims[2], PADDLE_ENFORCE_EQ(box_dims[1], score_dims[2],
"The 1st dimensiong of Input(BBoxes) must be equal to " "The 1st dimensiong of Input(BBoxes) must be equal to "
"3rd dimension of Input(Scores), which represents the " "3rd dimension of Input(Scores), which represents the "
...@@ -53,7 +56,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { ...@@ -53,7 +56,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
// Here the box_dims[0] is not the real dimension of output. // Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel. // It will be rewritten in the computing kernel.
ctx->SetOutputDim("Out", {box_dims[1], 6}); ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2});
} }
protected: protected:
...@@ -128,6 +131,21 @@ static inline T JaccardOverlap(const T* box1, const T* box2, ...@@ -128,6 +131,21 @@ static inline T JaccardOverlap(const T* box1, const T* box2,
} }
} }
template <class T>
T PolyIoU(const T* box1, const T* box2, const size_t box_size,
const bool normalized) {
T bbox1_area = PolyArea<T>(box1, box_size, normalized);
T bbox2_area = PolyArea<T>(box2, box_size, normalized);
T inter_area = PolyOverlapArea<T>(box1, box2, box_size, normalized);
if (bbox1_area == 0 || bbox2_area == 0 || inter_area == 0) {
// If coordinate values are is invalid
// if area size <= 0, return 0.
return T(0.);
} else {
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <typename T> template <typename T>
class MultiClassNMSKernel : public framework::OpKernel<T> { class MultiClassNMSKernel : public framework::OpKernel<T> {
public: public:
...@@ -137,6 +155,8 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -137,6 +155,8 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
// The total boxes for each instance. // The total boxes for each instance.
int64_t num_boxes = bbox.dims()[0]; int64_t num_boxes = bbox.dims()[0];
// 4: [xmin ymin xmax ymax] // 4: [xmin ymin xmax ymax]
// 8: [x1 y1 x2 y2 x3 y3 x4 y4]
// 16, 24, or 32: [x1 y1 x2 y2 ... xn yn], n = 8, 12 or 16
int64_t box_size = bbox.dims()[1]; int64_t box_size = bbox.dims()[1];
std::vector<T> scores_data(num_boxes); std::vector<T> scores_data(num_boxes);
...@@ -154,8 +174,19 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -154,8 +174,19 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
for (size_t k = 0; k < selected_indices->size(); ++k) { for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) { if (keep) {
const int kept_idx = (*selected_indices)[k]; const int kept_idx = (*selected_indices)[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size, T overlap = T(0.);
// 4: [xmin ymin xmax ymax]
if (box_size == 4) {
overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, true); bbox_data + kept_idx * box_size, true);
}
// 8: [x1 y1 x2 y2 x3 y3 x4 y4] or 16, 24, 32
if (box_size == 8 || box_size == 16 || box_size == 24 ||
box_size == 32) {
overlap =
PolyIoU<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, box_size, true);
}
keep = overlap <= adaptive_threshold; keep = overlap <= adaptive_threshold;
} else { } else {
break; break;
...@@ -228,7 +259,9 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -228,7 +259,9 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices, const std::map<int, std::vector<int>>& selected_indices,
Tensor* outs) const { Tensor* outs) const {
int predict_dim = scores.dims()[1]; int64_t predict_dim = scores.dims()[1];
int64_t box_size = bboxes.dims()[1];
int64_t out_dim = bboxes.dims()[1] + 2;
auto* scores_data = scores.data<T>(); auto* scores_data = scores.data<T>();
auto* bboxes_data = bboxes.data<T>(); auto* bboxes_data = bboxes.data<T>();
auto* odata = outs->data<T>(); auto* odata = outs->data<T>();
...@@ -240,11 +273,11 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -240,11 +273,11 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
const std::vector<int>& indices = it.second; const std::vector<int>& indices = it.second;
for (size_t j = 0; j < indices.size(); ++j) { for (size_t j = 0; j < indices.size(); ++j) {
int idx = indices[j]; int idx = indices[j];
const T* bdata = bboxes_data + idx * kBBoxSize; const T* bdata = bboxes_data + idx * box_size;
odata[count * kOutputDim] = label; // label odata[count * out_dim] = label; // label
odata[count * kOutputDim + 1] = sdata[idx]; // score odata[count * out_dim + 1] = sdata[idx]; // score
// xmin, ymin, xmax, ymax // xmin, ymin, xmax, ymax or multi-points coordinates
std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T)); std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T));
count++; count++;
} }
} }
...@@ -261,6 +294,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -261,6 +294,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
int64_t class_num = score_dims[1]; int64_t class_num = score_dims[1];
int64_t predict_dim = score_dims[2]; int64_t predict_dim = score_dims[2];
int64_t box_dim = boxes->dims()[2]; int64_t box_dim = boxes->dims()[2];
int64_t out_dim = boxes->dims()[2] + 2;
std::vector<std::map<int, std::vector<int>>> all_indices; std::vector<std::map<int, std::vector<int>>> all_indices;
std::vector<size_t> batch_starts = {0}; std::vector<size_t> batch_starts = {0};
...@@ -283,7 +317,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -283,7 +317,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
T* od = outs->mutable_data<T>({1}, ctx.GetPlace()); T* od = outs->mutable_data<T>({1}, ctx.GetPlace());
od[0] = -1; od[0] = -1;
} else { } else {
outs->mutable_data<T>({num_kept, kOutputDim}, ctx.GetPlace()); outs->mutable_data<T>({num_kept, out_dim}, ctx.GetPlace());
for (int64_t i = 0; i < batch_size; ++i) { for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = scores->Slice(i, i + 1); Tensor ins_score = scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim}); ins_score.Resize({class_num, predict_dim});
...@@ -311,10 +345,11 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -311,10 +345,11 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("BBoxes", AddInput("BBoxes",
"(Tensor) A 3-D Tensor with shape [N, M, 4] represents the " "(Tensor) A 3-D Tensor with shape "
"[N, M, 4 or 8 16 24 32] represents the "
"predicted locations of M bounding bboxes, N is the batch size. " "predicted locations of M bounding bboxes, N is the batch size. "
"Each bounding box has four coordinate values and the layout is " "Each bounding box has four coordinate values and the layout is "
"[xmin, ymin, xmax, ymax]."); "[xmin, ymin, xmax, ymax], when box size equals to 4.");
AddInput("Scores", AddInput("Scores",
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the " "(Tensor) A 3-D Tensor with shape [N, C, M] represents the "
"predicted confidence predictions. N is the batch size, C is the " "predicted confidence predictions. N is the batch size, C is the "
...@@ -351,8 +386,12 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -351,8 +386,12 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", AddOutput("Out",
"(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the " "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the "
"detections. Each row has 6 values: " "detections. Each row has 6 values: "
"[label, confidence, xmin, ymin, xmax, ymax], No is the total " "[label, confidence, xmin, ymin, xmax, ymax] or "
"number of detections in this mini-batch. For each instance, " "(LoDTensor) A 2-D LoDTensor with shape [No, 10] represents the "
"detections. Each row has 10 values: "
"[label, confidence, x1, y1, x2, y2, x3, y3, x4, y4]. No is the "
"total number of detections in this mini-batch."
"For each instance, "
"the offsets in first dimension are called LoD, the number of " "the offsets in first dimension are called LoD, the number of "
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected bbox."); "no detected bbox.");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef POLY_UTIL_CC_
#define POLY_UTIL_CC_
#include "paddle/fluid/operators/detection/poly_util.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using gpc::gpc_polygon_clip;
using gpc::gpc_free_polygon;
template <class T>
void Array2PointVec(const T*& box, const size_t box_size,
std::vector<Point_<T>>& vec) {
size_t pts_num = box_size / 2;
vec.resize(pts_num);
for (size_t i = 0; i < pts_num; i++) {
vec.at(i).x = box[2 * i];
vec.at(i).y = box[2 * i + 1];
}
}
template <class T>
void Array2Poly(const T*& box, const size_t box_size, gpc::gpc_polygon& poly) {
size_t pts_num = box_size / 2;
poly.num_contours = 1;
poly.hole = (int*)malloc(sizeof(int));
poly.hole[0] = 0;
poly.contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
poly.contour->num_vertices = pts_num;
poly.contour->vertex =
(gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num);
for (size_t i = 0; i < pts_num; ++i) {
poly.contour->vertex[i].x = box[2 * i];
poly.contour->vertex[i].y = box[2 * i + 1];
}
}
template <class T>
void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon& poly) {
int pts_num = vec.size();
poly.num_contours = 1;
poly.hole = (int*)malloc(sizeof(int));
poly.hole[0] = 0;
poly.contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
poly.contour->num_vertices = pts_num;
poly.contour->vertex =
(gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num);
for (size_t i = 0; i < pts_num; ++i) {
poly.contour->vertex[i].x = vec[i].x;
poly.contour->vertex[i].y = vec[i].y;
}
}
template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>& vec) {
int pts_num = contour.num_vertices;
vec.resize(pts_num);
for (int i = 0; i < pts_num; i++) {
vec.at(i).x = contour.vertex[i].x;
vec.at(i).y = contour.vertex[i].y;
}
}
template <class T>
T GetContourArea(std::vector<Point_<T>>& vec) {
size_t pts_num = vec.size();
if (pts_num < 3) return T(0.);
T area = T(0.);
for (size_t i = 0; i < pts_num; ++i) {
area += vec[i].x * vec[(i + 1) % pts_num].y -
vec[i].y * vec[(i + 1) % pts_num].x;
}
return std::fabs(area / 2.0);
}
template <class T>
T PolyArea(const T* box, const size_t box_size, const bool normalized) {
// If coordinate values are is invalid
// if area size <= 0, return 0.
std::vector<Point_<T>> vec;
Array2PointVec<T>(box, box_size, vec);
return GetContourArea<T>(vec);
}
template <class T>
T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size,
const bool normalized) {
gpc::gpc_polygon poly1;
gpc::gpc_polygon poly2;
Array2Poly<T>(box1, box_size, poly1);
Array2Poly<T>(box2, box_size, poly2);
gpc::gpc_polygon respoly;
gpc::gpc_op op = gpc::GPC_INT;
gpc::gpc_polygon_clip(op, &poly2, &poly1, &respoly);
T inter_area = T(0.);
int contour_num = respoly.num_contours;
for (int i = 0; i < contour_num; ++i) {
std::vector<Point_<T>> resvec;
Poly2PointVec<T>(respoly.contour[i], resvec);
// inter_area += std::fabs(cv::contourArea(resvec)) + 0.5f *
// (cv::arcLength(resvec, true));
inter_area += GetContourArea<T>(resvec);
}
gpc::gpc_free_polygon(&poly1);
gpc::gpc_free_polygon(&poly2);
gpc::gpc_free_polygon(&respoly);
return inter_area;
}
} // namespace operators
} // namespace paddle
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef POLY_UTIL_H_
#define POLY_UTIL_H_
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/gpc.h"
namespace paddle {
namespace operators {
template <class T>
class Point_ {
public:
// default constructor
Point_() {}
Point_(T _x, T _y) {}
Point_(const Point_& pt) {}
Point_& operator=(const Point_& pt);
// conversion to another data type
// template<typename _T> operator Point_<_T>() const;
// conversion to the old-style C structures
// operator Vec<T, 2>() const;
// checks whether the point is inside the specified rectangle
// bool inside(const Rect_<T>& r) const;
T x; //!< x coordinate of the point
T y; //!< y coordinate of the point
};
template <class T>
void Array2PointVec(const T*& box, const size_t box_size,
std::vector<Point_<T>>& vec);
template <class T>
void Array2Poly(const T*& box, const size_t box_size, gpc::gpc_polygon& poly);
template <class T>
void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon& poly);
template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>& vec);
template <class T>
T GetContourArea(std::vector<Point_<T>>& vec);
template <class T>
T PolyArea(const T* box, const size_t box_size, const bool normalized);
template <class T>
T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size,
const bool normalized);
} // namespace operators
} // namespace paddle
#include "paddle/fluid/operators/detection/poly_util.cc"
#endif // POLY_UTIL_H_
...@@ -41,9 +41,9 @@ class PolygonBoxTransformCPUKernel : public framework::OpKernel<T> { ...@@ -41,9 +41,9 @@ class PolygonBoxTransformCPUKernel : public framework::OpKernel<T> {
for (int id_w = 0; id_w < width; ++id_w) { for (int id_w = 0; id_w < width; ++id_w) {
id = id_n * height * width + width * id_h + id_w; id = id_n * height * width + width * id_h + id_w;
if (id_n % 2 == 0) { if (id_n % 2 == 0) {
out_data[id] = id_w - in_data[id]; out_data[id] = id_w * 4 - in_data[id];
} else { } else {
out_data[id] = id_h - in_data[id]; out_data[id] = id_h * 4 - in_data[id];
} }
} }
} }
......
...@@ -32,9 +32,9 @@ __global__ void PolygonBoxTransformKernel(const int n, const int h, const int w, ...@@ -32,9 +32,9 @@ __global__ void PolygonBoxTransformKernel(const int n, const int h, const int w,
if (id_n < n && id_h < h && id_w < w) { if (id_n < n && id_h < h && id_w < w) {
int id = id_n * h * w + w * id_h + id_w; int id = id_n * h * w + w * id_h + id_w;
if (id_n % 2 == 0) { if (id_n % 2 == 0) {
output[id] = id_w - input[id]; output[id] = id_w * 4 - input[id];
} else { } else {
output[id] = id_h - input[id]; output[id] = id_h * 4 - input[id];
} }
} }
} }
......
...@@ -86,7 +86,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -86,7 +86,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
// stub context // stub context
s->response_call_back_ = nullptr; s->response_call_back_ = nullptr;
platform::RecordEvent record_event(method, p_ctx); platform::RecordRPCEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
...@@ -143,7 +143,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -143,7 +143,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
// stub context // stub context
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
platform::RecordEvent record_event(method, p_ctx); platform::RecordRPCEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_); s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
...@@ -191,7 +191,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -191,7 +191,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
// stub context // stub context
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
platform::RecordEvent record_event(method, p_ctx); platform::RecordRPCEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
...@@ -221,7 +221,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, ...@@ -221,7 +221,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE); req.set_varname(BATCH_BARRIER_MESSAGE);
platform::RecordEvent record_event(method, nullptr); platform::RecordRPCEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
...@@ -246,7 +246,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep, ...@@ -246,7 +246,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE); req.set_varname(FETCH_BARRIER_MESSAGE);
platform::RecordEvent record_event(method, nullptr); platform::RecordRPCEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
...@@ -271,7 +271,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, ...@@ -271,7 +271,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE); req.set_varname(COMPLETE_MESSAGE);
platform::RecordEvent record_event(method, nullptr); platform::RecordRPCEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
...@@ -301,7 +301,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -301,7 +301,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
req.set_varname(CHECKPOINT_SAVE_MESSAGE); req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir); req.set_out_varname(dir);
platform::RecordEvent record_event(method, nullptr); platform::RecordRPCEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
......
...@@ -36,7 +36,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -36,7 +36,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, ::grpc::ByteBuffer* msg,
const std::string& out_name) { const std::string& out_name) {
platform::RecordEvent record_event("serial", &ctx); platform::RecordRPCEvent record_event("serial", &ctx);
// Default DestroyCallback does nothing, When using GPU // Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed. // the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {}; DestroyCallback destroy_callback = [](void* backing) {};
...@@ -148,7 +148,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, ...@@ -148,7 +148,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var) { framework::Variable** var) {
platform::RecordEvent record_event("deserial", &ctx); platform::RecordRPCEvent record_event("deserial", &ctx);
operators::distributed::GRPCVariableResponse resp(scope, &ctx); operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar(); *var = resp.GetVar();
......
此差异已折叠。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionSeqConvEltAddReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionSeqConvEltAddReluOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
...@@ -39,11 +39,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -39,11 +39,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0]; int64_t index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
const T* p_src = src.data<T>(); const T* p_src = src.data<T>();
const int* p_index = index.data<int>(); const int* p_index = index.data<int>();
...@@ -55,7 +53,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -55,7 +53,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const size_t slice_bytes = slice_size * sizeof(T); const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) { for (int64_t i = 0; i < index_size; ++i) {
int index_ = p_index[i]; int index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes); memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
} }
......
...@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat) ...@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
DEPS cpu_info cblas activation_functions) DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -86,6 +86,12 @@ class VAddBiasKernel : public Kernel { ...@@ -86,6 +86,12 @@ class VAddBiasKernel : public Kernel {
virtual void Compute(const T a, const T *x, T *y) const = 0; virtual void Compute(const T a, const T *x, T *y) const = 0;
}; };
template <typename T>
class VAddReluKernel : public Kernel {
public:
virtual void Compute(const T *x, const T *y, T *z) const = 0;
};
template <typename T> template <typename T>
class VActKernel : public Kernel { class VActKernel : public Kernel {
public: public:
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -174,4 +174,4 @@ REGISTER_OP_CPU_KERNEL( ...@@ -174,4 +174,4 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
roi_pool_grad, roi_pool_grad,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>, ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>); ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>);
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册