提交 d2d082db 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_mem_analyse

...@@ -52,9 +52,8 @@ ExternalProject_Add( ...@@ -52,9 +52,8 @@ ExternalProject_Add(
extern_anakin extern_anakin
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLML_PROJECT} DEPENDS ${MKLML_PROJECT}
# Anakin codes error on Intel(R) Xeon(R) Gold 5117 CPU, temporary do not compile avx512 related code. GIT_REPOSITORY "https://github.com/PaddlePaddle/Anakin"
GIT_REPOSITORY "https://github.com/luotao1/Anakin" GIT_TAG "9424277cf9ae180a14aff09560d3cd60a49c76d2"
GIT_TAG "211d1fc5d813d70c0c14072f9083cf25f40940ea"
PREFIX ${ANAKIN_SOURCE_DIR} PREFIX ${ANAKIN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DUSE_GPU_PLACE=YES CMAKE_ARGS -DUSE_GPU_PLACE=YES
......
...@@ -55,9 +55,10 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path ...@@ -55,9 +55,10 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path
paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,))
paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True)) paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
paddle.fluid.InferenceTranspiler.__init__ paddle.fluid.InferenceTranspiler.__init__
paddle.fluid.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0)) paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
...@@ -299,6 +300,7 @@ paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', ' ...@@ -299,6 +300,7 @@ paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', '
paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')) paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral'))
paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3)) paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3))
paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)) paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None))
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
...@@ -334,9 +336,10 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array ...@@ -334,9 +336,10 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array
paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None) paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True)) paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
paddle.fluid.transpiler.InferenceTranspiler.__init__ paddle.fluid.transpiler.InferenceTranspiler.__init__
paddle.fluid.transpiler.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0)) paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
......
...@@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) ...@@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
endif() endif()
if (NOT WIN32) if (NOT WIN32)
......
...@@ -3,14 +3,18 @@ cc_library(graph SRCS graph.cc DEPS node) ...@@ -3,14 +3,18 @@ cc_library(graph SRCS graph.cc DEPS node)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter) cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass) cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
cc_library(fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detector graph pass graph_traits framework_proto)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/api/helper.h"
namespace paddle {
namespace framework {
namespace ir {
struct Param {
std::string X = "concat_0.tmp_0";
std::string C0 = "cell_init";
std::string H0 = "hidden_init";
std::string AttentionWeight = "attention_fc.w_0";
std::string AttentionBias = "attention_fc.b_0";
std::string AttentionScalar = "attention_output.w_0";
std::string AttentionScalarBias = "attention_output.b_0";
std::string LSTMWeight = "attention_w.new";
std::string LSTMBias = "attention_b.new";
std::string Hidden = "array_to_lod_tensor_0.tmp_0";
std::string Cell = "at.cell.new";
std::string AttentionedX = "at.x.new";
std::string AttentionFCOut = "at.fc.new";
std::string LSTMX = "at.lstmx.new";
std::string LSTMOUT = "at.lstmout.new";
};
void PrepareParameters(Graph* graph, const Param& param);
void FindWhileOp(Graph* graph) {
GraphPatternDetector gpd;
std::unordered_set<int> fused_external_ops(
{35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48,
57, 55, 56, 52, 74, 80, 77, 78, 79, 50, 77, 39, 40, 51});
gpd.mutable_pattern()->NewNode(
[&](Node* n) { return fused_external_ops.count(n->id()); }, "while");
if (!graph->Has(kGraphvizMarkedNodeAttr)) {
graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t);
}
auto& marked_nodes =
graph->Get<GraphVizPass::marked_nodes_t>(kGraphvizMarkedNodeAttr);
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto* while_pat_node = gpd.pattern().RetriveNode("while");
auto* while_node = subgraph.at(while_pat_node);
marked_nodes.insert(while_node);
};
gpd(graph, handle);
Param param;
// Add AttentionLSTM node
OpDesc op_desc;
op_desc.SetType("attention_lstm");
#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x});
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x});
OP_SET_IN(X);
OP_SET_IN(C0);
OP_SET_IN(H0);
OP_SET_IN(AttentionWeight);
OP_SET_IN(AttentionBias);
OP_SET_IN(AttentionScalar);
OP_SET_IN(AttentionScalarBias);
OP_SET_IN(LSTMWeight);
OP_SET_IN(LSTMBias);
OP_SET_OUT(Hidden);
OP_SET_OUT(Cell);
OP_SET_OUT(AttentionedX);
OP_SET_OUT(AttentionFCOut);
OP_SET_OUT(LSTMX);
OP_SET_OUT(LSTMOUT);
#undef OP_SET_IN
#undef OP_SET_OUT
auto* X = graph->RetriveNode(34);
auto* LSTMOUT = graph->RetriveNode(81);
auto* cell_init = graph->RetriveNode(6);
auto* hidden_init = graph->RetriveNode(8);
#define LINK_TO(node0, node1) \
node0->outputs.push_back(node1); \
node1->inputs.push_back(node0);
auto* lstm_op = graph->CreateOpNode(&op_desc);
PrepareParameters(graph, param);
LINK_TO(X, lstm_op);
LINK_TO(cell_init, lstm_op);
LINK_TO(hidden_init, lstm_op);
LINK_TO(lstm_op, LSTMOUT);
GraphSafeRemoveNodes(graph, marked_nodes);
}
#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x);
#define CHECK_P2(x0, x1) \
CHECK_P1(x0); \
CHECK_P1(x1);
#define CHECK_P3(x0, x1, x2) \
CHECK_P2(x0, x1); \
CHECK_P1(x2);
#define CHECK_P4(x0, x1, x2, x3) \
CHECK_P3(x0, x1, x2); \
CHECK_P1(x3);
#define CHECK_P5(x0, x1, x2, x3, x4) \
CHECK_P4(x0, x1, x2, x3); \
CHECK_P1(x4);
void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
const LoDTensor& W_forget_w1,
const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
const LoDTensor& W_output_w0,
const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
const LoDTensor& W_cell_w1, LoDTensor* out);
void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
const LoDTensor& B_output, const LoDTensor& B_cell,
LoDTensor* out);
void PrepareParameters(Graph* graph, const Param& param) {
// Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
// Create new parameters.
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope->Var(param.Hidden)->GetMutable<LoDTensor>();
scope->Var(param.Cell)->GetMutable<LoDTensor>();
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
#define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_w1" \
<< " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_b0" \
<< " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \
auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \
auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();
GATE_W(forget);
GATE_W(input);
GATE_W(output);
GATE_W(c);
#undef GATE_W
auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
auto* attention_output_w = scope->FindVar("attention_output.w_0");
auto* attention_output_b = scope->FindVar("attention_output.b_0");
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
attention_output_b);
auto* lstm_weight = scope->Var(param.LSTMWeight);
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
auto* lstm_bias = scope->Var(param.LSTMBias);
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
// reshape attention_bias
auto* attention_bias_t =
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
auto* attention_scalar_bias_t =
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
attention_scalar_bias_t->Resize(
make_ddim({1, attention_scalar_bias_t->dims()[0]}));
PrepareLSTMWeight(W_forget_w0_t, W_forget_w1_t, W_input_w0_t, W_input_w1_t,
W_output_w0_t, W_output_w1_t, W_c_w0_t, W_c_w1_t,
lstm_weight_t);
PrepareLSTMBias(W_forget_b0_t, W_input_b0_t, W_output_b0_t, W_c_b0_t,
lstm_bias_t);
}
// Prepare parameters
void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
const LoDTensor& W_forget_w1,
const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
const LoDTensor& W_output_w0,
const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
const LoDTensor& W_cell_w1, LoDTensor* out) {
int D = W_forget_w0.dims()[0];
int M = W_forget_w1.dims()[0];
out->Resize(make_ddim({D + M, 4 * D}));
VLOG(3) << "LSTMWeight resized to " << out->dims();
float* out_data = out->mutable_data<float>(platform::CPUPlace());
std::array<const float*, 4> tensors(
{W_forget_w0.data<float>(), W_input_w0.data<float>(),
W_output_w0.data<float>(), W_cell_w0.data<float>()});
std::array<const float*, 4> tensors1(
{W_forget_w1.data<float>(), W_input_w1.data<float>(),
W_output_w1.data<float>(), W_cell_w1.data<float>()});
for (int row = 0; row < D; row++) {
for (int col = 0; col < 4; col++) {
float* dst = out_data + 4 * D * row + D * col;
const float* src = tensors[col] + D * row;
memcpy(dst, src, D * sizeof(float));
}
}
for (int row = 0; row < M; row++) {
for (int col = 0; col < 4; col++) {
float* dst = out_data + 4 * D * (D + row) + D * col;
const float* src = tensors1[col] + D * row;
memcpy(dst, src, D * sizeof(float));
}
}
}
void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
const LoDTensor& B_output, const LoDTensor& B_cell,
LoDTensor* out) {
std::array<const float*, 4> tensors(
{B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
B_cell.data<float>()});
PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1);
int D = B_forget.dims()[0];
out->Resize(make_ddim({1, 4 * D}));
auto* out_data = out->mutable_data<float>(platform::CPUPlace());
for (size_t i = 0; i < tensors.size(); i++) {
memcpy(out_data + D * i, tensors[i], D * sizeof(float));
}
}
// Parameters
std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PDPattern external_pattern, subblock_pattern;
FindWhileOp(graph.get());
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(attention_lstm_fuse_pass,
paddle::framework::ir::AttentionLSTMFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,12 +12,19 @@ ...@@ -12,12 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/dot.h" #pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle { namespace paddle {
namespace inference { namespace framework {
namespace analysis { namespace ir {
size_t Dot::counter = 0;
} // namespace analysis class AttentionLSTMFusePass : public FusePassBase {
} // namespace inference protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) { ...@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) {
}, },
"elementwise_add_out"); "elementwise_add_out");
pattern->AddEdge(mul_parameter_var, mul_op); mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var})
pattern->AddEdge(mul_tmp_input_var, mul_op); .LinksTo({mul_out_var});
pattern->AddEdge(mul_op, mul_out_var); elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
pattern->AddEdge(mul_out_var, elementwise_add_op); .LinksTo({elementwise_add_out_var});
pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op);
pattern->AddEdge(elementwise_add_op, elementwise_add_out_var);
} }
// Replace the node `from` in the links to `to` // Replace the node `from` in the links to `to`
...@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
GraphPatternDetecter gpd; GraphPatternDetector gpd;
BuildFCPattern(gpd.mutable_pattern()); BuildFCPattern(gpd.mutable_pattern());
#define GET_NODE(id) \ #define GET_NODE(id) \
...@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \ auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle FC fuse"; VLOG(4) << "handle FC fuse";
// Currently, there is no FC op available, so I will just simulate the // Currently, there is no FC op available, so I will just simulate the
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
std::unordered_set<int> fused_ops({// first lstm
13, 15, 16,
// second lstm
23, 25, 26});
pattern->NewNode([&](Node* x) { return fused_ops.count(x->id()); },
"any_node");
std::unordered_set<Node*> marked_nodes;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node"));
marked_nodes.insert(id);
};
gpd(graph.get(), handler);
// Create New OpDesc
auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h,
int bias, int hidden, int cell, int xx) {
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE(input);
GET_NODE(weight_x);
GET_NODE(weight_h);
GET_NODE(bias);
GET_NODE(hidden);
GET_NODE(cell);
GET_NODE(xx);
GET_NODE(lstm);
OpDesc op_desc;
op_desc.SetType("fusion_lstm");
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()});
SET_IN(X, input);
SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h);
SET_IN(Bias, bias);
#undef GET_NODE
#undef SET_IN
LOG(INFO) << "hidden_n: " << hidden_n->Name();
LOG(INFO) << "cell: " << cell_n->Name();
LOG(INFO) << "xx: " << xx_n->Name();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden_n->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()});
op_desc.SetOutput("XX", {xx_n->Name()});
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"});
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", false);
auto* op = graph->CreateOpNode(&op_desc);
#define LINK_TO(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
LINK_TO(input_n, op);
LINK_TO(weight_x_n, op);
LINK_TO(weight_h_n, op);
LINK_TO(bias_n, op);
LINK_TO(op, hidden_n);
#undef LINK_TO
return op;
};
lstm_creator(16, 12, 14, 18, 17, 22, 21, 19);
lstm_creator(26, 12, 24, 28, 27, 32, 31, 29);
// remove all the nodes
for (auto* node : marked_nodes) {
graph->RemoveNode(const_cast<Node*>(node));
}
for (auto* node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (marked_nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it);
} else
it++;
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (marked_nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it);
} else
it++;
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class FCLstmFusePass : public Pass {
public:
virtual ~FCLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace ir {
static const char kParamScopeAttr[] = "param_scope";
class FusePassBase : public Pass {
public:
void Init(Graph* graph) const { graph_ = graph; }
Scope* param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
}
virtual ~FusePassBase() {}
protected:
mutable Graph* graph_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -99,13 +99,13 @@ class Graph { ...@@ -99,13 +99,13 @@ class Graph {
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc); PADDLE_ENFORCE(var_desc);
return AddNode(new ir::Node(var_desc)); return AddNode(new ir::Node(var_desc, node_count_++));
} }
// Create a normal runnable operator with OpDesc. // Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc); PADDLE_ENFORCE(op_desc);
return AddNode(new ir::Node(op_desc)); return AddNode(new ir::Node(op_desc, node_count_++));
} }
// Create a control dependency var that connects 2 operations. The // Create a control dependency var that connects 2 operations. The
...@@ -115,13 +115,14 @@ class Graph { ...@@ -115,13 +115,14 @@ class Graph {
// TODO(panyx0718): control var name should be really unique. // TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf( const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); "%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(new ir::Node(name, ir::Node::Type::kVariable)); return AddNode(
new ir::Node(name, ir::Node::Type::kVariable, node_count_++));
} }
// A more free style way of creating a graph node. Mostly use for test // A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible. // or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
return AddNode(new ir::Node(name, type)); return AddNode(new ir::Node(name, type, node_count_++));
} }
// Clear all node information of the graph and return the ownership of the // Clear all node information of the graph and return the ownership of the
...@@ -142,12 +143,20 @@ class Graph { ...@@ -142,12 +143,20 @@ class Graph {
nodes_.erase(node); nodes_.erase(node);
} }
Node *RetriveNode(int id) {
auto it = id2node_.find(id);
if (it != id2node_.end()) return it->second;
return nullptr;
}
private: private:
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
nodes_[node].reset(node); nodes_[node].reset(node);
node_set_.insert(node); node_set_.insert(node);
PADDLE_ENFORCE(!id2node_.count(node->id()), "duplicate id %d", node->id());
id2node_[node->id()] = node;
return node; return node;
} }
...@@ -157,6 +166,8 @@ class Graph { ...@@ -157,6 +166,8 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_; std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_; std::unordered_set<ir::Node *> node_set_;
std::map<int, Node *> id2node_;
int node_count_{0};
}; };
bool IsControlDepVar(const ir::Node &var); bool IsControlDepVar(const ir::Node &var);
......
...@@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( ...@@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
for (auto &var : n->inputs) { for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) { for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
adj_list[n].insert(adj_n);
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
adj_list[n].insert(adj_n);
} }
} }
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { ...@@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
name); name);
} }
nodes_.emplace_back(new PDNode(std::move(teller), name)); nodes_.emplace_back(new PDNode(std::move(teller), this, name));
auto* cur = nodes_.back().get(); auto* cur = nodes_.back().get();
node_map_[name] = cur; node_map_[name] = cur;
return cur; return cur;
...@@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) { ...@@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
edges_.emplace_back(a, b); edges_.emplace_back(a, b);
} }
void GraphPatternDetecter::operator()(Graph* graph, void GraphPatternDetector::operator()(Graph* graph,
GraphPatternDetecter::handle_t handler) { GraphPatternDetector::handle_t handler) {
if (!MarkPDNodesInGraph(*graph)) return; if (!MarkPDNodesInGraph(*graph)) return;
auto subgraphs = DetectPatterns(); auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs); UniquePatterns(&subgraphs);
RemoveOverlappedMatch(&subgraphs); RemoveOverlappedMatch(&subgraphs);
LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
int id = 0;
for (auto& g : subgraphs) { for (auto& g : subgraphs) {
LOG(INFO) << "optimizing #" << id++ << " subgraph";
handler(g, graph); handler(g, graph);
} }
} }
bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) { bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
VLOG(4) << "mark pdnodes in graph"; VLOG(4) << "mark pdnodes in graph";
if (graph.Nodes().empty()) return false; if (graph.Nodes().empty()) return false;
...@@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) { ...@@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) {
return false; return false;
} }
std::vector<GraphPatternDetecter::subgraph_t> std::vector<GraphPatternDetector::subgraph_t>
GraphPatternDetecter::DetectPatterns() { GraphPatternDetector::DetectPatterns() {
// Init empty subgraphs. // Init empty subgraphs.
std::vector<GraphPatternDetecter::subgraph_t> result; std::vector<GraphPatternDetector::subgraph_t> result;
std::vector<HitGroup> init_groups; std::vector<HitGroup> init_groups;
PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed"); std::array<std::vector<HitGroup>, 2> bi_records;
auto* first_pnode = pattern_.edges().front().first; // PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto* first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
: pattern_.edges().front().first;
if (!pdnodes2nodes_.count(first_pnode)) return result; if (!pdnodes2nodes_.count(first_pnode)) return result;
for (auto* node : pdnodes2nodes_[first_pnode]) { for (auto* node : pdnodes2nodes_[first_pnode]) {
HitGroup group; HitGroup group;
...@@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() {
} }
int step = 0; int step = 0;
std::array<std::vector<HitGroup>, 2> bi_records;
bi_records[0] = std::move(init_groups); bi_records[0] = std::move(init_groups);
// Extend a PDNode to subgraphs by deducing the connection relations defined // Extend a PDNode to subgraphs by deducing the connection relations defined
...@@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() {
auto& pre_groups = bi_records[step % 2]; auto& pre_groups = bi_records[step % 2];
auto& cur_groups = bi_records[1 - (step++ % 2)]; auto& cur_groups = bi_records[1 - (step++ % 2)];
cur_groups.clear(); cur_groups.clear();
if (pre_groups.empty()) break;
// source -> target // source -> target
for (Node* source : pdnodes2nodes_[edge.first]) { for (Node* source : pdnodes2nodes_[edge.first]) {
for (Node* target : pdnodes2nodes_[edge.second]) { for (Node* target : pdnodes2nodes_[edge.second]) {
...@@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() {
} }
for (auto& group : bi_records[step % 2]) { for (auto& group : bi_records[step % 2]) {
GraphPatternDetecter::subgraph_t subgraph; GraphPatternDetector::subgraph_t subgraph;
for (auto& role : group.roles) { for (auto& role : group.roles) {
subgraph.emplace(role.first, role.second); subgraph.emplace(role.first, role.second);
} }
...@@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() {
return result; return result;
} }
void GraphPatternDetecter::UniquePatterns( void GraphPatternDetector::UniquePatterns(
std::vector<GraphPatternDetecter::subgraph_t>* subgraphs) { std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
if (subgraphs->empty()) return; if (subgraphs->empty()) return;
std::vector<GraphPatternDetecter::subgraph_t> result; std::vector<GraphPatternDetector::subgraph_t> result;
std::unordered_set<size_t> set; std::unordered_set<size_t> set;
for (auto& g : *subgraphs) { for (auto& g : *subgraphs) {
...@@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns( ...@@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns(
*subgraphs = result; *subgraphs = result;
} }
void GraphPatternDetecter::RemoveOverlappedMatch( void GraphPatternDetector::RemoveOverlappedMatch(
std::vector<subgraph_t>* subgraphs) { std::vector<subgraph_t>* subgraphs) {
std::vector<subgraph_t> result; std::vector<subgraph_t> result;
std::unordered_set<Node*> node_set; std::unordered_set<Node*> node_set;
...@@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch( ...@@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch(
*subgraphs = result; *subgraphs = result;
} }
std::string PDPattern::DotString() const {
using inference::analysis::Dot;
Dot dot;
int id = 0;
// Create Nodes
std::unordered_map<PDNode*, std::string> node2dot;
for (const auto& node : nodes()) {
std::string node_id = "Node" + std::to_string(id++);
dot.AddNode(node_id, {}, node->name());
node2dot[node.get()] = node_id;
}
// Create Edges
for (const auto& edge : edges()) {
if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) {
LOG(ERROR) << "no node " << edge.first << " " << edge.second;
continue;
}
auto& src = node2dot.at(edge.first);
auto& trg = node2dot.at(edge.second);
dot.AddEdge(src, trg, {});
}
return dot.Build();
}
PDNode& PDNode::LinksTo(const std::vector<PDNode*>& others) {
// extend outlinks.
for (PDNode* x : others) {
pattern_->AddEdge(this, x);
}
return *this;
}
PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
// extend outlinks.
for (PDNode* x : others) {
pattern_->AddEdge(x, this);
}
return *this;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -21,12 +21,14 @@ ...@@ -21,12 +21,14 @@
#include <numeric> #include <numeric>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class PDPattern;
// Some basic torminolygies: // Some basic terminologies:
// - PDPattern: a pattern defined as a data flow graph. // - PDPattern: a pattern defined as a data flow graph.
// - PDNode: the node in the pattern, each PDNode represents an `ir::Node` // - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
// that meets some conditions defined in `PDNode.teller`. // that meets some conditions defined in `PDNode.teller`.
...@@ -36,30 +38,43 @@ namespace ir { ...@@ -36,30 +38,43 @@ namespace ir {
struct PDNode { struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode. // tell whether an ir::Node* is a candidation for a PDNode.
using teller_t = std::function<bool(Node*)>; using teller_t = std::function<bool(Node*)>;
enum class Type { kOp, kVar };
PDNode(teller_t&& teller, const std::string& name = "") // this link to others
: teller_(teller), name_(name) { PDNode& LinksTo(const std::vector<PDNode*>& others);
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set."); PDNode& LinksFrom(const std::vector<PDNode*>& others);
}
PDNode(PDNode&& other) = default;
std::vector<PDNode*> inlinks;
std::vector<PDNode*> outlinks;
bool Tell(Node* node) const { bool Tell(Node* node) const {
PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode");
return teller_(node); return teller_(node);
} }
bool IsOp() const { return type_ == Type::kOp; }
bool IsVar() const { return type_ == Type::kVar; }
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
PDNode(const PDNode&) = delete; PDNode(const PDNode&) = delete;
PDNode& operator=(const PDNode&) = delete; PDNode& operator=(const PDNode&) = delete;
private: private:
PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
: teller_(std::move(teller)),
pattern_(pattern),
name_(name),
type_(type) {
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set.");
}
PDNode(PDNode&& other) = default;
friend class PDPattern;
teller_t teller_; teller_t teller_;
PDPattern* pattern_;
std::string name_; std::string name_;
Type type_;
}; };
/* /*
...@@ -102,6 +117,8 @@ class PDPattern { ...@@ -102,6 +117,8 @@ class PDPattern {
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; } const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
const std::vector<edge_t>& edges() const { return edges_; } const std::vector<edge_t>& edges() const { return edges_; }
std::string DotString() const;
private: private:
#ifdef PADDLE_WITH_TESTING #ifdef PADDLE_WITH_TESTING
FRIEND_TEST(PDPattern, AddEdge); FRIEND_TEST(PDPattern, AddEdge);
...@@ -117,7 +134,7 @@ class PDPattern { ...@@ -117,7 +134,7 @@ class PDPattern {
}; };
/* /*
* GraphPatternDetecter helps to detect the specific patterns in the graph. * GraphPatternDetector helps to detect the specific patterns in the graph.
* Input a pattern, output a list of the matched subgraphs/nodes. * Input a pattern, output a list of the matched subgraphs/nodes.
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.). * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
* *
...@@ -129,7 +146,7 @@ class PDPattern { ...@@ -129,7 +146,7 @@ class PDPattern {
* *
* Usage: * Usage:
* // Create a detector * // Create a detector
* GraphPatternDetecter detector; * GraphPatternDetector detector;
* // Define the detector's pattern, by adding PDNode and define the edges. * // Define the detector's pattern, by adding PDNode and define the edges.
* auto* node0 = detector.mutable_pattern().AddNode(...) * auto* node0 = detector.mutable_pattern().AddNode(...)
* auto* node1 = detector.mutable_pattern().AddNode(...) * auto* node1 = detector.mutable_pattern().AddNode(...)
...@@ -138,11 +155,11 @@ class PDPattern { ...@@ -138,11 +155,11 @@ class PDPattern {
* detector.mutable_pattern().AddEdge(node0, node1); * detector.mutable_pattern().AddEdge(node0, node1);
* // Create an handler, to define the behavior of treating the filtered * // Create an handler, to define the behavior of treating the filtered
* // subgraphs that comply with the patterns. * // subgraphs that comply with the patterns.
* GraphPatternDetecter::handle_t handler = some labmda * GraphPatternDetector::handle_t handler = some labmda
* // Execute the detector. * // Execute the detector.
* detector(&graph, handler); * detector(&graph, handler);
*/ */
class GraphPatternDetecter { class GraphPatternDetector {
public: public:
using subgraph_t = std::unordered_map<PDNode*, Node*>; using subgraph_t = std::unordered_map<PDNode*, Node*>;
...@@ -177,10 +194,62 @@ class GraphPatternDetecter { ...@@ -177,10 +194,62 @@ class GraphPatternDetecter {
using hit_rcd_t = using hit_rcd_t =
std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>; std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>;
PDPattern pattern_; PDPattern pattern_;
std::vector<hit_rcd_t> marked_records_;
std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_; std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_;
}; };
// some helper methods.
// Op's input.
static bool VarLinksToOp(Node* node, const std::string& op_type) {
for (auto* out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
// Op's output.
static bool VarLinksFromOp(Node* node, const std::string& op_type) {
for (auto* out : node->inputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
// Check whether a var node is a op node's nth input.
static bool IsNthInput(Node* var, Node* op, const std::string& argument,
size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->inputs.size() <= nth) return false;
return var->Name() == op->Op()->Input(argument)[nth];
}
static void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes) {
for (auto* node : nodes) {
graph->RemoveNode(const_cast<Node*>(node));
}
for (auto* node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it);
} else
it++;
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it);
} else
it++;
}
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) { ...@@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) {
} }
TEST(GraphPatternDetecter, MarkPDNodesInGraph) { TEST(GraphPatternDetecter, MarkPDNodesInGraph) {
GraphPatternDetecter x; GraphPatternDetector x;
// mark o2, o3, v2 // mark o2, o3, v2
// The pattern is a graph: // The pattern is a graph:
...@@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ...@@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
Graph graph(program); Graph graph(program);
BuildGraph(&graph); BuildGraph(&graph);
GraphPatternDetecter x; GraphPatternDetector x;
// The pattern is a graph: // The pattern is a graph:
// op -> var // op -> var
...@@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ...@@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
x.mutable_pattern()->AddEdge(any_var, any_op1); x.mutable_pattern()->AddEdge(any_var, any_op1);
int count = 0; int count = 0;
GraphPatternDetecter::handle_t handle = [&]( GraphPatternDetector::handle_t handle = [&](
const GraphPatternDetecter::subgraph_t& s, Graph* g) { const GraphPatternDetector::subgraph_t& s, Graph* g) {
LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> " LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> "
<< s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name(); << s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name();
count++; count++;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
std::unique_ptr<Graph> graph) const {
ProgramDesc& program = Get<ProgramDesc>("program");
std::unique_ptr<proto::ProgramDesc> program_pb(
new proto::ProgramDesc(*program.Proto()));
auto block = program_pb->mutable_blocks(kRootBlockIndex);
block->clear_vars();
std::unordered_set<std::string> visited_vars;
for (ir::Node* n : graph->Nodes()) {
if (n->NodeType() == ir::Node::Type::kVariable) {
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) {
visited_vars.insert(n->Var()->Name());
block->add_vars()->MergeFrom(*n->Var()->Proto());
}
}
}
block->clear_ops();
std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
for (ir::Node* n : nodes) {
if (!n->Op()) {
continue;
}
block->add_ops()->MergeFrom(*n->Op()->Proto());
}
program.CopyFrom(*program_pb);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(graph_to_program_pass, paddle::framework::ir::GraphToProgramPass);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class GraphToProgramPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
void BuildNoCircleGraph(Graph* g) {
OpDesc op1;
op1.SetType("op1");
OpDesc op2;
op2.SetType("op2");
OpDesc op3;
op3.SetType("op3");
OpDesc op4;
op4.SetType("op4");
OpDesc op5;
op5.SetType("op5");
VarDesc var1("var1");
VarDesc var2("var2");
VarDesc var3("var3");
VarDesc var4("var4");
ir::Node* o1 = g->CreateOpNode(&op1);
ir::Node* o2 = g->CreateOpNode(&op2);
ir::Node* o3 = g->CreateOpNode(&op3);
ir::Node* o4 = g->CreateOpNode(&op4);
ir::Node* o5 = g->CreateOpNode(&op5);
ir::Node* v1 = g->CreateVarNode(&var1);
ir::Node* v2 = g->CreateVarNode(&var2);
ir::Node* v3 = g->CreateVarNode(&var3);
ir::Node* v4 = g->CreateVarNode(&var4);
// o1->v1->o2
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
// o2->v2->o3
// o2->v2->o4
o2->outputs.push_back(v2);
o3->inputs.push_back(v2);
o4->inputs.push_back(v2);
v2->outputs.push_back(o3);
v2->outputs.push_back(o4);
v2->inputs.push_back(o2);
// o2->v3->o5
o2->outputs.push_back(v3);
o5->inputs.push_back(v3);
v3->inputs.push_back(o2);
v3->outputs.push_back(o5);
// o3-v4->o5
o3->outputs.push_back(v4);
o5->inputs.push_back(v4);
v4->inputs.push_back(o3);
v4->outputs.push_back(o5);
}
TEST(GraphToProgramPass, Basic) {
ProgramDesc prog;
std::unique_ptr<Graph> g(new Graph(prog));
BuildNoCircleGraph(g.get());
auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
"graph_to_program_pass");
ProgramDesc compiled_prog;
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &compiled_prog);
pass->Apply(std::move(g));
std::vector<OpDesc*> ops = compiled_prog.Block(0).AllOps();
EXPECT_EQ(ops[0]->Type(), "op1");
EXPECT_EQ(ops[1]->Type(), "op2");
if (ops[2]->Type() == "op3") {
EXPECT_EQ(ops[3]->Type(), "op4");
} else if (ops[2]->Type() == "op4") {
EXPECT_EQ(ops[3]->Type(), "op3");
}
EXPECT_EQ(ops[4]->Type(), "op5");
std::unordered_set<std::string> vars;
for (VarDesc* v : compiled_prog.Block(0).AllVars()) {
vars.insert(v->Name());
}
EXPECT_TRUE(vars.find("var1") != vars.end());
EXPECT_TRUE(vars.find("var2") != vars.end());
EXPECT_TRUE(vars.find("var3") != vars.end());
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(graph_to_program_pass);
...@@ -16,11 +16,13 @@ limitations under the License. */ ...@@ -16,11 +16,13 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static const char kGraphVizPath[] = "graph_viz_path"; static const char kGraphVizPath[] = "graph_viz_path";
using inference::analysis::Dot;
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
...@@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( ...@@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout; std::ostream& sout = *fout;
size_t var_id = 0; std::unordered_map<const ir::Node*, std::string> node2dot;
std::unordered_map<const ir::Node*, size_t> vars;
Dot dot;
sout << "digraph G {\n";
std::vector<Dot::Attr> op_attrs({Dot::Attr("style", "filled"),
for (const ir::Node* n : graph->Nodes()) { Dot::Attr("shape", "box"),
if (n->NodeType() != ir::Node::Type::kVariable) continue; Dot::Attr("fillcolor", "red")});
size_t cur_var_id = var_id++; std::vector<Dot::Attr> var_attrs({Dot::Attr("style", "filled,rounded"),
vars[n] = cur_var_id; // Dot::Attr("shape", "diamond"),
Dot::Attr("fillcolor", "yellow")});
sout << "var_" << cur_var_id << " [label=\"" << n->Name() << "\"]"
<< std::endl; std::vector<Dot::Attr> marked_op_attrs({Dot::Attr("style", "filled"),
} Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "lightgray")});
size_t op_id = 0; std::vector<Dot::Attr> marked_var_attrs(
for (const ir::Node* n : graph->Nodes()) { {Dot::Attr("style", "filled,rounded"),
if (n->NodeType() != ir::Node::Type::kOperation) continue; // Dot::Attr("shape", "diamond"),
std::string op_name = "op_" + std::to_string(op_id++); Dot::Attr("fillcolor", "lightgray")});
sout << op_name << " [label=\"" << n->Name() << "\", shape=rect]"
<< std::endl; auto marked_nodes = ConsumeMarkedNodes(graph.get());
for (auto in : n->inputs) { // Create nodes
std::string var_name = "var_" + std::to_string(vars[in]); for (const Node* n : graph->Nodes()) {
sout << var_name << " -> " << op_name << std::endl; std::string node_id = n->Name() + "(" + std::to_string(n->id()) + ")";
if (n->IsOp()) {
decltype(op_attrs) attr =
marked_nodes.count(n) ? marked_op_attrs : op_attrs;
dot.AddNode(node_id, attr, node_id);
} else if (n->IsVar()) {
decltype(op_attrs) attr =
marked_nodes.count(n) ? marked_var_attrs : var_attrs;
dot.AddNode(node_id, attr, node_id);
} }
node2dot[n] = node_id;
for (auto out : n->outputs) { }
std::string var_name = "var_" + std::to_string(vars[out]); // Create edges
sout << op_name << " -> " << var_name << std::endl; for (const Node* n : graph->Nodes()) {
const auto& src_id = node2dot.at(n);
for (auto* out : n->outputs) {
const auto& trg_id = node2dot.at(out);
dot.AddEdge(src_id, trg_id, {});
} }
} }
sout << "}\n"; sout << dot.Build();
return graph; return graph;
} }
GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
Graph* graph) const {
marked_nodes_t res;
if (graph->Has(kGraphvizMarkedNodeAttr)) {
auto& attr = graph->Get<marked_nodes_t>(kGraphvizMarkedNodeAttr);
res = attr;
attr.clear();
}
return res;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -27,10 +27,19 @@ namespace paddle { ...@@ -27,10 +27,19 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
const char kGraphvizMarkedNodeAttr[] = "__graphviz__marked_node__";
class GraphVizPass : public Pass { class GraphVizPass : public Pass {
public:
using marked_nodes_t = std::unordered_set<const Node*>;
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
// Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute.
marked_nodes_t ConsumeMarkedNodes(Graph* graph) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -29,20 +29,26 @@ class Node { ...@@ -29,20 +29,26 @@ class Node {
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var"; static constexpr char kControlDepVarName[] = "__control_var";
explicit Node(const std::string& name, Type type) explicit Node(const std::string& name, Type type, int id = -1)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} : name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
id_(id) {}
explicit Node(VarDesc* var_desc) explicit Node(VarDesc* var_desc, int id = -1)
: name_(var_desc->Name()), : name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)), var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr), op_desc_(nullptr),
type_(Type::kVariable) {} type_(Type::kVariable),
id_(id) {}
explicit Node(OpDesc* op_desc) explicit Node(OpDesc* op_desc, int id = -1)
: name_(op_desc->Type()), : name_(op_desc->Type()),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())), op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation) {} type_(Type::kOperation),
id_(id) {}
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
...@@ -58,6 +64,8 @@ class Node { ...@@ -58,6 +64,8 @@ class Node {
return op_desc_.get(); return op_desc_.get();
} }
int id() const { return id_; }
bool IsOp() const { return type_ == Type::kOperation; } bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; } bool IsVar() const { return type_ == Type::kVariable; }
...@@ -69,6 +77,7 @@ class Node { ...@@ -69,6 +77,7 @@ class Node {
std::unique_ptr<VarDesc> var_desc_; std::unique_ptr<VarDesc> var_desc_;
std::unique_ptr<OpDesc> op_desc_; std::unique_ptr<OpDesc> op_desc_;
Type type_; Type type_;
int id_;
private: private:
DISABLE_COPY_AND_ASSIGN(Node); DISABLE_COPY_AND_ASSIGN(Node);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
namespace ir {
struct FuseExpr {};
// sequence expand, concat fuse pattern, return concat's output
PDNode* BuildSeqExpandConcatPattern(PDPattern* pattern) {
// The following operators will be fused:
// concat
// sequence_expand
// sequence_expand
// The following variables will be treat as inputs:
// concat mid input, 0th input for fused op
// sequence_expand input, 1th input for fused op
// sequence_expand input, 2th input for fused op
// The following variables will be treat as outputs:
// concat output
// So the following variables will be removed:
// sequence-expand output
// sequence-expand output
// Three operators
auto* sequence_expand0 = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "sequence_expand";
},
"sequence_expand0");
auto* sequence_expand1 = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "sequence_expand";
},
"sequence_expand1");
auto* concat = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "concat" && // basic check
x->Op()->Input("X").size() == 3; // Special case
},
"concat");
auto* sequence_expand0_in = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && VarLinksToOp(x, "sequence_expand");
},
"sequence_expand0_in");
auto* sequence_expand1_in = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && VarLinksToOp(x, "sequence_expand");
},
"sequence_expand1_in");
// The variables
auto* sequence_expand0_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() &&
VarLinksFromOp(x, "sequence_expand") && // basic check
VarLinksToOp(x, "concat") && // is concat's input
IsNthInput(x, x->outputs[0], "X", 1); // X[0]
},
"sequence_expand0_out");
auto* sequence_expand1_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() &&
VarLinksFromOp(x, "sequence_expand") && // basic check
VarLinksToOp(x, "concat") && // is concat's input
IsNthInput(x, x->outputs[0], "X", 2); // x[2]
},
"sequence_expand1_out");
auto* concat_in0 = pattern->NewNode(
[](Node* x) { return x && x->IsVar() && VarLinksToOp(x, "concat"); },
"concat_in0");
auto* concat_out = pattern->NewNode(
[](Node* x) { return x && x->IsVar() && VarLinksFromOp(x, "concat"); },
"concat_out");
// Links
sequence_expand0->LinksFrom({sequence_expand0_in})
.LinksTo({sequence_expand0_out});
sequence_expand1->LinksFrom({sequence_expand1_in})
.LinksTo({sequence_expand1_out});
concat->LinksFrom({sequence_expand0_out, sequence_expand1_out, concat_in0})
.LinksTo({concat_out});
return concat_out;
}
PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
PDNode* fc_w = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksToOp(x, "mul") && // link
x->Var()->Proto()->persistable(); // is a parameter
},
"fc_w");
PDNode* mul_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksFromOp(x, "mul") && // link
VarLinksToOp(x, "elementwise_add") && //
!x->Var()->Proto()->persistable(); // is a parameter
},
"mul_out");
PDNode* fc_mul = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "mul"; // basic
},
"fc_mul");
PDNode* fc_bias = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksToOp(x, "elementwise_add") && // link
x->Var()->Proto()->persistable(); // is a parameter
},
"fc_bias");
PDNode* elementwise_add = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "elementwise_add";
},
"elementwise_add");
PDNode* add_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksFromOp(x, "elementwise_add") && // link
!x->Var()->Proto()->persistable(); // is a parameter
},
"add_out");
std::set<std::string> acts({"sigmoid", "tanh", "relu", "identity"});
PDNode* act = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && acts.count(x->Op()->Type());
},
"act");
PDNode* fc_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
!x->Var()->Proto()->persistable(); // is a parameter
},
"fc_out");
fc_mul->LinksFrom({fc_w, fc_x}).LinksTo({mul_out});
elementwise_add->LinksFrom({mul_out, fc_bias}).LinksTo({add_out});
act->LinksFrom({add_out}).LinksTo({fc_out});
return fc_out;
}
std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(graph.get());
GraphPatternDetector detector;
auto* pattern = detector.mutable_pattern();
auto* concat_out = BuildSeqExpandConcatPattern(pattern);
BuildFCPattern(pattern, concat_out);
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "get one concat pattern";
// fc
GET_NODE(fc_w, detector.pattern());
GET_NODE(fc_bias, detector.pattern());
GET_NODE(act, detector.pattern());
GET_NODE(fc_out, detector.pattern());
// concat
GET_NODE(concat_in0, detector.pattern());
GET_NODE(sequence_expand0_in, detector.pattern());
GET_NODE(sequence_expand1_in, detector.pattern());
OpDesc op_desc;
op_desc.SetType("fusion_seqexpand_concat_fc");
op_desc.SetInput("X", {concat_in0->Name(), sequence_expand0_in->Name(),
sequence_expand1_in->Name()});
op_desc.SetInput("FCWeight", {fc_w->Name()});
op_desc.SetInput("FCBias", {fc_bias->Name()});
const std::string fc_out_tmp = fc_out->Name() + ".tmp";
param_scope()->Var(fc_out_tmp)->GetMutable<framework::LoDTensor>();
op_desc.SetOutput("FCOut", {fc_out_tmp});
op_desc.SetOutput("Out", {fc_out->Name()});
op_desc.SetAttr("fc_activation", act->Op()->Type());
auto* op_node = graph->CreateOpNode(&op_desc);
// Add links
#define NODE_LINKS(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
NODE_LINKS(fc_w, op_node);
NODE_LINKS(fc_bias, op_node);
NODE_LINKS(concat_in0, op_node);
NODE_LINKS(sequence_expand0_in, op_node);
NODE_LINKS(sequence_expand1_in, op_node);
NODE_LINKS(op_node, fc_out);
// Clean nodes.
std::unordered_set<const Node*> marked_nodes;
for (auto& item : subgraph) {
marked_nodes.insert(item.second);
}
marked_nodes.erase(fc_w);
marked_nodes.erase(fc_bias);
marked_nodes.erase(concat_in0);
marked_nodes.erase(sequence_expand0_in);
marked_nodes.erase(sequence_expand1_in);
marked_nodes.erase(fc_out);
GraphSafeRemoveNodes(graph, marked_nodes);
});
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(seq_concat_fc_fuse_pass,
paddle::framework::ir::SeqConcatFcFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class SeqConcatFcFusePass : public FusePassBase {
public:
virtual ~SeqConcatFcFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -95,6 +95,12 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs, ...@@ -95,6 +95,12 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
need_update_ = true; need_update_ = true;
} }
OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
CopyFrom(other);
block_ = block;
need_update_ = true;
}
void OpDesc::CopyFrom(const OpDesc &op_desc) { void OpDesc::CopyFrom(const OpDesc &op_desc) {
desc_.set_type(op_desc.Type()); desc_.set_type(op_desc.Type());
inputs_ = op_desc.inputs_; inputs_ = op_desc.inputs_;
...@@ -131,8 +137,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) ...@@ -131,8 +137,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
for (const proto::OpDesc::Attr &attr : desc_.attrs()) { for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name(); std::string attr_name = attr.name();
// The sub_block referred to by the BLOCK attr hasn't been added // The sub_block referred to by the BLOCK attr hasn't been added
// to ProgramDesc class yet, we skip setting BLOCK attr here. // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS attr here.
if (attr.type() != proto::AttrType::BLOCK) { if (attr.type() != proto::AttrType::BLOCK &&
attr.type() != proto::AttrType::BLOCKS) {
attrs_[attr_name] = GetAttrValue(attr); attrs_[attr_name] = GetAttrValue(attr);
} }
} }
......
...@@ -37,11 +37,7 @@ class OpDesc { ...@@ -37,11 +37,7 @@ class OpDesc {
explicit OpDesc(BlockDesc *block) : block_(block) {} explicit OpDesc(BlockDesc *block) : block_(block) {}
OpDesc(const OpDesc &other, BlockDesc *block) { OpDesc(const OpDesc &other, BlockDesc *block);
*this = other;
block_ = block;
need_update_ = true;
}
void CopyFrom(const OpDesc &op_desc); void CopyFrom(const OpDesc &op_desc);
......
...@@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { ...@@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
InitFromProto(); InitFromProto();
} }
void ProgramDesc::CopyFrom(const proto::ProgramDesc &desc) {
blocks_.clear();
desc_ = desc;
InitFromProto();
}
ProgramDesc::ProgramDesc(const std::string &binary_str) { ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str), PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string."); "Fail to parse program_desc from binary string.");
...@@ -111,10 +117,16 @@ void ProgramDesc::InitFromProto() { ...@@ -111,10 +117,16 @@ void ProgramDesc::InitFromProto() {
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() { const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
auto &global_block = Block(0); auto &global_block = Block(0);
// The order of feed_target_names must follow the index specified in `col`.
// since feed operator's order doesn't necessary follow 'col'.
std::vector<std::string> feed_target_names; std::vector<std::string> feed_target_names;
for (auto *op : global_block.AllOps()) { for (auto *op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) { if (op->Type() == kFeedOpType) {
feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]); int col = boost::get<int>(op->GetAttr("col"));
if (col >= feed_target_names.size()) {
feed_target_names.resize(col + 1);
}
feed_target_names[col] = op->Output("Out")[0];
} }
} }
return feed_target_names; return feed_target_names;
...@@ -122,10 +134,16 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() { ...@@ -122,10 +134,16 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
const std::vector<std::string> ProgramDesc::GetFetchTargetNames() { const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
auto &global_block = Block(0); auto &global_block = Block(0);
// The order of fetch_target_names must follow the index specified in `col`.
// since fetch operator's order doesn't necessary follow 'col'.
std::vector<std::string> fetch_target_names; std::vector<std::string> fetch_target_names;
for (auto *op : global_block.AllOps()) { for (auto *op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) { if (op->Type() == kFetchOpType) {
fetch_target_names.push_back(op->Input("X")[0]); int col = boost::get<int>(op->GetAttr("col"));
if (col >= fetch_target_names.size()) {
fetch_target_names.resize(col + 1);
}
fetch_target_names[col] = op->Input("X")[0];
} }
} }
return fetch_target_names; return fetch_target_names;
......
...@@ -53,6 +53,8 @@ class ProgramDesc { ...@@ -53,6 +53,8 @@ class ProgramDesc {
void Flush(); void Flush();
void CopyFrom(const proto::ProgramDesc &desc);
proto::ProgramDesc *Proto(); proto::ProgramDesc *Proto();
// The output variable of feed_op is referenced as feed_target. // The output variable of feed_op is referenced as feed_target.
......
...@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor) ...@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor)
# TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal? # TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal?
cc_library(paddle_fluid_api cc_library(paddle_fluid_api
SRCS io.cc SRCS io.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} graph_to_program_pass)
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
......
cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass) cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass)
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc set(analysis_deps
framework_proto proto_desc ir_pass_manager graph pass paddle_fluid_api executor)
cc_library(analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
analyzer.cc analyzer.cc
helper.cc helper.cc
# passes # passes
...@@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph ...@@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
tensorrt_subgraph_node_mark_pass.cc tensorrt_subgraph_node_mark_pass.cc
fluid_to_ir_pass.cc fluid_to_ir_pass.cc
model_store_pass.cc model_store_pass.cc
DEPS framework_proto proto_desc ir_pass_manager graph pass) DEPS ${analysis_deps})
cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis) cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis paddle_fluid)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
...@@ -31,7 +34,7 @@ function (inference_analysis_test TARGET) ...@@ -31,7 +34,7 @@ function (inference_analysis_test TARGET)
endif() endif()
cc_test(${TARGET} cc_test(${TARGET}
SRCS "${analysis_test_SRCS}" SRCS "${analysis_test_SRCS}"
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass ${analysis_test_EXTRA_DEPS} DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detector pass ${analysis_test_EXTRA_DEPS}
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt}) ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING) endif(WITH_TESTING)
...@@ -58,20 +61,25 @@ endif() ...@@ -58,20 +61,25 @@ endif()
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
analysis_predictor
# ir # ir
fc_fuse_pass fc_fuse_pass
fc_lstm_fuse_pass
seq_concat_fc_fuse_pass
graph_viz_pass graph_viz_pass
infer_clean_graph_pass infer_clean_graph_pass
graph_pattern_detecter graph_pattern_detector
infer_clean_graph_pass infer_clean_graph_pass
attention_lstm_fuse_pass
paddle_inference_api
pass pass
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model
--infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt) --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc EXTRA_DEPS paddle_inference_api)
inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc) inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc EXTRA_DEPS paddle_fluid)
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
......
...@@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); } Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) { void Analyzer::Run(Argument* argument) {
// Ungly support fluid-to-ir-pass
argument->Set(kFluidToIrPassesAttr,
new std::vector<std::string>({
// Manual update the passes here.
"graph_viz_pass", //
"infer_clean_graph_pass", "graph_viz_pass", //
"attention_lstm_fuse_pass", "graph_viz_pass", //
"fc_lstm_fuse_pass", "graph_viz_pass", //
"seq_concat_fc_fuse_pass", "graph_viz_pass", //
"fc_fuse_pass", "graph_viz_pass" //
}));
for (auto& x : data_) { for (auto& x : data_) {
PADDLE_ENFORCE(x->Initialize(argument)); PADDLE_ENFORCE(x->Initialize(argument));
x->RunAll(); x->RunAll();
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN");
DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN"); DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN");
...@@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path, ...@@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path,
const std::string &data_path, int batch_size, const std::string &data_path, int batch_size,
bool use_analysis, bool activate_ir, bool use_analysis, bool activate_ir,
int num_times = 1) { int num_times = 1) {
FLAGS_IA_enable_ir = activate_ir;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = "./analysis.out";
std::string model_out;
if (use_analysis) {
Argument argument(model_path);
argument.model_output_store_path.reset(new std::string("./analysis.out"));
Analyzer analyzer;
analyzer.Run(&argument);
// Should get the transformed model stored to ./analysis.out
model_out = "./analysis.out";
ASSERT_TRUE(PathExists(model_out));
} else {
model_out = FLAGS_infer_ditu_rnn_model;
}
NativeConfig config; NativeConfig config;
config.prog_file = model_out + "/__model__"; config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
config.param_file = model_out + "/param"; config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
config.use_gpu = false; config.use_gpu = false;
config.device = 0; config.device = 0;
config.specify_input_name = true; config.specify_input_name = true;
auto predictor = auto base_predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config); CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
auto predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(config);
std::vector<PaddleTensor> input_slots; std::vector<PaddleTensor> input_slots;
DataRecord data(data_path, batch_size); DataRecord data(data_path, batch_size);
// Prepare inputs. // Prepare inputs.
PrepareInputs(&input_slots, &data, batch_size); PrepareInputs(&input_slots, &data, batch_size);
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs, base_outputs;
base_predictor->Run(input_slots, &base_outputs);
Timer timer; Timer timer;
timer.tic(); timer.tic();
...@@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path, ...@@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path,
<< ", latency: " << timer.toc() / num_times << "ms"; << ", latency: " << timer.toc() / num_times << "ms";
LOG(INFO) << "====================================="; LOG(INFO) << "=====================================";
for (auto &out : outputs) { PADDLE_ENFORCE_GT(outputs.size(), 0);
PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
auto &out = outputs[i];
auto &base_out = base_outputs[i];
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1, size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
[](int a, int b) { return a * b; }); [](int a, int b) { return a * b; });
size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(),
1, [](int a, int b) { return a * b; });
PADDLE_ENFORCE_EQ(size, size1);
PADDLE_ENFORCE_GT(size, 0);
float *data = static_cast<float *>(out.data.data()); float *data = static_cast<float *>(out.data.data());
for (size_t i = 0; float *base_data = static_cast<float *>(base_out.data.data());
i < std::min(sizeof(ditu_rnn_target_data) / sizeof(float), size); for (size_t i = 0; i < size; i++) {
i++) { EXPECT_NEAR(data[i], base_data[i], 1e-3);
EXPECT_NEAR(data[i], ditu_rnn_target_data[i], 1e-3);
} }
} }
} }
// Turn on the IR pass supportion, run a real inference and check the result.
TEST(Analyzer, SupportIRPass) {
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = "./analysis.out";
Argument argument(FLAGS_inference_model_dir);
argument.model_output_store_path.reset(new std::string("./analysis.out"));
Analyzer analyzer;
analyzer.Run(&argument);
// Should get the transformed model stored to ./analysis.out
ASSERT_TRUE(PathExists("./analysis.out"));
// Inference from this path.
TestWord2vecPrediction("./analysis.out");
}
// Directly infer with the original model. // Directly infer with the original model.
TEST(Analyzer, DituRNN_without_analysis) { TEST(Analyzer, DituRNN_without_analysis) {
TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data, TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
...@@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) { ...@@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) {
} // namespace paddle } // namespace paddle
USE_PASS(fc_fuse_pass); USE_PASS(fc_fuse_pass);
USE_PASS(seq_concat_fc_fuse_pass);
USE_PASS(fc_lstm_fuse_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass); USE_PASS(infer_clean_graph_pass);
USE_PASS(attention_lstm_fuse_pass);
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <string> #include <string>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -58,6 +59,46 @@ struct Argument { ...@@ -58,6 +59,46 @@ struct Argument {
// The output storage path of ModelStorePass. // The output storage path of ModelStorePass.
std::unique_ptr<std::string> model_output_store_path; std::unique_ptr<std::string> model_output_store_path;
// Support for any other attributes.
template <typename T>
void Set(const std::string& key, T* data) {
PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key);
attrs_[key] = data;
attr_deleters_[key] = [data, key, this]() {
VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
VLOG(3) << "argument delete attr: " << key;
delete data;
};
}
bool Has(const std::string& name) const { return attrs_.count(name); }
template <typename T>
T* Release(const std::string& key) {
PADDLE_ENFORCE(attrs_.count(key));
auto* res = boost::any_cast<T*>(attrs_.at(key));
attrs_.erase(key);
attr_deleters_.erase(key);
return res;
}
template <typename T>
T& Get(const std::string& key) {
PADDLE_ENFORCE(Has(key));
return *boost::any_cast<T*>(attrs_.at(key));
}
~Argument() {
for (auto& item : attr_deleters_) {
item.second();
}
}
private:
std::unordered_map<std::string, boost::any> attrs_;
std::unordered_map<std::string, std::function<void()>> attr_deleters_;
}; };
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/io.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { ...@@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
} }
} }
if (argument_->Has("param_scope")) {
LOG(WARNING) << "parameter changes in the scope takes effect";
}
PADDLE_ENFORCE(argument_->transformed_program_desc.get()); PADDLE_ENFORCE(argument_->transformed_program_desc.get());
} }
......
...@@ -29,13 +29,13 @@ namespace paddle { ...@@ -29,13 +29,13 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
static size_t dot_node_counter{0};
/* /*
* A Dot template that helps to build a DOT graph definition. * A Dot template that helps to build a DOT graph definition.
*/ */
class Dot { class Dot {
public: public:
static size_t counter;
struct Attr { struct Attr {
std::string key; std::string key;
std::string value; std::string value;
...@@ -57,7 +57,7 @@ class Dot { ...@@ -57,7 +57,7 @@ class Dot {
Node(const std::string& name, const std::vector<Attr>& attrs) Node(const std::string& name, const std::vector<Attr>& attrs)
: name(name), : name(name),
attrs(attrs), attrs(attrs),
id_("node_" + std::to_string(Dot::counter++)) {} id_("node_" + std::to_string(dot_node_counter++)) {}
std::string id() const { return id_; } std::string id() const { return id_; }
...@@ -65,6 +65,10 @@ class Dot { ...@@ -65,6 +65,10 @@ class Dot {
std::stringstream ss; std::stringstream ss;
CHECK(!name.empty()); CHECK(!name.empty());
ss << id_; ss << id_;
if (attrs.empty()) {
ss << "[label=" << '"' << name << '"' << "]";
return ss.str();
}
for (size_t i = 0; i < attrs.size(); i++) { for (size_t i = 0; i < attrs.size(); i++) {
if (i == 0) { if (i == 0) {
ss << "[label=" << '"' << name << '"' << " "; ss << "[label=" << '"' << name << '"' << " ";
...@@ -108,9 +112,11 @@ class Dot { ...@@ -108,9 +112,11 @@ class Dot {
explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {} explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {}
void AddNode(const std::string& name, const std::vector<Attr>& attrs) { void AddNode(const std::string& id, const std::vector<Attr>& attrs,
CHECK(!nodes_.count(name)) << "duplicate Node '" << name << "'"; std::string label = "") {
nodes_.emplace(name, Node{name, attrs}); CHECK(!nodes_.count(id)) << "duplicate Node '" << id << "'";
if (label.empty()) label = id;
nodes_.emplace(id, Node{label, attrs});
} }
void AddEdge(const std::string& source, const std::string& target, void AddEdge(const std::string& source, const std::string& target,
......
...@@ -13,3 +13,47 @@ ...@@ -13,3 +13,47 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace inference {
namespace analysis {
void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const std::string &prog_file,
const std::string &param_file) {
PADDLE_ENFORCE(argument_);
argument_->Set("param_scope", new framework::Scope);
// Load parameters.
VLOG(3) << "Loading parameters from " << model_dir;
LoadParams(&argument_->Get<framework::Scope>("param_scope"), model_dir,
prog_file, param_file);
}
bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir,
const std::string &prog_file,
const std::string &param_file) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
framework::Executor executor(place);
PADDLE_ENFORCE(argument_->origin_program_desc.get());
framework::ProgramDesc program(*argument_->origin_program_desc);
if ((!prog_file.empty()) && (!param_file.empty())) {
LOG(INFO) << "load single model file from " << prog_file;
Load(&executor, scope, prog_file, param_file);
} else if (!dir.empty()) {
LOG(INFO) << "load from dir " << dir;
Load(&executor, scope, dir);
} else {
LOG(ERROR) << "failed to load parameters";
return false;
}
return true;
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -21,12 +21,17 @@ namespace paddle { ...@@ -21,12 +21,17 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
class FluidToIrPass final : public DataFlowGraphPass { class FluidToIrPass final : public DataFlowGraphPass {
public: public:
FluidToIrPass() = default; FluidToIrPass() = default;
bool Initialize(Argument *argument) override { bool Initialize(Argument *argument) override {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument); ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
PADDLE_ENFORCE(argument->Has(kFluidToIrPassesAttr),
"argument need the attr %s", kFluidToIrPassesAttr);
argument_ = argument;
if (argument->origin_program_desc) { if (argument->origin_program_desc) {
LOG(WARNING) << "argument's origin_program_desc is already set, might " LOG(WARNING) << "argument's origin_program_desc is already set, might "
"duplicate called"; "duplicate called";
...@@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass {
if (!argument->main_dfg) { if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph); argument->main_dfg.reset(new DataFlowGraph);
} }
// Persist the ProgramDesc in graph's attribute. The IR graph just keep the argument->Set("ir_program_desc", new framework::ProgramDesc(program));
// address, will segfault if the original ProgramDesc destroys.
auto &ir_program_p = argument->main_dfg->Attr("ir_program_desc").Pointer(); LOG(INFO) << "Loading parameters";
ir_program_p = new framework::ProgramDesc(program); // Load parameters to argument if needed.
if (argument->fluid_model_dir || (argument->fluid_model_program_path &&
argument->fluid_model_param_path)) {
#define SAFE_GET(ATTR) std::string ATTR = argument->ATTR ? *argument->ATTR : "";
SAFE_GET(fluid_model_dir);
SAFE_GET(fluid_model_program_path);
SAFE_GET(fluid_model_param_path);
#undef SAFE_GET
EnableParamModify(fluid_model_dir, fluid_model_program_path,
fluid_model_param_path);
}
argument_ = argument;
return true; return true;
} }
...@@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass {
void Run(DataFlowGraph *graph) override { void Run(DataFlowGraph *graph) override {
// Call all the IR Passes // Call all the IR Passes
IRPassManager ir_passes(*static_cast<framework::ProgramDesc *>( IRPassManager ir_passes(
argument_->main_dfg->Attr("ir_program_desc").Pointer())); argument_->Get<framework::ProgramDesc>("ir_program_desc"), nullptr);
ir_passes.Apply(std::vector<std::string>( // Pass the scope from analysis to IR if needed.
{// Manual update the passes here. if (argument_->Has("param_scope")) {
"graph_viz_pass", "infer_clean_graph_pass", "graph_viz_pass", // Here the address is passed, attention that IR doesn't own the scope, so
"fc_fuse_pass", "graph_viz_pass"})); // the real scope in analysis should live during the IR phase.
ir_passes.graph().Set(
"param_scope", new framework::Scope *(
&argument_->Get<framework::Scope>("param_scope")));
}
const auto &ir_passes_to_apply =
argument_->Get<std::vector<std::string>>(kFluidToIrPassesAttr);
ir_passes.Apply(ir_passes_to_apply);
PADDLE_ENFORCE(argument_->main_dfg.get()); PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph()); argument_->main_dfg->Build(ir_passes.graph());
// PADDLE_ENFORCE(argument_->main_dfg->IsFullyConnected());
} }
void EnableParamModify(const std::string &model_dir,
const std::string &prog_file,
const std::string &param_file);
std::string repr() const override { return "fluid-to-ir-pass"; } std::string repr() const override { return "fluid-to-ir-pass"; }
private:
// Load parameters from a single file or from a directory.
bool LoadParams(framework::Scope *scope, const std::string &dir,
const std::string &prog_file, const std::string &param_file);
private: private:
Argument *argument_{nullptr}; Argument *argument_{nullptr};
}; };
......
...@@ -24,6 +24,8 @@ namespace analysis { ...@@ -24,6 +24,8 @@ namespace analysis {
TEST(FluidToIrPass, Test) { TEST(FluidToIrPass, Test) {
FluidToIrPass pass; FluidToIrPass pass;
Argument argument(FLAGS_inference_model_dir); Argument argument(FLAGS_inference_model_dir);
argument.Set(kFluidToIrPassesAttr,
new std::vector<std::string>({"infer_clean_graph_pass"}));
pass.Initialize(&argument); pass.Initialize(&argument);
pass.Run(argument.main_dfg.get()); pass.Run(argument.main_dfg.get());
} }
...@@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) { ...@@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) {
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass); USE_PASS(infer_clean_graph_pass);
USE_PASS(attention_lstm_fuse_pass);
USE_PASS(fc_lstm_fuse_pass);
USE_PASS(seq_concat_fc_fuse_pass);
USE_PASS(fc_fuse_pass);
...@@ -14,20 +14,24 @@ ...@@ -14,20 +14,24 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
IRPassManager::IRPassManager(const ProgramDesc& program) { IRPassManager::IRPassManager(const ProgramDesc &program,
framework::Scope *scope)
: program_(program) {
graph_.reset(new framework::ir::Graph(program)); graph_.reset(new framework::ir::Graph(program));
if (scope) graph_->Set("param_scope", new framework::Scope *(scope));
} }
void IRPassManager::Apply(const std::vector<std::string>& passes) { void IRPassManager::Apply(const std::vector<std::string> &passes) {
graph_->Set("graph_viz_path", new std::string("./1.dot"));
// Apply all the passes // Apply all the passes
std::string pre_pass; std::string pre_pass;
for (const std::string& pass_name : passes) { for (const std::string &pass_name : passes) {
LOG(WARNING) << "Running IR pass [" << pass_name << "]"; LOG(WARNING) << "Running IR pass [" << pass_name << "]";
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -31,14 +32,15 @@ using framework::ProgramDesc; ...@@ -31,14 +32,15 @@ using framework::ProgramDesc;
class IRPassManager final { class IRPassManager final {
public: public:
IRPassManager(const ProgramDesc& program); IRPassManager(const ProgramDesc &program, framework::Scope *scope);
void Apply(const std::vector<std::string>& passes); void Apply(const std::vector<std::string> &passes);
framework::ir::Graph& graph() const { return *graph_; } framework::ir::Graph &graph() const { return *graph_; }
private: private:
std::unique_ptr<framework::ir::Graph> graph_; std::unique_ptr<framework::ir::Graph> graph_;
ProgramDesc program_;
}; };
} // namespace analysis } // namespace analysis
......
...@@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) { ...@@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) {
void DfgPassManager::RunAll() { void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
LOG(INFO) << "Total " << data_.size() << " passes"; LOG(INFO) << "Total " << data_.size() << " Analysys passes";
for (auto& pass : data_) { for (auto& pass : data_) {
LOG(WARNING) << "Running pass [" << pass->repr() << "]"; LOG(WARNING) << "Running Analysis pass [" << pass->repr() << "]";
pass->Run(argument_->main_dfg.get()); pass->Run(argument_->main_dfg.get());
} }
} }
......
...@@ -20,7 +20,7 @@ endif(APPLE) ...@@ -20,7 +20,7 @@ endif(APPLE)
set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager
graph_viz_pass fc_fuse_pass graph_viz_pass fc_fuse_pass
infer_clean_graph_pass infer_clean_graph_pass
) )
if(WITH_GPU AND TENSORRT_FOUND) if(WITH_GPU AND TENSORRT_FOUND)
...@@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME) ...@@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME)
endif(WITH_TESTING) endif(WITH_TESTING)
endfunction(inference_api_test) endfunction(inference_api_test)
cc_library(paddle_inference_api SRCS api.cc api_impl.cc DEPS lod_tensor) cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor)
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api)
cc_test(test_paddle_inference_api cc_test(test_paddle_inference_api
SRCS api_tester.cc SRCS api_tester.cc
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
using inference::analysis::Argument;
using inference::Singleton;
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
explicit AnalysisPredictor(const NativeConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
VLOG(3) << "Predictor::init()";
if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device);
} else {
place_ = paddle::platform::CPUPlace();
}
PADDLE_ENFORCE(!parent_scope);
if (parent_scope) {
scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope());
} else {
paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope());
}
executor_.reset(new paddle::framework::Executor(place_));
// Initialize the inference program
if (!config_.model_dir.empty()) {
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.model_dir);
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
} else {
LOG(ERROR) << "fail to load inference model.";
return false;
}
OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0);
VLOG(5) << "to create variables";
PADDLE_ENFORCE(scope_.get());
executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names
feed_target_names_ = inference_program_->GetFeedTargetNames();
fetch_target_names_ = inference_program_->GetFetchTargetNames();
return true;
}
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override {
return NativePaddlePredictor::Run(inputs, output_data, batch_size);
}
void OptimizeInferenceProgram() {
LOG(INFO) << "optimize begin";
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = ""; // Don't output the model.
// Analyze inference_program
Argument argument;
if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir));
} else {
PADDLE_ENFORCE(
!config_.param_file.empty(),
"Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file.empty());
argument.fluid_model_program_path.reset(
new std::string(config_.prog_file));
argument.fluid_model_param_path.reset(
new std::string(config_.param_file));
}
argument.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument);
CHECK(argument.transformed_program_desc);
VLOG(5) << "to prepare executor";
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_.reset(
new framework::ProgramDesc(*argument.transformed_program_desc));
PADDLE_ENFORCE(argument.Has("param_scope"));
// Update scope.
scope_.reset(argument.Release<framework::Scope>("param_scope"));
LOG(INFO) << "optimize end ==";
}
private:
NativeConfig config_;
};
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
NativeConfig, PaddleEngineKind::kAnalysis>(const NativeConfig& config) {
VLOG(3) << "create NativePredictor";
if (config.use_gpu) {
// 1. GPU memeroy
PADDLE_ENFORCE_GT(
config.fraction_of_gpu_memory, 0.f,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]");
PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device);
std::vector<std::string> flags;
if (config.fraction_of_gpu_memory >= 0.0f ||
config.fraction_of_gpu_memory <= 0.95f) {
flags.push_back("dummpy");
std::string flag = "--fraction_of_gpu_memory_to_use=" +
std::to_string(config.fraction_of_gpu_memory);
flags.push_back(flag);
VLOG(3) << "set flag: " << flag;
framework::InitGflags(flags);
}
}
std::unique_ptr<PaddlePredictor> predictor(new AnalysisPredictor(config));
if (!dynamic_cast<AnalysisPredictor*>(predictor.get())->Init(nullptr)) {
return nullptr;
}
return predictor;
}
} // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/helper.h"
namespace paddle {
namespace inference {
template <>
std::string to_string<std::vector<float>>(
const std::vector<std::vector<float>> &vec) {
std::stringstream ss;
for (const auto &piece : vec) {
ss << to_string(piece) << "\n";
}
return ss.str();
}
template <>
std::string to_string<std::vector<std::vector<float>>>(
const std::vector<std::vector<std::vector<float>>> &vec) {
std::stringstream ss;
for (const auto &line : vec) {
for (const auto &rcd : line) {
ss << to_string(rcd) << ";\t";
}
ss << '\n';
}
return ss.str();
}
} // namespace inference
} // namespace paddle
...@@ -44,7 +44,8 @@ class Timer { ...@@ -44,7 +44,8 @@ class Timer {
} }
}; };
void split(const std::string &str, char sep, std::vector<std::string> *pieces) { static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear(); pieces->clear();
if (str.empty()) { if (str.empty()) {
return; return;
...@@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector<std::string> *pieces) { ...@@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector<std::string> *pieces) {
pieces->push_back(str.substr(pos)); pieces->push_back(str.substr(pos));
} }
} }
void split_to_float(const std::string &str, char sep, std::vector<float> *fs) { static void split_to_float(const std::string &str, char sep,
std::vector<float> *fs) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(str, sep, &pieces); split(str, sep, &pieces);
std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs), std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs),
...@@ -76,27 +78,14 @@ std::string to_string(const std::vector<T> &vec) { ...@@ -76,27 +78,14 @@ std::string to_string(const std::vector<T> &vec) {
} }
template <> template <>
std::string to_string<std::vector<float>>( std::string to_string<std::vector<float>>(
const std::vector<std::vector<float>> &vec) { const std::vector<std::vector<float>> &vec);
std::stringstream ss;
for (const auto &piece : vec) {
ss << to_string(piece) << "\n";
}
return ss.str();
}
template <> template <>
std::string to_string<std::vector<std::vector<float>>>( std::string to_string<std::vector<std::vector<float>>>(
const std::vector<std::vector<std::vector<float>>> &vec) { const std::vector<std::vector<std::vector<float>>> &vec);
std::stringstream ss;
for (const auto &line : vec) {
for (const auto &rcd : line) {
ss << to_string(rcd) << ";\t";
}
ss << '\n';
}
return ss.str();
}
// clang-format off // clang-format off
void TensorAssignData(PaddleTensor *tensor, const std::vector<std::vector<float>> &data) { static void TensorAssignData(PaddleTensor *tensor, const std::vector<std::vector<float>> &data) {
// Assign buffer // Assign buffer
int dim = std::accumulate(tensor->shape.begin(), tensor->shape.end(), 1, [](int a, int b) { return a * b; }); int dim = std::accumulate(tensor->shape.begin(), tensor->shape.end(), 1, [](int a, int b) { return a * b; });
tensor->data.Resize(sizeof(float) * dim); tensor->data.Resize(sizeof(float) * dim);
......
...@@ -77,6 +77,7 @@ enum class PaddleEngineKind { ...@@ -77,6 +77,7 @@ enum class PaddleEngineKind {
kNative = 0, // Use the native Fluid facility. kNative = 0, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference. kAnakin, // Use Anakin for inference.
kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT. kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
kAnalysis
// TODO(Superjomn) support following engines latter. // TODO(Superjomn) support following engines latter.
// kTensorRT, // Use TensorRT for inference. // kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin. // kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
......
...@@ -143,5 +143,21 @@ std::unique_ptr<framework::ProgramDesc> Load( ...@@ -143,5 +143,21 @@ std::unique_ptr<framework::ProgramDesc> Load(
return main_program; return main_program;
} }
void SaveVars(const framework::Scope& scope,
const std::vector<std::string>& vars, const std::string& dirname,
bool predicate) {
framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("save_combine");
op->SetInput("X", vars);
op->SetAttr("file_path", dirname + "/param");
op->CheckAttrs();
platform::CPUPlace place;
framework::Executor exe(place);
exe.Run(prog, const_cast<framework::Scope*>(&scope), 0, true, true);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -41,5 +41,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, ...@@ -41,5 +41,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
const std::string& prog_filename, const std::string& prog_filename,
const std::string& param_filename); const std::string& param_filename);
// Save the variables from a scope to disk.
void SaveVars(const framework::Scope& scope,
const std::vector<std::string>& vars, const std::string& dirname,
bool predicate = true);
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -135,6 +136,15 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes( ...@@ -135,6 +136,15 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
return feed_target_shapes; return feed_target_shapes;
} }
void Compile(paddle::framework::ProgramDesc* program) {
std::unique_ptr<paddle::framework::ir::Graph> g(
new paddle::framework::ir::Graph(*program));
auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
"graph_to_program_pass");
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", program);
pass->Apply(std::move(g));
}
template <typename Place, bool CreateVars = true, bool PrepareContext = false> template <typename Place, bool CreateVars = true, bool PrepareContext = false>
void TestInference(const std::string& dirname, void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
...@@ -172,6 +182,8 @@ void TestInference(const std::string& dirname, ...@@ -172,6 +182,8 @@ void TestInference(const std::string& dirname,
paddle::platform::DeviceContextPool::Instance().Get(place)); paddle::platform::DeviceContextPool::Instance().Get(place));
inference_program = InitProgram(&executor, scope, dirname, is_combined); inference_program = InitProgram(&executor, scope, dirname, is_combined);
} }
Compile(inference_program.get());
// Disable the profiler and print the timing information // Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
"load_program_profiler"); "load_program_profiler");
...@@ -249,3 +261,5 @@ void TestInference(const std::string& dirname, ...@@ -249,3 +261,5 @@ void TestInference(const std::string& dirname,
delete scope; delete scope;
} }
USE_PASS(graph_to_program_pass);
...@@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
const int D = w_dims[1] / 4; const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], D + M, PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D);
auto b_dims = ctx->GetInputDim("LSTMBias"); auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
......
...@@ -29,6 +29,6 @@ target_assign_op.cu) ...@@ -29,6 +29,6 @@ target_assign_op.cu)
detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
polygon_box_transform_op.cu) polygon_box_transform_op.cu)
detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
# Export local libraries to parent #Export local libraries to parent
set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
struct AppendProposalsFunctor {
LoDTensor *out_;
int64_t offset_;
Tensor *to_add_;
AppendProposalsFunctor(LoDTensor *out, int64_t offset, Tensor *to_add)
: out_(out), offset_(offset), to_add_(to_add) {}
template <typename T>
void operator()() const {
auto *out_data = out_->data<T>();
auto *to_add_data = to_add_->data<T>();
memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T));
}
};
class GenerateProposalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Scores"), "Input(Scores) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("BboxDeltas"),
"Input(BboxDeltas) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("ImInfo"), "Input(ImInfo) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Anchors"),
"Input(Anchors) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Variances"),
"Input(Variances) shouldn't be null.");
auto scores_dims = ctx->GetInputDim("Scores");
auto bbox_deltas_dims = ctx->GetInputDim("BboxDeltas");
auto im_info_dims = ctx->GetInputDim("ImInfo");
auto anchors_dims = ctx->GetInputDim("Anchors");
auto variances_dims = ctx->GetInputDim("Variances");
ctx->SetOutputDim("RpnRois", {-1, 4});
ctx->SetOutputDim("RpnRoiProbs", {-1, 1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Anchors")->type()),
platform::CPUPlace());
}
};
template <class T>
void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
Tensor *bbox_deltas, Tensor *variances, Tensor *proposals) {
T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace());
int64_t row = all_anchors->dims()[0];
int64_t len = all_anchors->dims()[1];
auto *bbox_deltas_data = bbox_deltas->data<T>();
auto *anchor_data = all_anchors->data<T>();
const T *variances_data = nullptr;
if (variances) {
variances_data = variances->data<T>();
}
for (int64_t i = 0; i < row; ++i) {
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len];
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1];
T anchor_center_x = (anchor_data[i * len + 2] + anchor_data[i * len]) / 2;
T anchor_center_y =
(anchor_data[i * len + 3] + anchor_data[i * len + 1]) / 2;
T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0;
if (variances) {
bbox_center_x =
variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width +
anchor_center_x;
bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y;
bbox_width = std::exp(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2]) *
anchor_width;
bbox_height = std::exp(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3]) *
anchor_height;
} else {
bbox_center_x =
bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(bbox_deltas_data[i * len + 2]) * anchor_width;
bbox_height = std::exp(bbox_deltas_data[i * len + 3]) * anchor_height;
}
proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2;
}
// return proposals;
}
template <class T>
void ClipTiledBoxes(const platform::DeviceContext &ctx, const Tensor &im_info,
Tensor *boxes) {
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
const T *im_info_data = im_info.data<T>();
for (int64_t i = 0; i < boxes->numel(); ++i) {
if (i % 4 == 0) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f);
} else if (i % 4 == 1) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f);
} else if (i % 4 == 2) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f);
} else {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f);
}
}
}
template <class T>
void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
float min_size, const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
min_size *= im_info_data[2];
keep->Resize({boxes->dims()[0], 1});
int *keep_data = keep->mutable_data<int>(ctx.GetPlace());
int keep_len = 0;
for (int i = 0; i < boxes->dims()[0]; ++i) {
T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
T x_ctr = boxes_data[4 * i] + ws / 2;
T y_ctr = boxes_data[4 * i + 1] + hs / 2;
if (ws >= min_size && hs >= min_size && x_ctr <= im_info_data[1] &&
y_ctr <= im_info_data[0]) {
keep_data[keep_len++] = i;
}
}
keep->Resize({keep_len});
}
bool SortScorePairDescend(const std::pair<float, int> &pair1,
const std::pair<float, int> &pair2) {
return pair1.first > pair2.first;
}
template <class T>
void GetMaxScoreIndex(const std::vector<T> &scores,
std::vector<std::pair<T, int>> *sorted_indices) {
for (size_t i = 0; i < scores.size(); ++i) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
SortScorePairDescend);
}
template <class T>
T BBoxArea(const T *box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = inter_xmax - inter_xmin;
const T inter_h = inter_ymax - inter_ymin;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <class T>
Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
const T nms_threshold, const float eta) {
PADDLE_ENFORCE_NOT_NULL(bbox);
int64_t num_boxes = bbox->dims()[0];
// 4: [xmin ymin xmax ymax]
int64_t box_size = bbox->dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
GetMaxScoreIndex<T>(scores_data, &sorted_indices);
std::vector<int> selected_indices;
int selected_num = 0;
T adaptive_threshold = nms_threshold;
const T *bbox_data = bbox->data<T>();
bool flag;
while (sorted_indices.size() != 0) {
int idx = sorted_indices.front().second;
flag = true;
for (size_t k = 0; k < selected_indices.size(); ++k) {
if (flag) {
const int kept_idx = selected_indices[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, false);
flag = (overlap <= adaptive_threshold);
} else {
break;
}
}
if (flag) {
selected_indices.push_back(idx);
selected_num++;
}
sorted_indices.erase(sorted_indices.begin());
if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
Tensor keep_nms;
keep_nms.Resize({selected_num});
int *keep_data = keep_nms.mutable_data<int>(ctx.GetPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
}
template <typename DeviceContext, typename T>
class GenerateProposalsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo");
auto *anchors = context.Input<Tensor>("Anchors");
auto *variances = context.Input<Tensor>("Variances");
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
int post_nms_top_n = context.Attr<int>("post_nms_topN");
float nms_thresh = context.Attr<float>("nms_thresh");
float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta");
auto &dev_ctx = context.template device_context<DeviceContext>();
auto scores_dim = scores->dims();
int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3];
auto bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3];
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel() / 4, 1},
context.GetPlace());
Tensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
dev_ctx.GetPlace());
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace());
math::Transpose<DeviceContext, T, 4> trans;
std::vector<int> axis = {0, 2, 3, 1};
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis);
framework::LoD lod;
std::vector<size_t> lod0(1, 0);
Tensor *anchor = const_cast<framework::Tensor *>(anchors);
anchor->Resize({anchors->numel() / 4, 4});
Tensor *var = const_cast<framework::Tensor *>(variances);
var->Resize({var->numel() / 4, 4});
int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) {
Tensor im_info_slice = im_info->Slice(i, i + 1);
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
Tensor scores_slice = scores_swap.Slice(i, i + 1);
bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair =
ProposalForOneImage(dev_ctx, im_info_slice, *anchor, *var,
bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta);
Tensor proposals = tensor_pair.first;
Tensor scores = tensor_pair.second;
framework::VisitDataType(
framework::ToDataType(rpn_rois->type()),
AppendProposalsFunctor(rpn_rois, 4 * num_proposals, &proposals));
framework::VisitDataType(
framework::ToDataType(rpn_roi_probs->type()),
AppendProposalsFunctor(rpn_roi_probs, num_proposals, &scores));
num_proposals += proposals.dims()[0];
lod0.emplace_back(num_proposals);
}
lod.emplace_back(lod0);
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
rpn_roi_probs->Resize({num_proposals, 1});
}
std::pair<Tensor, Tensor> ProposalForOneImage(
const DeviceContext &ctx, const Tensor &im_info_slice,
const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1]
int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
float eta) const {
auto *scores_data = scores_slice.data<T>();
// Sort index
Tensor index_t;
index_t.Resize({scores_slice.numel()});
int *index = index_t.mutable_data<int>(ctx.GetPlace());
for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i;
}
std::function<bool(const int64_t &, const int64_t &)> compare =
[scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
};
if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) {
std::sort(index, index + scores_slice.numel(), compare);
} else {
std::nth_element(index, index + pre_nms_top_n,
index + scores_slice.numel(), compare);
index_t.Resize({pre_nms_top_n});
}
Tensor scores_sel, bbox_sel, anchor_sel, var_sel;
scores_sel.mutable_data<T>({index_t.numel(), 1}, ctx.GetPlace());
bbox_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
anchor_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
var_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
CPUGather<T>(ctx, scores_slice, index_t, &scores_sel);
CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel);
CPUGather<T>(ctx, anchors, index_t, &anchor_sel);
CPUGather<T>(ctx, variances, index_t, &var_sel);
Tensor proposals;
proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
BoxCoder<T>(ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals);
ClipTiledBoxes<T>(ctx, im_info_slice, &proposals);
Tensor keep;
FilterBoxes<T>(ctx, &proposals, min_size, im_info_slice, &keep);
Tensor scores_filter;
bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep.numel(), 1}, ctx.GetPlace());
CPUGather<T>(ctx, proposals, keep, &bbox_sel);
CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
if (nms_thresh <= 0) {
return std::make_pair(bbox_sel, scores_sel);
}
Tensor keep_nms = NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta);
if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
keep_nms.Resize({post_nms_top_n});
}
proposals.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
scores_sel.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals);
CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel);
return std::make_pair(proposals, scores_sel);
}
};
class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Scores", "The scores of anchors should be foreground.");
AddInput("BboxDeltas", "bbox_deltas.");
AddInput("ImInfo", "Information for image reshape.");
AddInput("Anchors", "All anchors.");
AddInput("Variances", " variances");
AddOutput("RpnRois", "Anchors.");
AddOutput("RpnRoiProbs", "Anchors.");
AddAttr<int>("pre_nms_topN", "pre_nms_topN");
AddAttr<int>("post_nms_topN", "post_nms_topN");
AddAttr<float>("nms_thresh", "nms_thres");
AddAttr<float>("min_size", "min size");
AddAttr<float>("eta", "eta");
AddComment(R"DOC(
Generate Proposals OP
This operator proposes rois according to each box with their probability to be a foreground object and
the box can be calculated by anchors. Bbox_deltais and scores are the output of RPN. Final proposals
could be used to train detection net.
Scores is the probability for each box to be an object. In format of (N, A, H, W) where N is batch size, A is number
of anchors, H and W are height and width of the feature map.
BboxDeltas is the differece between predicted box locatoin and anchor location. In format of (N, 4*A, H, W)
For generating proposals, this operator transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4) and
calculate box locations as proposals candidates. Then clip boxes to image and remove predicted boxes with small area.
Finally, apply nms to get final proposals as output.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp,
ops::GenerateProposalsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
generate_proposals,
ops::GenerateProposalsKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h"
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
"Output(BatchedGate) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
"Output(BatchResetHiddenPrev) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
"The rank of Input(WeightX) should be 2.");
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
"The first dimension of Input(WeightX) "
"should be %d.",
x_dims[1]);
int frame_size = wx_dims[1] / 3;
auto wh_dims = ctx->GetInputDim("WeightH");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
"The rank of Input(WeightH) should be 2.");
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
"The first dimension of Input(WeightH) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
"The second dimension of Input(WeightH) "
"should be 3 * %d.",
frame_size);
if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size.");
}
if (ctx->HasInput("Bias")) {
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
}
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", out_dims);
ctx->ShareLoD("X", "Hidden");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
}
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void FusionGRUOpMaker::Make() {
AddInput("X",
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
AddInput("H0",
"(Tensor, optional) The initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.")
.AsDispensable();
AddInput("WeightX",
"(Tensor) The FC weight with shape (M x 3D),"
"where M is the dim size of x, D is the hidden size. ");
AddInput("WeightH",
"(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. ");
AddInput("Bias",
"(Tensor, optional) (1 x 3D)."
"Almost same as GRUOp."
"Note: if have FC bias it should be added on this bias.")
.AsDispensable();
AddOutput("XX",
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
" or batched_X (size is T x M), this will be automatically chosen,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input.")
.AsIntermediate();
AddOutput("BatchedGate", "(LoDTensor) Same as GRUOp").AsIntermediate();
AddOutput("BatchResetHiddenPrev", "(LoDTensor) (T x 3D) Same as GRUOp.")
.AsIntermediate();
AddOutput("BatchedHidden", "(LoDTensor) (T X D) Same as GRUOp.")
.AsIntermediate();
AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
AddAttr<std::string>("activation",
"(string, default tanh) "
"The activation type used for output candidate {h}_t.")
.SetDefault("tanh");
AddAttr<std::string>(
"gate_activation",
"(string, default sigmoid) "
"The activation type used in update gate and reset gate.")
.SetDefault("sigmoid");
AddAttr<bool>("is_reverse",
"(bool, defalut: False) "
"whether to compute reversed GRU.")
.SetDefault(false);
AddComment(R"DOC(
The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU,
more details can refer to GRU op.
)DOC");
}
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src,
framework::Vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index_lod, dst, indexed_src);
}
template <typename DeviceContext, typename T>
class FusionGRUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* h0 = ctx.Input<Tensor>("H0");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate");
auto* batch_reset_hidden_prev =
ctx.Output<LoDTensor>("BatchResetHiddenPrev");
auto* batch_hidden = ctx.Output<LoDTensor>("BatchedHidden");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
bool is_reverse = ctx.Attr<bool>("is_reverse");
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(ctx.GetPlace());
batch_hidden->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
auto x_dims = x->dims();
auto wx_dims = wx->dims();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
if (x_dims[1] > wx_dims[1]) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
x_data, wx_data, xx_data,
bias ? bias->data<T>() : NULL);
to_batch(dev_ctx, *xx, batched_gate, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_gate->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
xx_data, wx_data, batched_gate_data,
bias ? bias->data<T>() : NULL);
}
int frame_size = static_cast<int>(wx_dims[1] / 3);
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(wh_data);
gru_value.state_weight =
const_cast<T*>(wh_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
framework::Vector<size_t> order(batched_gate->lod()[2]);
if (h0) {
ReorderInitState<DeviceContext, T>(
ctx.template device_context<DeviceContext>(), *h0, order, &ordered_h0,
true);
gru_value.prev_out_value = ordered_h0.data<T>();
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batched_gate->lod()[0];
size_t seq_len = batch_starts.size() - 1;
auto active_node =
math::detail::GetActivationType(ctx.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation"));
#ifdef PADDLE_WITH_MKLML
// use MKL packed to speedup GEMM
if (FLAGS_paddle_num_threads >= 4) {
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size * 2 /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_gate);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
packed_gate);
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_state);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
frame_size, T(1.0), gru_value.state_weight, frame_size,
packed_state);
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
frame_size, gru_value.prev_out_value, frame_size, packed_gate,
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
}
math::detail::forward_reset_output(
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
cur_batch_size, active_gate);
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
gru_value.reset_output_value, frame_size, packed_state,
frame_size, T(1), gru_value.gate_value + frame_size * 2,
frame_size * 3);
}
math::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
cur_batch_size, active_node);
gru_value.prev_out_value = gru_value.output_value;
}
blas.GEMM_FREE(packed_gate);
blas.GEMM_FREE(packed_state);
} else {
#endif
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
gru_value.prev_out_value = gru_value.output_value;
}
#ifdef PADDLE_WITH_MKLML
}
#endif
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batched_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden_out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(
fusion_gru, ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, float>,
ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionGRUOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionGRUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
...@@ -49,9 +49,14 @@ void FusionSeqExpandConcatFCOp::InferShape( ...@@ -49,9 +49,14 @@ void FusionSeqExpandConcatFCOp::InferShape(
"FC height should be sum of all inputs width."); "FC height should be sum of all inputs width.");
if (ctx->HasInput("FCBias")) { if (ctx->HasInput("FCBias")) {
auto b_dims = ctx->GetInputDim("FCBias"); auto b_dims = ctx->GetInputDim("FCBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(FCBias)'s rank must be 2."); PADDLE_ENFORCE(b_dims.size() == 1 || b_dims.size() == 2,
PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1 * %d.", D); "b_dims should be 1 or 2, get %d", b_dims.size());
PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1 * %d.", D); if (b_dims.size() == 1) {
PADDLE_ENFORCE_EQ(b_dims[0], D, "FCBias shapes must be %d.", D);
} else {
PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1x%d.", D);
PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1x%d.", D);
}
} }
ctx->SetOutputDim("Out", {ins_dims[0][0], D}); ctx->SetOutputDim("Out", {ins_dims[0][0], D});
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <cmath> #include <cmath>
#include <functional>
#include <string> #include <string>
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#ifdef __AVX__ #ifdef __AVX__
......
...@@ -38,13 +38,14 @@ class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> { ...@@ -38,13 +38,14 @@ class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> {
auto width = dst_dims[1]; auto width = dst_dims[1];
auto* src_data = src.data<T>(); auto* src_data = src.data<T>();
auto* dst_data = dst->data<T>(); auto* dst_data = dst->data<T>();
for (int i = 0; i < height; ++i) { const int sz = width * sizeof(T);
if (is_src_index) { if (is_src_index) {
memcpy(dst_data + i * width, src_data + index[i] * width, for (int i = 0; i < height; ++i) {
width * sizeof(T)); memcpy(dst_data + i * width, src_data + index[i] * width, sz);
} else { }
memcpy(dst_data + index[i] * width, src_data + i * width, } else {
width * sizeof(T)); for (int i = 0; i < height; ++i) {
memcpy(dst_data + index[i] * width, src_data + i * width, sz);
} }
} }
} }
......
...@@ -355,6 +355,7 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -355,6 +355,7 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
grad->SetInput(framework::GradVarName(output_param), og_names); grad->SetInput(framework::GradVarName(output_param), og_names);
} }
} }
grad->SetInput("Communicator", {"nccl_com__do_not_change_"});
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kParallelBlock, grad_block_[0]); grad->SetBlockAttr(kParallelBlock, grad_block_[0]);
......
...@@ -85,9 +85,6 @@ void InitDevices(bool init_p2p) { ...@@ -85,9 +85,6 @@ void InitDevices(bool init_p2p) {
} catch (const std::exception &exp) { } catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
} }
#else
LOG(WARNING)
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif #endif
InitDevices(init_p2p, devices); InitDevices(init_p2p, devices);
} }
...@@ -101,9 +98,6 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) { ...@@ -101,9 +98,6 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
} catch (const std::exception &exp) { } catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
} }
#else
LOG(WARNING)
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif #endif
for (size_t i = 0; i < devices.size(); ++i) { for (size_t i = 0; i < devices.size(); ++i) {
......
...@@ -203,7 +203,7 @@ def resize_short(im, size): ...@@ -203,7 +203,7 @@ def resize_short(im, size):
h_new = size * h // w h_new = size * h // w
else: else:
w_new = size * w // h w_new = size * w // h
im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC) im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_CUBIC)
return im return im
...@@ -345,7 +345,6 @@ def simple_transform(im, ...@@ -345,7 +345,6 @@ def simple_transform(im,
if np.random.randint(2) == 0: if np.random.randint(2) == 0:
im = left_right_flip(im, is_color) im = left_right_flip(im, is_color)
else: else:
im = center_crop(im, crop_size, is_color)
im = center_crop(im, crop_size, is_color=is_color) im = center_crop(im, crop_size, is_color=is_color)
if len(im.shape) == 3: if len(im.shape) == 3:
im = to_chw(im) im = to_chw(im)
......
...@@ -39,6 +39,7 @@ __all__ = [ ...@@ -39,6 +39,7 @@ __all__ = [
'detection_map', 'detection_map',
'rpn_target_assign', 'rpn_target_assign',
'anchor_generator', 'anchor_generator',
'generate_proposals',
] ]
__auto__ = [ __auto__ = [
...@@ -1253,3 +1254,73 @@ def anchor_generator(input, ...@@ -1253,3 +1254,73 @@ def anchor_generator(input,
anchor.stop_gradient = True anchor.stop_gradient = True
var.stop_gradient = True var.stop_gradient = True
return anchor, var return anchor, var
def generate_proposals(scores,
bbox_deltas,
im_info,
anchors,
variances,
pre_nms_top_n=6000,
post_nms_top_n=1000,
nms_thresh=0.5,
min_size=0.1,
eta=1.0,
name=None):
"""
** Generate proposal labels Faster-RCNN **
This operation proposes RoIs according to each box with their probability to be a foreground object and
the box can be calculated by anchors. Bbox_deltais and scores to be an object are the output of RPN. Final proposals
could be used to train detection net.
For generating proposals, this operation performs following steps:
1. Transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4)
2. Calculate box locations as proposals candidates.
3. Clip boxes to image
4. Remove predicted boxes with small area.
5. Apply NMS to get final proposals as output.
Args:
scores(Variable): A 4-D Tensor with shape [N, A, H, W] represents the probability for each box to be an object.
N is batch size, A is number of anchors, H and W are height and width of the feature map.
bbox_deltas(Variable): A 4-D Tensor with shape [N, 4*A, H, W] represents the differece between predicted box locatoin and anchor location.
im_info(Variable): A 2-D Tensor with shape [N, 3] represents origin image information for N batch. Info contains height, width and scale
between origin image size and the size of feature map.
anchors(Variable): A 4-D Tensor represents the anchors with a layout of [H, W, A, 4]. H and W are height and width of the feature map,
num_anchors is the box count of each position. Each anchor is in (xmin, ymin, xmax, ymax) format an unnormalized.
variances(Variable): The expanded variances of anchors with a layout of [H, W, num_priors, 4]. Each variance is in (xcenter, ycenter, w, h) format.
pre_nms_top_n(float): Number of total bboxes to be kept per image before NMS. 6000 by default.
post_nms_top_n(float): Number of total bboxes to be kept per image after NMS. 1000 by default.
nms_thresh(float): Threshold in NMS, 0.5 by default.
min_size(float): Remove predicted boxes with either height or width < min_size. 0.1 by default.
eta(float): Apply in adaptive NMS, if adaptive threshold > 0.5, adaptive_threshold = adaptive_threshold * eta in each iteration.
"""
helper = LayerHelper('generate_proposals', **locals())
rpn_rois = helper.create_tmp_variable(dtype=bbox_deltas.dtype)
rpn_roi_probs = helper.create_tmp_variable(dtype=scores.dtype)
helper.append_op(
type="generate_proposals",
inputs={
'Scores': scores,
'BboxDeltas': bbox_deltas,
'ImInfo': im_info,
'Anchors': anchors,
'Variances': variances
},
attrs={
'pre_nms_topN': pre_nms_top_n,
'post_nms_topN': post_nms_top_n,
'nms_thresh': nms_thresh,
'min_size': min_size,
'eta': eta
},
outputs={'RpnRois': rpn_rois,
'RpnRoiProbs': rpn_roi_probs})
rpn_rois.stop_gradient = True
rpn_roi_probs.stop_gradient = True
return rpn_rois, rpn_roi_probs
...@@ -125,8 +125,8 @@ opts = optimizer.minimize(avg_cost) ...@@ -125,8 +125,8 @@ opts = optimizer.minimize(avg_cost)
batch_size = fluid.layers.create_tensor(dtype='int64') batch_size = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size) batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size)
# fluid.memory_optimize(fluid.default_main_program(), level=0) fluid.memory_optimize(fluid.default_main_program(), level=0)
fluid.release_memory(fluid.default_main_program()) # fluid.release_memory(fluid.default_main_program())
BATCH_SIZE = 16 BATCH_SIZE = 16
PASS_NUM = 1 PASS_NUM = 1
......
...@@ -92,8 +92,8 @@ def main(): ...@@ -92,8 +92,8 @@ def main():
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4) optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
# fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program())
fluid.release_memory(fluid.default_main_program()) # fluid.release_memory(fluid.default_main_program())
# fix the order of training data # fix the order of training data
train_data = paddle.batch( train_data = paddle.batch(
......
...@@ -201,5 +201,44 @@ class TestDetectionMAP(unittest.TestCase): ...@@ -201,5 +201,44 @@ class TestDetectionMAP(unittest.TestCase):
print(str(program)) print(str(program))
class TestGenerateProposals(unittest.TestCase):
def test_generate_proposals(self):
data_shape = [20, 64, 64]
images = fluid.layers.data(
name='images', shape=data_shape, dtype='float32')
im_info = fluid.layers.data(
name='im_info', shape=[1, 3], dtype='float32')
anchors, variances = fluid.layers.anchor_generator(
name='anchor_generator',
input=images,
anchor_sizes=[32, 64],
aspect_ratios=[1.0],
variance=[0.1, 0.1, 0.2, 0.2],
stride=[16.0, 16.0],
offset=0.5)
num_anchors = anchors.shape[2]
scores = fluid.layers.data(
name='scores', shape=[1, num_anchors, 8, 8], dtype='float32')
bbox_deltas = fluid.layers.data(
name='bbox_deltas',
shape=[1, num_anchors * 4, 8, 8],
dtype='float32')
rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals(
name='generate_proposals',
scores=scores,
bbox_deltas=bbox_deltas,
im_info=im_info,
anchors=anchors,
variances=variances,
pre_nms_top_n=6000,
post_nms_top_n=1000,
nms_thresh=0.5,
min_size=0.1,
eta=1.0)
self.assertIsNotNone(rpn_rois)
self.assertIsNotNone(rpn_roi_probs)
print(rpn_rois.shape)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -15,17 +15,55 @@ ...@@ -15,17 +15,55 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import paddle
from test_dist_base import TestDistBase from test_dist_base import TestDistBase
class TestDistTransformer2x2(TestDistBase): def download_files():
url_prefix = 'http://paddle-unittest-data.cdn.bcebos.com/dist_transformer/'
vocab_url = url_prefix + 'vocab.bpe.32000'
vocab_md5 = 'a86d345ca6e27f6591d0dccb1b9be853'
paddle.dataset.common.download(vocab_url, 'test_dist_transformer',
vocab_md5)
local_train_url = url_prefix + 'train.tok.clean.bpe.32000.en-de'
local_train_md5 = '033eb02b9449e6dd823f050782ac8914'
paddle.dataset.common.download(local_train_url, 'test_dist_transformer',
local_train_md5)
train0_url = url_prefix + 'train.tok.clean.bpe.32000.en-de.train_0'
train0_md5 = 'ddce7f602f352a0405267285379a38b1'
paddle.dataset.common.download(train0_url, 'test_dist_transformer',
train0_md5)
train1_url = url_prefix + 'train.tok.clean.bpe.32000.en-de.train_1'
train1_md5 = '8757798200180285b1a619cd7f408747'
paddle.dataset.common.download(train1_url, 'test_dist_transformer',
train1_md5)
test_url = url_prefix + 'newstest2013.tok.bpe.32000.en-de'
test_md5 = '9dd74a266dbdb25314183899f269b4a2'
paddle.dataset.common.download(test_url, 'test_dist_transformer', test_md5)
class TestDistTransformer2x2Sync(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
def test_transformer(self): def test_transformer(self):
# TODO(paddle-dev): check if the delta is OK. download_files()
# Usually start around ~8000 and converge to ~5000 #Note: loss on test dataset of the first 5 batch are:
self.check_with_place("dist_transformer.py", delta=400) # 10.518872, 10.518871, 10.518868, 10.518862, 10.518855
self.check_with_place("dist_transformer.py", delta=1e-7)
class TestDistTransformer2x2Async(TestDistBase):
def _setup_config(self):
self._sync_mode = False
def test_transformer(self):
download_files()
self.check_with_place("dist_transformer.py", delta=1.0)
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
from op_test import OpTest
from test_gru_op import gru
from test_fusion_lstm_op import fc, ACTIVATION
def fusion_gru(
x, # T x M
lod, # 1 x N
h0, # N x D
wx, # M x 3D
wh, # D x 3D
bias, # 1 x 3D
is_reverse,
act_state,
act_gate):
return gru(fc(x, wx, bias),
lod,
h0,
wh,
np.zeros(
(1, wh.shape[1]), dtype='float64'),
is_reverse,
act_state,
act_gate)
class TestFusionGRUOp(OpTest):
def set_confs(self):
pass
def setUp(self):
self.op_type = "fusion_gru"
self.lod = [[2, 4, 3]]
self.M = 3
self.D = 5
self.is_reverse = False
self.with_h0 = True
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.set_confs()
T = sum(self.lod[0])
N = len(self.lod[0])
x = np.random.rand(T, self.M).astype('float64')
wx = np.random.rand(self.M, 3 * self.D).astype('float64')
wh = np.random.rand(self.D, 3 * self.D).astype('float64')
bias = np.random.rand(
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float64')
h0 = np.random.rand(
N, self.D).astype('float64') if self.with_h0 else np.zeros(
(N, self.D), dtype='float64')
_, _, _, hidden = fusion_gru(
x, self.lod, h0, wx, wh, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
self.inputs = {'X': (x, self.lod), 'WeightX': wx, 'WeightH': wh}
if self.with_bias:
self.inputs['Bias'] = bias
if self.with_h0:
self.inputs['H0'] = h0
self.outputs = {'Hidden': (hidden, self.lod)}
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
}
def test_check_output(self):
self.check_output(atol=1e-8)
class TestFusionGRUOpNoInitial(TestFusionGRUOp):
def set_confs(self):
self.with_h0 = False
class TestFusionGRUOpNoBias(TestFusionGRUOp):
def set_confs(self):
self.with_bias = False
class TestFusionGRUOpReverse(TestFusionGRUOp):
def set_confs(self):
self.is_reverse = True
class TestFusionGRUOpMD1(TestFusionGRUOp):
def set_confs(self):
self.M = 36
self.D = 8
class TestFusionGRUOpMD2(TestFusionGRUOp):
def set_confs(self):
self.M = 8
self.D = 8
class TestFusionGRUOpBS1(TestFusionGRUOp):
def set_confs(self):
self.lod = [[3]]
self.D = 16
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://w_idxw.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import sys
import math
import paddle.fluid as fluid
from op_test import OpTest
from test_multiclass_nms_op import nms
from test_anchor_generator_op import anchor_generator_in_python
import copy
def generate_proposals_in_python(scores, bbox_deltas, im_info, anchors,
variances, pre_nms_topN, post_nms_topN,
nms_thresh, min_size, eta):
all_anchors = anchors.reshape(-1, 4)
rois = np.empty((0, 5), dtype=np.float32)
roi_probs = np.empty((0, 1), dtype=np.float32)
rpn_rois = []
rpn_roi_probs = []
lod = []
num_images = scores.shape[0]
for img_idx in range(num_images):
img_i_boxes, img_i_probs = proposal_for_one_image(
im_info[img_idx, :], all_anchors, variances,
bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :],
pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta)
lod.append(img_i_probs.shape[0])
rpn_rois.append(img_i_boxes)
rpn_roi_probs.append(img_i_probs)
return rpn_rois, rpn_roi_probs, lod
def proposal_for_one_image(im_info, all_anchors, variances, bbox_deltas, scores,
pre_nms_topN, post_nms_topN, nms_thresh, min_size,
eta):
# Transpose and reshape predicted bbox transformations to get them
# into the same order as the anchors:
# - bbox deltas will be (4 * A, H, W) format from conv output
# - transpose to (H, W, 4 * A)
# - reshape to (H * W * A, 4) where rows are ordered by (H, W, A)
# in slowest to fastest order to match the enumerated anchors
bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape(-1, 4)
all_anchors = all_anchors.reshape(-1, 4)
variances = variances.reshape(-1, 4)
# Same story for the scores:
# - scores are (A, H, W) format from conv output
# - transpose to (H, W, A)
# - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
# to match the order of anchors and bbox_deltas
scores = scores.transpose((1, 2, 0)).reshape(-1, 1)
# sort all (proposal, score) pairs by score from highest to lowest
# take top pre_nms_topN (e.g. 6000)
if pre_nms_topN <= 0 or pre_nms_topN >= len(scores):
order = np.argsort(-scores.squeeze())
else:
# Avoid sorting possibly large arrays;
# First partition to get top K unsorted
# and then sort just thoes
inds = np.argpartition(-scores.squeeze(), pre_nms_topN)[:pre_nms_topN]
order = np.argsort(-scores[inds].squeeze())
order = inds[order]
scores = scores[order, :]
bbox_deltas = bbox_deltas[order, :]
all_anchors = all_anchors[order, :]
proposals = box_coder(all_anchors, bbox_deltas, variances)
# clip proposals to image (may result in proposals with zero area
# that will be removed in the next step)
proposals = clip_tiled_boxes(proposals, im_info[:2])
# remove predicted boxes with height or width < min_size
keep = filter_boxes(proposals, min_size, im_info)
proposals = proposals[keep, :]
scores = scores[keep, :]
# apply loose nms (e.g. threshold = 0.7)
# take post_nms_topN (e.g. 1000)
# return the top proposals
if nms_thresh > 0:
keep = nms(boxes=proposals,
scores=scores,
nms_threshold=nms_thresh,
eta=eta)
if post_nms_topN > 0 and post_nms_topN < len(keep):
keep = keep[:post_nms_topN]
proposals = proposals[keep, :]
scores = scores[keep, :]
return proposals, scores
def box_coder(all_anchors, bbox_deltas, variances):
"""
Decode proposals by anchors and bbox_deltas from RPN
"""
#proposals: xmin, ymin, xmax, ymax
proposals = np.zeros_like(bbox_deltas, dtype=np.float32)
#anchor_loc: width, height, center_x, center_y
anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32)
anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0]
anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1]
anchor_loc[:, 2] = (all_anchors[:, 2] + all_anchors[:, 0]) / 2
anchor_loc[:, 3] = (all_anchors[:, 3] + all_anchors[:, 1]) / 2
#predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height
pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32)
if variances is not None:
for i in range(bbox_deltas.shape[0]):
pred_bbox[i, 0] = variances[i, 0] * bbox_deltas[i, 0] * anchor_loc[
i, 0] + anchor_loc[i, 2]
pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[
i, 1] + anchor_loc[i, 3]
pred_bbox[i, 2] = math.exp(variances[i, 2] *
bbox_deltas[i, 2]) * anchor_loc[i, 0]
pred_bbox[i, 3] = math.exp(variances[i, 3] *
bbox_deltas[i, 3]) * anchor_loc[i, 1]
else:
for i in range(bbox_deltas.shape[0]):
pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[
i, 2]
pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[
i, 3]
pred_bbox[i, 2] = math.exp(bbox_deltas[i, 2]) * anchor_loc[i, 0]
pred_bbox[i, 3] = math.exp(bbox_deltas[i, 3]) * anchor_loc[i, 1]
proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2
proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2
proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2
proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2
return proposals
def clip_tiled_boxes(boxes, im_shape):
"""Clip boxes to image boundaries. im_shape is [height, width] and boxes
has shape (N, 4 * num_tiled_boxes)."""
assert boxes.shape[1] % 4 == 0, \
'boxes.shape[1] is {:d}, but must be divisible by 4.'.format(
boxes.shape[1]
)
# x1 >= 0
boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)
# y1 >= 0
boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0)
# x2 < im_shape[1]
boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0)
# y2 < im_shape[0]
boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0)
return boxes
def filter_boxes(boxes, min_size, im_info):
"""Only keep boxes with both sides >= min_size and center within the image.
"""
# Scale min_size to match image scale
min_size *= im_info[2]
ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1
x_ctr = boxes[:, 0] + ws / 2.
y_ctr = boxes[:, 1] + hs / 2.
keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_info[1]) &
(y_ctr < im_info[0]))[0]
return keep
def iou(box_a, box_b):
"""
Apply intersection-over-union overlap between box_a and box_b
"""
xmin_a = min(box_a[0], box_a[2])
ymin_a = min(box_a[1], box_a[3])
xmax_a = max(box_a[0], box_a[2])
ymax_a = max(box_a[1], box_a[3])
xmin_b = min(box_b[0], box_b[2])
ymin_b = min(box_b[1], box_b[3])
xmax_b = max(box_b[0], box_b[2])
ymax_b = max(box_b[1], box_b[3])
area_a = (ymax_a - ymin_a + 1) * (xmax_a - xmin_a + 1)
area_b = (ymax_b - ymin_b + 1) * (xmax_b - xmin_b + 1)
if area_a <= 0 and area_b <= 0:
return 0.0
xa = max(xmin_a, xmin_b)
ya = max(ymin_a, ymin_b)
xb = min(xmax_a, xmax_b)
yb = min(ymax_a, ymax_b)
inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0)
iou_ratio = inter_area / (area_a + area_b - inter_area)
return iou_ratio
def nms(boxes, scores, nms_threshold, eta=1.0):
"""Apply non-maximum suppression at test time to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
scores: (tensor) The class predscores for the img, Shape:[num_priors].
nms_threshold: (float) The overlap thresh for suppressing unnecessary
boxes.
eta: (float) The parameter for adaptive NMS.
Return:
The indices of the kept boxes with respect to num_priors.
"""
all_scores = copy.deepcopy(scores)
all_scores = all_scores.flatten()
sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
sorted_scores = all_scores[sorted_indices]
selected_indices = []
adaptive_threshold = nms_threshold
for i in range(sorted_scores.shape[0]):
idx = sorted_indices[i]
keep = True
for k in range(len(selected_indices)):
if keep:
kept_idx = selected_indices[k]
overlap = iou(boxes[idx], boxes[kept_idx])
keep = True if overlap <= adaptive_threshold else False
else:
break
if keep:
selected_indices.append(idx)
if keep and eta < 1 and adaptive_threshold > 0.5:
adaptive_threshold *= eta
return selected_indices
class TestGenerateProposalsOp(OpTest):
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {
'Scores': self.scores,
'BboxDeltas': self.bbox_deltas,
'ImInfo': self.im_info.astype(np.float32),
'Anchors': self.anchors,
'Variances': self.variances
}
self.attrs = {
'pre_nms_topN': self.pre_nms_topN,
'post_nms_topN': self.post_nms_topN,
'nms_thresh': self.nms_thresh,
'min_size': self.min_size,
'eta': self.eta
}
print("lod = ", self.lod)
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.lod]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.lod])
}
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "generate_proposals"
self.set_data()
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 5000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 3.0
self.eta = 0.8
def init_test_input(self):
batch_size = 1
input_channels = 20
layer_h = 16
layer_w = 16
input_feat = np.random.random(
(batch_size, input_channels, layer_h, layer_w)).astype('float32')
self.anchors, self.variances = anchor_generator_in_python(
input_feat=input_feat,
anchor_sizes=[16., 32.],
aspect_ratios=[0.5, 1.0],
variances=[1.0, 1.0, 1.0, 1.0],
stride=[16.0, 16.0],
offset=0.5)
self.im_info = np.array([[64., 64., 8.]]) #im_height, im_width, scale
num_anchors = self.anchors.shape[2]
self.scores = np.random.random(
(batch_size, num_anchors, layer_h, layer_w)).astype('float32')
self.bbox_deltas = np.random.random(
(batch_size, num_anchors * 4, layer_h, layer_w)).astype('float32')
def init_test_output(self):
self.rpn_rois, self.rpn_roi_probs, self.lod = generate_proposals_in_python(
self.scores, self.bbox_deltas, self.im_info, self.anchors,
self.variances, self.pre_nms_topN, self.post_nms_topN,
self.nms_thresh, self.min_size, self.eta)
if __name__ == '__main__':
unittest.main()
...@@ -19,22 +19,19 @@ import numpy as np ...@@ -19,22 +19,19 @@ import numpy as np
import math import math
import functools import functools
from op_test import OpTest from op_test import OpTest
from test_lstm_op import identity, sigmoid, tanh, relu from test_lstm_op import ACTIVATION
class TestGRUOp(OpTest): def gru(
lod = [[2, 4, 3]] input, # T x 3D
batch_size = sum(lod[0]) lod, # 1 x N
frame_size = 5 h0, # N x D
activate = { weight, # D x 3D
'identity': identity, bias, # 1 x 3D
'sigmoid': sigmoid, is_reverse,
'tanh': tanh, act_state,
'relu': relu act_gate):
} def _seq_to_batch(lod, is_reverse):
@staticmethod
def seq_to_batch(lod, is_reverse):
idx_in_seq_list = [] idx_in_seq_list = []
seq_lens = lod[0] seq_lens = lod[0]
seq_starts = [0] seq_starts = [0]
...@@ -56,121 +53,125 @@ class TestGRUOp(OpTest): ...@@ -56,121 +53,125 @@ class TestGRUOp(OpTest):
idx_in_seq_list.append(idx_in_seq) idx_in_seq_list.append(idx_in_seq)
return idx_in_seq_list, sorted_seqs return idx_in_seq_list, sorted_seqs
def gru_step(self, x, h_p, w, b): def _step(x, h_p, w, b, act_state, act_gate):
batch_size = x.shape[0] T = x.shape[0]
frame_size = w.shape[0] D = w.shape[0]
g = x + np.tile(b, (batch_size, 1)) g = x + np.tile(b, (T, 1))
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape( w_u_r = w.flatten()[:D * D * 2].reshape((D, D * 2))
(frame_size, frame_size * 2)) u_r = act_gate(np.dot(h_p, w_u_r) + g[:, :D * 2])
u_r = self.activate[self.attrs['gate_activation']](np.dot( u = u_r[:, :D]
h_p, w_u_r) + g[:, :frame_size * 2]) r = u_r[:, D:D * 2]
u = u_r[:, :frame_size]
r = u_r[:, frame_size:frame_size * 2]
r_h_p = r * h_p r_h_p = r * h_p
w_c = w.flatten()[frame_size * frame_size * 2:].reshape( w_c = w.flatten()[D * D * 2:].reshape((D, D))
(frame_size, frame_size)) c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c)) g = np.hstack((u_r, c))
h = u * c + (1 - u) * h_p h = u * c + (1 - u) * h_p
return g, r_h_p, h return g, r_h_p, h
def gru(self): T = sum(lod[0])
input, lod = self.inputs['Input'] N = len(lod[0])
w = self.inputs['Weight'] D = weight.shape[0]
b = self.inputs['Bias'] if 'Bias' in self.inputs else np.zeros( batch_gate = np.zeros((T, 3 * D), dtype='float64')
(1, self.frame_size * 3)) batch_reset_hidden_prev = np.zeros((T, D), dtype='float64')
batch_gate = self.outputs['BatchGate'] batch_hidden = np.zeros((T, D), dtype='float64')
batch_reset_hidden_prev = self.outputs['BatchResetHiddenPrev'] hidden = np.zeros((T, D), dtype='float64')
batch_hidden = self.outputs['BatchHidden']
hidden = self.outputs['Hidden'] idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse)
idx_in_seq_list = self.idx_in_seq_list h_p = h0[sorted_seqs]
h_p = self.inputs['H0'][ max_seq_len = len(idx_in_seq_list)
self.sorted_seqs] if 'H0' in self.inputs else np.zeros( assert len(idx_in_seq_list[0]) == N
(len(idx_in_seq_list[0]), self.frame_size)) end_idx = 0
num_batch = len(idx_in_seq_list) for batch_idx in range(max_seq_len):
end_idx = 0 x = input[idx_in_seq_list[batch_idx]]
for batch_idx in range(num_batch): g, r_h_p, h = _step(x, h_p, weight, bias, act_state, act_gate)
x = input[idx_in_seq_list[batch_idx]] if batch_idx < (max_seq_len - 1):
g, r_h_p, h = self.gru_step(x, h_p, w, b) h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
if batch_idx < (num_batch - 1): start_idx = end_idx
h_p = h[:len(idx_in_seq_list[batch_idx + 1])] end_idx = start_idx + len(idx_in_seq_list[batch_idx])
start_idx = end_idx batch_gate[start_idx:end_idx] = g
end_idx = start_idx + len(idx_in_seq_list[batch_idx]) batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
batch_gate[start_idx:end_idx] = g batch_hidden[start_idx:end_idx] = h
batch_reset_hidden_prev[start_idx:end_idx] = r_h_p hidden[idx_in_seq_list[batch_idx]] = h
batch_hidden[start_idx:end_idx] = h return batch_gate, batch_reset_hidden_prev, batch_hidden, hidden
hidden[idx_in_seq_list[batch_idx]] = h
return batch_gate, batch_reset_hidden_prev, hidden
def set_data(self):
lod = self.lod
self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
lod, self.is_reverse)
batch_size = self.batch_size
frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
h0 = np.random.rand(len(self.idx_in_seq_list[0]),
frame_size).astype('float64')
weight = np.random.rand(frame_size, frame_size * 3).astype('float64')
bias = np.random.rand(1, frame_size * 3).astype('float64')
self.inputs = {
'Input': (input, lod),
'H0': h0,
'Weight': weight,
'Bias': bias
}
self.outputs = {
'BatchGate': np.zeros(
(batch_size, frame_size * 3), dtype='float64'),
'BatchResetHiddenPrev': np.zeros(
(batch_size, frame_size), dtype='float64'),
'BatchHidden': np.zeros(
(batch_size, frame_size), dtype='float64'),
'Hidden': np.zeros(
(batch_size, frame_size), dtype='float64')
}
class TestGRUOp(OpTest):
def set_confs(self): def set_confs(self):
self.is_reverse = False pass
self.attrs = {
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
def setUp(self): def setUp(self):
self.op_type = "gru" self.op_type = "gru"
self.lod = [[2, 4, 3]]
self.D = 5
self.is_reverse = False
self.with_h0 = True
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.set_confs() self.set_confs()
self.set_data()
self.gru() T = sum(self.lod[0])
N = len(self.lod[0])
input = np.random.rand(T, 3 * self.D).astype('float64')
weight = np.random.rand(self.D, 3 * self.D).astype('float64')
bias = np.random.rand(
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float64')
h0 = np.random.rand(
N, self.D).astype('float64') if self.with_h0 else np.zeros(
(N, self.D), dtype='float64')
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
input, self.lod, h0, weight, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
self.inputs = {'Input': (input, self.lod), 'Weight': weight}
if self.with_bias:
self.inputs['Bias'] = bias
if self.with_h0:
self.inputs['H0'] = h0
self.outputs = {
'Hidden': (hidden, self.lod),
'BatchGate': batch_gate,
'BatchResetHiddenPrev': batch_reset_hidden_prev,
'BatchHidden': batch_hidden,
}
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(atol=1e-8)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden']) self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpNoInitial(TestGRUOp): class TestGRUOpNoInitial(TestGRUOp):
def set_data(self): def set_confs(self):
super(TestGRUOpNoInitial, self).set_data() self.with_h0 = False
self.inputs.pop('H0')
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden']) self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpNoBias(TestGRUOp):
def set_confs(self):
self.with_bias = False
def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight'], ['Hidden'])
class TestGRUOpReverse(TestGRUOp): class TestGRUOpReverse(TestGRUOp):
def set_confs(self): def set_confs(self):
self.is_reverse = True self.is_reverse = True
self.attrs = {
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -31,6 +31,7 @@ Steps to transpile pserver: ...@@ -31,6 +31,7 @@ Steps to transpile pserver:
""" """
import math import math
import sys
import numpy as np import numpy as np
import collections import collections
import six import six
...@@ -181,7 +182,8 @@ class DistributeTranspiler(object): ...@@ -181,7 +182,8 @@ class DistributeTranspiler(object):
program=None, program=None,
pservers="127.0.0.1:6174", pservers="127.0.0.1:6174",
trainers=1, trainers=1,
sync_mode=True): sync_mode=True,
startup_program=None):
""" """
Run the transpiler. Run the transpiler.
...@@ -194,13 +196,17 @@ class DistributeTranspiler(object): ...@@ -194,13 +196,17 @@ class DistributeTranspiler(object):
list. list.
trainers (int): number of trainers in the distributed job. trainers (int): number of trainers in the distributed job.
sync_mode (bool): Do sync training or not, default is True. sync_mode (bool): Do sync training or not, default is True.
startup_program (Program|None): startup_program to transpile,
default is fluid.default_main_program().
""" """
if program is None: if program is None:
program = default_main_program() program = default_main_program()
if startup_program is None:
startup_program = default_startup_program()
self.origin_program = program self.origin_program = program
self.origin_startup_program = default_startup_program().clone() self.startup_program = startup_program
self.origin_startup_program = self.startup_program.clone()
self.startup_program = default_startup_program()
self.trainer_num = trainers self.trainer_num = trainers
self.sync_mode = sync_mode self.sync_mode = sync_mode
self.trainer_id = trainer_id self.trainer_id = trainer_id
...@@ -376,21 +382,18 @@ class DistributeTranspiler(object): ...@@ -376,21 +382,18 @@ class DistributeTranspiler(object):
return self.origin_program return self.origin_program
def _get_trainer_startup_program(self, def _get_trainer_startup_program(self, recv_vars, eplist):
recv_vars,
eplist,
startup_program=None):
""" """
Get transpiled trainer side startup program. Get transpiled trainer side startup program.
Args: Args:
startup_program(Program): Startup program. recv_vars (list): Variable list to recv for current trainer_id
eplist (list): A list of strings indicating
Returns: Returns:
Program: trainer side startup program. Program: trainer side startup program.
""" """
if startup_program is None: startup_program = self.startup_program
startup_program = self.startup_program
# FIXME(gongwb): delete not need ops. # FIXME(gongwb): delete not need ops.
# note that: some parameter is not trainable and those ops can't be deleted. # note that: some parameter is not trainable and those ops can't be deleted.
...@@ -438,7 +441,18 @@ class DistributeTranspiler(object): ...@@ -438,7 +441,18 @@ class DistributeTranspiler(object):
#add concat ops to merge splited parameters received from parameter servers. #add concat ops to merge splited parameters received from parameter servers.
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
orig_param = startup_program.global_block().vars[varname] # NOTE: if enable memory optimization, origin vars maybe removed.
if startup_program.global_block().vars.has_key(varname):
orig_param = startup_program.global_block().vars[varname]
else:
origin_param_var = self.origin_program.global_block().vars[
varname]
orig_param = startup_program.global_block().create_var(
name=varname,
persistable=origin_param_var.persistable,
type=origin_param_var.type,
dtype=origin_param_var.dtype,
shape=origin_param_var.shape)
startup_program.global_block().append_op( startup_program.global_block().append_op(
type="concat", type="concat",
inputs={"X": splited_var}, inputs={"X": splited_var},
...@@ -461,7 +475,9 @@ class DistributeTranspiler(object): ...@@ -461,7 +475,9 @@ class DistributeTranspiler(object):
# NOTE: assume blocks of the same variable is not distributed # NOTE: assume blocks of the same variable is not distributed
# on the same pserver, only change param/grad varnames for # on the same pserver, only change param/grad varnames for
# trainers to fetch. # trainers to fetch.
sys.stderr.write("get_pserver_program() is deprecated, call\
get_pserver_programs() to get pserver main and startup\
in a single call.")
# step1 # step1
pserver_program = Program() pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed pserver_program.random_seed = self.origin_program.random_seed
...@@ -651,32 +667,58 @@ class DistributeTranspiler(object): ...@@ -651,32 +667,58 @@ class DistributeTranspiler(object):
endpoint) endpoint)
pserver_program._sync_with_cpp() pserver_program._sync_with_cpp()
# save pserver program to generate pserver side startup relatively.
self.pserver_program = pserver_program
return pserver_program return pserver_program
def get_pserver_programs(self, endpoint):
"""
Get pserver side main program and startup program for distributed training.
Args:
endpoint (str): current pserver endpoint.
Returns:
tuple: (main_program, startup_program), of type "Program"
"""
pserver_prog = self.get_pserver_program(endpoint)
pserver_startup = self.get_startup_program(endpoint)
return pserver_prog, pserver_startup
def get_startup_program(self, def get_startup_program(self,
endpoint, endpoint,
pserver_program, pserver_program=None,
startup_program=None): startup_program=None):
""" """
**Deprecated**
Get startup program for current parameter server. Get startup program for current parameter server.
Modify operator input variables if there are variables that Modify operator input variables if there are variables that
were split to several blocks. were split to several blocks.
Args: Args:
endpoint (str): current pserver endpoint. endpoint (str): current pserver endpoint.
pserver_program (Program): call get_pserver_program first and pserver_program (Program): deprecated, call get_pserver_program first.
pass the result here. startup_program (Program): deprecated, should pass startup_program
startup_program (Program): if pass None, will use when initalizing
default_startup_program
Returns: Returns:
Program: parameter server side startup program. Program: parameter server side startup program.
""" """
sys.stderr.write("get_startup_program() is deprecated, call\
get_pserver_programs() to get pserver main and startup\
in a single call.")
if pserver_program != None:
sys.stderr.write("passing pserver_program to get_startup_program()\
is deprecated, you can use new API get_pserver_programs() to\
get both pserver main program and startup program.")
if startup_program != None:
sys.stderr.write("passing startup_program to get_startup_program()\
is deprecated, use fluid.program_guard() or pass this argument\
to transpile() call.")
s_prog = Program() s_prog = Program()
if not startup_program: orig_s_prog = self.startup_program
orig_s_prog = default_startup_program()
else:
orig_s_prog = startup_program
s_prog.random_seed = orig_s_prog.random_seed s_prog.random_seed = orig_s_prog.random_seed
params = self.param_grad_ep_mapping[endpoint]["params"] params = self.param_grad_ep_mapping[endpoint]["params"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册