diff --git a/cmake/generic.cmake b/cmake/generic.cmake index d5eaa9877181a2a3f7319693fc00f13e34873190..3f1be11d85555671eebb1c2ba3a5642d64d7f2bf 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -748,7 +748,7 @@ function(grpc_library TARGET_NAME) #FIXME(putcn): the follwoing line is supposed to generate *.pb.h and cc, but # somehow it didn't. line 602 to 604 is to patching this. Leaving this here # for now to enable dist CI. - protobuf_generate_cpp(grpc_proto_srcs grpc_proto_hdrs "${ABS_PROTO}") + paddle_protobuf_generate_cpp(grpc_proto_srcs grpc_proto_hdrs "${ABS_PROTO}") set(grpc_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.cc") set(grpc_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.h") cc_library("${TARGET_NAME}_proto" SRCS "${grpc_proto_srcs}") @@ -791,7 +791,7 @@ function(brpc_library TARGET_NAME) get_filename_component(PROTO_WE ${brpc_library_PROTO} NAME_WE) get_filename_component(PROTO_PATH ${ABS_PROTO} PATH) - protobuf_generate_cpp(brpc_proto_srcs brpc_proto_hdrs "${ABS_PROTO}") + paddle_protobuf_generate_cpp(brpc_proto_srcs brpc_proto_hdrs "${ABS_PROTO}") cc_library("${TARGET_NAME}_proto" SRCS "${brpc_proto_srcs}") cc_library("${TARGET_NAME}" SRCS "${brpc_library_SRCS}" DEPS "${TARGET_NAME}_proto" "${brpc_library_DEPS}") endfunction() diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6937d13dbaa60723f8d55964db8db1eb03f4057b..0a4edea2c3cdaa6b457ce2098313c9b962484a7a 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -70,8 +70,8 @@ paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')) paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None)) paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None)) -paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None)) -paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid')) +paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False)) +paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False)) paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None) @@ -215,6 +215,7 @@ paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', ' paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.teacher_student_sigmoid_loss ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)) paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.tree_conv ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/framework/details/all_reduce_deps_pass.cc b/paddle/fluid/framework/details/all_reduce_deps_pass.cc index fe21e21bcfc42bfb3251a7d0d15aa5926f56813f..b7d6edd389d8e40835dadf56d7c54d53402f6f4d 100644 --- a/paddle/fluid/framework/details/all_reduce_deps_pass.cc +++ b/paddle/fluid/framework/details/all_reduce_deps_pass.cc @@ -82,13 +82,13 @@ std::unique_ptr AllReduceDepsPass::ApplyImpl( PADDLE_ENFORCE(i0 != nullptr && i1 != nullptr, "%s convert to %s error", op1->DebugString(), op2->DebugString()); - auto l_it = vars.find(i0->name_); - auto r_it = vars.find(i1->name_); + auto l_it = vars.find(i0->name()); + auto r_it = vars.find(i1->name()); if (l_it->second < r_it->second) return true; if (l_it->second == r_it->second) { - return i0->name_ < i1->name_; + return i0->name() < i1->name(); } return false; diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index a24e3d3e487e488f0d0c59809a0adc9f9524cc6e..dd77f7099f581a5b825916c4ea010023f3ad5bcd 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -70,9 +70,9 @@ void AllReduceOpHandle::RunImpl() { auto *s = local_scopes_[i]; auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get(); auto &lod_tensor = - local_scope.FindVar(in_var_handles[i]->name_)->Get(); + local_scope.FindVar(in_var_handles[i]->name())->Get(); lod_tensors.emplace_back(&lod_tensor); - PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_, + PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(), "The name of input and output should be equal."); } @@ -134,7 +134,7 @@ void AllReduceOpHandle::RunImpl() { auto &trg = *this->local_scopes_[0] ->FindVar(kLocalExecScopeName) ->Get() - ->FindVar(out_var_handles[0]->name_) + ->FindVar(out_var_handles[0]->name()) ->GetMutable(); // Reduce All Tensor to trg in CPU @@ -145,7 +145,7 @@ void AllReduceOpHandle::RunImpl() { auto &scope = *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get(); auto &p = places_[i]; - auto *var = scope.FindVar(out_var_handles[i]->name_); + auto *var = scope.FindVar(out_var_handles[i]->name()); auto *dev_ctx = dev_ctxes_.at(p); RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index cf280c29ff8c7416be3b2d0b529bd04776150950..89d626edddfee3d2c43a3cf2232ad4fc1611e655 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -56,11 +56,11 @@ void BroadcastOpHandle::BroadcastOneVar( const std::vector &out_var_handles, const std::vector &var_scopes) { auto *in_var = - var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_); + var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name()); PADDLE_ENFORCE_NOT_NULL(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); if (UNLIKELY(!in_tensor.IsInitialized())) { - VLOG(3) << "in var " << in_var_handle.name_ << "not inited, return!"; + VLOG(3) << "in var " << in_var_handle.name() << "not inited, return!"; return; } @@ -71,9 +71,9 @@ void BroadcastOpHandle::BroadcastOneVar( if (out_var_handle->IsTheSameVar(in_var_handle)) { continue; } - auto &out_p = out_var_handle->place_; - auto *out_var = var_scopes.at(out_var_handle->scope_idx_) - ->FindVar(out_var_handle->name_); + auto &out_p = out_var_handle->place(); + auto *out_var = var_scopes.at(out_var_handle->scope_idx()) + ->FindVar(out_var_handle->name()); RunAndRecordEvent(out_p, [in_tensor, out_var] { paddle::framework::TensorCopy( @@ -91,11 +91,11 @@ void BroadcastOpHandle::BroadcastOneVar( size_t numel = static_cast(in_tensor.numel()); for (auto out_var_handle : out_var_handles) { - Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) - ->FindVar(out_var_handle->name_); + Variable *out_var = var_scopes.at(out_var_handle->scope_idx()) + ->FindVar(out_var_handle->name()); int dst_id = - boost::get(out_var_handle->place_).device; + boost::get(out_var_handle->place()).device; auto &nccl_ctx = nccl_ctxs_->at(dst_id); @@ -106,7 +106,7 @@ void BroadcastOpHandle::BroadcastOneVar( } else { send_recv_buffer = VariableVisitor::GetMutableTensor(out_var) .Resize(in_tensor.dims()) - .mutable_data(out_var_handle->place_); + .mutable_data(out_var_handle->place()); } broadcast_calls.emplace_back( @@ -126,11 +126,11 @@ void BroadcastOpHandle::BroadcastOneVar( } if (!out_handle->IsTheSameVar(in_var_handle)) { - auto out_var = var_scopes.at(in_var_handle.scope_idx_) - ->FindVar(out_var_handles[0]->name_); + auto out_var = var_scopes.at(in_var_handle.scope_idx()) + ->FindVar(out_var_handles[0]->name()); paddle::framework::TensorCopy( - in_tensor, in_var_handle.place_, - *(dev_ctxes_.at(in_var_handle.place_)), + in_tensor, in_var_handle.place(), + *(dev_ctxes_.at(in_var_handle.place())), &VariableVisitor::GetMutableTensor(out_var)); } }); @@ -148,7 +148,7 @@ void BroadcastOpHandle::InitOutputValue( var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); } auto *in_var = - var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_); + var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name()); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); @@ -158,9 +158,9 @@ void BroadcastOpHandle::InitOutputValue( if (out_var_handle->IsTheSameVar(in_var_handle)) { continue; } - auto t_out_p = out_var_handle->place_; - auto *out_var = var_scopes.at(out_var_handle->scope_idx_) - ->FindVar(out_var_handle->name_); + auto t_out_p = out_var_handle->place(); + auto *out_var = var_scopes.at(out_var_handle->scope_idx()) + ->FindVar(out_var_handle->name()); PADDLE_ENFORCE_NOT_NULL(out_var); if (is_gpu_place(in_tensor.place())) { PADDLE_ENFORCE(platform::is_gpu_place(t_out_p), diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index cc562c7b102cea80e18cbd2c054c34415a7442c9..48dcc52623369f7b0f51cd8c8aeb198b37467d5f 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -100,13 +100,13 @@ void DataBalanceOpHandle::RunImpl() { std::vector> lod_tensors(data_num); std::vector device_sizes; for (int i = 0; i < static_cast(in_var_handles.size()); ++i) { - PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_, + PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(), "The name of input and output should be equal."); int place_idx = i / data_num; int data_idx = i % data_num; auto *local_scope = local_scopes_[place_idx]->FindVar(kLocalExecScopeName)->Get(); - auto *tensor_var = local_scope->FindVar(in_var_handles[i]->name_); + auto *tensor_var = local_scope->FindVar(in_var_handles[i]->name()); PADDLE_ENFORCE(tensor_var->IsType()); auto *tensor = tensor_var->GetMutable(); lod_tensors[data_idx].push_back(tensor); diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 648adae06facb504042d8286f6eab5d98e99c015..bbf81e1b8e49cae133858f7aa121701fb0f5456f 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -52,12 +52,12 @@ void FetchOpHandle::RunImpl() { for (size_t i = 0; i < inputs_.size(); ++i) { auto *var_handle = static_cast(inputs_[i]); - auto &scope = scopes.at(var_handle->scope_idx_); + auto &scope = scopes.at(var_handle->scope_idx()); auto *var = scope->FindVar(kLocalExecScopeName) ->Get() - ->FindVar(var_handle->name_); + ->FindVar(var_handle->name()); PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", - var_handle->name_); + var_handle->name()); auto &t = var->Get(); if (platform::is_gpu_place(t.place())) { diff --git a/paddle/fluid/framework/details/fuse_vars_op_handle.cc b/paddle/fluid/framework/details/fuse_vars_op_handle.cc index 018c9bff71e553d8a3641f06f10b350453676b24..d65b0920698748e8a2ded728d78fbcd69b7bae0e 100644 --- a/paddle/fluid/framework/details/fuse_vars_op_handle.cc +++ b/paddle/fluid/framework/details/fuse_vars_op_handle.cc @@ -29,14 +29,14 @@ void FuseVarsOpHandle::RunImpl() { auto scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); auto out_var_handle = out_var_handles[0]; - auto out_var = scope->Var(out_var_handle->name_); + auto out_var = scope->Var(out_var_handle->name()); auto out_tensor = out_var->GetMutable(); out_tensor->Resize({total_numel_}).mutable_data(this->place_, type_); int64_t s = 0; for (size_t i = 1; i < out_var_handles.size(); ++i) { - auto out_name = out_var_handles[i]->name_; + auto out_name = out_var_handles[i]->name(); auto out_t = scope->Var(out_name)->GetMutable(); auto numel = this->inputs_numel_.at(out_name); out_t->ShareDataWith(out_tensor->Slice(s, s + numel)); diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index ca4633c5a8f22fc9f7319b06aa766f9fe37dc68c..179cca44cb1871bb9667074f6c6b32edee42be09 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -49,7 +49,7 @@ void GatherOpHandle::RunImpl() { auto in_0_handle = in_var_handles[0]; auto pre_in_var = - var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_); + var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name()); PADDLE_ENFORCE_NOT_NULL(pre_in_var); PADDLE_ENFORCE(pre_in_var->IsType(), @@ -65,7 +65,7 @@ void GatherOpHandle::RunImpl() { // Gather the inputs for (auto *in_handle : in_var_handles) { auto *in_var = - var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); + var_scopes.at(in_handle->scope_idx())->FindVar(in_handle->name()); PADDLE_ENFORCE_NOT_NULL(in_var); VariableVisitor::EnforceShapeAndDTypeEQ(*in_var, *pre_in_var); @@ -77,7 +77,7 @@ void GatherOpHandle::RunImpl() { } // NOTE: The Places of all input tensor must be all on CPU or all on GPU. - platform::Place t_out_p = out_var_handle->place_; + platform::Place t_out_p = out_var_handle->place(); if (platform::is_gpu_place(pre_in_value.place())) { PADDLE_ENFORCE(platform::is_gpu_place(t_out_p), "Places of input and output must be all on GPU."); @@ -85,8 +85,8 @@ void GatherOpHandle::RunImpl() { t_out_p = platform::CPUPlace(); } - auto out_var = - var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); + auto out_var = var_scopes.at(out_var_handle->scope_idx()) + ->FindVar(out_var_handle->name()); PADDLE_ENFORCE_NOT_NULL(out_var); auto out_value = out_var->GetMutable(); out_value->set_height(pre_in_value.height()); @@ -99,9 +99,9 @@ void GatherOpHandle::RunImpl() { Tensor *out_tensor = out_value->mutable_value(); // copy - auto dev_ctx = dev_ctxes_.at(out_var_handle->place_); - RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx, - t_out_p] { + auto dev_ctx = dev_ctxes_.at(out_var_handle->place()); + RunAndRecordEvent(out_var_handle->place(), [in_tensors, out_tensor, &dev_ctx, + t_out_p] { int s = 0, e = 0; for (size_t j = 0; j < in_tensors.size(); ++j) { e += in_tensors[j].dims()[0]; diff --git a/paddle/fluid/framework/details/memory_early_delete_pass.cc b/paddle/fluid/framework/details/memory_early_delete_pass.cc index 06a2451c136e3243ba41661fa691f9a6ef8b52ac..5906b7d57ce122520a4594f1528e00982eaa1a7f 100644 --- a/paddle/fluid/framework/details/memory_early_delete_pass.cc +++ b/paddle/fluid/framework/details/memory_early_delete_pass.cc @@ -33,7 +33,7 @@ static ComputationOpHandle* FindNextComputationOpHandle(VarHandle* var_in) { queue.pop(); for (auto* op : var->PendingOps()) { auto* compute_op = dynamic_cast(op); - if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) { + if (compute_op != nullptr && compute_op->GetPlace() == var_in->place()) { return compute_op; } for (auto* out_var : op->Outputs()) { @@ -64,7 +64,7 @@ std::unique_ptr MemoryEarlyDeletePass::ApplyImpl( for (auto& var : vars) { auto* var_handle = dynamic_cast(var); auto var_name = var->Node()->Name(); - auto& var_place = var_handle->place_; + auto& var_place = var_handle->place(); if (unlived_vars.count(var_name) == 0) continue; if (!unlived_vars[var_name].empty()) { if (compute_op != nullptr && diff --git a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc index c203073845375c879a0fc10564f5dad0f19ceae4..e82eb104fa9f461ec370fc4b31551dd1a9214a7c 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc @@ -52,11 +52,11 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, vars[var_ptr] = cur_var_id; if (var_handle_ptr) { - sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_ + sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name() << "\\n" - << var_handle_ptr->place_ << "\\n" - << "scope: " << var_handle_ptr->scope_idx_ << "\\n" - << "v" << var_handle_ptr->version_ << "\"]" << std::endl; + << var_handle_ptr->place() << "\\n" + << "scope: " << var_handle_ptr->scope_idx() << "\\n" + << "v" << var_handle_ptr->version() << "\"]" << std::endl; } else if (dummy_ptr) { sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl; } diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 7a5f7de57ef20b4b909894ff8d742a65ea05874d..ee4c8a6ecf77e5d0f23f38b763917d926afdb07a 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -60,8 +60,8 @@ void ReduceOpHandle::GatherSelectedRows( *CollectiveContext::GetInstance(); // 1. gather local selected rows, merge them - std::string gathered_var_name = out_var_handle->name_ + "_gathered_tmp"; - auto scope = local_scopes_.at(out_var_handle->scope_idx_); + std::string gathered_var_name = out_var_handle->name() + "_gathered_tmp"; + auto scope = local_scopes_.at(out_var_handle->scope_idx()); auto gathered_var_mid = scope->Var(gathered_var_name); auto gathered_select_rows = gathered_var_mid->GetMutable(); @@ -73,7 +73,7 @@ void ReduceOpHandle::GatherSelectedRows( // merge them auto merged_dev_ctx = dynamic_cast(dev_ctxes.at(out_place)); std::string merged_var_name = - GetRemoteVarName(out_var_handle->name_, collective_context.trainer_id_); + GetRemoteVarName(out_var_handle->name(), collective_context.trainer_id_); auto merged_select_rows = scope->Var(merged_var_name)->GetMutable(); operators::math::scatter::MergeAdd merge_func; @@ -101,7 +101,7 @@ void ReduceOpHandle::GatherSelectedRows( operators::distributed::RemoteVar var; var.trainer_id_ = i; - var.var_name_ = GetRemoteVarName(out_var_handle->name_, i); + var.var_name_ = GetRemoteVarName(out_var_handle->name(), i); var.ep_ = collective_context.endpoints_[i]; vars.push_back(var); @@ -166,7 +166,7 @@ void ReduceOpHandle::RunImpl() { } auto pre_in_var = - var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_); + var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name()); PADDLE_ENFORCE_NOT_NULL(pre_in_var); // Wait input done, this Wait is asynchronous operation @@ -175,15 +175,15 @@ void ReduceOpHandle::RunImpl() { // NOTE: The Places of all input tensor must be all on CPU or all on GPU. std::vector in_places; // used to get dev_ctx for (auto *in_handle : in_var_handles) { - in_places.emplace_back(in_handle->place_); + in_places.emplace_back(in_handle->place()); auto in_var = - var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); + var_scopes.at(in_handle->scope_idx())->FindVar(in_handle->name()); PADDLE_ENFORCE_NOT_NULL(in_var); VariableVisitor::EnforceShapeAndDTypeEQ(*pre_in_var, *in_var); } - auto out_var = - var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); + auto out_var = var_scopes.at(out_var_handle->scope_idx()) + ->FindVar(out_var_handle->name()); PADDLE_ENFORCE_NOT_NULL(out_var); // NOTE: The tensors' Place of input and output must be all on GPU or all on @@ -191,9 +191,9 @@ void ReduceOpHandle::RunImpl() { auto in_p = VariableVisitor::GetMutableTensor(pre_in_var).place(); platform::Place t_out_p; if (platform::is_gpu_place(in_p)) { - PADDLE_ENFORCE(platform::is_gpu_place(out_var_handle->place_), + PADDLE_ENFORCE(platform::is_gpu_place(out_var_handle->place()), "Places of input and output must be all on GPU."); - t_out_p = out_var_handle->place_; + t_out_p = out_var_handle->place(); } else { t_out_p = platform::CPUPlace(); } @@ -253,7 +253,7 @@ void ReduceOpHandle::RunImpl() { auto &reduce_sum_trg = *this->local_scopes_[0] ->FindVar(kLocalExecScopeName) ->Get() - ->FindVar(out_var_handle->name_) + ->FindVar(out_var_handle->name()) ->GetMutable(); ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); VisitDataType(lod_tensors[0]->type(), func); @@ -269,9 +269,9 @@ void ReduceOpHandle::RunImpl() { auto pre_in = pre_in_var->Get(); VariableVisitor::ShareDimsAndLoD(*pre_in_var, out_var); VariableVisitor::GetMutableTensor(out_var).mutable_data( - out_var_handle->place_, pre_in.type()); + out_var_handle->place(), pre_in.type()); - auto out_p = out_var_handle->place_; + auto out_p = out_var_handle->place(); int root_id = boost::get(out_p).device; std::vector> all_reduce_calls; for (size_t i = 0; i < var_scopes.size(); ++i) { @@ -286,7 +286,7 @@ void ReduceOpHandle::RunImpl() { if (root_id == dev_id) { recvbuffer = out_var->GetMutable()->mutable_data( - out_var_handle->place_); + out_var_handle->place()); } int type = platform::ToNCCLDataType(lod_tensor.type()); @@ -320,8 +320,8 @@ std::vector ReduceOpHandle::GetInputValues( const std::vector &var_scopes) const { std::vector in_selected_rows; for (auto *in_handle : in_var_handles) { - auto &in_sr = var_scopes.at(in_handle->scope_idx_) - ->FindVar(in_handle->name_) + auto &in_sr = var_scopes.at(in_handle->scope_idx()) + ->FindVar(in_handle->name()) ->Get(); in_selected_rows.emplace_back(&in_sr); } diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index dfa6c1ade1a024bb9087144d0e96fa5b0417f06a..3e082f247adf7fe22db2b62802f0a87c9c93447a 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -30,7 +30,7 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, void RPCOpHandle::RunImpl() { for (auto *in : inputs_) { - auto &p = static_cast(in)->place_; + auto &p = static_cast(in)->place(); if (ir::IsControlDepVar(*in->Node())) { continue; } diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index e1b8e8fe05f0615d689e78d9c405cc5d76d2abb1..6924549f36d6365534ab288257899a78107675cc 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -68,7 +68,7 @@ struct ScaleLossGradFunctor { void ScaleLossGradOpHandle::RunImpl() { // Doesn't wait any event - std::string var_name = static_cast(this->outputs_[0])->name_; + std::string var_name = static_cast(this->outputs_[0])->name(); auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get(); auto *tensor = local_scope.FindVar(var_name)->GetMutable(); diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index 3b007d7b1a52df765a2dbd41939f8f865123cb43..8321c32f8b1d73bf5e6080b4b314abc9fd20536d 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -111,15 +111,22 @@ struct VarHandle : public VarHandleBase { // version field currently is not used, however, just store the version to // debug easily. + private: size_t version_; size_t scope_idx_; std::string name_; platform::Place place_; + public: bool IsTheSameVar(const VarHandle& o) const { return o.generated_op_ == generated_op_ && o.name_ == name_ && o.scope_idx_ == scope_idx_; } + + size_t version() const { return version_; } + size_t scope_idx() const { return scope_idx_; } + const std::string& name() const { return name_; } + const platform::Place& place() const { return place_; } }; // Dummy Variable. It is used to represent dependencies between operators diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 336ab426c21d9de93693c44d8fc6bc5b37b58864..965bbd0fd26ce39f72b622bce0ecb7b3bbdf4f2f 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -127,6 +127,7 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size, use_tensorrt_ = true; tensorrt_workspace_size_ = workspace_size; tensorrt_max_batchsize_ = max_batch_size; + tensorrt_min_subgraph_size_ = min_subgraph_size; Update(); } @@ -145,8 +146,8 @@ void contrib::AnalysisConfig::Update() { LOG(ERROR) << "TensorRT engine is not available when EnableGpu() not actived."; } else { - // Append after the infer_clean pass. - pass_builder()->InsertPass(1, "tensorrt_subgraph_pass"); + // Append after the Affine_channel_conv_fuse pass. + pass_builder()->InsertPass(3, "tensorrt_subgraph_pass"); } } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 585634fae9c85f77cc77d774ac166891014a025c..3917b9b65b5905be38ba8a236aa158f42586c825 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -561,6 +561,7 @@ AnalysisPredictor::~AnalysisPredictor() { } std::unique_ptr AnalysisPredictor::Clone() { + std::lock_guard lk(clone_mutex_); auto *x = new AnalysisPredictor(config_); x->Init(scope_, inference_program_); return std::unique_ptr(x); diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index a6e126c5d533f4299ccc3deed7d116cabc71f75b..6ca4b5e9bed7505fca3b833dfbb7026ff550d258 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -115,6 +115,8 @@ class AnalysisPredictor : public PaddlePredictor { // concurrency problems, wrong results and memory leak, so cache them. std::vector feed_tensors_; details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; + // A mutex help to make Clone thread safe. + std::mutex clone_mutex_; private: // Some status here that help to determine the status inside the predictor. diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 6169e60541e4a14d560e719d56624b3219dbcefd..3df26cde3d5defac97074c9bc4086e81f9ec0c93 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -179,8 +179,9 @@ TEST(AnalysisPredictor, Clone) { threads.emplace_back([&predictors, &inputs, i] { LOG(INFO) << "thread #" << i << " running"; std::vector outputs; + auto predictor = predictors.front()->Clone(); for (int j = 0; j < 10; j++) { - ASSERT_TRUE(predictors[i]->Run(inputs, &outputs)); + ASSERT_TRUE(predictor->Run(inputs, &outputs)); } }); } diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 85e250aaaf4a18a261a4bfc5271670f93565a336..e18bc02d92eb517fa20dc83811694b8ac80ae316 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -161,13 +161,16 @@ bool NativePaddlePredictor::Run(const std::vector &inputs, } std::unique_ptr NativePaddlePredictor::Clone() { + std::lock_guard lk(clone_mutex_); VLOG(3) << "Predictor::clone"; std::unique_ptr cls(new NativePaddlePredictor(config_)); - - if (!dynamic_cast(cls.get())->Init(scope_)) { + // Hot fix the bug that result diff in multi-thread. + // TODO(Superjomn) re-implement a real clone here. + if (!dynamic_cast(cls.get())->Init(nullptr)) { LOG(ERROR) << "fail to call Init"; return nullptr; } + #ifdef __clang__ // fix clang compile error return cls; diff --git a/paddle/fluid/inference/api/api_impl.h b/paddle/fluid/inference/api/api_impl.h index d2133bd467376c723a80a98725ac7c70234c54b0..96b94777304382a9d4be115a84f80ead69249863 100644 --- a/paddle/fluid/inference/api/api_impl.h +++ b/paddle/fluid/inference/api/api_impl.h @@ -74,6 +74,8 @@ class NativePaddlePredictor : public PaddlePredictor { // Do not use unique_ptr, use parent scope to delete framework::Scope *sub_scope_{nullptr}; details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; + // A mutex to make Clone thread safe. + std::mutex clone_mutex_; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 0f540699b8ffea94c3f3aaf3354a0462e0e674a9..f60ff40c5da3e9e03c2cb3583263394cb82db805 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -33,9 +33,15 @@ void ZeroCopyTensor::Reshape(const std::vector &shape) { tensor->Resize(framework::make_ddim(shape)); } +#define EAGER_GET_TENSOR \ + if (!tensor_) { \ + tensor_ = FindTensor(); \ + } \ + auto *tensor = static_cast(tensor_); + template T *ZeroCopyTensor::mutable_data(PaddlePlace place) { - auto *tensor = static_cast(FindTensor()); + EAGER_GET_TENSOR; switch (static_cast(place)) { case static_cast(PaddlePlace::kCPU): { return tensor->mutable_data(platform::CPUPlace()); @@ -52,7 +58,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) { template T *ZeroCopyTensor::data(PaddlePlace *place, int *size) const { - auto *tensor = static_cast(FindTensor()); + EAGER_GET_TENSOR; auto *res = tensor->data(); if (platform::is_cpu_place(tensor->place())) { @@ -87,13 +93,13 @@ void *ZeroCopyTensor::FindTensor() const { } std::vector ZeroCopyTensor::shape() const { - auto *tensor = static_cast(FindTensor()); - PADDLE_ENFORCE(tensor, "not found tensor called %s in the scope", name_); + EAGER_GET_TENSOR; + PADDLE_ENFORCE(tensor_, "not found tensor called %s in the scope", name_); return framework::vectorize(tensor->dims()); } void ZeroCopyTensor::SetLoD(const std::vector> &x) { - auto *tensor = static_cast(FindTensor()); + EAGER_GET_TENSOR; framework::LoD lod; for (auto &level : x) { lod.emplace_back(level); @@ -102,8 +108,8 @@ void ZeroCopyTensor::SetLoD(const std::vector> &x) { } std::vector> ZeroCopyTensor::lod() const { + EAGER_GET_TENSOR; std::vector> res; - auto *tensor = static_cast(FindTensor()); for (auto &level : tensor->lod()) { res.emplace_back(level); } diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index 832c8cdf2849279c4c32a81e9f81ef522c401b86..46b510fd1ec94c59032b8f41a2ac4d6aa87dc150 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -146,6 +146,9 @@ class ZeroCopyTensor { bool input_or_output_; friend class AnalysisPredictor; void* scope_{nullptr}; + // The corresponding tensor pointer inside Paddle workspace is cached for + // performance. + mutable void* tensor_{nullptr}; }; /** A simple Inference API for Paddle. @@ -167,18 +170,40 @@ class PaddlePredictor { std::vector* output_data, int batch_size = -1) = 0; - /** Zero copy input and output optimization. - * Get the input or output tensors, and operate on their memory directly, - * without copy. + /** \brief Get a mutable tensor directly. + * + * NOTE Only works in AnalysisPredictor. + * + * One can also use this to modify any temporary variable related tensors in + * the predictor. + * */ virtual std::unique_ptr GetInputTensor( const std::string& name) { return nullptr; } + /** + * \brief Get an immutable tensor without copy. + * + * NOTE Only works in AnalysisPredictor. + * One can use this API to get any temporary tensors in the predictor and + * read it. + */ virtual std::unique_ptr GetOutputTensor( const std::string& name) { return nullptr; } + /** + * \brief Run the predictor with zero-copied inputs and outputs. + * + * NOTE Only works in AnalysisPredictor. + * + * This will save the IO copy for transfering inputs and outputs to predictor + * workspace and get some performance improvement. + * To use it, one should call the `AnalysisConfig.SwitchUseFeedFetchOp(true)` + * and then use the `GetInputTensor` and `GetOutputTensor` to directly write + * or read the input/output tensors. + */ virtual bool ZeroCopyRun() { return false; } /** Clone a predictor that share the model weights, the Cloned predictor diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 0f670658892b9926dcc534038925c46047a113fd..adbf98e9e8a535938157ae5c8214d8bfffbc3314 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -84,7 +84,12 @@ inference_analysis_api_test(test_analyzer_lac ${LAC_INSTALL_DIR} analyzer_lac_te # MM DNN set(MM_DNN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mm_dnn") download_model_and_data(${MM_DNN_INSTALL_DIR} "MM_DNN_model.tar.gz" "MM_DNN_data.txt.tar.gz") -inference_analysis_api_test(test_analyzer_mm_dnn ${MM_DNN_INSTALL_DIR} analyzer_mm_dnn_tester.cc) +inference_analysis_api_test(test_analyzer_mm_dnn ${MM_DNN_INSTALL_DIR} analyzer_mm_dnn_tester.cc SERIAL) + +# Pyramid DNN +set(PYRAMID_DNN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/pyramid_dnn") +download_model_and_data(${PYRAMID_DNN_INSTALL_DIR} "PyramidDNN_model.tar.gz" "PyramidDNN_data.txt.tar.gz") +inference_analysis_api_test(test_analyzer_pyramid_dnn ${PYRAMID_DNN_INSTALL_DIR} analyzer_pyramid_dnn_tester.cc) # text_classification set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classification") diff --git a/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc b/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..ad2c46e48d5a34a457a615f313f1ac3cc916b200 --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { +using contrib::AnalysisConfig; + +struct DataRecord { + std::vector> query_basic, query_phrase, title_basic, + title_phrase; + std::vector lod1, lod2, lod3, lod4; + size_t batch_iter{0}, batch_size{1}, num_samples; // total number of samples + DataRecord() = default; + explicit DataRecord(const std::string &path, int batch_size = 1) + : batch_size(batch_size) { + Load(path); + } + DataRecord NextBatch() { + DataRecord data; + size_t batch_end = batch_iter + batch_size; + // NOTE skip the final batch, if no enough data is provided. + if (batch_end <= query_basic.size()) { + GetInputPerBatch(query_basic, &data.query_basic, &data.lod1, batch_iter, + batch_end); + GetInputPerBatch(query_phrase, &data.query_phrase, &data.lod2, batch_iter, + batch_end); + GetInputPerBatch(title_basic, &data.title_basic, &data.lod3, batch_iter, + batch_end); + GetInputPerBatch(title_phrase, &data.title_phrase, &data.lod4, batch_iter, + batch_end); + } + batch_iter += batch_size; + return data; + } + void Load(const std::string &path) { + std::ifstream file(path); + std::string line; + int num_lines = 0; + while (std::getline(file, line)) { + std::vector data; + split(line, ';', &data); + // load query data + std::vector query_basic_data; + split_to_int64(data[1], ' ', &query_basic_data); + std::vector query_phrase_data; + split_to_int64(data[2], ' ', &query_phrase_data); + // load title data + std::vector title_basic_data; + split_to_int64(data[3], ' ', &title_basic_data); + std::vector title_phrase_data; + split_to_int64(data[4], ' ', &title_phrase_data); + // filter the empty data + bool flag = + data[1].size() && data[2].size() && data[3].size() && data[4].size(); + if (flag) { + query_basic.push_back(std::move(query_basic_data)); + query_phrase.push_back(std::move(query_phrase_data)); + title_basic.push_back(std::move(title_basic_data)); + title_phrase.push_back(std::move(title_phrase_data)); + num_lines++; + } + } + num_samples = num_lines; + } +}; + +void PrepareInputs(std::vector *input_slots, DataRecord *data, + int batch_size) { + PaddleTensor query_basic_tensor, query_phrase_tensor, title_basic_tensor, + title_phrase_tensor; + query_basic_tensor.name = "query_basic"; + query_phrase_tensor.name = "query_phrase"; + title_basic_tensor.name = "pos_title_basic"; + title_phrase_tensor.name = "pos_title_phrase"; + auto one_batch = data->NextBatch(); + // assign data + TensorAssignData(&query_basic_tensor, one_batch.query_basic, + one_batch.lod1); + TensorAssignData(&query_phrase_tensor, one_batch.query_phrase, + one_batch.lod2); + TensorAssignData(&title_basic_tensor, one_batch.title_basic, + one_batch.lod3); + TensorAssignData(&title_phrase_tensor, one_batch.title_phrase, + one_batch.lod4); + // Set inputs. + input_slots->assign({query_basic_tensor, query_phrase_tensor, + title_basic_tensor, title_phrase_tensor}); + for (auto &tensor : *input_slots) { + tensor.dtype = PaddleDType::INT64; + } +} + +void SetConfig(contrib::AnalysisConfig *cfg) { + cfg->SetModel(FLAGS_infer_model); + cfg->DisableGpu(); + cfg->SwitchSpecifyInputNames(); + cfg->SwitchIrOptim(); +} + +void SetInput(std::vector> *inputs) { + DataRecord data(FLAGS_infer_data, FLAGS_batch_size); + std::vector input_slots; + int epoch = FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1; + LOG(INFO) << "number of samples: " << epoch * FLAGS_batch_size; + for (int bid = 0; bid < epoch; ++bid) { + PrepareInputs(&input_slots, &data, FLAGS_batch_size); + (*inputs).emplace_back(input_slots); + } +} + +// Easy for profiling independently. +TEST(Analyzer_Pyramid_DNN, profile) { + contrib::AnalysisConfig cfg; + SetConfig(&cfg); + std::vector outputs; + + std::vector> input_slots_all; + SetInput(&input_slots_all); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); + + if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { + PADDLE_ENFORCE_EQ(outputs.size(), 1UL); + size_t size = GetSize(outputs[0]); + PADDLE_ENFORCE_GT(size, 0); + float *result = static_cast(outputs[0].data.data()); + // output is probability, which is in (0, 1). + for (size_t i = 0; i < size; i++) { + EXPECT_GT(result[i], 0); + EXPECT_LT(result[i], 1); + } + } +} + +// Check the fuse status +TEST(Analyzer_Pyramid_DNN, fuse_statis) { + contrib::AnalysisConfig cfg; + SetConfig(&cfg); + + int num_ops; + auto predictor = CreatePaddlePredictor(cfg); + auto fuse_statis = GetFuseStatis( + static_cast(predictor.get()), &num_ops); +} + +// Compare result of NativeConfig and AnalysisConfig +TEST(Analyzer_Pyramid_DNN, compare) { + contrib::AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); +} + +// Compare Deterministic result +TEST(Analyzer_Pyramid_DNN, compare_determine) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareDeterministic(reinterpret_cast(&cfg), + input_slots_all); +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e53a6a562ad1ed2ca02683b07cf6d4b56bc2cde7..992a2bdd5ad639bf6176328e94da6eb71a41790c 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -65,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) if (WITH_GPU) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 264a7880938cc86d91f8f1c992b5bc8a742361be..0360cf5273591946570cac47e2578e43f498b550 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -58,7 +58,6 @@ class WhileOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); - auto &cond = scope.FindVar(Input(kCondition))->Get(); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); @@ -73,27 +72,18 @@ class WhileOp : public framework::OperatorBase { PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), "Condition of while op must in CPU memory."); + bool is_test = Attr("is_test"); auto &skip_vars = Attr>(kSkipEagerDeletionVars); VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); - bool is_test = Attr("is_test"); auto ctx = executor.Prepare(*program, block->ID(), skip_vars); - - if (!is_test) { - while (cond.data()[0]) { - auto ¤t_scope = scope.NewScope(); - step_scopes->push_back(¤t_scope); - executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, - true); - } - } else { + while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); - executor.CreateVariables(*program, ¤t_scope, block->ID()); - while (cond.data()[0]) { - executor.RunPreparedContext(ctx.get(), ¤t_scope, false, false, - false); + step_scopes->push_back(¤t_scope); + executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, true); + if (is_test) { + scope.DeleteScope(¤t_scope); } - scope.DeleteScope(¤t_scope); } } }; diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 800c7a3705ddfdd4e3c17fccdcd0f049fe47c801..7fcbf85f187e944adee0dcb69cb71e453daa224b 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -7,7 +7,7 @@ if(WITH_GRPC) else() set(cc_generic_services "true") endif() -configure_file(send_recv.proto.in ${CMAKE_CURRENT_BINARY_DIR}/send_recv.proto @ONLY) +configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY) # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") @@ -19,8 +19,8 @@ if(WITH_GRPC) variable_response.cc collective_client.cc collective_server.cc ${GRPC_SRCS} - PROTO ${CMAKE_CURRENT_BINARY_DIR}/send_recv.proto - DEPS lod_tensor selected_rows_functor memory ${GRPC_DEPS}) + PROTO send_recv.proto + DEPS lod_tensor selected_rows_functor memory) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 087f903a8bba9a4bfcd7eaabd7098555442a904e..752d706cbfab8eb3027fe9610c25b7400ecfed1d 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, defalut: False) " "whether to compute reversed GRU.") .SetDefault(false); + AddAttr("origin_mode", + "bool" + "use origin mode in article https://arxiv.org/abs/1412.3555") + .SetDefault(false); AddComment(R"DOC( GRU Operator implements part calculations of the complete GRU as following: @@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { using DeviceContext = paddle::platform::CPUDeviceContext; + bool origin_mode = context.Attr("origin_mode"); auto* input = context.Input("Input"); auto* h0 = context.Input("H0"); auto* weight = context.Input("Weight"); @@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel { math::detail::forward_final_output( math::detail::forward::gru_finalOutput(), gru_value, frame_size, - cur_batch_size, active_node); + cur_batch_size, active_node, origin_mode); gru_value.prev_out_value = gru_value.output_value; } @@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel { math::GRUUnitFunctor::compute( dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); + active_gate, origin_mode); gru_value.prev_out_value = gru_value.output_value; } diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc index 55721c283dd18c2f9642563a9ce1eabfce16fd7b..ba918b3def22e3c60c4155f77ecbaad85d520928 100644 --- a/paddle/fluid/operators/gru_op.cu.cc +++ b/paddle/fluid/operators/gru_op.cu.cc @@ -21,6 +21,7 @@ template class GRUKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { + bool origin_mode = context.Attr("origin_mode"); auto* input = context.Input("Input"); auto* h0 = context.Input("H0"); auto* weight = context.Input("Weight"); @@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel { gru_value.reset_output_value = reset_hidden_prev_t.data(); math::GRUUnitFunctor::compute( dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); + active_gate, origin_mode); gru_value.prev_out_value = gru_value.output_value; } diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 0b551e8046be16c95f7d6b10b68b32a9af594f73..45c769ee37260bf912ebc848d58019557f4adc07 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -41,6 +41,7 @@ template class GRUGradKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { + bool origin_mode = context.Attr("origin_mode"); auto* h0 = context.Input("H0"); auto* weight = context.Input("Weight"); const T* weight_data = weight->data(); @@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel { math::GRUUnitGradFunctor::compute( dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node, - active_gate); + active_gate, origin_mode); } if (input_grad) { input_grad->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/gru_unit_op.cc b/paddle/fluid/operators/gru_unit_op.cc index 82a808b01e99ec33b0ca00a065fb301d3c633b19..e3beedcf10b6286c92371c48cae7912aef35e7a3 100644 --- a/paddle/fluid/operators/gru_unit_op.cc +++ b/paddle/fluid/operators/gru_unit_op.cc @@ -111,6 +111,13 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { "The activation type used in update gate and reset gate.") .SetDefault(sigmoid) .InEnum({identity, sigmoid, tanh, relu}); + AddAttr("origin_mode", + "bool" + "use origin mode in article (https://arxiv.org/pdf/1406.1078.pdf)") + .SetDefault(false); AddComment(R"DOC( GRUUnit Operator implements partial calculations of the GRU unit as following: diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index 451ec61ba1f7239d92c6dfbad0b2961e74e1bc17..712ef05d8631ac74b92795321202cb5590286e82 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -113,7 +113,11 @@ class GRUUnitKernel : public framework::OpKernel { auto c = g.slice(c_offsets, extents); // output candidate // calculate final output - h.device(place) = u * (c - h_p) + h_p; + if (context.Attr("origin_mode")) { + h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p + } else { + h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p + } } }; @@ -180,11 +184,19 @@ class GRUUnitGradKernel : public framework::OpKernel { auto c = g.slice(c_offsets, extents); // output candidate // backward for unactivated update gate - ActGradCompute(context.Attr("gate_activation"), place, u, u, - d_g.slice(u_offsets, extents), d_h * (c - h_p)); - // backward for unactivated output candidate - ActGradCompute(context.Attr("activation"), place, c, c, - d_g.slice(c_offsets, extents), d_h * u); + if (context.Attr("origin_mode")) { + ActGradCompute(context.Attr("gate_activation"), place, u, u, + d_g.slice(u_offsets, extents), d_h * (h_p - c)); + // backward for unactivated output candidate + ActGradCompute(context.Attr("activation"), place, c, c, + d_g.slice(c_offsets, extents), d_h * (1 - u)); + } else { + ActGradCompute(context.Attr("gate_activation"), place, u, u, + d_g.slice(u_offsets, extents), d_h * (c - h_p)); + // backward for unactivated output candidate + ActGradCompute(context.Attr("activation"), place, c, c, + d_g.slice(c_offsets, extents), d_h * u); + } // backward for reset_hidden_prev auto blas = math::GetBlas(context); blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, @@ -213,7 +225,11 @@ class GRUUnitGradKernel : public framework::OpKernel { T* hidden_prev_grad_data = hidden_prev_grad->mutable_data(context.GetPlace()); auto d_h_p = EigenMatrix::From(*hidden_prev_grad); - d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); + if (context.Attr("origin_mode")) { + d_h_p.device(place) = d_r_h_p * r + d_h * u; + } else { + d_h_p.device(place) = d_r_h_p * r + d_h * (1 - u); + } blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1, hidden_prev_grad_data, frame_size); diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 600ab14d37aad6b95516c5bd6551d12165596f57..dc27e543f0dfd65e556f9e3a138778972ad6982f 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -60,6 +60,7 @@ math_library(matrix_bit_code) math_library(unpooling) math_library(vol2col) math_library(prelu) +math_library(tree2col DEPS math_function) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h index 47c771f7c5c01b651423c7886207abf4a4297019..c6dd972e12b763283a4212d4c56844afb1c2fd7a 100644 --- a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h @@ -56,7 +56,8 @@ template void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { T r_value_update_gate; T r_value_frame_state; T r_prev_out = 0; @@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, } op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node); + &r_output, active_node, origin_mode); frame_state[i] = r_value_frame_state; output_value[i] = r_output; @@ -146,7 +147,8 @@ template void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { #ifdef __AVX__ __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f); @@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, } op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node); + &r_output, active_node, origin_mode); _mm256_storeu_ps(reinterpret_cast(frame_state + i), r_value_frame_state); @@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, if (rest > 0) { i = n - block; op_final_output(&r_value_update_gate_last, &r_value_frame_state_last, - &r_prev_out_last, &r_output, active_node); + &r_prev_out_last, &r_output, active_node, origin_mode); _mm256_storeu_ps(reinterpret_cast(frame_state + i), r_value_frame_state_last); @@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output, template inline void forward_final_output(OpFinalOutput op_final_output, GRUMetaValue value, int frame_size, - int batch_size, ActivationType active_node) { + int batch_size, ActivationType active_node, + bool origin_mode) { for (int b = 0; b < batch_size; b++) { if (OpFinalOutput::avx && (frame_size > static_cast(8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_final_output(op_final_output, value.gate_value, value.prev_out_value, value.output_value, - frame_size, active_node); + frame_size, active_node, origin_mode); } else { hl_naive_gru_forward_final_output( op_final_output, value.gate_value, value.prev_out_value, - value.output_value, frame_size, active_node); + value.output_value, frame_size, active_node, origin_mode); } value.gate_value += frame_size * 3; @@ -253,7 +256,8 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { T r_update_gate_value; T r_update_gate_grad; T r_frame_state_value; @@ -279,7 +283,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, - &r_prev_out_grad, &r_out_grad, active_node); + &r_prev_out_grad, &r_out_grad, active_node, origin_mode); update_gate_grad[i] = r_update_gate_grad; frame_state_grad[i] = r_frame_state_grad; @@ -338,8 +342,8 @@ template void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, - int frame_size, - ActivationType active_node) { + int frame_size, ActivationType active_node, + bool origin_mode) { #ifdef __AVX__ __m256 r_update_gate_value; __m256 r_update_gate_grad; @@ -368,7 +372,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, - &r_prev_out_grad, &r_out_grad, active_node); + &r_prev_out_grad, &r_out_grad, active_node, origin_mode); update_gate_grad[i] = r_update_gate_grad; frame_state_grad[i] = r_frame_state_grad; @@ -431,16 +435,18 @@ template inline void backward_state_grad(OpStateGrad op_state_grad, GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - ActivationType active_node) { + ActivationType active_node, bool origin_mode) { for (int b = 0; b < batch_size; b++) { if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { - hl_avx_gru_backward_state_grad( - op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value, - grad.prev_out_grad, grad.output_grad, frame_size, active_node); + hl_avx_gru_backward_state_grad(op_state_grad, value.gate_value, + grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, grad.output_grad, + frame_size, active_node, origin_mode); } else { - hl_naive_gru_backward_state_grad( - op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value, - grad.prev_out_grad, grad.output_grad, frame_size, active_node); + hl_naive_gru_backward_state_grad(op_state_grad, value.gate_value, + grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, grad.output_grad, + frame_size, active_node, origin_mode); } value.gate_value += frame_size * 3; diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h index 813d69f6aba722609a0523a5be71d32f91f76d59..6b57da1046a05b15b9c3302104d9f4d12c52227f 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h @@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, int batch_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, } op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node); + &r_output, active_node, origin_mode); gate_value[frame_idx + frame_size * 2] = r_value_frame_state; output_value[frame_idx] = r_output; @@ -109,7 +110,8 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, int batch_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -139,7 +141,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad, - &r_out_grad, active_node); + &r_out_grad, active_node, origin_mode); gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad; diff --git a/paddle/fluid/operators/math/detail/gru_kernel.h b/paddle/fluid/operators/math/detail/gru_kernel.h index f6d192358bd84eb56a2e01eb36f28d8832ef271f..894f5f04d2451151964965bd721ff35e353ff2b5 100644 --- a/paddle/fluid/operators/math/detail/gru_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_kernel.h @@ -57,10 +57,16 @@ class gru_finalOutput { public: HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state, T *prev_out, T *value_output, - ActivationType act_input) { + ActivationType act_input, bool origin_mode) { *value_frame_state = activation(*value_frame_state, act_input); - *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) + - ((*value_update_gate) * (*value_frame_state)); + if (origin_mode) { + *value_output = ((*value_update_gate) * (*prev_out)) + + *value_frame_state - + ((*value_update_gate) * (*value_frame_state)); + } else { + *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) + + ((*value_update_gate) * (*value_frame_state)); + } } #ifndef __NVCC__ #ifndef __AVX__ @@ -69,11 +75,20 @@ class gru_finalOutput { static const bool avx = true; HOSTDEVICE void operator()(__m256 *value_update_gate, __m256 *value_frame_state, __m256 *prev_out, - __m256 *value_output, ActivationType act_input) { + __m256 *value_output, ActivationType act_input, + bool origin_mode) { *value_frame_state = activation(*value_frame_state, act_input); - *value_output = _mm256_add_ps( - _mm256_sub_ps(*prev_out, _mm256_mul_ps(*value_update_gate, *prev_out)), - _mm256_mul_ps(*value_update_gate, *value_frame_state)); + if (origin_mode) { + *value_output = _mm256_sub_ps( + _mm256_add_ps(_mm256_mul_ps(*value_update_gate, *prev_out), + *value_frame_state), + _mm256_mul_ps(*value_update_gate, *value_frame_state)); + } else { + *value_output = _mm256_add_ps( + _mm256_sub_ps(*prev_out, + _mm256_mul_ps(*value_update_gate, *prev_out)), + _mm256_mul_ps(*value_update_gate, *value_frame_state)); + } } #endif #endif @@ -88,13 +103,23 @@ class gru_stateGrad { HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, T *value_frame_state, T *grad_frame_state, T *value_prev_out, T *grad_prev_out, - T *grad_output, ActivationType act_input) { - *grad_update_gate = (*grad_output * (*value_frame_state)); - *grad_update_gate -= (*grad_output * (*value_prev_out)); - *grad_prev_out -= (*grad_output * (*value_update_gate)); - *grad_prev_out += *grad_output; - *grad_frame_state = activation(*grad_output * (*value_update_gate), - *value_frame_state, act_input); + T *grad_output, ActivationType act_input, + bool origin_mode) { + if (origin_mode) { + *grad_update_gate = + (*grad_output) * ((*value_prev_out) - (*value_frame_state)); + *grad_prev_out += (*grad_output * (*value_update_gate)); + *grad_frame_state = activation( + *grad_output * (static_cast(1.0) - (*value_update_gate)), + *value_frame_state, act_input); + } else { + *grad_update_gate = + (*grad_output) * ((*value_frame_state) - (*value_prev_out)); + *grad_prev_out += + (*grad_output * (static_cast(1.0) - *value_update_gate)); + *grad_frame_state = activation(*grad_output * (*value_update_gate), + *value_frame_state, act_input); + } } #ifndef __NVCC__ #ifndef __AVX__ @@ -106,17 +131,27 @@ class gru_stateGrad { __m256 *value_frame_state, __m256 *grad_frame_state, __m256 *value_prev_out, __m256 *grad_prev_out, __m256 *grad_output, - ActivationType act_input) { - *grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state); - *grad_update_gate = _mm256_sub_ps( - *grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out)); - *grad_prev_out = _mm256_add_ps( - _mm256_sub_ps(*grad_prev_out, - _mm256_mul_ps(*grad_output, *value_update_gate)), - *grad_output); - *grad_frame_state = - activation(_mm256_mul_ps(*grad_output, *value_update_gate), - *value_frame_state, act_input); + ActivationType act_input, bool origin_mode) { + if (origin_mode) { + *grad_update_gate = _mm256_mul_ps( + *grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state)); + *grad_prev_out = _mm256_add_ps( + *grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate)); + *grad_frame_state = activation( + _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), + *value_update_gate)), + *value_frame_state, act_input); + } else { + *grad_update_gate = _mm256_mul_ps( + *grad_output, _mm256_sub_ps(*value_frame_state, *value_prev_out)); + *grad_prev_out = _mm256_add_ps( + *grad_prev_out, + _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), + *value_update_gate))); + *grad_frame_state = + activation(_mm256_mul_ps(*grad_output, *value_update_gate), + *value_frame_state, act_input); + } } #endif #endif diff --git a/paddle/fluid/operators/math/gru_compute.cc b/paddle/fluid/operators/math/gru_compute.cc index 0e15b81deef43a932d4b2d3f545393b0ad9e080c..07c5cbf33378e6f6cee8a82448f55399966a2574 100644 --- a/paddle/fluid/operators/math/gru_compute.cc +++ b/paddle/fluid/operators/math/gru_compute.cc @@ -23,7 +23,8 @@ struct GRUUnitFunctor { static void compute(const platform::CPUDeviceContext &context, GRUMetaValue value, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate) { + const detail::ActivationType active_gate, + bool origin_mode) { #ifndef __NVCC__ auto blas = math::GetBlas(context); if (value.prev_out_value) { @@ -43,7 +44,8 @@ struct GRUUnitFunctor { } detail::forward_final_output(detail::forward::gru_finalOutput(), value, - frame_size, batch_size, active_node); + frame_size, batch_size, active_node, + origin_mode); #endif } }; @@ -54,10 +56,12 @@ struct GRUUnitGradFunctor { GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate) { + const detail::ActivationType active_gate, + bool origin_mode) { #ifndef __NVCC__ detail::backward_state_grad(detail::backward::gru_stateGrad(), value, - grad, frame_size, batch_size, active_node); + grad, frame_size, batch_size, active_node, + origin_mode); auto blas = math::GetBlas(context); if (value.prev_out_value && grad.prev_out_grad) { blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index 1327d914952d57aab6e5d17090d0ea976a6d4755..ec7e4d2228c38161bb1f3f97ec21b91db454adb4 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -24,7 +24,8 @@ struct GRUUnitFunctor { static void compute(const platform::CUDADeviceContext &context, GRUMetaValue value, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate) { + const detail::ActivationType active_gate, + bool origin_mode) { auto stream = context.stream(); dim3 threads; dim3 grid; @@ -73,14 +74,14 @@ struct GRUUnitFunctor { T><<>>( detail::forward::gru_finalOutput(), value.gate_value, value.prev_out_value, value.output_value, frame_size, batch_size, - active_node); + active_node, origin_mode); } else { detail::KeGruForwardFinalOutput, /* is_batch= */ true, T><<>>( detail::forward::gru_finalOutput(), value.gate_value, value.prev_out_value, value.output_value, frame_size, batch_size, - active_node); + active_node, origin_mode); } } }; @@ -91,7 +92,8 @@ struct GRUUnitGradFunctor { GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate) { + const detail::ActivationType active_gate, + bool origin_mode) { auto stream = context.stream(); dim3 threads; dim3 grid; @@ -111,14 +113,14 @@ struct GRUUnitGradFunctor { /* is_batch= */ false><<>>( detail::backward::gru_stateGrad(), value.gate_value, grad.gate_grad, value.prev_out_value, grad.prev_out_grad, - grad.output_grad, frame_size, batch_size, active_node); + grad.output_grad, frame_size, batch_size, active_node, origin_mode); } else { detail::KeGruBackwardStateGrad< detail::backward::gru_stateGrad, /* is_batch= */ true><<>>( detail::backward::gru_stateGrad(), value.gate_value, grad.gate_grad, value.prev_out_value, grad.prev_out_grad, - grad.output_grad, frame_size, batch_size, active_node); + grad.output_grad, frame_size, batch_size, active_node, origin_mode); } auto blas = math::GetBlas(context); diff --git a/paddle/fluid/operators/math/gru_compute.h b/paddle/fluid/operators/math/gru_compute.h index c5816b16cd90410fcc48929931c25d0d561ad653..f5ddec0aaa275a32a5a9937699066a170edc0825 100644 --- a/paddle/fluid/operators/math/gru_compute.h +++ b/paddle/fluid/operators/math/gru_compute.h @@ -44,7 +44,8 @@ struct GRUUnitFunctor { static void compute(const DeviceContext &context, GRUMetaValue value, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate); + const detail::ActivationType active_gate, + bool origin_mode); }; template @@ -52,7 +53,8 @@ struct GRUUnitGradFunctor { static void compute(const DeviceContext &context, GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate); + const detail::ActivationType active_gate, + bool origin_mode); }; } // namespace math diff --git a/paddle/fluid/operators/math/tree2col.cc b/paddle/fluid/operators/math/tree2col.cc new file mode 100644 index 0000000000000000000000000000000000000000..05ce5bc7a205ae51ae147450e7c0f23ee0fe28e2 --- /dev/null +++ b/paddle/fluid/operators/math/tree2col.cc @@ -0,0 +1,197 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/math/tree2col.h" +#include +#include + +namespace paddle { +namespace operators { +namespace math { +using Tensor = framework::Tensor; +std::vector Tree2ColUtil::construct_patch( + size_t root, int max_depth, const std::vector> &tr) { + std::stack> stack; + std::unordered_map visited; + std::vector patch; + + stack.push(TreeNode(root, 1, 1, 0)); + patch.emplace_back(TreeNode(root, 1, 1, 0)); + visited[root] = true; + + while (!stack.empty()) { + TreeNode &u = stack.top(); + bool end = true; + size_t node = u.get_node(), sz = tr[node].size(); + visited[node] = true; + for (size_t i = 0; i < sz; i++) { + size_t v = tr[node][i]; + if (!visited[v] && static_cast(u.get_depth()) + 1 < max_depth) { + visited[v] = true; + stack.push(TreeNode(v, i, sz, u.get_depth() + 1)); + patch.push_back(TreeNode(v, i + 1, sz, u.get_depth() + 1)); + end = false; + } + } + if (end) { + stack.pop(); + } + } + return patch; +} + +void Tree2ColUtil::construct_tree(const paddle::Tensor &EdgeSet, + std::vector> *tr, + size_t *node_count) { + auto edge_set_dims = EdgeSet.dims(); + PADDLE_ENFORCE_EQ(edge_set_dims[1], 2); + int64_t edge_count = EdgeSet.numel(); + + const int *edge_data = EdgeSet.data(); + + for (int64_t i = 0; i < edge_count; i += 2) { + int u = edge_data[i], v = edge_data[i + 1]; + if (u != 0 && v != 0) (*node_count)++; + } + (*node_count)++; + + tr->resize(static_cast(*node_count + 1)); + + for (int64_t i = 0; i < edge_count; i += 2) { + int u = edge_data[i], v = edge_data[i + 1]; + if (u != 0 && v != 0) { + tr->at(u).push_back(v); + } else { + break; + } + } +} + +template +class Tree2ColFunctor { + public: + void operator()(const platform::CPUDeviceContext &context, + const framework::Tensor &EdgeSet, + const framework::Tensor &node_features, + framework::Tensor *patch, int max_depth) { + std::vector> tr; + auto feature_dims = node_features.dims(); + auto cpu_place = boost::get(context.GetPlace()); + math::SetConstant constant; + int64_t feature_size = feature_dims[1]; + size_t patch_elem_size = 3 * static_cast(feature_size); + size_t node_count = 0, patch_count = 0, patch_size; + Tree2ColUtil::construct_tree(EdgeSet, &tr, &node_count); + std::vector> processing_list; + for (size_t u = 1; u <= node_count; u++) { + std::vector temp_patch = + Tree2ColUtil::construct_patch(u, max_depth, tr); + if (!temp_patch.empty()) { + processing_list.emplace_back(temp_patch); + } + } + patch_size = processing_list.size(); + + T *patch_data = + patch->mutable_data({static_cast(patch_size), + static_cast(patch_elem_size)}, + cpu_place); + constant(context, patch, 0); + const T *features = node_features.data(); + + for (auto &patch_item : processing_list) { + size_t pointer_base = patch_count * patch_elem_size; + for (auto &v : patch_item) { + T eta_l = v.eta_l(max_depth), eta_r = v.eta_r(max_depth), + eta_t = v.eta_t(max_depth); + size_t id = v.get_node() - 1; + for (int i = 0; i < feature_size; i++) { + patch_data[pointer_base + i * 3] += + eta_l * features[id * feature_size + i]; + patch_data[pointer_base + i * 3 + 1] += + eta_r * features[id * feature_size + i]; + patch_data[pointer_base + i * 3 + 2] += + eta_t * features[id * feature_size + i]; + } + } + patch_count++; + } + patch->Resize({static_cast(patch_count), + static_cast(patch_elem_size)}); + } +}; +template +class Col2TreeFunctor { + public: + void operator()(const platform::CPUDeviceContext &context, + const framework::Tensor &EdgeSet, + const framework::Tensor &out_grad, framework::Tensor *in_grad, + int max_depth) { + std::vector> tr; + auto output_dims = out_grad.dims(); + auto cpu_place = boost::get(context.GetPlace()); + math::SetConstant constant; + int64_t output_size = output_dims[1]; + size_t grad_elem_size = 3 * static_cast(output_size); + size_t node_count = 0, grad_count = 0; + Tree2ColUtil::construct_tree(EdgeSet, &tr, &node_count); + std::vector> processing_list; + std::vector> grad_list; + grad_list.resize(node_count); + for (size_t u = 1; u <= node_count; u++) { + std::vector tmp = + Tree2ColUtil::construct_patch(u, max_depth, tr); + if (!tmp.empty()) { + processing_list.push_back(tmp); + } + } + for (size_t patch_id = 0; patch_id < processing_list.size(); patch_id++) { + for (auto v : processing_list[patch_id]) { + grad_list[v.get_node() - 1].push_back(v.change_node(patch_id + 1)); + } + } + T *grad_data = + in_grad->mutable_data({static_cast(node_count), + static_cast(grad_elem_size)}, + cpu_place); + + constant(context, in_grad, 0); + const T *out_g = out_grad.data(); + for (auto &patch_item : grad_list) { + size_t pointer_base = grad_count * grad_elem_size; + for (auto &v : patch_item) { + T eta_l = v.eta_l(max_depth), eta_r = v.eta_r(max_depth), + eta_t = v.eta_t(max_depth); + size_t id = v.get_node() - 1; + for (int i = 0; i < output_size; i++) { + grad_data[pointer_base + i * 3] += + eta_l * out_g[id * output_size + i]; + grad_data[pointer_base + i * 3 + 1] += + eta_r * out_g[id * output_size + i]; + grad_data[pointer_base + i * 3 + 2] += + eta_t * out_g[id * output_size + i]; + } + } + grad_count++; + } + } +}; + +template class Tree2ColFunctor; +template class Tree2ColFunctor; +template class Col2TreeFunctor; +template class Col2TreeFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/tree2col.cu b/paddle/fluid/operators/math/tree2col.cu new file mode 100644 index 0000000000000000000000000000000000000000..3c50a525c2eac440ace3cb1d87af6abb3c5a9628 --- /dev/null +++ b/paddle/fluid/operators/math/tree2col.cu @@ -0,0 +1,208 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/tree2col.h" + +namespace paddle { +namespace operators { +namespace math { +using Tensor = framework::Tensor; +using Node = paddle::operators::math::TreeNode; +template +__global__ void tree2col(const T* eta, const int* node, const int* index, + const T* vectors, T* result, int feature_size, int n) { + const int thread_id = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + const int patch_id = thread_id / feature_size; + const int j = thread_id % feature_size; + if (patch_id < n) { + const int begin_o = patch_id * 3 * feature_size; + const int begin = index[patch_id * 2], end = index[patch_id * 2 + 1]; + T res_l = 0, res_r = 0, res_t = 0; + for (int i = begin; i < end; i++) { + const int id = node[i]; + const T vec = vectors[id * feature_size + j]; + res_l += eta[i * 3] * vec; + res_r += eta[i * 3 + 1] * vec; + res_t += eta[i * 3 + 2] * vec; + } + result[begin_o + j * 3] = res_l; + result[begin_o + j * 3 + 1] = res_r; + result[begin_o + j * 3 + 2] = res_t; + } +} +template +class Tree2ColFunctor { + public: + void operator()(const paddle::platform::CUDADeviceContext& context, + const framework::Tensor& EdgeSet, + const framework::Tensor& node_features, + framework::Tensor* patch, int max_depth) { + std::vector> tr; + auto gpu_place = boost::get(context.GetPlace()); + auto cpu_place = platform::CPUPlace(); + auto stream = context.stream(); + auto feature_dims = node_features.dims(); + math::SetConstant constant; + + Tensor EdgeSet_cpu; + framework::TensorCopy(EdgeSet, cpu_place, &EdgeSet_cpu); + int64_t feature_size = feature_dims[1]; + size_t patch_elem_size = 3 * static_cast(feature_size); + size_t node_count = 0, patch_count = 0, total_size = 0; + size_t max_size = feature_dims[0]; + Tree2ColUtil::construct_tree(EdgeSet_cpu, &tr, &node_count); + + std::vector> processing_list; + for (size_t u = 1; u <= node_count; u++) { + std::vector tmp = Tree2ColUtil::construct_patch(u, max_depth, tr); + if (!tmp.empty()) { + processing_list.push_back(tmp); + total_size += tmp.size(); + } + } + + size_t patch_size = processing_list.size(); + Tensor node_cpu, node_gpu, eta_cpu, eta_gpu, index_cpu, index_gpu; + int* node = node_cpu.mutable_data({static_cast(total_size)}, + cpu_place); + T* eta = eta_cpu.mutable_data({static_cast(total_size * 3)}, + cpu_place); + int* index = index_cpu.mutable_data( + {static_cast(patch_size * 2)}, cpu_place); + + int idx = 0, index_idx = 0; + for (auto& tmp : processing_list) { + index[index_idx++] = idx; + for (auto& v : tmp) { + node[idx] = static_cast(v.node - 1); + eta[idx * 3] = v.eta_l(max_depth); + eta[idx * 3 + 1] = v.eta_r(max_depth); + eta[idx * 3 + 2] = v.eta_t(max_depth); + idx++; + } + index[index_idx++] = idx; + } + framework::TensorCopy(node_cpu, gpu_place, context, &node_gpu); + framework::TensorCopy(eta_cpu, gpu_place, context, &eta_gpu); + framework::TensorCopy(index_cpu, gpu_place, context, &index_gpu); + + int elem_size = patch_size * feature_size; + int blocks = (elem_size + 1024 - 1) / 1024; + int block_x = 512; + int block_y = (blocks + 512 - 1) / 512; + dim3 threads(1024, 1); + dim3 grid(block_x, block_y); + + patch->mutable_data( + {static_cast(max_size), static_cast(patch_elem_size)}, + gpu_place); + constant(context, patch, 0); + tree2col<<>>( + eta_gpu.data(), node_gpu.data(), index_gpu.data(), + node_features.data(), patch->data(), feature_size, patch_size); + } +}; +template +class Col2TreeFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& EdgeSet, + const framework::Tensor& patch_grad, + framework::Tensor* embedding_grad, int max_depth) { + std::vector> tr; + auto gpu_place = boost::get(context.GetPlace()); + auto cpu_place = platform::CPUPlace(); + auto stream = context.stream(); + auto output_dims = patch_grad.dims(); + math::SetConstant constant; + + Tensor EdgeSet_cpu; + framework::TensorCopy(EdgeSet, cpu_place, &EdgeSet_cpu); + int64_t output_size = output_dims[1]; + size_t patch_elem_size = 3 * static_cast(output_size); + size_t node_count = 0, patch_count = 0; + size_t max_size = output_dims[0]; + Tree2ColUtil::construct_tree(EdgeSet_cpu, &tr, &node_count); + std::vector> processing_list; + std::vector> grad_list; + grad_list.resize(node_count); + size_t total_size = 0, grad_size = node_count; + for (size_t u = 1; u <= node_count; u++) { + std::vector tmp = Tree2ColUtil::construct_patch(u, max_depth, tr); + if (!tmp.empty()) { + processing_list.push_back(tmp); + } + } + for (size_t patch_id = 0; patch_id < processing_list.size(); patch_id++) { + for (auto v : processing_list[patch_id]) { + grad_list[v.get_node() - 1].push_back(v.change_node(patch_id + 1)); + } + } + for (auto& tmp : grad_list) { + total_size += tmp.size(); + } + + Tensor node_cpu, node_gpu, eta_cpu, eta_gpu, index_cpu, index_gpu; + int* node = node_cpu.mutable_data({static_cast(total_size)}, + cpu_place); + T* eta = eta_cpu.mutable_data({static_cast(total_size * 3)}, + cpu_place); + int* index = index_cpu.mutable_data( + {static_cast(grad_size * 2)}, cpu_place); + + size_t idx = 0, index_idx = 0; + for (auto& tmp : grad_list) { + index[index_idx++] = idx; + for (auto& v : tmp) { + node[idx] = static_cast(v.node - 1); + eta[idx * 3] = v.eta_l(max_depth); + eta[idx * 3 + 1] = v.eta_r(max_depth); + eta[idx * 3 + 2] = v.eta_t(max_depth); + idx++; + } + index[index_idx++] = idx; + } + framework::TensorCopy(node_cpu, gpu_place, &node_gpu); + framework::TensorCopy(eta_cpu, gpu_place, &eta_gpu); + framework::TensorCopy(index_cpu, gpu_place, &index_gpu); + + int elem_size = output_size * grad_size; + int blocks = (elem_size + 1024 - 1) / 1024; + int block_x = 512; + int block_y = (blocks + 512 - 1) / 512; + dim3 threads(1024, 1); + dim3 grid(block_x, block_y); + + embedding_grad->mutable_data( + {static_cast(max_size), static_cast(patch_elem_size)}, + gpu_place); + + constant(context, embedding_grad, 0); + tree2col<<>>( + eta_gpu.data(), node_gpu.data(), index_gpu.data(), + patch_grad.data(), embedding_grad->data(), output_size, + grad_size); + } +}; + +template class Tree2ColFunctor; +template class Tree2ColFunctor; +template class Col2TreeFunctor; +template class Col2TreeFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/tree2col.h b/paddle/fluid/operators/math/tree2col.h new file mode 100644 index 0000000000000000000000000000000000000000..478ba78e259d327cc440f34161c8cf476109bb8c --- /dev/null +++ b/paddle/fluid/operators/math/tree2col.h @@ -0,0 +1,90 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +using Tensor = framework::Tensor; +using DDim = framework::DDim; +namespace operators { +namespace math { +class TreeNode { + public: + size_t node; + explicit TreeNode(size_t node = 0, size_t index = 0, size_t pclen = 0, + size_t depth = 0) + : node(node), index(index), pclen(pclen), depth(depth) {} + template + T eta_t(T filter_depth) { + return ((filter_depth - this->depth) / filter_depth); + } + template + T eta_l(T filter_depth) { + T temp; + if (this->pclen == 1) { + temp = 0.5; + } else { + temp = (this->index - 1.0) / (this->pclen - 1.0); + } + return (1.0 - this->eta_t(filter_depth)) * temp; + } + template + T eta_r(T filter_depth) { + return (1.0 - this->eta_t(filter_depth)) * + (1.0 - this->eta_l(filter_depth)); + } + TreeNode change_node(size_t v) { + return TreeNode(v, this->index, this->pclen, this->depth); + } + size_t get_node() { return this->node; } + size_t get_depth() { return this->depth; } + + private: + size_t index, pclen, depth; +}; +class Tree2ColUtil { + public: + static std::vector construct_patch( + size_t root, int max_depth, const std::vector> &tr); + + static void construct_tree(const Tensor &EdgeSet, + std::vector> *tr, + size_t *node_count); +}; + +template +class Tree2ColFunctor { + public: + void operator()(const DeviceContext &context, + const framework::Tensor &EdgeSet, + const framework::Tensor &node_features, + framework::Tensor *patch, int max_depth); +}; +template +class Col2TreeFunctor { + public: + void operator()(const DeviceContext &context, + const framework::Tensor &EdgeSet, + const framework::Tensor &out_grad, framework::Tensor *in_grad, + int max_depth); +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/tree_conv_op.cc b/paddle/fluid/operators/tree_conv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..615ea285e54b97a8fb81acfef9bf0d18ac4e914d --- /dev/null +++ b/paddle/fluid/operators/tree_conv_op.cc @@ -0,0 +1,129 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/tree_conv_op.h" +#include + +namespace paddle { +namespace operators { +class TreeConvOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("NodesVector", + "(Tensor) The feature vector of every node on the tree. " + "The shape of the feature vector must be " + "[max_tree_node_size, feature_size]."); + AddInput("EdgeSet", + "(Tensor) The Edges of Tree. The edge must be directional. " + "The shape of the edge set must be [max_tree_node_size, 2]."); + AddInput("Filter", + "(Tensor) The feature detector. " + "The shape of the filter is " + "[feature_size, 3, output_size, num_filters]."); + AddOutput("Out", + "(Tensor) The feature vector of subtrees. " + "The shape of the output tensor is [max_tree_node_size, " + "output_size, num_filters]. " + "The output tensor could be a new feature " + "vector for next tree convolution layers."); + AddAttr("max_depth", + "(int, default: 2) The depth of feature detector.") + .SetDefault(2) + .GreaterThan(1); + AddComment(R"DOC( +**Tree-Based Convolution Operator** + +Tree-Based Convolution is a kind of convolution based on tree structure. +Tree-Based Convolution is a part of Tree-Based Convolution Neural Network(TBCNN), +which is used to classify tree structures, such as Abstract Syntax Tree. +Tree-Based Convolution proposed a kind of data structure called continuous binary tree, +which regards multiway tree as binary tree. +The paper of Tree-Based Convolution Operator is here: +https://arxiv.org/abs/1409.5718v1 +)DOC"); + } +}; +class TreeConvOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out")); + auto edge_dims = ctx->GetInputDim("EdgeSet"); + auto vector_dims = ctx->GetInputDim("NodesVector"); + auto filter_dims = ctx->GetInputDim("Filter"); + PADDLE_ENFORCE_EQ(edge_dims[2], 2, "Input(EdgeSet) dim[2] should be 2"); + PADDLE_ENFORCE_EQ(edge_dims.size(), 3, + "The dimension of EdgeSet Tensor should be 3"); + PADDLE_ENFORCE_EQ(vector_dims.size(), 3, + "The dimension of NodesVector Tensor should be 3"); + PADDLE_ENFORCE_EQ(filter_dims.size(), 4, + "The dimension of Filter Tensor should be 4"); + PADDLE_ENFORCE_EQ(filter_dims[1], 3, "Input(Filter) dim[1] should be 3"); + PADDLE_ENFORCE_EQ( + filter_dims[0], vector_dims[2], + "Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]"); + auto output_dims = framework::make_ddim( + {vector_dims[0], vector_dims[1], filter_dims[2], filter_dims[3]}); + ctx->SetOutputDim("Out", output_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("NodesVector")->type(), + ctx.device_context()); + } +}; + +class TreeConvGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto vectors_dims = ctx->GetInputDim("NodesVector"); + auto filter_dims = ctx->GetInputDim("Filter"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "the gradient of output(Out) must not be null"); + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } + if (ctx->HasOutput(framework::GradVarName("NodesVector"))) { + ctx->SetOutputDim(framework::GradVarName("NodesVector"), vectors_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("NodesVector")->type(), + ctx.device_context()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(tree_conv, ops::TreeConvOp, ops::TreeConvOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OPERATOR(tree_conv_grad, ops::TreeConvGradOp); + +REGISTER_OP_CPU_KERNEL( + tree_conv, ops::TreeConvKernel, + ops::TreeConvKernel); + +REGISTER_OP_CPU_KERNEL( + tree_conv_grad, + ops::TreeConvGradKernel, + ops::TreeConvGradKernel); diff --git a/paddle/fluid/operators/tree_conv_op.cu b/paddle/fluid/operators/tree_conv_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..eebfe412bdd65139d9657aae78288f66d9d7bc06 --- /dev/null +++ b/paddle/fluid/operators/tree_conv_op.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/tree_conv_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + tree_conv, ops::TreeConvKernel, + ops::TreeConvKernel); +REGISTER_OP_CUDA_KERNEL( + tree_conv_grad, + ops::TreeConvGradKernel, + ops::TreeConvGradKernel); diff --git a/paddle/fluid/operators/tree_conv_op.h b/paddle/fluid/operators/tree_conv_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a84589b32fd0016e0372c50aac8156b2dce883ba --- /dev/null +++ b/paddle/fluid/operators/tree_conv_op.h @@ -0,0 +1,146 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/tree2col.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +using DDim = framework::DDim; +template +class TreeConvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + math::Tree2ColFunctor tree2col; + math::SetConstant constant; + + auto *Edges = ctx.Input("EdgeSet"); + auto *Embeddings = ctx.Input("NodesVector"); + auto *Filter = ctx.Input("Filter"); + auto *output_emb = ctx.Output("Out"); + int max_depth = ctx.Attr("max_depth"); + + auto &dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + + Tensor W; + W.ShareDataWith(*Filter); + W.Resize(framework::flatten_to_2d(Filter->dims(), 2)); + + int batch_size = static_cast(Edges->dims()[0]); + int n = static_cast(Embeddings->dims()[1]); + int out_size = static_cast(Filter->dims()[2]); + int num_filters = static_cast(Filter->dims()[3]); + output_emb->mutable_data({batch_size, n, out_size, num_filters}, + ctx.GetPlace()); + + auto edge_set_slicedim = framework::slice_ddim( + Edges->dims(), 1, static_cast(Edges->dims().size())); + + auto embedding_slicedim = framework::slice_ddim( + Embeddings->dims(), 1, static_cast(Embeddings->dims().size())); + + auto output_slicedim = framework::slice_ddim( + output_emb->dims(), 1, static_cast(output_emb->dims().size())); + + output_slicedim = framework::flatten_to_2d(output_slicedim, 1); + + for (int idx = 0; idx < batch_size; idx++) { + auto edge_set = Edges->Slice(idx, idx + 1).Resize(edge_set_slicedim); + auto embeddings = + Embeddings->Slice(idx, idx + 1).Resize(embedding_slicedim); + auto out_vec = output_emb->Slice(idx, idx + 1).Resize(output_slicedim); + Tensor patch; + tree2col(dev_ctx, edge_set, embeddings, &patch, max_depth); + constant(dev_ctx, &out_vec, 0); + blas.MatMul(patch, W, &out_vec); + } + } +}; +template +class TreeConvGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *out_g = ctx.Input(framework::GradVarName("Out")); + auto *in_g = ctx.Output(framework::GradVarName("NodesVector")); + auto *filter_g = ctx.Output(framework::GradVarName("Filter")); + int max_depth = ctx.Attr("max_depth"); + auto *Embeddings = ctx.Input("NodesVector"); + auto *edges = ctx.Input("EdgeSet"); + auto *Filter = ctx.Input("Filter"); + math::Tree2ColFunctor tree2col; + math::Col2TreeFunctor col2tree; + math::SetConstant constant; + auto &dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + + Tensor W; + W.ShareDataWith(*Filter); + W.Resize(framework::flatten_to_2d(Filter->dims(), 1)); + + int batch_size = static_cast(Embeddings->dims()[0]); + + auto edge_set_slicedim = framework::slice_ddim( + edges->dims(), 1, static_cast(edges->dims().size())); + + auto embedding_slicedim = framework::slice_ddim( + Embeddings->dims(), 1, static_cast(Embeddings->dims().size())); + + auto out_grad_dims = framework::slice_ddim( + out_g->dims(), 1, static_cast(out_g->dims().size())); + out_grad_dims = framework::flatten_to_2d(out_grad_dims, 1); + if (filter_g) { + filter_g->mutable_data(Filter->dims(), ctx.GetPlace()); + Tensor f_g; + f_g.ShareDataWith(*filter_g); + f_g.Resize(framework::flatten_to_2d(Filter->dims(), 2)); + constant(dev_ctx, filter_g, 0); + for (int batch_id = 0; batch_id < batch_size; batch_id++) { + auto edge_set = + edges->Slice(batch_id, batch_id + 1).Resize(edge_set_slicedim); + auto embeddings = Embeddings->Slice(batch_id, batch_id + 1) + .Resize(embedding_slicedim); + auto out_grad = + out_g->Slice(batch_id, batch_id + 1).Resize(out_grad_dims); + Tensor patch; + tree2col(dev_ctx, edge_set, embeddings, &patch, max_depth); + blas.MatMul(patch, true, out_grad, false, T(1.0), &f_g, T(1.0)); + } + } + if (in_g) { + auto input_grad_dims = framework::slice_ddim( + in_g->dims(), 1, static_cast(in_g->dims().size())); + in_g->mutable_data(Embeddings->dims(), ctx.GetPlace()); + constant(dev_ctx, in_g, 0); + for (int batch_id = 0; batch_id < batch_size; batch_id++) { + auto edge_set = + edges->Slice(batch_id, batch_id + 1).Resize(edge_set_slicedim); + auto out_grad = + out_g->Slice(batch_id, batch_id + 1).Resize(out_grad_dims); + auto in_grad = + in_g->Slice(batch_id, batch_id + 1).Resize(input_grad_dims); + Tensor in_grad_temp; + col2tree(dev_ctx, edge_set, out_grad, &in_grad_temp, max_depth); + blas.MatMul(in_grad_temp, false, W, true, &in_grad); + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index 4ca6a5170eb57b0d799159b7ecc55c2389246041..25f95ffbb0acf618f19b36987093d5884369e530 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -200,7 +200,6 @@ class AsyncExecutor(object): local_path, self.instance.get_worker_index(), self.instance.get_node_cnt() / 2, - file_cnt, multi_processes=process_num) self.instance.barrier_worker() #wait for download_data diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 56971cff43b919f946091cebf909731a23d60736..ea88d8b4d09ab5816eabdfb978ee31143b9da5cd 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -183,6 +183,7 @@ __all__ = [ 'psroi_pool', 'teacher_student_sigmoid_loss', 'huber_loss', + 'tree_conv', ] kIgnoreIndex = -100 @@ -864,12 +865,14 @@ def dynamic_gru(input, is_reverse=False, gate_activation='sigmoid', candidate_activation='tanh', - h_0=None): + h_0=None, + origin_mode=False): """ **Gated Recurrent Unit (GRU) Layer** - Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on - Sequence Modeling `_ . + if origin_mode is False, then the equation of a gru step is from paper + `Empirical Evaluation of Gated Recurrent Neural Networks on Sequence + Modeling `_ . The formula is as follows: @@ -883,6 +886,21 @@ def dynamic_gru(input, h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t} + + if origin_mode is True then the equation is from paper + Learning Phrase Representations using RNN Encoder-Decoder for Statistical + Machine Translation `_ + + .. math:: + + u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u) + + r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r) + + \\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c) + + h_t & = u_t \odot h_{t-1} + (1-u_t) \odot \\tilde{h_t} + The :math:`\odot` is the element-wise product of the vectors. :math:`act_g` is the update gate and reset gate activation function and :math:`sigmoid` is usually used for it. :math:`act_c` is the activation function for @@ -980,7 +998,8 @@ def dynamic_gru(input, attrs={ 'is_reverse': is_reverse, 'gate_activation': gate_activation, - 'activation': candidate_activation + 'activation': candidate_activation, + 'origin_mode': origin_mode }) return hidden @@ -991,9 +1010,14 @@ def gru_unit(input, param_attr=None, bias_attr=None, activation='tanh', - gate_activation='sigmoid'): + gate_activation='sigmoid', + origin_mode=False): """ - GRU unit layer. The equation of a gru step is: + **GRU unit layer** + + if origin_mode is True, then the equation of a gru step is from paper + `Learning Phrase Representations using RNN Encoder-Decoder for Statistical + Machine Translation `_ .. math:: u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u) @@ -1002,7 +1026,21 @@ def gru_unit(input, m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m) - h_t & = dot((1-u_t), m_t) + dot(u_t, h_{t-1}) + h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t) + + if origin_mode is False, then the equation of a gru step is from paper + `Empirical Evaluation of Gated Recurrent Neural Networks on Sequence + Modeling `_ + + .. math:: + u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u) + + r_t & = actGate(xr_{t} + W_r h_{t-1} + b_r) + + m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m) + + h_t & = dot((1-u_t), h_{t-1}) + dot(u_t, m_t) + The inputs of gru unit includes :math:`z_t`, :math:`h_{t-1}`. In terms of the equation above, the :math:`z_t` is split into 3 parts - @@ -9893,3 +9931,73 @@ def huber_loss(input, label, delta): 'Residual': residual}, attrs={'delta': delta}) return out + + +@templatedoc() +def tree_conv(nodes_vector, + edge_set, + output_size, + num_filters=1, + max_depth=2, + act='tanh', + param_attr=None, + bias_attr=None, + name=None): + """ + ${comment} + + Args: + nodes_vector(${nodes_vector_type}): ${nodes_vector_comment} + edge_set(${edge_set_type}): ${edge_set_comment} + output_size(int): output feature width + num_filters(int): number of filters, Default 1 + max_depth(int): max depth of filters, Default 2 + act(str): activation function, Default tanh + param_attr(ParamAttr): the parameter attribute for the filters, Default None + bias_attr(ParamAttr): the parameter attribute for the bias of this layer, Default None + name(str): a name of this layer(optional). If set None, the layer will be named automatically, Default None + + Returns: + out(${out_type}): ${out_comment} + + Examples: + .. code-block:: python + + nodes_vector = layers.data(name='vectors', shape=[None, 10, 5], dtype='float32) + # None for batch size, 10 for max_node_size of dataset, 5 for vector width + edge_set = layers.data(name='edge_set', shape=[None, 10, 2], dtype='float32') + # None for batch size, 10 for max_node_size of dataset, 2 for every edge has two nodes + # edges must be directional + out_vector = layers.tree_conv(nodes_vector, edge_set, 6, 1, 2, 'tanh', + ParamAttr(initializer=Constant(1.0), ParamAttr(initializer=Constant(1.0)) + # the shape of output will be [None, 10, 6, 1], + # None for batch size, 10 for max_node_size of dataset, 6 for output size, 1 for 1 filter + out_vector = layers.reshape(out_vector, shape=[None, 10, 6]) + # After reshape, output tensor could be nodes_vector for next tree convolution + out_vector_2 = layers.tree_conv(out_vector, edge_set, 3, 4, 2, 'tanh', + ParamAttr(initializer=Constant(1.0), ParamAttr(initializer=Constant(1.0)) + # also output tensor could be pooling(the pooling in paper called global pooling) + pooled = layers.reduce_max(out_vector, dims=2) # global pooling + """ + helper = LayerHelper("tree_conv", **locals()) + dtype = helper.input_dtype('nodes_vector') + feature_size = nodes_vector.shape[2] + W_shape = [feature_size, 3, output_size, num_filters] + W = helper.create_parameter( + attr=param_attr, shape=W_shape, dtype=dtype, is_bias=False) + if name == None: + out = helper.create_variable_for_type_inference(dtype=dtype) + else: + out = helper.create_variable(name=name, dtype=dtype, persistable=False) + helper.append_op( + type='tree_conv', + inputs={'NodesVector': nodes_vector, + 'EdgeSet': edge_set, + 'Filter': W}, + outputs={'Out': out, }, + attrs={'max_depth': max_depth}) + if helper.bias_attr: + pre_activation = helper.append_bias_op(out) + else: + pre_activation = out + return helper.append_activation(pre_activation) diff --git a/python/paddle/fluid/tests/book/high-level-api/recommender_system/test_recommender_system_newapi.py b/python/paddle/fluid/tests/book/high-level-api/recommender_system/test_recommender_system_newapi.py index 82193737967b2bebdd17cef8752eeb9cec2e85ce..07afa742c6b7d28b129192e4b9ffc41a405d3367 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recommender_system/test_recommender_system_newapi.py +++ b/python/paddle/fluid/tests/book/high-level-api/recommender_system/test_recommender_system_newapi.py @@ -231,14 +231,17 @@ def infer(use_cuda, inference_program, params_dirname): # Correspondingly, recursive_sequence_lengths = [[3, 2]] contains one # level of detail info, indicating that `data` consists of two sequences # of length 3 and 2, respectively. - user_id = fluid.create_lod_tensor([[1]], [[1]], place) - gender_id = fluid.create_lod_tensor([[1]], [[1]], place) - age_id = fluid.create_lod_tensor([[0]], [[1]], place) - job_id = fluid.create_lod_tensor([[10]], [[1]], place) - movie_id = fluid.create_lod_tensor([[783]], [[1]], place) - category_id = fluid.create_lod_tensor([[10, 8, 9]], [[3]], place) - movie_title = fluid.create_lod_tensor([[1069, 4140, 2923, 710, 988]], [[5]], - place) + user_id = fluid.create_lod_tensor([[np.int64(1)]], [[1]], place) + gender_id = fluid.create_lod_tensor([[np.int64(1)]], [[1]], place) + age_id = fluid.create_lod_tensor([[np.int64(0)]], [[1]], place) + job_id = fluid.create_lod_tensor([[np.int64(10)]], [[1]], place) + movie_id = fluid.create_lod_tensor([[np.int64(783)]], [[1]], place) + category_id = fluid.create_lod_tensor( + [np.array( + [10, 8, 9], dtype='int64')], [[3]], place) + movie_title = fluid.create_lod_tensor( + [np.array( + [1069, 4140, 2923, 710, 988], dtype='int64')], [[5]], place) results = inferencer.infer( { diff --git a/python/paddle/fluid/tests/book/test_recommender_system.py b/python/paddle/fluid/tests/book/test_recommender_system.py index cf8c48f34697d789d3d81d4d94f90a7169657baf..0e1efc8212ec2913ca3653c47bd2d9e298a772ee 100644 --- a/python/paddle/fluid/tests/book/test_recommender_system.py +++ b/python/paddle/fluid/tests/book/test_recommender_system.py @@ -271,26 +271,30 @@ def infer(use_cuda, save_dirname=None): # Correspondingly, recursive_sequence_lengths = [[3, 2]] contains one # level of detail info, indicating that `data` consists of two sequences # of length 3 and 2, respectively. - user_id = fluid.create_lod_tensor([[1]], [[1]], place) + user_id = fluid.create_lod_tensor([[np.int64(1)]], [[1]], place) assert feed_target_names[1] == "gender_id" - gender_id = fluid.create_lod_tensor([[1]], [[1]], place) + gender_id = fluid.create_lod_tensor([[np.int64(1)]], [[1]], place) assert feed_target_names[2] == "age_id" - age_id = fluid.create_lod_tensor([[0]], [[1]], place) + age_id = fluid.create_lod_tensor([[np.int64(0)]], [[1]], place) assert feed_target_names[3] == "job_id" - job_id = fluid.create_lod_tensor([[10]], [[1]], place) + job_id = fluid.create_lod_tensor([[np.int64(10)]], [[1]], place) assert feed_target_names[4] == "movie_id" - movie_id = fluid.create_lod_tensor([[783]], [[1]], place) + movie_id = fluid.create_lod_tensor([[np.int64(783)]], [[1]], place) assert feed_target_names[5] == "category_id" - category_id = fluid.create_lod_tensor([[10, 8, 9]], [[3]], place) + category_id = fluid.create_lod_tensor( + [np.array( + [10, 8, 9], dtype='int64')], [[3]], place) assert feed_target_names[6] == "movie_title" - movie_title = fluid.create_lod_tensor([[1069, 4140, 2923, 710, 988]], - [[5]], place) + movie_title = fluid.create_lod_tensor( + [np.array( + [1069, 4140, 2923, 710, 988], dtype='int64')], [[5]], + place) # Construct feed as a dictionary of {feed_target_name: feed_target_data} # and results will contain a list of data corresponding to fetch_targets. diff --git a/python/paddle/fluid/tests/unittests/test_auc_op.py b/python/paddle/fluid/tests/unittests/test_auc_op.py index 810e8a1a8547a92de923877695178e780981edeb..b75abd424a46976e3426c16b91fb977b3be2e94f 100644 --- a/python/paddle/fluid/tests/unittests/test_auc_op.py +++ b/python/paddle/fluid/tests/unittests/test_auc_op.py @@ -24,7 +24,7 @@ class TestAucOp(OpTest): def setUp(self): self.op_type = "auc" pred = np.random.random((128, 2)).astype("float32") - labels = np.random.randint(0, 2, (128, 1)) + labels = np.random.randint(0, 2, (128, 1)).astype("int64") num_thresholds = 200 stat_pos = np.zeros((num_thresholds + 1, )).astype("int64") diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index f61a447fd77d001c1a9c42c46ba9a0d747e926ec..6606162733487b15ef55f1a4677fb382e6e7e0ac 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -31,7 +31,8 @@ def gru( is_reverse, act_state, act_gate, - dtype='float32'): + dtype='float32', + origin_mode=False): def _seq_to_batch(lod, is_reverse): idx_in_seq_list = [] seq_lens = lod[0] @@ -66,7 +67,10 @@ def gru( w_c = w.flatten()[D * D * 2:].reshape((D, D)) c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:]) g = np.hstack((u_r, c)) - h = u * c + (1 - u) * h_p + if origin_mode: + h = (1 - u) * c + u * h_p + else: + h = u * c + (1 - u) * h_p return g, r_h_p, h T = sum(lod[0]) @@ -110,6 +114,7 @@ class TestGRUOp(OpTest): self.act_state = 'tanh' self.act_gate = 'sigmoid' self.dtype = 'float64' + self.origin_mode = False self.set_confs() T = sum(self.lod[0]) @@ -126,7 +131,8 @@ class TestGRUOp(OpTest): batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru( input, self.lod, h0, weight, bias, self.is_reverse, - ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype) + ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype, + self.origin_mode) self.inputs = {'Input': (input, self.lod), 'Weight': weight} if self.with_bias: @@ -145,7 +151,8 @@ class TestGRUOp(OpTest): self.attrs = { 'activation': self.act_state, 'gate_activation': self.act_gate, - 'is_reverse': self.is_reverse + 'is_reverse': self.is_reverse, + 'origin_mode': self.origin_mode } def test_check_output(self): @@ -155,12 +162,24 @@ class TestGRUOp(OpTest): self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden']) +class TestGRUOriginMode(TestGRUOp): + def set_confs(self): + self.origin_mode = True + + class TestGRUOp2(TestGRUOp): def set_confs(self): self.D = 19 self.dtype = 'float32' +class TestGRUOp2OriginMode(TestGRUOp): + def set_confs(self): + self.D = 19 + self.dtype = 'float32' + self.origin_mode = True + + class TestGRUOpNoInitial(TestGRUOp): def set_confs(self): self.with_h0 = False @@ -182,5 +201,11 @@ class TestGRUOpReverse(TestGRUOp): self.is_reverse = True +class TestGRUOpReverseOriginMode(TestGRUOp): + def set_confs(self): + self.is_reverse = True + self.origin_mode = True + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_gru_unit_op.py b/python/paddle/fluid/tests/unittests/test_gru_unit_op.py index b5a66fdf086f2abc0c9a8af663241b9eda739407..78f2f030f5b6dc9d827f3930dff590a0f5b784fb 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_unit_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_unit_op.py @@ -53,7 +53,7 @@ class TestGRUUnitOp(OpTest): GRUActivationType.relu: relu, } - def set_inputs(self): + def set_inputs(self, origin_mode=False): batch_size = self.batch_size frame_size = self.frame_size self.op_type = 'gru_unit' @@ -68,10 +68,11 @@ class TestGRUUnitOp(OpTest): } self.attrs = { 'activation': GRUActivationType.tanh, - 'gate_activation': GRUActivationType.sigmoid + 'gate_activation': GRUActivationType.sigmoid, + 'origin_mode': origin_mode } - def set_outputs(self): + def set_outputs(self, origin_mode=False): # GRU calculations batch_size = self.batch_size frame_size = self.frame_size @@ -93,7 +94,10 @@ class TestGRUUnitOp(OpTest): c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) + g[:, frame_size * 2:]) g = np.hstack((u_r, c)) - h = u * c + (1 - u) * h_p + if origin_mode: + h = (1 - u) * c + u * h_p + else: + h = u * c + (1 - u) * h_p self.outputs = { 'Gate': g.astype('float64'), 'ResetHiddenPrev': r_h_p.astype('float64'), @@ -111,8 +115,14 @@ class TestGRUUnitOp(OpTest): self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden']) +class TestGRUUnitOpOriginMode(TestGRUUnitOp): + def setUp(self): + self.set_inputs(origin_mode=True) + self.set_outputs(origin_mode=True) + + class TestGRUUnitOpWithBias(TestGRUUnitOp): - def set_inputs(self): + def set_inputs(self, origin_mode=False): batch_size = self.batch_size frame_size = self.frame_size super(TestGRUUnitOpWithBias, self).set_inputs() @@ -120,7 +130,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): -0.1, 0.1, (1, frame_size * 3)).astype('float64') self.attrs = { 'activation': GRUActivationType.identity, - 'gate_activation': GRUActivationType.sigmoid + 'gate_activation': GRUActivationType.sigmoid, + 'origin_mode': origin_mode } def test_check_grad(self): @@ -132,5 +143,11 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): no_grad_set=set('Input')) +class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias): + def setUp(self): + self.set_inputs(origin_mode=True) + self.set_outputs(origin_mode=True) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_nce.py b/python/paddle/fluid/tests/unittests/test_nce.py index f4f97446746836f4acf42fa662cc30b20af6e4b1..1e462d13d0755f48fd73a9eae335584858ecb17f 100644 --- a/python/paddle/fluid/tests/unittests/test_nce.py +++ b/python/paddle/fluid/tests/unittests/test_nce.py @@ -68,7 +68,8 @@ class TestNCE(OpTest): weight = np.random.randn(num_classes, dim).astype(np.float32) bias = np.random.randn(num_classes).astype(np.float32) sample_weight = np.random.randn(batch_size).astype(np.float32) - labels = np.random.randint(0, num_classes, (batch_size, num_true_class)) + labels = np.random.randint(0, num_classes, + (batch_size, num_true_class)).astype("int64") self.attrs = { 'num_total_classes': num_classes, 'num_neg_samples': num_neg_samples, diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py index ee0941f19838355180edb5771f1e85292a64de59..e0eba2147c6288e5b2f30373f610db78493d5e03 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py @@ -24,14 +24,14 @@ import os def Lenet(data, class_dim): - conv1 = fluid.layers.conv2d(data, 32, 5, 1, act=None) + conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None) bn1 = fluid.layers.batch_norm(conv1, act='relu') pool1 = fluid.layers.pool2d(bn1, 2, 'max', 2) - conv2 = fluid.layers.conv2d(pool1, 50, 5, 1, act=None) + conv2 = fluid.layers.conv2d(pool1, 16, 5, 1, act=None) bn2 = fluid.layers.batch_norm(conv2, act='relu') pool2 = fluid.layers.pool2d(bn2, 2, 'max', 2) - fc1 = fluid.layers.fc(pool2, size=500, act='relu') + fc1 = fluid.layers.fc(pool2, size=50, act='relu') fc2 = fluid.layers.fc(fc1, size=class_dim, act='softmax') return fc2 diff --git a/python/paddle/fluid/tests/unittests/test_tree_conv_op.py b/python/paddle/fluid/tests/unittests/test_tree_conv_op.py new file mode 100644 index 0000000000000000000000000000000000000000..712453d29101d0878eb362e60e22d6b21f0ba026 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tree_conv_op.py @@ -0,0 +1,120 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from op_test import OpTest + + +def collect_node_patch(og, max_depth): + """ + The naive method to construct patches + :param og: original graph + :param max_depth: the depth of convolution filters + :return: convolution patches + """ + + def gen(node, max_depth): + collected = [(node, 1, 1, 0, max_depth)] + + def recurse_helper(node, depth): + if depth > max_depth: + return + l = len(og[node]) + for idx, c in enumerate(og[node], 1): + if depth + 1 < max_depth: + collected.append((c, idx, l, depth + 1, max_depth)) + recurse_helper(c, depth + 1) + + recurse_helper(node, 0) + return collected + + res = [] + for u in range(1, len(og)): + lis = gen(u, max_depth) + if len(lis) > 0: + res.append(lis) + return res + + +class TestTreeConvOp(OpTest): + def setUp(self): + self.n = 17 + self.fea_size = 3 + self.output_size = 1 + self.max_depth = 2 + self.batch_size = 1 + self.num_filters = 1 + adj_array = [ + 1, 2, 1, 3, 1, 4, 1, 5, 2, 6, 2, 7, 2, 8, 4, 9, 4, 10, 5, 11, 6, 12, + 6, 13, 9, 14, 9, 15, 9, 16, 9, 17 + ] + adj = np.array(adj_array).reshape((1, self.n - 1, 2)).astype('int32') + adj = np.tile(adj, (self.batch_size, 1, 1)) + self.op_type = 'tree_conv' + vectors = np.random.random( + (self.batch_size, self.n, self.fea_size)).astype('float32') + self.inputs = { + 'EdgeSet': adj, + 'NodesVector': vectors, + 'Filter': np.random.random((self.fea_size, 3, self.output_size, + self.num_filters)).astype('float32') + } + self.attrs = {'max_depth': self.max_depth} + vectors = [] + for i in range(self.batch_size): + vector = self.get_output_naive(i) + vectors.append(vector) + self.outputs = {'Out': np.array(vectors).astype('float32'), } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['NodesVector', 'Filter'], 'Out', max_relative_error=0.5) + + def get_output_naive(self, batch_id): + og = [[] for i in range(1, self.n + 2)] + st = np.array(self.inputs['EdgeSet'][batch_id]).tolist() + for e in st: + og[e[0]].append(e[1]) + patches = collect_node_patch(og, self.max_depth) + W = np.array(self.inputs['Filter']).astype('float32') + W = np.transpose(W, axes=[1, 0, 2, 3]) + vec = [] + for i, patch in enumerate(patches, 1): + result = np.zeros((1, W.shape[2], W.shape[3])) + for v in patch: + eta_t = float(v[4] - v[3]) / float(v[4]) + eta_l = (1.0 - eta_t) * (0.5 if v[2] == 1 else + float(v[1] - 1.0) / float(v[2] - 1.0)) + eta_r = (1.0 - eta_t) * (1.0 - eta_l) + x = self.inputs['NodesVector'][batch_id][v[0] - 1] + eta = np.array([eta_l, eta_r, eta_t]).reshape( + (3, 1)).astype('float32') + Wconvi = np.tensordot(eta, W, axes=([0], [0])) + x = np.array(x).reshape((1, 1, self.fea_size)) + res = np.tensordot(x, Wconvi, axes=2) + result = result + res + vec.append(result) + vec = np.concatenate(vec, axis=0) + vec = np.concatenate( + [ + vec, np.zeros( + (self.n - vec.shape[0], W.shape[2], W.shape[3]), + dtype='float32') + ], + axis=0) + return vec