diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 077072f6eadb0c48f4ae32f94828613d89ed01c9..a3e682e54ac496e37ed4a33a7b30d9fdca381d9d 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -18,7 +18,7 @@ function(copy TARGET) set(oneValueArgs "") set(multiValueArgs SRCS DSTS DEPS) cmake_parse_arguments(copy_lib "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - set(inference_lib_dist_dep ${TARGET} ${inference_lib_dist_dep} PARENT_SCOPE) + set(fluid_lib_dist_dep ${TARGET} ${fluid_lib_dist_dep} PARENT_SCOPE) list(LENGTH copy_lib_SRCS copy_lib_SRCS_len) list(LENGTH copy_lib_DSTS copy_lib_DSTS_len) @@ -185,7 +185,8 @@ copy(cmake_cache SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt DSTS ${FLUID_INSTALL_DIR}) -add_custom_target(inference_lib_dist DEPENDS ${inference_lib_dist_dep}) +# This command generates a complete fluid library for both train and inference +add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep}) # paddle fluid version execute_process( diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index c6dd919a93d119723b389d3a695f0af82d711a06..6a37b5ca433a3baa1388dd4f720d782ca53e4e99 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -75,7 +75,8 @@ paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'outp paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)) paddle.fluid.layers.sequence_expand_as ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen', 'name'], varargs=None, keywords=None, defaults=(None, None)) +paddle.fluid.layers.sequence_unpad ArgSpec(args=['x', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None)) paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) @@ -84,6 +85,7 @@ paddle.fluid.layers.reduce_min ArgSpec(args=['input', 'dim', 'keep_dim', 'name'] paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'], varargs=None, keywords=None, defaults=(False, None, None)) paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)) paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)) @@ -127,6 +129,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.margin_rank_loss ArgSpec(args=['label', 'left', 'right', 'margin', 'name'], varargs=None, keywords=None, defaults=(0.1, None)) paddle.fluid.layers.elu ArgSpec(args=['x', 'alpha', 'name'], varargs=None, keywords=None, defaults=(1.0, None)) paddle.fluid.layers.relu6 ArgSpec(args=['x', 'threshold', 'name'], varargs=None, keywords=None, defaults=(6.0, None)) paddle.fluid.layers.pow ArgSpec(args=['x', 'factor', 'name'], varargs=None, keywords=None, defaults=(1.0, None)) diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index 519a00fb073b08f6c88de8186de187476b548fd3..48b36df6499e59fe742766b5f81fd30a9fb8b900 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -12,6 +12,5 @@ endif(NOT WIN32) if(WITH_INFERENCE) # NOTE: please add subdirectory inference at last. add_subdirectory(inference) + add_subdirectory(train) endif() - -add_subdirectory(train) diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 9fbefabc841e3f6940860f60d959fee97495e4c9..d09b94a3fd32952985a37cf4246c7640d2db4f56 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -64,7 +64,8 @@ class OpHandleBase { virtual bool IsMultiDeviceTransfer() { return false; } const platform::DeviceContext *DeviceContext(platform::Place place) { - return dev_ctxes_[place]; + auto it = dev_ctxes_.find(place); + return it != dev_ctxes_.end() ? it->second : nullptr; } void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) { diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 70ec6e90a4d0106b7f838e51b8357798daa4b10d..b212666637a5289c9c6cd3585655deaeed8afd4b 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -46,6 +46,41 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { VLOG(5) << "destroy ExecutorPrepareContext"; } +template +static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, + GarbageCollector* gc, + RefCntMap* ref_cnts) { + std::unordered_set erase_tensors; + + auto handler = [&](const VariableNameMap& name_map) { + for (auto& name_pair : name_map) { + for (auto& name : name_pair.second) { + auto it = ref_cnts->find(name); + if (it == ref_cnts->end()) continue; + if ((it->second)-- == 1) { + auto* var = scope.FindVar(name); + if (var != nullptr) { + VLOG(10) << "Erase tensor \'" << name << "\'"; + if (var->IsType()) { + erase_tensors.insert(var->GetMutable()); + } else if (var->IsType()) { + erase_tensors.insert( + var->GetMutable()->mutable_value()); + } + } + } + } + } + }; + + handler(op->Inputs()); + handler(op->Outputs()); + + if (!erase_tensors.empty()) { + gc->Add(erase_tensors); + } +} + Executor::Executor(const platform::Place& place) : place_(place) {} void Executor::Close() { @@ -66,7 +101,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { } else if (var_type == proto::VarType::FETCH_LIST) { var->GetMutable(); } else if (var_type == proto::VarType::STEP_SCOPES) { - var->GetMutable>(); + var->GetMutable>(); } else if (var_type == proto::VarType::LOD_RANK_TABLE) { var->GetMutable(); } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { @@ -331,9 +366,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } int64_t max_memory_size = GetEagerDeletionThreshold(); - std::unique_ptr> gc; - if (max_memory_size >= 0) { + // WhileOp would set keep_kids to false + // WhileGradOp would need the scopes created in WhileOp + // Perhaps, we should not perform eager deletion in WhileOp + // The scopes and variables created by WhileOp would be deleted + // in WhileGradOp. + if (max_memory_size >= 0 && !keep_kids) { ctx->ResetReferenceCount(); #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { @@ -352,45 +391,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, op->Run(*local_scope, place_); if (gc != nullptr) { - std::vector erase_vars; - for (auto& input : op->Inputs()) { - for (auto& input_name : input.second) { - auto it = ctx->cur_ref_cnts_.find(input_name); - if (it == ctx->cur_ref_cnts_.end()) continue; - if (it->second == 1) { // should delete it - erase_vars.emplace_back(input_name); - ctx->cur_ref_cnts_.erase(input_name); - } else { - --(it->second); - } - } - } - - for (auto& output : op->Outputs()) { - for (auto& output_name : output.second) { - auto it = ctx->cur_ref_cnts_.find(output_name); - if (it == ctx->cur_ref_cnts_.end()) continue; - if (it->second == 1) { - erase_vars.emplace_back(output_name); - ctx->cur_ref_cnts_.erase(output_name); - } else { - --(it->second); - } - } - } - - if (!erase_vars.empty()) { - std::vector erase_tensors; - for (auto& name : erase_vars) { - auto* var = local_scope->FindVar(name); - if (var == nullptr) continue; - if (var->IsType()) { - auto* tensor = var->GetMutable(); - erase_tensors.push_back(tensor); - } - } - if (!erase_tensors.empty()) gc->Add(erase_tensors); - } + DeleteUnusedTensors(*local_scope, op.get(), gc.get(), + &(ctx->cur_ref_cnts_)); } if (FLAGS_benchmark) { diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index f0cc1338a8af50030a70a9797cbcd1b0567272b5..36b36d49c2728dbef93042158dffa26d8f56d529 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -32,38 +32,32 @@ template std::unordered_map GetNonPersistableReferenceCount( const ProgramDesc& prog, size_t block_id) { auto& block = prog.Block(block_id); - std::unordered_set ignored_vars; std::unordered_map ref_cnts; - for (auto var_desc : block.AllVars()) { - auto type = var_desc->Proto()->type().type(); - if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) { - ignored_vars.insert(var_desc->Name()); // ignore persistable vars - } - } - - for (auto op_desc : block.AllOps()) { - for (auto& input : op_desc->Inputs()) { - for (auto& input_name : input.second) { - if (!ignored_vars.count(input_name)) { - if (ref_cnts.count(input_name)) - ++ref_cnts[input_name]; - else - ref_cnts[input_name] = 1; + auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) { + for (auto& name_pair : name_map) { + for (auto& name : name_pair.second) { + auto* var_desc = block.FindVar(name); + if (var_desc == nullptr || var_desc->Persistable()) continue; + auto type = var_desc->Proto()->type().type(); + if (type != proto::VarType::LOD_TENSOR && + type != proto::VarType::SELECTED_ROWS) { + continue; } - } - } - for (auto& output : op_desc->Outputs()) { - for (auto output_name : output.second) { - if (!ignored_vars.count(output_name)) { - if (ref_cnts.count(output_name)) - ++ref_cnts[output_name]; - else - ref_cnts[output_name] = 1; + auto it = ref_cnts.find(name); + if (it != ref_cnts.end()) { + ++it->second; + } else { + ref_cnts[name] = 1; } } } + }; + + for (auto op_desc : block.AllOps()) { + update_ref_cnts(op_desc, op_desc->Inputs()); + update_ref_cnts(op_desc, op_desc->Outputs()); } return ref_cnts; } diff --git a/paddle/fluid/framework/feed_fetch_method.cc b/paddle/fluid/framework/feed_fetch_method.cc index 8e1f93c5ebd448903d70f9668539e077875836e4..3e9353f5cf67d8de62c5551f12ea786e49190549 100644 --- a/paddle/fluid/framework/feed_fetch_method.cc +++ b/paddle/fluid/framework/feed_fetch_method.cc @@ -27,8 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input, // be created. VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; Variable* g_feed_value = scope->Var(var_name); - auto& feed_inputs = - *(g_feed_value->GetMutable>()); + auto& feed_inputs = *(g_feed_value->GetMutable()); if (index >= feed_inputs.size()) { feed_inputs.resize(index + 1); } diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 0076a8bece31f9a977b375717c25688fc0c95819..796ce1f91ce6f3e21dc6f0af8fca4960d43f6e2b 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -38,6 +38,7 @@ pass_library(fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference) +pass_library(conv_bn_fuse_pass inference) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..04459612a726bcb60f9d752dfd8927b6f5c2500d --- /dev/null +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -0,0 +1,258 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_CONV_BN_NODES(pattern_name) \ + /* OPERATORS */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name); \ + /* CONV inputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \ + /* CONV outputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \ + /* BN inputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name); \ + /* BN outputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */ \ + GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name) + +void recompute_bias_and_weights(const Scope* scope, + ir::Node* conv_weight, // + const ir::Node& bn_scale, // + const LoDTensor& bn_bias_tensor, // + const ir::Node& bn_mean, // + const ir::Node& bn_variance, // + LoDTensor* eltwise_y_in_tensor, // + float epsilon) { + using EigenVectorArrayMap = + Eigen::Map>; + using ConstEigenVectorArrayMap = + Eigen::Map>; + using EigenMatrixArrayMap = Eigen::Map< + Eigen::Array>; + + // Re-compute bias of conv2d from BN + PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims()); + + auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable(); + auto* variance_tensor = + scope->FindVar(bn_variance.Name())->GetMutable(); + auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable(); + + ConstEigenVectorArrayMap scale_array(scale_tensor->data(), + scale_tensor->numel(), 1); + EigenVectorArrayMap variance_array( + variance_tensor->mutable_data(platform::CPUPlace()), + variance_tensor->numel(), 1); + ConstEigenVectorArrayMap mean_array(mean_tensor->data(), + mean_tensor->numel(), 1); + ConstEigenVectorArrayMap bn_bias_array(bn_bias_tensor.data(), + bn_bias_tensor.numel(), 1); + + // variance will not be used anymore, so make it std_array and then tmp_array + variance_array += epsilon; + variance_array = variance_array.sqrt(); + variance_array = scale_array / variance_array; + + EigenVectorArrayMap eltwise_y_in_array( + eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), + eltwise_y_in_tensor->numel(), 1); + + eltwise_y_in_array = + ((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array; + + // Re-compute weight of conv2d from BN + auto* weights = scope->FindVar(conv_weight->Name())->GetMutable(); + auto weights_shape = weights->dims(); + auto weights_shape_2d = flatten_to_2d(weights_shape, 1); + + EigenMatrixArrayMap weights_array_2d( + weights->mutable_data(platform::CPUPlace()), weights_shape_2d[0], + weights_shape_2d[1]); + + weights_array_2d.colwise() *= variance_array; +} + +std::unique_ptr ConvBNFusePass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init(name_scope_, graph.get()); + + auto* scope = param_scope(); + PADDLE_ENFORCE(scope); + + GraphPatternDetector gpd; + auto* conv_input = + gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_); + conv_bn_pattern(conv_input, false /*with_eltwise_add*/); + + int found_conv_bn_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvBN fuse"; + + // conv, batch_norm, + // conv_weight, conv_out, + // bn_scale, bn_bias, bn_mean, bn_variance, + // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance + GET_CONV_BN_NODES(conv_bn_pattern); + + // Create eltwise_y (conv bias) variable + VarDesc eltwise_y_in_desc( + patterns::PDNodeName(name_scope_, "eltwise_y_in")); + auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc); + auto* eltwise_y_in_tensor = + scope->Var(eltwise_y_in_node->Name())->GetMutable(); + + // Get batch norm bias + auto* bn_bias_tensor = + scope->FindVar(bn_bias->Name())->GetMutable(); + + // Initialize eltwise_y + eltwise_y_in_tensor->Resize(bn_bias_tensor->dims()); + std::fill_n(eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), + eltwise_y_in_tensor->numel(), 0.0f); + + // update weights and biases + float epsilon = boost::get(batch_norm->Op()->GetAttr("epsilon")); + recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, + *bn_mean, *bn_variance, eltwise_y_in_tensor, + epsilon); + + // Create an elementwise add node + OpDesc desc; + desc.SetInput("X", std::vector({conv_out->Name()})); + desc.SetInput("Y", std::vector({eltwise_y_in_node->Name()})); + desc.SetOutput("Out", std::vector({bn_out->Name()})); + desc.SetType("elementwise_add"); + desc.SetAttr("axis", 1); + bool a = boost::get(conv->Op()->GetAttr("use_mkldnn")); + desc.SetAttr("use_mkldnn", a); + auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. + + GraphSafeRemoveNodes(graph.get(), {bn_scale, bn_bias, bn_mean, bn_variance, + batch_norm, bn_mean_out, bn_variance_out, + bn_saved_mean, bn_saved_variance}); + + PADDLE_ENFORCE(subgraph.count(conv_input)); + IR_NODE_LINK_TO(conv_out, eltwise_op); + IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); + IR_NODE_LINK_TO(eltwise_op, bn_out); + + found_conv_bn_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_conv_bn_count); + return graph; +} + +std::unique_ptr ConvEltwiseAddBNFusePass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init(name_scope_, graph.get()); + + auto* scope = param_scope(); + PADDLE_ENFORCE(scope); + + GraphPatternDetector gpd; + auto* conv_input = + gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_); + conv_bn_pattern(conv_input, true /*with_eltwise_add*/); + + int found_conv_bn_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvBN fuse"; + + // conv, batch_norm, + // conv_weight, conv_out, + // bn_scale, bn_bias, bn_mean, bn_variance, + // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance + GET_CONV_BN_NODES(conv_bn_pattern); + // OPERATORS + GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern); + // BIAS inputs + GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern); + // BIAS outputs + GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern); + + // Get eltwise_y (conv bias) variable + auto* eltwise_y_in_tensor = + scope->FindVar(eltwise_y_in->Name())->GetMutable(); + + // Get batch norm bias + auto* bn_bias_tensor = + scope->FindVar(bn_bias->Name())->GetMutable(); + + // update weights and biases + float epsilon = boost::get(batch_norm->Op()->GetAttr("epsilon")); + recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, + *bn_mean, *bn_variance, eltwise_y_in_tensor, + epsilon); + + // Update the elementwise_add node + eltwise->Op()->SetAttr("axis", 1); + eltwise->Op()->SetOutput("Out", std::vector({bn_out->Name()})); + + GraphSafeRemoveNodes( + graph.get(), + {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, + bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out}); + + PADDLE_ENFORCE(subgraph.count(conv_input)); + IR_NODE_LINK_TO(eltwise, bn_out); + + found_conv_bn_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_conv_bn_count); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass); +REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass, + paddle::framework::ir::ConvEltwiseAddBNFusePass); diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..2c9eb574fe8e054e0ae221f08f664b91f05d95c9 --- /dev/null +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h @@ -0,0 +1,49 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the Conv and BatchNorm to a ConvBNMKLDNNOp. + */ +class ConvBNFusePass : public FusePassBase { + public: + virtual ~ConvBNFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + const std::string name_scope_{"conv_bn_fuse"}; +}; + +class ConvEltwiseAddBNFusePass : public FusePassBase { + public: + virtual ~ConvEltwiseAddBNFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + const std::string name_scope_{"conv_eltwiseadd_bn_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 46c6a52c09e896596aa6d8e1e901955a68a4957d..8625b562e7dfab5a65692863cdc22b62ce15d758 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -626,6 +626,112 @@ bool VarLinksFromOp(Node *node, const std::string &op_type) { return false; } +PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, + bool with_eltwise_add) { + // Create Operators + conv_input->assert_is_op_input("conv2d", "Input"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + + PDNode *eltwise_op = nullptr; + if (with_eltwise_add) { + eltwise_op = + pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); + } + auto *batch_norm_op = + pattern->NewNode(batch_norm_repr())->assert_is_op("batch_norm"); + // Create variables + // Conv Filter + auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Filter"); + + auto *conv_out_var = pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("conv2d"); + + PDNode *eltwise_y_in_var = nullptr; + PDNode *eltwise_out_var = nullptr; + if (with_eltwise_add) { + // Conv output as Bias input + conv_out_var->assert_is_op_input("elementwise_add", "X"); + // Bias + eltwise_y_in_var = pattern->NewNode(eltwise_y_in_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->AsInput(); + eltwise_out_var = pattern->NewNode(eltwise_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("elementwise_add"); + } else { + // Conv output as BN input + conv_out_var->assert_is_op_input("batch_norm", "X"); + } + + // BN Scale + auto *bn_scale_var = pattern->NewNode(bn_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Scale"); + // BN Bias + auto *bn_bias_var = pattern->NewNode(bn_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Bias"); + // BN Mean + auto *bn_mean_var = pattern->NewNode(bn_mean_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Mean"); + // BN Variance + auto *bn_variance_var = pattern->NewNode(bn_variance_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Variance"); + + // BN output + auto *bn_out_var = pattern->NewNode(bn_out_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm"); + + auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "MeanOut"); + + auto *bn_variance_out_var = + pattern->NewNode(bn_variance_out_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "VarianceOut"); + + auto *bn_saved_mean_var = + pattern->NewNode(bn_saved_mean_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "SavedMean"); + + auto *bn_saved_variance_var = + pattern->NewNode(bn_saved_variance_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "SavedVariance"); + + conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); + + if (with_eltwise_add) { + eltwise_op->LinksFrom({conv_out_var, eltwise_y_in_var}) + .LinksTo({eltwise_out_var}); + batch_norm_op + ->LinksFrom({eltwise_out_var, bn_scale_var, bn_bias_var, bn_mean_var, + bn_variance_var}) + .LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var, + bn_saved_mean_var, bn_saved_variance_var}); + } else { + batch_norm_op + ->LinksFrom({conv_out_var, bn_scale_var, bn_bias_var, bn_mean_var, + bn_variance_var}) + .LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var, + bn_saved_mean_var, bn_saved_variance_var}); + } + return bn_out_var; +} + PDNode *patterns::ConvReLU::operator()( paddle::framework::ir::PDNode *conv_input) { // Create Operators diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 508113bf4fcab274394f2705c36eddbf4ba3c77a..cdd6413d968b065453177ff78b0aad641a09f6e7 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -375,6 +375,44 @@ struct PatternBase { size_t id_; }; +// Conv with batch norm +// op: conv + (elementwise_add +) batch_norm +// named nodes: +// conv_weight, conv_out, conv, +// bn_x, bn_scale, bn_bias, bn_mean, bn_variance, +// bn_batch_norm, bn_y, bn_mean_out, bn_variance_out, +// bn_saved_mean, bn_saved_variance +struct ConvBN : public PatternBase { + ConvBN(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_bn") {} + + PDNode* operator()(PDNode* conv_input, bool with_eltwise_add); + + // declare operator node's name + PATTERN_DECL_NODE(conv); + PATTERN_DECL_NODE(batch_norm); + PATTERN_DECL_NODE(eltwise); // ELEMENTWISE_ADD + // CONV inputs + PATTERN_DECL_NODE(conv_weight); // Filter + // CONV outputs + PATTERN_DECL_NODE(conv_out); // tmp + // ELTWISE inputs + PATTERN_DECL_NODE(eltwise_y_in); + // ELTWISE outputs + PATTERN_DECL_NODE(eltwise_out); // tmp + // BN inputs + PATTERN_DECL_NODE(bn_scale); + PATTERN_DECL_NODE(bn_bias); + PATTERN_DECL_NODE(bn_mean); + PATTERN_DECL_NODE(bn_variance); + // BN outputs + PATTERN_DECL_NODE(bn_out); // Out + PATTERN_DECL_NODE(bn_mean_out); + PATTERN_DECL_NODE(bn_variance_out); + PATTERN_DECL_NODE(bn_saved_mean); + PATTERN_DECL_NODE(bn_saved_variance); +}; + // CONV with ReLU // op: conv + relu // named nodes: diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index ba10687d65cfbbac89cfc76879c8b202ebd03229..2840d503f1454271afb309efdd435225ab077dc0 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -37,7 +37,7 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) { } else if (var_type == proto::VarType::FETCH_LIST) { var->GetMutable(); } else if (var_type == proto::VarType::STEP_SCOPES) { - var->GetMutable>(); + var->GetMutable>(); } else if (var_type == proto::VarType::LOD_RANK_TABLE) { var->GetMutable(); } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 9f930065324f13f5aa79c214e820fb6fc2f3a166..14fcde2fe3b1c3acfc0994e9cd37a784c57826d7 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -149,9 +149,17 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { platform::SetDeviceId(dev_id); #endif } - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::RecordEvent record_event(Type(), pool.Get(place)); - RunImpl(scope, place); + + // The profile has a process-wide mutex, results in serious performance issue + // in concurrency scenerio. Here use an `if` to fix this issue. + // Please not remove the `if`, ask @Superjomn if there are any concern. + if (platform::IsProfileEnabled()) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::RecordEvent record_event(Type(), pool.Get(place)); + RunImpl(scope, place); + } else { + RunImpl(scope, place); + } VLOG(3) << place << " " << DebugStringEx(&scope); } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f06bad6c78c05804e583f859906b88fb7b500372..e8adabd26540754d5b9206294eeeed79757220bf 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -307,6 +307,10 @@ ParallelExecutor::~ParallelExecutor() { } } } + + // member_ must be destructed before gcs_ since the destructor of + // ReferenceCountOpHandle use raw pointers of gcs_ inside. + member_.reset(); } } // namespace framework diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index fd386a5987f11ff64964e95eb7e9b83572dc790c..ef09b98b2aa91a9d729b94d15dbb676dde4092b6 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -75,7 +75,7 @@ class ParallelExecutor { private: void BCastParamsToDevices(const std::unordered_set &vars) const; - ParallelExecutorPrivate *member_; + std::unique_ptr member_; #ifdef PADDLE_WITH_CUDA // ref_cnts_ is only initialized when ParallelExecutor constructs, and then diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 1a727a2c8c759d010606d5b605823b7252b35c69..a4abd1b1283f08fb8431fbeea0cea17c8439fdd7 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -49,18 +49,18 @@ int64_t GetEagerDeletionThreshold() { Scope::~Scope() { DropKids(); } Scope& Scope::NewScope() const { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); kids_.push_back(new Scope(this)); return *kids_.back(); } Variable* Scope::Var(const std::string& name) { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); return VarInternal(name); } Variable* Scope::Var(std::string* name) { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); auto new_name = string::Sprintf("%p.%d", this, vars_.size()); if (name != nullptr) { *name = new_name; @@ -69,29 +69,34 @@ Variable* Scope::Var(std::string* name) { } Variable* Scope::FindVar(const std::string& name) const { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); return FindVarInternal(name); } +Variable* Scope::FindLocalVar(const std::string& name) const { + std::lock_guard lock(mutex_); + return FindVarLocally(name); +} + const Scope* Scope::FindScope(const Variable* var) const { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); return FindScopeInternal(var); } void Scope::DropKids() { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); for (Scope* s : kids_) delete s; kids_.clear(); } bool Scope::HasKid(const Scope* scope) const { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); return it != this->kids_.end(); } std::vector Scope::LocalVarNames() const { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); std::vector known_vars; known_vars.reserve(this->vars_.size()); for (auto& p : vars_) { @@ -101,7 +106,7 @@ std::vector Scope::LocalVarNames() const { } void Scope::DeleteScope(Scope* scope) const { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); this->kids_.erase(it); @@ -114,7 +119,7 @@ void Scope::DeleteScope(Scope* scope) const { } void Scope::EraseVars(const std::vector& var_names) { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); std::set var_set(var_names.begin(), var_names.end()); for (auto it = vars_.begin(); it != vars_.end();) { if (var_set.find(it->first) != var_set.end()) { @@ -127,12 +132,12 @@ void Scope::EraseVars(const std::vector& var_names) { void Scope::Rename(const std::string& origin_name, const std::string& new_name) const { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); RenameInternal(origin_name, new_name); } std::string Scope::Rename(const std::string& origin_name) const { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); auto new_name = string::Sprintf("%p.%d", this, vars_.size()); RenameInternal(origin_name, new_name); return new_name; diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index e42fff1d79d92fb7ed61768a614d8cd98f6775a0..14f9f36812d690fc4a7440f2e7e6a85e9993a535 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -63,6 +63,11 @@ class Scope { /// Caller doesn't own the returned Variable. Variable* FindVar(const std::string& name) const; + /// Find a variable in the current scope. + /// Return nullptr if cannot find. + /// Caller doesn't own the returned Variable. + Variable* FindLocalVar(const std::string& name) const; + const Scope* parent() const { return parent_; } /// Find the scope or an ancestor scope that contains the given variable. diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 1d7a2eb5b38255531880fe3d2e5321024caf0c6b..69bcbc0e5891f95af4de8dfd49a25648ca920ab1 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -36,6 +36,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, auto size = src.numel() * SizeOfType(src.type()); if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { + if (src_ptr == dst_ptr) { + VLOG(3) << "Skip copy the same data async from " << src_place << " to " + << dst_place; + return; + } memory::Copy(boost::get(dst_place), dst_ptr, boost::get(src_place), src_ptr, size); } @@ -71,6 +76,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, auto stream = reinterpret_cast(ctx).stream(); if (platform::is_same_place(src_place, dst_place)) { + if (src_ptr == dst_ptr) { + VLOG(3) << "Skip copy the same data async from " << src_place << " to " + << dst_place; + return; + } memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); } else { @@ -114,6 +124,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, auto dst_ptr = dst->mutable_data(dst_place, src.type()); auto size = src.numel() * SizeOfType(src.type()); if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { + if (src_ptr == dst_ptr) { + VLOG(3) << "Skip copy the same data from " << src_place << " to " + << dst_place; + return; + } memory::Copy(boost::get(dst_place), dst_ptr, boost::get(src_place), src_ptr, size); } @@ -130,6 +145,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, nullptr); } else if (platform::is_gpu_place(src_place) && platform::is_gpu_place(dst_place)) { + if (src_ptr == dst_ptr && platform::is_same_place(src_place, dst_place)) { + VLOG(3) << "Skip copy the same data from " << src_place << " to " + << dst_place; + return; + } auto src_gpu_place = boost::get(src_place); auto dst_gpu_place = boost::get(dst_place); memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr); diff --git a/paddle/fluid/framework/tensor_util_test.cc b/paddle/fluid/framework/tensor_util_test.cc index a1e5b967a86d10f3439db662af54bb82888027b9..793ccfc79fe56707f226477b9d50b1d972ab6a59 100644 --- a/paddle/fluid/framework/tensor_util_test.cc +++ b/paddle/fluid/framework/tensor_util_test.cc @@ -41,6 +41,11 @@ TEST(TensorCopy, Tensor) { EXPECT_EQ(src_ptr[i], dst_ptr[i]); } + TensorCopy(dst_tensor, *cpu_place, &dst_tensor); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(src_ptr[i], dst_ptr[i]); + } + EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout()); Tensor slice_tensor = src_tensor.Slice(1, 2); @@ -82,6 +87,15 @@ TEST(TensorCopy, Tensor) { EXPECT_EQ(src_ptr[i], dst_ptr[i]); } + // Copy the same tensor + TensorCopy(gpu_tensor, *gpu_place, gpu_ctx, &gpu_tensor); + gpu_ctx.Wait(); + const int* dst_ptr_tmp = dst_tensor.data(); + EXPECT_NE(src_ptr, dst_ptr_tmp); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(src_ptr[i], dst_ptr_tmp[i]); + } + Tensor slice_tensor = src_tensor.Slice(1, 2); // CPU Slice Tensor to GPU Tensor diff --git a/paddle/fluid/framework/var_desc.h b/paddle/fluid/framework/var_desc.h index e33849ef502fb10b913e7e28cbd0abdb8b8ff9bb..9d3fb811191c207c75845ef8f8486e8beac7525a 100644 --- a/paddle/fluid/framework/var_desc.h +++ b/paddle/fluid/framework/var_desc.h @@ -59,6 +59,7 @@ class VarDesc { public: explicit VarDesc(const std::string &name) { desc_.set_name(name); + // TODO(paddle-dev): Why default to lodtensor. desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR); } diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 067e0c2b8389f88639fd9b95bd680702517efee1..873e1b20a584df3ba90cf5c1a62a3879bf98ce5c 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -38,8 +38,12 @@ class Variable { template T* GetMutable() { - if (!IsType()) { + if (!holder_) { holder_.reset(new PlaceholderImpl(new T())); + } else { + PADDLE_ENFORCE(IsType(), + "Variable must be type %s, the holding type is %s", + typeid(T).name(), holder_->Type().name()); } return static_cast(holder_->Ptr()); } diff --git a/paddle/fluid/framework/variable_test.cc b/paddle/fluid/framework/variable_test.cc index c5c1d215f4a6affae0a3bdafacec40a2aee2ca19..003dcfd3dfe5ecfd563a686bb72b061aff602f73 100644 --- a/paddle/fluid/framework/variable_test.cc +++ b/paddle/fluid/framework/variable_test.cc @@ -33,9 +33,10 @@ TEST(Variable, GetMutable) { const Tensor& tt = v->Get(); EXPECT_EQ(1234, tt.content_); - std::string* s = v->GetMutable(); - *s = "hello"; - - const std::string& ss = v->Get(); - EXPECT_EQ("hello", ss); + try { + v->GetMutable(); + } catch (std::exception& e) { + return; + } + EXPECT_TRUE(false); } diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index ec1bc7825dd21628f5c37ea44a154abe7b7e8c73..9794a193bcfaae19552b1f6fbdf2dab2898033d5 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -19,9 +19,19 @@ cc_library(paddle_fluid_origin DEPS ${fluid_modules} paddle_fluid_api) add_subdirectory(api) +set(STATIC_INFERENCE_APIS paddle_fluid_api paddle_inference_api analysis_predictor) +set(SHARED_INFERENCE_SRCS + io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc + ${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc) +if (WITH_GPU AND TENSORRT_FOUND) + set(STATIC_INFERENCE_APIS ${STATIC_INFERENCE_APIS} paddle_inference_tensorrt_subgraph_engine) + set(SHARED_INFERENCE_SRCS ${SHARED_INFERENCE_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/api/api_tensorrt_subgraph_engine.cc) +endif() + # Create static library -cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api paddle_inference_api - analysis_predictor zero_copy_tensor) +cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor) + if(NOT APPLE) # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac. set(LINK_FLAGS "-Wl,--retain-symbols-file ${CMAKE_CURRENT_SOURCE_DIR}/paddle_fluid.sym") @@ -29,10 +39,7 @@ if(NOT APPLE) endif() # Create shared library -cc_library(paddle_fluid_shared SHARED - SRCS io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc - ${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc - ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc +cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} DEPS ${fluid_modules} paddle_fluid_api) set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 8a8aeb5e09a0d9a6746f6d6d61c547363e0e2d30..d780592eb9f79e39e34fcd3bd6b086992eaa931f 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -70,7 +70,7 @@ class DfgPassManagerImpl final : public DfgPassManager { auto trt_teller = [&](const Node* node) { std::unordered_set teller_set( {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", - "depthwise_conv2d", "batch_norm", "concat", "tanh", + "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", "elementwise_add", "dropout"}); if (!node->IsFunction()) return false; diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index 0aa9367bf5692e53e2a1f1247523cf9a4f0b3a1d..765145cb7da44ca13c5394ad1dc2e879e69d69d1 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -64,15 +64,17 @@ class Analyzer : public OrderedRegistry { // larger fusion. const std::vector all_ir_passes_{{ // Manual update the passes here. - "infer_clean_graph_pass", // - "attention_lstm_fuse_pass", // - "embedding_fc_lstm_fuse_pass", // - "fc_lstm_fuse_pass", // - "mul_lstm_fuse_pass", // - "fc_gru_fuse_pass", // - "mul_gru_fuse_pass", // - "seq_concat_fc_fuse_pass", // - "fc_fuse_pass", // + "infer_clean_graph_pass", // + "attention_lstm_fuse_pass", // + "embedding_fc_lstm_fuse_pass", // + "fc_lstm_fuse_pass", // + "mul_lstm_fuse_pass", // + "fc_gru_fuse_pass", // + "mul_gru_fuse_pass", // + "seq_concat_fc_fuse_pass", // + "fc_fuse_pass", // + "conv_bn_fuse_pass", // + "conv_eltwiseadd_bn_fuse_pass", // #ifdef PADDLE_WITH_MKLDNN "conv_relu_mkldnn_fuse_pass", // #endif diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 3bc6af5241c41bd805699121d614d431d46d863f..3095dee0f0106b2408663cd32bb4fb310111eda4 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -25,9 +25,11 @@ #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h" #include "paddle/fluid/inference/utils/singleton.h" +#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/profiler.h" DECLARE_bool(profile); +DECLARE_int32(paddle_num_threads); namespace paddle { @@ -47,6 +49,9 @@ bool AnalysisPredictor::Init( } #endif + // no matter with or without MKLDNN + paddle::platform::SetNumThreads(FLAGS_paddle_num_threads); + if (config_.use_gpu) { place_ = paddle::platform::CUDAPlace(config_.device); LOG(WARNING) << "ir optimize only supports CPU currently, enable_ir_optim " @@ -335,6 +340,19 @@ bool AnalysisPredictor::LoadProgramDesc() { } return true; } + +AnalysisPredictor::~AnalysisPredictor() { +#if !defined(_WIN32) + if (FLAGS_profile) { + platform::DisableProfiler(platform::EventSortingKey::kTotal, + "./profile.log"); + } +#endif + if (sub_scope_) { + scope_->DeleteScope(sub_scope_); + } +} + std::unique_ptr AnalysisPredictor::Clone() { auto *x = new AnalysisPredictor(config_); x->Init(scope_, inference_program_); diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 0d01d7ac2b29ea6364b07af9bb3bdeb5ced6bd00..5a9f4d36959d4ee7ca16dec769d9d1283b8787cb 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -72,6 +72,7 @@ class AnalysisPredictor : public PaddlePredictor { template void GetFetchOne(const framework::LoDTensor &fetchs, PaddleTensor *output_data); + ~AnalysisPredictor(); private: contrib::AnalysisConfig config_; diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 6682e0a81b20c82aa668a249d37986386d769c83..7cda9c5d8a8366bd097491f37f5352a10e4fb16c 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -23,9 +23,11 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/helper.h" +#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/profiler.h" DEFINE_bool(profile, false, "Turn on profiler for fluid"); +DECLARE_int32(paddle_num_threads); namespace paddle { namespace { @@ -72,6 +74,9 @@ bool NativePaddlePredictor::Init( } #endif + // no matter with or without MKLDNN + paddle::platform::SetNumThreads(FLAGS_paddle_num_threads); + if (config_.use_gpu) { place_ = paddle::platform::CUDAPlace(config_.device); } else { diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index 5ee6a5a93168f58770067f76ca7f6bb6f67b2965..7ac468ee4d33f49bba20a07c976055a083743cbc 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -185,3 +185,4 @@ USE_TRT_CONVERTER(softmax); USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(dropout); +USE_TRT_CONVERTER(pad); diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index d4e6bb3e4a4ceb361ccd35121d0ecf84a764243e..ec8471ef960a2fc44af23c52be09cd678fab3f70 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -3,6 +3,7 @@ project(cpp_inference_demo CXX C) option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) +option(USE_TENSORRT "Compile demo with TensorRT." OFF) macro(safe_set_static_flag) foreach(flag_var @@ -60,6 +61,13 @@ endif(NOT WIN32) include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") +if (NOT WIN32) + if (USE_TENSORRT AND WITH_GPU) + include_directories("${TENSORRT_INCLUDE_DIR}") + link_directories("${TENSORRT_LIB_DIR}") + endif() +endif(NOT WIN32) + if (NOT WIN32) link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") @@ -112,6 +120,10 @@ endif(NOT WIN32) if(WITH_GPU) if(NOT WIN32) + if (USE_TENSORRT) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif() set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) else() set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh index 44335a872f2e00b34e29a9e7601cb390a460362c..65c95f0834a9356fc14faed8342f5d1e474edf8f 100755 --- a/paddle/fluid/inference/api/demo_ci/run.sh +++ b/paddle/fluid/inference/api/demo_ci/run.sh @@ -3,6 +3,9 @@ PADDLE_ROOT=$1 TURN_ON_MKL=$2 # use MKL or Openblas TEST_GPU_CPU=$3 # test both GPU/CPU mode or only CPU mode DATA_DIR=$4 # dataset +TENSORRT_INCLUDE_DIR=$5 # TensorRT header file dir, defalut to /usr/local/TensorRT/include +TENSORRT_LIB_DIR=$6 # TensorRT lib file dir, default to /usr/local/TensorRT/lib + cd `dirname $0` current_dir=`pwd` if [ $2 == ON ]; then @@ -16,6 +19,11 @@ else use_gpu_list='false' fi +USE_TENSORRT=OFF +if [ [-d"$TENSORRT_INCLUDE_DIR"] -a [-d"$TENSORRT_LIB_DIR"] ]; then + USE_TENSORRT=ON +fi + PREFIX=inference-vis-demos%2F URL_ROOT=http://paddlemodels.cdn.bcebos.com/${PREFIX} @@ -86,5 +94,23 @@ for WITH_STATIC_LIB in ON OFF; do fi done done + + # --------tensorrt mobilenet------ + if [ $USE_TENSORRT == ON -a $TEST_GPU_CPU == ON ]; then + rm -rf * + cmake .. -DPADDLE_LIB=${PADDLE_ROOT}/build/fluid_install_dir/ \ + -DWITH_MKL=$TURN_ON_MKL \ + -DDEMO_NAME=trt_mobilenet_demo \ + -DWITH_GPU=$TEST_GPU_CPU \ + -DWITH_STATIC_LIB=$WITH_STATIC_LIB \ + -DUSE_TENSORRT=$USE_TENSORRT \ + -DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \ + -DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR + make -j + ./trt_mobilenet_demo \ + --modeldir=$DATA_DIR/mobilenet/model \ + --data=$DATA_DIR/mobilenet/data.txt \ + --refer=$DATA_DIR/mobilenet/result.txt + fi done set +x diff --git a/paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc b/paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc new file mode 100644 index 0000000000000000000000000000000000000000..ffb12b5871f088f15e43a1b0ff7e2a8b2f5fd079 --- /dev/null +++ b/paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* + * This file contains demo of mobilenet for tensorrt. + */ + +#include +#include // use glog instead of CHECK to avoid importing other paddle header files. +#include "paddle/fluid/inference/demo_ci/utils.h" + +DECLARE_double(fraction_of_gpu_memory_to_use); +DEFINE_string(modeldir, "", "Directory of the inference model."); +DEFINE_string(refer, "", "path to reference result for comparison."); +DEFINE_string( + data, "", + "path of data; each line is a record, format is " + "'\t predictor; + paddle::contrib::MixedRTConfig config; + config.param_file = FLAGS_modeldir + "/__params__"; + config.prog_file = FLAGS_modeldir + "/__model__"; + config.use_gpu = true; + config.device = 0; + config.max_batch_size = 1; + config.fraction_of_gpu_memory = 0.1; // set by yourself + predictor = CreatePaddlePredictor(config); + + VLOG(3) << "begin to process data"; + // Just a single batch of data. + std::string line; + std::ifstream file(FLAGS_data); + std::getline(file, line); + auto record = ProcessALine(line); + file.close(); + + // Inference. + PaddleTensor input; + input.shape = record.shape; + input.data = + PaddleBuf(record.data.data(), record.data.size() * sizeof(float)); + input.dtype = PaddleDType::FLOAT32; + + VLOG(3) << "run executor"; + std::vector output; + predictor->Run({input}, &output, 1); + + VLOG(3) << "output.size " << output.size(); + auto& tensor = output.front(); + VLOG(3) << "output: " << SummaryTensor(tensor); + + // compare with reference result + CheckOutput(FLAGS_refer, tensor); +} + +} // namespace demo +} // namespace paddle + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + paddle::demo::Main(); + return 0; +} diff --git a/paddle/fluid/inference/api/demo_ci/utils.h b/paddle/fluid/inference/api/demo_ci/utils.h index cb8990671162dff47228736e69617229528cc093..4792c97fe7d0a3f9c904774ad4a8e580cefcf237 100644 --- a/paddle/fluid/inference/api/demo_ci/utils.h +++ b/paddle/fluid/inference/api/demo_ci/utils.h @@ -14,6 +14,8 @@ #pragma once #include +#include +#include #include #include #include "paddle/fluid/inference/paddle_inference_api.h" @@ -21,6 +23,11 @@ namespace paddle { namespace demo { +struct Record { + std::vector data; + std::vector shape; +}; + static void split(const std::string& str, char sep, std::vector* pieces) { pieces->clear(); @@ -39,6 +46,58 @@ static void split(const std::string& str, char sep, } } +Record ProcessALine(const std::string& line) { + VLOG(3) << "process a line"; + std::vector columns; + split(line, '\t', &columns); + CHECK_EQ(columns.size(), 2UL) + << "data format error, should be \t"; + + Record record; + std::vector data_strs; + split(columns[0], ' ', &data_strs); + for (auto& d : data_strs) { + record.data.push_back(std::stof(d)); + } + + std::vector shape_strs; + split(columns[1], ' ', &shape_strs); + for (auto& s : shape_strs) { + record.shape.push_back(std::stoi(s)); + } + VLOG(3) << "data size " << record.data.size(); + VLOG(3) << "data shape size " << record.shape.size(); + return record; +} + +void CheckOutput(const std::string& referfile, const PaddleTensor& output) { + std::string line; + std::ifstream file(referfile); + std::getline(file, line); + auto refer = ProcessALine(line); + file.close(); + + size_t numel = output.data.length() / PaddleDtypeSize(output.dtype); + VLOG(3) << "predictor output numel " << numel; + VLOG(3) << "reference output numel " << refer.data.size(); + CHECK_EQ(numel, refer.data.size()); + switch (output.dtype) { + case PaddleDType::INT64: { + for (size_t i = 0; i < numel; ++i) { + CHECK_EQ(static_cast(output.data.data())[i], refer.data[i]); + } + break; + } + case PaddleDType::FLOAT32: + for (size_t i = 0; i < numel; ++i) { + CHECK_LT( + fabs(static_cast(output.data.data())[i] - refer.data[i]), + 1e-5); + } + break; + } +} + /* * Get a summary of a PaddleTensor content. */ diff --git a/paddle/fluid/inference/api/demo_ci/vis_demo.cc b/paddle/fluid/inference/api/demo_ci/vis_demo.cc index fb59cea457027854a099574c867299450690e61c..db61786e2fefda29256d84b5357028ec9c39b014 100644 --- a/paddle/fluid/inference/api/demo_ci/vis_demo.cc +++ b/paddle/fluid/inference/api/demo_ci/vis_demo.cc @@ -18,10 +18,6 @@ limitations under the License. */ #include #include // use glog instead of CHECK to avoid importing other paddle header files. -#include -#include - -// #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/inference/demo_ci/utils.h" #ifdef PADDLE_WITH_CUDA @@ -38,69 +34,11 @@ DEFINE_bool(use_gpu, false, "Whether use gpu."); namespace paddle { namespace demo { -struct Record { - std::vector data; - std::vector shape; -}; - -void split(const std::string& str, char sep, std::vector* pieces); - -Record ProcessALine(const std::string& line) { - VLOG(3) << "process a line"; - std::vector columns; - split(line, '\t', &columns); - CHECK_EQ(columns.size(), 2UL) - << "data format error, should be \t"; - - Record record; - std::vector data_strs; - split(columns[0], ' ', &data_strs); - for (auto& d : data_strs) { - record.data.push_back(std::stof(d)); - } - - std::vector shape_strs; - split(columns[1], ' ', &shape_strs); - for (auto& s : shape_strs) { - record.shape.push_back(std::stoi(s)); - } - VLOG(3) << "data size " << record.data.size(); - VLOG(3) << "data shape size " << record.shape.size(); - return record; -} - -void CheckOutput(const std::string& referfile, const PaddleTensor& output) { - std::string line; - std::ifstream file(referfile); - std::getline(file, line); - auto refer = ProcessALine(line); - file.close(); - - size_t numel = output.data.length() / PaddleDtypeSize(output.dtype); - VLOG(3) << "predictor output numel " << numel; - VLOG(3) << "reference output numel " << refer.data.size(); - CHECK_EQ(numel, refer.data.size()); - switch (output.dtype) { - case PaddleDType::INT64: { - for (size_t i = 0; i < numel; ++i) { - CHECK_EQ(static_cast(output.data.data())[i], refer.data[i]); - } - break; - } - case PaddleDType::FLOAT32: - for (size_t i = 0; i < numel; ++i) { - CHECK_LT( - fabs(static_cast(output.data.data())[i] - refer.data[i]), - 1e-5); - } - break; - } -} - /* * Use the native fluid engine to inference the demo. */ void Main(bool use_gpu) { + std::unique_ptr predictor; NativeConfig config; config.param_file = FLAGS_modeldir + "/__params__"; config.prog_file = FLAGS_modeldir + "/__model__"; @@ -111,7 +49,7 @@ void Main(bool use_gpu) { } VLOG(3) << "init predictor"; - auto predictor = + predictor = CreatePaddlePredictor(config); VLOG(3) << "begin to process data"; @@ -131,7 +69,7 @@ void Main(bool use_gpu) { VLOG(3) << "run executor"; std::vector output; - predictor->Run({input}, &output); + predictor->Run({input}, &output, 1); VLOG(3) << "output.size " << output.size(); auto& tensor = output.front(); @@ -146,9 +84,10 @@ void Main(bool use_gpu) { int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); - paddle::demo::Main(false /* use_gpu*/); if (FLAGS_use_gpu) { paddle::demo::Main(true /*use_gpu*/); + } else { + paddle::demo::Main(false /*use_gpu*/); } return 0; } diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index fac1babf6ec6131f84d3e3b9fc6efedd9f9f6cfc..0a35e10f6936313928ab21a6f17c40335e8fc882 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,7 +1,7 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc -batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc +batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc DEPS tensorrt_engine operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -26,6 +26,8 @@ nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) - nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) + +nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/pad_op.cc b/paddle/fluid/inference/tensorrt/convert/pad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..218030a591fcc7e533ef37062265449d4b6044bc --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/pad_op.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * PadOp. + */ +class PadOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert a fluid transpose op to tensorrt tranpose layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + + const std::vector paddings = + boost::get>(op_desc.GetAttr("paddings")); + const float pad_value = boost::get(op_desc.GetAttr("pad_value")); + + nvinfer1::Dims input_shape = input->getDimensions(); + int nbDims = input_shape.nbDims; + int pad_size = static_cast(paddings.size()); + PADDLE_ENFORCE_GE(nbDims, 2); + PADDLE_ENFORCE_EQ((nbDims + 1) * 2, pad_size); + PADDLE_ENFORCE(pad_value == 0.0, "The pad layer of TRT only support zero."); + + nvinfer1::DimsHW pre_pad(paddings[pad_size - 4], paddings[pad_size - 2]); + nvinfer1::DimsHW post_pad(paddings[pad_size - 3], paddings[pad_size - 1]); + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Padding, + *const_cast(input), + pre_pad, post_pad); + + PADDLE_ENFORCE(layer != nullptr); + auto output_name = op_desc.Output("Out")[0]; + engine_->SetITensor(output_name, layer->getOutput(0)); + layer->setName(("scale (Output: " + output_name + ")").c_str()); + layer->getOutput(0)->setName(output_name.c_str()); + if (test_mode) { // the test framework can not determine which is the + // output, so place the declaration inside. + engine_->DeclareOutput(output_name); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(pad, PadOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_pad_op.cc b/paddle/fluid/inference/tensorrt/convert/test_pad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba35d7ddbb2f4e6062713bd82be277e7ad0cb341 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_pad_op.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(PadConverter, main) { + framework::Scope scope; + std::unordered_set parameters; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("pad-X", nvinfer1::Dims3(3, 2, 2)); + validator.DeclOutputVar("pad-Out", nvinfer1::Dims3(3, 3, 5)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("pad"); + desc.SetInput("X", {"pad-X"}); + desc.SetOutput("Out", {"pad-Out"}); + + std::vector paddings = {0, 0, 0, 0, 0, 1, 1, 2}; + float pad_value = 0.0; + desc.SetAttr("paddings", paddings); + desc.SetAttr("pad_value", pad_value); + + LOG(INFO) << "set OP"; + validator.SetOp(*desc.Proto()); + LOG(INFO) << "execute"; + + validator.Execute(2); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(pad); diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 2ef13b72ed3ff6ae8ad8748ddea977e693615ac6..df3e3fcd9c75f03f4d9b0a7c12788f06bfdefd7f 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -230,7 +230,7 @@ if(WITH_DISTRIBUTE) op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) endforeach() - + #set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op # listen_and_serv_op sum_op executor SERIAL) @@ -268,6 +268,7 @@ if (WITH_GPU AND TENSORRT_FOUND) else() set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) endif() +op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows) op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) op_library(print_op DEPS lod_tensor) @@ -299,7 +300,7 @@ op_library(flatten_op DEPS reshape_op) op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) op_library(fake_quantize_op DEPS memory) -op_library(fusion_lstm_op DEPS cpu_lstm_compute) +op_library(fusion_lstm_op DEPS jit_kernel) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(layer_norm_op DEPS cub) diff --git a/paddle/fluid/operators/adadelta_op.cc b/paddle/fluid/operators/adadelta_op.cc index d1970515f58969948b1d2db5847e4344112f77f9..89a7a49e0fa8427826f5d91274912a68f2316b61 100644 --- a/paddle/fluid/operators/adadelta_op.cc +++ b/paddle/fluid/operators/adadelta_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; + class AdadeltaOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -31,6 +32,16 @@ class AdadeltaOp : public framework::OperatorWithKernel { "Input(AvgSquaredGrad) of AdadeltaOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"), "Input(AvgSquaredUpdate) of AdadeltaOp should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Grad").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(ParamOut) of AdadeltaOp should not be null."); @@ -56,6 +67,7 @@ class AdadeltaOp : public framework::OperatorWithKernel { ctx->SetOutputDim("AvgSquaredGradOut", param_dim); ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = diff --git a/paddle/fluid/operators/adadelta_op.h b/paddle/fluid/operators/adadelta_op.h index 822458daf663d99bbb38d99205f51163a0df4c4d..6c616aa03d9809e9b7725a700c7edd5ff5d6dc42 100644 --- a/paddle/fluid/operators/adadelta_op.h +++ b/paddle/fluid/operators/adadelta_op.h @@ -23,6 +23,17 @@ template class AdadeltaOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + const auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE(grad_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Grad").front(), grad_var->Type().name()); + auto param_out_tensor = ctx.Output("ParamOut"); auto avg_squared_grad_out_tensor = ctx.Output("AvgSquaredGradOut"); diff --git a/paddle/fluid/operators/adagrad_op.h b/paddle/fluid/operators/adagrad_op.h index df520fcc898ff5514927dbdd845ecaecdcf3c147..0a16ce00f71586ef55007c3753e024be29d0ed56 100644 --- a/paddle/fluid/operators/adagrad_op.h +++ b/paddle/fluid/operators/adagrad_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -21,25 +22,31 @@ namespace operators { template struct SparseAdagradFunctor { - void operator()(const DeviceContext& context, - const framework::SelectedRows& grad, - const framework::Tensor& learning_rate, T epsilon, - framework::Tensor* moment, framework::Tensor* param); + void operator()(const DeviceContext &context, + const framework::SelectedRows &grad, + const framework::Tensor &learning_rate, T epsilon, + framework::Tensor *moment, framework::Tensor *param); }; template class AdagradOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* param_out_tensor = ctx.Output("ParamOut"); - auto* moment_out_tensor = ctx.Output("MomentOut"); + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + + auto *param_out_tensor = ctx.Output("ParamOut"); + auto *moment_out_tensor = ctx.Output("MomentOut"); param_out_tensor->mutable_data(ctx.GetPlace()); moment_out_tensor->mutable_data(ctx.GetPlace()); T epsilon = static_cast(ctx.Attr("epsilon")); - auto* grad_var = ctx.InputVar("Grad"); + auto *grad_var = ctx.InputVar("Grad"); if (grad_var->IsType()) { auto param = framework::EigenVector::Flatten( *ctx.Input("Param")); @@ -47,16 +54,16 @@ class AdagradOpKernel : public framework::OpKernel { *ctx.Input("Grad")); auto moment = framework::EigenVector::Flatten( *ctx.Input("Moment")); - auto* learning_rate = ctx.Input("LearningRate"); + auto *learning_rate = ctx.Input("LearningRate"); auto param_out = framework::EigenVector::Flatten(*param_out_tensor); auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); - auto* place = ctx.template device_context().eigen_device(); + auto *place = ctx.template device_context().eigen_device(); moment_out.device(*place) = moment + grad * grad; Eigen::DSizes m_dsize(moment_out_tensor->numel()); if (platform::is_cpu_place(ctx.GetPlace())) { - auto* lr = learning_rate->data(); + auto *lr = learning_rate->data(); param_out.device(*place) = param - lr[0] * grad / (moment_out.sqrt() + epsilon); } else { @@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel { lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); } } else if (grad_var->IsType()) { - auto* param_tensor = ctx.Input("Param"); + auto *param_tensor = ctx.Input("Param"); PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor); - auto* moment_tensor = ctx.Input("Moment"); + auto *moment_tensor = ctx.Input("Moment"); PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor); SparseAdagradFunctor functor; diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h index 4cb1f3a80e95bdda79e6451dc3cc87e899b11779..3455d1ee54e8e6e498d0b0e6932ec099af9c0b30 100644 --- a/paddle/fluid/operators/adam_op.h +++ b/paddle/fluid/operators/adam_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/for_range.h" @@ -199,23 +200,9 @@ struct SparseAdamFunctor { row_numel_(row_numel), row_count_(row_count) {} - inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const { - int64_t beg = 0, end = row_count_ - 1; - while (beg <= end) { - auto mid = ((beg + end) >> 1); - if (rows_[mid] == row) - return mid; - else if (rows_[mid] < row) - beg = mid + 1; - else - end = mid - 1; - } - return -1; - } - inline HOSTDEVICE void operator()(size_t i) const { - int64_t row = i / row_numel_; - auto row_idx = BinarySearchInRows(row); + auto row_idx = + math::BinarySearch(rows_, row_count_, i / row_numel_); T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; // The following code is the same as dense @@ -244,6 +231,12 @@ template class AdamOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + using paddle::framework::LoDTensor; using paddle::operators::detail::Ref; diff --git a/paddle/fluid/operators/adamax_op.cc b/paddle/fluid/operators/adamax_op.cc index 32062574bcf71ff96e451eaa6865b6bbfc3b1c80..d4aa4d338a2379adf985ba7f89b528bc402eda06 100644 --- a/paddle/fluid/operators/adamax_op.cc +++ b/paddle/fluid/operators/adamax_op.cc @@ -35,6 +35,16 @@ class AdamaxOp : public framework::OperatorWithKernel { "Input(LearningRate) of AdamaxOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), "Input(Beta1Pow) of AdamaxOp should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Grad").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(ParamOut) of AdamaxOp should not be null."); diff --git a/paddle/fluid/operators/adamax_op.h b/paddle/fluid/operators/adamax_op.h index de644676fd9c3fabdbf01d2fd9c69858c2627ed3..7137fbd9651b4523f6d1609a0595b30758aa40df 100644 --- a/paddle/fluid/operators/adamax_op.h +++ b/paddle/fluid/operators/adamax_op.h @@ -23,6 +23,17 @@ template class AdamaxOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + const auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE(grad_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Grad").front(), grad_var->Type().name()); + auto param_out_tensor = ctx.Output("ParamOut"); auto moment_out_tensor = ctx.Output("MomentOut"); auto inf_norm_out_tensor = ctx.Output("InfNormOut"); diff --git a/paddle/fluid/operators/clip_by_norm_op.h b/paddle/fluid/operators/clip_by_norm_op.h index 5af0eb0b2ada66d5ae7d521d80e213f9e61f826f..855c4d70677395992e2bf685c910cbea2d37b20b 100644 --- a/paddle/fluid/operators/clip_by_norm_op.h +++ b/paddle/fluid/operators/clip_by_norm_op.h @@ -16,12 +16,15 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using SelectedRows = framework::SelectedRows; template using EigenVector = framework::EigenVector; @@ -31,9 +34,40 @@ class ClipByNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto max_norm = context.Attr("max_norm"); - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); + auto in_var = context.InputVar("X"); + + Tensor* output = nullptr; + const Tensor* input = nullptr; + if (in_var->IsType()) { + input = context.Input("X"); + + output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + } else if (in_var->IsType()) { + auto* x = context.Input("X"); + + // merge ids in selected rows first + math::scatter::MergeAdd merge_func; + SelectedRows* merged_input = + const_cast(context.scope()) + .Var() + ->GetMutable(); + merge_func(context.template device_context(), *x, + merged_input); + input = &(merged_input->value()); + + SelectedRows* output_selected_rows = context.Output("Out"); + output_selected_rows->set_rows(merged_input->rows()); + output_selected_rows->set_height(merged_input->height()); + output = output_selected_rows->mutable_value(); + output->Resize(merged_input->value().dims()); + output->mutable_data(context.GetPlace()); + } else { + PADDLE_THROW("Unexpected branch, input variable type is %s", + in_var->Type().name()); + } + + PADDLE_ENFORCE_NOT_NULL(input); auto x = EigenVector::Flatten(*input); auto out = EigenVector::Flatten(*output); diff --git a/paddle/fluid/operators/decayed_adagrad_op.cc b/paddle/fluid/operators/decayed_adagrad_op.cc index c0f2b49a04d9e88502c4b63bca493cd2b7ad1c5c..d73ae9e2721b388212cb6efa354eb4b480df9cad 100644 --- a/paddle/fluid/operators/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/decayed_adagrad_op.cc @@ -32,6 +32,16 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->HasInput("LearningRate"), "Input(LearningRate) of DecayedAdagradOp should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Grad").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(ParamOut) of DecayedAdagradOp should not be null."); diff --git a/paddle/fluid/operators/decayed_adagrad_op.h b/paddle/fluid/operators/decayed_adagrad_op.h index a46af078e0c6b4bf1faca0570b6a97b026864f13..5df43d33ef9f720fd20d57c53ff37cc85440b24e 100644 --- a/paddle/fluid/operators/decayed_adagrad_op.h +++ b/paddle/fluid/operators/decayed_adagrad_op.h @@ -23,6 +23,17 @@ template class DecayedAdagradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + const auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE(grad_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Grad").front(), grad_var->Type().name()); + auto param_out_tensor = ctx.Output("ParamOut"); auto moment_out_tensor = ctx.Output("MomentOut"); diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 2826b82117db113d4d8c10095e89f610ca895775..e04a68717b351ddb0be5a7e70aa9297e5eb0125f 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -70,6 +70,12 @@ class FillConstantOp : public framework::OperatorBase { } }; +class FillConstantOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override {} +}; + class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -102,4 +108,5 @@ Fill up a variable with specified constant value. namespace ops = paddle::operators; REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantInferShape, ops::FillConstantOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + ops::FillConstantOpVarTypeInference); diff --git a/paddle/fluid/operators/ftrl_op.cc b/paddle/fluid/operators/ftrl_op.cc index 70ba25c213046cc934f46be067080d5fdbb42f9e..b77e12d6508eb07ae137b313ca91eac951afbcbe 100644 --- a/paddle/fluid/operators/ftrl_op.cc +++ b/paddle/fluid/operators/ftrl_op.cc @@ -34,6 +34,16 @@ class FTRLOp : public framework::OperatorWithKernel { "Input(Grad) of FTRL should not be null."); PADDLE_ENFORCE(ctx->HasInput("LearningRate"), "Input(LearningRate) of FTRL should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Grad").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(ParamOut) of FTRL should not be null."); diff --git a/paddle/fluid/operators/ftrl_op.h b/paddle/fluid/operators/ftrl_op.h index 6f821e7e9944214fc5ebdf6bc7db8789b8ada6b9..8f812c9a037bfac8c1e29e32a5ad5b077c8153d1 100644 --- a/paddle/fluid/operators/ftrl_op.h +++ b/paddle/fluid/operators/ftrl_op.h @@ -28,6 +28,17 @@ template class FTRLOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + const auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE(grad_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Grad").front(), grad_var->Type().name()); + auto* param_out = ctx.Output("ParamOut"); auto* sq_accum_out = ctx.Output("SquaredAccumOut"); auto* lin_accum_out = ctx.Output("LinearAccumOut"); diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc index 0b917a403620e2ffb2cbb4ca7856cce9584e1eef..fdc9cb4888b3468b85abfa0c693ed8ac5b0d450b 100644 --- a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc @@ -93,11 +93,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( ctx->SetOutputDim("Cell", out_dims); ctx->ShareLoD("Ids", "Hidden"); ctx->ShareLoD("Ids", "Cell"); - int xx_width; - if (ctx->Attrs().Get("use_seq")) { - xx_width = wh_dims[1]; - } else { - xx_width = x_dims[1] > wh_dims[1] ? wh_dims[1] : x_dims[1]; + if (!ctx->Attrs().Get("use_seq")) { PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), "Assert only one Output(BatchedInput) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), @@ -112,7 +108,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedCell", out_dims); } - ctx->SetOutputDim("XX", {x_dims[0], xx_width}); + ctx->SetOutputDim("XX", {x_dims[0], wh_dims[1]}); ctx->ShareLoD("Ids", "XX"); } @@ -435,8 +431,6 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { INIT_VEC_FUNC INIT_BASE_INPUT_DATAS - // std::cout << "===> Batch Compute" << std::endl; - auto* reordered_h0 = ctx.Output("ReorderedH0"); auto* reordered_c0 = ctx.Output("ReorderedC0"); auto* batched_input = ctx.Output("BatchedInput"); diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index ae1f6d8e489039667d861a69acabf2c632ef2061..067e6a3e7cccc1f15ebdd984f3a2441339a989ab 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -15,11 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/cpu_lstm_compute.h" -#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/sequence2batch.h" -#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -219,121 +217,55 @@ This operator fuse the X into LSTM, more details can refer to LSTM op. template class FuisonLSTMKernel : public framework::OpKernel { public: -#define INIT_VEC_FUNC \ - std::function act_gate, act_cell, act_cand; \ - auto& act_gate_str = ctx.Attr("gate_activation"); \ - auto& act_cell_str = ctx.Attr("cell_activation"); \ - auto& act_cand_str = ctx.Attr("candidate_activation"); \ - if (platform::jit::MayIUse(platform::jit::avx)) { \ - math::VecActivations act_functor; \ - act_gate = act_functor(act_gate_str); \ - act_cell = act_functor(act_cell_str); \ - act_cand = act_functor(act_cand_str); \ - } else { \ - math::VecActivations act_functor; \ - act_gate = act_functor(act_gate_str); \ - act_cell = act_functor(act_cell_str); \ - act_cand = act_functor(act_cand_str); \ - } - -#define INIT_BASE_INPUT_OUTPUT \ - auto* x = ctx.Input("X"); \ - auto* h0 = ctx.Input("H0"); \ - auto* c0 = ctx.Input("C0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* bias = ctx.Input("Bias"); \ - auto* xx = ctx.Output("XX"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - auto* cell_out = ctx.Output("Cell"); \ - bool is_reverse = ctx.Attr("is_reverse"); \ - bool use_peepholes = ctx.Attr("use_peepholes"); - -#define INIT_BASE_SIZES \ - auto x_dims = x->dims(); /* T x M*/ \ - auto wh_dims = wh->dims(); /* D x 4D*/ \ - const int M = x_dims[1]; \ - const int D = wh_dims[0]; \ - const int D2 = D * 2; \ - const int D3 = D * 3; \ - const int D4 = wh_dims[1]; - -#define INIT_BASE_INPUT_DATAS \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - /* diagonal weight*/ \ - const T* wc_data = bias->data() + D4; \ - /* for peephole only*/ \ - T* checked_cell_data = nullptr; \ - auto place = ctx.GetPlace(); \ - if (use_peepholes) { \ - /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ - auto* checked_cell = ctx.Output("CheckedCell"); \ - checked_cell_data = checked_cell->mutable_data(place); \ - } - -/// Compute LSTM +#define INIT_BASE_DEFINES \ + using DeviceContext = paddle::platform::CPUDeviceContext; \ + auto* x = ctx.Input("X"); \ + auto* h0 = ctx.Input("H0"); \ + auto* c0 = ctx.Input("C0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* bias = ctx.Input("Bias"); \ + auto* xx = ctx.Output("XX"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + auto* cell_out = ctx.Output("Cell"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + bool use_peepholes = ctx.Attr("use_peepholes"); \ + auto x_dims = x->dims(); /* T x M*/ \ + auto wh_dims = wh->dims(); /* D x 4D*/ \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D4 = wh_dims[1] + +#define INIT_OTHER_DEFINES \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T* wp_data = bias->data() + D4; \ + /* for peephole only*/ \ + T* checked_cell_data = nullptr; \ + auto place = ctx.GetPlace(); \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + auto* checked_cell = ctx.Output("CheckedCell"); \ + checked_cell_data = checked_cell->mutable_data(place); \ + } \ + const auto& ker = \ + math::jitkernel::KernelPool::Instance() \ + .template Get, const std::string&, \ + const std::string&, const std::string&>( \ + ctx.Attr("gate_activation"), \ + ctx.Attr("candidate_activation"), \ + ctx.Attr("cell_activation"), D, use_peepholes) + +// Wh GEMM #define GEMM_WH_ADDON(bs, prev, out) \ blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast(1), prev, D, \ wh_data, D4, static_cast(1), out, D4) -#define GET_Ct(ct_1, gates, ct) \ - /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ - act_cand(D, gates, gates); \ - blas.VMUL(D, gates, gates + D, gates + D); \ - blas.VMUL(D, ct_1, gates + D2, gates + D2); \ - blas.VADD(D, gates + D, gates + D2, ct) - -#define GET_Ht(ct, gates, ht) \ - /* H_t = act_cell(C_t) * ogated */ \ - act_cell(D, ct, gates + D2); \ - blas.VMUL(D, gates + D2, gates + D3, ht) - -#define GET_Ct_NOH0C0(gates, ct) \ - /* C_t = igated * cgated*/ \ - act_gate(D, gates + D, gates + D); \ - act_cand(D, gates, gates); \ - blas.VMUL(D, gates, gates + D, ct) - -#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \ - GET_Ct_NOH0C0(gates, ct); \ - act_gate(D, gates + D3, gates + D3); \ - GET_Ht(ct, gates, ht) - -#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \ - GET_Ct_NOH0C0(gates, ct); \ - /* get outgated, put W_oc * C_t on igated */ \ - blas.VMUL(D, wc_data + D2, ct, gates + D); \ - blas.VADD(D, gates + D, gates + D3, gates + D3); \ - act_gate(D, gates + D3, gates + D3); \ - GET_Ht(ct, gates, ht) - -#define COMPUTE_CtHt(gates, ct_1, ct, ht) \ - act_gate(D3, gates + D, gates + D); \ - GET_Ct(ct_1, gates, ct); \ - GET_Ht(ct, gates, ht) - -#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \ - /* get fgated and igated*/ \ - blas.VMUL(D, wc_data, ct_1, checked_cell_data); \ - blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \ - blas.VADD(D2, checked_cell_data, gates + D, gates + D); \ - act_gate(D2, gates + D, gates + D); \ - GET_Ct(ct_1, gates, ct); \ - /* get ogated*/ \ - blas.VMUL(D, wc_data + D2, ct, gates + D); \ - blas.VADD(D, gates + D, gates + D3, gates + D3); \ - act_gate(D, gates + D3, gates + D3); \ - GET_Ht(ct, gates, ht) - void SeqCompute(const framework::ExecutionContext& ctx) const { - using DeviceContext = paddle::platform::CPUDeviceContext; - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES - INIT_VEC_FUNC - INIT_BASE_INPUT_DATAS - + INIT_BASE_DEFINES; + INIT_OTHER_DEFINES; auto x_lod = x->lod(); const int total_T = x_dims[0]; const int N = x_lod[0].size() - 1; @@ -357,89 +289,47 @@ class FuisonLSTMKernel : public framework::OpKernel { gate_offset = -D; } -#define MOVE_ONE_STEP \ - prev_h_data = h_out_data; \ - prev_c_data = c_out_data; \ - xx_data = xx_data + xx_offset; \ - h_out_data = h_out_data + gate_offset; \ - c_out_data = c_out_data + gate_offset - -#define PROCESS_H0C0_DEFINES \ - int bid = is_reverse ? N - 1 - i : i; \ - int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \ - const T* prev_c_data = nullptr; \ - const T* prev_h_data = nullptr; \ - int tstart = 0 - -#define PROCESS_H0C0_PEEPHOLE \ - PROCESS_H0C0_DEFINES; \ - if (h0_data) { \ - prev_h_data = h0_data + bid * D; \ - prev_c_data = c0_data + bid * D; \ - } else { \ - COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \ - MOVE_ONE_STEP; \ - tstart = 1; \ - } - -#define PROCESS_H0C0 \ - PROCESS_H0C0_DEFINES; \ - if (h0_data) { \ - prev_h_data = h0_data + bid * D; \ - prev_c_data = c0_data + bid * D; \ - } else { \ - COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \ - MOVE_ONE_STEP; \ - tstart = 1; \ - } - - if (use_peepholes) { - for (int i = 0; i < N; ++i) { - PROCESS_H0C0_PEEPHOLE - for (int step = tstart; step < seq_len; ++step) { - GEMM_WH_ADDON(1, prev_h_data, xx_data); - COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data); - MOVE_ONE_STEP; - } - } - } else { - // TODO(TJ): unly workaround, clean me - std::function compute_ctht; - if (platform::jit::MayIUse(platform::jit::avx) && - act_gate_str == "sigmoid" && act_cand_str == "tanh" && - act_cell_str == "tanh" && D == 8) { - compute_ctht = math::lstm_compute_ctht; + for (int i = 0; i < N; ++i) { + int bid = is_reverse ? N - 1 - i : i; + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; + const T* prev_c_data = nullptr; + const T* prev_h_data = nullptr; + int tstart = 0; + if (h0_data) { + prev_h_data = h0_data + bid * D; + prev_c_data = c0_data + bid * D; } else { - compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) { - COMPUTE_CtHt(gates, ct_1, ct, ht); - }; + ker->ComputeC1H1(xx_data, c_out_data, h_out_data, wp_data); + tstart = 1; + // move one step + prev_h_data = h_out_data; + prev_c_data = c_out_data; + xx_data = xx_data + xx_offset; + h_out_data = h_out_data + gate_offset; + c_out_data = c_out_data + gate_offset; } - for (int i = 0; i < N; ++i) { - PROCESS_H0C0 - for (int step = tstart; step < seq_len; ++step) { - GEMM_WH_ADDON(1, prev_h_data, xx_data); - compute_ctht(xx_data, prev_c_data, c_out_data, h_out_data); - MOVE_ONE_STEP; - } + for (int step = tstart; step < seq_len; ++step) { + GEMM_WH_ADDON(1, prev_h_data, xx_data); + ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data, wp_data, + checked_cell_data); + // move one step + prev_h_data = h_out_data; + prev_c_data = c_out_data; + xx_data = xx_data + xx_offset; + h_out_data = h_out_data + gate_offset; + c_out_data = c_out_data + gate_offset; } } -#undef PROCESS_H0C0_DEFINES -#undef PROCESS_H0C0_PEEPHOLE -#undef PROCESS_H0C0 -#undef MOVE_ONE_STEP } void BatchCompute(const framework::ExecutionContext& ctx) const { - using DeviceContext = platform::CPUDeviceContext; - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES + INIT_BASE_DEFINES; if (x->lod()[0].size() == 2) { xx->Resize({x_dims[0], D4}); SeqCompute(ctx); return; } - INIT_VEC_FUNC - INIT_BASE_INPUT_DATAS + INIT_OTHER_DEFINES; auto* reordered_h0 = ctx.Output("ReorderedH0"); auto* reordered_c0 = ctx.Output("ReorderedC0"); @@ -487,8 +377,8 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_c_data = reordered_c0_data; size_t sz = sizeof(T) * D; for (int i = 0; i < max_bs; ++i) { - std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); - std::memcpy(reordered_c0_data, c0_data + seq_order[i] * D, sz); + blas.VCOPY(sz, h0_data + seq_order[i] * D, reordered_h0_data); + blas.VCOPY(sz, c0_data + seq_order[i] * D, reordered_c0_data); reordered_h0_data += D; reordered_c0_data += D; } @@ -498,13 +388,7 @@ class FuisonLSTMKernel : public framework::OpKernel { T* cur_h_out_data = batched_h_out_data; T* cur_c_out_data = batched_c_out_data; for (int i = 0; i < max_bs; ++i) { - GET_Ct_NOH0C0(cur_in_data, cur_c_out_data); - if (use_peepholes) { - blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D); - blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3); - } - act_gate(D, cur_in_data + D3, cur_in_data + D3); - GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data); + ker->ComputeC1H1(cur_in_data, cur_c_out_data, cur_h_out_data, wp_data); cur_in_data += D4; cur_c_out_data += D; cur_h_out_data += D; @@ -513,71 +397,37 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_h_data = batched_h_out_data; prev_c_data = batched_c_out_data; } + + // compute kernel part const auto& batch_starts = batched_lod[0]; const int max_seq_len = batch_starts.size() - 1; const int offset = tstart * max_bs * D; batched_input_data = batched_input_data + offset * 4; batched_h_out_data = batched_h_out_data + offset; batched_c_out_data = batched_c_out_data + offset; - -#define DEFINE_CUR \ - T* cur_in_data = batched_input_data; \ - T* cur_prev_c_data = prev_c_data; \ - T* cur_c_out_data = batched_c_out_data; \ - T* cur_h_out_data = batched_h_out_data - -#define MOVE_ONE_BATCH \ - cur_in_data += D4; \ - cur_prev_c_data += D; \ - cur_c_out_data += D; \ - cur_h_out_data += D - -#define MOVE_ONE_STEP \ - prev_c_data = batched_c_out_data; \ - prev_h_data = batched_h_out_data; \ - batched_c_out_data = cur_c_out_data; \ - batched_h_out_data = cur_h_out_data; \ - batched_input_data = cur_in_data - - if (use_peepholes) { - for (int step = tstart; step < max_seq_len; ++step) { - const int cur_bs = batch_starts[step + 1] - batch_starts[step]; - GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); - DEFINE_CUR; - for (int i = 0; i < cur_bs; ++i) { - COMPUTE_CtHt_PEEPHOLE(cur_in_data, cur_prev_c_data, cur_c_out_data, - cur_h_out_data); - MOVE_ONE_BATCH; - } - MOVE_ONE_STEP; - } - } else { - // TODO(TJ): unly workaround, clean me - std::function compute_ctht; - if (platform::jit::MayIUse(platform::jit::avx) && - act_gate_str == "sigmoid" && act_cand_str == "tanh" && - act_cell_str == "tanh" && D == 8) { - compute_ctht = math::lstm_compute_ctht; - } else { - compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) { - COMPUTE_CtHt(gates, ct_1, ct, ht); - }; - } - for (int step = tstart; step < max_seq_len; ++step) { - const int cur_bs = batch_starts[step + 1] - batch_starts[step]; - GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); - DEFINE_CUR; - for (int i = 0; i < cur_bs; ++i) { - compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data, - cur_h_out_data); - MOVE_ONE_BATCH; - } - MOVE_ONE_STEP; + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); + T* cur_in_data = batched_input_data; + T* cur_prev_c_data = prev_c_data; + T* cur_c_out_data = batched_c_out_data; + T* cur_h_out_data = batched_h_out_data; + for (int i = 0; i < cur_bs; ++i) { + ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, + cur_h_out_data, wp_data, checked_cell_data); + // move one batch + cur_in_data += D4; + cur_prev_c_data += D; + cur_c_out_data += D; + cur_h_out_data += D; } + // move one step + prev_c_data = batched_c_out_data; + prev_h_data = batched_h_out_data; + batched_c_out_data = cur_c_out_data; + batched_h_out_data = cur_h_out_data; + batched_input_data = cur_in_data; } -#undef MOVE_ONE_STEP -#undef MOVE_ONE_BATCH -#undef DEFINE_CUR math::Batch2LoDTensorFunctor to_seq; batched_h_out->set_lod(batched_lod); @@ -594,18 +444,9 @@ class FuisonLSTMKernel : public framework::OpKernel { } } -#undef COMPUTE_CtHt_PEEPHOLE -#undef COMPUTE_CtHt -#undef GET_Ct_NOH0C0 -#undef COMPUTE_CtHt_NOH0C0 -#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0 -#undef GET_Ht -#undef GET_Ct #undef GEMM_WH_ADDON -#undef INIT_BASE_INPUT_DATAS -#undef INIT_BASE_SIZES -#undef INIT_BASE_INPUT_OUTPUT -#undef INIT_VEC_FUNC +#undef INIT_OTHER_DEFINES +#undef INIT_BASE_DEFINES }; } // namespace operators diff --git a/paddle/fluid/operators/isfinite_op.cc b/paddle/fluid/operators/isfinite_op.cc index 248c7793560db99c0af06421bf74808422016061..7b42efd623b31a703bf51d2d157130b3120b42a4 100644 --- a/paddle/fluid/operators/isfinite_op.cc +++ b/paddle/fluid/operators/isfinite_op.cc @@ -60,7 +60,7 @@ class OverflowOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) 1-dim tensor, contains a bool scalar. The output " "tensor of overflow operator."); AddComment(string::Sprintf(R"DOC( -Overflow operator. +Overflow %s operator. $$Out = any(X)$$ @@ -69,6 +69,8 @@ Out = Inf if any X contains Inf, Out = Nan if any X contains Nan, Out = 0 if no Inf/Nan detected. If X contains both Inf/Nan, it will return the first indicator it meeted. + +%s )DOC", GetName(), GetComments())); } diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 91101356436c26171eaca2fe01dfd4d937e71717..7365bfeeb8edf09a8ad5e1cb2c61300e86bdf518 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -3,8 +3,8 @@ add_subdirectory(detail) endif(NOT WIN32) function(math_library TARGET) - # math_library is a function to create math library. - # The interface is the same as cc_library. + # math_library is a function to create math library. + # The interface is the same as cc_library. # But it handle split GPU/CPU code and link some common library. set(cc_srcs) set(cu_srcs) @@ -45,15 +45,13 @@ math_library(im2col) if (NOT WIN32) # windows do not support avx functions yet. math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) -# TODO(TJ): ugly workaround, clean me -cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas cpu_info) endif (NOT WIN32) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) math_library(math_function DEPS blas) math_library(maxouting) math_library(pooling) -math_library(selected_rows_functor DEPS selected_rows math_function) +math_library(selected_rows_functor DEPS selected_rows math_function blas) math_library(sequence2batch) math_library(sequence_padding) math_library(sequence_pooling DEPS math_function) @@ -76,3 +74,7 @@ if(WITH_GPU) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) +cc_library(jit_kernel + SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc + DEPS cpu_info cblas activation_functions) +cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/algorithm.h b/paddle/fluid/operators/math/algorithm.h new file mode 100644 index 0000000000000000000000000000000000000000..262469beea7449eb5820b86de1ac4f790a833e79 --- /dev/null +++ b/paddle/fluid/operators/math/algorithm.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // for int64_t +#include + +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { + int64_t beg = 0, end = num - 1; + while (beg <= end) { + auto mid = ((beg + end) >> 1); + if (x[mid] == val) + return mid; + else if (x[mid] < val) + beg = mid + 1; + else + end = mid - 1; + } + return -1; +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc deleted file mode 100644 index e96d1879331974e0873e13f171414bcfa8c45953..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/cpu_lstm_compute.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/cpu_lstm_compute.h" - -namespace paddle { -namespace operators { -namespace math { -#ifdef __AVX__ -template <> -void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, - float* ht) { - namespace act = detail::forward::avx; - // gates: W_ch, W_ih, W_fh, W_oh - __m256 c, i, f, o; - c = _mm256_loadu_ps(gates); - i = _mm256_loadu_ps(gates + 8); - f = _mm256_loadu_ps(gates + 16); - o = _mm256_loadu_ps(gates + 24); - - /* C_t = C_t-1 * fgated + cand_gated * igated*/ - c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); - i = _mm256_loadu_ps(ct_1); - f = _mm256_mul_ps(i, act::Sigmoid(f)); - f = _mm256_add_ps(c, f); - _mm256_storeu_ps(ct, f); - - /* H_t = act_cell(C_t) * ogated */ - o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); - _mm256_storeu_ps(ht, o); -} -#endif -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h deleted file mode 100644 index 169a9e4b47f54851ad436428416eca879b78e186..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/cpu_lstm_compute.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/platform/cpu_info.h" -#ifdef __AVX__ -#include -#endif - -namespace paddle { -namespace operators { -namespace math { - -// TODO(TJ): ugly workaround, clean me -template -void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { - // gates: W_ch, W_ih, W_fh, W_oh - vec_sigmoid(24, gates + 8, gates + 8); - vec_tanh(8, gates, gates); - const T *i = gates + 8, *f = gates + 16, *o = gates + 24; - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - for (int d = 0; d < 8; ++d) { - // C_t = C_t-1 * fgated + cand_gated * igated - ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; - // H_t = act_cell(C_t) * ogated - T tmp = ct[d] * 2; - tmp = static_cast(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); - vec_exp(1, &tmp, &tmp); - tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); - ht[d] = tmp * o[d]; - } -} - -#ifdef __AVX__ -namespace detail { -namespace forward { -namespace avx { -__m256 Sigmoid(const __m256 a); -__m256 Tanh(const __m256 a); - -} // namespace avx -} // namespace forward -} // namespace detail - -template <> -void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, - float* ht); - -#endif - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 6a059968b79189458349e466079cc7a663a8e5ff..0aed253c80fc28560716cbcfa70f74ef9c84f9b6 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -125,10 +125,8 @@ inline void vec_scal(const int n, const float a, } template <> -inline void vec_scal(const int n, - const float a, - const float* x, - float* y) { +inline void vec_scal(const int n, const float a, + const float* x, float* y) { // TODO(TJ): enable me vec_scal(n, a, x, y); } @@ -181,10 +179,10 @@ inline void vec_bias_sub(const int n, const float a, } template <> -inline void vec_bias_sub(const int n, - const float a, - const float* x, - float* y) { +inline void vec_bias_sub(const int n, + const float a, + const float* x, + float* y) { // TODO(TJ): enable me vec_bias_sub(n, a, x, y); } @@ -242,7 +240,7 @@ inline void vec_cross(const int n, const float* x, } template <> -inline void vec_cross( +inline void vec_cross( const int n, const float* x, const float* y, const float* z, float* out) { // TODO(TJ): enable me vec_cross(n, x, y, z, out); @@ -296,10 +294,10 @@ inline void vec_add_bias(const int n, const float a, } template <> -inline void vec_add_bias(const int n, - const float a, - const float* x, - float* y) { +inline void vec_add_bias(const int n, + const float a, + const float* x, + float* y) { // TODO(TJ): enable me vec_add_bias(n, a, x, y); } @@ -390,9 +388,9 @@ inline void vec_sigmoid(const int n, const float* x, } template <> -inline void vec_sigmoid(const int n, - const float* x, - float* y) { +inline void vec_sigmoid(const int n, + const float* x, + float* y) { // TODO(TJ): enable me vec_sigmoid(n, x, y); } @@ -454,9 +452,8 @@ inline void vec_relu(const int n, const float* x, } template <> -inline void vec_relu(const int n, - const float* x, - float* y) { +inline void vec_relu(const int n, const float* x, + float* y) { // TODO(TJ): enable me vec_relu(n, x, y); } diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc index 3ce66f49ed8354c49e8af26ca6eb48fef654a40b..cd40f1b2f984126663a5711efac24fdf6d680b32 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -110,7 +110,7 @@ TEST(CpuVecTest, sigmoid) { TestAndBench(sz, vec_sigmoid, ref_sigmoid); TestAndBench(sz, vec_sigmoid, ref_sigmoid); TestAndBench(sz, vec_sigmoid, ref_sigmoid); - TestAndBench(sz, vec_sigmoid, + TestAndBench(sz, vec_sigmoid, ref_sigmoid); } TestAndBench(30, vec_sigmoid, ref_sigmoid); @@ -123,8 +123,7 @@ TEST(CpuVecTest, tanh) { TestAndBench(sz, vec_tanh, ref_tanh); TestAndBench(sz, vec_tanh, ref_tanh); TestAndBench(sz, vec_tanh, ref_tanh); - TestAndBench(sz, vec_tanh, - ref_tanh); + TestAndBench(sz, vec_tanh, ref_tanh); } TestAndBench(30, vec_tanh, ref_tanh); } @@ -136,8 +135,7 @@ TEST(CpuVecTest, relu) { TestAndBench(sz, vec_relu, ref_relu); TestAndBench(sz, vec_relu, ref_relu); TestAndBench(sz, vec_relu, ref_relu); - TestAndBench(sz, vec_relu, - ref_relu); + TestAndBench(sz, vec_relu, ref_relu); } TestAndBench(30, vec_relu, ref_relu); } @@ -170,7 +168,7 @@ TEST(CpuVecTest, inplace_sigmoid) { TestInplace(sz, vec_sigmoid, ref_sigmoid); TestInplace(sz, vec_sigmoid, ref_sigmoid); TestInplace(sz, vec_sigmoid, ref_sigmoid); - TestInplace(sz, vec_sigmoid, + TestInplace(sz, vec_sigmoid, ref_sigmoid); } TestInplace(30, vec_sigmoid, ref_sigmoid); @@ -183,8 +181,7 @@ TEST(CpuVecTest, inplace_tanh) { TestInplace(sz, vec_tanh, ref_tanh); TestInplace(sz, vec_tanh, ref_tanh); TestInplace(sz, vec_tanh, ref_tanh); - TestInplace(sz, vec_tanh, - ref_tanh); + TestInplace(sz, vec_tanh, ref_tanh); } TestInplace(30, vec_tanh, ref_tanh); } @@ -196,8 +193,7 @@ TEST(CpuVecTest, inplace_relu) { TestInplace(sz, vec_relu, ref_relu); TestInplace(sz, vec_relu, ref_relu); TestInplace(sz, vec_relu, ref_relu); - TestInplace(sz, vec_relu, - ref_relu); + TestInplace(sz, vec_relu, ref_relu); } TestInplace(30, vec_relu, ref_relu); } diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index 3be389912307f7aac6dda6d1018943eb8f08696d..66d37c3bf31ffa420cc527cb576dcdc5505a0960 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -46,17 +46,20 @@ __forceinline__ __device__ unsigned warp_id() { return ret; } +#define ARG_DEFINE_KernelDepthwiseConv \ + const T *const input_data, const T *const filter_data, const int batch_size, \ + const int output_channels, const int output_height, \ + const int output_width, const int input_channels, \ + const int input_height, const int input_width, \ + const int filter_multiplier, const int filter_height, \ + const int filter_width, const int stride_height, const int stride_width, \ + const int padding_height, const int padding_width, \ + const int dilate_height, const int dilate_width, T *const output_data + // A Cuda kernel to compute the depthwise convolution forward pass // in NCHW format. template -__device__ __inline__ void KernelDepthwiseConv( - const T* const input_data, const T* const filter_data, const int batch_size, - const int output_channels, const int output_height, const int output_width, - const int input_channels, const int input_height, const int input_width, - const int filter_multiplier, const int filter_height, - const int filter_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* const output_data) { +__device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { const int batch = blockIdx.y; @@ -97,42 +100,105 @@ __device__ __inline__ void KernelDepthwiseConv( } } -template -__global__ void KernelDepthwiseConvSp( - const T* const input_data, const T* const filter_data, const int batch_size, - const int output_channels, const int output_height, const int output_width, - const int input_channels, const int input_height, const int input_width, - const int filter_multiplier, const int filter_height, - const int filter_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* const output_data) { - if (c_filter_multiplier == 0) - KernelDepthwiseConv(input_data, filter_data, batch_size, output_channels, - output_height, output_width, input_channels, - input_height, input_width, filter_multiplier, - filter_height, filter_width, stride_height, - stride_width, padding_height, padding_width, - dilate_height, dilate_width, output_data); +template +__device__ __inline__ void KernelDepthwiseConvCFilter( + ARG_DEFINE_KernelDepthwiseConv) { + const int kWeghtSize = c_filter * c_filter; + T r_weight[kWeghtSize]; + const int batch = blockIdx.y; + const int c_out = blockIdx.x; + const T* weight = filter_data + c_out * c_filter * c_filter; + for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i]; - else - KernelDepthwiseConv(input_data, filter_data, batch_size, output_channels, - output_height, output_width, input_channels, - input_height, input_width, c_filter_multiplier, - filter_height, filter_height, c_stride, c_stride, - padding_height, padding_width, dilate_height, - dilate_width, output_data); + for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { + for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { + const int batch = blockIdx.y; + const int c_out = blockIdx.x; + + const int c_in = c_out / filter_multiplier; + T value = 0; + const int h_in_start = -padding_height + h_out * stride_height; + const int w_in_start = -padding_width + w_out * stride_width; + const int h_in_end = h_in_start + c_filter * dilate_height; + const int w_in_end = w_in_start + c_filter * dilate_width; + + const int in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; + + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; + + for (int h_in = h_in_start, h_f = 0; h_f < c_filter; + h_in += dilate_height, h_f++) { + for (int w_in = w_in_start, w_f = 0; w_f < c_filter; + w_in += dilate_width, w_f++) { + if (h_in >= 0 && h_in < input_height && w_in >= 0 && + w_in < input_width) { + const int offset = in_offset + h_in * input_width + w_in; + value += r_weight[h_f * c_filter + w_f] * input_data[offset]; + } + } + } + int index = + ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + + w_out; + output_data[index] = value; + } + } +} + +template +__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { + if (c_filter_multiplier == 0) { + if (c_filter == -1) + KernelDepthwiseConv( + input_data, filter_data, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + filter_multiplier, filter_height, filter_width, stride_height, + stride_width, padding_height, padding_width, dilate_height, + dilate_width, output_data); + else + KernelDepthwiseConvCFilter( + input_data, filter_data, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + filter_multiplier, filter_height, filter_width, stride_height, + stride_width, padding_height, padding_width, dilate_height, + dilate_width, output_data); + } else { + if (c_filter == -1) + KernelDepthwiseConv(input_data, filter_data, batch_size, + output_channels, output_height, output_width, + input_channels, input_height, input_width, + c_filter_multiplier, filter_height, filter_height, + c_stride, c_stride, padding_height, padding_width, + dilate_height, dilate_width, output_data); + else + KernelDepthwiseConvCFilter( + input_data, filter_data, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, + padding_height, padding_width, dilate_height, dilate_width, + output_data); + } } // CUDA kernel to compute the depthwise convolution backprop w.r.t input. +#define ARG_DEFINE_KernelDepthwiseConvInputGrad \ + const T *const output_grad_data, const T *const filter_data, \ + const int batch_size, const int output_channels, \ + const int output_height, const int output_width, \ + const int input_channels, const int input_height, const int input_width, \ + const int filter_multiplier, const int filter_height, \ + const int filter_width, const int stride_height, const int stride_width, \ + const int padding_height, const int padding_width, \ + const int dilate_height, const int dilate_width, \ + T *const input_grad_data + template __device__ __inline__ void KernelDepthwiseConvInputGrad( - const T* const output_grad_data, const T* const filter_data, - const int batch_size, const int output_channels, const int output_height, - const int output_width, const int input_channels, const int input_height, - const int input_width, const int filter_multiplier, const int filter_height, - const int filter_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* const input_grad_data) { + ARG_DEFINE_KernelDepthwiseConvInputGrad) { for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { const int batch = blockIdx.y; @@ -184,15 +250,67 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( } } -template +template +__device__ __inline__ void KernelDepthwiseConvInputGradCFilter( + ARG_DEFINE_KernelDepthwiseConvInputGrad) { + const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1; + T r_weight[kWeghtSize]; + const int batch = blockIdx.y; + const int c_in = blockIdx.x; + + for (int c_i = 0; c_i < filter_multiplier; c_i++) { + int c_out = c_in * filter_multiplier + c_i; + const T* weight = filter_data + c_out * c_filter * c_filter; + for (int i = 0; i < c_filter * c_filter; i++) + r_weight[i + c_i * c_filter * c_filter] = + weight[c_filter * c_filter - i - 1]; + } + + for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { + for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { + const int batch = blockIdx.y; + const int c_in = blockIdx.x; + + int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height; + + int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width; + + T value = 0; + + for (int c_i = 0; c_i < filter_multiplier; c_i++) { + int c_out = c_in * filter_multiplier + c_i; + for (int h_out = h_out_start, h_f = 0; h_f < c_filter; + h_out += dilate_height, h_f++) { + for (int w_out = w_out_start, w_f = 0; w_f < c_filter; + w_out += dilate_width, w_f++) { + int s_h_out = h_out / stride_height; + int s_w_out = w_out / stride_width; + if (h_out % stride_height == 0 && w_out % stride_width == 0 && + s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && + s_w_out < output_width) { + const int output_grad_offset = + ((batch * output_channels + c_out) * output_height + + s_h_out) * + output_width + + s_w_out; + value += + output_grad_data[output_grad_offset] * + r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter]; + } + } + } + } + int index = + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + + w_in; + input_grad_data[index] = value; + } + } +} + +template __global__ void KernelDepthwiseConvInputGradSp( - const T* const output_grad_data, const T* const filter_data, - const int batch_size, const int output_channels, const int output_height, - const int output_width, const int input_channels, const int input_height, - const int input_width, const int filter_multiplier, const int filter_height, - const int filter_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* const input_grad_data) { + ARG_DEFINE_KernelDepthwiseConvInputGrad) { if (c_filter_multiplier == 0) KernelDepthwiseConvInputGrad( output_grad_data, filter_data, batch_size, output_channels, @@ -200,13 +318,20 @@ __global__ void KernelDepthwiseConvInputGradSp( filter_multiplier, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, dilate_height, dilate_width, input_grad_data); - else + else if (c_filter == -1) KernelDepthwiseConvInputGrad( output_grad_data, filter_data, batch_size, output_channels, output_height, output_width, input_channels, input_height, input_width, c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, padding_height, padding_width, dilate_height, dilate_width, input_grad_data); + else + KernelDepthwiseConvInputGradCFilter( + output_grad_data, filter_data, batch_size, output_channels, + output_height, output_width, input_channels, input_height, input_width, + c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, + padding_height, padding_width, dilate_height, dilate_width, + input_grad_data); } // Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. @@ -325,12 +450,14 @@ class DepthwiseConvFunctor { dim3 threads(std::min(output_width, thread), blocks, 1); dim3 grid(output_channels, batch_size, 1); int filter_multiplier = output_channels / input_channels; -#define check_case(c_filter_multiplier, c_stride) \ +#define check_case(c_filter_multiplier, c_stride, c_filter) \ if (c_filter_multiplier == 0 || \ filter_multiplier == c_filter_multiplier && \ - stride_height == stride_width && stride_height == c_stride) { \ - KernelDepthwiseConvSp<<>>( \ + stride_height == stride_width && stride_height == c_stride && \ + (ksize_height == ksize_width && ksize_height == c_filter || \ + c_filter == -1)) { \ + KernelDepthwiseConvSp<<>>( \ input_data, filter_data, batch_size, output_channels, output_height, \ output_width, input_channels, input_height, input_width, \ filter_multiplier, ksize_height, ksize_width, stride_height, \ @@ -338,11 +465,17 @@ class DepthwiseConvFunctor { dilate_width, output_data); \ return; \ } - check_case(1, 1); - check_case(1, 2); - // NOTE(liangdun): 0,0 for other case - // add other case if needed, e.g. check_case(2^n,1) - check_case(0, 0); + check_case(1, 1, 3); + check_case(1, 1, 5); + check_case(1, 1, -1); + check_case(1, 2, 3); + check_case(1, 2, 5); + check_case(1, 2, -1); + check_case(0, 0, 3); + check_case(0, 0, 5); + check_case(0, 0, -1); +// NOTE(liangdun): 0,0 for other case +// add other case if needed, e.g. check_case(2^n,1) #undef check_case } }; @@ -384,13 +517,15 @@ class DepthwiseConvInputGradFunctor { dim3 grid(input_channels, batch_size, 1); int filter_multiplier = output_channels / input_channels; -#define check_case(c_filter_multiplier, c_stride) \ +#define check_case(c_filter_multiplier, c_stride, c_filter) \ if (c_filter_multiplier == 0 || \ filter_multiplier == c_filter_multiplier && \ - stride_height == stride_width && stride_height == c_stride) { \ + stride_height == stride_width && stride_height == c_stride && \ + (ksize_height == ksize_width && ksize_height == c_filter || \ + c_filter == -1)) { \ KernelDepthwiseConvInputGradSp< \ - T, c_filter_multiplier, \ - c_stride><<>>( \ + T, c_filter_multiplier, c_stride, \ + c_filter><<>>( \ output_grad_data, filter_data, batch_size, output_channels, \ output_height, output_width, input_channels, input_height, \ input_width, filter_multiplier, ksize_height, ksize_width, \ @@ -398,11 +533,21 @@ class DepthwiseConvInputGradFunctor { dilate_height, dilate_width, input_grad_data); \ return; \ } - check_case(1, 1); - check_case(1, 2); - // NOTE(liangdun): 0,0 for other case - // add other case if needed, e.g. check_case(2^n,1) - check_case(0, 0); + check_case(1, 1, 3); + check_case(1, 1, 5); + check_case(1, 1, -1); + check_case(1, 2, 3); + check_case(1, 2, 5); + check_case(1, 2, -1); + check_case(2, 1, 3); + check_case(2, 1, 5); + check_case(2, 1, -1); + check_case(2, 2, 3); + check_case(2, 2, 5); + check_case(2, 2, -1); + check_case(0, 0, -1); +// NOTE(liangdun): 0,0 for other case +// add other case if needed, e.g. check_case(2^n,1) #undef check_case } }; diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..68b708b345334bc63b5e2e88c308d20ca6378e6b --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel.cc @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include +#include + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { + +namespace jit = platform::jit; + +KernelPool& KernelPool::Instance() { + static thread_local KernelPool g_jit_kernels; + return g_jit_kernels; +} + +std::shared_ptr KernelPool::Get(const std::string& key) const { + if (kers_.find(key) == kers_.end()) { + return nullptr; + } + return kers_.at(key); +} + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b4dfda6db76fd4231be0acd1f90c98a2d62134b8 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -0,0 +1,142 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include // for shared_ptr +#include +#include +#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/macros.h" + +// Note: Only support on CPU yet. +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { + +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 +#define AVX_FLOAT_BLOCK 8 +#define AVX2_FLOAT_BLOCK 8 +#define AVX512_FLOAT_BLOCK 16 + +typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; + +class Kernel { + public: + Kernel() = default; + virtual ~Kernel() = default; + int num_{0}; + int end_{0}; + int rest_{0}; + DISABLE_COPY_AND_ASSIGN(Kernel); +}; + +class KernelPool { + public: + static KernelPool &Instance(); + + template + std::shared_ptr Get(ARGS... args); + + std::shared_ptr Get(const std::string &key) const; + + private: + KernelPool() = default; + std::unordered_map> kers_; + + DISABLE_COPY_AND_ASSIGN(KernelPool); +}; + +template +class VMulKernel : public Kernel { + public: + virtual void Compute(const T *x, const T *y, T *z) const = 0; +}; + +template +class VAddKernel : public Kernel { + public: + virtual void Compute(const T *x, const T *y, T *z) const = 0; +}; + +template +class VScalKernel : public Kernel { + public: + virtual void Compute(const T a, const T *x, T *y) const = 0; + virtual void Compute(const T a, T *x) const = 0; +}; + +template +class VAddBiasKernel : public Kernel { + public: + virtual void Compute(const T a, const T *x, T *y) const = 0; +}; + +template +class VActKernel : public Kernel { + public: + virtual void Compute(const T *x, T *y) const = 0; +}; + +template +class VReluKernel : public VActKernel { + public: + virtual void Compute(const T *x, T *y) const = 0; +}; + +template +class VIdentityKernel : public VActKernel { + public: + virtual void Compute(const T *x, T *y) const = 0; +}; + +template +class VExpKernel : public VActKernel { + public: + virtual void Compute(const T *x, T *y) const = 0; +}; + +template +class VSigmoidKernel : public VActKernel { + public: + virtual void Compute(const T *x, T *y) const = 0; +}; + +template +class VTanhKernel : public VActKernel { + public: + virtual void Compute(const T *x, T *y) const = 0; +}; + +template +class LSTMKernel : public Kernel { + public: + virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht, + /* below only used in peephole*/ + const T *wp_data = nullptr, + T *checked = nullptr) const = 0; + + // compute c1 and h1 without c0 or h0 + virtual void ComputeC1H1(T *gates, T *ct, T *ht, + /* below only used in peephole*/ + const T *wp_data = nullptr) const = 0; +}; + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f9ea533fccdd34a5ccf061d89ffe92687d65933 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -0,0 +1,391 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include +#include "paddle/fluid/operators/math/jit_kernel_macro.h" +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif + +#ifdef __AVX__ +#include +#endif + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { + +namespace jit = platform::jit; + +/* VMUL JitKernel */ +template +class VMulKernelImpl : public VMulKernel { + public: + explicit VMulKernelImpl(int d) : VMulKernel() { this->num_ = d; } + void Compute(const T* x, const T* y, T* z) const override { + for (int i = 0; i < this->num_; ++i) { + z[i] = x[i] * y[i]; + } + } +}; + +#ifdef PADDLE_WITH_MKLML +#define MKL_FLOAT(isa, block) \ + template <> \ + void VMulKernelImpl::Compute( \ + const float* x, const float* y, float* z) const { \ + platform::dynload::vsMul(this->num_, x, y, z); \ + } + +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VMulKernelImpl::Compute( \ + const double* x, const double* y, double* z) const { \ + platform::dynload::vdMul(this->num_, x, y, z); \ + } + +FOR_EACH_ISA(MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(MKL_DOUBLE); +#endif + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VMulKernelImpl::Compute( \ + const float* x, const float* y, float* z) const { \ + __m256 tmpx, tmpy; \ + tmpx = _mm256_loadu_ps(x); \ + tmpy = _mm256_loadu_ps(y); \ + tmpx = _mm256_mul_ps(tmpx, tmpy); \ + _mm256_storeu_ps(z, tmpx); \ + } + +// avx > for > mkl +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +#endif +// TODO(TJ): eq16 test and complete avx512 +#undef INTRI8_FLOAT +#undef MKL_FLOAT +#undef MKL_DOUBLE + +/* VADD JitKernel */ +template +class VAddKernelImpl : public VAddKernel { + public: + explicit VAddKernelImpl(int d) : VAddKernel() { this->num_ = d; } + void Compute(const T* x, const T* y, T* z) const override { + for (int i = 0; i < this->num_; ++i) { + z[i] = x[i] + y[i]; + } + } +}; + +#ifdef PADDLE_WITH_MKLML +#define MKL_FLOAT(isa, block) \ + template <> \ + void VAddKernelImpl::Compute( \ + const float* x, const float* y, float* z) const { \ + platform::dynload::vsAdd(this->num_, x, y, z); \ + } + +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VAddKernelImpl::Compute( \ + const double* x, const double* y, double* z) const { \ + platform::dynload::vdAdd(this->num_, x, y, z); \ + } + +FOR_EACH_ISA(MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(MKL_DOUBLE); +#endif + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VAddKernelImpl::Compute( \ + const float* x, const float* y, float* z) const { \ + __m256 tmpx, tmpy; \ + tmpx = _mm256_loadu_ps(x); \ + tmpy = _mm256_loadu_ps(y); \ + tmpx = _mm256_add_ps(tmpx, tmpy); \ + _mm256_storeu_ps(z, tmpx); \ + } +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +#endif +// TODO(TJ): eq16 test and complete avx512 + +#undef INTRI8_FLOAT +#undef MKL_FLOAT +#undef MKL_DOUBLE + +/* VSCAL JitKernel */ +template +class VScalKernelImpl : public VScalKernel { + public: + explicit VScalKernelImpl(int d) : VScalKernel() { this->num_ = d; } + void Compute(const T a, const T* x, T* y) const override { + for (int i = 0; i < this->num_; ++i) { + y[i] = a * x[i]; + } + } + void Compute(const T a, T* x) const override { + for (int i = 0; i < this->num_; ++i) { + x[i] = a * x[i]; + } + } +}; + +#ifdef PADDLE_WITH_MKLML +#define MKL_FLOAT(isa, block) \ + template <> \ + void VScalKernelImpl::Compute(const float a, float* x) \ + const { \ + platform::dynload::cblas_sscal(this->num_, a, x, 1); \ + } + +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VScalKernelImpl::Compute(const double a, double* x) \ + const { \ + platform::dynload::cblas_dscal(this->num_, a, x, 1); \ + } + +FOR_EACH_ISA(MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(MKL_DOUBLE); +#endif + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VScalKernelImpl::Compute( \ + const float a, const float* x, float* y) const { \ + __m256 tmp; \ + __m256 scalar = _mm256_set1_ps(a); \ + tmp = _mm256_loadu_ps(x); \ + tmp = _mm256_mul_ps(tmp, scalar); \ + _mm256_storeu_ps(y, tmp); \ + } +#define INTRI8_INPLACE_FLOAT(isa) \ + template <> \ + void VScalKernelImpl::Compute(const float a, float* x) \ + const { \ + __m256 tmp; \ + __m256 scalar = _mm256_set1_ps(a); \ + tmp = _mm256_loadu_ps(x); \ + tmp = _mm256_mul_ps(tmp, scalar); \ + _mm256_storeu_ps(x, tmp); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI8_INPLACE_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI8_INPLACE_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +INTRI8_INPLACE_FLOAT(jit::avx512f); +#endif +// TODO(TJ): eq16 test and complete avx512 + +#undef INTRI8_FLOAT +#undef INTRI8_INPLACE_FLOAT +#undef MKL_FLOAT +#undef MKL_DOUBLE + +/* VAddBias JitKernel */ +template +class VAddBiasKernelImpl : public VAddBiasKernel { + public: + explicit VAddBiasKernelImpl(int d) : VAddBiasKernel() { this->num_ = d; } + void Compute(const T a, const T* x, T* y) const override { + for (int i = 0; i < this->num_; ++i) { + y[i] = x[i] + a; + } + } +}; + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VAddBiasKernelImpl::Compute( \ + const float a, const float* x, float* y) const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \ + _mm256_storeu_ps(y, tmp); \ + } + +#define INTRI16_FLOAT(isa) \ + template <> \ + void VAddBiasKernelImpl::Compute( \ + const float a, const float* x, float* y) const { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \ + tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI16_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI16_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +INTRI16_FLOAT(jit::avx512f); +#endif +// TODO(TJ): eq16 test and complete avx512 + +#undef INTRI8_FLOAT +#undef INTRI16_FLOAT + +/* VRelu JitKernel */ +template +class VReluKernelImpl : public VReluKernel { + public: + explicit VReluKernelImpl(int d) : VReluKernel() { this->num_ = d; } + void Compute(const T* x, T* y) const override { + for (int i = 0; i < this->num_; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } + } +}; + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VReluKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); \ + _mm256_storeu_ps(y, tmp); \ + } + +#define INTRI16_FLOAT(isa) \ + template <> \ + void VReluKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 zeros = _mm256_setzero_ps(); \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + tmp0 = _mm256_max_ps(tmp0, zeros); \ + tmp1 = _mm256_max_ps(tmp1, zeros); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ + } + +#define INTRI_GT8LT16_FLOAT(isa) \ + template <> \ + VReluKernelImpl::VReluKernelImpl(int d) \ + : VReluKernel() { \ + this->num_ = d; \ + this->end_ = AVX_FLOAT_BLOCK; \ + this->rest_ = d - AVX_FLOAT_BLOCK; \ + } \ + template <> \ + void VReluKernelImpl::Compute(const float* x, \ + float* y) const { \ + __m256 zeros = _mm256_setzero_ps(); \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + this->rest_); \ + tmp0 = _mm256_max_ps(tmp0, zeros); \ + tmp1 = _mm256_max_ps(tmp1, zeros); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + this->rest_, tmp1); \ + } + +#define INTRI_GT16_FLOAT(isa) \ + template <> \ + VReluKernelImpl::VReluKernelImpl(int d) \ + : VReluKernel() { \ + this->num_ = d; \ + this->end_ = d - d % AVX_FLOAT_BLOCK; \ + this->rest_ = d - AVX_FLOAT_BLOCK; \ + } \ + template <> \ + void VReluKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 zeros = _mm256_setzero_ps(); \ + for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ + __m256 tmp = _mm256_loadu_ps(x + i); \ + tmp = _mm256_max_ps(tmp, zeros); \ + _mm256_storeu_ps(y + i, tmp); \ + } \ + __m256 tmp = _mm256_loadu_ps(x + this->rest_); \ + tmp = _mm256_max_ps(tmp, zeros); \ + _mm256_storeu_ps(y + this->rest_, tmp); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI16_FLOAT(jit::avx); +INTRI_GT8LT16_FLOAT(jit::avx); +INTRI_GT16_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI16_FLOAT(jit::avx2); +INTRI_GT8LT16_FLOAT(jit::avx2); +INTRI_GT16_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +// TODO(TJ): refine avx512 +INTRI8_FLOAT(jit::avx512f); +INTRI16_FLOAT(jit::avx512f); +INTRI_GT8LT16_FLOAT(jit::avx512f); +INTRI_GT16_FLOAT(jit::avx512f); +#endif + +#undef INTRI8_FLOAT +#undef INTRI16_FLOAT +#undef INTRI_GT8LT16_FLOAT +#undef INTRI_GT16_FLOAT + +/* An empty JitKernel */ +template +class VIdentityKernelImpl : public VIdentityKernel { + public: + explicit VIdentityKernelImpl(int d) : VIdentityKernel() { this->num_ = d; } + void Compute(const T* x, T* y) const override {} +}; + +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vscal, VScalKernel); +REGISTER_JITKERNEL(vaddb, VAddBiasKernel); +REGISTER_JITKERNEL(vrelu, VReluKernel); +REGISTER_JITKERNEL(videntity, VIdentityKernel); + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc new file mode 100644 index 0000000000000000000000000000000000000000..b62e130c43743f542e2074868fc01598047d6b19 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -0,0 +1,400 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include // for exp +#include +#include "paddle/fluid/operators/math/jit_kernel_macro.h" +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif + +#ifdef __AVX__ +#include +#endif + +namespace paddle { +namespace operators { +namespace math { + +#ifdef __AVX__ +namespace detail { +__m256 Exp(__m256 a); +} // namespace detail +#endif + +namespace jitkernel { +namespace jit = platform::jit; + +/* VExp JitKernel */ +template +class VExpKernelImpl : public VExpKernel { + public: + explicit VExpKernelImpl(int d) : VExpKernel() { this->num_ = d; } + void Compute(const T* x, T* y) const override { + for (int i = 0; i < this->num_; ++i) { + y[i] = std::exp(x[i]); + } + } +}; + +#ifdef PADDLE_WITH_MKLML +#define MKL_FLOAT(isa, block) \ + template <> \ + void VExpKernelImpl::Compute(const float* x, float* y) \ + const { \ + platform::dynload::vsExp(this->num_, x, y); \ + } + +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VExpKernelImpl::Compute(const double* x, double* y) \ + const { \ + platform::dynload::vdExp(this->num_, x, y); \ + } +FOR_EACH_ISA(MKL_FLOAT, kLT8); +FOR_EACH_ISA(MKL_FLOAT, kGT8LT16); +FOR_EACH_ISA(MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(MKL_DOUBLE); +#endif + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VExpKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + _mm256_storeu_ps(y, detail::Exp(tmp)); \ + } + +#define INTRI16_FLOAT(isa) \ + template <> \ + void VExpKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + tmp0 = detail::Exp(tmp0); \ + tmp1 = detail::Exp(tmp1); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI16_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI16_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +INTRI16_FLOAT(jit::avx512f); +#endif +// TODO(TJ): eq16 test and complete avx512 + +#undef INTRI8_FLOAT +#undef INTRI16_FLOAT +#undef MKL_FLOAT +#undef MKL_DOUBLE + +REGISTER_JITKERNEL(vexp, VExpKernel); + +/* VSigmoid JitKernel */ +template +class VSigmoidKernelImpl : public VSigmoidKernel { + public: + explicit VSigmoidKernelImpl(int d) : VSigmoidKernel() { + this->num_ = d; + vexp_ = KernelPool::Instance().template Get>(d); + } + void Compute(const T* x, T* y) const override { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < this->num_; ++i) { + y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = static_cast(0) - y[i]; + } + vexp_->Compute(y, y); + for (int i = 0; i < this->num_; ++i) { + y[i] = static_cast(1) / (static_cast(1) + y[i]); + } + } + + private: + std::shared_ptr> vexp_; +}; + +#define INTRI_SIGMOID(tmp, min, max) \ + tmp = _mm256_max_ps(tmp, min); \ + tmp = _mm256_min_ps(tmp, max); \ + tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \ + tmp = detail::Exp(tmp); \ + tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ + tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VSigmoidKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_SIGMOID(tmp, min, max); \ + _mm256_storeu_ps(y, tmp); \ + } + +#define INTRI16_FLOAT(isa) \ + template <> \ + void VSigmoidKernelImpl::Compute(const float* x, \ + float* y) const { \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + INTRI_SIGMOID(tmp0, min, max); \ + INTRI_SIGMOID(tmp1, min, max); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ + } + +#define INTRI_GT8LT16_FLOAT(isa) \ + template <> \ + VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ + : VSigmoidKernel() { \ + this->num_ = d; \ + this->end_ = AVX_FLOAT_BLOCK; \ + this->rest_ = d - this->end_; \ + vexp_ = \ + KernelPool::Instance().template Get>(this->rest_); \ + } \ + template <> \ + void VSigmoidKernelImpl::Compute(const float* x, \ + float* y) const { \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_SIGMOID(tmp, min, max); \ + _mm256_storeu_ps(y, tmp); \ + const float min_ = SIGMOID_THRESHOLD_MIN; \ + const float max_ = SIGMOID_THRESHOLD_MAX; \ + for (int i = this->end_; i < this->num_; ++i) { \ + y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ + y[i] = 0.f - y[i]; \ + } \ + vexp_->Compute(y + this->end_, y + this->end_); \ + for (int i = this->end_; i < this->num_; ++i) { \ + y[i] = 1.f / (1.f + y[i]); \ + } \ + } + +#define INTRI_GT16_FLOAT(isa) \ + template <> \ + VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ + : VSigmoidKernel() { \ + this->num_ = d; \ + this->rest_ = d % AVX_FLOAT_BLOCK; \ + this->end_ = d - this->rest_; \ + vexp_ = \ + KernelPool::Instance().template Get>(this->rest_); \ + } \ + template <> \ + void VSigmoidKernelImpl::Compute(const float* x, \ + float* y) const { \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ + __m256 tmp = _mm256_loadu_ps(x + i); \ + INTRI_SIGMOID(tmp, min, max); \ + _mm256_storeu_ps(y + i, tmp); \ + } \ + const float min_ = SIGMOID_THRESHOLD_MIN; \ + const float max_ = SIGMOID_THRESHOLD_MAX; \ + for (int i = this->end_; i < this->num_; ++i) { \ + y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ + y[i] = 0.f - y[i]; \ + } \ + vexp_->Compute(y + this->end_, y + this->end_); \ + for (int i = this->end_; i < this->num_; ++i) { \ + y[i] = 1.f / (1.f + y[i]); \ + } \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI16_FLOAT(jit::avx); +INTRI_GT8LT16_FLOAT(jit::avx); +INTRI_GT16_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI16_FLOAT(jit::avx2); +// INTRI_GT8LT16_FLOAT(jit::avx2); +// INTRI_GT16_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +INTRI16_FLOAT(jit::avx512f); +// INTRI_GT8LT16_FLOAT(jit::avx512f); +// INTRI_GT16_FLOAT(jit::avx512f); +#endif + +#undef INTRI8_FLOAT +#undef INTRI16_FLOAT +#undef INTRI_GT8LT16_FLOAT +#undef INTRI_GT16_FLOAT +#undef INTRI_VSIGMOID + +REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel); + +/* VTanh JitKernel */ +template +class VTanhKernelImpl : public VTanhKernel { + public: + explicit VTanhKernelImpl(int d) : VTanhKernel() { + this->num_ = d; + vscal_ = KernelPool::Instance().template Get>(d); + vsigmoid_ = KernelPool::Instance().template Get>(d); + vaddbias_ = KernelPool::Instance().template Get>(d); + } + void Compute(const T* x, T* y) const override { + vscal_->Compute(static_cast(2), x, y); + vsigmoid_->Compute(y, y); + vscal_->Compute(static_cast(2), y); + vaddbias_->Compute(static_cast(-1), y, y); + } + + private: + std::shared_ptr> vscal_; + std::shared_ptr> vsigmoid_; + std::shared_ptr> vaddbias_; +}; + +#define INTRI_VTANH(tmp) \ + tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \ + tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \ + tmp = detail::Exp(tmp); \ + tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ + tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ + tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) + +#define INTRI8_FLOAT(isa) \ + template <> \ + void VTanhKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_VTANH(tmp); \ + _mm256_storeu_ps(y, tmp); \ + } + +#define INTRI16_FLOAT(isa) \ + template <> \ + void VTanhKernelImpl::Compute(const float* x, float* y) \ + const { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + INTRI_VTANH(tmp0); \ + INTRI_VTANH(tmp1); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ + } + +#define INTRI_GT8LT16_FLOAT(isa) \ + template <> \ + VTanhKernelImpl::VTanhKernelImpl(int d) \ + : VTanhKernel() { \ + this->num_ = d; \ + this->end_ = AVX_FLOAT_BLOCK; \ + this->rest_ = d - this->end_; \ + vscal_ = \ + KernelPool::Instance().template Get>(this->rest_); \ + vsigmoid_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + vaddbias_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + } \ + template <> \ + void VTanhKernelImpl::Compute(const float* x, \ + float* y) const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_VTANH(tmp); \ + _mm256_storeu_ps(y, tmp); \ + x += AVX_FLOAT_BLOCK; \ + y += AVX_FLOAT_BLOCK; \ + vscal_->Compute(2.f, x, y); \ + vsigmoid_->Compute(y, y); \ + vscal_->Compute(2.f, y); \ + vaddbias_->Compute(-1.f, y, y); \ + } + +#define INTRI_GT16_FLOAT(isa) \ + template <> \ + VTanhKernelImpl::VTanhKernelImpl(int d) \ + : VTanhKernel() { \ + this->num_ = d; \ + this->rest_ = d % AVX_FLOAT_BLOCK; \ + this->end_ = d - this->rest_; \ + vscal_ = \ + KernelPool::Instance().template Get>(this->rest_); \ + vsigmoid_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + vaddbias_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + } \ + template <> \ + void VTanhKernelImpl::Compute(const float* x, float* y) \ + const { \ + for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ + __m256 tmp = _mm256_loadu_ps(x + i); \ + INTRI_VTANH(tmp); \ + _mm256_storeu_ps(y + i, tmp); \ + } \ + x += this->end_; \ + y += this->end_; \ + vscal_->Compute(2.f, x, y); \ + vsigmoid_->Compute(y, y); \ + vscal_->Compute(2.f, y); \ + vaddbias_->Compute(-1.f, y, y); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +INTRI16_FLOAT(jit::avx); +INTRI_GT8LT16_FLOAT(jit::avx); +INTRI_GT16_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +INTRI16_FLOAT(jit::avx2); +// maybe use avx at gt8lt16 and gt16 +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +INTRI16_FLOAT(jit::avx512f); +// maybe use avx at gt8lt16 and gt16 +#endif + +#undef INTRI8_FLOAT +#undef INTRI16_FLOAT +#undef INTRI_GT8LT16_FLOAT +#undef INTRI_GT16_FLOAT +#undef INTRI_VTANH + +REGISTER_JITKERNEL(vtanh, VTanhKernel); + +#undef JITKERNEL_NEW_ACT_IMPL + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_lstm.cc b/paddle/fluid/operators/math/jit_kernel_lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..42a2b96fd945c516f8c26ca51ecb452345a9a86f --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_lstm.cc @@ -0,0 +1,308 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include +#include "paddle/fluid/operators/math/jit_kernel_macro.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/macros.h" + +#ifdef __AVX__ +#include +#endif + +namespace paddle { +namespace operators { +namespace math { +#ifdef __AVX__ +namespace detail { +__m256 Exp(__m256 a); +} // namespace detail +#endif + +namespace jitkernel { +namespace jit = platform::jit; + +#ifdef __AVX__ +typedef enum { kSigmoid, kRelu, kTanh, kIdentity } act_type; + +class AVXAct { + public: + virtual ~AVXAct() = default; + virtual __m256 Compute(__m256 x) const = 0; +}; + +template +class AVXActImpl : public AVXAct { + public: + __m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); } +}; + +template <> +__m256 AVXActImpl::Compute(__m256 x) const { + __m256 ones = _mm256_set1_ps(1.0f); + x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); + x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); + x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); + x = detail::Exp(x); + x = _mm256_add_ps(ones, x); + return _mm256_div_ps(ones, x); +} + +template <> +__m256 AVXActImpl::Compute(__m256 x) const { + __m256 ones = _mm256_set1_ps(1.0f); + x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); + x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); + x = detail::Exp(x); + x = _mm256_add_ps(ones, x); + x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); + return _mm256_sub_ps(x, ones); +} + +template <> +__m256 AVXActImpl::Compute(__m256 x) const { + return _mm256_max_ps(x, _mm256_setzero_ps()); +} + +template <> +__m256 AVXActImpl::Compute(__m256 x) const { + return x; +} +#endif + +template +static std::shared_ptr> GetActKernel( + const std::string& type, int n) { + if (type == "sigmoid") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "relu") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "tanh") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } else if (type == "identity" || type == "") { + return std::dynamic_pointer_cast>( + KernelPool::Instance().template Get>(n)); + } + PADDLE_THROW("Not support type: %s", type); + return nullptr; +} + +/* LSTM JitKernel */ +template +class LSTMKernelImpl : public LSTMKernel { + public: + explicit LSTMKernelImpl(const std::string& act_gate, + const std::string& act_cand, + const std::string& act_cell, int d) + : LSTMKernel() { + d_ = d; + d2_ = d * 2; + d3_ = d * 3; + act_gate_d3_ = GetActKernel(act_gate, d3_); + act_gate_d_ = GetActKernel(act_gate, d); + act_cand_d_ = GetActKernel(act_cand, d); + act_cell_d_ = GetActKernel(act_cell, d); + vmul_d_ = KernelPool::Instance().template Get>(d); + vadd_d_ = KernelPool::Instance().template Get>(d); +#ifdef __AVX__ + auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr { + if (type == "sigmoid") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "relu") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "tanh") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "identity" || type == "") { + return std::unique_ptr(new AVXActImpl()); + } + PADDLE_THROW("Not support type: %s", type); + }; + avx_act_gate_ = GetAVXAct(act_gate); + avx_act_cand_ = GetAVXAct(act_cand); + avx_act_cell_ = GetAVXAct(act_cell); +#endif + } + + void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, + T* checked) const override { + // gates: W_ch, W_ih, W_fh, W_oh + act_gate_d3_->Compute(gates + d_, gates + d_); + + /* C_t = C_t-1 * fgated + cand_gated * igated */ + act_cand_d_->Compute(gates, gates); + vmul_d_->Compute(gates, gates + d_, gates + d_); + vmul_d_->Compute(ct_1, gates + d2_, gates + d2_); + vadd_d_->Compute(gates + d_, gates + d2_, ct); + + /* H_t = act_cell(C_t) * ogated */ + act_cell_d_->Compute(ct, gates + d2_); + vmul_d_->Compute(gates + d2_, gates + d3_, ht); + } + void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { + /* C_t = igated * cgated*/ + act_gate_d_->Compute(gates + d_, gates + d_); + act_cand_d_->Compute(gates, gates); + vmul_d_->Compute(gates, gates + d_, ct); + /* H_t = act_cell(C_t) * ogated */ + act_gate_d_->Compute(gates + d3_, gates + d3_); + act_cell_d_->Compute(ct, gates + d2_); + vmul_d_->Compute(gates + d2_, gates + d3_, ht); + } + + private: + int d_, d2_, d3_; + std::shared_ptr> act_gate_d3_, act_gate_d_, act_cand_d_, + act_cell_d_; + std::shared_ptr> vmul_d_; + std::shared_ptr> vadd_d_; +#ifdef __AVX__ + std::unique_ptr avx_act_gate_, avx_act_cand_, avx_act_cell_; +#endif +}; + +#define INTRI8_FLOAT(isa) \ + template <> \ + void LSTMKernelImpl::ComputeCtHt( \ + float* gates, const float* ct_1, float* ct, float* ht, \ + const float* wp_data, float* checked) const { \ + /* gates: W_ch, W_ih, W_fh, W_oh */ \ + __m256 c, i, f, o; \ + c = _mm256_loadu_ps(gates); \ + i = _mm256_loadu_ps(gates + 8); \ + f = _mm256_loadu_ps(gates + 16); \ + o = _mm256_loadu_ps(gates + 24); \ + /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ + c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ + i = _mm256_loadu_ps(ct_1); \ + f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ + f = _mm256_add_ps(c, f); \ + _mm256_storeu_ps(ct, f); \ + /* H_t = act_cell(C_t) * ogated */ \ + o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ + _mm256_storeu_ps(ht, o); \ + } + +// TODO(TJ): optimize keq16 + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +#endif + +/* Peephole JitKernel */ +template +class PeepholeKernelImpl : public LSTMKernel { + public: + explicit PeepholeKernelImpl(const std::string& act_gate, + const std::string& act_cand, + const std::string& act_cell, int d) + : LSTMKernel() { + d_ = d; + d2_ = d * 2; + d3_ = d * 3; + act_gate_d_ = GetActKernel(act_gate, d); + act_cand_d_ = GetActKernel(act_cand, d); + act_cell_d_ = GetActKernel(act_cell, d); + vmul_d_ = KernelPool::Instance().template Get>(d); + vadd_d_ = KernelPool::Instance().template Get>(d); + vadd_d2_ = KernelPool::Instance().template Get>(d2_); + act_gate_d2_ = GetActKernel(act_gate, d2_); + } + + void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, + T* checked) const override { + /* get fgated and igated*/ + vmul_d_->Compute(wp_data, ct_1, checked); + vmul_d_->Compute(wp_data + d_, ct_1, checked + d_); + vadd_d2_->Compute(checked, gates + d_, gates + d_); + act_gate_d2_->Compute(gates + d_, gates + d_); + /* C_t = C_t-1 * fgated + cand_gated * igated*/ + act_cand_d_->Compute(gates, gates); + vmul_d_->Compute(gates, gates + d_, gates + d_); + vmul_d_->Compute(ct_1, gates + d2_, gates + d2_); + vadd_d_->Compute(gates + d_, gates + d2_, ct); + /* get ogated*/ + vmul_d_->Compute(wp_data + d2_, ct, gates + d_); + vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); + act_gate_d_->Compute(gates + d3_, gates + d3_); + /* H_t = act_cell(C_t) * ogated */ + act_cell_d_->Compute(ct, gates + d2_); + vmul_d_->Compute(gates + d2_, gates + d3_, ht); + } + + void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { + /* C_t = igated * cgated*/ + act_gate_d_->Compute(gates + d_, gates + d_); + act_cand_d_->Compute(gates, gates); + vmul_d_->Compute(gates, gates + d_, ct); + /* get outgated, put W_oc * C_t on igated */ + vmul_d_->Compute(wp_data + d2_, ct, gates + d_); + vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); + /* H_t = act_cell(C_t) * ogated */ + act_gate_d_->Compute(gates + d3_, gates + d3_); + act_cell_d_->Compute(ct, gates + d2_); + vmul_d_->Compute(gates + d2_, gates + d3_, ht); + } + + private: + int d_, d2_, d3_; + std::shared_ptr> act_gate_d2_, act_gate_d_, act_cand_d_, + act_cell_d_; + std::shared_ptr> vmul_d_; + std::shared_ptr> vadd_d_, vadd_d2_; +}; + +#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ + template <> \ + std::shared_ptr> \ + KernelPool::Get, const std::string&, \ + const std::string&, const std::string&, int, bool>( \ + const std::string& act_gate, const std::string& act_cand, \ + const std::string& act_cell, int d, bool use_peephole) + +#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \ + #ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + \ + (use_peephole ? "p" : "n") + +#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \ + if (use_peephole) { \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>( \ + act_gate, act_cand, act_cell, d)); \ + } else { \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>(act_gate, act_cand, \ + act_cell, d)); \ + } + +REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, + JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL); + +#undef INTRI8_FLOAT +#undef JITKERNEL_DECLARE_LSTM +#undef JITKERNEL_KEY_LSTM +#undef JITKERNEL_NEW_LSTM_IMPL +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_macro.h b/paddle/fluid/operators/math/jit_kernel_macro.h new file mode 100644 index 0000000000000000000000000000000000000000..d8e55f2673560ff6afa34376b73275b57a8ceea1 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_macro.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace math { +namespace jitkernel { + +namespace jit = platform::jit; + +#define SEARCH_BLOCK(macro_, ker, dtype, isa) \ + if (d < AVX_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kLT8); \ + } else if (d == AVX_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kEQ8); \ + } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kGT8LT16); \ + } else if (d == AVX512_FLOAT_BLOCK) { \ + macro_(ker, dtype, isa, kEQ16); \ + } else { \ + macro_(ker, dtype, isa, kGT16); \ + } + +#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \ + if (jit::MayIUse(jit::avx512f)) { \ + SEARCH_BLOCK(macro_, ker, dtype, jit::avx512f); \ + } else if (jit::MayIUse(jit::avx2)) { \ + SEARCH_BLOCK(macro_, ker, dtype, jit::avx2); \ + } else if (jit::MayIUse(jit::avx)) { \ + SEARCH_BLOCK(macro_, ker, dtype, jit::avx); \ + } else { \ + SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \ + } + +#define JITKERNEL_DECLARE(ker_class, ker_dtype) \ + template <> \ + std::shared_ptr> \ + KernelPool::Get, int>(int d) + +#define JITKERNEL_KEY(ker_key, dtype_key) \ + #ker_key #dtype_key + std::to_string(d) + +#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>(d)) + +#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \ + marco_declare, macro_key, macro_impl) \ + marco_declare(ker_class, ker_dtype) { \ + std::string key = macro_key(ker_key, dtype_key); \ + if (kers_.find(key) == kers_.end()) { \ + std::shared_ptr> p; \ + SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \ + kers_.insert({key, std::dynamic_pointer_cast(p)}); \ + return p; \ + } \ + return std::dynamic_pointer_cast>( \ + kers_.at(key)); \ + } + +#define REGISTER_JITKERNEL(ker_key, ker_class) \ + JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \ + JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \ + JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \ + JITKERNEL_KEY, JITKERNEL_NEW_IMPL) + +#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \ + macro_impl) \ + JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \ + macro_impl); \ + JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \ + macro_key, macro_impl) + +#define FOR_EACH_ISA(macro_, block) \ + macro_(jit::avx512f, block); \ + macro_(jit::avx2, block); \ + macro_(jit::avx, block); \ + macro_(jit::isa_any, block) + +#define FOR_EACH_BLOCK(macro_, isa) \ + macro_(isa, kLT8); \ + macro_(isa, kEQ8); \ + macro_(isa, kGT8LT16); \ + macro_(isa, kEQ16); \ + macro_(isa, kGT16) + +#define FOR_EACH_ISA_BLOCK(macro_) \ + FOR_EACH_BLOCK(macro_, jit::avx512f); \ + FOR_EACH_BLOCK(macro_, jit::avx2); \ + FOR_EACH_BLOCK(macro_, jit::avx); \ + FOR_EACH_BLOCK(macro_, jit::isa_any) + +} // namespace jitkernel +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..26590171bbeaa385ac09b04e5faf483924176598 --- /dev/null +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -0,0 +1,749 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include +#include // for exp +#include // for memcpy +#include +#include +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "gtest/gtest.h" + +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif + +#ifdef __AVX__ +#include +#endif + +constexpr int repeat = 20000; + +inline double GetCurrentUS() { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; +} + +template +void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), + const T upper = static_cast(20.f)) { + static unsigned int seed = 100; + std::mt19937 rng(seed++); + std::uniform_real_distribution uniform_dist(0, 1); + for (int i = 0; i < n; ++i) { + a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); + } +} + +void vrelu_ref(const int n, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0.f ? x[i] : 0.f; + } +} + +#if defined __AVX__ || defined __AVX2__ +void vrelu_intri8(const int n, const float* x, float* y) { + __m256 tmp = _mm256_loadu_ps(x); + tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); + _mm256_storeu_ps(y, tmp); +} +#endif + +TEST(JitKernel, vrelu) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 256, 512}) { + std::vector x(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data(), -10.f, 1.f); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vrelu_ref(d, x_data, zref_data); + } + auto trefe = GetCurrentUS(); +#if defined __AVX__ || defined __AVX2__ + if (d == 8) { + auto si0 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vrelu_intri8(d, x_data, zref_data); + } + auto si1 = GetCurrentUS(); + VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; + } +#endif + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + +void vaddbias_ref(const int n, const float a, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] + a; + } +} + +TEST(JitKernel, vaddbias) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) { + std::vector x(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data(), -2.f, 2.f); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float a = 2.f; + const float* x_data = x.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vaddbias_ref(d, a, x_data, zref_data); + } + auto trefe = GetCurrentUS(); + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(a, x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + +void vexp_ref(const int n, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = std::exp(x[i]); + } +} + +#ifdef PADDLE_WITH_MKLML +void vexp_mkl(const int n, const float* x, float* y) { + paddle::platform::dynload::vsExp(n, x, y); +} +#endif + +TEST(JitKernel, vexp) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 128, 256}) { + std::vector x(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data(), -2.f, 2.f); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vexp_ref(d, x_data, zref_data); + } + auto trefe = GetCurrentUS(); + +#ifdef PADDLE_WITH_MKLML + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vexp_mkl(d, x_data, zref_data); + } + auto tmkle = GetCurrentUS(); +#endif + + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat +#ifdef PADDLE_WITH_MKLML + << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " +#else + << " us, " +#endif + << "tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + +inline float _sigmoid(float x) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + float tmp = (x < min) ? min : ((x > max) ? max : x); + return 1.f / (1.f + std::exp(-tmp)); +} + +void vsigmoid_ref(const int n, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = _sigmoid(x[i]); + } +} + +void vsigmoid_better( + const std::shared_ptr< + const paddle::operators::math::jitkernel::VExpKernel>& vexp, + const int n, const float* x, float* y) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + for (int i = 0; i < n; ++i) { + y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); + y[i] = 0.f - y[i]; + } + vexp->Compute(y, y); + for (int i = 0; i < n; ++i) { + y[i] = 1.f / (1.f + y[i]); + } +} + +TEST(JitKernel, vsigmoid) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { + std::vector x(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data(), -2.f, 2.f); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const auto& vexp = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vsigmoid_better(vexp, d, x_data, zref_data); + } + auto tmkle = GetCurrentUS(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vsigmoid_ref(d, x_data, zref_data); + } + auto trefe = GetCurrentUS(); + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat + << " us, tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + +inline float _tanh(float x) { return 2.f * _sigmoid(2.f * x) - 1.f; } + +void vtanh_ref(const int n, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = _tanh(x[i]); + } +} + +void vtanh_better( + const std::shared_ptr< + const paddle::operators::math::jitkernel::VScalKernel>& vscal, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VSigmoidKernel>& + vsigmoid, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VAddBiasKernel>& + vaddbias, + const int n, const float* x, float* y) { + vscal->Compute(2.f, x, y); + vsigmoid->Compute(y, y); + vscal->Compute(2.f, y); + vaddbias->Compute(-1.f, y, y); +} + +TEST(JitKernel, vtanh) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { + std::vector x(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data(), -2.f, 2.f); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const auto& vscal = + jit::KernelPool::Instance().template Get>(d); + const auto& vsigmoid = + jit::KernelPool::Instance().template Get>(d); + const auto& vaddbias = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vtanh_better(vscal, vsigmoid, vaddbias, d, x_data, zref_data); + } + auto tmkle = GetCurrentUS(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vtanh_ref(d, x_data, zref_data); + } + auto trefe = GetCurrentUS(); + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat + << " us, tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + +void lstm_ctht_ref( + const std::shared_ptr< + const paddle::operators::math::jitkernel::VSigmoidKernel>& + vsigmoid_3d, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VTanhKernel>& vtanh_d, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VExpKernel>& vexp_1, + const int d, float* gates, const float* ct_1, float* ct, float* ht) { + vsigmoid_3d->Compute(gates + d, gates + d); + vtanh_d->Compute(gates, gates); + const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3; + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + for (int k = 0; k < d; ++k) { + // C_t = C_t-1 * fgated + cand_gated * igated + ct[k] = ct_1[k] * f[k] + gates[k] * i[k]; + // H_t = act_cell(C_t) * ogated + float tmp = ct[k] * 2; + tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); + vexp_1->Compute(&tmp, &tmp); + tmp = 2.f / (1.f + tmp) - 1.f; + ht[k] = tmp * o[k]; + } +} + +void lstm_ctht_better( + const std::shared_ptr< + const paddle::operators::math::jitkernel::VSigmoidKernel>& + vsigmoid_3d, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VTanhKernel>& vtanh_d, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VMulKernel>& vmul_d, + const std::shared_ptr< + const paddle::operators::math::jitkernel::VAddKernel>& vadd_d, + const int d, float* gates, const float* ct_1, float* ct, float* ht) { + int d2 = d * 2; + vsigmoid_3d->Compute(gates + d, gates + d); + vtanh_d->Compute(gates, gates); + vmul_d->Compute(gates, gates + d, gates + d); + vmul_d->Compute(ct_1, gates + d2, gates + d2); + vadd_d->Compute(gates + d, gates + d2, ct); + /* H_t = act_cell(C_t) * ogated */ + vtanh_d->Compute(ct, gates + d2); + vmul_d->Compute(gates + d2, gates + d * 3, ht); +} + +TEST(JitKernel, lstm) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 32, 64, 100}) { + int d4 = d * 4; + int d3 = d * 3; + std::vector x(d4), xref(d4); + std::vector ct_1(d), ct_tgt(d), ht_tgt(d); + std::vector ct_ref(d), ht_ref(d); + RandomVec(d4, x.data(), -2.f, 2.f); + RandomVec(d, ct_1.data(), -2.f, 2.f); + memcpy(xref.data(), x.data(), sizeof(float) * d4); + std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; + const auto& ker = + jit::KernelPool::Instance() + .template Get, const std::string&, + const std::string&, const std::string&>( + act_gate, act_cand, act_cell, d, false); + // below kernels are used to compute refer + const auto& vsigmoid_3d = + jit::KernelPool::Instance().template Get>( + d3); + const auto& vtanh_d = + jit::KernelPool::Instance().template Get>(d); + const auto& vexp_1 = + jit::KernelPool::Instance().template Get>(1); + const auto& vmul_d = + jit::KernelPool::Instance().template Get>(d); + const auto& vadd_d = + jit::KernelPool::Instance().template Get>(d); + + float* x_data = x.data(); + float* xref_data = xref.data(); + const float* ct_1_data = ct_1.data(); + float* ct_tgt_data = ct_tgt.data(); + float* ht_tgt_data = ht_tgt.data(); + float* ct_ref_data = ct_ref.data(); + float* ht_ref_data = ht_ref.data(); + // compute once to check correctness + lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data, + ct_ref_data, ht_ref_data); + ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3); + EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3); + } + + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + lstm_ctht_better(vsigmoid_3d, vtanh_d, vmul_d, vadd_d, d, xref_data, + ct_1_data, ct_ref_data, ht_ref_data); + } + auto tmkle = GetCurrentUS(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data, + ct_ref_data, ht_ref_data); + } + auto trefe = GetCurrentUS(); + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data); + } + auto ttgte = GetCurrentUS(); + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, better(jit) takes: " << (tmkle - tmkls) / repeat + << " us, tgt takes: " << (ttgte - ttgts) / repeat; + } +} + +void vscal_ref(const int n, const float a, const float* x, float* y) { + for (int i = 0; i < n; ++i) { + y[i] = a * x[i]; + } +} +void vscal_inp_ref(const int n, const float a, float* x) { + for (int i = 0; i < n; ++i) { + x[i] = a * x[i]; + } +} +#if defined __AVX__ || defined __AVX2__ +void vscal_intri8(const int n, const float a, const float* x, float* y) { + __m256 tmp; + __m256 scalar = _mm256_set1_ps(a); + tmp = _mm256_loadu_ps(x); + tmp = _mm256_mul_ps(tmp, scalar); + _mm256_storeu_ps(y, tmp); +} +void vscal_inp_intri8(const int n, const float a, float* x) { + __m256 tmp; + __m256 scalar = _mm256_set1_ps(a); + tmp = _mm256_loadu_ps(x); + tmp = _mm256_mul_ps(tmp, scalar); + _mm256_storeu_ps(x, tmp); +} +#endif + +#ifdef PADDLE_WITH_MKLML +void vscal_inp_mkl(const int n, const float a, float* x) { + paddle::platform::dynload::cblas_sscal(n, a, x, 1); +} +#endif + +TEST(JitKernel, vscal) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 256, 512}) { + std::vector x(d), y(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data()); + std::memcpy(y.data(), x.data(), sizeof(float) * d); + float a = 2.f; + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + float* y_data = y.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vscal_ref(d, a, x_data, zref_data); + } + auto trefe = GetCurrentUS(); + auto trefs1 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vscal_inp_ref(d, a, y_data); + } + auto trefe1 = GetCurrentUS(); + +#ifdef PADDLE_WITH_MKLML + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vscal_inp_mkl(d, a, y_data); + } + auto tmkle = GetCurrentUS(); +#endif + +#if defined __AVX__ || defined __AVX2__ + if (d == 8) { + auto si0 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vscal_intri8(d, a, x_data, zref_data); + } + auto si1 = GetCurrentUS(); + auto si2 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vscal_inp_intri8(d, a, y_data); + } + auto si3 = GetCurrentUS(); + VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat + << " us, inplace: " << (si3 - si2) / repeat; + } +#endif + + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(a, x_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + auto ttgts1 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(a, y_data); + } + auto ttgte1 = GetCurrentUS(); + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat + << " us, inplace takes: " << (trefe1 - trefs1) / repeat +#ifdef PADDLE_WITH_MKLML + << " us, mkl inplace takes: " << (tmkle - tmkls) / repeat << " us, " +#else + << " us, " +#endif + << "tgt takes: " << (ttgte - ttgts) / repeat + << "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + +void vmul_ref(const int n, const float* x, const float* y, float* z) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } +} + +#if defined __AVX__ || defined __AVX2__ +void vmul_intri8(const int n, const float* x, const float* y, float* z) { + __m256 tmpx, tmpy; + tmpx = _mm256_loadu_ps(x); + tmpy = _mm256_loadu_ps(y); + tmpx = _mm256_mul_ps(tmpx, tmpy); + _mm256_storeu_ps(z, tmpx); +} +#endif + +#ifdef PADDLE_WITH_MKLML +void vmul_mkl(const int n, const float* x, const float* y, float* z) { + paddle::platform::dynload::vsMul(n, x, y, z); +} +#endif + +TEST(JitKernel, vmul) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 256, 512}) { + std::vector x(d), y(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data()); + RandomVec(d, y.data()); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + const float* y_data = y.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vmul_ref(d, x_data, y_data, zref_data); + } + auto trefe = GetCurrentUS(); + +#ifdef PADDLE_WITH_MKLML + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vmul_mkl(d, x_data, y_data, zref_data); + } + auto tmkle = GetCurrentUS(); +#endif + +#if defined __AVX__ || defined __AVX2__ + if (d == 8) { + auto si0 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vmul_intri8(d, x_data, y_data, zref_data); + } + auto si1 = GetCurrentUS(); + VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; + } +#endif + + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(x_data, y_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat +#ifdef PADDLE_WITH_MKLML + << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " +#else + << " us, " +#endif + << "tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + +void vadd_ref(const int n, const float* x, const float* y, float* z) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + } +} + +#if defined __AVX__ || defined __AVX2__ +void vadd_intri8(const int n, const float* x, const float* y, float* z) { + __m256 tmpx, tmpy; + tmpx = _mm256_loadu_ps(x); + tmpy = _mm256_loadu_ps(y); + tmpx = _mm256_add_ps(tmpx, tmpy); + _mm256_storeu_ps(z, tmpx); +} +#endif + +#ifdef PADDLE_WITH_MKLML +void vadd_mkl(const int n, const float* x, const float* y, float* z) { + paddle::platform::dynload::vsAdd(n, x, y, z); +} +#endif + +TEST(JitKernel, vadd) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 256, 512}) { + std::vector x(d), y(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data()); + RandomVec(d, y.data()); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + const float* y_data = y.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vadd_ref(d, x_data, y_data, zref_data); + } + auto trefe = GetCurrentUS(); + +#ifdef PADDLE_WITH_MKLML + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vadd_mkl(d, x_data, y_data, zref_data); + } + auto tmkle = GetCurrentUS(); +#endif + +#if defined __AVX__ || defined __AVX2__ + if (d == 8) { + auto si0 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vadd_intri8(d, x_data, y_data, zref_data); + } + auto si1 = GetCurrentUS(); + VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; + } +#endif + + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(x_data, y_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat +#ifdef PADDLE_WITH_MKLML + << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " +#else + << " us, " +#endif + << "tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + +TEST(JitKernel, pool) { + namespace jit = paddle::operators::math::jitkernel; + const int frame_size = 4; + std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh"; + const auto& plstm1 = + jit::KernelPool::Instance() + .template Get, const std::string&, + const std::string&, const std::string&>( + act_gate, act_cand, act_cell, frame_size, false); + const auto& plstm2 = + jit::KernelPool::Instance() + .template Get, const std::string&, + const std::string&, const std::string&>( + act_gate, act_cand, act_cell, frame_size, false); + const auto& peephole = + jit::KernelPool::Instance() + .template Get, const std::string&, + const std::string&, const std::string&>( + act_gate, act_cand, act_cell, frame_size, true); + EXPECT_TRUE(plstm1 != peephole); + + const auto& pvmul_f = + jit::KernelPool::Instance().template Get>(4); + EXPECT_TRUE(std::dynamic_pointer_cast(plstm2) != + std::dynamic_pointer_cast(pvmul_f)); + + const auto& pvmul_d = + jit::KernelPool::Instance().template Get>(4); + EXPECT_TRUE(std::dynamic_pointer_cast(pvmul_f) != + std::dynamic_pointer_cast(pvmul_d)); + + const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulf4"); + EXPECT_EQ(pvmul_f, pvmul_from_key); + const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulf5"); + EXPECT_TRUE(pvmul_from_key2 == nullptr); +} diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 8e8baf49b2330e95ff1a868b0b0a03bc10d84484..08f57dd45ad76946cbcafb98a3414003ed9d67a9 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include -#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" namespace paddle { @@ -150,6 +151,45 @@ template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; +template +struct SelectedRowsSumTo { + void operator()(const platform::CPUDeviceContext& context, + const std::vector& input1, + const std::vector& input2_offsets, + framework::SelectedRows* input2) { + // Ensure all selected rows have the same height + size_t size = 0u; + for (auto iter = input1.begin(); iter != input1.end(); ++iter) { + auto& in_rows = (*iter)->rows(); + size += in_rows.end() - in_rows.begin(); + auto in1_height = (*iter)->height(); + PADDLE_ENFORCE_EQ(in1_height, input2->height()); + } + // concat rows + std::vector in2_rows; + in2_rows.reserve(in2_rows.size() + size); + for (auto iter = input1.begin(); iter != input1.end(); ++iter) { + const framework::Vector& in_rows = (*iter)->rows(); + in2_rows.insert(in2_rows.end(), in_rows.begin(), in_rows.end()); + } + input2->set_rows(in2_rows); + + auto* in2_value = input2->mutable_value(); + auto* in2_data = in2_value->data(); + auto blas = math::GetBlas(context); + size_t offset = 0u; + for (size_t i = 0u; i != input1.size(); ++i) { + auto& in_value = input1[i]->value(); + const auto* in_data = in_value.data(); + offset += input2_offsets[i]; + blas.VCOPY(in_value.numel(), in_data, in2_data + offset); + } + } +}; + +template struct SelectedRowsSumTo; +template struct SelectedRowsSumTo; + template struct SelectedRowsAddToTensor { void operator()(const platform::CPUDeviceContext& context, @@ -207,35 +247,45 @@ struct MergeAdd { const framework::SelectedRows& input, framework::SelectedRows* output) { framework::SelectedRows& out = *output; - auto input_rows = input.rows(); - std::set row_set(input_rows.begin(), input_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); + std::vector input_rows(input.rows()); - auto input_width = input.value().dims()[1]; - out.set_rows(merge_rows); + std::map> merge_row_map; + for (size_t i = 0; i < input_rows.size(); ++i) { + merge_row_map[input_rows[i]].push_back(i); + } + + std::vector merge_rows(merge_row_map.size()); + size_t idx = 0; + int64_t input_width = input.value().dims()[1]; out.set_height(input.height()); - out.mutable_value()->mutable_data( + + T* out_data = out.mutable_value()->mutable_data( framework::make_ddim( {static_cast(merge_rows.size()), input_width}), context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); - - auto* out_data = out.mutable_value()->data(); - auto* input_data = input.value().data(); - - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = FindPos(merge_rows, input_rows[i]); - for (int64_t j = 0; j < input_width; j++) { - out_data[out_i * input_width + j] += input_data[i * input_width + j]; + const T* in_data = input.value().data(); + + for (auto& row_pair : merge_row_map) { + auto* out_ptr = out_data + idx * input_width; + auto& rows = row_pair.second; + merge_rows[idx] = row_pair.first; + ++idx; + // rows.size() is always larger than 0 + std::memcpy(out_ptr, in_data + rows[0] * input_width, + sizeof(T) * input_width); + + for (size_t i = 1; i < rows.size(); ++i) { + auto* in_ptr = in_data + rows[i] * input_width; + for (int64_t j = 0; j < input_width; ++j) { + out_ptr[j] += in_ptr[j]; + } } } + + out.set_rows(merge_rows); } }; -template struct MergeAdd; -template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index aa419f74fcd2a53cdd734ec270bc154b78c9f2ff..900be86f91c6658a5265189a6745316c6471209e 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -12,8 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#include +#include + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/device_context.h" #define INLINE_FOR2(sizei, sizej) \ @@ -49,6 +55,15 @@ struct SelectedRowsAddTo { const int64_t input2_offset, framework::SelectedRows* input2); }; +// input2 = [all input in input1] + input2 +template +struct SelectedRowsSumTo { + void operator()(const DeviceContext& context, + const std::vector& input1, + const std::vector& input2_offsets, + framework::SelectedRows* input2); +}; + // input2 = input1 + input2 template struct SelectedRowsAddToTensor { @@ -70,6 +85,104 @@ struct MergeAdd { framework::SelectedRows* output); }; +template <> +struct MergeAdd { + framework::SelectedRows operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; + (*this)(context, input, &out); + return out; + } + + void operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output) { + framework::SelectedRows& out = *output; + std::vector input_rows(input.rows()); + + std::map> merge_row_map; + for (size_t i = 0; i < input_rows.size(); ++i) { + merge_row_map[input_rows[i]].push_back(i); + } + + std::vector merge_rows(merge_row_map.size()); + size_t idx = 0; + int64_t input_width = input.value().dims()[1]; + out.set_height(input.height()); + + auto* out_data = out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + auto* in_data = input.value().data(); + + auto blas = GetBlas(context); + for (auto& row_pair : merge_row_map) { + auto* out_ptr = out_data + idx * input_width; + auto& rows = row_pair.second; + merge_rows[idx] = row_pair.first; + ++idx; + // rows.size() is always larger than 0 + blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr); + + for (size_t i = 1; i < rows.size(); ++i) { + blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr); + } + } + + out.set_rows(merge_rows); + } +}; + +template <> +struct MergeAdd { + framework::SelectedRows operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; + (*this)(context, input, &out); + return out; + } + + void operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output) { + framework::SelectedRows& out = *output; + std::vector input_rows(input.rows()); + + std::map> merge_row_map; + for (size_t i = 0; i < input_rows.size(); ++i) { + merge_row_map[input_rows[i]].push_back(i); + } + + std::vector merge_rows(merge_row_map.size()); + size_t idx = 0; + int64_t input_width = input.value().dims()[1]; + out.set_height(input.height()); + + auto* out_data = out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + auto* in_data = input.value().data(); + + auto blas = GetBlas(context); + for (auto& row_pair : merge_row_map) { + auto* out_ptr = out_data + idx * input_width; + auto& rows = row_pair.second; + merge_rows[idx] = row_pair.first; + ++idx; + // rows.size() is always larger than 0 + blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr); + + for (size_t i = 1; i < rows.size(); ++i) { + blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr); + } + } + + out.set_rows(merge_rows); + } +}; + template struct Add { framework::SelectedRows operator()(const DeviceContext& context, diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cc b/paddle/fluid/operators/math/selected_rows_functor_test.cc index 70bed820ee58885861fa8c5535c931f258625572..835589356042b44c9fa5988aed726434fd66910a 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cc +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cc @@ -219,3 +219,174 @@ TEST(selected_rows_functor, cpu_add_to) { // row9: 2.0 + 3.0 EXPECT_EQ(tensor1_data[9 * row_numel + 6], 5.0); } + +TEST(selected_rows_functor, cpu_merge_add_float) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + paddle::operators::math::SetConstant + functor; + int64_t height = 10; + int64_t row_numel = 10; + + std::vector rows{0, 4, 4, 7}; + std::unique_ptr selected_rows{ + new paddle::framework::SelectedRows(rows, height)}; + auto* in_value = selected_rows->mutable_value(); + in_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows.size()), row_numel}), + cpu_place); + functor(ctx, in_value, 1.0); + + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + + paddle::operators::math::scatter::MergeAdd + merge_add_functor; + merge_add_functor(ctx, *selected_rows, output.get()); + + auto out_height = output->height(); + EXPECT_EQ(out_height, height); + + auto& out_rows = output->rows(); + EXPECT_EQ(out_rows[0], 0); + EXPECT_EQ(out_rows[1], 4); + EXPECT_EQ(out_rows[2], 7); + + auto* out_data = output->value().data(); + + EXPECT_EQ(out_data[0 * row_numel], 1.0); + EXPECT_EQ(out_data[1 * row_numel], 2.0); + EXPECT_EQ(out_data[2 * row_numel], 1.0); +} + +TEST(selected_rows_functor, cpu_merge_add_int) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + paddle::operators::math::SetConstant + functor; + int64_t height = 10; + int64_t row_numel = 10; + + std::vector rows{0, 4, 4, 7}; + std::unique_ptr selected_rows{ + new paddle::framework::SelectedRows(rows, height)}; + auto* in_value = selected_rows->mutable_value(); + in_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows.size()), row_numel}), + cpu_place); + functor(ctx, in_value, 1); + + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + + paddle::operators::math::scatter::MergeAdd + merge_add_functor; + merge_add_functor(ctx, *selected_rows, output.get()); + + auto out_height = output->height(); + EXPECT_EQ(out_height, height); + + auto& out_rows = output->rows(); + EXPECT_EQ(out_rows[0], 0); + EXPECT_EQ(out_rows[1], 4); + EXPECT_EQ(out_rows[2], 7); + + auto* out_data = output->value().data(); + + EXPECT_EQ(out_data[0 * row_numel], 1); + EXPECT_EQ(out_data[1 * row_numel], 2); + EXPECT_EQ(out_data[2 * row_numel], 1); +} +TEST(selected_rows_functor, cpu_sum_to) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + paddle::operators::math::SetConstant + functor; + int64_t height = 10; + int64_t row_numel = 10; + std::vector rows1{0, 4, 7}; + std::unique_ptr selected_rows1{ + new paddle::framework::SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows1.size()), row_numel}), + cpu_place); + functor(ctx, in1_value, 1.0); + std::vector rows2{0, 5, 7, 9}; + std::unique_ptr selected_rows2{ + new paddle::framework::SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows2.size()), row_numel}), + cpu_place); + functor(ctx, in2_value, 2.0); + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + output->set_height(height); + auto* out_value = output->mutable_value(); + // simplely concat two SelectedRows + out_value->mutable_data(paddle::framework::make_ddim({7, 10}), + cpu_place); + paddle::operators::math::SelectedRowsSumTo + sum_to_functor; + sum_to_functor(ctx, std::vector( + {selected_rows1.get(), selected_rows2.get()}), + std::vector({0, in1_value->numel()}), output.get()); + auto out_height = output->height(); + EXPECT_EQ(out_height, height); + auto& out_rows = output->rows(); + // input1 rows + EXPECT_EQ(out_rows[0], 0); + EXPECT_EQ(out_rows[1], 4); + EXPECT_EQ(out_rows[2], 7); + // input2 rows + EXPECT_EQ(out_rows[3], 0); + EXPECT_EQ(out_rows[4], 5); + EXPECT_EQ(out_rows[5], 7); + EXPECT_EQ(out_rows[6], 9); + auto* out_data = output->value().data(); + // input1 value + EXPECT_EQ(out_data[0 * row_numel + 0], 1.0); + EXPECT_EQ(out_data[0 * row_numel + 8], 1.0); + EXPECT_EQ(out_data[1 * row_numel + 1], 1.0); + EXPECT_EQ(out_data[2 * row_numel + 6], 1.0); + // input2 value + EXPECT_EQ(out_data[3 * row_numel + 3], 2.0); + EXPECT_EQ(out_data[3 * row_numel + 8], 2.0); + EXPECT_EQ(out_data[4 * row_numel + 4], 2.0); + EXPECT_EQ(out_data[5 * row_numel + 7], 2.0); + EXPECT_EQ(out_data[6 * row_numel + 9], 2.0); + std::unique_ptr tensor1{ + new paddle::framework::Tensor()}; + tensor1->mutable_data( + paddle::framework::make_ddim({height, row_numel}), cpu_place); + functor(ctx, tensor1.get(), 3.0); + paddle::operators::math::SelectedRowsAddToTensor< + paddle::platform::CPUDeviceContext, float> + add_to_tensor_functor; + add_to_tensor_functor(ctx, *output, tensor1.get()); + auto* tensor1_data = tensor1->data(); + // row0: 1.0 + 2.0 + 3.0 + EXPECT_EQ(tensor1_data[0 * row_numel + 0], 6.0); + // row1: 3.0 + EXPECT_EQ(tensor1_data[1 * row_numel + 1], 3.0); + // row4 : 1.0 + 3.0 + EXPECT_EQ(tensor1_data[4 * row_numel + 6], 4.0); + // row5: 2.0 + 3.0 + EXPECT_EQ(tensor1_data[5 * row_numel + 7], 5.0); + // row6: 3.0 + EXPECT_EQ(tensor1_data[6 * row_numel + 1], 3.0); + // row7: 1.0 + 2.0 + 3.0 + EXPECT_EQ(tensor1_data[7 * row_numel + 3], 6.0); + // row9: 2.0 + 3.0 + EXPECT_EQ(tensor1_data[9 * row_numel + 6], 5.0); +} diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 69318a6598c8c69eceab7216df6382537153d34f..235b5405fb7d016f4bd8c738f75b303522183116 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/sequence_pooling.h" #include + +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sequence_pooling.h" namespace paddle { namespace operators { @@ -180,6 +182,7 @@ class SequencePoolFunctor { } auto lod = input.lod()[0]; auto& place = *context.eigen_device(); + auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { Tensor in_t = input.Slice(static_cast(lod[i]), static_cast(lod[i + 1])); @@ -191,7 +194,14 @@ class SequencePoolFunctor { if (pooltype == "AVERAGE") { out_e.device(place) = in_e.mean(Eigen::array({{0}})); } else if (pooltype == "SUM") { - out_e.device(place) = in_e.sum(Eigen::array({{0}})); + if (h > 0) { + const T* in_data = in_t.data(); + T* out_data = out_t.mutable_data(context.GetPlace()); + blas.VCOPY(w, in_data, out_data); + for (int64_t r = 1; r != h; ++r) { + blas.AXPY(w, 1., in_data + r * w, out_data); + } + } } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); @@ -223,6 +233,7 @@ class SequencePoolGradFunctor { } auto lod = in_grad->lod()[0]; auto& place = *context.eigen_device(); + auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { auto in_g_t = in_grad->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); @@ -237,7 +248,11 @@ class SequencePoolGradFunctor { if (pooltype == "AVERAGE") { in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); } else if (pooltype == "SUM") { - in_g_e.device(place) = (out_g_e).broadcast(bcast); + const T* out_g_data = out_g_t.data(); + T* in_g_data = in_g_t.mutable_data(context.GetPlace()); + for (int r = 0; r != h; ++r) { + blas.VCOPY(w, out_g_data, in_g_data + r * w); + } } else if (pooltype == "SQRT") { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index 5f43c5810812260c4384349bdb709716c9a182f5..c8079a99fb8c8e05144c5390794db7757e74f6ae 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -33,6 +33,11 @@ class MomentumOp : public framework::OperatorWithKernel { "Input(velocity) of Momentum should not be null."); PADDLE_ENFORCE(ctx->HasInput("LearningRate"), "Input(LearningRate) of Momentum should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(ParamOut) of Momentum should not be null."); diff --git a/paddle/fluid/operators/momentum_op.cu b/paddle/fluid/operators/momentum_op.cu index a3932db1f3a50305d585cd3d5e86fa1b527df78b..5dc920c70979ad5c3d4fb3acc8d2dacaffe386c0 100644 --- a/paddle/fluid/operators/momentum_op.cu +++ b/paddle/fluid/operators/momentum_op.cu @@ -46,6 +46,17 @@ template class MomentumOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + const auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE(grad_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Grad").front(), grad_var->Type().name()); + auto param_out = ctx.Output("ParamOut"); auto velocity_out = ctx.Output("VelocityOut"); auto param = ctx.Input("Param"); diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/momentum_op.h index 264726040fb566a52b8c0cdee0a1524197d2a675..40073d21b7186ad319e4988e667750c267581300 100644 --- a/paddle/fluid/operators/momentum_op.h +++ b/paddle/fluid/operators/momentum_op.h @@ -23,6 +23,12 @@ template class MomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + auto param_out = ctx.Output("ParamOut"); auto velocity_out = ctx.Output("VelocityOut"); auto param = ctx.Input("Param"); diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc index 97c36a83fc5eff421725d05f66fca05f5169d1bb..ab25628d45699dbcfc1fc5792958bae9e42e72a3 100644 --- a/paddle/fluid/operators/parallel_do_op.cc +++ b/paddle/fluid/operators/parallel_do_op.cc @@ -397,6 +397,24 @@ class ParallelDoGradOpShapeInference : public framework::InferShapeBase { } }; +class ParallelDoGradOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + framework::BlockDesc *sub_block = + boost::get(op_desc.GetAttr(kParallelBlock)); + for (auto &out_vars : op_desc.Outputs()) { + for (auto &out_var : out_vars.second) { + auto &var = block->FindRecursiveOrCreateVar(out_var); + auto sub_var = sub_block->FindRecursiveOrCreateVar(out_var); + if (sub_var.GetType() != var.GetType()) { + var.SetType(sub_var.GetType()); + } + } + } + } +}; + } // namespace operators } // namespace paddle @@ -404,4 +422,5 @@ REGISTER_OPERATOR(parallel_do, paddle::operators::ParallelDoOp, paddle::operators::ParallelDoOpProtoMaker, paddle::operators::ParallelDoGradOpDescMaker); REGISTER_OPERATOR(parallel_do_grad, paddle::operators::ParallelDoGradOp, - paddle::operators::ParallelDoGradOpShapeInference); + paddle::operators::ParallelDoGradOpShapeInference, + paddle::operators::ParallelDoGradOpVarTypeInference); diff --git a/paddle/fluid/operators/reader/blocking_queue.h b/paddle/fluid/operators/reader/blocking_queue.h index 28cc91a5ed5d74994e5b960a0a4dd3c6a5e6cdcc..51b980acb5a08d431d96a3a92479dec09119c27e 100644 --- a/paddle/fluid/operators/reader/blocking_queue.h +++ b/paddle/fluid/operators/reader/blocking_queue.h @@ -31,8 +31,8 @@ class BlockingQueue { // is a workaround and a simplified version of framework::Channel as it // doesn't support GPU and it implements on buffered blocking queue. public: - explicit BlockingQueue(size_t capacity) - : capacity_(capacity), closed_(false) { + explicit BlockingQueue(size_t capacity, bool speed_test_mode = false) + : capacity_(capacity), speed_test_mode_(speed_test_mode), closed_(false) { PADDLE_ENFORCE_GT( capacity_, 0, "The capacity of a reader::BlockingQueue must be greater than 0."); @@ -72,7 +72,9 @@ class BlockingQueue { if (!queue_.empty()) { PADDLE_ENFORCE_NOT_NULL(elem); *elem = queue_.front(); - queue_.pop_front(); + if (LIKELY(!speed_test_mode_)) { + queue_.pop_front(); + } send_cv_.notify_one(); return true; } else { @@ -114,6 +116,7 @@ class BlockingQueue { private: size_t capacity_; + bool speed_test_mode_; bool closed_; std::deque queue_; diff --git a/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h b/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h index 4f7cfc24ec035349f3c85e84d876ad9b5b5493a6..3f041ff7e4e32b407729a22aab25d3aab199fee0 100644 --- a/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h +++ b/paddle/fluid/operators/reader/lod_tensor_blocking_queue.h @@ -33,8 +33,9 @@ class LoDTensorBlockingQueue { private: LoDTensorBlockingQueue(size_t capacity, - const std::vector& dims) - : queue_(capacity), dims_(dims) {} + const std::vector& dims, + bool speed_test_mode = false) + : queue_(capacity, speed_test_mode), dims_(dims) {} public: bool Push(const std::vector& lod_tensor_vec) { @@ -69,11 +70,12 @@ class LoDTensorBlockingQueue { class LoDTensorBlockingQueueHolder { public: - void InitOnce(size_t capacity, const std::vector& dims) { + void InitOnce(size_t capacity, const std::vector& dims, + bool speed_test_mode = false) { PADDLE_ENFORCE( queue_ == nullptr, "LoDTensorBlockingQueueHolder::InitOnce() can only be called once"); - queue_.reset(new LoDTensorBlockingQueue(capacity, dims)); + queue_.reset(new LoDTensorBlockingQueue(capacity, dims, speed_test_mode)); } inline const std::shared_ptr& GetQueue() const { diff --git a/paddle/fluid/operators/reader/reader_blocking_queue_test.cc b/paddle/fluid/operators/reader/reader_blocking_queue_test.cc index 7d1b381d56c8cdc1e79e594b18c1a1ed59ab5284..bd7ac64b2fce2452744e4756b149ee7f291d38aa 100644 --- a/paddle/fluid/operators/reader/reader_blocking_queue_test.cc +++ b/paddle/fluid/operators/reader/reader_blocking_queue_test.cc @@ -217,3 +217,27 @@ TEST(BlockingQueue, MyClassTest) { q.Receive(&b); EXPECT_EQ(a.val_, b.val_); } + +TEST(BlockingQueue, speed_test_mode) { + size_t queue_size = 10; + BlockingQueue q1(queue_size, false); + for (size_t i = 0; i < queue_size; ++i) { + q1.Send(i); + } + size_t b; + for (size_t i = 0; i < queue_size; ++i) { + q1.Receive(&b); + EXPECT_EQ(b, i); + } + EXPECT_EQ(q1.Size(), 0); + + BlockingQueue q2(queue_size, true); + for (size_t i = 0; i < queue_size; ++i) { + q2.Send(i); + } + for (size_t i = 0; i < queue_size; ++i) { + q2.Receive(&b); + EXPECT_EQ(b, 0); + } + EXPECT_EQ(q2.Size(), queue_size); +} diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index d72f85f2c44db2fa887732cfc05e1376a6a79e4a..500d86fec33830fc2cfb0412f1f2c7780d08eb02 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -164,7 +164,7 @@ dimension value will be copied from Input(X) at runtime. Note that the index of [2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input. 3. Input(Shape) has a higher priority than Attr(shape) if it is provided, while -Attr(shape) still should be set correctly to gurantee shape inference in +Attr(shape) still should be set correctly to gurantee shape inference in compile-time. )DOC"); @@ -259,7 +259,6 @@ class Reshape2Op : public ReshapeOp { : ReshapeOp(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - ReshapeOp::InferShape(ctx); PADDLE_ENFORCE(ctx->HasOutput("XShape"), "Output(XShape) of ReshapeOp should not be null."); const auto &x_dims = ctx->GetInputDim("X"); @@ -270,6 +269,8 @@ class Reshape2Op : public ReshapeOp { } ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); ctx->ShareLoD("X", /*->*/ "XShape"); + + ReshapeOp::InferShape(ctx); } }; diff --git a/paddle/fluid/operators/rmsprop_op.cc b/paddle/fluid/operators/rmsprop_op.cc index 2f773f222e50a440801b06a4fd997bf237b34772..f06f87e61d3a4d1fc8b864b9dd84e697fb12a006 100644 --- a/paddle/fluid/operators/rmsprop_op.cc +++ b/paddle/fluid/operators/rmsprop_op.cc @@ -32,6 +32,11 @@ class RmspropOp : public framework::OperatorWithKernel { "Input(Grad) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Moment"), "Input(Moment) of RmspropOp should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(param_out) of RmspropOp should not be null."); diff --git a/paddle/fluid/operators/rmsprop_op.h b/paddle/fluid/operators/rmsprop_op.h index 25ed32c5ebb2ff5be962ac1e3e38c970623d705c..797cd45fdcdbd5c3567d1676f37e148304ee6e2d 100644 --- a/paddle/fluid/operators/rmsprop_op.h +++ b/paddle/fluid/operators/rmsprop_op.h @@ -13,66 +13,254 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; template using EigenVector = framework::EigenVector; +template +struct DenseRmspropGradFunctor { + inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {} + + HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; } + + const T *grad_; +}; + +template +struct SparseRmspropGradFunctor { + inline SparseRmspropGradFunctor(const T *grad, const int64_t *rows, + int64_t row_numel, int64_t row_count) + : grad_(grad), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count) {} + + HOSTDEVICE inline T operator()(int64_t idx) const { + auto row_idx = math::BinarySearch(rows_, row_count_, idx / row_numel_); + return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0; + } + + const T *grad_; + const int64_t *rows_; + int64_t row_numel_; + int64_t row_count_; +}; + +template +struct UncenteredRmspropFunctor { + UncenteredRmspropFunctor(T *param, T *ms, T *mom, const T *lr, T rho, + T epsilon, T momentum, + const GradFunctor &grad_functor) + : param_(param), + ms_(ms), + mom_(mom), + lr_(lr), + rho_(rho), + epsilon_(epsilon), + momentum_(momentum), + grad_functor_(grad_functor) {} + + HOSTDEVICE inline void operator()(int64_t idx) const { + T g = grad_functor_(idx); + T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; + T mom_out = momentum_ * mom_[idx] + lr_[0] * g / sqrt(ms_out + epsilon_); + param_[idx] -= mom_out; + ms_[idx] = ms_out; + mom_[idx] = mom_out; + } + + T *param_; + T *ms_; + T *mom_; + const T *lr_; + T rho_; + T epsilon_; + T momentum_; + GradFunctor grad_functor_; +}; + +template +struct CenteredRmspropFunctor { + CenteredRmspropFunctor(T *param, T *ms, T *mom, T *mean_grad, const T *lr, + T rho, T epsilon, T momentum, + const GradFunctor &grad_functor) + : param_(param), + ms_(ms), + mom_(mom), + mean_grad_(mean_grad), + lr_(lr), + rho_(rho), + epsilon_(epsilon), + momentum_(momentum), + grad_functor_(grad_functor) {} + + HOSTDEVICE inline void operator()(int64_t idx) const { + T g = grad_functor_(idx); + T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; + T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g; + T mom_out = momentum_ * mom_[idx] + + lr_[0] * g / sqrt(ms_out - mg_out * mg_out + epsilon_); + param_[idx] -= mom_out; + ms_[idx] = ms_out; + mom_[idx] = mom_out; + mean_grad_[idx] = mg_out; + } + + T *param_; + T *ms_; + T *mom_; + T *mean_grad_; + const T *lr_; + T rho_; + T epsilon_; + T momentum_; + GradFunctor grad_functor_; +}; + template class RmspropOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* param_out = ctx.Output("ParamOut"); - auto* moment_out = ctx.Output("MomentOut"); - auto* mean_square_out = ctx.Output("MeanSquareOut"); + void Compute(const framework::ExecutionContext &ctx) const override { + using LoDTensor = framework::LoDTensor; + auto *grad_var = ctx.InputVar("Grad"); + auto *param_out = ctx.Output("ParamOut"); + auto *moment_out = ctx.Output("MomentOut"); + auto *mean_square_out = ctx.Output("MeanSquareOut"); - auto grad = ctx.Input("Grad"); + auto epsilon = static_cast(ctx.Attr("epsilon")); + auto rho = static_cast(ctx.Attr("decay")); + auto momentum = static_cast(ctx.Attr("momentum")); + bool centered = ctx.Attr("centered"); - param_out->mutable_data(ctx.GetPlace()); - moment_out->mutable_data(ctx.GetPlace()); - mean_square_out->mutable_data(ctx.GetPlace()); + auto &p_tensor = *ctx.Input("Param"); + auto &ms_tensor = *ctx.Input("MeanSquare"); + auto &lr_tensor = *ctx.Input("LearningRate"); + auto &mom_tensor = *ctx.Input("Moment"); - float epsilon = ctx.Attr("epsilon"); - float rho = ctx.Attr("decay"); - float momentum = ctx.Attr("momentum"); - bool centered = ctx.Attr("centered"); + PADDLE_ENFORCE_EQ(&p_tensor, param_out, + "Param and ParamOut must be the same Tensor"); + PADDLE_ENFORCE_EQ(&mom_tensor, moment_out, + "Moment and MomentOut must be the same Tensor"); + PADDLE_ENFORCE_EQ(&ms_tensor, mean_square_out, + "MeanSquare and MeanSquareOut must be the same Tensor"); + + auto &dev_ctx = ctx.template device_context(); + size_t limit = static_cast(ms_tensor.numel()); + + if (grad_var->IsType()) { + auto &grad_tensor = grad_var->Get(); + + if (std::is_same::value) { + auto &place = + *ctx.template device_context().eigen_device(); + auto lr_value = lr_tensor.data()[0]; + + auto p = EigenVector::Flatten(p_tensor); + auto ms = EigenVector::Flatten(ms_tensor); + auto g = EigenVector::Flatten(grad_tensor); + auto mom = EigenVector::Flatten(mom_tensor); + + auto p_out = EigenVector::Flatten(*param_out); + auto mom_out = EigenVector::Flatten(*moment_out); + auto ms_out = EigenVector::Flatten(*mean_square_out); + + ms_out.device(place) = rho * ms + (1 - rho) * g * g; + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto mg = EigenVector::Flatten(mg_tensor); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + auto mg_out = EigenVector::Flatten(*mean_grad_out); + + mg_out.device(place) = rho * mg + (1 - rho) * g; + mom_out.device(place) = + momentum * mom + + lr_value * g / (ms_out - mg_out.square() + epsilon).sqrt(); + } else { + mom_out.device(place) = + momentum * mom + lr_value * g / (ms_out + epsilon).sqrt(); + } + p_out.device(place) = p - mom_out; + } else { + DenseRmspropGradFunctor grad_func(grad_tensor.data()); + platform::ForRange for_range(dev_ctx, limit); + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + for_range(CenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), + mean_grad_out->mutable_data(ctx.GetPlace()), + lr_tensor.data(), rho, epsilon, momentum, grad_func)); + } else { + for_range(UncenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } + } + } else if (grad_var->IsType()) { + auto &grad = grad_var->Get(); + auto *merged_grad = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + + math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, grad, merged_grad); + + platform::ForRange for_range(dev_ctx, limit); + const int64_t *rows; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + rows = merged_grad->rows().CUDAData(ctx.GetPlace()); + } else { +#endif + rows = merged_grad->rows().data(); +#ifdef PADDLE_WITH_CUDA + } +#endif + auto &merged_tensor = merged_grad->value(); + int64_t row_count = merged_grad->rows().size(); + int64_t row_numel = merged_tensor.numel() / row_count; + SparseRmspropGradFunctor grad_func(merged_tensor.data(), rows, + row_numel, row_count); - auto p = EigenVector::Flatten(*ctx.Input("Param")); - auto ms = EigenVector::Flatten(*ctx.Input("MeanSquare")); - auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); - auto g = EigenVector::Flatten(*grad); - auto mom = EigenVector::Flatten(*ctx.Input("Moment")); - - auto p_out = EigenVector::Flatten(*param_out); - auto mom_out = EigenVector::Flatten(*moment_out); - auto ms_out = EigenVector::Flatten(*mean_square_out); - auto& place = *ctx.template device_context().eigen_device(); - - Eigen::DSizes grad_dsize(static_cast(grad->numel())); - - ms_out.device(place) = rho * ms + (1 - rho) * g * g; - if (centered) { - auto mg = EigenVector::Flatten(*ctx.Input("MeanGrad")); - auto* mean_grad_out = ctx.Output("MeanGradOut"); - mean_grad_out->mutable_data(ctx.GetPlace()); - auto mg_out = EigenVector::Flatten(*mean_grad_out); - - mg_out.device(place) = rho * mg + (1 - rho) * g; - mom_out.device(place) = momentum * mom + - lr.broadcast(grad_dsize) * g / - (ms_out - mg_out.square() + epsilon).sqrt(); + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + for_range(CenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), + mean_grad_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } else { + for_range(UncenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } } else { - mom_out.device(place) = - momentum * mom + - lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt(); + PADDLE_THROW("RMSProp only supports LoDTensor or SelectedRows gradient"); } - p_out.device(place) = p - mom_out; } }; diff --git a/paddle/fluid/operators/sequence_concat_op.cc b/paddle/fluid/operators/sequence_concat_op.cc index 397a3182953e3f1afaeadeff6d53a4f22fb95d26..3234b60861da3d0c6a8434eb11fd0488a95e171f 100644 --- a/paddle/fluid/operators/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_concat_op.cc @@ -90,11 +90,13 @@ REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel, paddle::framework::DefaultGradOpDescMaker); template using Kernel = op::SeqConcatKernel; -REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel, Kernel); +REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel, Kernel, + Kernel); + REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel, op::SeqConcatGradShapeInferer); template using GradKernel = op::SeqConcatGradKernel; REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel, - GradKernel); + GradKernel, GradKernel); diff --git a/paddle/fluid/operators/sequence_unpad_op.cc b/paddle/fluid/operators/sequence_unpad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3a0762b9a4d3e080d5d6d10b249e0bd81980b95 --- /dev/null +++ b/paddle/fluid/operators/sequence_unpad_op.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_unpad_op.h" + +namespace paddle { +namespace operators { + +class SequenceUnpadOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceUnpadOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Length"), + "Input(Length) of SequenceUnpadOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceUnpadOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "The rank of Input(X) can't be less than 2."); + + auto len_dims = ctx->GetInputDim("Length"); + PADDLE_ENFORCE(len_dims.size() == 2 && len_dims[1] == 1, + "The shape of Input(Length) should be [batch_size, 1]."); + PADDLE_ENFORCE( + len_dims[0] == x_dims[0], + "Input(X) and Input(Length) should have the same first dimension."); + + int64_t out_dim_0 = -1; + if (ctx->IsRuntime()) { + out_dim_0 = x_dims[0] * x_dims[1]; + } + + std::vector out_dims_vec{out_dim_0}; + if (x_dims.size() == 2) { + out_dims_vec.push_back(1); + } else { + for (size_t i = 2; i < x_dims.size(); ++i) { + out_dims_vec.push_back(x_dims[i]); + } + } + ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class SequenceUnpadOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LoDTensor, default LoDTensor) Input tensor which " + "contains the padded sequences with equal length."); + AddInput("Length", + "(LoDTensor) The input tensor which specifies the actual ength of " + "sequences after unpadding."); + AddOutput( + "Out", + "(LoDTensor) The output tensor which contains unpadded sequences."); + AddComment(R"DOC( + Sequence Unpad Operator + + This operator removes the padding data in the input sequences and convert + them into sequences with actual length as output, identitied by lod + information. + + Example: + + Given input tensor Input(X): + X.data = [[ 1.0, 2.0, 3.0, 4.0, 5.0], + [ 6.0, 7.0, 8.0, 9.0, 10.0], + [11.0, 12.0, 13.0, 14.0, 15.0]], +` + in which there are 3 sequences padded to length 5, and the acutal length + specified by Input(Length): + + Length.data = [[2], [3], [4]], + + after unpadding, Output(Out) will be: + + Out.data = [[1.0, 2.0, 6.0, 7.0, 8.0, 11.0, 12.0, 13.0, 14.0]] + Out.lod = [[0, 2, 5, 9]] + + )DOC"); + } +}; + +class SequenceUnpadGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceUnpadGradOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequenceUnpadGradOp should not be null."); + + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp, + ops::SequenceUnpadOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_unpad, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel); +REGISTER_OP_CPU_KERNEL( + sequence_unpad_grad, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel); diff --git a/paddle/fluid/operators/sequence_unpad_op.cu b/paddle/fluid/operators/sequence_unpad_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..75248372237ec2cb23122f6b16e64f6ce750ebf9 --- /dev/null +++ b/paddle/fluid/operators/sequence_unpad_op.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_unpad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sequence_unpad, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel); +REGISTER_OP_CUDA_KERNEL( + sequence_unpad_grad, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel); diff --git a/paddle/fluid/operators/sequence_unpad_op.h b/paddle/fluid/operators/sequence_unpad_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ebe3118b985bdfd41ca55e8c572047aa87502ff4 --- /dev/null +++ b/paddle/fluid/operators/sequence_unpad_op.h @@ -0,0 +1,104 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sequence_padding.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +template +class SequenceUnpadOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x_t = ctx.Input("X"); + auto* len_t = ctx.Input("Length"); + auto* out_t = ctx.Output("Out"); + out_t->mutable_data(ctx.GetPlace()); + + const int64_t* seq_len_ptr = nullptr; + if (platform::is_gpu_place(ctx.GetPlace())) { + LoDTensor seq_len_cpu; + seq_len_cpu.Resize(len_t->dims()); + seq_len_ptr = seq_len_cpu.mutable_data(platform::CPUPlace()); + framework::TensorCopy(*len_t, platform::CPUPlace(), + ctx.template device_context(), + &seq_len_cpu); + } else { + seq_len_ptr = len_t->data(); + } + + size_t batch_size = x_t->dims()[0]; + std::vector out_lod0(batch_size + 1, 0); + for (size_t i = 0; i < batch_size; ++i) { + out_lod0[i + 1] = out_lod0[i] + seq_len_ptr[i]; + } + + framework::LoD out_lod; + out_lod.push_back(out_lod0); + out_t->set_lod(out_lod); + + std::vector out_dims_vec{static_cast(out_lod0.back())}; + if (x_t->dims().size() == 2) { + out_dims_vec.push_back(1); + } else { + for (size_t i = 2; i < x_t->dims().size(); ++i) { + out_dims_vec.push_back(x_t->dims()[i]); + } + } + out_t->Resize(framework::make_ddim(out_dims_vec)); + + int64_t padded_length = x_t->dims()[1]; + math::UnpaddingLoDTensorFunctor()( + ctx.template device_context(), *x_t, out_t, + padded_length, 0, false, math::kBatchLengthWidth); + } +}; + +template +class SequenceUnpadGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_x = ctx.Output(framework::GradVarName("X")); + if (d_x) { + const auto* d_out = ctx.Input(framework::GradVarName("Out")); + const auto* x_t = ctx.Input("X"); + d_x->mutable_data(ctx.GetPlace()); + + int padded_length = x_t->dims()[1]; + + LoDTensor zero_pads; + zero_pads.Resize({1, 1}); + zero_pads.mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + set_zero(dev_ctx, &zero_pads, static_cast(0)); + + math::PaddingLoDTensorFunctor()( + ctx.template device_context(), *d_out, d_x, zero_pads, + padded_length, 0, false, math::kBatchLengthWidth); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/sgd_op.cc b/paddle/fluid/operators/sgd_op.cc index fef230e42d07a5ed73b7a7a6ab682694675bb9d2..411a126bc8e2b3a8d25f436489c13970568ccae4 100644 --- a/paddle/fluid/operators/sgd_op.cc +++ b/paddle/fluid/operators/sgd_op.cc @@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of SGDOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), @@ -42,7 +42,7 @@ class SGDOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); return framework::OpKernelType(data_type, ctx.device_context()); } @@ -50,17 +50,20 @@ class SGDOp : public framework::OperatorWithKernel { class SGDOpInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto input_var = op_desc.Input("Param")[0]; - for (auto& out_var : op_desc.Output("ParamOut")) { - if (block->FindRecursiveOrCreateVar(input_var).GetType() == - framework::proto::VarType::SELECTED_ROWS) { - block->FindRecursiveOrCreateVar(out_var).SetType( - framework::proto::VarType::SELECTED_ROWS); - } else { - block->FindRecursiveOrCreateVar(out_var).SetType( - framework::proto::VarType::LOD_TENSOR); + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto input_var_n = op_desc.Input("Param")[0]; + auto in_var_type = block->FindRecursiveOrCreateVar(input_var_n).GetType(); + PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || + in_var_type == framework::proto::VarType::LOD_TENSOR, + "The input Var's type should be LoDtensor or SelectedRows," + " but the received var(%s)'s type is %s", + input_var_n, in_var_type); + + for (auto &out_var_n : op_desc.Output("ParamOut")) { + auto &out_var = block->FindRecursiveOrCreateVar(out_var_n); + if (out_var.GetType() != in_var_type) { + out_var.SetType(in_var_type); } } } diff --git a/paddle/fluid/operators/sgd_op.cu b/paddle/fluid/operators/sgd_op.cu index 243609075713305a90dc162991166ba24d54e835..d3f4eba3b24ec1ac0328ef270256cdf3abe499db 100644 --- a/paddle/fluid/operators/sgd_op.cu +++ b/paddle/fluid/operators/sgd_op.cu @@ -56,6 +56,12 @@ template class SGDOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), param_var->Type().name()); + auto* param = ctx.Input("Param"); auto* param_out = ctx.Output("ParamOut"); auto* learning_rate = ctx.Input("LearningRate"); diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cc b/paddle/fluid/operators/truncated_gaussian_random_op.cc index d854e2803975543b51c50ea2bc173322d3c3ca5e..1e8708f2648d7dd3c10319bd0a4be193d2458d53 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cc @@ -148,7 +148,7 @@ struct TruncatedNormal { T operator()(T value) const { auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; - return (std::sqrt(2.0) * Erfinv(2 * p - 1) + mean) * std; + return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean; } }; diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cu b/paddle/fluid/operators/truncated_gaussian_random_op.cu index ad2a9021bfe344d838dff2040b3fb9371274e218..5a3510babe4d57b9e80f0e7898df98033834ca15 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cu +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cu @@ -42,7 +42,7 @@ struct TruncatedNormal { rng.discard(n); T value = dist(rng); auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; - return (std::sqrt(2.0) * erfinvf(2 * p - 1) + mean) * std; + return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean; } }; @@ -52,6 +52,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = static_cast(context.Attr("seed")); if (seed == 0) { std::random_device rd; diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 763bb403588d13c15271d26b09813dddf3a5dd8c..aa907595cb7cf165974caa69fe8eb0370471732d 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -23,14 +23,14 @@ namespace operators { template class CPUUniformRandomKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - framework::Tensor* tensor = nullptr; + void Compute(const framework::ExecutionContext &ctx) const override { + framework::Tensor *tensor = nullptr; auto out_var = ctx.OutputVar("Out"); if (out_var->IsType()) { tensor = out_var->GetMutable(); } else if (out_var->IsType()) { auto shape = ctx.Attr>("shape"); - auto* selected_rows = out_var->GetMutable(); + auto *selected_rows = out_var->GetMutable(); tensor = selected_rows->mutable_value(); tensor->Resize(framework::make_ddim(shape)); selected_rows->mutable_rows()->reserve(shape[0]); @@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel { "uniform_random_op's output only" "supports SelectedRows and LoDTensor"); } - T* data = tensor->mutable_data(ctx.GetPlace()); + T *data = tensor->mutable_data(ctx.GetPlace()); unsigned int seed = static_cast(ctx.Attr("seed")); std::minstd_rand engine; if (seed == 0) { @@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of UniformRandomOp should not be null."); PADDLE_ENFORCE( ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), "uniform_random's min must less then max"); - auto& shape = ctx->Attrs().Get>("shape"); + auto &shape = ctx->Attrs().Get>("shape"); std::vector temp; temp.reserve(shape.size()); for (auto dim : shape) { @@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); @@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max]. class UniformRandomOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { auto out_var_name = op_desc.Output("Out").front(); - if (block->FindRecursiveOrCreateVar(out_var_name).GetType() == - framework::proto::VarType::SELECTED_ROWS) { - block->FindRecursiveOrCreateVar(out_var_name) - .SetType(framework::proto::VarType::SELECTED_ROWS); - } else { - block->FindRecursiveOrCreateVar(out_var_name) - .SetType(framework::proto::VarType::LOD_TENSOR); + auto var_data_type = static_cast( + boost::get(op_desc.GetAttr("dtype"))); + + auto out_var = block->FindRecursiveOrCreateVar(out_var_name); + if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) { + out_var.SetType(framework::proto::VarType::LOD_TENSOR); } + out_var.SetDataType(var_data_type); } }; diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index 2880c09263f10e9c624e11b77188171f48d9db28..b5f472d20f40fa182a4aa55ff384b0954e4ba9e3 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -128,7 +128,7 @@ bool MayIUse(const cpu_isa_t cpu_isa) { return cpu.has(Cpu::tAVX); case avx2: return cpu.has(Cpu::tAVX2); - case avx512_common: + case avx512f: return cpu.has(Cpu::tAVX512F); case avx512_core: return true && cpu.has(Cpu::tAVX512F) && cpu.has(Cpu::tAVX512BW) && diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index 30c8fbcfce92a8b06a175ddf198cde572f72b2a4..6810a1651a14cdb2080af846b21cad242b70bf35 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -43,7 +43,7 @@ typedef enum { sse42, avx, avx2, - avx512_common, + avx512f, avx512_core, avx512_core_vnni, avx512_mic, diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index dfc079e986e93c7f02f17b299e5d6293edbedd05..4286242b2a93d7046e7349a99d1d1a09dca09113 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -198,9 +198,9 @@ class CudnnHolder { CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place), cudnn_holder_(nullptr) { SetDeviceId(place_.device); - compute_capability = GetCUDAComputeCapability(place_.device); - multi_process = GetCUDAMultiProcessors(place_.device); - max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device); + compute_capability_ = GetCUDAComputeCapability(place_.device); + multi_process_ = GetCUDAMultiProcessors(place_.device); + max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); @@ -211,6 +211,16 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) cudnn_holder_.reset(new CudnnHolder(&stream_, place)); } + driver_version_ = GetCUDADriverVersion(place_.device); + runtime_version_ = GetCUDARuntimeVersion(place_.device); + + LOG(INFO) << "device: " << place_.device + << ", CUDA Capability: " << compute_capability_ + << ", Driver Version: " << driver_version_ / 1000 << "." + << (driver_version_ % 100) / 10 + << ", Runtime Version: " << runtime_version_ / 1000 << "." + << (runtime_version_ % 100) / 10; + callback_manager_.reset(new StreamCallbackManager(stream_)); } @@ -232,11 +242,11 @@ void CUDADeviceContext::Wait() const { } int CUDADeviceContext::GetComputeCapability() const { - return compute_capability; + return compute_capability_; } int CUDADeviceContext::GetMaxPhysicalThreadCount() const { - return multi_process * max_threads_per_mp; + return multi_process_ * max_threads_per_mp_; } Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 79539195157d74d4d757edee5e008cbb76c93ee2..e1ff1a1746952de5aa4bead361b50af4e99bc9bc 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -135,9 +135,11 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream_; cublasHandle_t cublas_handle_; - int compute_capability; - int multi_process; - int max_threads_per_mp; + int compute_capability_; + int runtime_version_; + int driver_version_; + int multi_process_; + int max_threads_per_mp_; mutable std::mutex mtx_; diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index f04395a8ac00f33501008aa12f22773ddda9b138..a251bfcd9914422cb6300adbbcdef3dfa79f441c 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -130,6 +130,13 @@ struct EOFException : public std::exception { #define UNLIKELY(condition) (condition == 0) #endif +#if !defined(_WIN32) +#define LIKELY(condition) __builtin_expect(static_cast(condition), 1) +#else +// there is no equivalent intrinsics in msvc. +#define LIKELY(condition) (condition != 0) +#endif + template inline typename std::enable_if::type throw_on_error( bool stat, const Args&... args) { diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index f599e7fbc886a60394ae4690e4160275b55b8596..8fff9844db738dbd6508569a8aaeed044e445e5f 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -46,6 +46,24 @@ int GetCUDAComputeCapability(int id) { return device_prop.major * 10 + device_prop.minor; } +int GetCUDARuntimeVersion(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + int runtime_version = 0; + PADDLE_ENFORCE(cudaRuntimeGetVersion(&runtime_version), + "cudaRuntimeGetVersion failed in " + "paddle::platform::cudaRuntimeGetVersion"); + return runtime_version; +} + +int GetCUDADriverVersion(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + int driver_version = 0; + PADDLE_ENFORCE(cudaDriverGetVersion(&driver_version), + "cudaDriverGetVersion failed in " + "paddle::platform::GetCUDADriverVersion"); + return driver_version; +} + int GetCUDAMultiProcessors(int id) { PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); int count; diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index f4640d3eaa2165c35e8e14690d83e9e7e7168c0b..be44158431ff80a41f7fdf4dfd4d070667f2ac63 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -29,6 +29,12 @@ int GetCUDADeviceCount(); //! Get the compute capability of the ith GPU (format: major * 10 + minor) int GetCUDAComputeCapability(int i); +//! Get the runtime version of the ith GPU +int GetCUDARuntimeVersion(int id); + +//! Get the driver version of the ith GPU +int GetCUDADriverVersion(int id); + //! Get the MultiProcessors of the ith GPU. int GetCUDAMultiProcessors(int i); diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 4c99f4be321160caf0ee2f89a655bdfb933408e3..ab91ca5345047f3053eb8771e6a265d2a3011f85 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { platform::SetNumThreads(FLAGS_paddle_num_threads); #endif - if (platform::jit::MayIUse(platform::jit::avx512_common)) { + if (platform::jit::MayIUse(platform::jit::avx512f)) { #ifndef __AVX512F__ LOG(WARNING) << "AVX512F is available, Please re-compile on local machine"; #endif diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index 652a6ec7a4e2e823b28f39b449570cd375e88e18..a35147da90e87af85308431fd7dbe965bb1fd1d7 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -276,7 +276,7 @@ struct EventItem { // Print results void PrintProfiler(const std::vector>& events_table, const std::string& sorted_domain, const size_t name_width, - const size_t data_width, double total) { + const size_t data_width, bool merge_thread) { // Output header information std::cout << "\n------------------------->" << " Profiling Report " @@ -292,6 +292,10 @@ void PrintProfiler(const std::vector>& events_table, PADDLE_THROW("Invalid profiler state", g_state); } + if (merge_thread) { + std::cout << "Note! This Report merge all thread info into one." + << std::endl; + } std::cout << "Place: " << place << std::endl; std::cout << "Time unit: ms" << std::endl; std::cout << "Sorted by " << sorted_domain @@ -312,8 +316,7 @@ void PrintProfiler(const std::vector>& events_table, << std::setw(data_width) << event_item.min_time << std::setw(data_width) << event_item.max_time << std::setw(data_width) << event_item.ave_time - << std::setw(data_width) << event_item.total_time / total - << std::endl; + << std::setw(data_width) << event_item.ratio << std::endl; } } std::cout << std::endl; @@ -321,8 +324,10 @@ void PrintProfiler(const std::vector>& events_table, // Parse the event list and output the profiling report void ParseEvents(const std::vector>& events, + bool merge_thread, EventSortingKey sorted_by = EventSortingKey::kDefault) { if (g_state == ProfilerState::kDisabled) return; + if (merge_thread && events.size() < 2) return; std::string sorted_domain; std::function sorted_func; @@ -361,34 +366,55 @@ void ParseEvents(const std::vector>& events, sorted_domain = "event first end time"; } + const std::vector>* analyze_events; + std::vector> merged_events_list; + if (merge_thread) { + std::vector merged_events; + for (size_t i = 0; i < events.size(); ++i) { + for (size_t j = 0; j < events[i].size(); ++j) { + merged_events.push_back(events[i][j]); + } + } + merged_events_list.push_back(merged_events); + analyze_events = &merged_events_list; + } else { + analyze_events = &events; + } + std::vector> events_table; size_t max_name_width = 0; - double total = 0.; // the total time - for (size_t i = 0; i < events.size(); i++) { + for (size_t i = 0; i < (*analyze_events).size(); i++) { + double total = 0.; // the total time in one thread std::list pushed_events; std::vector event_items; std::unordered_map event_idx; - for (size_t j = 0; j < events[i].size(); j++) { - if (events[i][j].type() == EventType::kPushRange) { - pushed_events.push_back(events[i][j]); - } else if (events[i][j].type() == EventType::kPopRange) { + for (size_t j = 0; j < (*analyze_events)[i].size(); j++) { + if ((*analyze_events)[i][j].type() == EventType::kPushRange) { + pushed_events.push_back((*analyze_events)[i][j]); + } else if ((*analyze_events)[i][j].type() == EventType::kPopRange) { std::list::reverse_iterator rit = pushed_events.rbegin(); while (rit != pushed_events.rend() && - rit->name() != events[i][j].name()) { + rit->name() != (*analyze_events)[i][j].name()) { ++rit; } if (rit != pushed_events.rend()) { double event_time = (g_state == ProfilerState::kCUDA || g_state == ProfilerState::kAll) - ? rit->CudaElapsedMs(events[i][j]) - : rit->CpuElapsedMs(events[i][j]); + ? rit->CudaElapsedMs((*analyze_events)[i][j]) + : rit->CpuElapsedMs((*analyze_events)[i][j]); total += event_time; - std::string event_name = - "thread" + std::to_string(rit->thread_id()) + "::" + rit->name(); - max_name_width = std::max(max_name_width, event_name.size()); + std::string event_name; + if (merge_thread) { + event_name = rit->name(); + max_name_width = std::max(max_name_width, event_name.size()); + } else { + event_name = "thread" + std::to_string(rit->thread_id()) + "::" + + rit->name(); + max_name_width = std::max(max_name_width, event_name.size()); + } if (event_idx.find(event_name) == event_idx.end()) { event_idx[event_name] = event_items.size(); @@ -413,7 +439,7 @@ void ParseEvents(const std::vector>& events, pushed_events.erase((++rit).base()); } else { LOG(WARNING) << "Cannot find the push marker of event \'" - << events[i][j].name() + << (*analyze_events)[i][j].name() << "\', which will be ignored in profiling report."; } } @@ -421,6 +447,7 @@ void ParseEvents(const std::vector>& events, // average time for (auto& item : event_items) { item.ave_time = item.total_time / item.calls; + item.ratio = item.total_time / total; } // sort if (sorted_by != EventSortingKey::kDefault) { @@ -438,7 +465,8 @@ void ParseEvents(const std::vector>& events, } // Print report - PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12, total); + PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12, + merge_thread); } void DisableProfiler(EventSortingKey sorted_key, @@ -449,7 +477,8 @@ void DisableProfiler(EventSortingKey sorted_key, Mark("_stop_profiler_", nullptr); std::vector> all_events = GetAllEvents(); - ParseEvents(all_events, sorted_key); + ParseEvents(all_events, true, sorted_key); + ParseEvents(all_events, false, sorted_key); ResetProfiler(); DeviceTracer* tracer = GetDeviceTracer(); if (tracer->IsEnabled()) { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 311cd944603e9bdfefef4daa3a9c690df5b30235..339a7c98c6a2bba2cd46790cecc169ef447c63ce 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -57,6 +57,10 @@ limitations under the License. */ #include "pybind11/stl.h" +DEFINE_bool(reader_queue_speed_test_mode, false, + "If set true, the queue.pop will only get data from queue but not " + "remove the data from queue for speed testing"); + // disable auto conversion to list in Python PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray); @@ -157,7 +161,50 @@ PYBIND11_PLUGIN(core) { .def("_get_double_element", TensorGetElement) .def("_dtype", [](Tensor &self) { return ToDataType(self.type()); }); - py::class_(m, "LoDTensor") + py::class_(m, "LoDTensor", R"DOC( + LoDTensor is a Tensor with optional LoD information. + + np.array(lod_tensor) can convert LoDTensor to numpy array. + lod_tensor.lod() can retrieve the LoD information. + + LoD is short for Level of Details and is usually used for varied sequence + length. You can skip the following comment if you don't need optional LoD. + + For example: + A LoDTensor X can look like the example below. It contains 2 sequences. + The first has length 2 and the second has length 3, as described by x.lod. + + The first tensor dimension 5=2+3 is calculated from LoD if it's available. + It means the total number of sequence element. In X, each element has 2 + columns, hence [5, 2]. + + x.lod = [[2, 3]] + x.data = [[1, 2], [3, 4], + [5, 6], [7, 8], [9, 10]] + x.shape = [5, 2] + + LoD can have multiple levels (for example, a paragraph can have multiple + sentences and a sentence can have multiple words). In the following + LodTensor Y, the lod_level is 2. It means there are 2 sequence, the + first sequence length is 2 (has 2 sub-sequences), the second one's + length is 1. The first sequence's 2 sub-sequences have length 2 and 2, + respectively. And the second sequence's 1 sub-sequence has length 3. + + y.lod = [[2 1], [2 2 3]] + y.shape = [2+2+3, ...] + + Note: + In above description, LoD is length-based. In Paddle internal + implementation, lod is offset-based. Hence, internally, + y.lod is represented as [[0, 2, 3], [0, 2, 4, 7]] (length-based + equivlent would be [[2-0, 3-2], [2-0, 4-2, 7-4]]). + + Sometimes LoD is called recursive_sequence_length to be more + self-explanatory. In this case, it must be length-based. Due to history + reasons. when LoD is called lod in public API, it might be offset-based. + Users should be careful about it. + + )DOC") .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) .def("__init__", @@ -337,7 +384,8 @@ All parameter, weight, gradient are variables in Paddle. return make_ddim(shape); }); auto *holder = var.GetMutable(); - holder->InitOnce(capacity, dims); + holder->InitOnce(capacity, dims, + FLAGS_reader_queue_speed_test_mode); return holder->GetQueue(); }, py::return_value_policy::copy); @@ -624,16 +672,17 @@ All parameter, weight, gradient are variables in Paddle. ExecutionStrategy allows the user to more preciously control how to run the program in ParallelExecutor by setting the property. - The available properties include: - use_cuda (bool): Whether to use CUDA or not. Default True. - num_threads (int): The number of threads that used to run the - operators in ParallelExecutor. If it is not set, it will be - set in ParallelExecutor according to the device count. - Default 0. - allow_op_delay (bool): Whether to delay the communication operators - to run. Default False. - num_iteration_per_drop_scope (int): how many iterations between - the two dropping local scopes. Default 100. + Examples: + .. code-block:: python + + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = 4 + + train_exe = fluid.ParallelExecutor(use_cuda=True, + loss_name=loss.name, + exec_strategy=exec_strategy) + + train_loss, = train_exe.run([loss.name], feed=feed_dict) )DOC"); @@ -643,19 +692,34 @@ All parameter, weight, gradient are variables in Paddle. [](const ExecutionStrategy &self) { return self.num_threads_; }, [](ExecutionStrategy &self, size_t num_threads) { self.num_threads_ = num_threads; - }) + }, + R"DOC(The type is INT, num_threads represents the size of thread pool that + used to run the operators of the current program in ParallelExecutor. + If :math:`num\_threads=1`, all the operators will execute one by one, + but the order maybe difference between iterations. + If it is not set, it will be set in ParallelExecutor according to the + device type and device count, for GPU, :math:`num\_threads=device\_count*4`, for CPU, + :math:`num\_threads=CPU\_NUM*4`, the explanation of:math:`CPU\_NUM` is in ParallelExecutor. + if it is not set, ParallelExecutor will get the cpu count by calling + `multiprocessing.cpu_count()`. Default 0.)DOC") .def_property( "use_cuda", [](const ExecutionStrategy &self) { return self.use_cuda_; }, [](ExecutionStrategy &self, bool use_cuda) { self.use_cuda_ = use_cuda; - }) + }) // FIXME(chengduo): Doesn't add doc for 'use_cuda', use_cuda may + // make user confuse, because ParallelExecutor has a parameter named + // 'use_cuda' too, in current implementation, ParallelExecutor's + // 'use_cuda' will rewrite ExecutionStrategy's 'use_cuda'. .def_property( "allow_op_delay", [](const ExecutionStrategy &self) { return self.allow_op_delay_; }, [](ExecutionStrategy &self, bool allow_op_delay) { self.allow_op_delay_ = allow_op_delay; - }) + }, + R"DOC(The type is BOOL, allow_op_delay represents whether to delay the + communication operators to run, it may make the execution faster. + Note that in some models, allow_op_delay may cause program hang. Default False.)DOC") .def_property( "num_iteration_per_drop_scope", [](const ExecutionStrategy &self) { @@ -663,7 +727,19 @@ All parameter, weight, gradient are variables in Paddle. }, [](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) { self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope; - }); + }, + R"DOC(The type is INT, num_iteration_per_drop_scope indicates how + many iterations to clean up the temp variables which + is generated during execution. It may make the execution faster, + because the temp variable's shape maybe the same between two iterations. Default 100. + + NOTES: + 1. If you fetch data when calling the 'run', the ParallelExecutor + will clean up the temp variables at the end of the current iteration. + 2. In some NLP model, it may cause the GPU memory is insufficient, + in this case, you should reduce `num_iteration_per_drop_scope`. + )DOC"); + exec_strategy.def_property( "use_experimental_executor", [](const ExecutionStrategy &self) { @@ -678,20 +754,17 @@ All parameter, weight, gradient are variables in Paddle. BuildStrategy allows the user to more preciously control how to build the SSA Graph in ParallelExecutor by setting the property. - The available properties include: - reduce_strategy (str): There are two reduce strategies, 'AllReduce' - and 'Reduce'. If you want that all parameters will be optimized - on all devices, you can choose 'AllReduce'; if you choose - 'Reduce', all parameters will be evenly allocated to different - devices for optimization, and then broadcast the optimized - parameter to other devices. Default 'AllReduce'. - gradient_scale_strategy (str): There are two ways of defining loss@grad, - 'CoeffNumDevice' and 'Customized'. By default, ParallelExecutor - sets the loss@grad according to the number of devices. If you want - to customize loss@grad, you can choose 'Customized'. - Default 'CoeffNumDevice'. - debug_graphviz_path (str): Whether to write the SSA Graph to file in the - form of graphviz. It is useful for debugging. Default "". + Examples: + .. code-block:: python + + build_strategy = fluid.BuildStrategy() + build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce + + train_exe = fluid.ParallelExecutor(use_cuda=True, + loss_name=loss.name, + build_strategy=build_strategy) + + train_loss, = train_exe.run([loss.name], feed=feed_dict) )DOC"); py::enum_(build_strategy, "ReduceStrategy") @@ -710,31 +783,51 @@ All parameter, weight, gradient are variables in Paddle. [](const BuildStrategy &self) { return self.reduce_; }, [](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) { self.reduce_ = strategy; - }) + }, + R"DOC(The type is STR, there are two reduce strategies in ParallelExecutor, + 'AllReduce' and 'Reduce'. If you want that all the parameters' + optimization are done on all devices independently, you should choose 'AllReduce'; + if you choose 'Reduce', all the parameters' optimization will be evenly distributed + to different devices, and then broadcast the optimized parameter to other devices. + In some models, `Reduce` is faster. Default 'AllReduce'. )DOC") .def_property( "gradient_scale_strategy", [](const BuildStrategy &self) { return self.gradient_scale_; }, [](BuildStrategy &self, BuildStrategy::GradientScaleStrategy strategy) { self.gradient_scale_ = strategy; - }) + }, + R"DOC(The type is STR, there are three ways of defining :math:`loss@grad` in + ParallelExecutor, 'CoeffNumDevice', 'One' and 'Customized'. By default, + ParallelExecutor sets the :math:`loss@grad` according to the number of devices. + If you want to customize :math:`loss@grad`, you can choose 'Customized'. + Default 'CoeffNumDevice'.)DOC") .def_property( "debug_graphviz_path", [](const BuildStrategy &self) { return self.debug_graphviz_path_; }, [](BuildStrategy &self, const std::string &path) { self.debug_graphviz_path_ = path; - }) + }, + R"DOC(The type is STR, debug_graphviz_path indicate the path that + writing the SSA Graph to file in the form of graphviz, you. + It is useful for debugging. Default "")DOC") .def_property( "enable_data_balance", [](const BuildStrategy &self) { return self.enable_data_balance_; }, - [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }) - .def_property("fuse_elewise_add_act_ops", - [](const BuildStrategy &self) { - return self.fuse_elewise_add_act_ops_; - }, - [](BuildStrategy &self, bool b) { - self.fuse_elewise_add_act_ops_ = b; - }) + [](BuildStrategy &self, bool b) { + self.enable_data_balance_ = b; + }) // FIXME(chengudo): enable_data_balance seems not important + .def_property( + "fuse_elewise_add_act_ops", + [](const BuildStrategy &self) { + return self.fuse_elewise_add_act_ops_; + }, + [](BuildStrategy &self, bool b) { + self.fuse_elewise_add_act_ops_ = b; + }, + R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether + to fuse elementwise_add_op and activation_op, + it may make the execution faster. Default False)DOC") .def("_create_passes_from_strategy", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(); diff --git a/paddle/fluid/train/demo/README.md b/paddle/fluid/train/demo/README.md index 41b01d33828f750f67bba5f82cb7ed6fe4d4ea0a..191da20669e185d819ec5eed55427461cc0b10e4 100644 --- a/paddle/fluid/train/demo/README.md +++ b/paddle/fluid/train/demo/README.md @@ -15,7 +15,7 @@ cmake .. -DFLUID_INSTALL_DIR=$PADDLE_LIB \ -DWITH_MKL=OFF \ -DWITH_MKLDNN=OFF make -j8 -make -j8 inference_lib_dist +make -j8 fluid_lib_dist ``` ### step 2. generate program desc diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index e133323ae420ba68d90215767ab940aed744acd6..da6f5ca1586a570fb548d7b987330a8d58156e24 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -648,25 +648,25 @@ function gen_capi_package() { fi } -function gen_fluid_inference_lib() { +function gen_fluid_lib() { mkdir -p ${PADDLE_ROOT}/build cd ${PADDLE_ROOT}/build if [[ ${WITH_C_API:-OFF} == "OFF" && ${WITH_INFERENCE:-ON} == "ON" ]] ; then cat < self.max_norm: + output = self.max_norm * y_np / norm + else: + output = y_np + self.assertTrue( + np.allclose( + np.array(out_tensor), output, atol=1e-5, equal_nan=False)) + + def test_clip_by_norm_with_selected_ros(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + + for place in places: + self.check_with_place(place) + + def config_test_case(self): + self.max_norm = 1.0 + self.max_relative_error = 0.006 + self.grad_shape = (4, 1) + self.grad_clipped_shape = (3, 1) + self.grad_rows = [0, 0, 1, 2] + self.grad_clipped_rows = [0, 1, 2] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py index e971f29db42a7c1a2394505a8ece3d2fd6b347e9..11095f23591edc41a82962149a52096fa17cfb93 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py +++ b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py @@ -25,7 +25,11 @@ class TestDistSimnetBowDense2x2(TestDistBase): self._enforce_place = "CPU" def test_simnet_bow(self): - need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '0'} + need_envs = { + "IS_DISTRIBUTED": '0', + "IS_SPARSE": '0', + 'IS_SELF_CONTAINED_LR': '1' + } self.check_with_place( "dist_simnet_bow.py", delta=1e-5, @@ -39,7 +43,11 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase): self._enforce_place = "CPU" def test_simnet_bow(self): - need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '0'} + need_envs = { + "IS_DISTRIBUTED": '0', + "IS_SPARSE": '0', + 'IS_SELF_CONTAINED_LR': '1' + } self.check_with_place( "dist_simnet_bow.py", delta=100, @@ -53,7 +61,11 @@ class TestDistSimnetBowSparse2x2(TestDistBase): self._enforce_place = "CPU" def test_simnet_bow(self): - need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '1'} + need_envs = { + "IS_DISTRIBUTED": '0', + "IS_SPARSE": '1', + 'IS_SELF_CONTAINED_LR': '1' + } self.check_with_place( "dist_simnet_bow.py", delta=1e-5, @@ -67,7 +79,11 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): self._enforce_place = "CPU" def test_simnet_bow(self): - need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '1'} + need_envs = { + "IS_DISTRIBUTED": '0', + "IS_SPARSE": '1', + 'IS_SELF_CONTAINED_LR': '1' + } self.check_with_place( "dist_simnet_bow.py", delta=100, @@ -75,5 +91,59 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): need_envs=need_envs) +class TestDistSimnetBow2x2LookupTableSync(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._enforce_place = "CPU" + + def test_simnet_bow(self): + need_envs = { + "IS_DISTRIBUTED": '1', + "IS_SPARSE": '1', + 'IS_SELF_CONTAINED_LR': '1' + } + self.check_with_place( + "dist_simnet_bow.py", + delta=1e-5, + check_error_log=False, + need_envs=need_envs) + + +class TestDistSimnetBow2x2LookupTableAsync(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._enforce_place = "CPU" + + def test_simnet_bow(self): + need_envs = { + "IS_DISTRIBUTED": '1', + "IS_SPARSE": '1', + 'IS_SELF_CONTAINED_LR': '1' + } + self.check_with_place( + "dist_simnet_bow.py", + delta=100, + check_error_log=False, + need_envs=need_envs) + + +class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._enforce_place = "CPU" + + def test_simnet_bow(self): + need_envs = { + "IS_DISTRIBUTED": '1', + "IS_SPARSE": '1', + 'IS_SELF_CONTAINED_LR': '0' + } + self.check_with_place( + "dist_simnet_bow.py", + delta=1e-5, + check_error_log=False, + need_envs=need_envs) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_embedding_fc_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fused_embedding_fc_lstm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..70ca521d3387ac11cd41d8496b4d094667232d4c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_embedding_fc_lstm_op.py @@ -0,0 +1,218 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from test_lstm_op import lstm, ACTIVATION + + +def fc(x, w, b): + return np.dot(x, w) + b + + +def fused_embedded_fc_lstm( + ids, # T x 1 + lod, # 1 x N + embeddings=None, # Dict_size x M + wx=None, # M x 4D + bx=None, # 1 x 4D + h0=None, # N x D + c0=None, # N x D + w_h=None, # D x 4D + w_b=None, # 1 x 4D + w_c=None, # 1 x 3D + is_reverse=False, + act_gate=None, + act_cell=None, + act_cand=None): + # Make a lookup for embeddings and pass result into lstm reference + T = ids.shape[0] + M = embeddings.shape[1] + x = embeddings[ids].reshape([T, M]) + return lstm( + fc(x, wx, bx), lod, h0, c0, w_h, w_b, w_c, is_reverse, act_gate, + act_cell, act_cand) + + +class TestFusionLSTMOp(OpTest): + def set_conf(self): + pass + + def setUp(self): + self.op_type = 'fused_embedding_fc_lstm' + self.lod = [[2, 3, 5, 4]] + self.M = 8 # Embedding size + self.D = 16 # Hidden size + self.dict_size = 18 + self.has_initial_state = False + self.use_peepholes = False + self.is_reverse = False + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + self.set_conf() + + T = sum(self.lod[0]) + bs = len(self.lod[0]) + + # this is the weight of fc + wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32') + # this is the bias of fc + bx = np.random.normal(size=(1, 4 * self.D)).astype('float32') + + if self.use_peepholes: + b = np.random.normal(size=(1, 7 * self.D)).astype('float32') + else: + b = np.random.normal(size=(1, 4 * self.D)).astype('float32') + w_b = np.copy(b[:, 0:4 * self.D]) + w_c = b[:, 4 * self.D:] if self.use_peepholes else None + + # low is 0 , high is voc_size - 1 + ids = np.random.randint( + low=0, high=self.dict_size - 1, size=(T, 1)).astype("int64") + # embeddings as they were trained , so each entry is of M size + embeddings = np.random.random( + (self.dict_size, self.M)).astype("float32") + + # multiply embeddings via Weights + fc_embeddings = np.dot(embeddings, wx) + + # bias should be manually added into the bias of this fused embedding fc LSTM + b[0, 0:4 * self.D] += bx[0, :] + combined_biases = b[:, 0:4 * self.D] + # So let broadcast it , so they can be added + ones = np.ones([self.dict_size, 1]) + broadcasted_biases = np.dot(ones, combined_biases) + # Sum biases with Wx*embeddings + fc_embeddings += broadcasted_biases + + if self.has_initial_state: + h0 = np.random.normal(size=(bs, self.D)).astype('float32') + c0 = np.random.normal(size=(bs, self.D)).astype('float32') + else: + h0 = np.zeros((bs, self.D)).astype('float32') + c0 = np.zeros((bs, self.D)).astype('float32') + + wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32') + + h, c = fused_embedded_fc_lstm( + ids, self.lod, embeddings, wx, bx, h0, c0, wh, w_b, w_c, + self.is_reverse, ACTIVATION[self.act_gate], + ACTIVATION[self.act_cell], ACTIVATION[self.act_cand]) + + self.inputs = { + 'Ids': (ids, self.lod), + 'Embeddings': fc_embeddings, + 'WeightH': wh, + 'Bias': b + } + + if self.has_initial_state: + self.inputs['H0'] = h0 + self.inputs['C0'] = c0 + + self.outputs = { + 'Hidden': (h, self.lod), + 'Cell': (c, self.lod), + } + self.attrs = { + 'use_peepholes': self.use_peepholes, + 'is_reverse': self.is_reverse, + 'gate_activation': self.act_gate, + 'cell_activation': self.act_cell, + 'candidate_activation': self.act_cand + } + + def test_check_output(self): + for use_seq in {True, False}: + self.attrs['use_seq'] = use_seq + self.check_output() + + +class TestFusionLSTMOpInit(TestFusionLSTMOp): + def set_conf(self): + self.has_initial_state = True + + +class TestFusionLSTMOpReverse(TestFusionLSTMOp): + def set_conf(self): + self.is_reverse = True + + +class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): + def set_conf(self): + self.has_initial_state = True + self.is_reverse = True + + +class TestFusionLSTMOpMD1(TestFusionLSTMOp): + def set_conf(self): + self.M = 36 + self.D = 8 + + +class TestFusionLSTMOpMD2(TestFusionLSTMOp): + def set_conf(self): + self.M = 8 + self.D = 8 + + +class TestFusionLSTMOpMD3(TestFusionLSTMOp): + def set_conf(self): + self.M = 15 + self.D = 3 + + +class TestFusionLSTMOpBS1(TestFusionLSTMOp): + def set_conf(self): + self.lod = [[3]] + self.D = 16 + + +class TestFusionLSTMOpPeepholes(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + + +class TestFusionLSTMOpPeepholesInit(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.has_initial_state = True + + +class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.is_reverse = True + + +class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.has_initial_state = True + self.is_reverse = True + + +class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.lod = [[2]] + self.D = 8 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 1d8d0b55f0c5d7cffa01a100847bdf48b6d7023d..dc70477ebe1cfbffd207ebb4bbf9d9f39893d79e 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -194,6 +194,14 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(layers.sequence_expand(x=x, y=y, ref_level=1)) print(str(program)) + def test_sequence_unpad(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[10, 5], dtype='float32') + length = layers.data(name='length', shape=[1], dtype='int64') + self.assertIsNotNone(layers.sequence_unpad(x=x, length=length)) + print(str(program)) + def test_lstm_unit(self): program = Program() with program_guard(program): @@ -406,6 +414,19 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_sequence_slice(self): + program = Program() + with program_guard(program): + import numpy as np + seqs = layers.data( + name='x', shape=[10, 5], dtype='float32', lod_level=1) + offset = layers.assign(input=np.array([[0, 1]]).astype('int32')) + length = layers.assign(input=np.array([[2, 1]]).astype('int32')) + out = layers.sequence_slice( + input=seqs, offset=offset, length=length) + self.assertIsNotNone(out) + print(str(program)) + def test_lod_reset(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py index 70848e4e2239e2be160bb0c1a28a5aecd01a87dc..eb12bc741767340a3e7e3580a8b95065d4267693 100644 --- a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py +++ b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py @@ -19,33 +19,76 @@ import unittest import numpy as np import paddle.fluid.core as core from paddle.fluid.op import Operator +import paddle.fluid as fluid + + +def create_selected_rows_and_tensor(scope, place, height, row_num, + embedding_size): + sr = scope.var("@selected_rows@").get_selected_rows() + tensor = scope.var("grad").get_tensor() + + rows = np.random.random_integers( + low=0, high=height - 1, size=[row_num, ]).astype('int64') + sr_val = np.random.random(size=[row_num, embedding_size]).astype('float32') + + sr.set_height(height) + sr.set_rows(rows) + sr.get_tensor().set(sr_val, place) + + tensor_val = np.zeros(shape=[height, embedding_size], dtype='float32') + for i in range(row_num): + row = rows[i] + tensor_val[row, :] = tensor_val[row, :] + sr_val[i, :] + + tensor.set(tensor_val, place) + return tensor_val, sr_val class TestBase(unittest.TestCase): - def setup(self, centered, epsilon=1e-6): + def setup(self, + place, + is_sparse, + centered, + size, + row_num=None, + epsilon=1e-6): np.random.seed(5) # fix seed + self.scope = fluid.global_scope() + self.place = place + self.param_name = "param" - self.param = np.random.random((123, 321)).astype("float32") + self.param = np.random.random(size).astype("float32") self.mean_square_name = "mean_square" - self.mean_square = np.random.random((123, 321)).astype("float32") + self.mean_square = np.random.uniform( + low=1, high=2, size=size).astype("float32") self.mean_grad_name = "mean_grad" - self.mean_grad = np.random.random((123, 321)).astype("float32") + self.mean_grad = np.random.random(size).astype("float32") self.lr_name = "lr" self.learning_rate = np.array([0.01]).astype("float32") self.grad_name = "grad" - self.grad = np.random.random((123, 321)).astype("float32") + + self.is_sparse = is_sparse + if self.is_sparse: + self.grad_sr_name = "@selected_rows@" + self.grad, self.grad_sr = create_selected_rows_and_tensor( + self.scope, place, size[0], row_num, size[1]) + else: + self.grad = np.random.random(size).astype("float32") + grad_tensor = self.scope.var(self.grad_name).get_tensor() + grad_tensor.set(self.grad, place) self.moment_name = "moment" - self.moment = np.zeros((123, 321)).astype("float32") + self.moment = np.random.uniform( + low=0, high=1, size=size).astype("float32") self.epsilon = epsilon self.decay = 0.9 - self.momentum = 0.0 + self.momentum = 0.1 self.centered = centered self.ms_out = self.decay * self.mean_square + (1 - self.decay @@ -61,118 +104,122 @@ class TestBase(unittest.TestCase): self.param_out = self.param - self.moment_out - def check(self, - actual_t, - expect_t, - place, - out_name, - atol=1e-5, - equal_nan=False): - self.assertTrue( - np.allclose( - actual_t, expect_t, atol=atol, equal_nan=equal_nan), - "Output (" + out_name + ") has diff at " + str(place) + "\nExpect " - + str(expect_t) + "\n" + "But Got" + str(actual_t)) - - -class TestRmspropOp(TestBase): - def check_with_place(self, place, centered, epsilon): - self.setup(centered, epsilon) - scope = core.Scope() - # create and initialize Param Variable - param = scope.var(self.param_name).get_tensor() - param.set(self.param, place) + self.param_tensor = self.scope.var(self.param_name).get_tensor() + self.param_tensor.set(self.param, place) - mean_square = scope.var(self.mean_square_name).get_tensor() - mean_square.set(self.mean_square, place) + self.mean_square_tensor = self.scope.var( + self.mean_square_name).get_tensor() + self.mean_square_tensor.set(self.mean_square, place) - lr = scope.var(self.lr_name).get_tensor() + lr = self.scope.var(self.lr_name).get_tensor() lr.set(self.learning_rate, place) - grad = scope.var(self.grad_name).get_tensor() - grad.set(self.grad, place) + self.moment_tensor = self.scope.var(self.moment_name).get_tensor() + self.moment_tensor.set(self.moment, place) - moment = scope.var(self.moment_name).get_tensor() - moment.set(self.moment, place) + if self.centered: + self.mean_grad_tensor = self.scope.var( + self.mean_grad_name).get_tensor() + self.mean_grad_tensor.set(self.mean_grad, place) - # create and run sgd operator + def check(self, actual_t, expect_t, place, out_name, atol=1e-5): + self.assertTrue( + np.allclose( + actual_t, expect_t, atol=atol), + "Output (" + out_name + ") has diff at " + str(place) + "\nExpect " + + str(expect_t) + "\n" + "But Got" + str(actual_t)) - if self.centered: - mean_grad = scope.var(self.mean_grad_name).get_tensor() - mean_grad.set(self.mean_grad, place) - - rmsprop_op = Operator( - "rmsprop", - Param=self.param_name, - Grad=self.grad_name, - MeanSquare=self.mean_square_name, - MeanGrad=self.mean_grad_name, - Moment=self.moment_name, - LearningRate=self.lr_name, - ParamOut=self.param_name, - MeanSquareOut=self.mean_square_name, - MomentOut=self.moment_name, - MeanGradOut=self.mean_grad_name, - epsilon=self.epsilon, - decay=self.decay, - momentum=self.momentum, - centered=True) - else: - rmsprop_op = Operator( - "rmsprop", - Param=self.param_name, - Grad=self.grad_name, - MeanSquare=self.mean_square_name, - Moment=self.moment_name, - LearningRate=self.lr_name, - ParamOut=self.param_name, - MeanSquareOut=self.mean_square_name, - MomentOut=self.moment_name, - epsilon=self.epsilon, - decay=self.decay, - momentum=self.momentum, - centered=False) - - rmsprop_op.run(scope, place) - - atol = 1e-5 - equal_nan = False + +class TestRmspropOp(TestBase): + def check_with_place(self, + place, + is_sparse, + centered, + size, + row_num=None, + epsilon=1e-6): + self.setup(place, is_sparse, centered, size, row_num, epsilon) + self.run_and_check() + + def run_and_check(self): + grad_name = self.grad_sr_name if self.is_sparse else self.grad_name + + kwargs = { + 'Param': self.param_name, + 'Grad': grad_name, + 'MeanSquare': self.mean_square_name, + 'Moment': self.moment_name, + 'LearningRate': self.lr_name, + 'ParamOut': self.param_name, + 'MeanSquareOut': self.mean_square_name, + 'MomentOut': self.moment_name, + 'epsilon': self.epsilon, + 'decay': self.decay, + 'momentum': self.momentum, + 'centered': self.centered + } if self.centered: - atol = 1e-3 - equal_nan = True + kwargs['MeanGrad'] = self.mean_grad_name + kwargs['MeanGradOut'] = self.mean_grad_name + + rmsprop_op = Operator('rmsprop', **kwargs) + atol = 1e-6 + + rmsprop_op.run(self.scope, self.place) self.check( - np.array(mean_square), self.ms_out, place, self.mean_square_name) + np.array(self.mean_square_tensor), + self.ms_out, + self.place, + self.mean_square_name, + atol=atol) self.check( - np.array(moment), + np.array(self.moment_tensor), self.moment_out, - place, + self.place, self.moment_name, - atol=atol, - equal_nan=equal_nan) + atol=atol) self.check( - np.array(param), + np.array(self.param_tensor), self.param_out, - place, + self.place, self.param_name, - atol=atol, - equal_nan=equal_nan) + atol=atol) if self.centered: self.check( - np.array(mean_grad), self.mg_out, place, self.mean_grad_name) + np.array(self.mean_grad_tensor), self.mg_out, self.place, + self.mean_grad_name) def test_rmsprop(self): places = [core.CPUPlace()] if core.is_compiled_with_cuda(): places.append(core.CUDAPlace(0)) + + size = (128, 320) for place in places: - self.check_with_place(place, False, 1e-6) - self.check_with_place(place, False, 1e-10) - self.check_with_place(place, True, 1e-6) - self.check_with_place(place, True, 1e-10) + for centered in [False, True]: + with fluid.scope_guard(core.Scope()): + self.check_with_place( + place, is_sparse=False, centered=centered, size=size) + + with fluid.scope_guard(core.Scope()): + self.check_with_place( + place, + is_sparse=True, + centered=centered, + row_num=512, + size=size) + + with fluid.scope_guard(core.Scope()): + self.check_with_place( + place, + is_sparse=True, + centered=centered, + row_num=60, + size=size) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py new file mode 100644 index 0000000000000000000000000000000000000000..673b0ea180464b8b8f6f5c6e76d5c5c80f347d25 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py @@ -0,0 +1,75 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import six +import numpy as np +from op_test import OpTest + + +class TestSequenceUnpadOp(OpTest): + def init(self): + self.length = [2, 3, 4] + self.x_shape = (3, 5) + self.dtype = "float32" + + def compute(self): + assert len(self.length) == self.x_shape[0] + x = np.random.random(self.x_shape).astype(self.dtype) + out_lod = [self.length] + + out = x[0, 0:self.length[0]] + for i in six.moves.xrange(1, x.shape[0]): + out = np.append(out, x[i, 0:self.length[i]], axis=0) + + out_shape = (sum(self.length), ) + if len(self.x_shape) == 2: + out_shape = out_shape + (1, ) + else: + out_shape = out_shape + self.x_shape[2:] + + self.inputs = { + 'X': x, + 'Length': np.array(self.length).astype('int64').reshape(-1, 1) + } + self.outputs = {'Out': (out.reshape(out_shape), out_lod)} + + def setUp(self): + self.op_type = 'sequence_unpad' + self.init() + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSequenceUnpadOp2(TestSequenceUnpadOp): + def init(self): + self.length = [2, 3, 4] + self.x_shape = (3, 5, 4, 3) + self.dtype = "float32" + + +class TestSequenceUnpadOp3(TestSequenceUnpadOp): + def init(self): + self.length = [5, 2, 3, 4] + self.x_shape = (4, 5, 3, 3, 6) + self.dtype = "float64" + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index ecdbe27f4d90268d755a712e25289cfaf4715f29..2192139f8d5950286691a77333dd8ec35505b033 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -788,7 +788,8 @@ in a single call.") tuple: (main_program, startup_program), of type "Program" """ pserver_prog = self.get_pserver_program(endpoint) - pserver_startup = self.get_startup_program(endpoint) + pserver_startup = self.get_startup_program( + endpoint, pserver_program=pserver_prog) return pserver_prog, pserver_startup def get_startup_program(self, @@ -1118,6 +1119,7 @@ to transpile() call.") def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints): # 2. add split_ids_op and send_op to send gradient to pservers + # there should only be one table_name all_ops = program.global_block().ops table_grad_name = grad_var_name(self.table_name) @@ -1142,7 +1144,7 @@ to transpile() call.") if self.sync_mode else [] }, attrs={ - "sync_mode": self.sync_mode, + "sync_mode": not self.sync_mode, "epmap": pserver_endpoints, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, OP_ROLE_VAR_ATTR_NAME: [ @@ -1188,7 +1190,15 @@ to transpile() call.") def _create_table_optimize_block(self, pserver_index, pserver_program, pre_block_idx, grad_to_block_id): # STEP: create table optimize block + table_opt_block = pserver_program._create_block(pre_block_idx) # create table param and grad var in pserver program + # create table optimize block in pserver program + table_opt_op = [ + op for op in self.optimize_ops + if 'Param' in op.input_names and op.input("Param")[0] == + self.table_name + ][0] + origin_param_var = self.origin_program.global_block().vars[ self.table_name] @@ -1204,19 +1214,16 @@ to transpile() call.") dtype=origin_param_var.dtype, type=core.VarDesc.VarType.SELECTED_ROWS, persistable=True) + # parameter must be selected rows param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) grad_var = pserver_program.global_block()._clone_variable( self.origin_program.global_block().vars[grad_var_name( self.table_name)]) - # create table optimize block in pserver program - table_opt_op = [ - op for op in self.optimize_ops - if 'Param' in op.input_names and op.input("Param")[0] == - self.table_name - ][0] - table_opt_block = pserver_program._create_block(pre_block_idx) + lr_var = pserver_program.global_block()._clone_variable( + self.origin_program.global_block().vars[table_opt_op.input( + "LearningRate")[0]]) if self.sync_mode: # create grad vars in pserver program @@ -1248,8 +1255,6 @@ to transpile() call.") grad_var = pserver_program.global_block()._rename_var( origin_grad_name, splited_grad_name) - lr_var = pserver_program.global_block().vars[table_opt_op.input( - "LearningRate")[0]] inputs = { "Param": [param_var], "Grad": [grad_var],