未验证 提交 b98b7440 编写于 作者: C Chen Weihang 提交者: GitHub

Merge branch 'develop' into sequence_enumerate_op

...@@ -52,9 +52,8 @@ ExternalProject_Add( ...@@ -52,9 +52,8 @@ ExternalProject_Add(
extern_anakin extern_anakin
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLML_PROJECT} DEPENDS ${MKLML_PROJECT}
# Anakin codes error on Intel(R) Xeon(R) Gold 5117 CPU, temporary do not compile avx512 related code. GIT_REPOSITORY "https://github.com/PaddlePaddle/Anakin"
GIT_REPOSITORY "https://github.com/luotao1/Anakin" GIT_TAG "9424277cf9ae180a14aff09560d3cd60a49c76d2"
GIT_TAG "211d1fc5d813d70c0c14072f9083cf25f40940ea"
PREFIX ${ANAKIN_SOURCE_DIR} PREFIX ${ANAKIN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DUSE_GPU_PLACE=YES CMAKE_ARGS -DUSE_GPU_PLACE=YES
......
...@@ -113,6 +113,7 @@ paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size ...@@ -113,6 +113,7 @@ paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)) paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None)) paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None))
paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
...@@ -146,6 +147,7 @@ paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', ' ...@@ -146,6 +147,7 @@ paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', '
paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.lrn ArgSpec(args=['input', 'n', 'k', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(5, 1.0, 0.0001, 0.75, None)) paddle.fluid.layers.lrn ArgSpec(args=['input', 'n', 'k', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(5, 1.0, 0.0001, 0.75, None))
paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None)) paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None))
paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None))
paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None)) paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None))
paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)) paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0))
paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)) paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,))
...@@ -165,6 +167,7 @@ paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, ke ...@@ -165,6 +167,7 @@ paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, ke
paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None)) paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None))
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)) paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
...@@ -297,6 +300,7 @@ paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', ' ...@@ -297,6 +300,7 @@ paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', '
paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')) paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral'))
paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3)) paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3))
paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)) paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None))
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
...@@ -379,7 +383,7 @@ paddle.fluid.LoDTensor.__init__ 1. __init__(self: paddle.fluid.core.LoDTensor, a ...@@ -379,7 +383,7 @@ paddle.fluid.LoDTensor.__init__ 1. __init__(self: paddle.fluid.core.LoDTensor, a
paddle.fluid.LoDTensor.has_valid_recursive_sequence_lengths has_valid_recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor) -> bool paddle.fluid.LoDTensor.has_valid_recursive_sequence_lengths has_valid_recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor) -> bool
paddle.fluid.LoDTensor.lod lod(self: paddle.fluid.core.LoDTensor) -> List[List[int]] paddle.fluid.LoDTensor.lod lod(self: paddle.fluid.core.LoDTensor) -> List[List[int]]
paddle.fluid.LoDTensor.recursive_sequence_lengths recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor) -> List[List[int]] paddle.fluid.LoDTensor.recursive_sequence_lengths recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor) -> List[List[int]]
paddle.fluid.LoDTensor.set 1. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CPUPlace) -> None 2. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CPUPlace) -> None 3. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CPUPlace) -> None 4. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CPUPlace) -> None 5. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CPUPlace) -> None 6. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CPUPlace) -> None 7. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CPUPlace) -> None 8. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPlace) -> None 9. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPlace) -> None 10. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPlace) -> None 11. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPlace) -> None 12. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPlace) -> None 13. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPlace) -> None 14. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPlace) -> None 15. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPinnedPlace) -> None 16. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPinnedPlace) -> None 17. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPinnedPlace) -> None 18. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPinnedPlace) -> None 19. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPinnedPlace) -> None 20. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPinnedPlace) -> None 21. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPinnedPlace) -> None paddle.fluid.LoDTensor.set 1. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CPUPlace) -> None 2. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CPUPlace) -> None 3. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CPUPlace) -> None 4. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CPUPlace) -> None 5. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CPUPlace) -> None 6. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CPUPlace) -> None 7. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CPUPlace) -> None 8. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CPUPlace) -> None 9. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPlace) -> None 10. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPlace) -> None 11. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPlace) -> None 12. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPlace) -> None 13. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPlace) -> None 14. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPlace) -> None 15. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPlace) -> None 16. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CUDAPlace) -> None 17. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPinnedPlace) -> None 18. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPinnedPlace) -> None 19. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPinnedPlace) -> None 20. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPinnedPlace) -> None 21. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPinnedPlace) -> None 22. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPinnedPlace) -> None 23. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPinnedPlace) -> None 24. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CUDAPinnedPlace) -> None
paddle.fluid.LoDTensor.set_lod set_lod(self: paddle.fluid.core.LoDTensor, arg0: List[List[int]]) -> None paddle.fluid.LoDTensor.set_lod set_lod(self: paddle.fluid.core.LoDTensor, arg0: List[List[int]]) -> None
paddle.fluid.LoDTensor.set_recursive_sequence_lengths set_recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor, arg0: List[List[int]]) -> None paddle.fluid.LoDTensor.set_recursive_sequence_lengths set_recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor, arg0: List[List[int]]) -> None
paddle.fluid.LoDTensor.shape shape(self: paddle.fluid.core.Tensor) -> List[int] paddle.fluid.LoDTensor.shape shape(self: paddle.fluid.core.Tensor) -> List[int]
......
...@@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) ...@@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
endif() endif()
if (NOT WIN32) if (NOT WIN32)
......
...@@ -64,6 +64,7 @@ static DataTypeMap* InitDataTypeMap() { ...@@ -64,6 +64,7 @@ static DataTypeMap* InitDataTypeMap() {
RegType(size_t, proto::VarType::SIZE_T); RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16); RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8); RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
#undef RegType #undef RegType
return retv; return retv;
......
...@@ -54,6 +54,9 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { ...@@ -54,6 +54,9 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
case proto::VarType::INT16: case proto::VarType::INT16:
visitor.template operator()<int16_t>(); visitor.template operator()<int16_t>();
break; break;
case proto::VarType::INT8:
visitor.template operator()<int8_t>();
break;
default: default:
PADDLE_THROW("Not supported %d", type); PADDLE_THROW("Not supported %d", type);
} }
......
...@@ -754,17 +754,26 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -754,17 +754,26 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node->Op()->Type()); node->Op()->Type());
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
if (node->Op()->Type() == "concat") { }
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
"fetch_barrier"); void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
for (ir::Node *input : node->inputs) {
VarHandle *var = nullptr;
for (int place_offset = 0; place_offset < num_places; ++place_offset) {
auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[input->Name()];
if (!var_holder.empty()) {
var = var_holder.rbegin()->get();
op_handle->AddInput(var);
}
}
} }
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const { ir::Node *node) const {
// FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode
// put them into transpiler.
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
...@@ -799,8 +808,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -799,8 +808,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
} }
auto recv_param_grad = boost::get<std::vector<std::string>>( auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
// FIXME(typhoonzero): assume each recv op output one param
// Use the same place as send.
if (recv_param_grad.size() == 2U) { if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]); op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
VLOG(10) << "recv param " << recv_param_grad[0] VLOG(10) << "recv param " << recv_param_grad[0]
...@@ -814,34 +821,44 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -814,34 +821,44 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
} else { } else {
// send_barrier and fetch_barrier op can be scheduled on device 0 // send_barrier, fetch_barrier will run on place 0;
op_dev_id = 0; op_dev_id = 0;
} }
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
node->Op()->Type()); node->Op()->Type());
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id])); node->Op()->Type(), places_[op_dev_id]));
// TODO(panyx0718): This might not be needed anymore. if (node->Op()->Type() == "send") {
if (node->Op()->Type() == "send_barrier") { CreateOpHandleIOs(result, node, op_dev_id);
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "send");
} else if (node->Op()->Type() == "recv") {
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(),
"send_barrier");
} else if (node->Op()->Type() == "fetch_barrier") {
ConnectOp(result, result->Get<GraphOps>(kGraphOps).back().get(), "recv");
} else if (node->Op()->Type() == "send") {
// do nothing
} else { } else {
PADDLE_THROW( // send_barrier, recv, fetch_barrier's inputs are deps var, get them from
"rpc op should be in [" // all places
"send, send_barrier. recv, fetch_barrier]"); auto p = places_[op_dev_id];
} auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
CreateOpHandleIOs(result, node, op_dev_id); SetOpInputsAllPlaces(result, node, places_.size());
for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(*result, output->Name());
PADDLE_ENFORCE_NE(outvar_dev_id, -1);
}
p = places_[outvar_dev_id];
ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
}
CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id);
}
}
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
......
...@@ -107,6 +107,7 @@ message VarType { ...@@ -107,6 +107,7 @@ message VarType {
// Tensor<size_t> is used in C++. // Tensor<size_t> is used in C++.
SIZE_T = 19; SIZE_T = 19;
UINT8 = 20; UINT8 = 20;
INT8 = 21;
// Other types that may need additional descriptions // Other types that may need additional descriptions
LOD_TENSOR = 7; LOD_TENSOR = 7;
......
...@@ -3,14 +3,18 @@ cc_library(graph SRCS graph.cc DEPS node) ...@@ -3,14 +3,18 @@ cc_library(graph SRCS graph.cc DEPS node)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter) cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass) cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
cc_library(fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detector graph pass graph_traits framework_proto)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/api/helper.h"
namespace paddle {
namespace framework {
namespace ir {
struct Param {
std::string X = "concat_0.tmp_0";
std::string C0 = "cell_init";
std::string H0 = "hidden_init";
std::string AttentionWeight = "attention_fc.w_0";
std::string AttentionBias = "attention_fc.b_0";
std::string AttentionScalar = "attention_output.w_0";
std::string AttentionScalarBias = "attention_output.b_0";
std::string LSTMWeight = "attention_w.new";
std::string LSTMBias = "attention_b.new";
std::string Hidden = "array_to_lod_tensor_0.tmp_0";
std::string Cell = "at.cell.new";
std::string AttentionedX = "at.x.new";
std::string AttentionFCOut = "at.fc.new";
std::string LSTMX = "at.lstmx.new";
std::string LSTMOUT = "at.lstmout.new";
};
void PrepareParameters(Graph* graph, const Param& param);
void FindWhileOp(Graph* graph) {
GraphPatternDetector gpd;
std::unordered_set<int> fused_external_ops(
{35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48,
57, 55, 56, 52, 74, 80, 77, 78, 79, 50, 77, 39, 40, 51});
gpd.mutable_pattern()->NewNode(
[&](Node* n) { return fused_external_ops.count(n->id()); }, "while");
if (!graph->Has(kGraphvizMarkedNodeAttr)) {
graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t);
}
auto& marked_nodes =
graph->Get<GraphVizPass::marked_nodes_t>(kGraphvizMarkedNodeAttr);
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto* while_pat_node = gpd.pattern().RetriveNode("while");
auto* while_node = subgraph.at(while_pat_node);
marked_nodes.insert(while_node);
};
gpd(graph, handle);
Param param;
// Add AttentionLSTM node
OpDesc op_desc;
op_desc.SetType("attention_lstm");
#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x});
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x});
OP_SET_IN(X);
OP_SET_IN(C0);
OP_SET_IN(H0);
OP_SET_IN(AttentionWeight);
OP_SET_IN(AttentionBias);
OP_SET_IN(AttentionScalar);
OP_SET_IN(AttentionScalarBias);
OP_SET_IN(LSTMWeight);
OP_SET_IN(LSTMBias);
OP_SET_OUT(Hidden);
OP_SET_OUT(Cell);
OP_SET_OUT(AttentionedX);
OP_SET_OUT(AttentionFCOut);
OP_SET_OUT(LSTMX);
OP_SET_OUT(LSTMOUT);
#undef OP_SET_IN
#undef OP_SET_OUT
auto* X = graph->RetriveNode(34);
auto* LSTMOUT = graph->RetriveNode(81);
auto* cell_init = graph->RetriveNode(6);
auto* hidden_init = graph->RetriveNode(8);
#define LINK_TO(node0, node1) \
node0->outputs.push_back(node1); \
node1->inputs.push_back(node0);
auto* lstm_op = graph->CreateOpNode(&op_desc);
PrepareParameters(graph, param);
LINK_TO(X, lstm_op);
LINK_TO(cell_init, lstm_op);
LINK_TO(hidden_init, lstm_op);
LINK_TO(lstm_op, LSTMOUT);
GraphSafeRemoveNodes(graph, marked_nodes);
}
#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x);
#define CHECK_P2(x0, x1) \
CHECK_P1(x0); \
CHECK_P1(x1);
#define CHECK_P3(x0, x1, x2) \
CHECK_P2(x0, x1); \
CHECK_P1(x2);
#define CHECK_P4(x0, x1, x2, x3) \
CHECK_P3(x0, x1, x2); \
CHECK_P1(x3);
#define CHECK_P5(x0, x1, x2, x3, x4) \
CHECK_P4(x0, x1, x2, x3); \
CHECK_P1(x4);
void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
const LoDTensor& W_forget_w1,
const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
const LoDTensor& W_output_w0,
const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
const LoDTensor& W_cell_w1, LoDTensor* out);
void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
const LoDTensor& B_output, const LoDTensor& B_cell,
LoDTensor* out);
void PrepareParameters(Graph* graph, const Param& param) {
// Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
// Create new parameters.
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope->Var(param.Hidden)->GetMutable<LoDTensor>();
scope->Var(param.Cell)->GetMutable<LoDTensor>();
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
#define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_w1" \
<< " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_b0" \
<< " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \
auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \
auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();
GATE_W(forget);
GATE_W(input);
GATE_W(output);
GATE_W(c);
#undef GATE_W
auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
auto* attention_output_w = scope->FindVar("attention_output.w_0");
auto* attention_output_b = scope->FindVar("attention_output.b_0");
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
attention_output_b);
auto* lstm_weight = scope->Var(param.LSTMWeight);
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
auto* lstm_bias = scope->Var(param.LSTMBias);
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
// reshape attention_bias
auto* attention_bias_t =
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
auto* attention_scalar_bias_t =
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
attention_scalar_bias_t->Resize(
make_ddim({1, attention_scalar_bias_t->dims()[0]}));
PrepareLSTMWeight(W_forget_w0_t, W_forget_w1_t, W_input_w0_t, W_input_w1_t,
W_output_w0_t, W_output_w1_t, W_c_w0_t, W_c_w1_t,
lstm_weight_t);
PrepareLSTMBias(W_forget_b0_t, W_input_b0_t, W_output_b0_t, W_c_b0_t,
lstm_bias_t);
}
// Prepare parameters
void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
const LoDTensor& W_forget_w1,
const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
const LoDTensor& W_output_w0,
const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
const LoDTensor& W_cell_w1, LoDTensor* out) {
int D = W_forget_w0.dims()[0];
int M = W_forget_w1.dims()[0];
out->Resize(make_ddim({D + M, 4 * D}));
VLOG(3) << "LSTMWeight resized to " << out->dims();
float* out_data = out->mutable_data<float>(platform::CPUPlace());
std::array<const float*, 4> tensors(
{W_forget_w0.data<float>(), W_input_w0.data<float>(),
W_output_w0.data<float>(), W_cell_w0.data<float>()});
std::array<const float*, 4> tensors1(
{W_forget_w1.data<float>(), W_input_w1.data<float>(),
W_output_w1.data<float>(), W_cell_w1.data<float>()});
for (int row = 0; row < D; row++) {
for (int col = 0; col < 4; col++) {
float* dst = out_data + 4 * D * row + D * col;
const float* src = tensors[col] + D * row;
memcpy(dst, src, D * sizeof(float));
}
}
for (int row = 0; row < M; row++) {
for (int col = 0; col < 4; col++) {
float* dst = out_data + 4 * D * (D + row) + D * col;
const float* src = tensors1[col] + D * row;
memcpy(dst, src, D * sizeof(float));
}
}
}
void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
const LoDTensor& B_output, const LoDTensor& B_cell,
LoDTensor* out) {
std::array<const float*, 4> tensors(
{B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
B_cell.data<float>()});
PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1);
int D = B_forget.dims()[0];
out->Resize(make_ddim({1, 4 * D}));
auto* out_data = out->mutable_data<float>(platform::CPUPlace());
for (size_t i = 0; i < tensors.size(); i++) {
memcpy(out_data + D * i, tensors[i], D * sizeof(float));
}
}
// Parameters
std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PDPattern external_pattern, subblock_pattern;
FindWhileOp(graph.get());
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(attention_lstm_fuse_pass,
paddle::framework::ir::AttentionLSTMFusePass);
...@@ -12,12 +12,19 @@ ...@@ -12,12 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/dot.h" #pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle { namespace paddle {
namespace inference { namespace framework {
namespace analysis { namespace ir {
size_t Dot::counter = 0;
} // namespace analysis class AttentionLSTMFusePass : public FusePassBase {
} // namespace inference protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) { ...@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) {
}, },
"elementwise_add_out"); "elementwise_add_out");
pattern->AddEdge(mul_parameter_var, mul_op); mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var})
pattern->AddEdge(mul_tmp_input_var, mul_op); .LinksTo({mul_out_var});
pattern->AddEdge(mul_op, mul_out_var); elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
pattern->AddEdge(mul_out_var, elementwise_add_op); .LinksTo({elementwise_add_out_var});
pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op);
pattern->AddEdge(elementwise_add_op, elementwise_add_out_var);
} }
// Replace the node `from` in the links to `to` // Replace the node `from` in the links to `to`
...@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
GraphPatternDetecter gpd; GraphPatternDetector gpd;
BuildFCPattern(gpd.mutable_pattern()); BuildFCPattern(gpd.mutable_pattern());
#define GET_NODE(id) \ #define GET_NODE(id) \
...@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \ auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle FC fuse"; VLOG(4) << "handle FC fuse";
// Currently, there is no FC op available, so I will just simulate the // Currently, there is no FC op available, so I will just simulate the
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
std::unordered_set<int> fused_ops({// first lstm
13, 15, 16,
// second lstm
23, 25, 26});
pattern->NewNode([&](Node* x) { return fused_ops.count(x->id()); },
"any_node");
std::unordered_set<Node*> marked_nodes;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node"));
marked_nodes.insert(id);
};
gpd(graph.get(), handler);
// Create New OpDesc
auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h,
int bias, int hidden, int cell, int xx) {
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE(input);
GET_NODE(weight_x);
GET_NODE(weight_h);
GET_NODE(bias);
GET_NODE(hidden);
GET_NODE(cell);
GET_NODE(xx);
GET_NODE(lstm);
OpDesc op_desc;
op_desc.SetType("fusion_lstm");
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()});
SET_IN(X, input);
SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h);
SET_IN(Bias, bias);
#undef GET_NODE
#undef SET_IN
LOG(INFO) << "hidden_n: " << hidden_n->Name();
LOG(INFO) << "cell: " << cell_n->Name();
LOG(INFO) << "xx: " << xx_n->Name();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden_n->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()});
op_desc.SetOutput("XX", {xx_n->Name()});
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"});
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", false);
auto* op = graph->CreateOpNode(&op_desc);
#define LINK_TO(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
LINK_TO(input_n, op);
LINK_TO(weight_x_n, op);
LINK_TO(weight_h_n, op);
LINK_TO(bias_n, op);
LINK_TO(op, hidden_n);
#undef LINK_TO
return op;
};
lstm_creator(16, 12, 14, 18, 17, 22, 21, 19);
lstm_creator(26, 12, 24, 28, 27, 32, 31, 29);
// remove all the nodes
for (auto* node : marked_nodes) {
graph->RemoveNode(const_cast<Node*>(node));
}
for (auto* node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (marked_nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it);
} else
it++;
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (marked_nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it);
} else
it++;
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class FCLstmFusePass : public Pass {
public:
virtual ~FCLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace ir {
static const char kParamScopeAttr[] = "param_scope";
class FusePassBase : public Pass {
public:
void Init(Graph* graph) const { graph_ = graph; }
Scope* param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
}
virtual ~FusePassBase() {}
protected:
mutable Graph* graph_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -132,63 +132,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -132,63 +132,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
} }
} }
std::vector<ir::Node *> send_ops;
ir::Node *send_bar = nullptr;
std::vector<ir::Node *> recv_ops;
ir::Node *fetch_bar = nullptr;
for (ir::Node *node : Nodes()) {
if (node->Name() == "send") {
send_ops.push_back(node);
} else if (node->Name() == "send_barrier") {
PADDLE_ENFORCE(!send_bar, "only has one send barrier");
send_bar = node;
} else if (node->Name() == "recv") {
recv_ops.push_back(node);
} else if (node->Name() == "fetch_barrier") {
PADDLE_ENFORCE(!fetch_bar, "only has one fetch barrier");
fetch_bar = node;
}
}
if (send_bar) {
for (ir::Node *send : send_ops) {
ir::Node *dep_var = CreateControlDepVar();
send->outputs.push_back(dep_var);
dep_var->inputs.push_back(send);
send_bar->inputs.push_back(dep_var);
dep_var->outputs.push_back(send_bar);
}
for (ir::Node *recv : recv_ops) {
ir::Node *dep_var = CreateControlDepVar();
recv->inputs.push_back(dep_var);
dep_var->outputs.push_back(recv);
send_bar->outputs.push_back(dep_var);
dep_var->inputs.push_back(send_bar);
}
}
if (fetch_bar) {
for (ir::Node *recv : recv_ops) {
ir::Node *dep_var = CreateControlDepVar();
recv->outputs.push_back(dep_var);
dep_var->inputs.push_back(recv);
fetch_bar->inputs.push_back(dep_var);
dep_var->outputs.push_back(fetch_bar);
}
}
std::vector<std::string> send_vars = FindDistTrainSendVars(send_ops);
std::vector<std::string> recv_vars = FindDistTrainRecvVars(recv_ops);
for (ir::Node *node : Nodes()) {
if (IsDistTrainOp(node, send_vars, recv_vars)) {
if (fetch_bar && node->Name() == "concat") {
ir::Node *dep_var = CreateControlDepVar();
fetch_bar->outputs.push_back(dep_var);
dep_var->inputs.push_back(fetch_bar);
node->inputs.push_back(dep_var);
dep_var->outputs.push_back(node);
}
}
}
/** /**
* We should handle write after read(WAR) and write after write(WAW) here. * We should handle write after read(WAR) and write after write(WAW) here.
* Because some of the operators of the program can be executed parallelly. * Because some of the operators of the program can be executed parallelly.
......
...@@ -99,13 +99,13 @@ class Graph { ...@@ -99,13 +99,13 @@ class Graph {
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc); PADDLE_ENFORCE(var_desc);
return AddNode(new ir::Node(var_desc)); return AddNode(new ir::Node(var_desc, node_count_++));
} }
// Create a normal runnable operator with OpDesc. // Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc); PADDLE_ENFORCE(op_desc);
return AddNode(new ir::Node(op_desc)); return AddNode(new ir::Node(op_desc, node_count_++));
} }
// Create a control dependency var that connects 2 operations. The // Create a control dependency var that connects 2 operations. The
...@@ -115,13 +115,14 @@ class Graph { ...@@ -115,13 +115,14 @@ class Graph {
// TODO(panyx0718): control var name should be really unique. // TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf( const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); "%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(new ir::Node(name, ir::Node::Type::kVariable)); return AddNode(
new ir::Node(name, ir::Node::Type::kVariable, node_count_++));
} }
// A more free style way of creating a graph node. Mostly use for test // A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible. // or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
return AddNode(new ir::Node(name, type)); return AddNode(new ir::Node(name, type, node_count_++));
} }
// Clear all node information of the graph and return the ownership of the // Clear all node information of the graph and return the ownership of the
...@@ -142,12 +143,20 @@ class Graph { ...@@ -142,12 +143,20 @@ class Graph {
nodes_.erase(node); nodes_.erase(node);
} }
Node *RetriveNode(int id) {
auto it = id2node_.find(id);
if (it != id2node_.end()) return it->second;
return nullptr;
}
private: private:
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
nodes_[node].reset(node); nodes_[node].reset(node);
node_set_.insert(node); node_set_.insert(node);
PADDLE_ENFORCE(!id2node_.count(node->id()), "duplicate id %d", node->id());
id2node_[node->id()] = node;
return node; return node;
} }
...@@ -157,6 +166,8 @@ class Graph { ...@@ -157,6 +166,8 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_; std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_; std::unordered_set<ir::Node *> node_set_;
std::map<int, Node *> id2node_;
int node_count_{0};
}; };
bool IsControlDepVar(const ir::Node &var); bool IsControlDepVar(const ir::Node &var);
......
...@@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( ...@@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
for (auto &var : n->inputs) { for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) { for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
adj_list[n].insert(adj_n);
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
adj_list[n].insert(adj_n);
} }
} }
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { ...@@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
name); name);
} }
nodes_.emplace_back(new PDNode(std::move(teller), name)); nodes_.emplace_back(new PDNode(std::move(teller), this, name));
auto* cur = nodes_.back().get(); auto* cur = nodes_.back().get();
node_map_[name] = cur; node_map_[name] = cur;
return cur; return cur;
...@@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) { ...@@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
edges_.emplace_back(a, b); edges_.emplace_back(a, b);
} }
void GraphPatternDetecter::operator()(Graph* graph, void GraphPatternDetector::operator()(Graph* graph,
GraphPatternDetecter::handle_t handler) { GraphPatternDetector::handle_t handler) {
if (!MarkPDNodesInGraph(*graph)) return; if (!MarkPDNodesInGraph(*graph)) return;
auto subgraphs = DetectPatterns(); auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs); UniquePatterns(&subgraphs);
RemoveOverlappedMatch(&subgraphs); RemoveOverlappedMatch(&subgraphs);
LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
int id = 0;
for (auto& g : subgraphs) { for (auto& g : subgraphs) {
LOG(INFO) << "optimizing #" << id++ << " subgraph";
handler(g, graph); handler(g, graph);
} }
} }
bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) { bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
VLOG(4) << "mark pdnodes in graph"; VLOG(4) << "mark pdnodes in graph";
if (graph.Nodes().empty()) return false; if (graph.Nodes().empty()) return false;
...@@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) { ...@@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) {
return false; return false;
} }
std::vector<GraphPatternDetecter::subgraph_t> std::vector<GraphPatternDetector::subgraph_t>
GraphPatternDetecter::DetectPatterns() { GraphPatternDetector::DetectPatterns() {
// Init empty subgraphs. // Init empty subgraphs.
std::vector<GraphPatternDetecter::subgraph_t> result; std::vector<GraphPatternDetector::subgraph_t> result;
std::vector<HitGroup> init_groups; std::vector<HitGroup> init_groups;
PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed"); std::array<std::vector<HitGroup>, 2> bi_records;
auto* first_pnode = pattern_.edges().front().first; // PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto* first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
: pattern_.edges().front().first;
if (!pdnodes2nodes_.count(first_pnode)) return result; if (!pdnodes2nodes_.count(first_pnode)) return result;
for (auto* node : pdnodes2nodes_[first_pnode]) { for (auto* node : pdnodes2nodes_[first_pnode]) {
HitGroup group; HitGroup group;
...@@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() {
} }
int step = 0; int step = 0;
std::array<std::vector<HitGroup>, 2> bi_records;
bi_records[0] = std::move(init_groups); bi_records[0] = std::move(init_groups);
// Extend a PDNode to subgraphs by deducing the connection relations defined // Extend a PDNode to subgraphs by deducing the connection relations defined
...@@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() {
auto& pre_groups = bi_records[step % 2]; auto& pre_groups = bi_records[step % 2];
auto& cur_groups = bi_records[1 - (step++ % 2)]; auto& cur_groups = bi_records[1 - (step++ % 2)];
cur_groups.clear(); cur_groups.clear();
if (pre_groups.empty()) break;
// source -> target // source -> target
for (Node* source : pdnodes2nodes_[edge.first]) { for (Node* source : pdnodes2nodes_[edge.first]) {
for (Node* target : pdnodes2nodes_[edge.second]) { for (Node* target : pdnodes2nodes_[edge.second]) {
...@@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() {
} }
for (auto& group : bi_records[step % 2]) { for (auto& group : bi_records[step % 2]) {
GraphPatternDetecter::subgraph_t subgraph; GraphPatternDetector::subgraph_t subgraph;
for (auto& role : group.roles) { for (auto& role : group.roles) {
subgraph.emplace(role.first, role.second); subgraph.emplace(role.first, role.second);
} }
...@@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() {
return result; return result;
} }
void GraphPatternDetecter::UniquePatterns( void GraphPatternDetector::UniquePatterns(
std::vector<GraphPatternDetecter::subgraph_t>* subgraphs) { std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
if (subgraphs->empty()) return; if (subgraphs->empty()) return;
std::vector<GraphPatternDetecter::subgraph_t> result; std::vector<GraphPatternDetector::subgraph_t> result;
std::unordered_set<size_t> set; std::unordered_set<size_t> set;
for (auto& g : *subgraphs) { for (auto& g : *subgraphs) {
...@@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns( ...@@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns(
*subgraphs = result; *subgraphs = result;
} }
void GraphPatternDetecter::RemoveOverlappedMatch( void GraphPatternDetector::RemoveOverlappedMatch(
std::vector<subgraph_t>* subgraphs) { std::vector<subgraph_t>* subgraphs) {
std::vector<subgraph_t> result; std::vector<subgraph_t> result;
std::unordered_set<Node*> node_set; std::unordered_set<Node*> node_set;
...@@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch( ...@@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch(
*subgraphs = result; *subgraphs = result;
} }
std::string PDPattern::DotString() const {
using inference::analysis::Dot;
Dot dot;
int id = 0;
// Create Nodes
std::unordered_map<PDNode*, std::string> node2dot;
for (const auto& node : nodes()) {
std::string node_id = "Node" + std::to_string(id++);
dot.AddNode(node_id, {}, node->name());
node2dot[node.get()] = node_id;
}
// Create Edges
for (const auto& edge : edges()) {
if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) {
LOG(ERROR) << "no node " << edge.first << " " << edge.second;
continue;
}
auto& src = node2dot.at(edge.first);
auto& trg = node2dot.at(edge.second);
dot.AddEdge(src, trg, {});
}
return dot.Build();
}
PDNode& PDNode::LinksTo(const std::vector<PDNode*>& others) {
// extend outlinks.
for (PDNode* x : others) {
pattern_->AddEdge(this, x);
}
return *this;
}
PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
// extend outlinks.
for (PDNode* x : others) {
pattern_->AddEdge(x, this);
}
return *this;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -21,12 +21,14 @@ ...@@ -21,12 +21,14 @@
#include <numeric> #include <numeric>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class PDPattern;
// Some basic torminolygies: // Some basic terminologies:
// - PDPattern: a pattern defined as a data flow graph. // - PDPattern: a pattern defined as a data flow graph.
// - PDNode: the node in the pattern, each PDNode represents an `ir::Node` // - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
// that meets some conditions defined in `PDNode.teller`. // that meets some conditions defined in `PDNode.teller`.
...@@ -36,30 +38,43 @@ namespace ir { ...@@ -36,30 +38,43 @@ namespace ir {
struct PDNode { struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode. // tell whether an ir::Node* is a candidation for a PDNode.
using teller_t = std::function<bool(Node*)>; using teller_t = std::function<bool(Node*)>;
enum class Type { kOp, kVar };
PDNode(teller_t&& teller, const std::string& name = "") // this link to others
: teller_(teller), name_(name) { PDNode& LinksTo(const std::vector<PDNode*>& others);
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set."); PDNode& LinksFrom(const std::vector<PDNode*>& others);
}
PDNode(PDNode&& other) = default;
std::vector<PDNode*> inlinks;
std::vector<PDNode*> outlinks;
bool Tell(Node* node) const { bool Tell(Node* node) const {
PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode");
return teller_(node); return teller_(node);
} }
bool IsOp() const { return type_ == Type::kOp; }
bool IsVar() const { return type_ == Type::kVar; }
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
PDNode(const PDNode&) = delete; PDNode(const PDNode&) = delete;
PDNode& operator=(const PDNode&) = delete; PDNode& operator=(const PDNode&) = delete;
private: private:
PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
: teller_(std::move(teller)),
pattern_(pattern),
name_(name),
type_(type) {
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set.");
}
PDNode(PDNode&& other) = default;
friend class PDPattern;
teller_t teller_; teller_t teller_;
PDPattern* pattern_;
std::string name_; std::string name_;
Type type_;
}; };
/* /*
...@@ -102,6 +117,8 @@ class PDPattern { ...@@ -102,6 +117,8 @@ class PDPattern {
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; } const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
const std::vector<edge_t>& edges() const { return edges_; } const std::vector<edge_t>& edges() const { return edges_; }
std::string DotString() const;
private: private:
#ifdef PADDLE_WITH_TESTING #ifdef PADDLE_WITH_TESTING
FRIEND_TEST(PDPattern, AddEdge); FRIEND_TEST(PDPattern, AddEdge);
...@@ -117,7 +134,7 @@ class PDPattern { ...@@ -117,7 +134,7 @@ class PDPattern {
}; };
/* /*
* GraphPatternDetecter helps to detect the specific patterns in the graph. * GraphPatternDetector helps to detect the specific patterns in the graph.
* Input a pattern, output a list of the matched subgraphs/nodes. * Input a pattern, output a list of the matched subgraphs/nodes.
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.). * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
* *
...@@ -129,7 +146,7 @@ class PDPattern { ...@@ -129,7 +146,7 @@ class PDPattern {
* *
* Usage: * Usage:
* // Create a detector * // Create a detector
* GraphPatternDetecter detector; * GraphPatternDetector detector;
* // Define the detector's pattern, by adding PDNode and define the edges. * // Define the detector's pattern, by adding PDNode and define the edges.
* auto* node0 = detector.mutable_pattern().AddNode(...) * auto* node0 = detector.mutable_pattern().AddNode(...)
* auto* node1 = detector.mutable_pattern().AddNode(...) * auto* node1 = detector.mutable_pattern().AddNode(...)
...@@ -138,11 +155,11 @@ class PDPattern { ...@@ -138,11 +155,11 @@ class PDPattern {
* detector.mutable_pattern().AddEdge(node0, node1); * detector.mutable_pattern().AddEdge(node0, node1);
* // Create an handler, to define the behavior of treating the filtered * // Create an handler, to define the behavior of treating the filtered
* // subgraphs that comply with the patterns. * // subgraphs that comply with the patterns.
* GraphPatternDetecter::handle_t handler = some labmda * GraphPatternDetector::handle_t handler = some labmda
* // Execute the detector. * // Execute the detector.
* detector(&graph, handler); * detector(&graph, handler);
*/ */
class GraphPatternDetecter { class GraphPatternDetector {
public: public:
using subgraph_t = std::unordered_map<PDNode*, Node*>; using subgraph_t = std::unordered_map<PDNode*, Node*>;
...@@ -177,10 +194,62 @@ class GraphPatternDetecter { ...@@ -177,10 +194,62 @@ class GraphPatternDetecter {
using hit_rcd_t = using hit_rcd_t =
std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>; std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>;
PDPattern pattern_; PDPattern pattern_;
std::vector<hit_rcd_t> marked_records_;
std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_; std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_;
}; };
// some helper methods.
// Op's input.
static bool VarLinksToOp(Node* node, const std::string& op_type) {
for (auto* out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
// Op's output.
static bool VarLinksFromOp(Node* node, const std::string& op_type) {
for (auto* out : node->inputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
// Check whether a var node is a op node's nth input.
static bool IsNthInput(Node* var, Node* op, const std::string& argument,
size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->inputs.size() <= nth) return false;
return var->Name() == op->Op()->Input(argument)[nth];
}
static void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes) {
for (auto* node : nodes) {
graph->RemoveNode(const_cast<Node*>(node));
}
for (auto* node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it);
} else
it++;
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it);
} else
it++;
}
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) { ...@@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) {
} }
TEST(GraphPatternDetecter, MarkPDNodesInGraph) { TEST(GraphPatternDetecter, MarkPDNodesInGraph) {
GraphPatternDetecter x; GraphPatternDetector x;
// mark o2, o3, v2 // mark o2, o3, v2
// The pattern is a graph: // The pattern is a graph:
...@@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ...@@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
Graph graph(program); Graph graph(program);
BuildGraph(&graph); BuildGraph(&graph);
GraphPatternDetecter x; GraphPatternDetector x;
// The pattern is a graph: // The pattern is a graph:
// op -> var // op -> var
...@@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ...@@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
x.mutable_pattern()->AddEdge(any_var, any_op1); x.mutable_pattern()->AddEdge(any_var, any_op1);
int count = 0; int count = 0;
GraphPatternDetecter::handle_t handle = [&]( GraphPatternDetector::handle_t handle = [&](
const GraphPatternDetecter::subgraph_t& s, Graph* g) { const GraphPatternDetector::subgraph_t& s, Graph* g) {
LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> " LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> "
<< s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name(); << s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name();
count++; count++;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
std::unique_ptr<Graph> graph) const {
ProgramDesc& program = Get<ProgramDesc>("program");
std::unique_ptr<proto::ProgramDesc> program_pb(
new proto::ProgramDesc(*program.Proto()));
auto block = program_pb->mutable_blocks(kRootBlockIndex);
block->clear_vars();
std::unordered_set<std::string> visited_vars;
for (ir::Node* n : graph->Nodes()) {
if (n->NodeType() == ir::Node::Type::kVariable) {
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) {
visited_vars.insert(n->Var()->Name());
block->add_vars()->MergeFrom(*n->Var()->Proto());
}
}
}
block->clear_ops();
std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
for (ir::Node* n : nodes) {
if (!n->Op()) {
continue;
}
block->add_ops()->MergeFrom(*n->Op()->Proto());
}
program.CopyFrom(*program_pb);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(graph_to_program_pass, paddle::framework::ir::GraphToProgramPass);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class GraphToProgramPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
void BuildNoCircleGraph(Graph* g) {
OpDesc op1;
op1.SetType("op1");
OpDesc op2;
op2.SetType("op2");
OpDesc op3;
op3.SetType("op3");
OpDesc op4;
op4.SetType("op4");
OpDesc op5;
op5.SetType("op5");
VarDesc var1("var1");
VarDesc var2("var2");
VarDesc var3("var3");
VarDesc var4("var4");
ir::Node* o1 = g->CreateOpNode(&op1);
ir::Node* o2 = g->CreateOpNode(&op2);
ir::Node* o3 = g->CreateOpNode(&op3);
ir::Node* o4 = g->CreateOpNode(&op4);
ir::Node* o5 = g->CreateOpNode(&op5);
ir::Node* v1 = g->CreateVarNode(&var1);
ir::Node* v2 = g->CreateVarNode(&var2);
ir::Node* v3 = g->CreateVarNode(&var3);
ir::Node* v4 = g->CreateVarNode(&var4);
// o1->v1->o2
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
// o2->v2->o3
// o2->v2->o4
o2->outputs.push_back(v2);
o3->inputs.push_back(v2);
o4->inputs.push_back(v2);
v2->outputs.push_back(o3);
v2->outputs.push_back(o4);
v2->inputs.push_back(o2);
// o2->v3->o5
o2->outputs.push_back(v3);
o5->inputs.push_back(v3);
v3->inputs.push_back(o2);
v3->outputs.push_back(o5);
// o3-v4->o5
o3->outputs.push_back(v4);
o5->inputs.push_back(v4);
v4->inputs.push_back(o3);
v4->outputs.push_back(o5);
}
TEST(GraphToProgramPass, Basic) {
ProgramDesc prog;
std::unique_ptr<Graph> g(new Graph(prog));
BuildNoCircleGraph(g.get());
auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
"graph_to_program_pass");
ProgramDesc compiled_prog;
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &compiled_prog);
pass->Apply(std::move(g));
std::vector<OpDesc*> ops = compiled_prog.Block(0).AllOps();
EXPECT_EQ(ops[0]->Type(), "op1");
EXPECT_EQ(ops[1]->Type(), "op2");
if (ops[2]->Type() == "op3") {
EXPECT_EQ(ops[3]->Type(), "op4");
} else if (ops[2]->Type() == "op4") {
EXPECT_EQ(ops[3]->Type(), "op3");
}
EXPECT_EQ(ops[4]->Type(), "op5");
std::unordered_set<std::string> vars;
for (VarDesc* v : compiled_prog.Block(0).AllVars()) {
vars.insert(v->Name());
}
EXPECT_TRUE(vars.find("var1") != vars.end());
EXPECT_TRUE(vars.find("var2") != vars.end());
EXPECT_TRUE(vars.find("var3") != vars.end());
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(graph_to_program_pass);
...@@ -16,11 +16,13 @@ limitations under the License. */ ...@@ -16,11 +16,13 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static const char kGraphVizPath[] = "graph_viz_path"; static const char kGraphVizPath[] = "graph_viz_path";
using inference::analysis::Dot;
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
...@@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( ...@@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout; std::ostream& sout = *fout;
size_t var_id = 0; std::unordered_map<const ir::Node*, std::string> node2dot;
std::unordered_map<const ir::Node*, size_t> vars;
Dot dot;
sout << "digraph G {\n";
std::vector<Dot::Attr> op_attrs({Dot::Attr("style", "filled"),
for (const ir::Node* n : graph->Nodes()) { Dot::Attr("shape", "box"),
if (n->NodeType() != ir::Node::Type::kVariable) continue; Dot::Attr("fillcolor", "red")});
size_t cur_var_id = var_id++; std::vector<Dot::Attr> var_attrs({Dot::Attr("style", "filled,rounded"),
vars[n] = cur_var_id; // Dot::Attr("shape", "diamond"),
Dot::Attr("fillcolor", "yellow")});
sout << "var_" << cur_var_id << " [label=\"" << n->Name() << "\"]"
<< std::endl; std::vector<Dot::Attr> marked_op_attrs({Dot::Attr("style", "filled"),
Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "lightgray")});
std::vector<Dot::Attr> marked_var_attrs(
{Dot::Attr("style", "filled,rounded"),
// Dot::Attr("shape", "diamond"),
Dot::Attr("fillcolor", "lightgray")});
auto marked_nodes = ConsumeMarkedNodes(graph.get());
// Create nodes
for (const Node* n : graph->Nodes()) {
std::string node_id = n->Name() + "(" + std::to_string(n->id()) + ")";
if (n->IsOp()) {
decltype(op_attrs) attr =
marked_nodes.count(n) ? marked_op_attrs : op_attrs;
dot.AddNode(node_id, attr, node_id);
} else if (n->IsVar()) {
decltype(op_attrs) attr =
marked_nodes.count(n) ? marked_var_attrs : var_attrs;
dot.AddNode(node_id, attr, node_id);
} }
node2dot[n] = node_id;
size_t op_id = 0;
for (const ir::Node* n : graph->Nodes()) {
if (n->NodeType() != ir::Node::Type::kOperation) continue;
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << n->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : n->inputs) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
} }
// Create edges
for (auto out : n->outputs) { for (const Node* n : graph->Nodes()) {
std::string var_name = "var_" + std::to_string(vars[out]); const auto& src_id = node2dot.at(n);
sout << op_name << " -> " << var_name << std::endl; for (auto* out : n->outputs) {
const auto& trg_id = node2dot.at(out);
dot.AddEdge(src_id, trg_id, {});
} }
} }
sout << "}\n"; sout << dot.Build();
return graph; return graph;
} }
GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
Graph* graph) const {
marked_nodes_t res;
if (graph->Has(kGraphvizMarkedNodeAttr)) {
auto& attr = graph->Get<marked_nodes_t>(kGraphvizMarkedNodeAttr);
res = attr;
attr.clear();
}
return res;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -27,10 +27,19 @@ namespace paddle { ...@@ -27,10 +27,19 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
const char kGraphvizMarkedNodeAttr[] = "__graphviz__marked_node__";
class GraphVizPass : public Pass { class GraphVizPass : public Pass {
public:
using marked_nodes_t = std::unordered_set<const Node*>;
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
// Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute.
marked_nodes_t ConsumeMarkedNodes(Graph* graph) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -29,20 +29,26 @@ class Node { ...@@ -29,20 +29,26 @@ class Node {
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var"; static constexpr char kControlDepVarName[] = "__control_var";
explicit Node(const std::string& name, Type type) explicit Node(const std::string& name, Type type, int id = -1)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} : name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
id_(id) {}
explicit Node(VarDesc* var_desc) explicit Node(VarDesc* var_desc, int id = -1)
: name_(var_desc->Name()), : name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)), var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr), op_desc_(nullptr),
type_(Type::kVariable) {} type_(Type::kVariable),
id_(id) {}
explicit Node(OpDesc* op_desc) explicit Node(OpDesc* op_desc, int id = -1)
: name_(op_desc->Type()), : name_(op_desc->Type()),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())), op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation) {} type_(Type::kOperation),
id_(id) {}
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
...@@ -58,6 +64,8 @@ class Node { ...@@ -58,6 +64,8 @@ class Node {
return op_desc_.get(); return op_desc_.get();
} }
int id() const { return id_; }
bool IsOp() const { return type_ == Type::kOperation; } bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; } bool IsVar() const { return type_ == Type::kVariable; }
...@@ -69,6 +77,7 @@ class Node { ...@@ -69,6 +77,7 @@ class Node {
std::unique_ptr<VarDesc> var_desc_; std::unique_ptr<VarDesc> var_desc_;
std::unique_ptr<OpDesc> op_desc_; std::unique_ptr<OpDesc> op_desc_;
Type type_; Type type_;
int id_;
private: private:
DISABLE_COPY_AND_ASSIGN(Node); DISABLE_COPY_AND_ASSIGN(Node);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
namespace ir {
struct FuseExpr {};
// sequence expand, concat fuse pattern, return concat's output
PDNode* BuildSeqExpandConcatPattern(PDPattern* pattern) {
// The following operators will be fused:
// concat
// sequence_expand
// sequence_expand
// The following variables will be treat as inputs:
// concat mid input, 0th input for fused op
// sequence_expand input, 1th input for fused op
// sequence_expand input, 2th input for fused op
// The following variables will be treat as outputs:
// concat output
// So the following variables will be removed:
// sequence-expand output
// sequence-expand output
// Three operators
auto* sequence_expand0 = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "sequence_expand";
},
"sequence_expand0");
auto* sequence_expand1 = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "sequence_expand";
},
"sequence_expand1");
auto* concat = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "concat" && // basic check
x->Op()->Input("X").size() == 3; // Special case
},
"concat");
auto* sequence_expand0_in = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && VarLinksToOp(x, "sequence_expand");
},
"sequence_expand0_in");
auto* sequence_expand1_in = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && VarLinksToOp(x, "sequence_expand");
},
"sequence_expand1_in");
// The variables
auto* sequence_expand0_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() &&
VarLinksFromOp(x, "sequence_expand") && // basic check
VarLinksToOp(x, "concat") && // is concat's input
IsNthInput(x, x->outputs[0], "X", 1); // X[0]
},
"sequence_expand0_out");
auto* sequence_expand1_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() &&
VarLinksFromOp(x, "sequence_expand") && // basic check
VarLinksToOp(x, "concat") && // is concat's input
IsNthInput(x, x->outputs[0], "X", 2); // x[2]
},
"sequence_expand1_out");
auto* concat_in0 = pattern->NewNode(
[](Node* x) { return x && x->IsVar() && VarLinksToOp(x, "concat"); },
"concat_in0");
auto* concat_out = pattern->NewNode(
[](Node* x) { return x && x->IsVar() && VarLinksFromOp(x, "concat"); },
"concat_out");
// Links
sequence_expand0->LinksFrom({sequence_expand0_in})
.LinksTo({sequence_expand0_out});
sequence_expand1->LinksFrom({sequence_expand1_in})
.LinksTo({sequence_expand1_out});
concat->LinksFrom({sequence_expand0_out, sequence_expand1_out, concat_in0})
.LinksTo({concat_out});
return concat_out;
}
PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
PDNode* fc_w = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksToOp(x, "mul") && // link
x->Var()->Proto()->persistable(); // is a parameter
},
"fc_w");
PDNode* mul_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksFromOp(x, "mul") && // link
VarLinksToOp(x, "elementwise_add") && //
!x->Var()->Proto()->persistable(); // is a parameter
},
"mul_out");
PDNode* fc_mul = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "mul"; // basic
},
"fc_mul");
PDNode* fc_bias = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksToOp(x, "elementwise_add") && // link
x->Var()->Proto()->persistable(); // is a parameter
},
"fc_bias");
PDNode* elementwise_add = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "elementwise_add";
},
"elementwise_add");
PDNode* add_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksFromOp(x, "elementwise_add") && // link
!x->Var()->Proto()->persistable(); // is a parameter
},
"add_out");
std::set<std::string> acts({"sigmoid", "tanh", "relu", "identity"});
PDNode* act = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && acts.count(x->Op()->Type());
},
"act");
PDNode* fc_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
!x->Var()->Proto()->persistable(); // is a parameter
},
"fc_out");
fc_mul->LinksFrom({fc_w, fc_x}).LinksTo({mul_out});
elementwise_add->LinksFrom({mul_out, fc_bias}).LinksTo({add_out});
act->LinksFrom({add_out}).LinksTo({fc_out});
return fc_out;
}
std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(graph.get());
GraphPatternDetector detector;
auto* pattern = detector.mutable_pattern();
auto* concat_out = BuildSeqExpandConcatPattern(pattern);
BuildFCPattern(pattern, concat_out);
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "get one concat pattern";
// fc
GET_NODE(fc_w, detector.pattern());
GET_NODE(fc_bias, detector.pattern());
GET_NODE(act, detector.pattern());
GET_NODE(fc_out, detector.pattern());
// concat
GET_NODE(concat_in0, detector.pattern());
GET_NODE(sequence_expand0_in, detector.pattern());
GET_NODE(sequence_expand1_in, detector.pattern());
OpDesc op_desc;
op_desc.SetType("fusion_seqexpand_concat_fc");
op_desc.SetInput("X", {concat_in0->Name(), sequence_expand0_in->Name(),
sequence_expand1_in->Name()});
op_desc.SetInput("FCWeight", {fc_w->Name()});
op_desc.SetInput("FCBias", {fc_bias->Name()});
const std::string fc_out_tmp = fc_out->Name() + ".tmp";
param_scope()->Var(fc_out_tmp)->GetMutable<framework::LoDTensor>();
op_desc.SetOutput("FCOut", {fc_out_tmp});
op_desc.SetOutput("Out", {fc_out->Name()});
op_desc.SetAttr("fc_activation", act->Op()->Type());
auto* op_node = graph->CreateOpNode(&op_desc);
// Add links
#define NODE_LINKS(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
NODE_LINKS(fc_w, op_node);
NODE_LINKS(fc_bias, op_node);
NODE_LINKS(concat_in0, op_node);
NODE_LINKS(sequence_expand0_in, op_node);
NODE_LINKS(sequence_expand1_in, op_node);
NODE_LINKS(op_node, fc_out);
// Clean nodes.
std::unordered_set<const Node*> marked_nodes;
for (auto& item : subgraph) {
marked_nodes.insert(item.second);
}
marked_nodes.erase(fc_w);
marked_nodes.erase(fc_bias);
marked_nodes.erase(concat_in0);
marked_nodes.erase(sequence_expand0_in);
marked_nodes.erase(sequence_expand1_in);
marked_nodes.erase(fc_out);
GraphSafeRemoveNodes(graph, marked_nodes);
});
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(seq_concat_fc_fuse_pass,
paddle::framework::ir::SeqConcatFcFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class SeqConcatFcFusePass : public FusePassBase {
public:
virtual ~SeqConcatFcFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -95,6 +95,12 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs, ...@@ -95,6 +95,12 @@ OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
need_update_ = true; need_update_ = true;
} }
OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
CopyFrom(other);
block_ = block;
need_update_ = true;
}
void OpDesc::CopyFrom(const OpDesc &op_desc) { void OpDesc::CopyFrom(const OpDesc &op_desc) {
desc_.set_type(op_desc.Type()); desc_.set_type(op_desc.Type());
inputs_ = op_desc.inputs_; inputs_ = op_desc.inputs_;
...@@ -131,8 +137,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) ...@@ -131,8 +137,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
for (const proto::OpDesc::Attr &attr : desc_.attrs()) { for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name(); std::string attr_name = attr.name();
// The sub_block referred to by the BLOCK attr hasn't been added // The sub_block referred to by the BLOCK attr hasn't been added
// to ProgramDesc class yet, we skip setting BLOCK attr here. // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS attr here.
if (attr.type() != proto::AttrType::BLOCK) { if (attr.type() != proto::AttrType::BLOCK &&
attr.type() != proto::AttrType::BLOCKS) {
attrs_[attr_name] = GetAttrValue(attr); attrs_[attr_name] = GetAttrValue(attr);
} }
} }
......
...@@ -37,11 +37,7 @@ class OpDesc { ...@@ -37,11 +37,7 @@ class OpDesc {
explicit OpDesc(BlockDesc *block) : block_(block) {} explicit OpDesc(BlockDesc *block) : block_(block) {}
OpDesc(const OpDesc &other, BlockDesc *block) { OpDesc(const OpDesc &other, BlockDesc *block);
*this = other;
block_ = block;
need_update_ = true;
}
void CopyFrom(const OpDesc &op_desc); void CopyFrom(const OpDesc &op_desc);
......
...@@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { ...@@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
InitFromProto(); InitFromProto();
} }
void ProgramDesc::CopyFrom(const proto::ProgramDesc &desc) {
blocks_.clear();
desc_ = desc;
InitFromProto();
}
ProgramDesc::ProgramDesc(const std::string &binary_str) { ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str), PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string."); "Fail to parse program_desc from binary string.");
...@@ -111,10 +117,16 @@ void ProgramDesc::InitFromProto() { ...@@ -111,10 +117,16 @@ void ProgramDesc::InitFromProto() {
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() { const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
auto &global_block = Block(0); auto &global_block = Block(0);
// The order of feed_target_names must follow the index specified in `col`.
// since feed operator's order doesn't necessary follow 'col'.
std::vector<std::string> feed_target_names; std::vector<std::string> feed_target_names;
for (auto *op : global_block.AllOps()) { for (auto *op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) { if (op->Type() == kFeedOpType) {
feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]); int col = boost::get<int>(op->GetAttr("col"));
if (col >= feed_target_names.size()) {
feed_target_names.resize(col + 1);
}
feed_target_names[col] = op->Output("Out")[0];
} }
} }
return feed_target_names; return feed_target_names;
...@@ -122,10 +134,16 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() { ...@@ -122,10 +134,16 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
const std::vector<std::string> ProgramDesc::GetFetchTargetNames() { const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
auto &global_block = Block(0); auto &global_block = Block(0);
// The order of fetch_target_names must follow the index specified in `col`.
// since fetch operator's order doesn't necessary follow 'col'.
std::vector<std::string> fetch_target_names; std::vector<std::string> fetch_target_names;
for (auto *op : global_block.AllOps()) { for (auto *op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) { if (op->Type() == kFetchOpType) {
fetch_target_names.push_back(op->Input("X")[0]); int col = boost::get<int>(op->GetAttr("col"));
if (col >= fetch_target_names.size()) {
fetch_target_names.resize(col + 1);
}
fetch_target_names[col] = op->Input("X")[0];
} }
} }
return fetch_target_names; return fetch_target_names;
......
...@@ -53,6 +53,8 @@ class ProgramDesc { ...@@ -53,6 +53,8 @@ class ProgramDesc {
void Flush(); void Flush();
void CopyFrom(const proto::ProgramDesc &desc);
proto::ProgramDesc *Proto(); proto::ProgramDesc *Proto();
// The output variable of feed_op is referenced as feed_target. // The output variable of feed_op is referenced as feed_target.
......
...@@ -40,7 +40,11 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type, ...@@ -40,7 +40,11 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type,
"When calling this method, the Tensor's numel must be " "When calling this method, the Tensor's numel must be "
"equal or larger than zero. " "equal or larger than zero. "
"Please check Tensor::Resize has been called first."); "Please check Tensor::Resize has been called first.");
size_t size = requested_size ? requested_size : numel() * SizeOfType(type); size_t size = numel() * SizeOfType(type);
if (requested_size) {
PADDLE_ENFORCE_GE(requested_size, size);
size = requested_size;
}
/* some versions of boost::variant don't have operator!= */ /* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) || if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) { holder_->size() < size + offset_) {
......
...@@ -26,7 +26,7 @@ namespace paddle { ...@@ -26,7 +26,7 @@ namespace paddle {
namespace framework { namespace framework {
template <typename T> template <typename T>
bool IsType(const std::type_index& type_index) { inline bool IsType(const std::type_index& type_index) {
return type_index == std::type_index(typeid(T)); return type_index == std::type_index(typeid(T));
} }
......
...@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor) ...@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor)
# TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal? # TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal?
cc_library(paddle_fluid_api cc_library(paddle_fluid_api
SRCS io.cc SRCS io.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} graph_to_program_pass)
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
......
cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass) cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass)
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc set(analysis_deps
framework_proto proto_desc ir_pass_manager graph pass paddle_fluid_api executor)
cc_library(analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
analyzer.cc analyzer.cc
helper.cc helper.cc
# passes # passes
...@@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph ...@@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
tensorrt_subgraph_node_mark_pass.cc tensorrt_subgraph_node_mark_pass.cc
fluid_to_ir_pass.cc fluid_to_ir_pass.cc
model_store_pass.cc model_store_pass.cc
DEPS framework_proto proto_desc ir_pass_manager graph pass) DEPS ${analysis_deps})
cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis) cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis paddle_fluid)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
...@@ -31,7 +34,7 @@ function (inference_analysis_test TARGET) ...@@ -31,7 +34,7 @@ function (inference_analysis_test TARGET)
endif() endif()
cc_test(${TARGET} cc_test(${TARGET}
SRCS "${analysis_test_SRCS}" SRCS "${analysis_test_SRCS}"
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass ${analysis_test_EXTRA_DEPS} DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detector pass ${analysis_test_EXTRA_DEPS}
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt}) ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING) endif(WITH_TESTING)
...@@ -58,20 +61,25 @@ endif() ...@@ -58,20 +61,25 @@ endif()
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
analysis_predictor
# ir # ir
fc_fuse_pass fc_fuse_pass
fc_lstm_fuse_pass
seq_concat_fc_fuse_pass
graph_viz_pass graph_viz_pass
infer_clean_graph_pass infer_clean_graph_pass
graph_pattern_detecter graph_pattern_detector
infer_clean_graph_pass infer_clean_graph_pass
attention_lstm_fuse_pass
paddle_inference_api
pass pass
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model
--infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt) --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc EXTRA_DEPS paddle_inference_api)
inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc) inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc EXTRA_DEPS paddle_fluid)
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
......
...@@ -72,7 +72,7 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -72,7 +72,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
auto trt_teller = [&](const Node* node) { auto trt_teller = [&](const Node* node) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax", {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax",
"depthwise_conv2d", "batch_norm"}); "depthwise_conv2d", "batch_norm", "concat"});
if (!node->IsFunction()) return false; if (!node->IsFunction()) return false;
const auto* func = static_cast<const Function*>(node); const auto* func = static_cast<const Function*>(node);
...@@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); } Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) { void Analyzer::Run(Argument* argument) {
// Ungly support fluid-to-ir-pass
argument->Set(kFluidToIrPassesAttr,
new std::vector<std::string>({
// Manual update the passes here.
"graph_viz_pass", //
"infer_clean_graph_pass", "graph_viz_pass", //
"attention_lstm_fuse_pass", "graph_viz_pass", //
"fc_lstm_fuse_pass", "graph_viz_pass", //
"seq_concat_fc_fuse_pass", "graph_viz_pass", //
"fc_fuse_pass", "graph_viz_pass" //
}));
for (auto& x : data_) { for (auto& x : data_) {
PADDLE_ENFORCE(x->Initialize(argument)); PADDLE_ENFORCE(x->Initialize(argument));
x->RunAll(); x->RunAll();
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN");
DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN"); DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN");
...@@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path, ...@@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path,
const std::string &data_path, int batch_size, const std::string &data_path, int batch_size,
bool use_analysis, bool activate_ir, bool use_analysis, bool activate_ir,
int num_times = 1) { int num_times = 1) {
FLAGS_IA_enable_ir = activate_ir;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = "./analysis.out";
std::string model_out;
if (use_analysis) {
Argument argument(model_path);
argument.model_output_store_path.reset(new std::string("./analysis.out"));
Analyzer analyzer;
analyzer.Run(&argument);
// Should get the transformed model stored to ./analysis.out
model_out = "./analysis.out";
ASSERT_TRUE(PathExists(model_out));
} else {
model_out = FLAGS_infer_ditu_rnn_model;
}
NativeConfig config; NativeConfig config;
config.prog_file = model_out + "/__model__"; config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
config.param_file = model_out + "/param"; config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
config.use_gpu = false; config.use_gpu = false;
config.device = 0; config.device = 0;
config.specify_input_name = true; config.specify_input_name = true;
auto predictor = auto base_predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config); CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
auto predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(config);
std::vector<PaddleTensor> input_slots; std::vector<PaddleTensor> input_slots;
DataRecord data(data_path, batch_size); DataRecord data(data_path, batch_size);
// Prepare inputs. // Prepare inputs.
PrepareInputs(&input_slots, &data, batch_size); PrepareInputs(&input_slots, &data, batch_size);
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs, base_outputs;
base_predictor->Run(input_slots, &base_outputs);
Timer timer; Timer timer;
timer.tic(); timer.tic();
...@@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path, ...@@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path,
<< ", latency: " << timer.toc() / num_times << "ms"; << ", latency: " << timer.toc() / num_times << "ms";
LOG(INFO) << "====================================="; LOG(INFO) << "=====================================";
for (auto &out : outputs) { PADDLE_ENFORCE_GT(outputs.size(), 0);
PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
auto &out = outputs[i];
auto &base_out = base_outputs[i];
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1, size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
[](int a, int b) { return a * b; }); [](int a, int b) { return a * b; });
size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(),
1, [](int a, int b) { return a * b; });
PADDLE_ENFORCE_EQ(size, size1);
PADDLE_ENFORCE_GT(size, 0);
float *data = static_cast<float *>(out.data.data()); float *data = static_cast<float *>(out.data.data());
for (size_t i = 0; float *base_data = static_cast<float *>(base_out.data.data());
i < std::min(sizeof(ditu_rnn_target_data) / sizeof(float), size); for (size_t i = 0; i < size; i++) {
i++) { EXPECT_NEAR(data[i], base_data[i], 1e-3);
EXPECT_NEAR(data[i], ditu_rnn_target_data[i], 1e-3);
} }
} }
} }
// Turn on the IR pass supportion, run a real inference and check the result.
TEST(Analyzer, SupportIRPass) {
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = "./analysis.out";
Argument argument(FLAGS_inference_model_dir);
argument.model_output_store_path.reset(new std::string("./analysis.out"));
Analyzer analyzer;
analyzer.Run(&argument);
// Should get the transformed model stored to ./analysis.out
ASSERT_TRUE(PathExists("./analysis.out"));
// Inference from this path.
TestWord2vecPrediction("./analysis.out");
}
// Directly infer with the original model. // Directly infer with the original model.
TEST(Analyzer, DituRNN_without_analysis) { TEST(Analyzer, DituRNN_without_analysis) {
TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data, TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
...@@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) { ...@@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) {
} // namespace paddle } // namespace paddle
USE_PASS(fc_fuse_pass); USE_PASS(fc_fuse_pass);
USE_PASS(seq_concat_fc_fuse_pass);
USE_PASS(fc_lstm_fuse_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass); USE_PASS(infer_clean_graph_pass);
USE_PASS(attention_lstm_fuse_pass);
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <string> #include <string>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -58,6 +59,46 @@ struct Argument { ...@@ -58,6 +59,46 @@ struct Argument {
// The output storage path of ModelStorePass. // The output storage path of ModelStorePass.
std::unique_ptr<std::string> model_output_store_path; std::unique_ptr<std::string> model_output_store_path;
// Support for any other attributes.
template <typename T>
void Set(const std::string& key, T* data) {
PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key);
attrs_[key] = data;
attr_deleters_[key] = [data, key, this]() {
VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
VLOG(3) << "argument delete attr: " << key;
delete data;
};
}
bool Has(const std::string& name) const { return attrs_.count(name); }
template <typename T>
T* Release(const std::string& key) {
PADDLE_ENFORCE(attrs_.count(key));
auto* res = boost::any_cast<T*>(attrs_.at(key));
attrs_.erase(key);
attr_deleters_.erase(key);
return res;
}
template <typename T>
T& Get(const std::string& key) {
PADDLE_ENFORCE(Has(key));
return *boost::any_cast<T*>(attrs_.at(key));
}
~Argument() {
for (auto& item : attr_deleters_) {
item.second();
}
}
private:
std::unordered_map<std::string, boost::any> attrs_;
std::unordered_map<std::string, std::function<void()>> attr_deleters_;
}; };
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/io.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { ...@@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
} }
} }
if (argument_->Has("param_scope")) {
LOG(WARNING) << "parameter changes in the scope takes effect";
}
PADDLE_ENFORCE(argument_->transformed_program_desc.get()); PADDLE_ENFORCE(argument_->transformed_program_desc.get());
} }
......
...@@ -29,13 +29,13 @@ namespace paddle { ...@@ -29,13 +29,13 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
static size_t dot_node_counter{0};
/* /*
* A Dot template that helps to build a DOT graph definition. * A Dot template that helps to build a DOT graph definition.
*/ */
class Dot { class Dot {
public: public:
static size_t counter;
struct Attr { struct Attr {
std::string key; std::string key;
std::string value; std::string value;
...@@ -57,7 +57,7 @@ class Dot { ...@@ -57,7 +57,7 @@ class Dot {
Node(const std::string& name, const std::vector<Attr>& attrs) Node(const std::string& name, const std::vector<Attr>& attrs)
: name(name), : name(name),
attrs(attrs), attrs(attrs),
id_("node_" + std::to_string(Dot::counter++)) {} id_("node_" + std::to_string(dot_node_counter++)) {}
std::string id() const { return id_; } std::string id() const { return id_; }
...@@ -65,6 +65,10 @@ class Dot { ...@@ -65,6 +65,10 @@ class Dot {
std::stringstream ss; std::stringstream ss;
CHECK(!name.empty()); CHECK(!name.empty());
ss << id_; ss << id_;
if (attrs.empty()) {
ss << "[label=" << '"' << name << '"' << "]";
return ss.str();
}
for (size_t i = 0; i < attrs.size(); i++) { for (size_t i = 0; i < attrs.size(); i++) {
if (i == 0) { if (i == 0) {
ss << "[label=" << '"' << name << '"' << " "; ss << "[label=" << '"' << name << '"' << " ";
...@@ -108,9 +112,11 @@ class Dot { ...@@ -108,9 +112,11 @@ class Dot {
explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {} explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {}
void AddNode(const std::string& name, const std::vector<Attr>& attrs) { void AddNode(const std::string& id, const std::vector<Attr>& attrs,
CHECK(!nodes_.count(name)) << "duplicate Node '" << name << "'"; std::string label = "") {
nodes_.emplace(name, Node{name, attrs}); CHECK(!nodes_.count(id)) << "duplicate Node '" << id << "'";
if (label.empty()) label = id;
nodes_.emplace(id, Node{label, attrs});
} }
void AddEdge(const std::string& source, const std::string& target, void AddEdge(const std::string& source, const std::string& target,
......
...@@ -13,3 +13,47 @@ ...@@ -13,3 +13,47 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace inference {
namespace analysis {
void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const std::string &prog_file,
const std::string &param_file) {
PADDLE_ENFORCE(argument_);
argument_->Set("param_scope", new framework::Scope);
// Load parameters.
VLOG(3) << "Loading parameters from " << model_dir;
LoadParams(&argument_->Get<framework::Scope>("param_scope"), model_dir,
prog_file, param_file);
}
bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir,
const std::string &prog_file,
const std::string &param_file) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
framework::Executor executor(place);
PADDLE_ENFORCE(argument_->origin_program_desc.get());
framework::ProgramDesc program(*argument_->origin_program_desc);
if ((!prog_file.empty()) && (!param_file.empty())) {
LOG(INFO) << "load single model file from " << prog_file;
Load(&executor, scope, prog_file, param_file);
} else if (!dir.empty()) {
LOG(INFO) << "load from dir " << dir;
Load(&executor, scope, dir);
} else {
LOG(ERROR) << "failed to load parameters";
return false;
}
return true;
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -21,12 +21,17 @@ namespace paddle { ...@@ -21,12 +21,17 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
class FluidToIrPass final : public DataFlowGraphPass { class FluidToIrPass final : public DataFlowGraphPass {
public: public:
FluidToIrPass() = default; FluidToIrPass() = default;
bool Initialize(Argument *argument) override { bool Initialize(Argument *argument) override {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument); ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
PADDLE_ENFORCE(argument->Has(kFluidToIrPassesAttr),
"argument need the attr %s", kFluidToIrPassesAttr);
argument_ = argument;
if (argument->origin_program_desc) { if (argument->origin_program_desc) {
LOG(WARNING) << "argument's origin_program_desc is already set, might " LOG(WARNING) << "argument's origin_program_desc is already set, might "
"duplicate called"; "duplicate called";
...@@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass {
if (!argument->main_dfg) { if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph); argument->main_dfg.reset(new DataFlowGraph);
} }
// Persist the ProgramDesc in graph's attribute. The IR graph just keep the argument->Set("ir_program_desc", new framework::ProgramDesc(program));
// address, will segfault if the original ProgramDesc destroys.
auto &ir_program_p = argument->main_dfg->Attr("ir_program_desc").Pointer(); LOG(INFO) << "Loading parameters";
ir_program_p = new framework::ProgramDesc(program); // Load parameters to argument if needed.
if (argument->fluid_model_dir || (argument->fluid_model_program_path &&
argument->fluid_model_param_path)) {
#define SAFE_GET(ATTR) std::string ATTR = argument->ATTR ? *argument->ATTR : "";
SAFE_GET(fluid_model_dir);
SAFE_GET(fluid_model_program_path);
SAFE_GET(fluid_model_param_path);
#undef SAFE_GET
EnableParamModify(fluid_model_dir, fluid_model_program_path,
fluid_model_param_path);
}
argument_ = argument;
return true; return true;
} }
...@@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass {
void Run(DataFlowGraph *graph) override { void Run(DataFlowGraph *graph) override {
// Call all the IR Passes // Call all the IR Passes
IRPassManager ir_passes(*static_cast<framework::ProgramDesc *>( IRPassManager ir_passes(
argument_->main_dfg->Attr("ir_program_desc").Pointer())); argument_->Get<framework::ProgramDesc>("ir_program_desc"), nullptr);
ir_passes.Apply(std::vector<std::string>( // Pass the scope from analysis to IR if needed.
{// Manual update the passes here. if (argument_->Has("param_scope")) {
"graph_viz_pass", "infer_clean_graph_pass", "graph_viz_pass", // Here the address is passed, attention that IR doesn't own the scope, so
"fc_fuse_pass", "graph_viz_pass"})); // the real scope in analysis should live during the IR phase.
ir_passes.graph().Set(
"param_scope", new framework::Scope *(
&argument_->Get<framework::Scope>("param_scope")));
}
const auto &ir_passes_to_apply =
argument_->Get<std::vector<std::string>>(kFluidToIrPassesAttr);
ir_passes.Apply(ir_passes_to_apply);
PADDLE_ENFORCE(argument_->main_dfg.get()); PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph()); argument_->main_dfg->Build(ir_passes.graph());
// PADDLE_ENFORCE(argument_->main_dfg->IsFullyConnected());
} }
void EnableParamModify(const std::string &model_dir,
const std::string &prog_file,
const std::string &param_file);
std::string repr() const override { return "fluid-to-ir-pass"; } std::string repr() const override { return "fluid-to-ir-pass"; }
private:
// Load parameters from a single file or from a directory.
bool LoadParams(framework::Scope *scope, const std::string &dir,
const std::string &prog_file, const std::string &param_file);
private: private:
Argument *argument_{nullptr}; Argument *argument_{nullptr};
}; };
......
...@@ -24,6 +24,8 @@ namespace analysis { ...@@ -24,6 +24,8 @@ namespace analysis {
TEST(FluidToIrPass, Test) { TEST(FluidToIrPass, Test) {
FluidToIrPass pass; FluidToIrPass pass;
Argument argument(FLAGS_inference_model_dir); Argument argument(FLAGS_inference_model_dir);
argument.Set(kFluidToIrPassesAttr,
new std::vector<std::string>({"infer_clean_graph_pass"}));
pass.Initialize(&argument); pass.Initialize(&argument);
pass.Run(argument.main_dfg.get()); pass.Run(argument.main_dfg.get());
} }
...@@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) { ...@@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) {
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass); USE_PASS(infer_clean_graph_pass);
USE_PASS(attention_lstm_fuse_pass);
USE_PASS(fc_lstm_fuse_pass);
USE_PASS(seq_concat_fc_fuse_pass);
USE_PASS(fc_fuse_pass);
...@@ -14,20 +14,24 @@ ...@@ -14,20 +14,24 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
IRPassManager::IRPassManager(const ProgramDesc& program) { IRPassManager::IRPassManager(const ProgramDesc &program,
framework::Scope *scope)
: program_(program) {
graph_.reset(new framework::ir::Graph(program)); graph_.reset(new framework::ir::Graph(program));
if (scope) graph_->Set("param_scope", new framework::Scope *(scope));
} }
void IRPassManager::Apply(const std::vector<std::string>& passes) { void IRPassManager::Apply(const std::vector<std::string> &passes) {
graph_->Set("graph_viz_path", new std::string("./1.dot"));
// Apply all the passes // Apply all the passes
std::string pre_pass; std::string pre_pass;
for (const std::string& pass_name : passes) { for (const std::string &pass_name : passes) {
LOG(WARNING) << "Running IR pass [" << pass_name << "]"; LOG(WARNING) << "Running IR pass [" << pass_name << "]";
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -31,14 +32,15 @@ using framework::ProgramDesc; ...@@ -31,14 +32,15 @@ using framework::ProgramDesc;
class IRPassManager final { class IRPassManager final {
public: public:
IRPassManager(const ProgramDesc& program); IRPassManager(const ProgramDesc &program, framework::Scope *scope);
void Apply(const std::vector<std::string>& passes); void Apply(const std::vector<std::string> &passes);
framework::ir::Graph& graph() const { return *graph_; } framework::ir::Graph &graph() const { return *graph_; }
private: private:
std::unique_ptr<framework::ir::Graph> graph_; std::unique_ptr<framework::ir::Graph> graph_;
ProgramDesc program_;
}; };
} // namespace analysis } // namespace analysis
......
...@@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) { ...@@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) {
void DfgPassManager::RunAll() { void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
LOG(INFO) << "Total " << data_.size() << " passes"; LOG(INFO) << "Total " << data_.size() << " Analysys passes";
for (auto& pass : data_) { for (auto& pass : data_) {
LOG(WARNING) << "Running pass [" << pass->repr() << "]"; LOG(WARNING) << "Running Analysis pass [" << pass->repr() << "]";
pass->Run(argument_->main_dfg.get()); pass->Run(argument_->main_dfg.get());
} }
} }
......
...@@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME) ...@@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME)
endif(WITH_TESTING) endif(WITH_TESTING)
endfunction(inference_api_test) endfunction(inference_api_test)
cc_library(paddle_inference_api SRCS api.cc api_impl.cc DEPS lod_tensor) cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor)
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api)
cc_test(test_paddle_inference_api cc_test(test_paddle_inference_api
SRCS api_tester.cc SRCS api_tester.cc
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
using inference::analysis::Argument;
using inference::Singleton;
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
explicit AnalysisPredictor(const NativeConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
VLOG(3) << "Predictor::init()";
if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device);
} else {
place_ = paddle::platform::CPUPlace();
}
PADDLE_ENFORCE(!parent_scope);
if (parent_scope) {
scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope());
} else {
paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope());
}
executor_.reset(new paddle::framework::Executor(place_));
// Initialize the inference program
if (!config_.model_dir.empty()) {
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.model_dir);
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
} else {
LOG(ERROR) << "fail to load inference model.";
return false;
}
OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0);
VLOG(5) << "to create variables";
PADDLE_ENFORCE(scope_.get());
executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names
feed_target_names_ = inference_program_->GetFeedTargetNames();
fetch_target_names_ = inference_program_->GetFetchTargetNames();
return true;
}
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override {
return NativePaddlePredictor::Run(inputs, output_data, batch_size);
}
void OptimizeInferenceProgram() {
LOG(INFO) << "optimize begin";
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = ""; // Don't output the model.
// Analyze inference_program
Argument argument;
if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir));
} else {
PADDLE_ENFORCE(
!config_.param_file.empty(),
"Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file.empty());
argument.fluid_model_program_path.reset(
new std::string(config_.prog_file));
argument.fluid_model_param_path.reset(
new std::string(config_.param_file));
}
argument.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument);
CHECK(argument.transformed_program_desc);
VLOG(5) << "to prepare executor";
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_.reset(
new framework::ProgramDesc(*argument.transformed_program_desc));
PADDLE_ENFORCE(argument.Has("param_scope"));
// Update scope.
scope_.reset(argument.Release<framework::Scope>("param_scope"));
LOG(INFO) << "optimize end ==";
}
private:
NativeConfig config_;
};
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
NativeConfig, PaddleEngineKind::kAnalysis>(const NativeConfig& config) {
VLOG(3) << "create NativePredictor";
if (config.use_gpu) {
// 1. GPU memeroy
PADDLE_ENFORCE_GT(
config.fraction_of_gpu_memory, 0.f,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]");
PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device);
std::vector<std::string> flags;
if (config.fraction_of_gpu_memory >= 0.0f ||
config.fraction_of_gpu_memory <= 0.95f) {
flags.push_back("dummpy");
std::string flag = "--fraction_of_gpu_memory_to_use=" +
std::to_string(config.fraction_of_gpu_memory);
flags.push_back(flag);
VLOG(3) << "set flag: " << flag;
framework::InitGflags(flags);
}
}
std::unique_ptr<PaddlePredictor> predictor(new AnalysisPredictor(config));
if (!dynamic_cast<AnalysisPredictor*>(predictor.get())->Init(nullptr)) {
return nullptr;
}
return predictor;
}
} // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass);
...@@ -32,6 +32,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -32,6 +32,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
: NativePaddlePredictor(config), config_(config) {} : NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) { bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
VLOG(3) << "Predictor::init()"; VLOG(3) << "Predictor::init()";
FLAGS_tensorrt_max_batch_size = config_.max_batch_size; FLAGS_tensorrt_max_batch_size = config_.max_batch_size;
FLAGS_tensorrt_workspace_size = config_.workspace_size; FLAGS_tensorrt_workspace_size = config_.workspace_size;
...@@ -161,3 +162,4 @@ USE_TRT_CONVERTER(fc); ...@@ -161,3 +162,4 @@ USE_TRT_CONVERTER(fc);
USE_TRT_CONVERTER(pool2d); USE_TRT_CONVERTER(pool2d);
USE_TRT_CONVERTER(softmax); USE_TRT_CONVERTER(softmax);
USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat);
...@@ -37,6 +37,7 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) { ...@@ -37,6 +37,7 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) {
config1.use_gpu = true; config1.use_gpu = true;
config1.fraction_of_gpu_memory = 0.3; config1.fraction_of_gpu_memory = 0.3;
config1.device = 0; config1.device = 0;
config1.max_batch_size = 10;
auto predictor0 = auto predictor0 =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0); CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/helper.h"
namespace paddle {
namespace inference {
template <>
std::string to_string<std::vector<float>>(
const std::vector<std::vector<float>> &vec) {
std::stringstream ss;
for (const auto &piece : vec) {
ss << to_string(piece) << "\n";
}
return ss.str();
}
template <>
std::string to_string<std::vector<std::vector<float>>>(
const std::vector<std::vector<std::vector<float>>> &vec) {
std::stringstream ss;
for (const auto &line : vec) {
for (const auto &rcd : line) {
ss << to_string(rcd) << ";\t";
}
ss << '\n';
}
return ss.str();
}
} // namespace inference
} // namespace paddle
...@@ -44,7 +44,8 @@ class Timer { ...@@ -44,7 +44,8 @@ class Timer {
} }
}; };
void split(const std::string &str, char sep, std::vector<std::string> *pieces) { static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear(); pieces->clear();
if (str.empty()) { if (str.empty()) {
return; return;
...@@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector<std::string> *pieces) { ...@@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector<std::string> *pieces) {
pieces->push_back(str.substr(pos)); pieces->push_back(str.substr(pos));
} }
} }
void split_to_float(const std::string &str, char sep, std::vector<float> *fs) { static void split_to_float(const std::string &str, char sep,
std::vector<float> *fs) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(str, sep, &pieces); split(str, sep, &pieces);
std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs), std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs),
...@@ -76,27 +78,14 @@ std::string to_string(const std::vector<T> &vec) { ...@@ -76,27 +78,14 @@ std::string to_string(const std::vector<T> &vec) {
} }
template <> template <>
std::string to_string<std::vector<float>>( std::string to_string<std::vector<float>>(
const std::vector<std::vector<float>> &vec) { const std::vector<std::vector<float>> &vec);
std::stringstream ss;
for (const auto &piece : vec) {
ss << to_string(piece) << "\n";
}
return ss.str();
}
template <> template <>
std::string to_string<std::vector<std::vector<float>>>( std::string to_string<std::vector<std::vector<float>>>(
const std::vector<std::vector<std::vector<float>>> &vec) { const std::vector<std::vector<std::vector<float>>> &vec);
std::stringstream ss;
for (const auto &line : vec) {
for (const auto &rcd : line) {
ss << to_string(rcd) << ";\t";
}
ss << '\n';
}
return ss.str();
}
// clang-format off // clang-format off
void TensorAssignData(PaddleTensor *tensor, const std::vector<std::vector<float>> &data) { static void TensorAssignData(PaddleTensor *tensor, const std::vector<std::vector<float>> &data) {
// Assign buffer // Assign buffer
int dim = std::accumulate(tensor->shape.begin(), tensor->shape.end(), 1, [](int a, int b) { return a * b; }); int dim = std::accumulate(tensor->shape.begin(), tensor->shape.end(), 1, [](int a, int b) { return a * b; });
tensor->data.Resize(sizeof(float) * dim); tensor->data.Resize(sizeof(float) * dim);
......
...@@ -77,6 +77,7 @@ enum class PaddleEngineKind { ...@@ -77,6 +77,7 @@ enum class PaddleEngineKind {
kNative = 0, // Use the native Fluid facility. kNative = 0, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference. kAnakin, // Use Anakin for inference.
kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT. kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
kAnalysis
// TODO(Superjomn) support following engines latter. // TODO(Superjomn) support following engines latter.
// kTensorRT, // Use TensorRT for inference. // kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin. // kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
......
...@@ -143,5 +143,21 @@ std::unique_ptr<framework::ProgramDesc> Load( ...@@ -143,5 +143,21 @@ std::unique_ptr<framework::ProgramDesc> Load(
return main_program; return main_program;
} }
void SaveVars(const framework::Scope& scope,
const std::vector<std::string>& vars, const std::string& dirname,
bool predicate) {
framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("save_combine");
op->SetInput("X", vars);
op->SetAttr("file_path", dirname + "/param");
op->CheckAttrs();
platform::CPUPlace place;
framework::Executor exe(place);
exe.Run(prog, const_cast<framework::Scope*>(&scope), 0, true, true);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -41,5 +41,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, ...@@ -41,5 +41,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
const std::string& prog_filename, const std::string& prog_filename,
const std::string& param_filename); const std::string& param_filename);
// Save the variables from a scope to disk.
void SaveVars(const framework::Scope& scope,
const std::vector<std::string>& vars, const std::string& dirname,
bool predicate = true);
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry) DEPS tensorrt_engine operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
...@@ -18,12 +18,12 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc ...@@ -18,12 +18,12 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL)
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL)
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL)
nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL)
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL)
nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
*/
class ConcatOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
std::vector<nvinfer1::ITensor*> itensors;
for (auto& input_name : op_desc.Input("X")) {
itensors.push_back(engine_->GetITensor(input_name));
}
int axis = boost::get<int>(op_desc.GetAttr("axis"));
PADDLE_ENFORCE(axis > 0,
"The axis attr of Concat op should be large than 0 for trt");
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(),
itensors.size());
axis = axis - 1; // Remove batch dim
layer->setAxis(axis);
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(concat, ConcatOpConverter);
...@@ -79,6 +79,14 @@ class OpConverter { ...@@ -79,6 +79,14 @@ class OpConverter {
it = it =
Registry<OpConverter>::Lookup("elementwise_" + op_type + "_tensor"); Registry<OpConverter>::Lookup("elementwise_" + op_type + "_tensor");
} }
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type());
}
if (op_desc.Type() == "depthwise_conv2d") {
it = Registry<OpConverter>::Lookup("conv2d");
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type());
} }
if (!it) { if (!it) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(concat_op, test) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("concat_x1", nvinfer1::DimsCHW(10, 3, 1));
validator.DeclInputVar("concat_x2", nvinfer1::DimsCHW(3, 3, 1));
validator.DeclInputVar("concat_x3", nvinfer1::DimsCHW(7, 3, 1));
validator.DeclOutputVar("concat_out", nvinfer1::DimsCHW(20, 3, 1));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("concat");
desc.SetInput("X", {"concat_x1", "concat_x2", "concat_x3"});
desc.SetOutput("Out", {"concat_out"});
int axis = 1;
desc.SetAttr("axis", axis);
validator.SetOp(*desc.Proto());
validator.Execute(5);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(concat);
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -135,6 +136,15 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes( ...@@ -135,6 +136,15 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
return feed_target_shapes; return feed_target_shapes;
} }
void Compile(paddle::framework::ProgramDesc* program) {
std::unique_ptr<paddle::framework::ir::Graph> g(
new paddle::framework::ir::Graph(*program));
auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
"graph_to_program_pass");
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", program);
pass->Apply(std::move(g));
}
template <typename Place, bool CreateVars = true, bool PrepareContext = false> template <typename Place, bool CreateVars = true, bool PrepareContext = false>
void TestInference(const std::string& dirname, void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
...@@ -172,6 +182,8 @@ void TestInference(const std::string& dirname, ...@@ -172,6 +182,8 @@ void TestInference(const std::string& dirname,
paddle::platform::DeviceContextPool::Instance().Get(place)); paddle::platform::DeviceContextPool::Instance().Get(place));
inference_program = InitProgram(&executor, scope, dirname, is_combined); inference_program = InitProgram(&executor, scope, dirname, is_combined);
} }
Compile(inference_program.get());
// Disable the profiler and print the timing information // Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
"load_program_profiler"); "load_program_profiler");
...@@ -249,3 +261,5 @@ void TestInference(const std::string& dirname, ...@@ -249,3 +261,5 @@ void TestInference(const std::string& dirname,
delete scope; delete scope;
} }
USE_PASS(graph_to_program_pass);
...@@ -291,6 +291,8 @@ op_library(unsqueeze_op DEPS reshape_op) ...@@ -291,6 +291,8 @@ op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory) op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op) op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op)
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
...@@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
const int D = w_dims[1] / 4; const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], D + M, PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D);
auto b_dims = ctx->GetInputDim("LSTMBias"); auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
......
...@@ -60,20 +60,6 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -60,20 +60,6 @@ class AucKernel : public framework::OpKernel<T> {
const T* inference_data = predict->data<T>(); const T* inference_data = predict->data<T>();
const auto* label_data = label->data<int64_t>(); const auto* label_data = label->data<int64_t>();
// check if states are inited.
auto* tp_in = ctx.Input<Tensor>("TP");
auto* fp_in = ctx.Input<Tensor>("FP");
auto* tn_in = ctx.Input<Tensor>("TN");
auto* fn_in = ctx.Input<Tensor>("FN");
PADDLE_ENFORCE(tp_in->IsInitialized(), "true_positive is not inited!");
PADDLE_ENFORCE(fp_in->IsInitialized(), "false_negative is not inited!");
PADDLE_ENFORCE(tn_in->IsInitialized(), "true_negative is not inited!");
PADDLE_ENFORCE(fn_in->IsInitialized(), "false_positive is not inited!");
PADDLE_ENFORCE_EQ(tp_in->numel(), num_thresholds, "");
PADDLE_ENFORCE_EQ(fp_in->numel(), num_thresholds, "");
PADDLE_ENFORCE_EQ(tn_in->numel(), num_thresholds, "");
PADDLE_ENFORCE_EQ(fn_in->numel(), num_thresholds, "");
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace()); auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace()); auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace()); auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace());
......
...@@ -29,6 +29,6 @@ target_assign_op.cu) ...@@ -29,6 +29,6 @@ target_assign_op.cu)
detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
polygon_box_transform_op.cu) polygon_box_transform_op.cu)
detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
# Export local libraries to parent #Export local libraries to parent
set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
struct AppendProposalsFunctor {
LoDTensor *out_;
int64_t offset_;
Tensor *to_add_;
AppendProposalsFunctor(LoDTensor *out, int64_t offset, Tensor *to_add)
: out_(out), offset_(offset), to_add_(to_add) {}
template <typename T>
void operator()() const {
auto *out_data = out_->data<T>();
auto *to_add_data = to_add_->data<T>();
memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T));
}
};
class GenerateProposalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Scores"), "Input(Scores) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("BboxDeltas"),
"Input(BboxDeltas) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("ImInfo"), "Input(ImInfo) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Anchors"),
"Input(Anchors) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Variances"),
"Input(Variances) shouldn't be null.");
auto scores_dims = ctx->GetInputDim("Scores");
auto bbox_deltas_dims = ctx->GetInputDim("BboxDeltas");
auto im_info_dims = ctx->GetInputDim("ImInfo");
auto anchors_dims = ctx->GetInputDim("Anchors");
auto variances_dims = ctx->GetInputDim("Variances");
ctx->SetOutputDim("RpnRois", {-1, 4});
ctx->SetOutputDim("RpnRoiProbs", {-1, 1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Anchors")->type()),
platform::CPUPlace());
}
};
template <class T>
void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
Tensor *bbox_deltas, Tensor *variances, Tensor *proposals) {
T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace());
int64_t row = all_anchors->dims()[0];
int64_t len = all_anchors->dims()[1];
auto *bbox_deltas_data = bbox_deltas->data<T>();
auto *anchor_data = all_anchors->data<T>();
const T *variances_data = nullptr;
if (variances) {
variances_data = variances->data<T>();
}
for (int64_t i = 0; i < row; ++i) {
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len];
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1];
T anchor_center_x = (anchor_data[i * len + 2] + anchor_data[i * len]) / 2;
T anchor_center_y =
(anchor_data[i * len + 3] + anchor_data[i * len + 1]) / 2;
T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0;
if (variances) {
bbox_center_x =
variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width +
anchor_center_x;
bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y;
bbox_width = std::exp(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2]) *
anchor_width;
bbox_height = std::exp(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3]) *
anchor_height;
} else {
bbox_center_x =
bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(bbox_deltas_data[i * len + 2]) * anchor_width;
bbox_height = std::exp(bbox_deltas_data[i * len + 3]) * anchor_height;
}
proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2;
}
// return proposals;
}
template <class T>
void ClipTiledBoxes(const platform::DeviceContext &ctx, const Tensor &im_info,
Tensor *boxes) {
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
const T *im_info_data = im_info.data<T>();
for (int64_t i = 0; i < boxes->numel(); ++i) {
if (i % 4 == 0) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f);
} else if (i % 4 == 1) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f);
} else if (i % 4 == 2) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f);
} else {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f);
}
}
}
template <class T>
void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
float min_size, const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
min_size *= im_info_data[2];
keep->Resize({boxes->dims()[0], 1});
int *keep_data = keep->mutable_data<int>(ctx.GetPlace());
int keep_len = 0;
for (int i = 0; i < boxes->dims()[0]; ++i) {
T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
T x_ctr = boxes_data[4 * i] + ws / 2;
T y_ctr = boxes_data[4 * i + 1] + hs / 2;
if (ws >= min_size && hs >= min_size && x_ctr <= im_info_data[1] &&
y_ctr <= im_info_data[0]) {
keep_data[keep_len++] = i;
}
}
keep->Resize({keep_len});
}
bool SortScorePairDescend(const std::pair<float, int> &pair1,
const std::pair<float, int> &pair2) {
return pair1.first > pair2.first;
}
template <class T>
void GetMaxScoreIndex(const std::vector<T> &scores,
std::vector<std::pair<T, int>> *sorted_indices) {
for (size_t i = 0; i < scores.size(); ++i) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
SortScorePairDescend);
}
template <class T>
T BBoxArea(const T *box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = inter_xmax - inter_xmin;
const T inter_h = inter_ymax - inter_ymin;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <class T>
Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
const T nms_threshold, const float eta) {
PADDLE_ENFORCE_NOT_NULL(bbox);
int64_t num_boxes = bbox->dims()[0];
// 4: [xmin ymin xmax ymax]
int64_t box_size = bbox->dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
GetMaxScoreIndex<T>(scores_data, &sorted_indices);
std::vector<int> selected_indices;
int selected_num = 0;
T adaptive_threshold = nms_threshold;
const T *bbox_data = bbox->data<T>();
bool flag;
while (sorted_indices.size() != 0) {
int idx = sorted_indices.front().second;
flag = true;
for (size_t k = 0; k < selected_indices.size(); ++k) {
if (flag) {
const int kept_idx = selected_indices[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, false);
flag = (overlap <= adaptive_threshold);
} else {
break;
}
}
if (flag) {
selected_indices.push_back(idx);
selected_num++;
}
sorted_indices.erase(sorted_indices.begin());
if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
Tensor keep_nms;
keep_nms.Resize({selected_num});
int *keep_data = keep_nms.mutable_data<int>(ctx.GetPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
}
template <typename DeviceContext, typename T>
class GenerateProposalsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo");
auto *anchors = context.Input<Tensor>("Anchors");
auto *variances = context.Input<Tensor>("Variances");
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
int post_nms_top_n = context.Attr<int>("post_nms_topN");
float nms_thresh = context.Attr<float>("nms_thresh");
float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta");
auto &dev_ctx = context.template device_context<DeviceContext>();
auto scores_dim = scores->dims();
int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3];
auto bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3];
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel() / 4, 1},
context.GetPlace());
Tensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
dev_ctx.GetPlace());
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace());
math::Transpose<DeviceContext, T, 4> trans;
std::vector<int> axis = {0, 2, 3, 1};
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis);
framework::LoD lod;
std::vector<size_t> lod0(1, 0);
Tensor *anchor = const_cast<framework::Tensor *>(anchors);
anchor->Resize({anchors->numel() / 4, 4});
Tensor *var = const_cast<framework::Tensor *>(variances);
var->Resize({var->numel() / 4, 4});
int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) {
Tensor im_info_slice = im_info->Slice(i, i + 1);
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
Tensor scores_slice = scores_swap.Slice(i, i + 1);
bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair =
ProposalForOneImage(dev_ctx, im_info_slice, *anchor, *var,
bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta);
Tensor proposals = tensor_pair.first;
Tensor scores = tensor_pair.second;
framework::VisitDataType(
framework::ToDataType(rpn_rois->type()),
AppendProposalsFunctor(rpn_rois, 4 * num_proposals, &proposals));
framework::VisitDataType(
framework::ToDataType(rpn_roi_probs->type()),
AppendProposalsFunctor(rpn_roi_probs, num_proposals, &scores));
num_proposals += proposals.dims()[0];
lod0.emplace_back(num_proposals);
}
lod.emplace_back(lod0);
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
rpn_roi_probs->Resize({num_proposals, 1});
}
std::pair<Tensor, Tensor> ProposalForOneImage(
const DeviceContext &ctx, const Tensor &im_info_slice,
const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1]
int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
float eta) const {
auto *scores_data = scores_slice.data<T>();
// Sort index
Tensor index_t;
index_t.Resize({scores_slice.numel()});
int *index = index_t.mutable_data<int>(ctx.GetPlace());
for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i;
}
std::function<bool(const int64_t &, const int64_t &)> compare =
[scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
};
if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) {
std::sort(index, index + scores_slice.numel(), compare);
} else {
std::nth_element(index, index + pre_nms_top_n,
index + scores_slice.numel(), compare);
index_t.Resize({pre_nms_top_n});
}
Tensor scores_sel, bbox_sel, anchor_sel, var_sel;
scores_sel.mutable_data<T>({index_t.numel(), 1}, ctx.GetPlace());
bbox_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
anchor_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
var_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
CPUGather<T>(ctx, scores_slice, index_t, &scores_sel);
CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel);
CPUGather<T>(ctx, anchors, index_t, &anchor_sel);
CPUGather<T>(ctx, variances, index_t, &var_sel);
Tensor proposals;
proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
BoxCoder<T>(ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals);
ClipTiledBoxes<T>(ctx, im_info_slice, &proposals);
Tensor keep;
FilterBoxes<T>(ctx, &proposals, min_size, im_info_slice, &keep);
Tensor scores_filter;
bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep.numel(), 1}, ctx.GetPlace());
CPUGather<T>(ctx, proposals, keep, &bbox_sel);
CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
if (nms_thresh <= 0) {
return std::make_pair(bbox_sel, scores_sel);
}
Tensor keep_nms = NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta);
if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
keep_nms.Resize({post_nms_top_n});
}
proposals.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
scores_sel.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals);
CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel);
return std::make_pair(proposals, scores_sel);
}
};
class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Scores", "The scores of anchors should be foreground.");
AddInput("BboxDeltas", "bbox_deltas.");
AddInput("ImInfo", "Information for image reshape.");
AddInput("Anchors", "All anchors.");
AddInput("Variances", " variances");
AddOutput("RpnRois", "Anchors.");
AddOutput("RpnRoiProbs", "Anchors.");
AddAttr<int>("pre_nms_topN", "pre_nms_topN");
AddAttr<int>("post_nms_topN", "post_nms_topN");
AddAttr<float>("nms_thresh", "nms_thres");
AddAttr<float>("min_size", "min size");
AddAttr<float>("eta", "eta");
AddComment(R"DOC(
Generate Proposals OP
This operator proposes rois according to each box with their probability to be a foreground object and
the box can be calculated by anchors. Bbox_deltais and scores are the output of RPN. Final proposals
could be used to train detection net.
Scores is the probability for each box to be an object. In format of (N, A, H, W) where N is batch size, A is number
of anchors, H and W are height and width of the feature map.
BboxDeltas is the differece between predicted box locatoin and anchor location. In format of (N, 4*A, H, W)
For generating proposals, this operator transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4) and
calculate box locations as proposals candidates. Then clip boxes to image and remove predicted boxes with small area.
Finally, apply nms to get final proposals as output.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp,
ops::GenerateProposalsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
generate_proposals,
ops::GenerateProposalsKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -18,15 +18,32 @@ limitations under the License. */ ...@@ -18,15 +18,32 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale,
T max_range, framework::Tensor* out) {
auto in_e = framework::EigenVector<T>::Flatten(*in);
const T* scale_factor = scale->data<T>();
auto out_e = framework::EigenVector<T>::Flatten(*out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = (scale_factor[0] / max_range) * in_e;
}
};
template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel { class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public: public:
FakeDequantizeMaxAbsOp(const std::string &type, FakeDequantizeMaxAbsOp(const std::string& type,
const framework::VariableNameMap &inputs, const framework::VariableNameMap& inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeDequantizeMaxAbsOp should not be null."); "Input(X) of FakeDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
...@@ -42,21 +59,17 @@ class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -42,21 +59,17 @@ class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"(Tensor) The input with float-32/64 type is the " "(Tensor) The input with float-32/64 type is the "
"low precision tensor."); "low precision tensor.");
AddInput("Scale", "(float) The scale in quantization stage.");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output is the dequantized high " "(Tensor) The output is the dequantized high "
"precision tensor."); "precision tensor.");
AddAttr<int>("num_bits", AddAttr<float>("max_range", "(float) The max range in quantization stage.");
"(int) `num_bits` is the quantization level bits, "
"such as 2, 5, 8.");
AddAttr<float>("scale",
"(float) The maximum absolute value of low precision tensor."
"It is usually calculated by the fake_quantize_max_abs_op.");
AddComment(R"DOC( AddComment(R"DOC(
FakeDequantizeMaxAbsOp operator. FakeDequantizeMaxAbsOp operator.
This calculation is an opposite operation of FakeQuantizeMaxAbsOp: This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
$$Out = \frac{scale*X}{2^{num_bits} - 1}$$ $$Out = \frac{scale*X}{ max_range }$$
)DOC"); )DOC");
} }
......
...@@ -14,6 +14,42 @@ limitations under the License. */ ...@@ -14,6 +14,42 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h" #include "paddle/fluid/operators/fake_dequantize_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void KeDequantize(const T* in, const T* scale, T max_range, int num,
T* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < num) {
out[idx] = in[idx] * scale[0] / max_range;
}
}
template <typename T>
struct DequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale,
T max_range, framework::Tensor* out) {
const T* in_data = in->data<T>();
const T* scale_factor = scale->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int num = in->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeDequantize<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, out_data);
}
};
template struct DequantizeFunctor<platform::CUDADeviceContext, float>;
template struct DequantizeFunctor<platform::CUDADeviceContext, double>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext; using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs, REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
......
...@@ -19,22 +19,29 @@ limitations under the License. */ ...@@ -19,22 +19,29 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T>
struct DequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor* scale, T max_range,
framework::Tensor* out);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> { class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(in->place());
int num_bits = ctx.Attr<int>("num_bits"); float max_range = ctx.Attr<float>("max_range");
T scale = static_cast<T>(ctx.Attr<float>("scale"));
int range = std::pow(2, num_bits) - 1; auto& dev_ctx = ctx.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace());
auto eigen_out = framework::EigenVector<T>::Flatten(*out); DequantizeFunctor<DeviceContext, T>()(dev_ctx, in, scale,
auto eigen_in = framework::EigenVector<T>::Flatten(*in); static_cast<T>(max_range), out);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
eigen_out.device(dev) = (scale / range) * eigen_in;
} }
}; };
......
...@@ -52,6 +52,8 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -52,6 +52,8 @@ class FetchBarrierOp : public framework::OperatorBase {
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker { class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
.AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
SendBarrier operator SendBarrier operator
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h"
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
"Output(BatchedGate) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
"Output(BatchResetHiddenPrev) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
"The rank of Input(WeightX) should be 2.");
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
"The first dimension of Input(WeightX) "
"should be %d.",
x_dims[1]);
int frame_size = wx_dims[1] / 3;
auto wh_dims = ctx->GetInputDim("WeightH");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
"The rank of Input(WeightH) should be 2.");
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
"The first dimension of Input(WeightH) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
"The second dimension of Input(WeightH) "
"should be 3 * %d.",
frame_size);
if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size.");
}
if (ctx->HasInput("Bias")) {
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
}
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", out_dims);
ctx->ShareLoD("X", "Hidden");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
}
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void FusionGRUOpMaker::Make() {
AddInput("X",
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
AddInput("H0",
"(Tensor, optional) The initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.")
.AsDispensable();
AddInput("WeightX",
"(Tensor) The FC weight with shape (M x 3D),"
"where M is the dim size of x, D is the hidden size. ");
AddInput("WeightH",
"(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. ");
AddInput("Bias",
"(Tensor, optional) (1 x 3D)."
"Almost same as GRUOp."
"Note: if have FC bias it should be added on this bias.")
.AsDispensable();
AddOutput("XX",
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
" or batched_X (size is T x M), this will be automatically chosen,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input.")
.AsIntermediate();
AddOutput("BatchedGate", "(LoDTensor) Same as GRUOp").AsIntermediate();
AddOutput("BatchResetHiddenPrev", "(LoDTensor) (T x 3D) Same as GRUOp.")
.AsIntermediate();
AddOutput("BatchedHidden", "(LoDTensor) (T X D) Same as GRUOp.")
.AsIntermediate();
AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
AddAttr<std::string>("activation",
"(string, default tanh) "
"The activation type used for output candidate {h}_t.")
.SetDefault("tanh");
AddAttr<std::string>(
"gate_activation",
"(string, default sigmoid) "
"The activation type used in update gate and reset gate.")
.SetDefault("sigmoid");
AddAttr<bool>("is_reverse",
"(bool, defalut: False) "
"whether to compute reversed GRU.")
.SetDefault(false);
AddComment(R"DOC(
The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU,
more details can refer to GRU op.
)DOC");
}
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src,
framework::Vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index_lod, dst, indexed_src);
}
template <typename DeviceContext, typename T>
class FusionGRUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* h0 = ctx.Input<Tensor>("H0");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate");
auto* batch_reset_hidden_prev =
ctx.Output<LoDTensor>("BatchResetHiddenPrev");
auto* batch_hidden = ctx.Output<LoDTensor>("BatchedHidden");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
bool is_reverse = ctx.Attr<bool>("is_reverse");
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(ctx.GetPlace());
batch_hidden->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
auto x_dims = x->dims();
auto wx_dims = wx->dims();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
if (x_dims[1] > wx_dims[1]) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
x_data, wx_data, xx_data,
bias ? bias->data<T>() : NULL);
to_batch(dev_ctx, *xx, batched_gate, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_gate->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
xx_data, wx_data, batched_gate_data,
bias ? bias->data<T>() : NULL);
}
int frame_size = static_cast<int>(wx_dims[1] / 3);
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(wh_data);
gru_value.state_weight =
const_cast<T*>(wh_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
framework::Vector<size_t> order(batched_gate->lod()[2]);
if (h0) {
ReorderInitState<DeviceContext, T>(
ctx.template device_context<DeviceContext>(), *h0, order, &ordered_h0,
true);
gru_value.prev_out_value = ordered_h0.data<T>();
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batched_gate->lod()[0];
size_t seq_len = batch_starts.size() - 1;
auto active_node =
math::detail::GetActivationType(ctx.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation"));
#ifdef PADDLE_WITH_MKLML
// use MKL packed to speedup GEMM
if (FLAGS_paddle_num_threads >= 4) {
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size * 2 /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_gate);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
packed_gate);
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_state);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
frame_size, T(1.0), gru_value.state_weight, frame_size,
packed_state);
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
frame_size, gru_value.prev_out_value, frame_size, packed_gate,
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
}
math::detail::forward_reset_output(
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
cur_batch_size, active_gate);
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
gru_value.reset_output_value, frame_size, packed_state,
frame_size, T(1), gru_value.gate_value + frame_size * 2,
frame_size * 3);
}
math::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
cur_batch_size, active_node);
gru_value.prev_out_value = gru_value.output_value;
}
blas.GEMM_FREE(packed_gate);
blas.GEMM_FREE(packed_state);
} else {
#endif
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
gru_value.prev_out_value = gru_value.output_value;
}
#ifdef PADDLE_WITH_MKLML
}
#endif
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batched_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden_out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(
fusion_gru, ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, float>,
ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionGRUOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionGRUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
...@@ -15,10 +15,14 @@ limitations under the License. */ ...@@ -15,10 +15,14 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_lstm_op.h" #include "paddle/fluid/operators/fusion_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(seq_mode, true, "Use sequence mode");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell"); ctx->ShareLoD("X", "Cell");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; int xx_width;
if (FLAGS_seq_mode) {
xx_width = wx_dims[1];
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
}
ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX"); ctx->ShareLoD("X", "XX");
} }
...@@ -205,10 +214,138 @@ inline void ReorderInitState(const DeviceContext& ctx, ...@@ -205,10 +214,138 @@ inline void ReorderInitState(const DeviceContext& ctx,
row_shuffle(ctx, src, index_lod, dst, indexed_src); row_shuffle(ctx, src, index_lod, dst, indexed_src);
} }
template <typename DeviceContext, typename T> template <typename T>
class FuisonLSTMKernel : public framework::OpKernel<T> { class FuisonLSTMKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
auto* h0 = ctx.Input<Tensor>("H0");
auto* c0 = ctx.Input<Tensor>("C0");
auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
bool is_reverse = ctx.Attr<bool>("is_reverse");
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
}
auto x_lod = x->lod();
auto x_dims = x->dims(); // T x M
auto wh_dims = wh->dims(); // D x 4D
const int total_T = x_dims[0];
const int N = x_lod[0].size() - 1; // batch size
const int M = x_dims[1]; // x frame size
const int D = wh_dims[0];
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = wh_dims[1];
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;
const T* c0_data = c0 ? c0->data<T>() : NULL;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
xx_data, bias->data<T>());
int xx_offset = D4;
int gate_offset = D;
if (is_reverse) {
const int offset = (total_T - 1) * D;
xx_data = xx_data + offset * 4;
hidden_out_data = hidden_out_data + offset;
cell_out_data = cell_out_data + offset;
xx_offset = -D4;
gate_offset = -D;
}
auto move_step = [&]() {
xx_data = xx_data + xx_offset;
hidden_out_data = hidden_out_data + gate_offset;
cell_out_data = cell_out_data + gate_offset;
};
for (int i = 0; i < N; ++i) {
int bid = is_reverse ? N - 1 - i : i;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_cell_data = NULL;
const T* prev_hidden_data = NULL;
int tstart = 0;
if (h0_data) {
prev_hidden_data = h0_data + bid * D;
prev_cell_data = c0_data + bid * D;
} else {
// W_ch, W_ih, W_fh, W_oh
act_gate(D3, xx_data + D, xx_data + D);
act_cand(D, xx_data, xx_data);
// cell out= input*tilde
blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
// hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2);
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
prev_hidden_data = hidden_out_data;
prev_cell_data = cell_out_data;
tstart = 1;
move_step();
}
for (int step = tstart; step < seq_len; ++step) {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D4, static_cast<T>(1), xx_data,
D4);
// W_ch, W_ih, W_fh, W_oh
act_gate(D3, xx_data + D, xx_data + D);
act_cand(D, xx_data, xx_data);
// a = forget * prev_cell
blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
// b = input * tilde
blas.VMUL(D, xx_data, xx_data + D, xx_data + D);
// cell out= a+b
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
// hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2);
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
prev_hidden_data = hidden_out_data;
prev_cell_data = cell_out_data;
move_step();
}
}
}
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX"); auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH"); auto* wh = ctx.Input<Tensor>("WeightH");
...@@ -339,6 +476,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -339,6 +476,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
// restore the output cell state in LoDTensor from the batch cell // restore the output cell state in LoDTensor from the batch cell
to_seq(dev_ctx, batch_cell, cell_out); to_seq(dev_ctx, batch_cell, cell_out);
} }
void Compute(const framework::ExecutionContext& ctx) const override {
if (FLAGS_seq_mode) {
SeqCompute(ctx);
} else {
BatchCompute(ctx);
}
}
}; };
} // namespace operators } // namespace operators
...@@ -348,7 +492,5 @@ namespace ops = paddle::operators; ...@@ -348,7 +492,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker, REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel<float>,
fusion_lstm, ops::FuisonLSTMKernel<double>);
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float>,
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
void FusionSeqExpandConcatFCOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GT(
ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1.");
PADDLE_ENFORCE(
ctx->HasInput("FCWeight"),
"Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FusionSeqExpandConcatFCOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("FCOut"),
"Output(FCOut) of FusionSeqExpandConcatFCOp should not be null.");
auto ins_dims = ctx->GetInputsDim("X");
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, "Input(FCWeight)'s rank must be 2.");
const int D = w_dims[1];
int sum = ins_dims[0][1];
for (size_t i = 1; i < ins_dims.size(); ++i) {
sum += ins_dims[i][1];
}
PADDLE_ENFORCE_EQ(sum, w_dims[0],
"FC height should be sum of all inputs width.");
if (ctx->HasInput("FCBias")) {
auto b_dims = ctx->GetInputDim("FCBias");
PADDLE_ENFORCE(b_dims.size() == 1 || b_dims.size() == 2,
"b_dims should be 1 or 2, get %d", b_dims.size());
if (b_dims.size() == 1) {
PADDLE_ENFORCE_EQ(b_dims[0], D, "FCBias shapes must be %d.", D);
} else {
PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1x%d.", D);
PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1x%d.", D);
}
}
ctx->SetOutputDim("Out", {ins_dims[0][0], D});
// fcout should be reshape when run since can not get lod in infershape
// explicit share the ref lod
ctx->ShareLoD("X", "Out", 0);
}
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()),
ctx.device_context());
}
void FusionSeqExpandConcatFCOpMaker::Make() {
AddInput("X",
"(LoDTensor) input LodDTensors, the first one must be have ref lod "
"for sequence expand, and the rest input should have same lod.")
.AsDuplicable();
AddInput("FCWeight", "(Tensor) the weights of fc.");
AddInput("FCBias", "(Tensor, optional) the bias of fc.").AsDispensable();
AddOutput("Out", "(LoDTensor) Output LodTensor.");
AddOutput(
"FCOut",
"(Tensor) the intermediate tensor to keep the result of fc."
"Shape is (N x D), where N is the batch size, D is the output dim of fc")
.AsIntermediate();
AddAttr<std::string>("fc_activation",
"(string, default: identity)"
"The activation for the result of fc."
"`identity` by default.")
.SetDefault("identity")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC(
Fusion Sequence expand + concat + fc Operator.
All below conditions should be meet:
The ref_level of seq_expand should be 0.
The ref lod of seq_expand level is the first input of concat.
The other inputs should have same lod and same batch size of ref lod.
The seq len of other inputs should be 1.
The concat axis should be 1.
)DOC");
}
template <typename T>
class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto ins = ctx.MultiInput<LoDTensor>("X");
auto* w = ctx.Input<Tensor>("FCWeight");
auto* b = ctx.Input<Tensor>("FCBias");
auto* out = ctx.Output<LoDTensor>("Out");
auto* fc_out = ctx.Output<Tensor>("FCOut");
auto* ref_in = ins[0];
auto ref_lod = ref_in->lod();
auto in1_lod = ins[1]->lod();
auto ref_dims = ref_in->dims(); // T x M0
auto in1_dims = ins[1]->dims(); // N x M1
auto w_dims = w->dims();
const int N = ref_lod[0].size() - 1;
const int total_T = ref_dims[0];
const int M0 = ref_dims[1];
const int M1 = in1_dims[1];
const int D = w_dims[1];
// some check and fcout should be reshape here
// since infershape can not get lod info
PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1.");
PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1.");
PADDLE_ENFORCE_EQ(in1_lod[0].size() - 1, N,
"Batch size of all inputs should be equal.");
PADDLE_ENFORCE_EQ(in1_lod[0][N], N,
"Seq_length of other inputs should be 1.");
PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size.");
for (size_t i = 2; i < ins.size(); ++i) {
PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N,
"All other inputs height should be equal");
PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod,
"All other inputs should have same lod");
}
fc_out->Resize({N, D});
std::function<void(const int, const T*, T*)> fc_act;
auto& fc_act_str = ctx.Attr<std::string>("fc_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
fc_act = act_functor(fc_act_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
fc_act = act_functor(fc_act_str);
}
const T* ref_in_data = ref_in->data<T>();
const T* in1_data = ins[1]->data<T>();
const T* w_data = w->data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
out_data, b ? b->data<T>() : NULL);
w_data = w_data + M0 * D;
// first write on
blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
w_data = w_data + M1 * D;
for (size_t i = 2; i < ins.size(); ++i) {
// add on
const T* in_data = ins[i]->data<T>();
const int K = ins[i]->dims()[1];
blas.GEMM(CblasNoTrans, CblasNoTrans, N, D, K, static_cast<T>(1), in_data,
K, w_data, D, static_cast<T>(1), fc_out_data, D);
w_data = w_data + K * D;
}
T* cur_out_data = out_data;
for (int i = 0; i < N; ++i) {
int seq_len = ref_lod[0][i + 1] - ref_lod[0][i];
T* src = fc_out_data + i * D;
for (int step = 0; step < seq_len; ++step) {
blas.VADD(D, cur_out_data, src, cur_out_data);
cur_out_data = cur_out_data + D;
}
}
fc_act(total_T * D, out_data, out_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ops::FusionSeqExpandConcatFCOp,
ops::FusionSeqExpandConcatFCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc,
ops::FusionSeqExpandConcatFCOpKernel<float>,
ops::FusionSeqExpandConcatFCOpKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionSeqExpandConcatFCOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册