提交 8f6597aa 编写于 作者: L luotao1

Merge branch 'develop' into infershape_example

...@@ -174,7 +174,7 @@ else() ...@@ -174,7 +174,7 @@ else()
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
target_link_libraries(executor garbage_collector) target_link_libraries(executor garbage_collector while_op_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
......
...@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_ ...@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -31,6 +32,8 @@ class ComputationOpHandle : public OpHandleBase { ...@@ -31,6 +32,8 @@ class ComputationOpHandle : public OpHandleBase {
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t scope_idx); size_t scope_idx);
OperatorBase *GetOp() { return op_.get(); }
std::string Name() const override; std::string Name() const override;
const Scope *GetScope() const { return scope_; } const Scope *GetScope() const { return scope_; }
......
...@@ -12,6 +12,10 @@ ...@@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -45,6 +49,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle( ...@@ -45,6 +49,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
} }
} }
#endif #endif
PADDLE_ENFORCE(!var_names_.empty(), "Var names cannot be empty");
} }
EagerDeletionOpHandle::~EagerDeletionOpHandle() { EagerDeletionOpHandle::~EagerDeletionOpHandle() {
...@@ -60,15 +65,20 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() { ...@@ -60,15 +65,20 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() { void EagerDeletionOpHandle::RunImpl() {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); Scope *exec_scope = nullptr;
std::deque<std::shared_ptr<memory::Allocation>> garbages; std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) { for (auto &name : var_names_) {
auto it = ref_cnts_->find(name); auto it = ref_cnts_->find(name);
// Var not found, not reference count has not decreased to 0 // Reference count has not decreased to 0
if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) { if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
continue; continue;
} }
if (!exec_scope) {
exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
}
// Var not found
auto *var = exec_scope->FindVar(name); auto *var = exec_scope->FindVar(name);
if (var == nullptr) { if (var == nullptr) {
continue; continue;
......
...@@ -12,20 +12,173 @@ ...@@ -12,20 +12,173 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include <functional>
#include <queue> #include <queue>
#include <string> #include <string>
#include <tuple>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
"Fraction of eager deletion. If less than 1.0, all variables in "
"the program would be sorted according to its memory size, and "
"only the FLAGS_memory_fraction_of_eager_deletion of the largest "
"variables would be deleted.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
// op -> variables which can be deleted after op runs
using OpToVarNameSetMap =
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
// Check whether the variable is LoDTensor based on static VarDesc info
static bool IsLoDTensor(VarDesc *var) {
return var->Proto()->type().type() == proto::VarType::LOD_TENSOR;
}
// Get memory size of LoDTensor
static int64_t GetMemorySize(
const std::unordered_map<std::string, std::vector<VarHandle *>> &vars,
const std::string &var_name) {
auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
PADDLE_ENFORCE_NOT_NULL(var_desc);
PADDLE_ENFORCE(IsLoDTensor(var_desc));
auto dims = var_desc->GetShape();
return SizeOfType(var_desc->GetDataType()) *
std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1),
std::multiplies<int64_t>());
}
// Split all variables in the graph into LoDTensor and Non-LoDTensor (e.g.
// SelectedRows, LoDTensorArray)
// Since partial GC is based on static analysis of memory size of each variable
// So we should skip SelectedRows and LoDTensorArray here
static void SplitIntoLoDTensorAndNonLoDTensorVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
lod_tensors->clear();
other_vars->clear();
for (auto &op_vars_pair : m) {
for (auto &var_name : op_vars_pair.second) {
auto *var_desc = TryGetLatestVarDesc(
vars[op_vars_pair.first->GetScopeIdx()].at(var_name));
if (IsLoDTensor(var_desc)) {
(*lod_tensors)[op_vars_pair.first].insert(var_name);
} else {
(*other_vars)[op_vars_pair.first].insert(var_name);
}
}
}
}
struct GCVarInfo {
GCVarInfo(const std::string &name, int64_t memory_size,
ComputationOpHandle *op, size_t scope_idx)
: name_(name),
memory_size_(memory_size),
op_(op),
scope_idx_(scope_idx) {}
std::string name_; // variable name
int64_t memory_size_; // memory size
ComputationOpHandle *op_; // op after which the variable could be deleted
size_t scope_idx_; // scope index where the variable locates
int64_t AbsMemorySize() const { return std::abs(memory_size_); }
};
// Delete delete_lod_tensor_only is not used currently
static OpToVarNameSetMap ShrinkGCVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
const std::vector<platform::Place> &places, double fraction_of_memory_size,
bool delete_lod_tensor_only = false) {
// Do not perform gc when fraction_of_memory_size = 0
if (fraction_of_memory_size <= 0.0) return {};
/**
* Step 1: Split all variables into LoDTensor and Non-LoDTensor.
* We can only calculate memory size of LoDTensors
*/
OpToVarNameSetMap lod_tensors, other_vars;
SplitIntoLoDTensorAndNonLoDTensorVars(m, vars, &lod_tensors, &other_vars);
// Perform complete gc when fraction_of_memory_size >= 1
if (fraction_of_memory_size >= 1.0) {
return delete_lod_tensor_only ? lod_tensors : m;
}
/**
* Step 2: build GCVarInfos, and calculate total memory sizes of each device
*/
// place -> variable info (name, memory size, place, scope_idx)
std::map<platform::Place, std::vector<GCVarInfo>> place_to_vars;
// place -> total memory sizes
std::map<platform::Place, int64_t> place_to_size;
for (auto &op_vars_pair : lod_tensors) {
auto *op = op_vars_pair.first;
auto &var_names = op_vars_pair.second;
auto scope_idx = op->GetScopeIdx();
auto &place = places[scope_idx];
for (auto &var_name : var_names) {
auto var_size = GetMemorySize(vars[scope_idx], var_name);
GCVarInfo var_info(var_name, var_size, op, scope_idx);
place_to_size[place] += var_info.AbsMemorySize();
place_to_vars[place].emplace_back(std::move(var_info));
}
}
/**
* Step 3: sort GCVarInfos, and only delete the largest variables.
*/
OpToVarNameSetMap partial_vars;
for (auto &place_to_var_pair : place_to_vars) {
auto &place = place_to_var_pair.first;
auto &gc_vars = place_to_var_pair.second;
std::sort(gc_vars.begin(), gc_vars.end(),
[](const GCVarInfo &var1, const GCVarInfo &var2) {
return var1.AbsMemorySize() > var2.AbsMemorySize();
});
int64_t accumulated_size = 0;
int64_t size_threshold =
static_cast<int64_t>(fraction_of_memory_size * place_to_size[place]);
for (size_t i = 0; i < gc_vars.size() && accumulated_size < size_threshold;
++i) {
partial_vars[gc_vars[i].op_].insert(gc_vars[i].name_);
accumulated_size += gc_vars[i].AbsMemorySize();
}
}
/**
* Step 4: Combine other vars (SelectedRows, LoDTensorArray)
*/
if (!delete_lod_tensor_only) {
for (auto &op_vars_pair : other_vars) {
partial_vars[op_vars_pair.first].insert(op_vars_pair.second.begin(),
op_vars_pair.second.end());
}
}
return partial_vars;
}
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = auto &ref_cnts =
...@@ -43,9 +196,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -43,9 +196,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
// a reverse map of last_live_ops // a reverse map of last_live_ops
// i.e., last op --> variable names which can be deleted. // i.e., last op --> variable names which can be deleted.
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>> OpToVarNameSetMap op_vars_map;
op_vars_map;
for (auto &var_ops_map : last_live_ops) { for (auto &var_ops_map : last_live_ops) {
for (auto &var_ops_pair : var_ops_map) { for (auto &var_ops_pair : var_ops_map) {
const std::string &var_name = var_ops_pair.first; const std::string &var_name = var_ops_pair.first;
...@@ -55,6 +206,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -55,6 +206,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
} }
} }
op_vars_map = ShrinkGCVars(op_vars_map, vars, places,
FLAGS_memory_fraction_of_eager_deletion);
for (auto &pair : op_vars_map) { for (auto &pair : op_vars_map) {
auto *op = pair.first; auto *op = pair.first;
auto &var_names = pair.second; auto &var_names = pair.second;
...@@ -85,8 +239,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -85,8 +239,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
eager_deletion_op->AddOutput(dummy_leaf); eager_deletion_op->AddOutput(dummy_leaf);
} }
VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = "
<< FLAGS_memory_fraction_of_eager_deletion;
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)"; VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
return graph;
auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
return while_op_eager_deletion_pass->Apply(std::move(graph));
} }
} // namespace details } // namespace details
...@@ -99,3 +258,5 @@ REGISTER_PASS(eager_deletion_pass, ...@@ -99,3 +258,5 @@ REGISTER_PASS(eager_deletion_pass,
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars) .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces) .RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector); .RequirePassAttr(paddle::framework::details::kGarbageCollector);
USE_PASS(while_op_eager_deletion_pass);
...@@ -12,9 +12,13 @@ ...@@ -12,9 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include <queue> #include <queue>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
...@@ -189,15 +193,6 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, ...@@ -189,15 +193,6 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
return shrink_func(computation_op); return shrink_func(computation_op);
} }
static VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount); auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
......
...@@ -13,9 +13,22 @@ ...@@ -13,9 +13,22 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details {} // namespace details namespace details {
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
} // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <atomic> #include <atomic>
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -25,6 +26,10 @@ ...@@ -25,6 +26,10 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class VarDesc;
class VarHandle;
namespace details { namespace details {
class ComputationOpHandle; class ComputationOpHandle;
...@@ -43,9 +48,11 @@ const char kGarbageCollector[] = "garbage_collector"; ...@@ -43,9 +48,11 @@ const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places"; const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars = using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>; std::unordered_map<std::string, std::unordered_set<ComputationOpHandle *>>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var"; const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
namespace paddle {
namespace framework {
namespace details {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// Find all while_op and while_grad_op
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
std::vector<OperatorBase *>>>
target_ops;
for (auto *op : all_ops) {
auto compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "while") {
target_ops[compute_op->GetScopeIdx()].first.emplace_back(
compute_op->GetOp());
} else if (compute_op->Name() == "while_grad") {
target_ops[compute_op->GetScopeIdx()].second.emplace_back(
compute_op->GetOp());
}
}
for (auto &ops_pair : target_ops) {
auto &while_ops = ops_pair.second.first;
auto &while_grad_ops = ops_pair.second.second;
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops);
}
return graph;
}
};
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(while_op_eager_deletion_pass,
paddle::framework::details::WhileOpEagerDeletionPass);
...@@ -14,6 +14,10 @@ limitations under the License. */ ...@@ -14,6 +14,10 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include <deque> #include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
...@@ -23,6 +27,7 @@ limitations under the License. */ ...@@ -23,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -75,11 +80,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts( ...@@ -75,11 +80,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
ExecutorPrepareContext::ExecutorPrepareContext( ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id, const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars) const std::vector<std::string>& keep_vars, bool force_disable_gc)
: prog_(prog), block_id_(block_id) { : prog_(prog), block_id_(block_id), force_disable_gc_(force_disable_gc) {
if (GetEagerDeletionThreshold() >= 0) { if (GetEagerDeletionThreshold() >= 0 && !force_disable_gc_) {
global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id), global_ref_cnts_ =
skip_ref_cnt_vars); GetNonPersistableReferenceCounts(prog.Block(block_id), keep_vars);
} }
} }
...@@ -184,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -184,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
} }
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars,
bool force_disable_gc) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc); if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc);
#endif #endif
auto ctx = Prepare(pdesc, block_id); auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
} }
...@@ -357,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -357,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars) { const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
std::unique_ptr<ExecutorPrepareContext> ctx( std::unique_ptr<ExecutorPrepareContext> ctx(new ExecutorPrepareContext(
new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars)); program, block_id, skip_ref_cnt_vars, force_disable_gc));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id); auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
...@@ -370,7 +377,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( ...@@ -370,7 +377,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids, const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars) { const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
bool force_disable_gc) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(), skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
"skip_ref_cnt_vars should be either empty or equals to block number %d", "skip_ref_cnt_vars should be either empty or equals to block number %d",
...@@ -380,9 +388,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( ...@@ -380,9 +388,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
for (auto& bid : block_ids) { for (auto& bid : block_ids) {
ExecutorPrepareContext* ctx; ExecutorPrepareContext* ctx;
if (skip_ref_cnt_vars.empty()) { if (skip_ref_cnt_vars.empty()) {
ctx = new ExecutorPrepareContext(program, bid); ctx = new ExecutorPrepareContext(program, bid, std::vector<std::string>(),
force_disable_gc);
} else { } else {
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]); ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx],
force_disable_gc);
} }
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
auto& block = program.Block(bid); auto& block = program.Block(bid);
...@@ -409,8 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -409,8 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc; std::unique_ptr<GarbageCollector> gc;
// skip while_op and while_grad_op temporarily // FIXME(zjl): recurrent_op is rather complex, we would
if (max_memory_size >= 0 && !keep_kids) { // disable gc forcely in recurrent_op
if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
ctx->ResetReferenceCount(); ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
...@@ -428,6 +439,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -428,6 +439,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
#endif #endif
// If gc is enabled and block size > 1
if (gc && ctx->prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_,
ctx->ops_);
}
} }
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
...@@ -30,7 +32,8 @@ namespace framework { ...@@ -30,7 +32,8 @@ namespace framework {
struct ExecutorPrepareContext { struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id, ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>()); std::vector<std::string>(),
bool force_disable_gc = false);
~ExecutorPrepareContext(); ~ExecutorPrepareContext();
...@@ -38,6 +41,7 @@ struct ExecutorPrepareContext { ...@@ -38,6 +41,7 @@ struct ExecutorPrepareContext {
const framework::ProgramDesc& prog_; const framework::ProgramDesc& prog_;
size_t block_id_; size_t block_id_;
bool force_disable_gc_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, size_t> global_ref_cnts_; std::unordered_map<std::string, size_t> global_ref_cnts_;
...@@ -66,7 +70,10 @@ class Executor { ...@@ -66,7 +70,10 @@ class Executor {
* Scope * Scope
*/ */
void Run(const ProgramDesc& prog, Scope* scope, int block_id, void Run(const ProgramDesc& prog, Scope* scope, int block_id,
bool create_local_scope = true, bool create_vars = true); bool create_local_scope = true, bool create_vars = true,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
// This API is very slow. // This API is very slow.
void Run(const ProgramDesc& program, Scope* scope, void Run(const ProgramDesc& program, Scope* scope,
...@@ -79,12 +86,14 @@ class Executor { ...@@ -79,12 +86,14 @@ class Executor {
static std::unique_ptr<ExecutorPrepareContext> Prepare( static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>()); std::vector<std::string>(),
bool force_disable_gc = false);
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare( static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids, const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars = const std::vector<std::vector<std::string>>& skip_ref_cnt_vars =
std::vector<std::vector<std::string>>()); std::vector<std::vector<std::string>>(),
bool force_disable_gc = false);
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id); void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
......
...@@ -159,10 +159,9 @@ class Autograd { ...@@ -159,10 +159,9 @@ class Autograd {
for (auto it : candidate->pre_ops_) { for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) { for (OpBase* pre_op : it.second) {
if (!pre_op) continue; if (!pre_op) continue;
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " trace id " VLOG(5) << "op dep " << candidate->Type() << " trace id "
<< candidate->trace_id_ << " <---- " << it.first << " <---- " << candidate->trace_id_ << " <---- " << it.first << " <---- "
<< pre_op->op_desc_->Type() << " trace id " << pre_op->Type() << " trace id " << pre_op->trace_id_;
<< pre_op->trace_id_;
if (visited.find(pre_op) == visited.end()) { if (visited.find(pre_op) == visited.end()) {
visited.insert(pre_op); visited.insert(pre_op);
queue.push_back(pre_op); queue.push_back(pre_op);
...@@ -180,10 +179,12 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -180,10 +179,12 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
PADDLE_ENFORCE(var_->IsInitialized(), PADDLE_ENFORCE(var_->IsInitialized(),
"Variable must be initialized when getting numpy tensor"); "Variable must be initialized when getting numpy tensor");
std::unique_ptr<VarBase> new_var(new VarBase()); // TODO(minqiyang): change this after move unique_name generator to CXX
const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
std::unique_ptr<VarBase> new_var(new VarBase(
"Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
framework::LoDTensor* tensor = framework::LoDTensor* tensor =
new_var->var_->GetMutable<framework::LoDTensor>(); new_var->var_->GetMutable<framework::LoDTensor>();
tensor->Resize(var_->Get<framework::LoDTensor>().dims());
tensor->set_lod(var_->Get<framework::LoDTensor>().lod()); tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
if (blocking) { if (blocking) {
...@@ -199,52 +200,62 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -199,52 +200,62 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
} }
if (platform::is_gpu_place(dst_place)) { if (platform::is_gpu_place(dst_place)) {
VLOG(3) << "copy tensor " << var_desc_->Name() << " from gpu"; VLOG(3) << "copy tensor " << Name() << " from gpu";
} }
return new_var; return new_var;
} }
framework::LoDTensor& VarBase::GradValue() { framework::LoDTensor& VarBase::GradValue() {
VLOG(3) << "get var grad " << var_desc_->Name(); VLOG(3) << "get var grad " << Name();
PADDLE_ENFORCE_NOT_NULL(grads_,
"Could not get grad value from no grad variable");
return *(grads_->var_->GetMutable<framework::LoDTensor>()); return *(grads_->var_->GetMutable<framework::LoDTensor>());
} }
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (grad_op_descs_.empty() && backward_id_ <= 0) { if (grad_op_descs_.empty() && backward_id_ <= 0) {
VLOG(3) << "op with no grad: " << op_desc_->Type(); VLOG(3) << "op with no grad: " << Type();
return {}; return {};
} }
VLOG(3) << "apply op grad: " << op_desc_->Type(); VLOG(3) << "apply op grad: " << Type();
std::vector<framework::VariableValueMap> grad_outputs; std::vector<framework::VariableValueMap> tmp_grad_outputs;
if (backward_id_ > 0) { if (backward_id_ > 0) {
VLOG(3) << "py_layer_grad"; VLOG(3) << "py_layer_grad";
grad_outputs.resize(1); tmp_grad_outputs.resize(1);
grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] = tmp_grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
PyLayer::ApplyGrad( PyLayer::ApplyGrad(
backward_id_, backward_id_,
grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]); grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
} else { } else {
grad_outputs.resize(grad_op_descs_.size()); const size_t grad_op_count = grad_op_descs_.size();
for (size_t k = 0; k < grad_op_descs_.size(); ++k) {
tmp_grad_outputs.resize(grad_op_count);
for (size_t k = 0; k < grad_op_count; ++k) {
framework::OpDesc* grad_op_desc = grad_op_descs_[k]; framework::OpDesc* grad_op_desc = grad_op_descs_[k];
VLOG(3) << "op grad " << grad_op_desc->Type(); auto& grad_output_variable_map = grad_output_vars_[k];
for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first]; VLOG(3) << "apply grad op " << grad_op_desc->Type();
// Allocate tmp grad output variable
for (auto it : grad_output_variable_map) {
auto& outputs = tmp_grad_outputs[k][it.first];
outputs.reserve(it.second.size());
for (size_t i = 0; i < it.second.size(); ++i) { for (size_t i = 0; i < it.second.size(); ++i) {
// Allocate a new variable // Allocate a new variable
Variable* tmp_var = new framework::Variable(); Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>(); tmp_var->GetMutable<framework::LoDTensor>();
outputs.push_back(tmp_var); outputs.emplace_back(tmp_var);
} }
} }
framework::RuntimeContext ctx(grad_input_vars_[k], grad_outputs[k]); // Run grad op
framework::RuntimeContext ctx(grad_input_vars_[k], tmp_grad_outputs[k]);
// No need to do compile time infer shape here. // No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_); // grad_op_desc_->InferShape(*block_);
grad_op_desc->InferVarType(block_); // grad_op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc); framework::OpRegistry::CreateOp(*grad_op_desc);
...@@ -260,9 +271,10 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -260,9 +271,10 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
} }
} }
// Add tmp grad outputs to original grad vars
for (size_t k = 0; k < grad_output_vars_.size(); ++k) { for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
for (auto it : grad_output_vars_[k]) { for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first]; auto& outputs = tmp_grad_outputs[k][it.first];
auto& origin_outputs = it.second; auto& origin_outputs = it.second;
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size()); PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
...@@ -316,19 +328,14 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) { ...@@ -316,19 +328,14 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
int PyLayer::NumFuncs() { return py_funcs_.size(); } int PyLayer::NumFuncs() { return py_funcs_.size(); }
std::vector<VarBase*> PyLayer::Apply(int func_id, std::vector<Variable*> PyLayer::Apply(int func_id,
const std::vector<VarBase*>& inputs) { const std::vector<VarBase*>& inputs) {
std::vector<framework::Variable*> invars; std::vector<framework::Variable*> invars;
for (const VarBase* in : inputs) { for (const VarBase* in : inputs) {
invars.push_back(in->var_); invars.push_back(in->var_);
} }
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end()); PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
std::vector<Variable*> outvars = CallPythonFunc(py_funcs_[func_id], invars); return CallPythonFunc(py_funcs_[func_id], invars);
std::vector<VarBase*> ret;
for (Variable* v : outvars) {
ret.push_back(new VarBase(v, new VarBase(true)));
}
return ret;
} }
std::vector<Variable*> PyLayer::ApplyGrad( std::vector<Variable*> PyLayer::ApplyGrad(
......
...@@ -112,31 +112,53 @@ class OpBase; ...@@ -112,31 +112,53 @@ class OpBase;
*/ */
class VarBase { class VarBase {
public: public:
VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {} // Internal interface, create VarBase from exist variable
VarBase(const std::string& name, framework::Variable* var, VarBase* grad,
explicit VarBase(bool stop_gradient) bool stop_gradient)
: VarBase(new framework::Variable(), : VarBase(name, var->Get<framework::LoDTensor>().type(),
stop_gradient ? nullptr : new VarBase(true), stop_gradient) {} var->Get<framework::LoDTensor>().dims(),
var->Get<framework::LoDTensor>().place(), var, grad,
VarBase(framework::Variable* var, VarBase* grad) stop_gradient, false) {}
: VarBase(var, grad, false) {}
// Python interface
VarBase(const std::string& name, const framework::proto::VarType::Type dtype,
const std::vector<int64_t>& shape, const platform::Place& place,
bool stop_gradient, bool persistable)
: VarBase(name, dtype, framework::make_ddim(shape), place, stop_gradient,
persistable) {}
// Internal interface, create VarBase from with ddim
VarBase(const std::string& name, const framework::proto::VarType::Type dtype,
const framework::DDim& shape, const platform::Place& place,
bool stop_gradient, bool persistable)
: VarBase(name, dtype, shape, place, nullptr, nullptr, stop_gradient,
persistable) {}
private: private:
VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient) VarBase(const std::string& name, framework::proto::VarType::Type dtype,
: name_(), const framework::DDim& shape, const platform::Place& place,
var_desc_(nullptr), framework::Variable* var, VarBase* grad, bool stop_gradient,
bool persistable)
: name_(name),
dtype_(dtype),
place_(place),
var_(var), var_(var),
grads_(grad), grads_(grad),
block_(nullptr),
persistable_(false),
stop_gradient_(stop_gradient), stop_gradient_(stop_gradient),
persistable_(persistable),
pre_op_(nullptr), pre_op_(nullptr),
pre_op_out_name_(), pre_op_out_name_(),
pre_op_out_idx_(-1) {} pre_op_out_idx_(-1) {
if (!var_) {
var_ = new framework::Variable();
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place_, dtype_);
}
}
public: public:
virtual ~VarBase() { virtual ~VarBase() {
// TODO(minqiyang): remove var desc from block desc
if (var_) { if (var_) {
delete var_; delete var_;
var_ = nullptr; var_ = nullptr;
...@@ -151,14 +173,30 @@ class VarBase { ...@@ -151,14 +173,30 @@ class VarBase {
pre_op_out_idx_ = -1; pre_op_out_idx_ = -1;
} }
inline OpBase* PreOp() const { return pre_op_; } inline void SetName(const std::string& name) { name_ = name; }
inline int PreOpOutIdx() const { return pre_op_out_idx_; } inline std::string Name() const { return name_; }
inline std::vector<int64_t> Shape() const {
if (var_->IsInitialized()) {
return framework::vectorize(var_->Get<framework::LoDTensor>().dims());
} else {
return {};
}
}
inline framework::proto::VarType::Type DType() const { return dtype_; }
inline void SetStopGradient(bool stop_gradient) { inline void SetStopGradient(bool stop_gradient) {
stop_gradient_ = stop_gradient; stop_gradient_ = stop_gradient;
} }
inline bool IsStopGradient() const { return stop_gradient_; } inline bool IsStopGradient() const { return stop_gradient_; }
inline void SetPersistable(bool persistable) { persistable_ = persistable; }
inline bool IsPersistable() const { return persistable_; }
inline OpBase* PreOp() const { return pre_op_; }
inline int PreOpOutIdx() const { return pre_op_out_idx_; }
void RunBackward(); void RunBackward();
inline void ResetPreOp(OpBase* op) { inline void ResetPreOp(OpBase* op) {
...@@ -180,7 +218,7 @@ class VarBase { ...@@ -180,7 +218,7 @@ class VarBase {
} }
void ClearGradient() { void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name(); VLOG(1) << "clear gradient of " << Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) { if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>(); auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant( operators::math::set_constant(
...@@ -196,23 +234,20 @@ class VarBase { ...@@ -196,23 +234,20 @@ class VarBase {
const bool blocking) const; const bool blocking) const;
inline std::string GradName() const { inline std::string GradName() const {
PADDLE_ENFORCE( return string::Sprintf("%s@IGrad", Name());
var_desc_,
"Couldn't get gradient variable's name, please call backward() first");
return string::Sprintf("%s@IGrad", var_desc_->Name());
} }
std::string name_; std::string name_;
framework::VarDesc* var_desc_; framework::proto::VarType::Type dtype_;
platform::Place place_;
framework::Variable* var_; framework::Variable* var_;
VarBase* grads_; VarBase* grads_;
framework::BlockDesc* block_;
bool persistable_;
private: private:
bool stop_gradient_; bool stop_gradient_;
bool persistable_;
OpBase* pre_op_; OpBase* pre_op_;
std::string pre_op_out_name_; std::string pre_op_out_name_;
int pre_op_out_idx_; int pre_op_out_idx_;
...@@ -223,11 +258,11 @@ class VarBase { ...@@ -223,11 +258,11 @@ class VarBase {
*/ */
class PYBIND11_HIDDEN OpBase { class PYBIND11_HIDDEN OpBase {
public: public:
OpBase() OpBase(const std::string& type)
: op_desc_(nullptr), : type_(type),
trace_id_(-1),
forward_id_(-1), forward_id_(-1),
backward_id_(-1), backward_id_(-1),
trace_id_(-1),
place_(platform::CPUPlace()), place_(platform::CPUPlace()),
backward_hooks_() {} backward_hooks_() {}
...@@ -249,13 +284,34 @@ class PYBIND11_HIDDEN OpBase { ...@@ -249,13 +284,34 @@ class PYBIND11_HIDDEN OpBase {
std::map<std::string, std::vector<VarBase*>> ApplyGrad(); std::map<std::string, std::vector<VarBase*>> ApplyGrad();
inline std::string Type() const { return type_; }
inline std::string GradOpType(size_t index) const {
PADDLE_ENFORCE_NOT_NULL(grad_op_descs_[index]);
return grad_op_descs_[index]->Type();
}
void RegisterBackwardHooks(const py::object& callable); void RegisterBackwardHooks(const py::object& callable);
void InvokeBackwardHooks(); void InvokeBackwardHooks();
// One of `op_desc_` or `forward_id_` is set, not both. void TrackPreOp(const VarBase* inp_var, const std::string& inp_name) {
// For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_. if (inp_var->PreOp() && !inp_var->IsStopGradient()) {
framework::OpDesc* op_desc_; VLOG(3) << "add pre op " << inp_var->PreOp()->Type() << " in slot "
<< inp_name;
pre_ops_[inp_name].push_back(inp_var->PreOp());
pre_ops_out_idx_[inp_name].push_back(inp_var->PreOpOutIdx());
} else {
VLOG(3) << "no pre op in slot " << inp_name
<< " input var stop_gradient: " << inp_var->IsStopGradient();
pre_ops_[inp_name].push_back(nullptr);
// pre_ops_out_idx_[inp_name].push_back(-1);
}
}
std::string type_;
// One of `trace_id_` or `forward_id_` is set, not both.
// For pure python PyLayer, use `forward_id_`, otherwise, use trace_id_.
int trace_id_;
int forward_id_; int forward_id_;
// When has backward, one of `grad_op_descs_` or `backward_id_` is set, // When has backward, one of `grad_op_descs_` or `backward_id_` is set,
...@@ -263,7 +319,6 @@ class PYBIND11_HIDDEN OpBase { ...@@ -263,7 +319,6 @@ class PYBIND11_HIDDEN OpBase {
// Note: each fwd op corresponds to a vector of bwd ops. // Note: each fwd op corresponds to a vector of bwd ops.
std::vector<framework::OpDesc*> grad_op_descs_; std::vector<framework::OpDesc*> grad_op_descs_;
int backward_id_; int backward_id_;
int trace_id_;
platform::Place place_; platform::Place place_;
...@@ -277,8 +332,6 @@ class PYBIND11_HIDDEN OpBase { ...@@ -277,8 +332,6 @@ class PYBIND11_HIDDEN OpBase {
// Outputs to a vector of bwd ops. // Outputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_output_vars_; std::vector<framework::VariableValueMap> grad_output_vars_;
framework::BlockDesc* block_;
std::vector<py::object> backward_hooks_; std::vector<py::object> backward_hooks_;
}; };
...@@ -303,8 +356,8 @@ class PyLayer { ...@@ -303,8 +356,8 @@ class PyLayer {
static int NumFuncs(); static int NumFuncs();
static std::vector<VarBase*> Apply(int func_id, static std::vector<framework::Variable*> Apply(
const std::vector<VarBase*>& inputs); int func_id, const std::vector<VarBase*>& inputs);
static std::vector<framework::Variable*> ApplyGrad( static std::vector<framework::Variable*> ApplyGrad(
int func_id, const std::vector<framework::Variable*>& inputs); int func_id, const std::vector<framework::Variable*>& inputs);
......
...@@ -56,15 +56,19 @@ void CreateGradOp(const framework::OpDesc& op_desc, ...@@ -56,15 +56,19 @@ void CreateGradOp(const framework::OpDesc& op_desc,
} }
} }
void InitVar(framework::Variable* var, framework::Variable* grad_var, void InitGrad(VarBase* var, platform::DeviceContext* dev_ctx) {
platform::DeviceContext* dev_ctx) { PADDLE_ENFORCE_NOT_NULL(var, "Could not get valid var base");
PADDLE_ENFORCE_NOT_NULL(dev_ctx, PADDLE_ENFORCE_NOT_NULL(dev_ctx,
"Could not get valid device from forward op"); "Could not get valid device from forward op");
auto& var_t = var->Get<framework::LoDTensor>();
grad_var->GetMutable<framework::LoDTensor>()->mutable_data<float>( if (var->grads_ == nullptr) {
var_t.dims(), dev_ctx->GetPlace()); auto& var_t = var->var_->Get<framework::LoDTensor>();
operators::math::set_constant( var->grads_ = new VarBase(var->GradName(), framework::proto::VarType::FP32,
*dev_ctx, grad_var->GetMutable<framework::LoDTensor>(), 0.0); framework::vectorize(var_t.dims()),
dev_ctx->GetPlace(), true, false);
auto grad_t = var->grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(*dev_ctx, grad_t, 0.0);
}
} }
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
...@@ -85,6 +89,62 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { ...@@ -85,6 +89,62 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
return result; return result;
} }
framework::VariableNameMap CreateInputVarNameMap(
const OpBase* op, const VarBasePtrMap& varbase_map) {
framework::VariableNameMap result;
auto& info_map = framework::OpInfoMap::Instance();
auto* op_info = info_map.GetNullable(op->Type());
if (op_info == nullptr || op_info->proto_ == nullptr) {
return result;
}
for (auto& in : op_info->Proto().inputs()) {
auto it = varbase_map.find(in.name());
if (it == varbase_map.end()) {
PADDLE_ENFORCE(in.dispensable());
result[in.name()] = {};
} else {
auto var_vector = it->second;
std::vector<std::string> args;
args.reserve(var_vector.size());
for (VarBase* var_base : var_vector) {
args.emplace_back(var_base->Name());
}
result[in.name()] = args;
}
}
return result;
}
framework::VariableNameMap CreateOutputVarNameMap(
const OpBase* op, const VarBasePtrMap& varbase_map) {
framework::VariableNameMap result;
auto& info_map = framework::OpInfoMap::Instance();
auto* op_info = info_map.GetNullable(op->Type());
if (op_info == nullptr || op_info->proto_ == nullptr) {
return result;
}
for (auto& out : op_info->Proto().outputs()) {
auto it = varbase_map.find(out.name());
if (it == varbase_map.end()) {
PADDLE_ENFORCE(out.dispensable());
result[out.name()] = {};
} else {
auto var_vector = it->second;
std::vector<std::string> args;
args.reserve(var_vector.size());
for (VarBase* var_base : var_vector) {
args.emplace_back(var_base->Name());
}
result[out.name()] = args;
}
}
return result;
}
Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) { Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
if (!FLAGS_tracer_profile_fname.empty()) { if (!FLAGS_tracer_profile_fname.empty()) {
std::call_once(gTracerProfileOnce, [] { std::call_once(gTracerProfileOnce, [] {
...@@ -101,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) { ...@@ -101,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, const VarBasePtrMap& outputs,
framework::BlockDesc* block, framework::AttributeMap attrs_map,
const platform::Place expected_place, const platform::Place expected_place,
const bool stop_gradient) { const bool stop_gradient) {
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
...@@ -110,40 +170,27 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -110,40 +170,27 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
} }
#endif #endif
std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_;
VLOG(3) << "tracer tracing " << op_desc->Type() << " trace id "
<< op->trace_id_;
op_desc->InferShape(*block);
op_desc->InferVarType(block);
std::unique_ptr<framework::OperatorBase> op_base =
framework::OpRegistry::CreateOp(*op_desc);
framework::VariableValueMap invars_map; framework::VariableValueMap invars_map;
framework::VariableValueMap outvars_map; framework::VariableValueMap outvars_map;
// Construct input_vars_map and output_vars_map
std::map<std::string, VarBase*> current_vars_map;
op->input_vars_ = inputs; op->input_vars_ = inputs;
for (auto it : op->input_vars_) { for (auto it : op->input_vars_) {
auto& invars = invars_map[it.first]; auto& invars = invars_map[it.first];
invars.reserve(it.second.size()); invars.reserve(it.second.size());
for (VarBase* inp : it.second) { for (VarBase* inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->Type(),
op->op_desc_->Type(), inp->var_desc_->Name()); inp->Name());
invars.emplace_back(inp->var_); invars.emplace_back(inp->var_);
vars[inp->var_desc_->Name()] = inp; op->TrackPreOp(inp, it.first);
if (inp->PreOp() && !inp->IsStopGradient()) { if (!stop_gradient) {
op->pre_ops_[it.first].push_back(inp->PreOp()); current_vars_map[inp->Name()] = inp;
op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx());
VLOG(3) << "add pre op " << inp->PreOp()->op_desc_->Type();
} else {
op->pre_ops_[it.first].push_back(nullptr);
} }
VLOG(3) << "input vname " << inp->var_desc_->Name() << " " VLOG(3) << "input var name: " << inp->Name()
<< inp->var_->IsInitialized() << " stop_gradient " << " inited: " << inp->var_->IsInitialized()
<< inp->IsStopGradient(); << " stop_grad: " << inp->IsStopGradient();
} }
} }
...@@ -152,25 +199,38 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -152,25 +199,38 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
auto& outvars = outvars_map[it.first]; auto& outvars = outvars_map[it.first];
const std::vector<VarBase*>& outputs = it.second; const std::vector<VarBase*>& outputs = it.second;
outvars.reserve(outputs.size()); outvars.reserve(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0U; i < outputs.size(); ++i) {
VarBase* out = outputs[i]; VarBase* out = outputs[i];
outvars.emplace_back(out->var_); outvars.emplace_back(out->var_);
vars[out->var_desc_->Name()] = out;
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
out->var_->GetMutable<framework::LoDTensor>();
} else {
LOG(ERROR) << "tracer doesn't support yet";
}
out->TrackPreOp(op, it.first, i, stop_gradient); out->TrackPreOp(op, it.first, i, stop_gradient);
if (!stop_gradient) {
current_vars_map[out->Name()] = out;
}
VLOG(3) << "output vname " << out->var_desc_->Name() << " " VLOG(3) << "input var name: " << out->Name()
<< out->var_->IsInitialized(); << " inited: " << out->var_->IsInitialized()
<< " stop_grad: " << out->IsStopGradient();
}
} }
// Check attrs and create op
framework::VariableNameMap invars_name_map =
CreateInputVarNameMap(op, inputs);
framework::VariableNameMap outvars_name_map =
CreateOutputVarNameMap(op, outputs);
auto& info = framework::OpInfoMap::Instance().Get(op->Type());
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_map);
} }
VLOG(3) << "tracer running " << op_desc->Type(); std::unique_ptr<framework::OperatorBase> op_base =
framework::OpRegistry::CreateOp(op->Type(), invars_name_map,
outvars_name_map, attrs_map);
// TODO(minqiyang): Support infer var type in imperative mode
// Run forward op
VLOG(3) << "tracer running " << op->Type();
framework::RuntimeContext ctx(invars_map, outvars_map); framework::RuntimeContext ctx(invars_map, outvars_map);
// TODO(panyx0718): Cache p. // TODO(panyx0718): Cache p.
...@@ -186,36 +246,44 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -186,36 +246,44 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx, framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx,
prepared_op.ctx, prepared_op.kernel_configs)); prepared_op.ctx, prepared_op.kernel_configs));
// construct backward op
std::set<std::string> vars_saved_for_backward; std::set<std::string> vars_saved_for_backward;
if (!stop_gradient) { if (!stop_gradient) {
VLOG(5) << "start construct backward op";
// construct grad op descs
std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc(
op->Type(), invars_name_map, outvars_name_map, attrs_map));
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var( std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
new std::unordered_map<std::string, std::string>()); new std::unordered_map<std::string, std::string>());
CreateGradOp(*op_desc, {}, {block}, &op->grad_op_descs_, grad_to_var.get()); // NOTE(minqiyang): We don't support control flow op in imperative now
// Add grad_block_ when we want to support it
CreateGradOp(*fwd_op_desc, {}, {}, &op->grad_op_descs_, grad_to_var.get());
op->grad_input_vars_.resize(op->grad_op_descs_.size()); VLOG(5) << "create grad op desc: " << op->grad_op_descs_[0]->Type();
op->grad_output_vars_.resize(op->grad_op_descs_.size());
for (size_t i = 0; i < op->grad_op_descs_.size(); ++i) { const size_t grad_op_count = op->grad_op_descs_.size();
op->grad_input_vars_.resize(grad_op_count);
op->grad_output_vars_.resize(grad_op_count);
for (size_t i = 0; i < grad_op_count; ++i) {
framework::OpDesc* grad_op_desc = op->grad_op_descs_[i]; framework::OpDesc* grad_op_desc = op->grad_op_descs_[i];
for (auto it : grad_op_desc->Inputs()) { for (auto it : grad_op_desc->Inputs()) {
auto& grad_in_vars = op->grad_input_vars_[i][it.first]; auto& grad_in_vars = op->grad_input_vars_[i][it.first];
grad_in_vars.reserve(it.second.size());
for (const std::string& grad_invar : it.second) { for (const std::string& grad_invar : it.second) {
block->FindRecursiveOrCreateVar(grad_invar);
auto var_it = grad_to_var->find(grad_invar); auto var_it = grad_to_var->find(grad_invar);
if (var_it == grad_to_var->end()) { if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar); auto fwd_var_it = current_vars_map.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end()); PADDLE_ENFORCE(fwd_var_it != current_vars_map.end());
// Forward inputs or outputs. // Forward inputs or outputs.
grad_in_vars.push_back(fwd_var_it->second->var_); grad_in_vars.emplace_back(fwd_var_it->second->var_);
} else { } else {
VarBase* var = vars[var_it->second]; VarBase* var = current_vars_map[var_it->second];
if (!var->grads_->var_->IsInitialized()) { InitGrad(var, prepared_op.GetDeviceContext());
InitVar(var->var_, var->grads_->var_,
prepared_op.GetDeviceContext());
}
// Douts. // Douts.
grad_in_vars.push_back(var->grads_->var_); grad_in_vars.emplace_back(var->grads_->var_);
} }
vars_saved_for_backward.insert(it.first); vars_saved_for_backward.insert(it.first);
...@@ -225,48 +293,48 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -225,48 +293,48 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
for (auto it : grad_op_desc->Outputs()) { for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[i][it.first]; auto& grad_out_vars = op->grad_output_vars_[i][it.first];
for (const std::string& grad_outvar : it.second) { for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = grad_to_var->find(grad_outvar); auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end(), PADDLE_ENFORCE(var_it != grad_to_var->end(),
"Could not found the grad op output var, should this " "Could not found the grad op output var, should this "
"operator %s's stop gradient be True", "operator %s's stop gradient be True",
op_desc->Type()); op->Type());
VarBase* var = vars[var_it->second]; VarBase* var = current_vars_map[var_it->second];
if (!var->grads_->var_->IsInitialized()) { InitGrad(var, prepared_op.GetDeviceContext());
InitVar(var->var_, var->grads_->var_,
prepared_op.GetDeviceContext());
}
grad_out_vars.push_back(var->grads_->var_); grad_out_vars.push_back(var->grads_->var_);
} }
} }
} }
} }
op->block_ = block;
return vars_saved_for_backward; return vars_saved_for_backward;
} }
std::vector<VarBase*> Tracer::PyTrace(OpBase* op, std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
const std::vector<VarBase*>& inputs, const std::vector<VarBase*>& inputs,
bool stop_gradient) { bool stop_gradient) {
VLOG(3) << "py_trace"; VLOG(3) << "py_trace " << op->Type();
op->input_vars_[PyLayer::kFwdInp] = inputs; op->input_vars_[PyLayer::kFwdInp] = inputs;
op->output_vars_[PyLayer::kFwdOut] = PyLayer::Apply(op->forward_id_, inputs);
std::vector<framework::Variable*> ret_vars =
PyLayer::Apply(op->forward_id_, inputs);
for (VarBase* inp : inputs) { for (VarBase* inp : inputs) {
if (inp->PreOp() && !inp->IsStopGradient()) { op->TrackPreOp(inp, PyLayer::kFwdInp);
op->pre_ops_[PyLayer::kFwdInp].push_back(inp->PreOp());
op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->PreOpOutIdx());
} else {
op->pre_ops_[PyLayer::kFwdInp].push_back(nullptr);
}
} }
auto& outputs = op->output_vars_[PyLayer::kFwdOut]; std::vector<VarBase*>& outputs = op->output_vars_[PyLayer::kFwdOut];
for (size_t i = 0; i < outputs.size(); ++i) { outputs.reserve(ret_vars.size());
VarBase* out = outputs[i]; for (size_t i = 0U; i != ret_vars.size(); ++i) {
framework::Variable* v = ret_vars[i];
VarBase* out = new VarBase(string::Sprintf("%s_out_%d", op->Type(), i), v,
nullptr, stop_gradient);
outputs.emplace_back(out);
out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient); out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient);
} }
if (!stop_gradient) { if (!stop_gradient) {
VLOG(5) << "start construct backward op";
op->grad_input_vars_.resize(1); op->grad_input_vars_.resize(1);
op->grad_output_vars_.resize(1); op->grad_output_vars_.resize(1);
auto& grad_input_vars = auto& grad_input_vars =
...@@ -281,23 +349,16 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -281,23 +349,16 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
grad_input_vars.push_back(out->var_); grad_input_vars.push_back(out->var_);
} }
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
platform::CPUPlace place; platform::CPUPlace place;
for (VarBase* out : outputs) { for (VarBase* out : outputs) {
InitGrad(out, platform::DeviceContextPool::Instance().Get(place));
grad_input_vars.push_back(out->grads_->var_); grad_input_vars.push_back(out->grads_->var_);
if (!grad_input_vars.back()->IsInitialized()) {
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
InitVar(out->var_, grad_input_vars.back(),
platform::DeviceContextPool::Instance().Get(place));
}
} }
for (const VarBase* inp : inputs) { for (VarBase* inp : inputs) {
InitGrad(inp, platform::DeviceContextPool::Instance().Get(place));
grad_output_vars.push_back(inp->grads_->var_); grad_output_vars.push_back(inp->grads_->var_);
if (!grad_output_vars.back()->IsInitialized()) {
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
InitVar(inp->var_, grad_output_vars.back(),
platform::DeviceContextPool::Instance().Get(place));
}
} }
} }
return outputs; return outputs;
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <map> #include <map>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
...@@ -34,7 +36,8 @@ void CreateGradOp(const framework::OpDesc& op_desc, ...@@ -34,7 +36,8 @@ void CreateGradOp(const framework::OpDesc& op_desc,
framework::OpDesc** grad_op_desc, framework::OpDesc** grad_op_desc,
std::unordered_map<std::string, std::string>* grad_to_var); std::unordered_map<std::string, std::string>* grad_to_var);
void InitVar(framework::Variable* var, framework::Variable* grad_var); void InitVar(const VarBase* var, framework::Variable* grad_var,
platform::DeviceContext* dev_ctx);
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs); platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs);
...@@ -46,7 +49,7 @@ class Tracer { ...@@ -46,7 +49,7 @@ class Tracer {
std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs, std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, const VarBasePtrMap& outputs,
framework::BlockDesc* block, framework::AttributeMap attrs_map,
const platform::Place expected_place, const platform::Place expected_place,
const bool stop_gradient = false); const bool stop_gradient = false);
......
...@@ -126,15 +126,20 @@ void ZeroCopyTensor::copy_to_cpu(T *data) { ...@@ -126,15 +126,20 @@ void ZeroCopyTensor::copy_to_cpu(T *data) {
} }
template void ZeroCopyTensor::copy_from_cpu<float>(const float *data); template void ZeroCopyTensor::copy_from_cpu<float>(const float *data);
template void ZeroCopyTensor::copy_from_cpu<int64_t>(const int64_t *data); template void ZeroCopyTensor::copy_from_cpu<int64_t>(const int64_t *data);
template void ZeroCopyTensor::copy_from_cpu<int32_t>(const int32_t *data);
template void ZeroCopyTensor::copy_to_cpu<float>(float *data); template void ZeroCopyTensor::copy_to_cpu<float>(float *data);
template void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data); template void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data);
template void ZeroCopyTensor::copy_to_cpu<int32_t>(int32_t *data);
template float *ZeroCopyTensor::data<float>(PaddlePlace *place, template float *ZeroCopyTensor::data<float>(PaddlePlace *place,
int *size) const; int *size) const;
template int64_t *ZeroCopyTensor::data<int64_t>(PaddlePlace *place, template int64_t *ZeroCopyTensor::data<int64_t>(PaddlePlace *place,
int *size) const; int *size) const;
template int32_t *ZeroCopyTensor::data<int32_t>(PaddlePlace *place,
int *size) const;
template float *ZeroCopyTensor::mutable_data<float>(PaddlePlace place); template float *ZeroCopyTensor::mutable_data<float>(PaddlePlace place);
template int64_t *ZeroCopyTensor::mutable_data<int64_t>(PaddlePlace place); template int64_t *ZeroCopyTensor::mutable_data<int64_t>(PaddlePlace place);
template int32_t *ZeroCopyTensor::mutable_data<int32_t>(PaddlePlace place);
void *ZeroCopyTensor::FindTensor() const { void *ZeroCopyTensor::FindTensor() const {
PADDLE_ENFORCE(!name_.empty(), PADDLE_ENFORCE(!name_.empty(),
......
...@@ -139,9 +139,8 @@ static void TensorAssignData(PaddleTensor *tensor, ...@@ -139,9 +139,8 @@ static void TensorAssignData(PaddleTensor *tensor,
} }
template <typename T> template <typename T>
static int ZeroCopyTensorAssignData(ZeroCopyTensor *tensor, static void ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
const std::vector<std::vector<T>> &data) { const std::vector<std::vector<T>> &data) {
int size{0};
auto *ptr = tensor->mutable_data<T>(PaddlePlace::kCPU); auto *ptr = tensor->mutable_data<T>(PaddlePlace::kCPU);
int c = 0; int c = 0;
for (const auto &f : data) { for (const auto &f : data) {
...@@ -149,7 +148,15 @@ static int ZeroCopyTensorAssignData(ZeroCopyTensor *tensor, ...@@ -149,7 +148,15 @@ static int ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
ptr[c++] = v; ptr[c++] = v;
} }
} }
return size; }
template <typename T>
static void ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
const PaddleBuf &data) {
auto *ptr = tensor->mutable_data<T>(PaddlePlace::kCPU);
for (size_t i = 0; i < data.length() / sizeof(T); i++) {
ptr[i] = *(reinterpret_cast<T *>(data.data()) + i);
}
} }
static bool CompareTensor(const PaddleTensor &a, const PaddleTensor &b) { static bool CompareTensor(const PaddleTensor &a, const PaddleTensor &b) {
......
...@@ -107,6 +107,9 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -107,6 +107,9 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false);
}
} }
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...@@ -131,7 +134,7 @@ TEST(Analyzer_Pyramid_DNN, profile) { ...@@ -131,7 +134,7 @@ TEST(Analyzer_Pyramid_DNN, profile) {
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg), TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads); input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { if (FLAGS_num_threads == 1 && !FLAGS_test_all_data && !FLAGS_zero_copy) {
PADDLE_ENFORCE_EQ(outputs.size(), 1UL); PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
size_t size = GetSize(outputs[0]); size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(size, 0); PADDLE_ENFORCE_GT(size, 0);
...@@ -166,6 +169,19 @@ TEST(Analyzer_Pyramid_DNN, compare) { ...@@ -166,6 +169,19 @@ TEST(Analyzer_Pyramid_DNN, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare result of AnalysisConfig and AnalysisConfig + ZeroCopy
TEST(Analyzer_Pyramid_DNN, compare_zero_copy) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
std::vector<std::string> outputs_name;
outputs_name.emplace_back("cos_sim_2.tmp_0");
CompareAnalysisAndZeroCopy(reinterpret_cast<PaddlePredictor::Config *>(&cfg),
input_slots_all, outputs_name);
}
// Compare Deterministic result // Compare Deterministic result
TEST(Analyzer_Pyramid_DNN, compare_determine) { TEST(Analyzer_Pyramid_DNN, compare_determine) {
AnalysisConfig cfg; AnalysisConfig cfg;
......
...@@ -207,6 +207,9 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -207,6 +207,9 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false);
}
} }
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...@@ -285,133 +288,17 @@ TEST(Analyzer_rnn1, multi_thread) { ...@@ -285,133 +288,17 @@ TEST(Analyzer_rnn1, multi_thread) {
input_slots_all, &outputs, 2 /* multi_thread */); input_slots_all, &outputs, 2 /* multi_thread */);
} }
// Validate that the AnalysisPredictor + ZeroCopyTensor really works by testing // Compare result of AnalysisConfig and AnalysisConfig + ZeroCopy
// on the complex RNN1 model. TEST(Analyzer_rnn1, compare_zero_copy) {
TEST(Analyzer_rnn1, ZeroCopy) { AnalysisConfig cfg;
AnalysisConfig config; SetConfig(&cfg);
SetConfig(&config);
config.SwitchUseFeedFetchOps(false);
PaddlePlace place;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
config.SwitchUseFeedFetchOps(true);
auto native_predictor =
CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig());
config.SwitchUseFeedFetchOps(
true); // the analysis predictor needs feed/fetch.
auto analysis_predictor = CreatePaddlePredictor<AnalysisConfig>(config);
#define NEW_TENSOR(name__) \
auto name__##_tensor = predictor->GetInputTensor(#name__);
NEW_TENSOR(data_lod_attention);
NEW_TENSOR(cell_init);
NEW_TENSOR(data);
NEW_TENSOR(week);
NEW_TENSOR(minute);
NEW_TENSOR(hidden_init);
// Prepare data for AnalysisPredictor
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
PrepareZeroCopyInputs(data_lod_attention_tensor.get(), cell_init_tensor.get(),
data_tensor.get(), hidden_init_tensor.get(),
week_tensor.get(), minute_tensor.get(), &data,
FLAGS_batch_size);
// Prepare data for NativePredictor
std::vector<std::vector<PaddleTensor>> native_inputs;
SetInput(&native_inputs);
std::vector<PaddleTensor> native_outputs;
std::vector<PaddleTensor> analysis_outputs;
auto output_tensor = predictor->GetOutputTensor("final_output.tmp_1");
// Run analysis predictor
int num_ops;
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 1);
ASSERT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM
ASSERT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1);
ASSERT_EQ(num_ops,
13); // After graph optimization, only 13 operators exists.
Timer timer;
double total_time{0};
for (int i = 0; i < FLAGS_repeat; i++) {
timer.tic();
predictor->ZeroCopyRun();
total_time += timer.toc();
}
LOG(INFO) << "ZeroCopy output: " << DescribeZeroCopyTensor(*output_tensor);
ASSERT_TRUE(native_predictor->Run(native_inputs.front(), &native_outputs));
LOG(INFO) << "native output " << DescribeTensor(native_outputs.front());
int output_size{0}; // this is the number of elements not memory size
auto *zero_copy_data = output_tensor->data<float>(&place, &output_size);
auto *native_data = static_cast<float *>(native_outputs.front().data.data());
for (int i = 0; i < output_size; i++) {
EXPECT_NEAR(zero_copy_data[i], native_data[i], 1e-3);
}
}
TEST(Analyzer_rnn1, ZeroCopyMultiThread) {
AnalysisConfig config;
SetConfig(&config);
config.SwitchUseFeedFetchOps(false);
#define NEW_TENSOR(name__) \
auto name__##_tensor = predictor->GetInputTensor(#name__);
std::vector<std::unique_ptr<PaddlePredictor>> predictors;
predictors.emplace_back(CreatePaddlePredictor<AnalysisConfig>(config));
for (int tid = 1; tid < FLAGS_num_threads; tid++) {
predictors.emplace_back(predictors.front()->Clone());
}
double total_time_of_threads{0};
std::vector<std::thread> threads;
for (int tid = 0; tid < FLAGS_num_threads; tid++) {
threads.emplace_back([&, tid] {
auto &predictor = predictors[tid];
NEW_TENSOR(data_lod_attention);
NEW_TENSOR(cell_init);
NEW_TENSOR(data);
NEW_TENSOR(week);
NEW_TENSOR(minute);
NEW_TENSOR(hidden_init);
// Prepare data for AnalysisPredictor
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
Timer timer;
double total_time{0};
for (int i = 0; i < FLAGS_repeat; i++) {
PrepareZeroCopyInputs(data_lod_attention_tensor.get(),
cell_init_tensor.get(), data_tensor.get(),
hidden_init_tensor.get(), week_tensor.get(),
minute_tensor.get(), &data, FLAGS_batch_size);
timer.tic();
predictor->ZeroCopyRun();
total_time += timer.toc();
}
total_time_of_threads += total_time;
LOG(INFO) << "thread time: " << total_time / FLAGS_repeat;
});
}
for (auto &t : threads) {
t.join();
}
LOG(INFO) << "average time: " std::vector<std::vector<PaddleTensor>> input_slots_all;
<< total_time_of_threads / FLAGS_num_threads / FLAGS_repeat; SetInput(&input_slots_all);
std::vector<std::string> outputs_name;
outputs_name.emplace_back("final_output.tmp_1");
CompareAnalysisAndZeroCopy(reinterpret_cast<PaddlePredictor::Config *>(&cfg),
input_slots_all, outputs_name);
} }
} // namespace inference } // namespace inference
......
...@@ -144,6 +144,9 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { ...@@ -144,6 +144,9 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrDebug(); cfg->SwitchIrDebug();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false);
}
if (use_mkldnn) { if (use_mkldnn) {
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
} }
...@@ -184,10 +187,10 @@ TEST(Analyzer_seq_pool1, compare_determine) { ...@@ -184,10 +187,10 @@ TEST(Analyzer_seq_pool1, compare_determine) {
input_slots_all); input_slots_all);
} }
void analysis_fuse_statis(bool use_zerocopy) { // Check the fuse status
TEST(Analyzer_seq_pool1, fuse_statis) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
cfg.SwitchUseFeedFetchOps(!use_zerocopy);
int num_ops; int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg); auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops); auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
...@@ -203,137 +206,17 @@ void analysis_fuse_statis(bool use_zerocopy) { ...@@ -203,137 +206,17 @@ void analysis_fuse_statis(bool use_zerocopy) {
EXPECT_EQ(num_ops, 171); EXPECT_EQ(num_ops, 171);
} }
// Check the fuse status // Compare result of AnalysisConfig and AnalysisConfig + ZeroCopy
TEST(Analyzer_seq_pool1, fuse_statis) { analysis_fuse_statis(false); } TEST(Analyzer_seq_pool1, compare_zero_copy) {
AnalysisConfig cfg;
void PrepareZeroCopyInputs( SetConfig(&cfg);
const std::unique_ptr<PaddlePredictor> &predictor,
std::vector<std::unique_ptr<ZeroCopyTensor>> *inputs) {
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
// only feed one batch
const auto &one_batch = data.NextBatch();
inputs->clear();
for (size_t i = 0; i < one_batch.size(); ++i) {
auto &slot = one_batch[i];
auto tensor = predictor->GetInputTensor(slot.name + "_embed");
tensor->Reshape(slot.shape);
tensor->SetLoD({slot.lod});
ZeroCopyTensorAssignData<float>(tensor.get(), slot.data);
inputs->emplace_back(std::move(tensor));
}
}
// return the output values
std::vector<float> zerocopy_profile(int repeat_times) {
AnalysisConfig config;
SetConfig(&config);
config.SwitchUseFeedFetchOps(false);
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
std::vector<std::unique_ptr<ZeroCopyTensor>> inputs;
PrepareZeroCopyInputs(predictor, &inputs);
auto output_tensor = predictor->GetOutputTensor(out_var_name);
Timer timer;
LOG(INFO) << "Warm up run...";
timer.tic();
predictor->ZeroCopyRun();
PrintTime(FLAGS_batch_size, 1, 1, 0, timer.toc(), 1);
if (FLAGS_profile) {
paddle::platform::ResetProfiler();
}
LOG(INFO) << "Run " << repeat_times << " times...";
timer.tic();
for (int i = 0; i < repeat_times; i++) {
predictor->ZeroCopyRun();
}
PrintTime(FLAGS_batch_size, repeat_times, 1, 0, timer.toc() / repeat_times,
1);
LOG(INFO) << "ZeroCopy output: " << DescribeZeroCopyTensor(*output_tensor);
PaddlePlace place;
int output_size{0};
auto *pdata = output_tensor->data<float>(&place, &output_size);
std::vector<float> res(output_size);
for (int i = 0; i < output_size; ++i) {
res[i] = pdata[i];
}
return res;
}
TEST(Analyzer_seq_pool1, zerocopy_profile) { zerocopy_profile(FLAGS_repeat); }
TEST(Analyzer_seq_pool1, zerocopy_profile_threads) {
AnalysisConfig config;
SetConfig(&config);
config.SwitchUseFeedFetchOps(false);
std::vector<std::unique_ptr<PaddlePredictor>> predictors;
predictors.emplace_back(CreatePaddlePredictor<AnalysisConfig>(config));
for (int tid = 1; tid < FLAGS_num_threads; tid++) {
predictors.emplace_back(predictors.front()->Clone());
}
double total_time_of_threads{0};
std::vector<std::thread> threads;
for (int tid = 0; tid < FLAGS_num_threads; tid++) {
threads.emplace_back([&, tid] {
auto &predictor = predictors[tid];
std::vector<std::unique_ptr<ZeroCopyTensor>> inputs;
PrepareZeroCopyInputs(predictor, &inputs);
auto output_tensor = predictor->GetOutputTensor(out_var_name);
Timer timer;
double total_time{0};
LOG(INFO) << "Warm up run...";
timer.tic();
predictor->ZeroCopyRun();
PrintTime(FLAGS_batch_size, 1, FLAGS_num_threads, tid, timer.toc(), 1);
if (FLAGS_profile) {
paddle::platform::ResetProfiler();
}
int repeat_times = FLAGS_repeat;
LOG(INFO) << "Run " << repeat_times << " times...";
timer.tic();
for (int i = 0; i < repeat_times; i++) {
predictor->ZeroCopyRun();
}
total_time += timer.toc();
total_time_of_threads += total_time;
LOG(INFO) << "thread time: " << total_time / repeat_times;
});
}
for (auto &t : threads) {
t.join();
}
LOG(INFO) << "average time: "
<< total_time_of_threads / FLAGS_num_threads / FLAGS_repeat;
}
TEST(Analyzer_seq_pool1, zerocopy_fuse_statis) { analysis_fuse_statis(true); }
TEST(Analyzer_seq_pool1, zerocopy_compare_native) {
AnalysisConfig config;
SetConfig(&config);
config.SwitchUseFeedFetchOps(true);
auto predictor = CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig());
std::vector<PaddleTensor> native_outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all); SetInput(&input_slots_all);
ASSERT_TRUE(predictor->Run(input_slots_all[0], &native_outputs)); std::vector<std::string> outputs_name;
EXPECT_EQ(native_outputs.size(), 1UL); outputs_name.emplace_back(out_var_name);
CompareAnalysisAndZeroCopy(reinterpret_cast<PaddlePredictor::Config *>(&cfg),
auto zerocopy_output = zerocopy_profile(1); input_slots_all, outputs_name);
EXPECT_EQ(zerocopy_output.size() * sizeof(float),
native_outputs.front().data.length());
auto *native_data = static_cast<float *>(native_outputs.front().data.data());
for (size_t i = 0; i < zerocopy_output.size(); ++i) {
EXPECT_LT(
std::fabs((zerocopy_output[i] - native_data[i]) / zerocopy_output[i]),
1e-3);
}
} }
} // namespace analysis } // namespace analysis
......
...@@ -50,6 +50,7 @@ DEFINE_bool(use_analysis, true, ...@@ -50,6 +50,7 @@ DEFINE_bool(use_analysis, true,
DEFINE_bool(record_benchmark, false, DEFINE_bool(record_benchmark, false,
"Record benchmark after profiling the model"); "Record benchmark after profiling the model");
DEFINE_double(accuracy, 1e-3, "Result Accuracy."); DEFINE_double(accuracy, 1e-3, "Result Accuracy.");
DEFINE_bool(zero_copy, false, "Use ZeroCopy to speedup Feed/Fetch.");
DECLARE_bool(profile); DECLARE_bool(profile);
DECLARE_int32(paddle_num_threads); DECLARE_int32(paddle_num_threads);
...@@ -67,6 +68,7 @@ void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) { ...@@ -67,6 +68,7 @@ void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
LOG(INFO) << analysis_config->ToNativeConfig(); LOG(INFO) << analysis_config->ToNativeConfig();
} }
// Compare result between two PaddleTensor
void CompareResult(const std::vector<PaddleTensor> &outputs, void CompareResult(const std::vector<PaddleTensor> &outputs,
const std::vector<PaddleTensor> &ref_outputs) { const std::vector<PaddleTensor> &ref_outputs) {
EXPECT_GT(outputs.size(), 0UL); EXPECT_GT(outputs.size(), 0UL);
...@@ -108,6 +110,50 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -108,6 +110,50 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
} }
} }
// Compare result between a PaddleTensor and a ZeroCopyTensor
void CompareResult(const std::vector<PaddleTensor> &outputs,
const std::vector<ZeroCopyTensor> &ref_outputs) {
EXPECT_GT(outputs.size(), 0UL);
EXPECT_EQ(outputs.size(), ref_outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
auto &out = outputs[i];
auto &ref_out = ref_outputs[i];
size_t size = VecReduceToInt(out.shape);
EXPECT_GT(size, 0UL);
int ref_size = 0; // this is the number of elements not memory size
PaddlePlace place;
switch (out.dtype) {
case PaddleDType::INT64: {
int64_t *pdata = static_cast<int64_t *>(out.data.data());
int64_t *pdata_ref = ref_out.data<int64_t>(&place, &ref_size);
EXPECT_EQ(size, ref_size);
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::FLOAT32: {
float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = ref_out.data<float>(&place, &ref_size);
EXPECT_EQ(size, ref_size);
for (size_t j = 0; j < size; ++j) {
CHECK_LE(std::abs(pdata_ref[j] - pdata[j]), FLAGS_accuracy);
}
break;
}
case PaddleDType::INT32: {
int32_t *pdata = static_cast<int32_t *>(out.data.data());
int32_t *pdata_ref = ref_out.data<int32_t>(&place, &ref_size);
EXPECT_EQ(size, ref_size);
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
}
}
}
std::unique_ptr<PaddlePredictor> CreateTestPredictor( std::unique_ptr<PaddlePredictor> CreateTestPredictor(
const PaddlePredictor::Config *config, bool use_analysis = true) { const PaddlePredictor::Config *config, bool use_analysis = true) {
const auto *analysis_config = const auto *analysis_config =
...@@ -205,52 +251,99 @@ void GetInputPerBatch(const std::vector<std::vector<int64_t>> &in, ...@@ -205,52 +251,99 @@ void GetInputPerBatch(const std::vector<std::vector<int64_t>> &in,
} }
} }
void TestOneThreadPrediction( void ConvertPaddleTensorToZeroCopyTensor(
const PaddlePredictor::Config *config, PaddlePredictor *predictor, const std::vector<PaddleTensor> &inputs) {
for (size_t i = 0; i < inputs.size(); i++) {
auto input = inputs[i];
auto tensor = predictor->GetInputTensor(input.name);
tensor->Reshape(input.shape);
tensor->SetLoD({input.lod});
if (input.dtype == PaddleDType::INT64) {
ZeroCopyTensorAssignData<int64_t>(tensor.get(), input.data);
} else if (input.dtype == PaddleDType::FLOAT32) {
ZeroCopyTensorAssignData<float>(tensor.get(), input.data);
} else if (input.dtype == PaddleDType::INT32) {
ZeroCopyTensorAssignData<int32_t>(tensor.get(), input.data);
} else {
LOG(ERROR) << "unsupported feed type " << input.dtype;
}
}
}
void PredictionWarmUp(PaddlePredictor *predictor,
const std::vector<std::vector<PaddleTensor>> &inputs, const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, bool use_analysis = true) { std::vector<PaddleTensor> *outputs, int num_threads,
int tid) {
int batch_size = FLAGS_batch_size; int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat; LOG(INFO) << "Running thread " << tid << ", warm up run...";
auto predictor = CreateTestPredictor(config, use_analysis); if (FLAGS_zero_copy) {
ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[0]);
// warmup run }
LOG(INFO) << "Warm up run...";
{
Timer warmup_timer; Timer warmup_timer;
warmup_timer.tic(); warmup_timer.tic();
if (!FLAGS_zero_copy) {
predictor->Run(inputs[0], outputs, batch_size); predictor->Run(inputs[0], outputs, batch_size);
PrintTime(batch_size, 1, 1, 0, warmup_timer.toc(), 1); } else {
predictor->ZeroCopyRun();
}
PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1);
if (FLAGS_profile) { if (FLAGS_profile) {
paddle::platform::ResetProfiler(); paddle::platform::ResetProfiler();
} }
} }
LOG(INFO) << "Run " << num_times << " times..."; void PredictionRun(PaddlePredictor *predictor,
{ const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, int num_threads,
int tid) {
int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat;
LOG(INFO) << "Thread " << tid << " run " << num_times << " times...";
Timer run_timer; Timer run_timer;
run_timer.tic(); double elapsed_time = 0;
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
ProfilerStart("paddle_inference.prof"); ProfilerStart("paddle_inference.prof");
#endif #endif
for (int i = 0; i < num_times; i++) { if (!FLAGS_zero_copy) {
for (size_t j = 0; j < inputs.size(); j++) { run_timer.tic();
predictor->Run(inputs[j], outputs, batch_size); for (size_t i = 0; i < inputs.size(); i++) {
for (int j = 0; j < num_times; j++) {
predictor->Run(inputs[i], outputs, batch_size);
}
}
elapsed_time = run_timer.toc();
} else {
for (size_t i = 0; i < inputs.size(); i++) {
ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[i]);
run_timer.tic();
for (int j = 0; j < num_times; j++) {
predictor->ZeroCopyRun();
}
elapsed_time += run_timer.toc();
} }
} }
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
ProfilerStop(); ProfilerStop();
#endif #endif
double latency = run_timer.toc() / (num_times > 1 ? num_times : 1); PrintTime(batch_size, num_times, num_threads, tid, elapsed_time / num_times,
PrintTime(batch_size, num_times, 1, 0, latency, inputs.size()); inputs.size());
if (FLAGS_record_benchmark) { if (FLAGS_record_benchmark) {
Benchmark benchmark; Benchmark benchmark;
benchmark.SetName(FLAGS_model_name); benchmark.SetName(FLAGS_model_name);
benchmark.SetBatchSize(batch_size); benchmark.SetBatchSize(batch_size);
benchmark.SetLatency(latency); benchmark.SetLatency(elapsed_time / num_times);
benchmark.PersistToFile("benchmark_record.txt"); benchmark.PersistToFile("benchmark_record.txt");
} }
} }
void TestOneThreadPrediction(
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, bool use_analysis = true) {
auto predictor = CreateTestPredictor(config, use_analysis);
PredictionWarmUp(predictor.get(), inputs, outputs, 1, 0);
PredictionRun(predictor.get(), inputs, outputs, 1, 0);
} }
void TestMultiThreadPrediction( void TestMultiThreadPrediction(
...@@ -258,8 +351,6 @@ void TestMultiThreadPrediction( ...@@ -258,8 +351,6 @@ void TestMultiThreadPrediction(
const std::vector<std::vector<PaddleTensor>> &inputs, const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, int num_threads, std::vector<PaddleTensor> *outputs, int num_threads,
bool use_analysis = true) { bool use_analysis = true) {
int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat;
std::vector<std::thread> threads; std::vector<std::thread> threads;
std::vector<std::unique_ptr<PaddlePredictor>> predictors; std::vector<std::unique_ptr<PaddlePredictor>> predictors;
predictors.emplace_back(CreateTestPredictor(config, use_analysis)); predictors.emplace_back(CreateTestPredictor(config, use_analysis));
...@@ -267,7 +358,6 @@ void TestMultiThreadPrediction( ...@@ -267,7 +358,6 @@ void TestMultiThreadPrediction(
predictors.emplace_back(predictors.front()->Clone()); predictors.emplace_back(predictors.front()->Clone());
} }
size_t total_time{0};
for (int tid = 0; tid < num_threads; ++tid) { for (int tid = 0; tid < num_threads; ++tid) {
threads.emplace_back([&, tid]() { threads.emplace_back([&, tid]() {
// Each thread should have local inputs and outputs. // Each thread should have local inputs and outputs.
...@@ -280,34 +370,8 @@ void TestMultiThreadPrediction( ...@@ -280,34 +370,8 @@ void TestMultiThreadPrediction(
->SetMkldnnThreadID(static_cast<int>(tid) + 1); ->SetMkldnnThreadID(static_cast<int>(tid) + 1);
} }
#endif #endif
PredictionWarmUp(predictor.get(), inputs, outputs, num_threads, tid);
// warmup run PredictionRun(predictor.get(), inputs, outputs, num_threads, tid);
LOG(INFO) << "Running thread " << tid << ", warm up run...";
{
Timer warmup_timer;
warmup_timer.tic();
predictor->Run(inputs[0], outputs, batch_size);
PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1);
if (FLAGS_profile) {
paddle::platform::ResetProfiler();
}
}
LOG(INFO) << "Thread " << tid << " run " << num_times << " times...";
{
Timer timer;
timer.tic();
for (int i = 0; i < num_times; i++) {
for (const auto &input : inputs) {
ASSERT_TRUE(predictor->Run(input, &outputs_tid));
}
}
auto time = timer.toc();
total_time += time;
PrintTime(batch_size, num_times, num_threads, tid, time / num_times,
inputs.size());
}
}); });
} }
for (int i = 0; i < num_threads; ++i) { for (int i = 0; i < num_threads; ++i) {
...@@ -367,6 +431,31 @@ void CompareNativeAndAnalysis( ...@@ -367,6 +431,31 @@ void CompareNativeAndAnalysis(
CompareResult(analysis_outputs, native_outputs); CompareResult(analysis_outputs, native_outputs);
} }
void CompareAnalysisAndZeroCopy(
PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs,
const std::vector<std::string> &outputs_name) {
int batch_size = FLAGS_batch_size;
// analysis
std::vector<PaddleTensor> analysis_outputs;
auto predictor = CreateTestPredictor(config, true);
predictor->Run(inputs[0], &analysis_outputs, batch_size);
// analysis + zero_copy
std::vector<ZeroCopyTensor> zerocopy_outputs;
reinterpret_cast<AnalysisConfig *>(config)->SwitchUseFeedFetchOps(false);
predictor = CreateTestPredictor(config, true);
ConvertPaddleTensorToZeroCopyTensor(predictor.get(), inputs[0]);
predictor->ZeroCopyRun();
for (size_t i = 0; i < outputs_name.size(); i++) {
ZeroCopyTensor zerocopy_output =
*predictor->GetOutputTensor(outputs_name[i]).get();
zerocopy_outputs.emplace_back(zerocopy_output);
LOG(INFO) << "ZeroCopy output: " << DescribeZeroCopyTensor(zerocopy_output);
}
// compare
CompareResult(analysis_outputs, zerocopy_outputs);
}
template <typename T> template <typename T>
std::string LoDTensorSummary(const framework::LoDTensor &tensor) { std::string LoDTensorSummary(const framework::LoDTensor &tensor) {
std::stringstream ss; std::stringstream ss;
......
...@@ -30,18 +30,19 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME) ...@@ -30,18 +30,19 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME)
${EXTERNAL_PROJECT_NAME} ${EXTERNAL_PROJECT_NAME}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${INSTALL_DIR} PREFIX ${INSTALL_DIR}
URL ${URL}/${FILENAME} DOWNLOAD_COMMAND wget -q -O ${INSTALL_DIR}/${FILENAME} ${URL}/${FILENAME} &&
${CMAKE_COMMAND} -E tar xzf ${INSTALL_DIR}/${FILENAME}
DOWNLOAD_DIR ${INSTALL_DIR} DOWNLOAD_DIR ${INSTALL_DIR}
DOWNLOAD_NO_PROGRESS 1 DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
UPDATE_COMMAND "" UPDATE_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory ${UNPACK_DIR} ${INSTALL_DIR} INSTALL_COMMAND ""
) )
endfunction() endfunction()
set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec") set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec")
if (NOT EXISTS ${WORD2VEC_INSTALL_DIR}) if(NOT EXISTS ${WORD2VEC_INSTALL_DIR} AND NOT WIN32)
inference_download_and_uncompress(${WORD2VEC_INSTALL_DIR} ${INFERENCE_URL} "word2vec.inference.model.tar.gz") inference_download_and_uncompress(${WORD2VEC_INSTALL_DIR} ${INFERENCE_URL} "word2vec.inference.model.tar.gz")
endif() endif()
set(WORD2VEC_MODEL_DIR "${WORD2VEC_INSTALL_DIR}/word2vec.inference.model") set(WORD2VEC_MODEL_DIR "${WORD2VEC_INSTALL_DIR}/word2vec.inference.model")
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/memory/allocation/legacy_allocator.h" #include "paddle/fluid/memory/allocation/legacy_allocator.h"
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
......
include(operators) include(operators)
register_operators(DEPS naive_executor) register_operators(DEPS naive_executor)
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
...@@ -26,14 +27,6 @@ namespace operators { ...@@ -26,14 +27,6 @@ namespace operators {
using StepScopeVar = std::vector<framework::Scope *>; using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
static constexpr char kStepBlock[] = "sub_block";
static constexpr char kCondition[] = "Condition";
static constexpr char kStepScopes[] = "StepScopes";
static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
namespace { // NOLINT namespace { // NOLINT
static std::string GetSkipEagerDeletionVarsDebugString( static std::string GetSkipEagerDeletionVarsDebugString(
const std::vector<std::string> &vars) { const std::vector<std::string> &vars) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace operators {
// OpVariant is a wrapper class of OpDesc and OperatorBase
// So that API would be the same.
class OpVariant {
struct InputsVisitor
: public boost::static_visitor<const framework::VariableNameMap *> {
template <typename OpType>
const framework::VariableNameMap *operator()(const OpType *op) const {
return &(op->Inputs());
}
};
struct OutputsVisitor
: public boost::static_visitor<const framework::VariableNameMap *> {
template <typename OpType>
const framework::VariableNameMap *operator()(const OpType *op) const {
return &(op->Outputs());
}
};
struct AttributeMapVisitor
: public boost::static_visitor<const framework::AttributeMap *> {
const framework::AttributeMap *operator()(
const framework::OpDesc *op) const {
return &(op->GetAttrMap());
}
const framework::AttributeMap *operator()(
const framework::OperatorBase *op) const {
return &(op->Attrs());
}
};
struct RawPointerVisitor : public boost::static_visitor<const void *> {
template <typename OpType>
const void *operator()(const OpType *op) const {
return op;
}
};
public:
OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT
OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT
const framework::VariableNameMap &Inputs() const {
return *boost::apply_visitor(InputsVisitor(), op_);
}
const framework::VariableNameMap &Outputs() const {
return *boost::apply_visitor(OutputsVisitor(), op_);
}
const framework::AttributeMap &Attrs() const {
return *boost::apply_visitor(AttributeMapVisitor(), op_);
}
template <typename AttrType>
const AttrType &Attr(const std::string &name) const {
auto &attrs = Attrs();
auto it = attrs.find(name);
PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
return boost::get<AttrType>(it->second);
}
bool operator==(const OpVariant &other) const {
return RawPointer() == other.RawPointer();
}
const void *RawPointer() const {
return boost::apply_visitor(RawPointerVisitor(), op_);
}
int which() const { return static_cast<int>(op_.which()); }
struct Hasher {
size_t operator()(const OpVariant &op) const {
return reinterpret_cast<size_t>(op.RawPointer());
}
};
private:
const boost::variant<const framework::OperatorBase *,
const framework::OpDesc *>
op_;
};
static std::string GetDebugString(const std::vector<std::string> &names) {
if (names.empty()) return "";
std::string ret = names[0];
for (size_t i = 1; i < names.size(); ++i) {
ret += (" " + names[i]);
}
return ret;
}
// Set skip variables of while_op and while_grad_op
// These variables should be skipped when eager deletion enables.
// It is because:
// 1. while_grad_op needs some variables defined in while_op.
// 2. while_grad_op needs variables from the previous time step.
static void SetSkipVars(const OpVariant &op, std::vector<std::string> attr) {
auto &attrs = const_cast<framework::AttributeMap &>(op.Attrs());
VLOG(2) << "Prepare to skip " << attr.size()
<< " var(s): " << GetDebugString(attr);
attrs[kSkipEagerDeletionVars] = std::move(attr);
}
// Check whether the forward while_op and while_grad_op match
// The program may have many while_ops.
static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op,
const OpVariant &grad_op) {
return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) &&
fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs);
}
// Test whether the variable is skippable in forward while_op
// The variable is skippable in while_op when the variable used in while_grad
// is not from grad_block.
static bool IsSkippableVar(const std::string &name,
framework::BlockDesc *grad_block) {
return name != framework::kEmptyVarName && !grad_block->HasVar(name);
}
static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
const OpVariant &bwd_op) {
auto *grad_block = bwd_op.Attr<framework::BlockDesc *>(kStepBlock);
// Find all skippable variables in forward while_op
std::unordered_set<std::string> forward_skip_vars;
for (auto *op_desc : grad_block->AllOps()) {
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
if (IsSkippableVar(in_arg_name, grad_block)) {
forward_skip_vars.insert(in_arg_name);
}
}
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
if (IsSkippableVar(out_arg_name, grad_block)) {
forward_skip_vars.insert(out_arg_name);
}
}
}
SetSkipVars(fwd_op, std::vector<std::string>(forward_skip_vars.begin(),
forward_skip_vars.end()));
// Find all skippable variables in while_grad_op
// The skipped variables are those which would be used across time steps.
auto &fwd_input = fwd_op.Inputs().at(kX);
auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX));
PADDLE_ENFORCE_EQ(
fwd_input.size(), in_grads.size(),
"Backward input gradient number does not match forward input number.");
std::unordered_set<std::string> backward_skip_vars;
for (size_t i = 0; i < in_grads.size(); ++i) {
if (in_grads[i] == framework::kEmptyVarName) {
continue;
}
backward_skip_vars.insert(in_grads[i]);
backward_skip_vars.insert(framework::GradVarName(fwd_input[i]));
}
SetSkipVars(bwd_op, std::vector<std::string>(backward_skip_vars.begin(),
backward_skip_vars.end()));
}
// Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out.
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
std::vector<OpVariant> *while_grad_ops) {
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
if (while_ops->empty()) return;
const auto *program =
while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j);
if (op->Type() == "while") {
while_ops->emplace_back(op);
} else if (op->Type() == "while_grad") {
while_grad_ops->emplace_back(op);
}
}
}
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(),
"There are extra while_grad ops in the graph or program");
}
static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);
VLOG(2) << "Found while op num: " << while_ops->size()
<< ", while grad op num: " << while_grad_ops->size();
if (while_grad_ops->empty()) {
return;
}
std::unordered_set<OpVariant, OpVariant::Hasher> while_op_set(
while_ops->begin(), while_ops->end());
for (auto &bwd_op : *while_grad_ops) {
const OpVariant *matched_fwd_op = nullptr;
for (auto &fwd_op : while_op_set) {
if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) {
PADDLE_ENFORCE(matched_fwd_op == nullptr,
"Found multiple matched while ops");
matched_fwd_op = &fwd_op;
}
}
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
"Cannot find matched forward while op.");
ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op);
while_op_set.erase(*matched_fwd_op);
}
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
// If block_id is not 0, returns
// This is because all while_ops and while_grad_ops in the whole program
// would be processed when block_id is 0 (i.e. when Executor::Run() or
// ParallelExecutor constructs).
// What's more, all while_ops and while_grad_ops must be processed when
// block_id is zero. If not, while_op may run first and erase variables
// used in while_grad_op, and in this moment, while_grad_ops may be not
// constructed yet.
if (block_id != 0) return;
std::vector<OpVariant> fwd_ops, bwd_ops;
for (auto &op : all_ops) {
if (op->Type() == "while") {
fwd_ops.emplace_back(op.get());
} else if (op->Type() == "while_grad") {
bwd_ops.emplace_back(op.get());
}
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops;
fwd_ops.reserve(while_ops.size());
for (auto *op : while_ops) {
fwd_ops.emplace_back(op);
}
bwd_ops.reserve(while_grad_ops.size());
for (auto *op : while_grad_ops) {
bwd_ops.emplace_back(op);
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,19 +14,30 @@ ...@@ -14,19 +14,30 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h" #include <memory>
#include "paddle/fluid/framework/ir/pass.h" #include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace operators {
namespace details {
class EagerDeletionPass : public ir::Pass { static constexpr char kStepBlock[] = "sub_block";
protected: static constexpr char kCondition[] = "Condition";
std::unique_ptr<ir::Graph> ApplyImpl( static constexpr char kStepScopes[] = "StepScopes";
std::unique_ptr<ir::Graph> graph) const override; static constexpr char kX[] = "X";
}; static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
} // namespace details void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
} // namespace framework int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops);
} // namespace operators
} // namespace paddle } // namespace paddle
...@@ -82,8 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -82,8 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>, auto ker =
platform::CPUPlace>(tag_num); jit::KernelFuncs<jit::CRFDecodingTuple<T>, platform::CPUPlace>::Cache()
.At(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;
......
...@@ -110,8 +110,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -110,8 +110,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>, auto multiply = jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>,
platform::CPUPlace>(0); platform::CPUPlace>::Cache()
.At(0);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {
......
...@@ -52,8 +52,9 @@ struct EmbeddingVSumFunctor { ...@@ -52,8 +52,9 @@ struct EmbeddingVSumFunctor {
out_width, jit::SeqPoolType::kSum); out_width, jit::SeqPoolType::kSum);
for (size_t i = 0; i != ids_lod.size() - 1; ++i) { for (size_t i = 0; i != ids_lod.size() - 1; ++i) {
attr.index_height = ids_lod[i + 1] - ids_lod[i]; attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool = jit::Get<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>, auto emb_seqpool =
platform::CPUPlace>(attr); jit::KernelFuncs<jit::EmbSeqPoolTuple<T>, platform::CPUPlace>::Cache()
.At(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width, emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr); &attr);
} }
...@@ -135,8 +136,9 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -135,8 +136,9 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace()); T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
const T *d_output_data = d_output->data<T>(); const T *d_output_data = d_output->data<T>();
auto vbroadcast = jit::Get<jit::kVBroadcast, jit::VBroadcastTuples<T>, auto vbroadcast =
platform::CPUPlace>(out_width); jit::KernelFuncs<jit::VBroadcastTuple<T>, platform::CPUPlace>::Cache()
.At(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
const T *src = d_output_data + i * out_width; const T *src = d_output_data + i * out_width;
......
...@@ -196,11 +196,14 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -196,11 +196,14 @@ class FusionGRUKernel : public framework::OpKernel<T> {
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \ jit::gru_t one_step; \
auto ComputeH1 = \ auto ComputeH1 = \
jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At( \
attr); \
auto ComputeHtPart1 = \ auto ComputeHtPart1 = \
jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
.At(attr); \
auto ComputeHtPart2 = \ auto ComputeHtPart2 = \
jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
.At(attr); \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
......
...@@ -258,9 +258,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -258,9 +258,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.wp = wp_data; \ one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \ one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \ auto ComputeC1H1 = \
jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \ jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \
attr); \
auto ComputeCtHt = \ auto ComputeCtHt = \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr) jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \
attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
......
...@@ -82,9 +82,11 @@ template <typename T> ...@@ -82,9 +82,11 @@ template <typename T>
static void fc_relu(const T* x, const T* w, const T* b, T* y, static void fc_relu(const T* x, const T* w, const T* b, T* y,
const jit::matmul_attr_t& attr) { const jit::matmul_attr_t& attr) {
auto matmul = auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr); jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
attr);
auto addbias_relu = auto addbias_relu =
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(attr.n); jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
attr.n);
matmul(x, w, y, &attr); matmul(x, w, y, &attr);
T* dst = y; T* dst = y;
for (int i = 0; i < attr.m; ++i) { for (int i = 0; i < attr.m; ++i) {
......
...@@ -98,7 +98,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> { ...@@ -98,7 +98,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
attr.type = jit::SeqPoolType::kSqrt; attr.type = jit::SeqPoolType::kSqrt;
} }
auto seqpool = auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(
attr); attr);
size_t n = ins.size(); size_t n = ins.size();
size_t dst_step_size = n * w; size_t dst_step_size = n * w;
......
...@@ -94,19 +94,23 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> { ...@@ -94,19 +94,23 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
int o_numel = attr.m * attr.n; int o_numel = attr.m * attr.n;
auto vsquare_x = auto vsquare_x =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.m * jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
attr.k); attr.m * attr.k);
auto vsquare_y = auto vsquare_y =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.k * jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
attr.n); attr.k * attr.n);
auto vsquare_xy = auto vsquare_xy =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel); jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto vsub = auto vsub =
jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel); jit::KernelFuncs<jit::VSubTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto vscal = auto vscal =
jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel); jit::KernelFuncs<jit::VScalTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto matmul = auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr); jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
attr);
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* y_data = y->data<T>(); const T* y_data = y->data<T>();
......
...@@ -5,7 +5,7 @@ file(APPEND ${jit_file} "\#pragma once\n") ...@@ -5,7 +5,7 @@ file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n") file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n") file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place xxhash)
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc) list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
......
...@@ -59,8 +59,6 @@ BenchJITKernel* InsertBenchmark(BenchJITKernel* b) { ...@@ -59,8 +59,6 @@ BenchJITKernel* InsertBenchmark(BenchJITKernel* b) {
InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \ InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \
void BenchJITKernel_##name##_##dtype##_##place##_::Run() void BenchJITKernel_##name##_##dtype##_##place##_::Run()
#define BENCH_FP32_CPU(name) BENCH_JITKERNEL(name, FP32, CPU)
void RUN_ALL_BENCHMARK() { void RUN_ALL_BENCHMARK() {
for (auto p : g_all_benchmarks) { for (auto p : g_all_benchmarks) {
if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) { if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) {
...@@ -90,11 +88,11 @@ std::vector<int> TestSizes() { ...@@ -90,11 +88,11 @@ std::vector<int> TestSizes() {
return s; return s;
} }
template <typename KernelTuples, typename... Args> template <typename KernelTuple, typename... Args>
struct BenchFunc { struct BenchFunc {
// return this function avg time // return this function avg time
// TODO(TJ): clear cache every time // TODO(TJ): clear cache every time
double operator()(const typename KernelTuples::func_type tgt, Args... args) { double operator()(const typename KernelTuple::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) { for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...); tgt(args...);
} }
...@@ -109,40 +107,17 @@ struct BenchFunc { ...@@ -109,40 +107,17 @@ struct BenchFunc {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
template <jit::KernelType KT, typename KernelTuples, typename PlaceType, template <typename KernelTuple, typename PlaceType, typename... Args>
typename... Args> void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { BenchFunc<KernelTuple, Args...> benchmark;
BenchFunc<KernelTuples, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos; std::vector<std::pair<std::string, double>> infos;
// test refer auto funcs = jit::GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
auto refer = jit::GetRefer<KT, KernelTuples>(); for (auto f : funcs) {
if (!refer) { infos.push_back(std::make_pair(f.first, benchmark(f.second, args...)));
LOG(FATAL) << "Refer can not be empty!";
} }
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
std::make_pair(i->ImplType(), benchmark(more, args...)));
}
}
}
// Test result from Get function // Test result from Get function
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr); auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
if (!tgt) { if (!tgt) {
LOG(FATAL) << "Target can not be empty!"; LOG(FATAL) << "Target can not be empty!";
} }
...@@ -150,7 +125,8 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -150,7 +125,8 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
// print // print
std::ostringstream loginfos; std::ostringstream loginfos;
loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": "; loginfos << "Kernel Type " << jit::to_string(KernelTuple::kernel_type) << ": "
<< attr << ": ";
for (auto pair : infos) { for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; "; loginfos << pair.first << " takes " << pair.second << " us; ";
} }
...@@ -159,8 +135,9 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -159,8 +135,9 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
using Tensor = paddle::framework::Tensor; using Tensor = paddle::framework::Tensor;
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchXYZNKernel() { void BenchKernelXYZN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x, y, z; Tensor x, y, z;
x.Resize({d}); x.Resize({d});
...@@ -171,16 +148,16 @@ void BenchXYZNKernel() { ...@@ -171,16 +148,16 @@ void BenchXYZNKernel() {
T* z_data = z.mutable_data<T>(PlaceType()); T* z_data = z.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data); RandomVec<T>(d, x_data);
RandomVec<T>(d, y_data); RandomVec<T>(d, y_data);
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y.data<T>(), z_data,
y.data<T>(), z_data, d); d);
// test inplace // test inplace
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data, BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), z_data, z_data, d);
z_data, d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchAXYNKernel() { void BenchKernelAXYN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
const T a = static_cast<T>(3); const T a = static_cast<T>(3);
Tensor x, y; Tensor x, y;
...@@ -189,26 +166,26 @@ void BenchAXYNKernel() { ...@@ -189,26 +166,26 @@ void BenchAXYNKernel() {
T* x_data = x.mutable_data<T>(PlaceType()); T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data); RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data, BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), y_data, d);
d);
// test inplace // test inplace
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), x_data, BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), x_data, d);
d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchXRNKernel() { void BenchKernelXRN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x; Tensor x;
RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType())); RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType()));
T res; T res;
BenchAllImpls<KT, jit::XRNTuples<T>, PlaceType>(d, x.data<T>(), &res, d); BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), &res, d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchXYNKernel() { void BenchKernelXYN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x, y; Tensor x, y;
x.Resize({d}); x.Resize({d});
...@@ -216,12 +193,13 @@ void BenchXYNKernel() { ...@@ -216,12 +193,13 @@ void BenchXYNKernel() {
T* x_data = x.mutable_data<T>(PlaceType()); T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data); RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data<T>(), y_data, d); BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y_data, d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchLSTMKernel() { void BenchKernelLSTM() {
using T = typename KernelTuple::data_type;
for (bool use_peephole : {true, false}) { for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh, const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
...@@ -252,13 +230,14 @@ void BenchLSTMKernel() { ...@@ -252,13 +230,14 @@ void BenchLSTMKernel() {
step.wp = wp_data; step.wp = wp_data;
step.checked = checked_data; step.checked = checked_data;
} }
BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr); BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchGRUKernel() { void BenchKernelGRU() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh); const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
auto place = PlaceType(); auto place = PlaceType();
...@@ -275,12 +254,13 @@ void BenchGRUKernel() { ...@@ -275,12 +254,13 @@ void BenchGRUKernel() {
step.gates = x_data; step.gates = x_data;
step.ht_1 = ht_1_data; step.ht_1 = ht_1_data;
step.ht = ht_data; step.ht = ht_data;
BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr); BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchSeqPoolKernel() { void BenchKernelSeqPool() {
using T = typename KernelTuple::data_type;
std::vector<jit::SeqPoolType> pool_types = { std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) { for (auto type : pool_types) {
...@@ -294,15 +274,15 @@ void BenchSeqPoolKernel() { ...@@ -294,15 +274,15 @@ void BenchSeqPoolKernel() {
RandomVec<T>(h * w, x.mutable_data<T>(PlaceType()), -2.f, 2.f); RandomVec<T>(h * w, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data, BenchAllImpls<KernelTuple, PlaceType>(attr, x_data, y_data, &attr);
y_data, &attr);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchEmbSeqPoolKernel() { void BenchKernelEmbSeqPool() {
using T = typename KernelTuple::data_type;
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
int64_t tbl_h = 1e4; int64_t tbl_h = 1e4;
for (int tbl_w : {10, 16, 256}) { for (int tbl_w : {10, 16, 256}) {
...@@ -324,16 +304,17 @@ void BenchEmbSeqPoolKernel() { ...@@ -324,16 +304,17 @@ void BenchEmbSeqPoolKernel() {
tbl_h - 1); tbl_h - 1);
const int64_t* idx_data = idx.data<int64_t>(); const int64_t* idx_data = idx.data<int64_t>();
T* o_data = out.mutable_data<T>(PlaceType()); T* o_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(attr, table_data, idx_data,
attr, table_data, idx_data, o_data, &attr); o_data, &attr);
} }
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchSgdKernel() { void BenchKernelSgd() {
using T = typename KernelTuple::data_type;
const T lr = 0.1; const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower, auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> { const int64_t upper) -> std::vector<int64_t> {
...@@ -364,15 +345,16 @@ void BenchSgdKernel() { ...@@ -364,15 +345,16 @@ void BenchSgdKernel() {
const T* grad_data = grad.data<T>(); const T* grad_data = grad.data<T>();
const int64_t* rows_data = rows.data(); const int64_t* rows_data = rows.data();
jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size); jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
BenchAllImpls<KT, jit::SgdTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(attr, &lr, param_data, grad_data,
attr, &lr, param_data, grad_data, rows_data, param_data, &attr); rows_data, param_data, &attr);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchMatMulKernel() { void BenchKernelMatMul() {
using T = typename KernelTuple::data_type;
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
for (int k : TestSizes()) { for (int k : TestSizes()) {
...@@ -386,15 +368,16 @@ void BenchMatMulKernel() { ...@@ -386,15 +368,16 @@ void BenchMatMulKernel() {
const T* b_data = b.data<T>(); const T* b_data = b.data<T>();
T* c_data = c.mutable_data<T>(PlaceType()); T* c_data = c.mutable_data<T>(PlaceType());
const jit::matmul_attr_t attr{m, n, k}; const jit::matmul_attr_t attr{m, n, k};
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(attr, a_data, b_data, BenchAllImpls<KernelTuple, PlaceType>(attr, a_data, b_data, c_data,
c_data, &attr); &attr);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchSoftmaxKernel() { void BenchKernelSoftmax() {
using T = typename KernelTuple::data_type;
for (int bs : {1, 2, 10}) { for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
Tensor x, y; Tensor x, y;
...@@ -403,14 +386,14 @@ void BenchSoftmaxKernel() { ...@@ -403,14 +386,14 @@ void BenchSoftmaxKernel() {
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f); RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType>(n, x_data, y_data, n, BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs);
bs);
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchLayerNormKernel() { void BenchKernelLayerNorm() {
using T = typename KernelTuple::data_type;
const T epsilon = 9.99999975e-06; const T epsilon = 9.99999975e-06;
for (int n : {1, 2, 10}) { for (int n : {1, 2, 10}) {
for (int x_dim_0 : {1, 9, 17, 50}) { for (int x_dim_0 : {1, 9, 17, 50}) {
...@@ -439,16 +422,17 @@ void BenchLayerNormKernel() { ...@@ -439,16 +422,17 @@ void BenchLayerNormKernel() {
T* var_data = var.data<T>(); T* var_data = var.data<T>();
T* out_data = out.mutable_data<T>(PlaceType()); T* out_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::LayerNormTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(right, x_data, out_data,
right, x_data, out_data, mean_data, var_data, scale_data, bias_data, mean_data, var_data, scale_data,
left, epsilon, right); bias_data, left, epsilon, right);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchCRFDecodingKernel() { void BenchKernelCRFDecoding() {
using T = typename KernelTuple::data_type;
constexpr int state_trans_base_idx = 2; constexpr int state_trans_base_idx = 2;
for (int seq_len : {1, 11, 17, 50}) { for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : TestSizes()) { for (int tag_num : TestSizes()) {
...@@ -468,14 +452,15 @@ void BenchCRFDecodingKernel() { ...@@ -468,14 +452,15 @@ void BenchCRFDecodingKernel() {
T* alpha_data = alpha.mutable_data<T>(PlaceType()); T* alpha_data = alpha.mutable_data<T>(PlaceType());
int* track_data = track.mutable_data<int>(PlaceType()); int* track_data = track.mutable_data<int>(PlaceType());
BenchAllImpls<KT, jit::CRFDecodingTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(tag_num, seq_len, x_data, w_data,
tag_num, seq_len, x_data, w_data, alpha_data, track_data, tag_num); alpha_data, track_data, tag_num);
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchVBroadcastKernel() { void BenchKernelVBroadcast() {
using T = typename KernelTuple::data_type;
for (int64_t w : {1, 16, 64, 100, 256}) { for (int64_t w : {1, 16, 64, 100, 256}) {
Tensor x; Tensor x;
x.Resize({w}); x.Resize({w});
...@@ -485,78 +470,86 @@ void BenchVBroadcastKernel() { ...@@ -485,78 +470,86 @@ void BenchVBroadcastKernel() {
Tensor y; Tensor y;
y.Resize({h * w}); y.Resize({h * w});
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(w, x_data, y_data,
w, x_data, y_data, static_cast<int64_t>(h), w); static_cast<int64_t>(h), w);
} }
} }
} }
using T = float; #define BenchKernelVMul BenchKernelXYZN
using CPUPlace = paddle::platform::CPUPlace; #define BenchKernelVAdd BenchKernelXYZN
#define BenchKernelVAddRelu BenchKernelXYZN
#define BenchKernelVSub BenchKernelXYZN
// xyzn #define BenchKernelVScal BenchKernelAXYN
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, CPUPlace>(); } #define BenchKernelVAddBias BenchKernelAXYN
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, CPUPlace>(); }
// axyn #define BenchKernelVRelu BenchKernelXYN
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, CPUPlace>(); } #define BenchKernelVIdentity BenchKernelXYN
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, CPUPlace>(); } #define BenchKernelVSquare BenchKernelXYN
#define BenchKernelVExp BenchKernelXYN
#define BenchKernelVSigmoid BenchKernelXYN
#define BenchKernelVTanh BenchKernelXYN
#define BenchKernelVCopy BenchKernelXYN
// xrn #define BenchKernelHMax BenchKernelXRN
BENCH_FP32_CPU(kHSum) { BenchXRNKernel<jit::kHSum, T, CPUPlace>(); } #define BenchKernelHSum BenchKernelXRN
BENCH_FP32_CPU(kHMax) { BenchXRNKernel<jit::kHMax, T, CPUPlace>(); }
// xyn #define BenchKernelLSTMCtHt BenchKernelLSTM
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, CPUPlace>(); } #define BenchKernelLSTMC1H1 BenchKernelLSTM
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
BENCH_FP32_CPU(kVCopy) { BenchXYNKernel<jit::kVCopy, T, CPUPlace>(); }
// lstm and peephole
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, CPUPlace>(); }
// gru functions
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel<jit::kGRUHtPart1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
// seq pool function
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
// embedding seq pool function
BENCH_FP32_CPU(kEmbSeqPool) {
BenchEmbSeqPoolKernel<jit::kEmbSeqPool, T, CPUPlace>();
}
// sgd function #define BenchKernelGRUH1 BenchKernelGRU
BENCH_FP32_CPU(kSgd) { BenchSgdKernel<jit::kSgd, T, CPUPlace>(); } #define BenchKernelGRUHtPart1 BenchKernelGRU
#define BenchKernelGRUHtPart2 BenchKernelGRU
// matmul using CPUPlace = paddle::platform::CPUPlace;
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
// softmax #define BENCH_FP32_CPU(name) \
BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel<jit::kSoftmax, T, CPUPlace>(); } BENCH_JITKERNEL(name, FP32, CPU) { \
BenchKernel##name<jit::name##Tuple<float>, CPUPlace>(); \
}
// layernorm // xyzn
BENCH_FP32_CPU(kLayerNorm) { BENCH_FP32_CPU(VMul);
BenchLayerNormKernel<jit::kLayerNorm, T, CPUPlace>(); BENCH_FP32_CPU(VAdd);
} BENCH_FP32_CPU(VAddRelu);
BENCH_FP32_CPU(VSub);
// crfdecoding // axyn
BENCH_FP32_CPU(kCRFDecoding) { BENCH_FP32_CPU(VScal);
BenchCRFDecodingKernel<jit::kCRFDecoding, T, CPUPlace>(); BENCH_FP32_CPU(VAddBias);
}
// vbroadcast function // xyn
BENCH_FP32_CPU(kVBroadcast) { BENCH_FP32_CPU(VRelu);
BenchVBroadcastKernel<jit::kVBroadcast, T, CPUPlace>(); BENCH_FP32_CPU(VIdentity);
} BENCH_FP32_CPU(VSquare);
BENCH_FP32_CPU(VExp);
BENCH_FP32_CPU(VSigmoid);
BENCH_FP32_CPU(VTanh);
BENCH_FP32_CPU(VCopy);
// xrn
BENCH_FP32_CPU(HMax);
BENCH_FP32_CPU(HSum);
// LSTM
BENCH_FP32_CPU(LSTMCtHt);
BENCH_FP32_CPU(LSTMC1H1);
// GRU
BENCH_FP32_CPU(GRUH1);
BENCH_FP32_CPU(GRUHtPart1);
BENCH_FP32_CPU(GRUHtPart2);
BENCH_FP32_CPU(LayerNorm);
BENCH_FP32_CPU(CRFDecoding);
BENCH_FP32_CPU(SeqPool);
BENCH_FP32_CPU(EmbSeqPool);
BENCH_FP32_CPU(MatMul);
BENCH_FP32_CPU(Softmax);
BENCH_FP32_CPU(Sgd);
BENCH_FP32_CPU(VBroadcast);
// Benchmark all jit kernels including jitcode, mkl and refer. // Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...] // To use this tool, run command: ./benchmark [options...]
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h" #include "paddle/fluid/operators/jit/gen/act.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -81,7 +82,7 @@ void VActJitCode::genCode() { ...@@ -81,7 +82,7 @@ void VActJitCode::genCode() {
#define DECLARE_ACT_CREATOR(name) \ #define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override; \ bool CanBeUsed(const int& attr) const override; \
size_t CodeSize(const int& d) const override; \ size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \ std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \ return make_unique<name##JitCode>(attr, CodeSize(attr)); \
...@@ -96,27 +97,27 @@ DECLARE_ACT_CREATOR(VSigmoid); ...@@ -96,27 +97,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh); DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me // TODO(TJ): tuning use me
bool VReluCreator::UseMe(const int& d) const { bool VReluCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VSquareCreator::UseMe(const int& d) const { bool VSquareCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VIdentityCreator::UseMe(const int& d) const { bool VIdentityCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VExpCreator::UseMe(const int& d) const { bool VExpCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d < 32; return platform::MayIUse(platform::avx) && d < 32;
} }
bool VSigmoidCreator::UseMe(const int& d) const { bool VSigmoidCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VTanhCreator::UseMe(const int& d) const { bool VTanhCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h" #include "paddle/fluid/operators/jit/gen/blas.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -142,7 +143,7 @@ void NCHW16CMulNCJitCode::genCode() { ...@@ -142,7 +143,7 @@ void NCHW16CMulNCJitCode::genCode() {
class NCHW16CMulNCCreator : public JitCodeCreator<int> { class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public: public:
bool UseMe(const int& attr) const override { bool CanBeUsed(const int& attr) const override {
return platform::MayIUse(platform::avx512f); return platform::MayIUse(platform::avx512f);
} }
size_t CodeSize(const int& d) const override { return 256 * 1024; } size_t CodeSize(const int& d) const override { return 256 * 1024; }
...@@ -154,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> { ...@@ -154,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
#define DECLARE_BLAS_CREATOR(name) \ #define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override { \ bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx) && attr <= 1024; \ return platform::MayIUse(platform::avx) && attr <= 1024; \
} \ } \
size_t CodeSize(const int& d) const override { \ size_t CodeSize(const int& d) const override { \
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/embseqpool.h" #include "paddle/fluid/operators/jit/gen/embseqpool.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
...@@ -121,7 +122,7 @@ void EmbSeqPoolJitCode::genCode() { ...@@ -121,7 +122,7 @@ void EmbSeqPoolJitCode::genCode() {
class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> { class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
public: public:
bool UseMe(const emb_seq_pool_attr_t& attr) const override { bool CanBeUsed(const emb_seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx) && return platform::MayIUse(platform::avx) &&
attr.table_width % YMM_FLOAT_BLOCK == 0; attr.table_width % YMM_FLOAT_BLOCK == 0;
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/gru.h" #include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -86,7 +87,7 @@ void GRUJitCode::genCode() { ...@@ -86,7 +87,7 @@ void GRUJitCode::genCode() {
class name##Creator : public JitCodeCreator<gru_attr_t> { \ class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \ public: \
/* TODO(TJ): enable more */ \ /* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \ bool CanBeUsed(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \ } \
size_t CodeSize(const gru_attr_t& attr) const override { \ size_t CodeSize(const gru_attr_t& attr) const override { \
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/hopv.h" #include "paddle/fluid/operators/jit/gen/hopv.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -76,7 +77,7 @@ void HOPVJitCode::genCode() { ...@@ -76,7 +77,7 @@ void HOPVJitCode::genCode() {
#define DECLARE_HOP_CREATOR(name) \ #define DECLARE_HOP_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override { \ bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx); \ return platform::MayIUse(platform::avx); \
} \ } \
size_t CodeSize(const int& d) const override { \ size_t CodeSize(const int& d) const override { \
......
...@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator { ...@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
virtual void genCode() = 0; virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); } size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override { const unsigned char* getCodeInternal() const override {
const Xbyak::uint8* code = CodeGenerator::getCode(); const Xbyak::uint8* code = CodeGenerator::getCode();
return code; return code;
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/lstm.h" #include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -114,7 +115,7 @@ void LSTMJitCode::genCode() { ...@@ -114,7 +115,7 @@ void LSTMJitCode::genCode() {
class name##Creator : public JitCodeCreator<lstm_attr_t> { \ class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \ public: \
/* TODO(TJ): enable more */ \ /* TODO(TJ): enable more */ \
bool UseMe(const lstm_attr_t& attr) const override { \ bool CanBeUsed(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \ } \
size_t CodeSize(const lstm_attr_t& attr) const override { \ size_t CodeSize(const lstm_attr_t& attr) const override { \
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "paddle/fluid/operators/jit/gen/matmul.h" #include "paddle/fluid/operators/jit/gen/matmul.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() { ...@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
class MatMulCreator : public JitCodeCreator<matmul_attr_t> { class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
public: public:
bool UseMe(const matmul_attr_t& attr) const override { bool CanBeUsed(const matmul_attr_t& attr) const override {
return attr.m == 1 && platform::MayIUse(platform::avx512f) && return attr.m == 1 && platform::MayIUse(platform::avx512f) &&
attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512; attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h" #include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <memory>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -57,7 +58,7 @@ void SeqPoolJitCode::genCode() { ...@@ -57,7 +58,7 @@ void SeqPoolJitCode::genCode() {
class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> { class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
public: public:
bool UseMe(const seq_pool_attr_t& attr) const override { bool CanBeUsed(const seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
size_t CodeSize(const seq_pool_attr_t& attr) const override { size_t CodeSize(const seq_pool_attr_t& attr) const override {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/sgd.h" #include "paddle/fluid/operators/jit/gen/sgd.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -104,7 +105,7 @@ void SgdJitCode::genCode() { ...@@ -104,7 +105,7 @@ void SgdJitCode::genCode() {
class SgdCreator : public JitCodeCreator<sgd_attr_t> { class SgdCreator : public JitCodeCreator<sgd_attr_t> {
public: public:
bool UseMe(const sgd_attr_t& attr) const override { bool CanBeUsed(const sgd_attr_t& attr) const override {
return platform::MayIUse(platform::avx) && return platform::MayIUse(platform::avx) &&
attr.grad_width % YMM_FLOAT_BLOCK == 0; attr.grad_width % YMM_FLOAT_BLOCK == 0;
} }
......
...@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() { ...@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
class VBroadcastCreator : public JitCodeCreator<int64_t> { class VBroadcastCreator : public JitCodeCreator<int64_t> {
public: public:
bool UseMe(const int64_t& w) const override { bool CanBeUsed(const int64_t& w) const override {
return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0; return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0;
} }
size_t CodeSize(const int64_t& w) const override { size_t CodeSize(const int64_t& w) const override {
......
...@@ -31,7 +31,7 @@ namespace paddle { ...@@ -31,7 +31,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
// refer do not need useme, it would be the last one. // refer do not need CanBeUsed, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const { void GenBase::dumpCode(const unsigned char* code) const {
if (code) { if (code) {
static int counter = 0; static int counter = 0;
......
...@@ -31,9 +31,10 @@ class GenBase : public Kernel { ...@@ -31,9 +31,10 @@ class GenBase : public Kernel {
virtual ~GenBase() = default; virtual ~GenBase() = default;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual size_t getSize() const = 0; virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0; virtual const unsigned char* getCodeInternal() const = 0;
const char* ImplType() const override { return "JitCode"; }
template <typename Func> template <typename Func>
Func getCode() { Func getCode() const {
const unsigned char* code = this->getCodeInternal(); const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) { if (FLAGS_dump_jitcode) {
this->dumpCode(code); this->dumpCode(code);
...@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator { ...@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
virtual ~JitCodeCreator() = default; virtual ~JitCodeCreator() = default;
// condition when this jit code can be used. // condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0; virtual bool CanBeUsed(const Attr& attr) const = 0;
// estimate this code size // estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0; virtual size_t CodeSize(const Attr& attr) const = 0;
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <unordered_map>
#include <utility> // for std::move
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
...@@ -27,35 +29,34 @@ namespace paddle { ...@@ -27,35 +29,34 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value && std::is_same<typename KernelTuple::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value, std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type const Kernel*>::type
GetJitCode(const typename KernelTuples::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuples::func_type; using Attr = typename KernelTuple::attr_type;
using Attr = typename KernelTuples::attr_type; int64_t key = JitCodeKey<Attr>(attr);
size_t key = JitCodeKey<Attr>(attr); auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance();
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) { if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>(); return codes.AllKernels().at(key).get();
} }
// creator is not related with attr, so can use KernelKey as key // creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KernelTuple::kernel_type, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>) // pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); auto& creator_map = JitCodeCreatorPool::Instance().AllCreators();
auto iter = creator_map.find(kkey); auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) { if (iter != creator_map.end()) {
auto& creators = iter->second; auto& creators = iter->second;
for (auto& cur : creators) { for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get()); auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) { if (i && i->CanBeUsed(attr)) {
auto p = i->CreateJitCode(attr); auto p = i->CreateJitCode(attr);
if (p) { if (p) {
auto f = p->template getCode<Func>(); auto res = p.get();
codes.Insert(key, std::move(p)); codes.Insert(key, std::move(p));
return f; return res;
} }
} }
} }
...@@ -63,87 +64,153 @@ GetJitCode(const typename KernelTuples::attr_type& attr) { ...@@ -63,87 +64,153 @@ GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr; return nullptr;
} }
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value || !std::is_same<typename KernelTuple::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value, !std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type const Kernel*>::type
GetJitCode(const typename KernelTuples::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
return nullptr; return nullptr;
} }
// Refer code do not related with attr, which is just for cast // Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace // Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples> template <typename KernelTuple>
inline typename KernelTuples::func_type GetRefer() { inline const Kernel* GetReferKernel() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels(); auto& ref_pool = ReferKernelPool::Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace()); KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey); auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(), PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function."); "Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second; auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) { for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get()); auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
if (i) { if (i) {
return i->GetFunc(); return i;
} }
} }
return nullptr; return nullptr;
} }
template <KernelType KT, typename KernelTuples, template <typename KernelTuple>
typename PlaceType = platform::CPUPlace> inline typename KernelTuple::func_type GetReferFunc() {
typename KernelTuples::func_type Get( auto ker = GetReferKernel<KernelTuple>();
const typename KernelTuples::attr_type& attr) { auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr); PADDLE_ENFORCE(p, "The Refer kernel should exsit");
if (jitfunc) { return p->GetFunc();
return jitfunc; }
// Return all Kernels that can be used
template <typename KernelTuple, typename PlaceType>
std::vector<const Kernel*> GetAllCandidateKernels(
const typename KernelTuple::attr_type& attr) {
// the search order shoudl be jitcode > more > refer
std::vector<const Kernel*> res;
auto jitker = GetJitCode<KernelTuple, PlaceType>(attr);
if (jitker) {
res.emplace_back(jitker);
} }
// pool: (KernelKey(type, place), vector<KernelPtr>) // more kernelpool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = KernelPool().Instance().AllKernels(); auto& pool = KernelPool::Instance().AllKernels();
auto iter = pool.find(kkey); auto iter = pool.find(kkey);
if (iter != pool.end()) { if (iter != pool.end()) {
auto& impls = iter->second; auto& impls = iter->second;
for (auto& impl : impls) { for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get()); auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) { if (i && i->CanBeUsed(attr)) {
return i->GetFunc(); res.emplace_back(i);
} }
} }
} }
// The last implementation should be reference function on CPUPlace. // The last implementation should be reference function on CPUPlace.
return GetRefer<KT, KernelTuples>(); auto ref = GetReferKernel<KernelTuple>();
PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty.");
res.emplace_back(ref);
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<std::pair<std::string, typename KernelTuple::func_type>>
GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type;
auto kers = GetAllCandidateKernels<KernelTuple, PlaceType>(attr);
std::vector<std::pair<std::string, Func>> res;
for (auto k : kers) {
std::string name = k->ImplType();
if (name == "JitCode") {
auto i = dynamic_cast<const GenBase*>(k);
PADDLE_ENFORCE(i, "jitcode kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
} else {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
PADDLE_ENFORCE(i, "kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->GetFunc()));
}
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
std::vector<typename KernelTuple::func_type> res;
for (auto& i : funcs) {
res.emplace_back(i.second);
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type GetDefaultBestFunc(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
PADDLE_ENFORCE_GE(funcs.size(), 1UL);
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
return funcs[0];
} }
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
class KernelFuncs { class KernelFuncs {
public: public:
KernelFuncs() = default; KernelFuncs() = default;
static KernelFuncs& Cache() { static KernelFuncs& Cache() {
static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache; static thread_local KernelFuncs<KernelTuple, PlaceType> g_func_cache;
return g_func_cache; return g_func_cache;
} }
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); } // the exposed interface to use
typename KernelTuple::func_type At(
void Insert(int key, typename KernelTuples::func_type func) { const typename KernelTuple::attr_type& attr) {
funcs_.emplace(key, func); // Maybe here is not good enough, not all kernels should have jitcode
} int64_t key = JitCodeKey<typename KernelTuple::attr_type>(attr);
typename KernelTuples::func_type At(int key) {
if (Has(key)) { if (Has(key)) {
return funcs_.at(key); return funcs_.at(key);
} }
auto func = Get<KT, KernelTuples, PlaceType>(key); // If do not have this attr in cache then get the default best
auto func = GetDefaultBestFunc<KernelTuple, PlaceType>(attr);
Insert(key, func); Insert(key, func);
return func; return func;
} }
typename KernelTuple::func_type operator[](
const typename KernelTuple::attr_type& attr) {
return At(attr);
}
protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuple::func_type func) {
funcs_.emplace(key, func);
}
private: private:
std::unordered_map<int, typename KernelTuples::func_type> funcs_; std::unordered_map<int64_t, typename KernelTuple::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncs); DISABLE_COPY_AND_ASSIGN(KernelFuncs);
}; };
......
...@@ -62,26 +62,55 @@ typedef enum { ...@@ -62,26 +62,55 @@ typedef enum {
kSqrt, kSqrt,
} SeqPoolType; } SeqPoolType;
// x, y, z, n
template <typename T> template <typename T>
struct XYZNTuples { struct XYZNTuple {
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int); typedef void (*func_type)(const T*, const T*, T*, int);
}; };
// a, x, y, n
template <typename T> template <typename T>
struct AXYNTuples : public XYZNTuples<T> {}; struct AXYNTuple : public XYZNTuple<T> {};
// x, y, n
template <typename T> template <typename T>
struct XYNTuples { struct XYNTuple {
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, T*, int); typedef void (*func_type)(const T*, T*, int);
}; };
// x, return and int // x, returned value, n
template <typename T> template <typename T>
struct XRNTuples : public XYNTuples<T> {}; struct XRNTuple : public XYNTuple<T> {};
#define DECLARE_KERNELTUPLE(kernel_tuple, type) \
template <typename T> \
struct type##Tuple : public kernel_tuple<T> { \
static constexpr KernelType kernel_type = k##type; \
}
// Tuple should be corresponding to the KernelType
DECLARE_KERNELTUPLE(XYZNTuple, VMul);
DECLARE_KERNELTUPLE(XYZNTuple, VAdd);
DECLARE_KERNELTUPLE(XYZNTuple, VAddRelu);
DECLARE_KERNELTUPLE(XYZNTuple, VSub);
DECLARE_KERNELTUPLE(AXYNTuple, VScal);
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
DECLARE_KERNELTUPLE(XYNTuple, VRelu);
DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
DECLARE_KERNELTUPLE(XYNTuple, VSquare);
DECLARE_KERNELTUPLE(XYNTuple, VExp);
DECLARE_KERNELTUPLE(XYNTuple, VSigmoid);
DECLARE_KERNELTUPLE(XYNTuple, VTanh);
DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum);
typedef struct { typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh void* gates; // gates: x_ch, x_ih, x_fh, x_oh
...@@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t; ...@@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t; typedef struct lstm_attr_s lstm_attr_t;
template <typename T> template <typename T>
struct LSTMTuples { struct LSTMTuple {
typedef T data_type; typedef T data_type;
typedef lstm_attr_t attr_type; typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*); typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
}; };
template <typename T> template <typename T>
struct GRUTuples { struct GRUTuple {
typedef T data_type; typedef T data_type;
typedef gru_attr_t attr_type; typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*); typedef void (*func_type)(gru_t*, const gru_attr_t*);
}; };
DECLARE_KERNELTUPLE(LSTMTuple, LSTMCtHt);
DECLARE_KERNELTUPLE(LSTMTuple, LSTMC1H1);
DECLARE_KERNELTUPLE(GRUTuple, GRUH1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart2);
#undef DECLARE_KERNELTUPLE
template <typename T> template <typename T>
struct VBroadcastTuples { struct VBroadcastTuple {
static constexpr KernelType kernel_type = kVBroadcast;
typedef T data_type; typedef T data_type;
typedef int64_t attr_type; typedef int64_t attr_type;
typedef void (*func_type)(const T*, T*, int64_t, int64_t); typedef void (*func_type)(const T*, T*, int64_t, int64_t);
...@@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s { ...@@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s {
} seq_pool_attr_t; } seq_pool_attr_t;
template <typename T> template <typename T>
struct SeqPoolTuples { struct SeqPoolTuple {
static constexpr KernelType kernel_type = kSeqPool;
typedef T data_type; typedef T data_type;
typedef seq_pool_attr_t attr_type; typedef seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
...@@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s { ...@@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s {
} emb_seq_pool_attr_t; } emb_seq_pool_attr_t;
template <typename T> template <typename T>
struct EmbSeqPoolTuples { struct EmbSeqPoolTuple {
static constexpr KernelType kernel_type = kEmbSeqPool;
typedef T data_type; typedef T data_type;
typedef emb_seq_pool_attr_t attr_type; typedef emb_seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, const int64_t*, T*, typedef void (*func_type)(const T*, const int64_t*, T*,
...@@ -198,7 +239,8 @@ typedef struct sgd_attr_s { ...@@ -198,7 +239,8 @@ typedef struct sgd_attr_s {
} sgd_attr_t; } sgd_attr_t;
template <typename T> template <typename T>
struct SgdTuples { struct SgdTuple {
static constexpr KernelType kernel_type = kSgd;
typedef T data_type; typedef T data_type;
typedef sgd_attr_t attr_type; typedef sgd_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*, typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*,
...@@ -214,21 +256,24 @@ typedef struct matmul_attr_s { ...@@ -214,21 +256,24 @@ typedef struct matmul_attr_s {
} matmul_attr_t; } matmul_attr_t;
template <typename T> template <typename T>
struct MatMulTuples { struct MatMulTuple {
static constexpr KernelType kernel_type = kMatMul;
typedef T data_type; typedef T data_type;
typedef matmul_attr_t attr_type; typedef matmul_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*); typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*);
}; };
template <typename T> template <typename T>
struct CRFDecodingTuples { struct CRFDecodingTuple {
static constexpr KernelType kernel_type = kCRFDecoding;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int); typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
}; };
template <typename T> template <typename T>
struct LayerNormTuples { struct LayerNormTuple {
static constexpr KernelType kernel_type = kLayerNorm;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int, typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
...@@ -236,7 +281,8 @@ struct LayerNormTuples { ...@@ -236,7 +281,8 @@ struct LayerNormTuples {
}; };
template <typename T> template <typename T>
struct SoftmaxTuples { struct SoftmaxTuple {
static constexpr KernelType kernel_type = kSoftmax;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int); typedef void (*func_type)(const T*, T*, int, int);
...@@ -244,7 +290,8 @@ struct SoftmaxTuples { ...@@ -244,7 +290,8 @@ struct SoftmaxTuples {
// nChw16c = nChw16c .* NC // nChw16c = nChw16c .* NC
template <typename T> template <typename T>
struct NCHW16CMulNCTuples { struct NCHW16CMulNCTuple {
static constexpr KernelType kernel_type = kNCHW16CMulNC;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int); typedef void (*func_type)(const T*, const T*, T*, int, int);
...@@ -255,28 +302,29 @@ class Kernel { ...@@ -255,28 +302,29 @@ class Kernel {
public: public:
Kernel() = default; Kernel() = default;
virtual ~Kernel() = default; virtual ~Kernel() = default;
virtual const char* ImplType() const = 0;
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
}; };
template <typename KernelTuples> template <typename KernelTuple>
class KernelMore : public Kernel { class KernelMore : public Kernel {
public: public:
using T = typename KernelTuples::data_type; using T = typename KernelTuple::data_type;
using Func = typename KernelTuples::func_type; using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuples::attr_type; using Attr = typename KernelTuple::attr_type;
virtual Func GetFunc() const { return func; } virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0; // specify this kernel can be used, means it should not fail if use it.
virtual const char* ImplType() const = 0; virtual bool CanBeUsed(const Attr& attr) const = 0;
protected: protected:
Func func{nullptr}; Func func{nullptr};
}; };
template <typename KernelTuples> template <typename KernelTuple>
class ReferKernel : public KernelMore<KernelTuples> { class ReferKernel : public KernelMore<KernelTuple> {
public: public:
// Refer code can always be used // Refer code can always be used
bool UseMe(const typename KernelTuples::attr_type& attr) const override { bool CanBeUsed(const typename KernelTuple::attr_type& attr) const override {
return true; return true;
} }
const char* ImplType() const override { return "Refer"; } const char* ImplType() const override { return "Refer"; }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h> // XXH64: 13.8 GB/s
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -20,71 +21,46 @@ namespace operators { ...@@ -20,71 +21,46 @@ namespace operators {
namespace jit { namespace jit {
template <> template <>
size_t JitCodeKey<int>(const int& d) { int64_t JitCodeKey<int>(const int& d) {
return d; return d;
} }
template <> template <>
size_t JitCodeKey<int64_t>(const int64_t& d) { int64_t JitCodeKey<int64_t>(const int64_t& d) {
return d; return d;
} }
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) {
if (type == kVIdentity) {
return 0;
} else if (type == kVExp) {
return 1;
} else if (type == kVRelu) {
return 2;
} else if (type == kVSigmoid) {
return 3;
} else if (type == kVTanh) {
return 4;
}
PADDLE_THROW("Unsupported act type %d", type);
return 0;
}
template <> template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { int64_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d; return XXH64(&attr, sizeof(gru_attr_t), 0);
int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
} }
template <> template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { int64_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d; int keys[5] = {
return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) + attr.d, static_cast<int>(attr.act_gate), static_cast<int>(attr.act_cand),
(act_type_convert(attr.act_cand) << act_type_shift); static_cast<int>(attr.act_cell), static_cast<int>(attr.use_peephole)};
return XXH64(keys, sizeof(int) * 5, 0);
} }
template <> template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) { int64_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
size_t key = attr.w; int keys[2] = {attr.w, static_cast<int>(attr.type)};
constexpr int pool_type_shift = 3; return XXH64(keys, sizeof(int) * 2, 0);
return (key << pool_type_shift) + static_cast<int>(attr.type);
} }
template <> template <>
size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) { int64_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
size_t key = attr.m; return XXH64(&attr, sizeof(int) * 3, 0); // m, n, k
constexpr int shift = 21;
return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k;
} }
template <> template <>
size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) { int64_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
return attr.table_width; return attr.table_width;
} }
template <> template <>
size_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) { int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
return attr.grad_width; return attr.grad_width;
} }
......
...@@ -46,7 +46,7 @@ struct KernelKey { ...@@ -46,7 +46,7 @@ struct KernelKey {
// Every JitCode should have a method to get the key from attribution // Every JitCode should have a method to get the key from attribution
template <typename Attr> template <typename Attr>
size_t JitCodeKey(const Attr& attr); int64_t JitCodeKey(const Attr& attr);
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> // for move
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
...@@ -30,7 +31,7 @@ namespace jit { ...@@ -30,7 +31,7 @@ namespace jit {
template <KernelType KT> template <KernelType KT>
class JitCodePool { class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr; typedef std::unique_ptr<GenBase> GenBasePtr;
typedef std::unordered_map<size_t, GenBasePtr> JitCodeMap; typedef std::unordered_map<int64_t, GenBasePtr> JitCodeMap;
public: public:
JitCodePool() = default; JitCodePool() = default;
...@@ -41,9 +42,9 @@ class JitCodePool { ...@@ -41,9 +42,9 @@ class JitCodePool {
const JitCodeMap& AllKernels() { return codes_; } const JitCodeMap& AllKernels() { return codes_; }
bool Has(size_t key) const { return codes_.find(key) != codes_.end(); } bool Has(int64_t key) const { return codes_.find(key) != codes_.end(); }
void Insert(size_t key, GenBasePtr value) { void Insert(int64_t key, GenBasePtr value) {
codes_.emplace(key, std::move(value)); codes_.emplace(key, std::move(value));
} }
......
...@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w, ...@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
} }
} }
bool CRFDecodingKernel::UseMe(const int& d) const { bool CRFDecodingKernel::CanBeUsed(const int& d) const {
#ifdef __AVX512F__ #ifdef __AVX512F__
constexpr int block = ZMM_FLOAT_BLOCK; constexpr int block = ZMM_FLOAT_BLOCK;
#else #else
......
...@@ -26,11 +26,11 @@ namespace intrinsic { ...@@ -26,11 +26,11 @@ namespace intrinsic {
void CRFDecoding(const int seq_len, const float* x, const float* w, void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num); float* alpha, int* track, int tag_num);
class CRFDecodingKernel : public KernelMore<CRFDecodingTuples<float>> { class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
public: public:
CRFDecodingKernel() { this->func = CRFDecoding; } CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe( bool CanBeUsed(
const typename CRFDecodingTuples<float>::attr_type&) const override; const typename CRFDecodingTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; } const char* ImplType() const override { return "Intrinsic"; }
}; };
......
...@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var, ...@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
} }
} }
bool LayerNormKernel::UseMe(const int& d) const { bool LayerNormKernel::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK; return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK;
} }
......
...@@ -27,10 +27,11 @@ void LayerNorm(float* x, float* out, float* mean, float* var, ...@@ -27,10 +27,11 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height, const float* scale, const float* bias, int height,
const float epsilon, int right); const float epsilon, int right);
class LayerNormKernel : public KernelMore<LayerNormTuples<float>> { class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
public: public:
LayerNormKernel() { this->func = LayerNorm; } LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuples<float>::attr_type&) const override; bool CanBeUsed(
const typename LayerNormTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; } const char* ImplType() const override { return "Intrinsic"; }
}; };
......
...@@ -23,6 +23,8 @@ namespace jit { ...@@ -23,6 +23,8 @@ namespace jit {
namespace more { namespace more {
namespace mix { namespace mix {
using CPUPlace = platform::CPUPlace;
void VSigmoid(const T* x, T* y, int n) { void VSigmoid(const T* x, T* y, int n) {
const float min = SIGMOID_THRESHOLD_MIN; const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX; const float max = SIGMOID_THRESHOLD_MAX;
...@@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) { ...@@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i]; y[i] = static_cast<T>(0) - y[i];
} }
auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n); auto compute = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
compute(y, y, n); compute(y, y, n);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]); y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
...@@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) { ...@@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) {
void VTanh(const T* x, T* y, int n) { void VTanh(const T* x, T* y, int n) {
const T a = 2, b = -1; const T a = 2, b = -1;
auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n); auto compute_scal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n); auto compute_addbias = KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n); auto compute_sigmoid = KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(n);
compute_scal(&a, x, y, n); compute_scal(&a, x, y, n);
compute_sigmoid(y, y, n); compute_sigmoid(y, y, n);
compute_scal(&a, y, y, n); compute_scal(&a, y, y, n);
...@@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) { ...@@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) {
} }
void Softmax(const T* x, T* y, int n, int bs) { void Softmax(const T* x, T* y, int n, int bs) {
auto compute_hmax = auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n); auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vscal =
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vaddbias = auto compute_vaddbias =
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n); KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vexp = auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
T scalar; T scalar;
...@@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) { ...@@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) {
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
if (type == kVSigmoid) { if (type == kVSigmoid) {
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVRelu) { } else if (type == kVRelu) {
return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VReluTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVTanh) { } else if (type == kVTanh) {
return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VTanhTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVIdentity) { } else if (type == kVIdentity) {
return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
} }
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW("Not support type: %s", type);
return nullptr; return nullptr;
...@@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { ...@@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
const int d = attr->d; const int d = attr->d;
const int d2 = d * 2; const int d2 = d * 2;
const int d3 = d * 3; const int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d); auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2); auto vadd_d2 = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d2);
auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_gate_d2 = getActFunc(attr->act_gate, d2); auto act_gate_d2 = getActFunc(attr->act_gate, d2);
auto act_gate_d3 = getActFunc(attr->act_gate, d3); auto act_gate_d3 = getActFunc(attr->act_gate, d3);
...@@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { ...@@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
int d = attr->d; int d = attr->d;
int d2 = d * 2; int d2 = d * 2;
int d3 = d * 3; int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d); auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_cand_d = getActFunc(attr->act_cand, d); auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d); auto act_cell_d = getActFunc(attr->act_cell, d);
...@@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) { ...@@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
int d2 = d * 2; int d2 = d * 2;
auto act_gate = getActFunc(attr->act_gate, d); auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d); auto act_cand = getActFunc(attr->act_cand, d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
act_gate(gates, gates, d); act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d); act_cand(gates + d2, gates + d2, d);
vmul_d(gates, gates + d2, ht, d); vmul_d(gates, gates + d2, ht, d);
...@@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { ...@@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
T* ht = reinterpret_cast<T*>(step->ht); T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1); const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc(attr->act_gate, attr->d); auto act_gate = getActFunc(attr->act_gate, attr->d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(attr->d);
act_gate(gates + attr->d, gates + attr->d, attr->d); act_gate(gates + attr->d, gates + attr->d, attr->d);
vmul_d(ht_1, gates + attr->d, ht, attr->d); vmul_d(ht_1, gates + attr->d, ht, attr->d);
} }
...@@ -206,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { ...@@ -206,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
} }
// TODO(TJ): tuning me // TODO(TJ): tuning me
bool VSigmoidKernel::UseMe(const int& d) const { return true; } bool VSigmoidKernel::CanBeUsed(const int& d) const { return true; }
bool VTanhKernel::UseMe(const int& d) const { return true; } bool VTanhKernel::CanBeUsed(const int& d) const { return true; }
bool SoftmaxKernel::UseMe(const int& d) const { return true; } bool SoftmaxKernel::CanBeUsed(const int& d) const { return true; }
bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; } bool LSTMCtHtKernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; } bool LSTMC1H1Kernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; } bool GRUH1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; } bool GRUHtPart1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; } bool GRUHtPart2Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
} // namespace mix } // namespace mix
} // namespace more } // namespace more
...@@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; } ...@@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
namespace mix = paddle::operators::jit::more::mix; namespace mix = paddle::operators::jit::more::mix;
#define REGISTER_MORE_KERNEL(key, func) \ #define REGISTER_MORE_KERNEL(func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel) REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid); REGISTER_MORE_KERNEL(VSigmoid);
REGISTER_MORE_KERNEL(kVTanh, VTanh); REGISTER_MORE_KERNEL(VTanh);
REGISTER_MORE_KERNEL(kSoftmax, Softmax); REGISTER_MORE_KERNEL(Softmax);
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt); REGISTER_MORE_KERNEL(LSTMCtHt);
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1); REGISTER_MORE_KERNEL(LSTMC1H1);
REGISTER_MORE_KERNEL(kGRUH1, GRUH1); REGISTER_MORE_KERNEL(GRUH1);
REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1); REGISTER_MORE_KERNEL(GRUHtPart1);
REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2); REGISTER_MORE_KERNEL(GRUHtPart2);
#undef REGISTER_MORE_KERNEL #undef REGISTER_MORE_KERNEL
...@@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr); ...@@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart1(gru_t* step, const gru_attr_t* attr); void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart2(gru_t* step, const gru_attr_t* attr); void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
#define DECLARE_MORE_KERNEL(name, tuples) \ #define DECLARE_MORE_KERNEL(name) \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name; } \ name##Kernel() { this->func = name; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \ bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \ const char* ImplType() const override { return "Mixed"; } \
} }
// XYN // XYN
DECLARE_MORE_KERNEL(VSigmoid, XYNTuples); DECLARE_MORE_KERNEL(VSigmoid);
DECLARE_MORE_KERNEL(VTanh, XYNTuples); DECLARE_MORE_KERNEL(VTanh);
// XRN // XRN
DECLARE_MORE_KERNEL(Softmax, SoftmaxTuples); DECLARE_MORE_KERNEL(Softmax);
DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_MORE_KERNEL(LSTMCtHt);
DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples); DECLARE_MORE_KERNEL(LSTMC1H1);
DECLARE_MORE_KERNEL(GRUH1, GRUTuples); DECLARE_MORE_KERNEL(GRUH1);
DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples); DECLARE_MORE_KERNEL(GRUHtPart1);
DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples); DECLARE_MORE_KERNEL(GRUHtPart2);
#undef DECLARE_MORE_KERNEL #undef DECLARE_MORE_KERNEL
......
...@@ -130,104 +130,105 @@ void ASum<double>(const double* x, double* res, int n) { ...@@ -130,104 +130,105 @@ void ASum<double>(const double* x, double* res, int n) {
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <> template <>
bool VMulKernel<float>::UseMe(const int& d) const { bool VMulKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return platform::MayIUse(platform::avx512f) && d > 512;
} }
template <> template <>
bool VAddKernel<float>::UseMe(const int& d) const { bool VAddKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d > 512; return platform::MayIUse(platform::avx) && d > 512;
} }
template <> template <>
bool VScalKernel<float>::UseMe(const int& d) const { bool VScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return platform::MayIUse(platform::avx512f) && d > 512;
} }
template <> template <>
bool VExpKernel<float>::UseMe(const int& d) const { bool VExpKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool VSquareKernel<float>::UseMe(const int& d) const { bool VSquareKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool VCopyKernel<float>::UseMe(const int& d) const { bool VCopyKernel<float>::CanBeUsed(const int& d) const {
return d > 15; return d > 15;
} }
template <> template <>
bool VBroadcastKernel<float>::UseMe(const int64_t& d) const { bool VBroadcastKernel<float>::CanBeUsed(const int64_t& d) const {
return d > 127; return d > 127;
} }
template <> template <>
bool VBroadcastKernel<double>::UseMe(const int64_t& attr) const { bool VBroadcastKernel<double>::CanBeUsed(const int64_t& attr) const {
return true; return true;
} }
template <> template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const { bool VSigmoidKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool VTanhKernel<float>::UseMe(const int& d) const { bool VTanhKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const { bool SeqPoolKernel<float>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const { bool SeqPoolKernel<double>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool EmbSeqPoolKernel<float>::UseMe(const emb_seq_pool_attr_t& attr) const { bool EmbSeqPoolKernel<float>::CanBeUsed(const emb_seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const { bool EmbSeqPoolKernel<double>::CanBeUsed(
const emb_seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SgdKernel<float>::UseMe(const sgd_attr_t& attr) const { bool SgdKernel<float>::CanBeUsed(const sgd_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SgdKernel<double>::UseMe(const sgd_attr_t& attr) const { bool SgdKernel<double>::CanBeUsed(const sgd_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const { bool MatMulKernel<float>::CanBeUsed(const matmul_attr_t& attr) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
template <> template <>
bool MatMulKernel<double>::UseMe(const matmul_attr_t& attr) const { bool MatMulKernel<double>::CanBeUsed(const matmul_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SoftmaxKernel<float>::UseMe(const int& d) const { bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
// tuned on avx2 // tuned on avx2
return platform::MayIUse(platform::avx) && d < 60; return platform::MayIUse(platform::avx) && d < 60;
} }
#define AWALYS_USE_ME_WITH_DOUBLE(func) \ #define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \ template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \ bool func##Kernel<double>::CanBeUsed(const int& d) const { \
return true; \ return true; \
} }
...@@ -250,23 +251,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax); ...@@ -250,23 +251,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax);
namespace mkl = paddle::operators::jit::more::mkl; namespace mkl = paddle::operators::jit::more::mkl;
#define REGISTER_MKL_KERNEL(key, func) \ #define REGISTER_MKL_KERNEL(func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \ REGISTER_JITKERNEL_MORE(k##func, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>) mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kMatMul, MatMul); REGISTER_MKL_KERNEL(MatMul);
REGISTER_MKL_KERNEL(kVMul, VMul); REGISTER_MKL_KERNEL(VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd); REGISTER_MKL_KERNEL(VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(VScal);
REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare); REGISTER_MKL_KERNEL(VSquare);
REGISTER_MKL_KERNEL(kVCopy, VCopy); REGISTER_MKL_KERNEL(VCopy);
REGISTER_MKL_KERNEL(kVBroadcast, VBroadcast); REGISTER_MKL_KERNEL(VBroadcast);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(SeqPool);
REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_MKL_KERNEL(EmbSeqPool);
REGISTER_MKL_KERNEL(kSoftmax, Softmax); REGISTER_MKL_KERNEL(Softmax);
REGISTER_MKL_KERNEL(kSgd, Sgd); REGISTER_MKL_KERNEL(Sgd);
#undef REGISTER_MKL_KERNEL #undef REGISTER_MKL_KERNEL
...@@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, ...@@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
} }
} }
#define DECLARE_MKL_KERNEL(name, tuples) \ #define DECLARE_MKL_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name<T>; } \ name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \ bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \ const char* ImplType() const override { return "MKL"; } \
} }
// ABCMNK // ABCMNK
DECLARE_MKL_KERNEL(MatMul, MatMulTuples); DECLARE_MKL_KERNEL(MatMul);
// XYZN // XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples); DECLARE_MKL_KERNEL(VMul);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples); DECLARE_MKL_KERNEL(VAdd);
// AXYN // AXYN
DECLARE_MKL_KERNEL(VScal, AXYNTuples); DECLARE_MKL_KERNEL(VScal);
// XYN // XYN
DECLARE_MKL_KERNEL(VExp, XYNTuples); DECLARE_MKL_KERNEL(VExp);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid);
DECLARE_MKL_KERNEL(VTanh, XYNTuples); DECLARE_MKL_KERNEL(VTanh);
DECLARE_MKL_KERNEL(VSquare, XYNTuples); DECLARE_MKL_KERNEL(VSquare);
DECLARE_MKL_KERNEL(VCopy, XYNTuples); DECLARE_MKL_KERNEL(VCopy);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); // others
DECLARE_MKL_KERNEL(SeqPool);
DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples); DECLARE_MKL_KERNEL(EmbSeqPool);
DECLARE_MKL_KERNEL(Softmax);
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); DECLARE_MKL_KERNEL(Sgd);
DECLARE_MKL_KERNEL(VBroadcast);
DECLARE_MKL_KERNEL(Sgd, SgdTuples);
DECLARE_MKL_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_MKL_KERNEL #undef DECLARE_MKL_KERNEL
......
...@@ -17,51 +17,43 @@ ...@@ -17,51 +17,43 @@
namespace refer = paddle::operators::jit::refer; namespace refer = paddle::operators::jit::refer;
#define REGISTER_REFER_KERNEL(key, func) \ #define REGISTER_REFER_KERNEL(func) \
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \ REGISTER_JITKERNEL_REFER(k##func, refer::func##Kernel<float>, \
refer::func##Kernel<double>) refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(kVMul, VMul); REGISTER_REFER_KERNEL(VMul);
REGISTER_REFER_KERNEL(kVAdd, VAdd); REGISTER_REFER_KERNEL(VAdd);
REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu); REGISTER_REFER_KERNEL(VAddRelu);
REGISTER_REFER_KERNEL(kVSub, VSub); REGISTER_REFER_KERNEL(VSub);
REGISTER_REFER_KERNEL(kVScal, VScal); REGISTER_REFER_KERNEL(VScal);
REGISTER_REFER_KERNEL(kVAddBias, VAddBias); REGISTER_REFER_KERNEL(VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu); REGISTER_REFER_KERNEL(VRelu);
REGISTER_REFER_KERNEL(kVCopy, VCopy); REGISTER_REFER_KERNEL(VCopy);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity); REGISTER_REFER_KERNEL(VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare); REGISTER_REFER_KERNEL(VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp); REGISTER_REFER_KERNEL(VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid); REGISTER_REFER_KERNEL(VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh); REGISTER_REFER_KERNEL(VTanh);
REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt); REGISTER_REFER_KERNEL(LSTMCtHt);
REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1); REGISTER_REFER_KERNEL(LSTMC1H1);
REGISTER_REFER_KERNEL(kGRUH1, GRUH1); REGISTER_REFER_KERNEL(GRUH1);
REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1); REGISTER_REFER_KERNEL(GRUHtPart1);
REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2); REGISTER_REFER_KERNEL(GRUHtPart2);
REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding); REGISTER_REFER_KERNEL(CRFDecoding);
REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); REGISTER_REFER_KERNEL(LayerNorm);
REGISTER_REFER_KERNEL(NCHW16CMulNC);
REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(kSeqPool, SeqPool); REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(kMatMul, MatMul); REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(kHMax, HMax); REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(kHSum, HSum); REGISTER_REFER_KERNEL(VBroadcast);
REGISTER_REFER_KERNEL(kSoftmax, Softmax);
REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_REFER_KERNEL(kSgd, Sgd);
REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, ...@@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
} }
} }
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name<T>; } \ name##Kernel() { this->func = name<T>; } \
} }
// const T* x, const T* y, T* z, int n // const T* x, const T* y, T* z, int n
DECLARE_REFER_KERNEL(VMul, XYZNTuples); DECLARE_REFER_KERNEL(VMul);
DECLARE_REFER_KERNEL(VAdd, XYZNTuples); DECLARE_REFER_KERNEL(VAdd);
DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples); DECLARE_REFER_KERNEL(VAddRelu);
DECLARE_REFER_KERNEL(VSub, XYZNTuples); DECLARE_REFER_KERNEL(VSub);
// const T* a, const T* x, T* y, int n // const T* a, const T* x, T* y, int n
DECLARE_REFER_KERNEL(VScal, AXYNTuples); DECLARE_REFER_KERNEL(VScal);
DECLARE_REFER_KERNEL(VAddBias, AXYNTuples); DECLARE_REFER_KERNEL(VAddBias);
// const T* x, T* y, int n // const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu, XYNTuples); DECLARE_REFER_KERNEL(VRelu);
DECLARE_REFER_KERNEL(VIdentity, XYNTuples); DECLARE_REFER_KERNEL(VIdentity);
DECLARE_REFER_KERNEL(VExp, XYNTuples); DECLARE_REFER_KERNEL(VExp);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid);
DECLARE_REFER_KERNEL(VTanh, XYNTuples); DECLARE_REFER_KERNEL(VTanh);
DECLARE_REFER_KERNEL(VSquare, XYNTuples); DECLARE_REFER_KERNEL(VSquare);
DECLARE_REFER_KERNEL(VCopy, XYNTuples); DECLARE_REFER_KERNEL(VCopy);
// lstm_t*, const lstm_attr_t* // lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_REFER_KERNEL(LSTMCtHt);
DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples); DECLARE_REFER_KERNEL(LSTMC1H1);
// gru_t*, const gru_attr_t* // gru_t*, const gru_attr_t*
DECLARE_REFER_KERNEL(GRUH1, GRUTuples); DECLARE_REFER_KERNEL(GRUH1);
DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples); DECLARE_REFER_KERNEL(GRUHtPart1);
DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples); DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples); DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); DECLARE_REFER_KERNEL(HSum);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); // others
DECLARE_REFER_KERNEL(CRFDecoding);
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); DECLARE_REFER_KERNEL(LayerNorm);
DECLARE_REFER_KERNEL(NCHW16CMulNC);
DECLARE_REFER_KERNEL(MatMul, MatMulTuples); DECLARE_REFER_KERNEL(SeqPool);
DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(HMax, XRNTuples); DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(HSum, XRNTuples); DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples); DECLARE_REFER_KERNEL(VBroadcast);
DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_REFER_KERNEL(Sgd, SgdTuples);
DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <tuple> #include <tuple>
#include <type_traits> #include <type_traits>
#include <utility> // for std::move
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h" #include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -49,7 +50,7 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> { ...@@ -49,7 +50,7 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
void operator()(KernelType kt) const { void operator()(KernelType kt) const {
KernelKey kkey(kt, PlaceType()); KernelKey kkey(kt, PlaceType());
Pool().Instance().Insert(kkey, Pool::Instance().Insert(kkey,
std::move(make_unique<const KERNEL_IMPL_TYPE>())); std::move(make_unique<const KERNEL_IMPL_TYPE>()));
constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value; constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1, JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
......
此差异已折叠。
...@@ -230,8 +230,8 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -230,8 +230,8 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(bias->numel(), right); PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker = auto ker =
jit::Get<jit::kLayerNorm, jit::LayerNormTuples<T>, platform::CPUPlace>( jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
right); .At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(), ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left), scale->data<T>(), bias->data<T>(), static_cast<int>(left),
static_cast<const float>(epsilon), right); static_cast<const float>(epsilon), right);
......
...@@ -30,17 +30,16 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -30,17 +30,16 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return; return;
} }
if (relu) { if (relu) {
auto compute = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>, auto compute =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
.At(N); N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} else { } else {
auto compute = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<T>, auto compute =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
.At(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
......
...@@ -256,8 +256,8 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -256,8 +256,8 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
static_cast<int>(input.numel() / input.dims()[0]), static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum); jit::SeqPoolType::kSum);
auto seqpool = auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache()
attr); .At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]); attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr); seqpool(src, dst, &attr);
......
...@@ -82,8 +82,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -82,8 +82,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const int kClassDim = 1; const int kClassDim = 1;
// 2D data. Batch x C // 2D data. Batch x C
auto compute_softmax = auto compute_softmax =
jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>, jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]); .At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
} }
......
...@@ -48,7 +48,8 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -48,7 +48,8 @@ class SGDOpKernel : public framework::OpKernel<T> {
T *out_data = param_out->mutable_data<T>(ctx.GetPlace()); T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto sgd = auto sgd =
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr); jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced. // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
...@@ -82,7 +83,8 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -82,7 +83,8 @@ class SGDOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width); PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
auto sgd = auto sgd =
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr); jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr); sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
......
...@@ -282,7 +282,9 @@ class RecurrentOp : public RecurrentBase { ...@@ -282,7 +282,9 @@ class RecurrentOp : public RecurrentBase {
// Every inputs are linked now, execute! // Every inputs are linked now, execute!
executor.Run(*program, &cur_scope, block->ID(), executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/); false /*create_local_scope*/, true /*create_vars*/,
std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
...@@ -398,7 +400,9 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -398,7 +400,9 @@ class RecurrentGradOp : public RecurrentBase {
VLOG(5) << "Recurrent memory linking finished "; VLOG(5) << "Recurrent memory linking finished ";
// Run step block with cur_scope // Run step block with cur_scope
executor.Run(*program, &cur_scope, block->ID(), executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/); false /*create_local_scope*/, true /*create_vars*/,
std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
VLOG(5) << "executor.Run finished "; VLOG(5) << "executor.Run finished ";
......
...@@ -13,10 +13,18 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/imperative.h"
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -31,20 +39,20 @@ void BindTracer(pybind11::module* m) { ...@@ -31,20 +39,20 @@ void BindTracer(pybind11::module* m) {
[](imperative::Tracer& self, imperative::OpBase* op, [](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs, const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs, const imperative::VarBasePtrMap& outputs,
framework::BlockDesc* block, framework::AttributeMap attrs_map,
const platform::CPUPlace expected_place, const platform::CPUPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
return self.Trace(op, inputs, outputs, block, expected_place, return self.Trace(op, inputs, outputs, attrs_map, expected_place,
stop_gradient); stop_gradient);
}) })
.def("trace", .def("trace",
[](imperative::Tracer& self, imperative::OpBase* op, [](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs, const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs, const imperative::VarBasePtrMap& outputs,
framework::BlockDesc* block, framework::AttributeMap attrs_map,
const platform::CUDAPlace expected_place, const platform::CUDAPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
return self.Trace(op, inputs, outputs, block, expected_place, return self.Trace(op, inputs, outputs, attrs_map, expected_place,
stop_gradient); stop_gradient);
}) })
.def("py_trace", &imperative::Tracer::PyTrace, .def("py_trace", &imperative::Tracer::PyTrace,
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <Python.h> #include <Python.h>
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
...@@ -36,6 +37,8 @@ class Layer : public imperative::Layer { ...@@ -36,6 +37,8 @@ class Layer : public imperative::Layer {
class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase { class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase {
public: public:
using imperative::OpBase::OpBase; // Inherit constructors using imperative::OpBase::OpBase; // Inherit constructors
PyOpBase(const std::string& name) : OpBase(name) {}
}; };
class PyVarBase : public imperative::VarBase { class PyVarBase : public imperative::VarBase {
......
...@@ -23,97 +23,7 @@ limitations under the License. */ ...@@ -23,97 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
// Cast boost::variant for PyBind. #include "paddle/fluid/pybind/pybind_boost_headers.h"
// Copy from
// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
namespace pybind11 {
namespace detail {
#if !defined(PYBIND11_HIDDEN)
#ifdef _WIN32
#define PYBIND11_HIDDEN __declspec(dllexport)
#else
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
#endif
#endif
// Can be replaced by a generic lambda in C++14
struct PYBIND11_HIDDEN paddle_variant_caster_visitor
: public boost::static_visitor<handle> {
return_value_policy policy;
handle parent;
paddle_variant_caster_visitor(return_value_policy policy, handle parent)
: policy(policy), parent(parent) {}
template <class T>
handle operator()(T const &src) const {
return make_caster<T>::cast(src, policy, parent);
}
};
template <class Variant>
struct paddle_variant_caster;
template <template <class...> class V, class... Ts>
struct paddle_variant_caster<V<Ts...>> {
using Type = V<Ts...>;
template <typename T>
typename std::enable_if<
!std::is_same<T, boost::detail::variant::void_>::value, bool>::type
try_load(handle src, bool convert) {
auto caster = make_caster<T>();
if (!load_success_ && caster.load(src, convert)) {
load_success_ = true;
if (std::is_same<T, std::vector<float>>::value) {
auto caster_ints = make_caster<std::vector<int64_t>>();
if (caster_ints.load(src, convert)) {
VLOG(4) << "This value are floats and int64_ts satisfy "
"simultaneously, will set it's type to "
"std::vector<int64_t>";
value = cast_op<std::vector<int64_t>>(caster_ints);
return true;
}
}
value = cast_op<T>(caster);
return true;
}
return false;
}
template <typename T>
typename std::enable_if<std::is_same<T, boost::detail::variant::void_>::value,
bool>::type
try_load(handle src, bool convert) {
return false;
}
bool load(handle src, bool convert) {
auto unused = {false, try_load<Ts>(src, convert)...};
(void)(unused);
return load_success_;
}
static handle cast(Type const &src, return_value_policy policy,
handle parent) {
paddle_variant_caster_visitor visitor(policy, parent);
return boost::apply_visitor(visitor, src);
}
PYBIND11_TYPE_CASTER(Type, _("Variant"));
bool load_success_{false};
};
// Add specialization for concrete variant type
template <class... Args>
struct type_caster<boost::variant<Args...>>
: paddle_variant_caster<boost::variant<Args...>> {};
} // namespace detail
} // namespace pybind11
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
......
...@@ -149,8 +149,14 @@ PYBIND11_MODULE(core, m) { ...@@ -149,8 +149,14 @@ PYBIND11_MODULE(core, m) {
[]() { return memory::allocation::GPUMemMonitor.PrintMemUsage(); }); []() { return memory::allocation::GPUMemMonitor.PrintMemUsage(); });
py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC") py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC")
// .def(py::init<>()) .def(
.def(py::init<bool>(), py::arg("stop_gradient") = false) py::init<const std::string &, paddle::framework::proto::VarType::Type,
const std::vector<int64_t>, const paddle::platform::CPUPlace,
bool, bool>())
.def(
py::init<const std::string &, paddle::framework::proto::VarType::Type,
const std::vector<int64_t>,
const paddle::platform::CUDAPlace, bool, bool>())
.def("_run_backward", .def("_run_backward",
[](imperative::VarBase &self) { self.RunBackward(); }) [](imperative::VarBase &self) { self.RunBackward(); })
.def("_grad_name", &imperative::VarBase::GradName) .def("_grad_name", &imperative::VarBase::GradName)
...@@ -177,51 +183,21 @@ PYBIND11_MODULE(core, m) { ...@@ -177,51 +183,21 @@ PYBIND11_MODULE(core, m) {
py::return_value_policy::take_ownership) py::return_value_policy::take_ownership)
.def("value", [](const imperative::VarBase &self) { return self.var_; }, .def("value", [](const imperative::VarBase &self) { return self.var_; },
py::return_value_policy::reference) py::return_value_policy::reference)
.def_property("name", .def_property("name", &imperative::VarBase::Name,
[](const imperative::VarBase &self) { return self.name_; }, &imperative::VarBase::SetName)
[](imperative::VarBase &self, const std::string &name) { .def_property_readonly("shape", &imperative::VarBase::Shape)
self.name_ = name; .def_property_readonly("dtype", &imperative::VarBase::DType)
}) .def_property("persistable", &imperative::VarBase::IsPersistable,
.def_property("block", &imperative::VarBase::SetPersistable)
[](const imperative::VarBase &self) { return self.block_; }, .def_property("stop_gradient", &imperative::VarBase::IsStopGradient,
[](imperative::VarBase &self, framework::BlockDesc *block) { &imperative::VarBase::SetStopGradient);
self.block_ = block;
},
py::return_value_policy::reference)
.def_property(
"persistable",
[](const imperative::VarBase &self) { return self.persistable_; },
[](imperative::VarBase &self, const bool persistable) {
self.persistable_ = persistable;
})
.def_property(
"desc",
[](const imperative::VarBase &self) { return self.var_desc_; },
[](imperative::VarBase &self, framework::VarDesc *var_desc) {
self.var_desc_ = var_desc;
},
py::return_value_policy::reference)
.def_property(
"stop_gradient",
[](const imperative::VarBase &self) { return self.IsStopGradient(); },
[](imperative::VarBase &self, bool stop_gradient) {
self.SetStopGradient(stop_gradient);
});
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC") py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>()) .def(py::init<const std::string &>())
.def("register_backward_hooks", .def("register_backward_hooks",
[](imperative::OpBase &self, const py::object &callable) { [](imperative::OpBase &self, const py::object &callable) {
self.RegisterBackwardHooks(callable); self.RegisterBackwardHooks(callable);
}) })
.def_property(
"desc", [](const imperative::OpBase &self) { return self.op_desc_; },
[](imperative::OpBase &self, framework::OpDesc *op_desc) {
if (op_desc) {
self.op_desc_ = op_desc;
}
},
py::return_value_policy::reference)
.def_property("_trace_id", .def_property("_trace_id",
[](const imperative::OpBase &self) { [](const imperative::OpBase &self) {
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
...@@ -260,7 +236,17 @@ PYBIND11_MODULE(core, m) { ...@@ -260,7 +236,17 @@ PYBIND11_MODULE(core, m) {
"apply", "apply",
[](int func_id, const std::vector<imperative::VarBase *> &inputs) [](int func_id, const std::vector<imperative::VarBase *> &inputs)
-> std::vector<imperative::VarBase *> { -> std::vector<imperative::VarBase *> {
return imperative::PyLayer::Apply(func_id, inputs); auto ret_vars = imperative::PyLayer::Apply(func_id, inputs);
std::vector<imperative::VarBase *> outputs;
outputs.reserve(ret_vars.size());
for (size_t i = 0U; i != ret_vars.size(); ++i) {
framework::Variable *v = ret_vars[i];
// TODO(minqiyang): use unique_name generator to set a name
outputs.emplace_back(
new imperative::VarBase("", v, nullptr, true));
}
return outputs;
}, },
py::return_value_policy::take_ownership) py::return_value_policy::take_ownership)
.def_static("register_func", .def_static("register_func",
...@@ -876,9 +862,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -876,9 +862,11 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
.def("close", &Executor::Close) .def("close", &Executor::Close)
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) { int block_id, bool create_local_scope, bool create_vars,
const std::vector<std::string> &fetch_vars) {
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
self.Run(prog, scope, block_id, create_local_scope, create_vars); self.Run(prog, scope, block_id, create_local_scope, create_vars,
fetch_vars);
}); });
m.def("init_gflags", framework::InitGflags); m.def("init_gflags", framework::InitGflags);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/platform/variant.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
// Cast boost::variant for PyBind.
// Copy from
// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
namespace pybind11 {
namespace detail {
#if !defined(PYBIND11_HIDDEN)
#ifdef _WIN32
#define PYBIND11_HIDDEN __declspec(dllexport)
#else
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
#endif
#endif
// Can be replaced by a generic lambda in C++14
struct PYBIND11_HIDDEN paddle_variant_caster_visitor
: public boost::static_visitor<handle> {
return_value_policy policy;
handle parent;
paddle_variant_caster_visitor(return_value_policy policy, handle parent)
: policy(policy), parent(parent) {}
template <class T>
handle operator()(T const &src) const {
return make_caster<T>::cast(src, policy, parent);
}
};
template <class Variant>
struct paddle_variant_caster;
template <template <class...> class V, class... Ts>
struct paddle_variant_caster<V<Ts...>> {
using Type = V<Ts...>;
template <typename T>
typename std::enable_if<
!std::is_same<T, boost::detail::variant::void_>::value, bool>::type
try_load(handle src, bool convert) {
auto caster = make_caster<T>();
if (!load_success_ && caster.load(src, convert)) {
load_success_ = true;
if (std::is_same<T, std::vector<float>>::value) {
auto caster_ints = make_caster<std::vector<int64_t>>();
if (caster_ints.load(src, convert)) {
VLOG(4) << "This value are floats and int64_ts satisfy "
"simultaneously, will set it's type to "
"std::vector<int64_t>";
value = cast_op<std::vector<int64_t>>(caster_ints);
return true;
}
}
value = cast_op<T>(caster);
return true;
}
return false;
}
template <typename T>
typename std::enable_if<std::is_same<T, boost::detail::variant::void_>::value,
bool>::type
try_load(handle src, bool convert) {
return false;
}
bool load(handle src, bool convert) {
auto unused = {false, try_load<Ts>(src, convert)...};
(void)(unused);
return load_success_;
}
static handle cast(Type const &src, return_value_policy policy,
handle parent) {
paddle_variant_caster_visitor visitor(policy, parent);
return boost::apply_visitor(visitor, src);
}
PYBIND11_TYPE_CASTER(Type, _("Variant"));
bool load_success_{false};
};
// Add specialization for concrete variant type
template <class... Args>
struct type_caster<boost::variant<Args...>>
: paddle_variant_caster<boost::variant<Args...>> {};
} // namespace detail
} // namespace pybind11
...@@ -128,11 +128,11 @@ def __bootstrap__(): ...@@ -128,11 +128,11 @@ def __bootstrap__():
'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph', 'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph',
'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory',
'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb', 'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb',
'fast_eager_deletion_mode', 'allocator_strategy', 'fast_eager_deletion_mode', 'memory_fraction_of_eager_deletion',
'reader_queue_speed_test_mode', 'print_sub_graph_dir', 'allocator_strategy', 'reader_queue_speed_test_mode',
'pe_profile_fname', 'warpctc_dir', 'inner_op_parallelism', 'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir',
'enable_parallel_graph', 'multiple_of_cupti_buffer_size', 'inner_op_parallelism', 'enable_parallel_graph',
'enable_subgraph_optimize' 'multiple_of_cupti_buffer_size', 'enable_subgraph_optimize'
] ]
if 'Darwin' not in sysstr: if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory') read_env_flags.append('use_pinned_memory')
......
...@@ -590,7 +590,7 @@ class Executor(object): ...@@ -590,7 +590,7 @@ class Executor(object):
fetch_var_name=fetch_var_name) fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope) self._feed_data(program, feed, feed_var_name, scope)
exe.run(program.desc, scope, 0, True, True) exe.run(program.desc, scope, 0, True, True, fetch_var_name)
outs = self._fetch_data(fetch_list, fetch_var_name, scope) outs = self._fetch_data(fetch_list, fetch_var_name, scope)
if return_numpy: if return_numpy:
outs = as_numpy(outs) outs = as_numpy(outs)
......
此差异已折叠。
...@@ -258,7 +258,7 @@ class PyLayer(core.PyLayer): ...@@ -258,7 +258,7 @@ class PyLayer(core.PyLayer):
cls.backward_id = core.PyLayer.num_funcs() + 1 cls.backward_id = core.PyLayer.num_funcs() + 1
PyLayer.register_func(cls.backward_id, cls._do_backward) PyLayer.register_func(cls.backward_id, cls._do_backward)
iop = core.OpBase() iop = core.OpBase(cls.__class__.__name__ + str(cls.forward_id))
iop.forward_id = cls.forward_id iop.forward_id = cls.forward_id
iop.backward_id = cls.backward_id iop.backward_id = cls.backward_id
block.ops.append(iop) block.ops.append(iop)
......
...@@ -36,14 +36,21 @@ class Tracer(core.Tracer): ...@@ -36,14 +36,21 @@ class Tracer(core.Tracer):
super(Tracer, self).__init__(block) super(Tracer, self).__init__(block)
self._ops = defaultdict() self._ops = defaultdict()
self._vars = defaultdict()
self._trace_id = 0 self._trace_id = 0
def trace_var(self, name, var):
self._vars[name] = var
def all_parameters(self):
return list((item for name, item in six.iteritems(self._vars)
if isinstance(item, framework.Parameter)))
def trace_op(self, op, stop_gradient=False): def trace_op(self, op, stop_gradient=False):
# record op's trace id # record op's trace id
op.iop._trace_id = self._trace_id op.iop._trace_id = self._trace_id
# trace op and save it backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs,
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.block.desc,
framework._current_expected_place(), framework._current_expected_place(),
stop_gradient) stop_gradient)
......
...@@ -10704,8 +10704,9 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002): ...@@ -10704,8 +10704,9 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
similarity_matrix = matmul( similarity_matrix = matmul(
anchor, positive, transpose_x=False, transpose_y=True) anchor, positive, transpose_x=False, transpose_y=True)
softmax_value = softmax(similarity_matrix) softmax_ce = softmax_with_cross_entropy(
cross_entropy = -1 * reduce_sum(labels * log(softmax_value), 0) logits=similarity_matrix, label=labels, soft_label=True)
cross_entropy = reduce_sum(labels * softmax_ce, 0)
celoss = reduce_mean(cross_entropy) celoss = reduce_mean(cross_entropy)
return l2loss + celoss return l2loss + celoss
...@@ -377,17 +377,16 @@ class Optimizer(object): ...@@ -377,17 +377,16 @@ class Optimizer(object):
and list of (param, grad) Variables pair for optimization. and list of (param, grad) Variables pair for optimization.
""" """
self._dtype = loss.dtype self._dtype = loss.dtype
program = loss.block.program
optimize_ops = [] optimize_ops = []
if framework._in_imperative_mode(): if framework._in_imperative_mode():
if parameter_list is not None: if parameter_list is not None:
parameters = parameter_list parameters = parameter_list
else: else:
parameters = program.global_block().all_parameters() parameters = framework._imperative_tracer().all_parameters()
params_grads = [] params_grads = []
for param in parameters: for param in parameters:
if param.stop_gradient or not param.trainable: if not param.trainable:
continue continue
# create gradient variable # create gradient variable
grad_var = Variable( grad_var = Variable(
...@@ -396,9 +395,11 @@ class Optimizer(object): ...@@ -396,9 +395,11 @@ class Optimizer(object):
stop_gradient=True, stop_gradient=True,
ivar=param._ivar._grad_ivar()) ivar=param._ivar._grad_ivar())
params_grads.append((param, grad_var)) params_grads.append((param, grad_var))
with program_guard(program, startup_program): with program_guard(framework.default_main_program(),
framework.default_startup_program()):
optimize_ops = self._create_optimization_pass(params_grads) optimize_ops = self._create_optimization_pass(params_grads)
else: else:
program = loss.block.program
with program_guard(program, startup_program): with program_guard(program, startup_program):
params_grads = self.backward(loss, startup_program, params_grads = self.backward(loss, startup_program,
parameter_list, no_grad_set) parameter_list, no_grad_set)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册