提交 8f6597aa 编写于 作者: L luotao1

Merge branch 'develop' into infershape_example

...@@ -174,7 +174,7 @@ else() ...@@ -174,7 +174,7 @@ else()
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
target_link_libraries(executor garbage_collector) target_link_libraries(executor garbage_collector while_op_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
......
...@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_ ...@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -31,6 +32,8 @@ class ComputationOpHandle : public OpHandleBase { ...@@ -31,6 +32,8 @@ class ComputationOpHandle : public OpHandleBase {
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place,
size_t scope_idx); size_t scope_idx);
OperatorBase *GetOp() { return op_.get(); }
std::string Name() const override; std::string Name() const override;
const Scope *GetScope() const { return scope_; } const Scope *GetScope() const { return scope_; }
......
...@@ -12,6 +12,10 @@ ...@@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -45,6 +49,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle( ...@@ -45,6 +49,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
} }
} }
#endif #endif
PADDLE_ENFORCE(!var_names_.empty(), "Var names cannot be empty");
} }
EagerDeletionOpHandle::~EagerDeletionOpHandle() { EagerDeletionOpHandle::~EagerDeletionOpHandle() {
...@@ -60,15 +65,20 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() { ...@@ -60,15 +65,20 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() { void EagerDeletionOpHandle::RunImpl() {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); Scope *exec_scope = nullptr;
std::deque<std::shared_ptr<memory::Allocation>> garbages; std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) { for (auto &name : var_names_) {
auto it = ref_cnts_->find(name); auto it = ref_cnts_->find(name);
// Var not found, not reference count has not decreased to 0 // Reference count has not decreased to 0
if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) { if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
continue; continue;
} }
if (!exec_scope) {
exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
}
// Var not found
auto *var = exec_scope->FindVar(name); auto *var = exec_scope->FindVar(name);
if (var == nullptr) { if (var == nullptr) {
continue; continue;
......
...@@ -12,20 +12,173 @@ ...@@ -12,20 +12,173 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include <functional>
#include <queue> #include <queue>
#include <string> #include <string>
#include <tuple>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
"Fraction of eager deletion. If less than 1.0, all variables in "
"the program would be sorted according to its memory size, and "
"only the FLAGS_memory_fraction_of_eager_deletion of the largest "
"variables would be deleted.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
// op -> variables which can be deleted after op runs
using OpToVarNameSetMap =
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
// Check whether the variable is LoDTensor based on static VarDesc info
static bool IsLoDTensor(VarDesc *var) {
return var->Proto()->type().type() == proto::VarType::LOD_TENSOR;
}
// Get memory size of LoDTensor
static int64_t GetMemorySize(
const std::unordered_map<std::string, std::vector<VarHandle *>> &vars,
const std::string &var_name) {
auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
PADDLE_ENFORCE_NOT_NULL(var_desc);
PADDLE_ENFORCE(IsLoDTensor(var_desc));
auto dims = var_desc->GetShape();
return SizeOfType(var_desc->GetDataType()) *
std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1),
std::multiplies<int64_t>());
}
// Split all variables in the graph into LoDTensor and Non-LoDTensor (e.g.
// SelectedRows, LoDTensorArray)
// Since partial GC is based on static analysis of memory size of each variable
// So we should skip SelectedRows and LoDTensorArray here
static void SplitIntoLoDTensorAndNonLoDTensorVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
lod_tensors->clear();
other_vars->clear();
for (auto &op_vars_pair : m) {
for (auto &var_name : op_vars_pair.second) {
auto *var_desc = TryGetLatestVarDesc(
vars[op_vars_pair.first->GetScopeIdx()].at(var_name));
if (IsLoDTensor(var_desc)) {
(*lod_tensors)[op_vars_pair.first].insert(var_name);
} else {
(*other_vars)[op_vars_pair.first].insert(var_name);
}
}
}
}
struct GCVarInfo {
GCVarInfo(const std::string &name, int64_t memory_size,
ComputationOpHandle *op, size_t scope_idx)
: name_(name),
memory_size_(memory_size),
op_(op),
scope_idx_(scope_idx) {}
std::string name_; // variable name
int64_t memory_size_; // memory size
ComputationOpHandle *op_; // op after which the variable could be deleted
size_t scope_idx_; // scope index where the variable locates
int64_t AbsMemorySize() const { return std::abs(memory_size_); }
};
// Delete delete_lod_tensor_only is not used currently
static OpToVarNameSetMap ShrinkGCVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
const std::vector<platform::Place> &places, double fraction_of_memory_size,
bool delete_lod_tensor_only = false) {
// Do not perform gc when fraction_of_memory_size = 0
if (fraction_of_memory_size <= 0.0) return {};
/**
* Step 1: Split all variables into LoDTensor and Non-LoDTensor.
* We can only calculate memory size of LoDTensors
*/
OpToVarNameSetMap lod_tensors, other_vars;
SplitIntoLoDTensorAndNonLoDTensorVars(m, vars, &lod_tensors, &other_vars);
// Perform complete gc when fraction_of_memory_size >= 1
if (fraction_of_memory_size >= 1.0) {
return delete_lod_tensor_only ? lod_tensors : m;
}
/**
* Step 2: build GCVarInfos, and calculate total memory sizes of each device
*/
// place -> variable info (name, memory size, place, scope_idx)
std::map<platform::Place, std::vector<GCVarInfo>> place_to_vars;
// place -> total memory sizes
std::map<platform::Place, int64_t> place_to_size;
for (auto &op_vars_pair : lod_tensors) {
auto *op = op_vars_pair.first;
auto &var_names = op_vars_pair.second;
auto scope_idx = op->GetScopeIdx();
auto &place = places[scope_idx];
for (auto &var_name : var_names) {
auto var_size = GetMemorySize(vars[scope_idx], var_name);
GCVarInfo var_info(var_name, var_size, op, scope_idx);
place_to_size[place] += var_info.AbsMemorySize();
place_to_vars[place].emplace_back(std::move(var_info));
}
}
/**
* Step 3: sort GCVarInfos, and only delete the largest variables.
*/
OpToVarNameSetMap partial_vars;
for (auto &place_to_var_pair : place_to_vars) {
auto &place = place_to_var_pair.first;
auto &gc_vars = place_to_var_pair.second;
std::sort(gc_vars.begin(), gc_vars.end(),
[](const GCVarInfo &var1, const GCVarInfo &var2) {
return var1.AbsMemorySize() > var2.AbsMemorySize();
});
int64_t accumulated_size = 0;
int64_t size_threshold =
static_cast<int64_t>(fraction_of_memory_size * place_to_size[place]);
for (size_t i = 0; i < gc_vars.size() && accumulated_size < size_threshold;
++i) {
partial_vars[gc_vars[i].op_].insert(gc_vars[i].name_);
accumulated_size += gc_vars[i].AbsMemorySize();
}
}
/**
* Step 4: Combine other vars (SelectedRows, LoDTensorArray)
*/
if (!delete_lod_tensor_only) {
for (auto &op_vars_pair : other_vars) {
partial_vars[op_vars_pair.first].insert(op_vars_pair.second.begin(),
op_vars_pair.second.end());
}
}
return partial_vars;
}
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = auto &ref_cnts =
...@@ -43,9 +196,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -43,9 +196,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
// a reverse map of last_live_ops // a reverse map of last_live_ops
// i.e., last op --> variable names which can be deleted. // i.e., last op --> variable names which can be deleted.
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>> OpToVarNameSetMap op_vars_map;
op_vars_map;
for (auto &var_ops_map : last_live_ops) { for (auto &var_ops_map : last_live_ops) {
for (auto &var_ops_pair : var_ops_map) { for (auto &var_ops_pair : var_ops_map) {
const std::string &var_name = var_ops_pair.first; const std::string &var_name = var_ops_pair.first;
...@@ -55,6 +206,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -55,6 +206,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
} }
} }
op_vars_map = ShrinkGCVars(op_vars_map, vars, places,
FLAGS_memory_fraction_of_eager_deletion);
for (auto &pair : op_vars_map) { for (auto &pair : op_vars_map) {
auto *op = pair.first; auto *op = pair.first;
auto &var_names = pair.second; auto &var_names = pair.second;
...@@ -85,8 +239,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -85,8 +239,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
eager_deletion_op->AddOutput(dummy_leaf); eager_deletion_op->AddOutput(dummy_leaf);
} }
VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = "
<< FLAGS_memory_fraction_of_eager_deletion;
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)"; VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
return graph;
auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
return while_op_eager_deletion_pass->Apply(std::move(graph));
} }
} // namespace details } // namespace details
...@@ -99,3 +258,5 @@ REGISTER_PASS(eager_deletion_pass, ...@@ -99,3 +258,5 @@ REGISTER_PASS(eager_deletion_pass,
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars) .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces) .RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector); .RequirePassAttr(paddle::framework::details::kGarbageCollector);
USE_PASS(while_op_eager_deletion_pass);
...@@ -12,9 +12,13 @@ ...@@ -12,9 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include <queue> #include <queue>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
...@@ -189,15 +193,6 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, ...@@ -189,15 +193,6 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
return shrink_func(computation_op); return shrink_func(computation_op);
} }
static VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount); auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
......
...@@ -13,9 +13,22 @@ ...@@ -13,9 +13,22 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details {} // namespace details namespace details {
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
} // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <atomic> #include <atomic>
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -25,6 +26,10 @@ ...@@ -25,6 +26,10 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class VarDesc;
class VarHandle;
namespace details { namespace details {
class ComputationOpHandle; class ComputationOpHandle;
...@@ -43,9 +48,11 @@ const char kGarbageCollector[] = "garbage_collector"; ...@@ -43,9 +48,11 @@ const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places"; const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars = using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>; std::unordered_map<std::string, std::unordered_set<ComputationOpHandle *>>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var"; const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
namespace paddle {
namespace framework {
namespace details {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// Find all while_op and while_grad_op
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
std::vector<OperatorBase *>>>
target_ops;
for (auto *op : all_ops) {
auto compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "while") {
target_ops[compute_op->GetScopeIdx()].first.emplace_back(
compute_op->GetOp());
} else if (compute_op->Name() == "while_grad") {
target_ops[compute_op->GetScopeIdx()].second.emplace_back(
compute_op->GetOp());
}
}
for (auto &ops_pair : target_ops) {
auto &while_ops = ops_pair.second.first;
auto &while_grad_ops = ops_pair.second.second;
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops);
}
return graph;
}
};
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(while_op_eager_deletion_pass,
paddle::framework::details::WhileOpEagerDeletionPass);
...@@ -14,6 +14,10 @@ limitations under the License. */ ...@@ -14,6 +14,10 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include <deque> #include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
...@@ -23,6 +27,7 @@ limitations under the License. */ ...@@ -23,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -75,11 +80,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts( ...@@ -75,11 +80,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
ExecutorPrepareContext::ExecutorPrepareContext( ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id, const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars) const std::vector<std::string>& keep_vars, bool force_disable_gc)
: prog_(prog), block_id_(block_id) { : prog_(prog), block_id_(block_id), force_disable_gc_(force_disable_gc) {
if (GetEagerDeletionThreshold() >= 0) { if (GetEagerDeletionThreshold() >= 0 && !force_disable_gc_) {
global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id), global_ref_cnts_ =
skip_ref_cnt_vars); GetNonPersistableReferenceCounts(prog.Block(block_id), keep_vars);
} }
} }
...@@ -184,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -184,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
} }
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars,
bool force_disable_gc) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc); if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc);
#endif #endif
auto ctx = Prepare(pdesc, block_id); auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
} }
...@@ -357,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -357,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars) { const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
std::unique_ptr<ExecutorPrepareContext> ctx( std::unique_ptr<ExecutorPrepareContext> ctx(new ExecutorPrepareContext(
new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars)); program, block_id, skip_ref_cnt_vars, force_disable_gc));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id); auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
...@@ -370,7 +377,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( ...@@ -370,7 +377,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids, const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars) { const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
bool force_disable_gc) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(), skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
"skip_ref_cnt_vars should be either empty or equals to block number %d", "skip_ref_cnt_vars should be either empty or equals to block number %d",
...@@ -380,9 +388,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( ...@@ -380,9 +388,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
for (auto& bid : block_ids) { for (auto& bid : block_ids) {
ExecutorPrepareContext* ctx; ExecutorPrepareContext* ctx;
if (skip_ref_cnt_vars.empty()) { if (skip_ref_cnt_vars.empty()) {
ctx = new ExecutorPrepareContext(program, bid); ctx = new ExecutorPrepareContext(program, bid, std::vector<std::string>(),
force_disable_gc);
} else { } else {
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]); ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx],
force_disable_gc);
} }
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
auto& block = program.Block(bid); auto& block = program.Block(bid);
...@@ -409,8 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -409,8 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc; std::unique_ptr<GarbageCollector> gc;
// skip while_op and while_grad_op temporarily // FIXME(zjl): recurrent_op is rather complex, we would
if (max_memory_size >= 0 && !keep_kids) { // disable gc forcely in recurrent_op
if (!ctx->force_disable_gc_ && max_memory_size >= 0) {
ctx->ResetReferenceCount(); ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
...@@ -428,6 +439,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -428,6 +439,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
#endif #endif
// If gc is enabled and block size > 1
if (gc && ctx->prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_,
ctx->ops_);
}
} }
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
...@@ -30,7 +32,8 @@ namespace framework { ...@@ -30,7 +32,8 @@ namespace framework {
struct ExecutorPrepareContext { struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id, ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>()); std::vector<std::string>(),
bool force_disable_gc = false);
~ExecutorPrepareContext(); ~ExecutorPrepareContext();
...@@ -38,6 +41,7 @@ struct ExecutorPrepareContext { ...@@ -38,6 +41,7 @@ struct ExecutorPrepareContext {
const framework::ProgramDesc& prog_; const framework::ProgramDesc& prog_;
size_t block_id_; size_t block_id_;
bool force_disable_gc_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, size_t> global_ref_cnts_; std::unordered_map<std::string, size_t> global_ref_cnts_;
...@@ -66,7 +70,10 @@ class Executor { ...@@ -66,7 +70,10 @@ class Executor {
* Scope * Scope
*/ */
void Run(const ProgramDesc& prog, Scope* scope, int block_id, void Run(const ProgramDesc& prog, Scope* scope, int block_id,
bool create_local_scope = true, bool create_vars = true); bool create_local_scope = true, bool create_vars = true,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
// This API is very slow. // This API is very slow.
void Run(const ProgramDesc& program, Scope* scope, void Run(const ProgramDesc& program, Scope* scope,
...@@ -79,12 +86,14 @@ class Executor { ...@@ -79,12 +86,14 @@ class Executor {
static std::unique_ptr<ExecutorPrepareContext> Prepare( static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>()); std::vector<std::string>(),
bool force_disable_gc = false);
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare( static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids, const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars = const std::vector<std::vector<std::string>>& skip_ref_cnt_vars =
std::vector<std::vector<std::string>>()); std::vector<std::vector<std::string>>(),
bool force_disable_gc = false);
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id); void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
......
...@@ -159,10 +159,9 @@ class Autograd { ...@@ -159,10 +159,9 @@ class Autograd {
for (auto it : candidate->pre_ops_) { for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) { for (OpBase* pre_op : it.second) {
if (!pre_op) continue; if (!pre_op) continue;
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " trace id " VLOG(5) << "op dep " << candidate->Type() << " trace id "
<< candidate->trace_id_ << " <---- " << it.first << " <---- " << candidate->trace_id_ << " <---- " << it.first << " <---- "
<< pre_op->op_desc_->Type() << " trace id " << pre_op->Type() << " trace id " << pre_op->trace_id_;
<< pre_op->trace_id_;
if (visited.find(pre_op) == visited.end()) { if (visited.find(pre_op) == visited.end()) {
visited.insert(pre_op); visited.insert(pre_op);
queue.push_back(pre_op); queue.push_back(pre_op);
...@@ -180,10 +179,12 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -180,10 +179,12 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
PADDLE_ENFORCE(var_->IsInitialized(), PADDLE_ENFORCE(var_->IsInitialized(),
"Variable must be initialized when getting numpy tensor"); "Variable must be initialized when getting numpy tensor");
std::unique_ptr<VarBase> new_var(new VarBase()); // TODO(minqiyang): change this after move unique_name generator to CXX
const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
std::unique_ptr<VarBase> new_var(new VarBase(
"Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
framework::LoDTensor* tensor = framework::LoDTensor* tensor =
new_var->var_->GetMutable<framework::LoDTensor>(); new_var->var_->GetMutable<framework::LoDTensor>();
tensor->Resize(var_->Get<framework::LoDTensor>().dims());
tensor->set_lod(var_->Get<framework::LoDTensor>().lod()); tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
if (blocking) { if (blocking) {
...@@ -199,52 +200,62 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -199,52 +200,62 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
} }
if (platform::is_gpu_place(dst_place)) { if (platform::is_gpu_place(dst_place)) {
VLOG(3) << "copy tensor " << var_desc_->Name() << " from gpu"; VLOG(3) << "copy tensor " << Name() << " from gpu";
} }
return new_var; return new_var;
} }
framework::LoDTensor& VarBase::GradValue() { framework::LoDTensor& VarBase::GradValue() {
VLOG(3) << "get var grad " << var_desc_->Name(); VLOG(3) << "get var grad " << Name();
PADDLE_ENFORCE_NOT_NULL(grads_,
"Could not get grad value from no grad variable");
return *(grads_->var_->GetMutable<framework::LoDTensor>()); return *(grads_->var_->GetMutable<framework::LoDTensor>());
} }
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (grad_op_descs_.empty() && backward_id_ <= 0) { if (grad_op_descs_.empty() && backward_id_ <= 0) {
VLOG(3) << "op with no grad: " << op_desc_->Type(); VLOG(3) << "op with no grad: " << Type();
return {}; return {};
} }
VLOG(3) << "apply op grad: " << op_desc_->Type(); VLOG(3) << "apply op grad: " << Type();
std::vector<framework::VariableValueMap> grad_outputs; std::vector<framework::VariableValueMap> tmp_grad_outputs;
if (backward_id_ > 0) { if (backward_id_ > 0) {
VLOG(3) << "py_layer_grad"; VLOG(3) << "py_layer_grad";
grad_outputs.resize(1); tmp_grad_outputs.resize(1);
grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] = tmp_grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
PyLayer::ApplyGrad( PyLayer::ApplyGrad(
backward_id_, backward_id_,
grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]); grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
} else { } else {
grad_outputs.resize(grad_op_descs_.size()); const size_t grad_op_count = grad_op_descs_.size();
for (size_t k = 0; k < grad_op_descs_.size(); ++k) {
tmp_grad_outputs.resize(grad_op_count);
for (size_t k = 0; k < grad_op_count; ++k) {
framework::OpDesc* grad_op_desc = grad_op_descs_[k]; framework::OpDesc* grad_op_desc = grad_op_descs_[k];
VLOG(3) << "op grad " << grad_op_desc->Type(); auto& grad_output_variable_map = grad_output_vars_[k];
for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first]; VLOG(3) << "apply grad op " << grad_op_desc->Type();
// Allocate tmp grad output variable
for (auto it : grad_output_variable_map) {
auto& outputs = tmp_grad_outputs[k][it.first];
outputs.reserve(it.second.size());
for (size_t i = 0; i < it.second.size(); ++i) { for (size_t i = 0; i < it.second.size(); ++i) {
// Allocate a new variable // Allocate a new variable
Variable* tmp_var = new framework::Variable(); Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>(); tmp_var->GetMutable<framework::LoDTensor>();
outputs.push_back(tmp_var); outputs.emplace_back(tmp_var);
} }
} }
framework::RuntimeContext ctx(grad_input_vars_[k], grad_outputs[k]); // Run grad op
framework::RuntimeContext ctx(grad_input_vars_[k], tmp_grad_outputs[k]);
// No need to do compile time infer shape here. // No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_); // grad_op_desc_->InferShape(*block_);
grad_op_desc->InferVarType(block_); // grad_op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc); framework::OpRegistry::CreateOp(*grad_op_desc);
...@@ -260,9 +271,10 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -260,9 +271,10 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
} }
} }
// Add tmp grad outputs to original grad vars
for (size_t k = 0; k < grad_output_vars_.size(); ++k) { for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
for (auto it : grad_output_vars_[k]) { for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first]; auto& outputs = tmp_grad_outputs[k][it.first];
auto& origin_outputs = it.second; auto& origin_outputs = it.second;
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size()); PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
...@@ -316,19 +328,14 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) { ...@@ -316,19 +328,14 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
int PyLayer::NumFuncs() { return py_funcs_.size(); } int PyLayer::NumFuncs() { return py_funcs_.size(); }
std::vector<VarBase*> PyLayer::Apply(int func_id, std::vector<Variable*> PyLayer::Apply(int func_id,
const std::vector<VarBase*>& inputs) { const std::vector<VarBase*>& inputs) {
std::vector<framework::Variable*> invars; std::vector<framework::Variable*> invars;
for (const VarBase* in : inputs) { for (const VarBase* in : inputs) {
invars.push_back(in->var_); invars.push_back(in->var_);
} }
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end()); PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
std::vector<Variable*> outvars = CallPythonFunc(py_funcs_[func_id], invars); return CallPythonFunc(py_funcs_[func_id], invars);
std::vector<VarBase*> ret;
for (Variable* v : outvars) {
ret.push_back(new VarBase(v, new VarBase(true)));
}
return ret;
} }
std::vector<Variable*> PyLayer::ApplyGrad( std::vector<Variable*> PyLayer::ApplyGrad(
......
...@@ -112,31 +112,53 @@ class OpBase; ...@@ -112,31 +112,53 @@ class OpBase;
*/ */
class VarBase { class VarBase {
public: public:
VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {} // Internal interface, create VarBase from exist variable
VarBase(const std::string& name, framework::Variable* var, VarBase* grad,
explicit VarBase(bool stop_gradient) bool stop_gradient)
: VarBase(new framework::Variable(), : VarBase(name, var->Get<framework::LoDTensor>().type(),
stop_gradient ? nullptr : new VarBase(true), stop_gradient) {} var->Get<framework::LoDTensor>().dims(),
var->Get<framework::LoDTensor>().place(), var, grad,
VarBase(framework::Variable* var, VarBase* grad) stop_gradient, false) {}
: VarBase(var, grad, false) {}
// Python interface
VarBase(const std::string& name, const framework::proto::VarType::Type dtype,
const std::vector<int64_t>& shape, const platform::Place& place,
bool stop_gradient, bool persistable)
: VarBase(name, dtype, framework::make_ddim(shape), place, stop_gradient,
persistable) {}
// Internal interface, create VarBase from with ddim
VarBase(const std::string& name, const framework::proto::VarType::Type dtype,
const framework::DDim& shape, const platform::Place& place,
bool stop_gradient, bool persistable)
: VarBase(name, dtype, shape, place, nullptr, nullptr, stop_gradient,
persistable) {}
private: private:
VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient) VarBase(const std::string& name, framework::proto::VarType::Type dtype,
: name_(), const framework::DDim& shape, const platform::Place& place,
var_desc_(nullptr), framework::Variable* var, VarBase* grad, bool stop_gradient,
bool persistable)
: name_(name),
dtype_(dtype),
place_(place),
var_(var), var_(var),
grads_(grad), grads_(grad),
block_(nullptr),
persistable_(false),
stop_gradient_(stop_gradient), stop_gradient_(stop_gradient),
persistable_(persistable),
pre_op_(nullptr), pre_op_(nullptr),
pre_op_out_name_(), pre_op_out_name_(),
pre_op_out_idx_(-1) {} pre_op_out_idx_(-1) {
if (!var_) {
var_ = new framework::Variable();
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place_, dtype_);
}
}
public: public:
virtual ~VarBase() { virtual ~VarBase() {
// TODO(minqiyang): remove var desc from block desc
if (var_) { if (var_) {
delete var_; delete var_;
var_ = nullptr; var_ = nullptr;
...@@ -151,14 +173,30 @@ class VarBase { ...@@ -151,14 +173,30 @@ class VarBase {
pre_op_out_idx_ = -1; pre_op_out_idx_ = -1;
} }
inline OpBase* PreOp() const { return pre_op_; } inline void SetName(const std::string& name) { name_ = name; }
inline int PreOpOutIdx() const { return pre_op_out_idx_; } inline std::string Name() const { return name_; }
inline std::vector<int64_t> Shape() const {
if (var_->IsInitialized()) {
return framework::vectorize(var_->Get<framework::LoDTensor>().dims());
} else {
return {};
}
}
inline framework::proto::VarType::Type DType() const { return dtype_; }
inline void SetStopGradient(bool stop_gradient) { inline void SetStopGradient(bool stop_gradient) {
stop_gradient_ = stop_gradient; stop_gradient_ = stop_gradient;
} }
inline bool IsStopGradient() const { return stop_gradient_; } inline bool IsStopGradient() const { return stop_gradient_; }
inline void SetPersistable(bool persistable) { persistable_ = persistable; }
inline bool IsPersistable() const { return persistable_; }
inline OpBase* PreOp() const { return pre_op_; }
inline int PreOpOutIdx() const { return pre_op_out_idx_; }
void RunBackward(); void RunBackward();
inline void ResetPreOp(OpBase* op) { inline void ResetPreOp(OpBase* op) {
...@@ -180,7 +218,7 @@ class VarBase { ...@@ -180,7 +218,7 @@ class VarBase {
} }
void ClearGradient() { void ClearGradient() {
VLOG(1) << "clear gradient of " << var_desc_->Name(); VLOG(1) << "clear gradient of " << Name();
if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) { if (grads_ && grads_->var_ && grads_->var_->IsInitialized()) {
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>(); auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant( operators::math::set_constant(
...@@ -196,23 +234,20 @@ class VarBase { ...@@ -196,23 +234,20 @@ class VarBase {
const bool blocking) const; const bool blocking) const;
inline std::string GradName() const { inline std::string GradName() const {
PADDLE_ENFORCE( return string::Sprintf("%s@IGrad", Name());
var_desc_,
"Couldn't get gradient variable's name, please call backward() first");
return string::Sprintf("%s@IGrad", var_desc_->Name());
} }
std::string name_; std::string name_;
framework::VarDesc* var_desc_; framework::proto::VarType::Type dtype_;
platform::Place place_;
framework::Variable* var_; framework::Variable* var_;
VarBase* grads_; VarBase* grads_;
framework::BlockDesc* block_;
bool persistable_;
private: private:
bool stop_gradient_; bool stop_gradient_;
bool persistable_;
OpBase* pre_op_; OpBase* pre_op_;
std::string pre_op_out_name_; std::string pre_op_out_name_;
int pre_op_out_idx_; int pre_op_out_idx_;
...@@ -223,11 +258,11 @@ class VarBase { ...@@ -223,11 +258,11 @@ class VarBase {
*/ */
class PYBIND11_HIDDEN OpBase { class PYBIND11_HIDDEN OpBase {
public: public:
OpBase() OpBase(const std::string& type)
: op_desc_(nullptr), : type_(type),
trace_id_(-1),
forward_id_(-1), forward_id_(-1),
backward_id_(-1), backward_id_(-1),
trace_id_(-1),
place_(platform::CPUPlace()), place_(platform::CPUPlace()),
backward_hooks_() {} backward_hooks_() {}
...@@ -249,13 +284,34 @@ class PYBIND11_HIDDEN OpBase { ...@@ -249,13 +284,34 @@ class PYBIND11_HIDDEN OpBase {
std::map<std::string, std::vector<VarBase*>> ApplyGrad(); std::map<std::string, std::vector<VarBase*>> ApplyGrad();
inline std::string Type() const { return type_; }
inline std::string GradOpType(size_t index) const {
PADDLE_ENFORCE_NOT_NULL(grad_op_descs_[index]);
return grad_op_descs_[index]->Type();
}
void RegisterBackwardHooks(const py::object& callable); void RegisterBackwardHooks(const py::object& callable);
void InvokeBackwardHooks(); void InvokeBackwardHooks();
// One of `op_desc_` or `forward_id_` is set, not both. void TrackPreOp(const VarBase* inp_var, const std::string& inp_name) {
// For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_. if (inp_var->PreOp() && !inp_var->IsStopGradient()) {
framework::OpDesc* op_desc_; VLOG(3) << "add pre op " << inp_var->PreOp()->Type() << " in slot "
<< inp_name;
pre_ops_[inp_name].push_back(inp_var->PreOp());
pre_ops_out_idx_[inp_name].push_back(inp_var->PreOpOutIdx());
} else {
VLOG(3) << "no pre op in slot " << inp_name
<< " input var stop_gradient: " << inp_var->IsStopGradient();
pre_ops_[inp_name].push_back(nullptr);
// pre_ops_out_idx_[inp_name].push_back(-1);
}
}
std::string type_;
// One of `trace_id_` or `forward_id_` is set, not both.
// For pure python PyLayer, use `forward_id_`, otherwise, use trace_id_.
int trace_id_;
int forward_id_; int forward_id_;
// When has backward, one of `grad_op_descs_` or `backward_id_` is set, // When has backward, one of `grad_op_descs_` or `backward_id_` is set,
...@@ -263,7 +319,6 @@ class PYBIND11_HIDDEN OpBase { ...@@ -263,7 +319,6 @@ class PYBIND11_HIDDEN OpBase {
// Note: each fwd op corresponds to a vector of bwd ops. // Note: each fwd op corresponds to a vector of bwd ops.
std::vector<framework::OpDesc*> grad_op_descs_; std::vector<framework::OpDesc*> grad_op_descs_;
int backward_id_; int backward_id_;
int trace_id_;
platform::Place place_; platform::Place place_;
...@@ -277,8 +332,6 @@ class PYBIND11_HIDDEN OpBase { ...@@ -277,8 +332,6 @@ class PYBIND11_HIDDEN OpBase {
// Outputs to a vector of bwd ops. // Outputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_output_vars_; std::vector<framework::VariableValueMap> grad_output_vars_;
framework::BlockDesc* block_;
std::vector<py::object> backward_hooks_; std::vector<py::object> backward_hooks_;
}; };
...@@ -303,8 +356,8 @@ class PyLayer { ...@@ -303,8 +356,8 @@ class PyLayer {
static int NumFuncs(); static int NumFuncs();
static std::vector<VarBase*> Apply(int func_id, static std::vector<framework::Variable*> Apply(
const std::vector<VarBase*>& inputs); int func_id, const std::vector<VarBase*>& inputs);
static std::vector<framework::Variable*> ApplyGrad( static std::vector<framework::Variable*> ApplyGrad(
int func_id, const std::vector<framework::Variable*>& inputs); int func_id, const std::vector<framework::Variable*>& inputs);
......
...@@ -56,15 +56,19 @@ void CreateGradOp(const framework::OpDesc& op_desc, ...@@ -56,15 +56,19 @@ void CreateGradOp(const framework::OpDesc& op_desc,
} }
} }
void InitVar(framework::Variable* var, framework::Variable* grad_var, void InitGrad(VarBase* var, platform::DeviceContext* dev_ctx) {
platform::DeviceContext* dev_ctx) { PADDLE_ENFORCE_NOT_NULL(var, "Could not get valid var base");
PADDLE_ENFORCE_NOT_NULL(dev_ctx, PADDLE_ENFORCE_NOT_NULL(dev_ctx,
"Could not get valid device from forward op"); "Could not get valid device from forward op");
auto& var_t = var->Get<framework::LoDTensor>();
grad_var->GetMutable<framework::LoDTensor>()->mutable_data<float>( if (var->grads_ == nullptr) {
var_t.dims(), dev_ctx->GetPlace()); auto& var_t = var->var_->Get<framework::LoDTensor>();
operators::math::set_constant( var->grads_ = new VarBase(var->GradName(), framework::proto::VarType::FP32,
*dev_ctx, grad_var->GetMutable<framework::LoDTensor>(), 0.0); framework::vectorize(var_t.dims()),
dev_ctx->GetPlace(), true, false);
auto grad_t = var->grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(*dev_ctx, grad_t, 0.0);
}
} }
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
...@@ -85,6 +89,62 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { ...@@ -85,6 +89,62 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
return result; return result;
} }
framework::VariableNameMap CreateInputVarNameMap(
const OpBase* op, const VarBasePtrMap& varbase_map) {
framework::VariableNameMap result;
auto& info_map = framework::OpInfoMap::Instance();
auto* op_info = info_map.GetNullable(op->Type());
if (op_info == nullptr || op_info->proto_ == nullptr) {
return result;
}
for (auto& in : op_info->Proto().inputs()) {
auto it = varbase_map.find(in.name());
if (it == varbase_map.end()) {
PADDLE_ENFORCE(in.dispensable());
result[in.name()] = {};
} else {
auto var_vector = it->second;
std::vector<std::string> args;
args.reserve(var_vector.size());
for (VarBase* var_base : var_vector) {
args.emplace_back(var_base->Name());
}
result[in.name()] = args;
}
}
return result;
}
framework::VariableNameMap CreateOutputVarNameMap(
const OpBase* op, const VarBasePtrMap& varbase_map) {
framework::VariableNameMap result;
auto& info_map = framework::OpInfoMap::Instance();
auto* op_info = info_map.GetNullable(op->Type());
if (op_info == nullptr || op_info->proto_ == nullptr) {
return result;
}
for (auto& out : op_info->Proto().outputs()) {
auto it = varbase_map.find(out.name());
if (it == varbase_map.end()) {
PADDLE_ENFORCE(out.dispensable());
result[out.name()] = {};
} else {
auto var_vector = it->second;
std::vector<std::string> args;
args.reserve(var_vector.size());
for (VarBase* var_base : var_vector) {
args.emplace_back(var_base->Name());
}
result[out.name()] = args;
}
}
return result;
}
Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) { Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
if (!FLAGS_tracer_profile_fname.empty()) { if (!FLAGS_tracer_profile_fname.empty()) {
std::call_once(gTracerProfileOnce, [] { std::call_once(gTracerProfileOnce, [] {
...@@ -101,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) { ...@@ -101,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, const VarBasePtrMap& outputs,
framework::BlockDesc* block, framework::AttributeMap attrs_map,
const platform::Place expected_place, const platform::Place expected_place,
const bool stop_gradient) { const bool stop_gradient) {
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
...@@ -110,40 +170,27 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -110,40 +170,27 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
} }
#endif #endif
std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_;
VLOG(3) << "tracer tracing " << op_desc->Type() << " trace id "
<< op->trace_id_;
op_desc->InferShape(*block);
op_desc->InferVarType(block);
std::unique_ptr<framework::OperatorBase> op_base =
framework::OpRegistry::CreateOp(*op_desc);
framework::VariableValueMap invars_map; framework::VariableValueMap invars_map;
framework::VariableValueMap outvars_map; framework::VariableValueMap outvars_map;
// Construct input_vars_map and output_vars_map
std::map<std::string, VarBase*> current_vars_map;
op->input_vars_ = inputs; op->input_vars_ = inputs;
for (auto it : op->input_vars_) { for (auto it : op->input_vars_) {
auto& invars = invars_map[it.first]; auto& invars = invars_map[it.first];
invars.reserve(it.second.size()); invars.reserve(it.second.size());
for (VarBase* inp : it.second) { for (VarBase* inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->Type(),
op->op_desc_->Type(), inp->var_desc_->Name()); inp->Name());
invars.emplace_back(inp->var_); invars.emplace_back(inp->var_);
vars[inp->var_desc_->Name()] = inp; op->TrackPreOp(inp, it.first);
if (inp->PreOp() && !inp->IsStopGradient()) { if (!stop_gradient) {
op->pre_ops_[it.first].push_back(inp->PreOp()); current_vars_map[inp->Name()] = inp;
op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx());
VLOG(3) << "add pre op " << inp->PreOp()->op_desc_->Type();
} else {
op->pre_ops_[it.first].push_back(nullptr);
} }
VLOG(3) << "input vname " << inp->var_desc_->Name() << " " VLOG(3) << "input var name: " << inp->Name()
<< inp->var_->IsInitialized() << " stop_gradient " << " inited: " << inp->var_->IsInitialized()
<< inp->IsStopGradient(); << " stop_grad: " << inp->IsStopGradient();
} }
} }
...@@ -152,25 +199,38 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -152,25 +199,38 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
auto& outvars = outvars_map[it.first]; auto& outvars = outvars_map[it.first];
const std::vector<VarBase*>& outputs = it.second; const std::vector<VarBase*>& outputs = it.second;
outvars.reserve(outputs.size()); outvars.reserve(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0U; i < outputs.size(); ++i) {
VarBase* out = outputs[i]; VarBase* out = outputs[i];
outvars.emplace_back(out->var_); outvars.emplace_back(out->var_);
vars[out->var_desc_->Name()] = out;
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
out->var_->GetMutable<framework::LoDTensor>();
} else {
LOG(ERROR) << "tracer doesn't support yet";
}
out->TrackPreOp(op, it.first, i, stop_gradient); out->TrackPreOp(op, it.first, i, stop_gradient);
if (!stop_gradient) {
current_vars_map[out->Name()] = out;
}
VLOG(3) << "output vname " << out->var_desc_->Name() << " " VLOG(3) << "input var name: " << out->Name()
<< out->var_->IsInitialized(); << " inited: " << out->var_->IsInitialized()
<< " stop_grad: " << out->IsStopGradient();
} }
} }
VLOG(3) << "tracer running " << op_desc->Type(); // Check attrs and create op
framework::VariableNameMap invars_name_map =
CreateInputVarNameMap(op, inputs);
framework::VariableNameMap outvars_name_map =
CreateOutputVarNameMap(op, outputs);
auto& info = framework::OpInfoMap::Instance().Get(op->Type());
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_map);
}
std::unique_ptr<framework::OperatorBase> op_base =
framework::OpRegistry::CreateOp(op->Type(), invars_name_map,
outvars_name_map, attrs_map);
// TODO(minqiyang): Support infer var type in imperative mode
// Run forward op
VLOG(3) << "tracer running " << op->Type();
framework::RuntimeContext ctx(invars_map, outvars_map); framework::RuntimeContext ctx(invars_map, outvars_map);
// TODO(panyx0718): Cache p. // TODO(panyx0718): Cache p.
...@@ -186,36 +246,44 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -186,36 +246,44 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx, framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx,
prepared_op.ctx, prepared_op.kernel_configs)); prepared_op.ctx, prepared_op.kernel_configs));
// construct backward op
std::set<std::string> vars_saved_for_backward; std::set<std::string> vars_saved_for_backward;
if (!stop_gradient) { if (!stop_gradient) {
VLOG(5) << "start construct backward op";
// construct grad op descs
std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc(
op->Type(), invars_name_map, outvars_name_map, attrs_map));
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var( std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
new std::unordered_map<std::string, std::string>()); new std::unordered_map<std::string, std::string>());
CreateGradOp(*op_desc, {}, {block}, &op->grad_op_descs_, grad_to_var.get()); // NOTE(minqiyang): We don't support control flow op in imperative now
// Add grad_block_ when we want to support it
CreateGradOp(*fwd_op_desc, {}, {}, &op->grad_op_descs_, grad_to_var.get());
op->grad_input_vars_.resize(op->grad_op_descs_.size()); VLOG(5) << "create grad op desc: " << op->grad_op_descs_[0]->Type();
op->grad_output_vars_.resize(op->grad_op_descs_.size());
for (size_t i = 0; i < op->grad_op_descs_.size(); ++i) { const size_t grad_op_count = op->grad_op_descs_.size();
op->grad_input_vars_.resize(grad_op_count);
op->grad_output_vars_.resize(grad_op_count);
for (size_t i = 0; i < grad_op_count; ++i) {
framework::OpDesc* grad_op_desc = op->grad_op_descs_[i]; framework::OpDesc* grad_op_desc = op->grad_op_descs_[i];
for (auto it : grad_op_desc->Inputs()) { for (auto it : grad_op_desc->Inputs()) {
auto& grad_in_vars = op->grad_input_vars_[i][it.first]; auto& grad_in_vars = op->grad_input_vars_[i][it.first];
grad_in_vars.reserve(it.second.size());
for (const std::string& grad_invar : it.second) { for (const std::string& grad_invar : it.second) {
block->FindRecursiveOrCreateVar(grad_invar);
auto var_it = grad_to_var->find(grad_invar); auto var_it = grad_to_var->find(grad_invar);
if (var_it == grad_to_var->end()) { if (var_it == grad_to_var->end()) {
auto fwd_var_it = vars.find(grad_invar); auto fwd_var_it = current_vars_map.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != vars.end()); PADDLE_ENFORCE(fwd_var_it != current_vars_map.end());
// Forward inputs or outputs. // Forward inputs or outputs.
grad_in_vars.push_back(fwd_var_it->second->var_); grad_in_vars.emplace_back(fwd_var_it->second->var_);
} else { } else {
VarBase* var = vars[var_it->second]; VarBase* var = current_vars_map[var_it->second];
if (!var->grads_->var_->IsInitialized()) { InitGrad(var, prepared_op.GetDeviceContext());
InitVar(var->var_, var->grads_->var_,
prepared_op.GetDeviceContext());
}
// Douts. // Douts.
grad_in_vars.push_back(var->grads_->var_); grad_in_vars.emplace_back(var->grads_->var_);
} }
vars_saved_for_backward.insert(it.first); vars_saved_for_backward.insert(it.first);
...@@ -225,48 +293,48 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -225,48 +293,48 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
for (auto it : grad_op_desc->Outputs()) { for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[i][it.first]; auto& grad_out_vars = op->grad_output_vars_[i][it.first];
for (const std::string& grad_outvar : it.second) { for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = grad_to_var->find(grad_outvar); auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end(), PADDLE_ENFORCE(var_it != grad_to_var->end(),
"Could not found the grad op output var, should this " "Could not found the grad op output var, should this "
"operator %s's stop gradient be True", "operator %s's stop gradient be True",
op_desc->Type()); op->Type());
VarBase* var = vars[var_it->second]; VarBase* var = current_vars_map[var_it->second];
if (!var->grads_->var_->IsInitialized()) { InitGrad(var, prepared_op.GetDeviceContext());
InitVar(var->var_, var->grads_->var_,
prepared_op.GetDeviceContext());
}
grad_out_vars.push_back(var->grads_->var_); grad_out_vars.push_back(var->grads_->var_);
} }
} }
} }
} }
op->block_ = block;
return vars_saved_for_backward; return vars_saved_for_backward;
} }
std::vector<VarBase*> Tracer::PyTrace(OpBase* op, std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
const std::vector<VarBase*>& inputs, const std::vector<VarBase*>& inputs,
bool stop_gradient) { bool stop_gradient) {
VLOG(3) << "py_trace"; VLOG(3) << "py_trace " << op->Type();
op->input_vars_[PyLayer::kFwdInp] = inputs; op->input_vars_[PyLayer::kFwdInp] = inputs;
op->output_vars_[PyLayer::kFwdOut] = PyLayer::Apply(op->forward_id_, inputs);
std::vector<framework::Variable*> ret_vars =
PyLayer::Apply(op->forward_id_, inputs);
for (VarBase* inp : inputs) { for (VarBase* inp : inputs) {
if (inp->PreOp() && !inp->IsStopGradient()) { op->TrackPreOp(inp, PyLayer::kFwdInp);
op->pre_ops_[PyLayer::kFwdInp].push_back(inp->PreOp());
op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->PreOpOutIdx());
} else {
op->pre_ops_[PyLayer::kFwdInp].push_back(nullptr);
}
} }
auto& outputs = op->output_vars_[PyLayer::kFwdOut]; std::vector<VarBase*>& outputs = op->output_vars_[PyLayer::kFwdOut];
for (size_t i = 0; i < outputs.size(); ++i) { outputs.reserve(ret_vars.size());
VarBase* out = outputs[i]; for (size_t i = 0U; i != ret_vars.size(); ++i) {
framework::Variable* v = ret_vars[i];
VarBase* out = new VarBase(string::Sprintf("%s_out_%d", op->Type(), i), v,
nullptr, stop_gradient);
outputs.emplace_back(out);
out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient); out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient);
} }
if (!stop_gradient) { if (!stop_gradient) {
VLOG(5) << "start construct backward op";
op->grad_input_vars_.resize(1); op->grad_input_vars_.resize(1);
op->grad_output_vars_.resize(1); op->grad_output_vars_.resize(1);
auto& grad_input_vars = auto& grad_input_vars =
...@@ -281,23 +349,16 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -281,23 +349,16 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
grad_input_vars.push_back(out->var_); grad_input_vars.push_back(out->var_);
} }
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
platform::CPUPlace place; platform::CPUPlace place;
for (VarBase* out : outputs) { for (VarBase* out : outputs) {
InitGrad(out, platform::DeviceContextPool::Instance().Get(place));
grad_input_vars.push_back(out->grads_->var_); grad_input_vars.push_back(out->grads_->var_);
if (!grad_input_vars.back()->IsInitialized()) {
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
InitVar(out->var_, grad_input_vars.back(),
platform::DeviceContextPool::Instance().Get(place));
}
} }
for (const VarBase* inp : inputs) { for (VarBase* inp : inputs) {
InitGrad(inp, platform::DeviceContextPool::Instance().Get(place));
grad_output_vars.push_back(inp->grads_->var_); grad_output_vars.push_back(inp->grads_->var_);
if (!grad_output_vars.back()->IsInitialized()) {
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
InitVar(inp->var_, grad_output_vars.back(),
platform::DeviceContextPool::Instance().Get(place));
}
} }
} }
return outputs; return outputs;
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <map> #include <map>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
...@@ -34,7 +36,8 @@ void CreateGradOp(const framework::OpDesc& op_desc, ...@@ -34,7 +36,8 @@ void CreateGradOp(const framework::OpDesc& op_desc,
framework::OpDesc** grad_op_desc, framework::OpDesc** grad_op_desc,
std::unordered_map<std::string, std::string>* grad_to_var); std::unordered_map<std::string, std::string>* grad_to_var);
void InitVar(framework::Variable* var, framework::Variable* grad_var); void InitVar(const VarBase* var, framework::Variable* grad_var,
platform::DeviceContext* dev_ctx);
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs); platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs);
...@@ -46,7 +49,7 @@ class Tracer { ...@@ -46,7 +49,7 @@ class Tracer {
std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs, std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, const VarBasePtrMap& outputs,
framework::BlockDesc* block, framework::AttributeMap attrs_map,
const platform::Place expected_place, const platform::Place expected_place,
const bool stop_gradient = false); const bool stop_gradient = false);
......
...@@ -126,15 +126,20 @@ void ZeroCopyTensor::copy_to_cpu(T *data) { ...@@ -126,15 +126,20 @@ void ZeroCopyTensor::copy_to_cpu(T *data) {
} }
template void ZeroCopyTensor::copy_from_cpu<float>(const float *data); template void ZeroCopyTensor::copy_from_cpu<float>(const float *data);
template void ZeroCopyTensor::copy_from_cpu<int64_t>(const int64_t *data); template void ZeroCopyTensor::copy_from_cpu<int64_t>(const int64_t *data);
template void ZeroCopyTensor::copy_from_cpu<int32_t>(const int32_t *data);
template void ZeroCopyTensor::copy_to_cpu<float>(float *data); template void ZeroCopyTensor::copy_to_cpu<float>(float *data);
template void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data); template void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data);
template void ZeroCopyTensor::copy_to_cpu<int32_t>(int32_t *data);
template float *ZeroCopyTensor::data<float>(PaddlePlace *place, template float *ZeroCopyTensor::data<float>(PaddlePlace *place,
int *size) const; int *size) const;
template int64_t *ZeroCopyTensor::data<int64_t>(PaddlePlace *place, template int64_t *ZeroCopyTensor::data<int64_t>(PaddlePlace *place,
int *size) const; int *size) const;
template int32_t *ZeroCopyTensor::data<int32_t>(PaddlePlace *place,
int *size) const;
template float *ZeroCopyTensor::mutable_data<float>(PaddlePlace place); template float *ZeroCopyTensor::mutable_data<float>(PaddlePlace place);
template int64_t *ZeroCopyTensor::mutable_data<int64_t>(PaddlePlace place); template int64_t *ZeroCopyTensor::mutable_data<int64_t>(PaddlePlace place);
template int32_t *ZeroCopyTensor::mutable_data<int32_t>(PaddlePlace place);
void *ZeroCopyTensor::FindTensor() const { void *ZeroCopyTensor::FindTensor() const {
PADDLE_ENFORCE(!name_.empty(), PADDLE_ENFORCE(!name_.empty(),
......
...@@ -139,9 +139,8 @@ static void TensorAssignData(PaddleTensor *tensor, ...@@ -139,9 +139,8 @@ static void TensorAssignData(PaddleTensor *tensor,
} }
template <typename T> template <typename T>
static int ZeroCopyTensorAssignData(ZeroCopyTensor *tensor, static void ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
const std::vector<std::vector<T>> &data) { const std::vector<std::vector<T>> &data) {
int size{0};
auto *ptr = tensor->mutable_data<T>(PaddlePlace::kCPU); auto *ptr = tensor->mutable_data<T>(PaddlePlace::kCPU);
int c = 0; int c = 0;
for (const auto &f : data) { for (const auto &f : data) {
...@@ -149,7 +148,15 @@ static int ZeroCopyTensorAssignData(ZeroCopyTensor *tensor, ...@@ -149,7 +148,15 @@ static int ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
ptr[c++] = v; ptr[c++] = v;
} }
} }
return size; }
template <typename T>
static void ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
const PaddleBuf &data) {
auto *ptr = tensor->mutable_data<T>(PaddlePlace::kCPU);
for (size_t i = 0; i < data.length() / sizeof(T); i++) {
ptr[i] = *(reinterpret_cast<T *>(data.data()) + i);
}
} }
static bool CompareTensor(const PaddleTensor &a, const PaddleTensor &b) { static bool CompareTensor(const PaddleTensor &a, const PaddleTensor &b) {
......
...@@ -107,6 +107,9 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -107,6 +107,9 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false);
}
} }
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...@@ -131,7 +134,7 @@ TEST(Analyzer_Pyramid_DNN, profile) { ...@@ -131,7 +134,7 @@ TEST(Analyzer_Pyramid_DNN, profile) {
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg), TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads); input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { if (FLAGS_num_threads == 1 && !FLAGS_test_all_data && !FLAGS_zero_copy) {
PADDLE_ENFORCE_EQ(outputs.size(), 1UL); PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
size_t size = GetSize(outputs[0]); size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(size, 0); PADDLE_ENFORCE_GT(size, 0);
...@@ -166,6 +169,19 @@ TEST(Analyzer_Pyramid_DNN, compare) { ...@@ -166,6 +169,19 @@ TEST(Analyzer_Pyramid_DNN, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare result of AnalysisConfig and AnalysisConfig + ZeroCopy
TEST(Analyzer_Pyramid_DNN, compare_zero_copy) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
std::vector<std::string> outputs_name;
outputs_name.emplace_back("cos_sim_2.tmp_0");
CompareAnalysisAndZeroCopy(reinterpret_cast<PaddlePredictor::Config *>(&cfg),
input_slots_all, outputs_name);
}
// Compare Deterministic result // Compare Deterministic result
TEST(Analyzer_Pyramid_DNN, compare_determine) { TEST(Analyzer_Pyramid_DNN, compare_determine) {
AnalysisConfig cfg; AnalysisConfig cfg;
......
...@@ -207,6 +207,9 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -207,6 +207,9 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false);
}
} }
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
...@@ -285,133 +288,17 @@ TEST(Analyzer_rnn1, multi_thread) { ...@@ -285,133 +288,17 @@ TEST(Analyzer_rnn1, multi_thread) {
input_slots_all, &outputs, 2 /* multi_thread */); input_slots_all, &outputs, 2 /* multi_thread */);
} }
// Validate that the AnalysisPredictor + ZeroCopyTensor really works by testing // Compare result of AnalysisConfig and AnalysisConfig + ZeroCopy
// on the complex RNN1 model. TEST(Analyzer_rnn1, compare_zero_copy) {
TEST(Analyzer_rnn1, ZeroCopy) { AnalysisConfig cfg;
AnalysisConfig config; SetConfig(&cfg);
SetConfig(&config);
config.SwitchUseFeedFetchOps(false);
PaddlePlace place;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
config.SwitchUseFeedFetchOps(true);
auto native_predictor =
CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig());
config.SwitchUseFeedFetchOps(
true); // the analysis predictor needs feed/fetch.
auto analysis_predictor = CreatePaddlePredictor<AnalysisConfig>(config);
#define NEW_TENSOR(name__) \
auto name__##_tensor = predictor->GetInputTensor(#name__);
NEW_TENSOR(data_lod_attention);
NEW_TENSOR(cell_init);
NEW_TENSOR(data);
NEW_TENSOR(week);
NEW_TENSOR(minute);
NEW_TENSOR(hidden_init);
// Prepare data for AnalysisPredictor
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
PrepareZeroCopyInputs(data_lod_attention_tensor.get(), cell_init_tensor.get(),
data_tensor.get(), hidden_init_tensor.get(),
week_tensor.get(), minute_tensor.get(), &data,
FLAGS_batch_size);
// Prepare data for NativePredictor
std::vector<std::vector<PaddleTensor>> native_inputs;
SetInput(&native_inputs);
std::vector<PaddleTensor> native_outputs;
std::vector<PaddleTensor> analysis_outputs;
auto output_tensor = predictor->GetOutputTensor("final_output.tmp_1");
// Run analysis predictor
int num_ops;
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 1);
ASSERT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM
ASSERT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1);
ASSERT_EQ(num_ops,
13); // After graph optimization, only 13 operators exists.
Timer timer;
double total_time{0};
for (int i = 0; i < FLAGS_repeat; i++) {
timer.tic();
predictor->ZeroCopyRun();
total_time += timer.toc();
}
LOG(INFO) << "ZeroCopy output: " << DescribeZeroCopyTensor(*output_tensor);
ASSERT_TRUE(native_predictor->Run(native_inputs.front(), &native_outputs));
LOG(INFO) << "native output " << DescribeTensor(native_outputs.front());
int output_size{0}; // this is the number of elements not memory size
auto *zero_copy_data = output_tensor->data<float>(&place, &output_size);
auto *native_data = static_cast<float *>(native_outputs.front().data.data());
for (int i = 0; i < output_size; i++) {
EXPECT_NEAR(zero_copy_data[i], native_data[i], 1e-3);
}
}
TEST(Analyzer_rnn1, ZeroCopyMultiThread) {
AnalysisConfig config;
SetConfig(&config);
config.SwitchUseFeedFetchOps(false);
#define NEW_TENSOR(name__) \
auto name__##_tensor = predictor->GetInputTensor(#name__);
std::vector<std::unique_ptr<PaddlePredictor>> predictors;
predictors.emplace_back(CreatePaddlePredictor<AnalysisConfig>(config));
for (int tid = 1; tid < FLAGS_num_threads; tid++) {
predictors.emplace_back(predictors.front()->Clone());
}
double total_time_of_threads{0};
std::vector<std::thread> threads;
for (int tid = 0; tid < FLAGS_num_threads; tid++) {
threads.emplace_back([&, tid] {
auto &predictor = predictors[tid];
NEW_TENSOR(data_lod_attention);
NEW_TENSOR(cell_init);
NEW_TENSOR(data);
NEW_TENSOR(week);
NEW_TENSOR(minute);
NEW_TENSOR(hidden_init);
// Prepare data for AnalysisPredictor
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
Timer timer;
double total_time{0};
for (int i = 0; i < FLAGS_repeat; i++) {
PrepareZeroCopyInputs(data_lod_attention_tensor.get(),
cell_init_tensor.get(), data_tensor.get(),
hidden_init_tensor.get(), week_tensor.get(),
minute_tensor.get(), &data, FLAGS_batch_size);
timer.tic();
predictor->ZeroCopyRun();
total_time += timer.toc();
}
total_time_of_threads += total_time;
LOG(INFO) << "thread time: " << total_time / FLAGS_repeat;
});
}
for (auto &t : threads) {
t.join();
}
LOG(INFO) << "average time: " std::vector<std::vector<PaddleTensor>> input_slots_all;
<< total_time_of_threads / FLAGS_num_threads / FLAGS_repeat; SetInput(&input_slots_all);
std::vector<std::string> outputs_name;
outputs_name.emplace_back("final_output.tmp_1");
CompareAnalysisAndZeroCopy(reinterpret_cast<PaddlePredictor::Config *>(&cfg),
input_slots_all, outputs_name);
} }
} // namespace inference } // namespace inference
......
...@@ -144,6 +144,9 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { ...@@ -144,6 +144,9 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->SwitchIrDebug(); cfg->SwitchIrDebug();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false);
}
if (use_mkldnn) { if (use_mkldnn) {
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
} }
...@@ -184,10 +187,10 @@ TEST(Analyzer_seq_pool1, compare_determine) { ...@@ -184,10 +187,10 @@ TEST(Analyzer_seq_pool1, compare_determine) {
input_slots_all); input_slots_all);
} }
void analysis_fuse_statis(bool use_zerocopy) { // Check the fuse status
TEST(Analyzer_seq_pool1, fuse_statis) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
cfg.SwitchUseFeedFetchOps(!use_zerocopy);
int num_ops; int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg); auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops); auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
...@@ -203,137 +206,17 @@ void analysis_fuse_statis(bool use_zerocopy) { ...@@ -203,137 +206,17 @@ void analysis_fuse_statis(bool use_zerocopy) {
EXPECT_EQ(num_ops, 171); EXPECT_EQ(num_ops, 171);
} }
// Check the fuse status // Compare result of AnalysisConfig and AnalysisConfig + ZeroCopy
TEST(Analyzer_seq_pool1, fuse_statis) { analysis_fuse_statis(false); } TEST(Analyzer_seq_pool1, compare_zero_copy) {
AnalysisConfig cfg;
void PrepareZeroCopyInputs( SetConfig(&cfg);
const std::unique_ptr<PaddlePredictor> &predictor,
std::vector<std::unique_ptr<ZeroCopyTensor>> *inputs) {
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
// only feed one batch
const auto &one_batch = data.NextBatch();
inputs->clear();
for (size_t i = 0; i < one_batch.size(); ++i) {
auto &slot = one_batch[i];
auto tensor = predictor->GetInputTensor(slot.name + "_embed");
tensor->Reshape(slot.shape);
tensor->SetLoD({slot.lod});
ZeroCopyTensorAssignData<float>(tensor.get(), slot.data);
inputs->emplace_back(std::move(tensor));
}
}
// return the output values
std::vector<float> zerocopy_profile(int repeat_times) {
AnalysisConfig config;
SetConfig(&config);
config.SwitchUseFeedFetchOps(false);
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
std::vector<std::unique_ptr<ZeroCopyTensor>> inputs;
PrepareZeroCopyInputs(predictor, &inputs);
auto output_tensor = predictor->GetOutputTensor(out_var_name);
Timer timer;
LOG(INFO) << "Warm up run...";
timer.tic();
predictor->ZeroCopyRun();
PrintTime(FLAGS_batch_size, 1, 1, 0, timer.toc(), 1);
if (FLAGS_profile) {
paddle::platform::ResetProfiler();
}
LOG(INFO) << "Run " << repeat_times << " times...";
timer.tic();
for (int i = 0; i < repeat_times; i++) {
predictor->ZeroCopyRun();
}
PrintTime(FLAGS_batch_size, repeat_times, 1, 0, timer.toc() / repeat_times,
1);
LOG(INFO) << "ZeroCopy output: " << DescribeZeroCopyTensor(*output_tensor);
PaddlePlace place;
int output_size{0};
auto *pdata = output_tensor->data<float>(&place, &output_size);
std::vector<float> res(output_size);
for (int i = 0; i < output_size; ++i) {
res[i] = pdata[i];
}
return res;
}
TEST(Analyzer_seq_pool1, zerocopy_profile) { zerocopy_profile(FLAGS_repeat); }
TEST(Analyzer_seq_pool1, zerocopy_profile_threads) {
AnalysisConfig config;
SetConfig(&config);
config.SwitchUseFeedFetchOps(false);
std::vector<std::unique_ptr<PaddlePredictor>> predictors;
predictors.emplace_back(CreatePaddlePredictor<AnalysisConfig>(config));
for (int tid = 1; tid < FLAGS_num_threads; tid++) {
predictors.emplace_back(predictors.front()->Clone());
}
double total_time_of_threads{0};
std::vector<std::thread> threads;
for (int tid = 0; tid < FLAGS_num_threads; tid++) {
threads.emplace_back([&, tid] {
auto &predictor = predictors[tid];
std::vector<std::unique_ptr<ZeroCopyTensor>> inputs;
PrepareZeroCopyInputs(predictor, &inputs);
auto output_tensor = predictor->GetOutputTensor(out_var_name);
Timer timer;
double total_time{0};
LOG(INFO) << "Warm up run...";
timer.tic();
predictor->ZeroCopyRun();
PrintTime(FLAGS_batch_size, 1, FLAGS_num_threads, tid, timer.toc(), 1);
if (FLAGS_profile) {
paddle::platform::ResetProfiler();
}
int repeat_times = FLAGS_repeat;
LOG(INFO) << "Run " << repeat_times << " times...";
timer.tic();
for (int i = 0; i < repeat_times; i++) {
predictor->ZeroCopyRun();
}
total_time += timer.toc();
total_time_of_threads += total_time;
LOG(INFO) << "thread time: " << total_time / repeat_times;
});
}
for (auto &t : threads) {
t.join();
}
LOG(INFO) << "average time: "
<< total_time_of_threads / FLAGS_num_threads / FLAGS_repeat;
}
TEST(Analyzer_seq_pool1, zerocopy_fuse_statis) { analysis_fuse_statis(true); }
TEST(Analyzer_seq_pool1, zerocopy_compare_native) {
AnalysisConfig config;
SetConfig(&config);
config.SwitchUseFeedFetchOps(true);
auto predictor = CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig());
std::vector<PaddleTensor> native_outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all); SetInput(&input_slots_all);
ASSERT_TRUE(predictor->Run(input_slots_all[0], &native_outputs)); std::vector<std::string> outputs_name;
EXPECT_EQ(native_outputs.size(), 1UL); outputs_name.emplace_back(out_var_name);
CompareAnalysisAndZeroCopy(reinterpret_cast<PaddlePredictor::Config *>(&cfg),
auto zerocopy_output = zerocopy_profile(1); input_slots_all, outputs_name);
EXPECT_EQ(zerocopy_output.size() * sizeof(float),
native_outputs.front().data.length());
auto *native_data = static_cast<float *>(native_outputs.front().data.data());
for (size_t i = 0; i < zerocopy_output.size(); ++i) {
EXPECT_LT(
std::fabs((zerocopy_output[i] - native_data[i]) / zerocopy_output[i]),
1e-3);
}
} }
} // namespace analysis } // namespace analysis
......
...@@ -50,6 +50,7 @@ DEFINE_bool(use_analysis, true, ...@@ -50,6 +50,7 @@ DEFINE_bool(use_analysis, true,
DEFINE_bool(record_benchmark, false, DEFINE_bool(record_benchmark, false,
"Record benchmark after profiling the model"); "Record benchmark after profiling the model");
DEFINE_double(accuracy, 1e-3, "Result Accuracy."); DEFINE_double(accuracy, 1e-3, "Result Accuracy.");
DEFINE_bool(zero_copy, false, "Use ZeroCopy to speedup Feed/Fetch.");
DECLARE_bool(profile); DECLARE_bool(profile);
DECLARE_int32(paddle_num_threads); DECLARE_int32(paddle_num_threads);
...@@ -67,6 +68,7 @@ void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) { ...@@ -67,6 +68,7 @@ void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
LOG(INFO) << analysis_config->ToNativeConfig(); LOG(INFO) << analysis_config->ToNativeConfig();
} }
// Compare result between two PaddleTensor
void CompareResult(const std::vector<PaddleTensor> &outputs, void CompareResult(const std::vector<PaddleTensor> &outputs,
const std::vector<PaddleTensor> &ref_outputs) { const std::vector<PaddleTensor> &ref_outputs) {
EXPECT_GT(outputs.size(), 0UL); EXPECT_GT(outputs.size(), 0UL);
...@@ -108,6 +110,50 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -108,6 +110,50 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
} }
} }
// Compare result between a PaddleTensor and a ZeroCopyTensor
void CompareResult(const std::vector<PaddleTensor> &outputs,
const std::vector<ZeroCopyTensor> &ref_outputs) {
EXPECT_GT(outputs.size(), 0UL);
EXPECT_EQ(outputs.size(), ref_outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
auto &out = outputs[i];
auto &ref_out = ref_outputs[i];
size_t size = VecReduceToInt(out.shape);
EXPECT_GT(size, 0UL);
int ref_size = 0; // this is the number of elements not memory size
PaddlePlace place;
switch (out.dtype) {
case PaddleDType::INT64: {
int64_t *pdata = static_cast<int64_t *>(out.data.data());
int64_t *pdata_ref = ref_out.data<int64_t>(&place, &ref_size);
EXPECT_EQ(size, ref_size);
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::FLOAT32: {
float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = ref_out.data<float>(&place, &ref_size);
EXPECT_EQ(size, ref_size);
for (size_t j = 0; j < size; ++j) {
CHECK_LE(std::abs(pdata_ref[j] - pdata[j]), FLAGS_accuracy);
}
break;
}
case PaddleDType::INT32: {
int32_t *pdata = static_cast<int32_t *>(out.data.data());
int32_t *pdata_ref = ref_out.data<int32_t>(&place, &ref_size);
EXPECT_EQ(size, ref_size);
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
}
}
}
std::unique_ptr<PaddlePredictor> CreateTestPredictor( std::unique_ptr<PaddlePredictor> CreateTestPredictor(
const PaddlePredictor::Config *config, bool use_analysis = true) { const PaddlePredictor::Config *config, bool use_analysis = true) {
const auto *analysis_config = const auto *analysis_config =
...@@ -205,61 +251,106 @@ void GetInputPerBatch(const std::vector<std::vector<int64_t>> &in, ...@@ -205,61 +251,106 @@ void GetInputPerBatch(const std::vector<std::vector<int64_t>> &in,
} }
} }
void TestOneThreadPrediction( void ConvertPaddleTensorToZeroCopyTensor(
const PaddlePredictor::Config *config, PaddlePredictor *predictor, const std::vector<PaddleTensor> &inputs) {
const std::vector<std::vector<PaddleTensor>> &inputs, for (size_t i = 0; i < inputs.size(); i++) {
std::vector<PaddleTensor> *outputs, bool use_analysis = true) { auto input = inputs[i];
int batch_size = FLAGS_batch_size; auto tensor = predictor->GetInputTensor(input.name);
int num_times = FLAGS_repeat; tensor->Reshape(input.shape);
auto predictor = CreateTestPredictor(config, use_analysis); tensor->SetLoD({input.lod});
if (input.dtype == PaddleDType::INT64) {
ZeroCopyTensorAssignData<int64_t>(tensor.get(), input.data);
} else if (input.dtype == PaddleDType::FLOAT32) {
ZeroCopyTensorAssignData<float>(tensor.get(), input.data);
} else if (input.dtype == PaddleDType::INT32) {
ZeroCopyTensorAssignData<int32_t>(tensor.get(), input.data);
} else {
LOG(ERROR) << "unsupported feed type " << input.dtype;
}
}
}
// warmup run void PredictionWarmUp(PaddlePredictor *predictor,
LOG(INFO) << "Warm up run..."; const std::vector<std::vector<PaddleTensor>> &inputs,
{ std::vector<PaddleTensor> *outputs, int num_threads,
Timer warmup_timer; int tid) {
warmup_timer.tic(); int batch_size = FLAGS_batch_size;
LOG(INFO) << "Running thread " << tid << ", warm up run...";
if (FLAGS_zero_copy) {
ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[0]);
}
Timer warmup_timer;
warmup_timer.tic();
if (!FLAGS_zero_copy) {
predictor->Run(inputs[0], outputs, batch_size); predictor->Run(inputs[0], outputs, batch_size);
PrintTime(batch_size, 1, 1, 0, warmup_timer.toc(), 1); } else {
if (FLAGS_profile) { predictor->ZeroCopyRun();
paddle::platform::ResetProfiler(); }
} PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1);
if (FLAGS_profile) {
paddle::platform::ResetProfiler();
} }
}
LOG(INFO) << "Run " << num_times << " times..."; void PredictionRun(PaddlePredictor *predictor,
{ const std::vector<std::vector<PaddleTensor>> &inputs,
Timer run_timer; std::vector<PaddleTensor> *outputs, int num_threads,
run_timer.tic(); int tid) {
int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat;
LOG(INFO) << "Thread " << tid << " run " << num_times << " times...";
Timer run_timer;
double elapsed_time = 0;
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
ProfilerStart("paddle_inference.prof"); ProfilerStart("paddle_inference.prof");
#endif #endif
for (int i = 0; i < num_times; i++) { if (!FLAGS_zero_copy) {
for (size_t j = 0; j < inputs.size(); j++) { run_timer.tic();
predictor->Run(inputs[j], outputs, batch_size); for (size_t i = 0; i < inputs.size(); i++) {
for (int j = 0; j < num_times; j++) {
predictor->Run(inputs[i], outputs, batch_size);
}
}
elapsed_time = run_timer.toc();
} else {
for (size_t i = 0; i < inputs.size(); i++) {
ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[i]);
run_timer.tic();
for (int j = 0; j < num_times; j++) {
predictor->ZeroCopyRun();
} }
elapsed_time += run_timer.toc();
} }
}
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
ProfilerStop(); ProfilerStop();
#endif #endif
double latency = run_timer.toc() / (num_times > 1 ? num_times : 1); PrintTime(batch_size, num_times, num_threads, tid, elapsed_time / num_times,
PrintTime(batch_size, num_times, 1, 0, latency, inputs.size()); inputs.size());
if (FLAGS_record_benchmark) { if (FLAGS_record_benchmark) {
Benchmark benchmark; Benchmark benchmark;
benchmark.SetName(FLAGS_model_name); benchmark.SetName(FLAGS_model_name);
benchmark.SetBatchSize(batch_size); benchmark.SetBatchSize(batch_size);
benchmark.SetLatency(latency); benchmark.SetLatency(elapsed_time / num_times);
benchmark.PersistToFile("benchmark_record.txt"); benchmark.PersistToFile("benchmark_record.txt");
}
} }
} }
void TestOneThreadPrediction(
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, bool use_analysis = true) {
auto predictor = CreateTestPredictor(config, use_analysis);
PredictionWarmUp(predictor.get(), inputs, outputs, 1, 0);
PredictionRun(predictor.get(), inputs, outputs, 1, 0);
}
void TestMultiThreadPrediction( void TestMultiThreadPrediction(
const PaddlePredictor::Config *config, const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs, const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, int num_threads, std::vector<PaddleTensor> *outputs, int num_threads,
bool use_analysis = true) { bool use_analysis = true) {
int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat;
std::vector<std::thread> threads; std::vector<std::thread> threads;
std::vector<std::unique_ptr<PaddlePredictor>> predictors; std::vector<std::unique_ptr<PaddlePredictor>> predictors;
predictors.emplace_back(CreateTestPredictor(config, use_analysis)); predictors.emplace_back(CreateTestPredictor(config, use_analysis));
...@@ -267,7 +358,6 @@ void TestMultiThreadPrediction( ...@@ -267,7 +358,6 @@ void TestMultiThreadPrediction(
predictors.emplace_back(predictors.front()->Clone()); predictors.emplace_back(predictors.front()->Clone());
} }
size_t total_time{0};
for (int tid = 0; tid < num_threads; ++tid) { for (int tid = 0; tid < num_threads; ++tid) {
threads.emplace_back([&, tid]() { threads.emplace_back([&, tid]() {
// Each thread should have local inputs and outputs. // Each thread should have local inputs and outputs.
...@@ -280,34 +370,8 @@ void TestMultiThreadPrediction( ...@@ -280,34 +370,8 @@ void TestMultiThreadPrediction(
->SetMkldnnThreadID(static_cast<int>(tid) + 1); ->SetMkldnnThreadID(static_cast<int>(tid) + 1);
} }
#endif #endif
PredictionWarmUp(predictor.get(), inputs, outputs, num_threads, tid);
// warmup run PredictionRun(predictor.get(), inputs, outputs, num_threads, tid);
LOG(INFO) << "Running thread " << tid << ", warm up run...";
{
Timer warmup_timer;
warmup_timer.tic();
predictor->Run(inputs[0], outputs, batch_size);
PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1);
if (FLAGS_profile) {
paddle::platform::ResetProfiler();
}
}
LOG(INFO) << "Thread " << tid << " run " << num_times << " times...";
{
Timer timer;
timer.tic();
for (int i = 0; i < num_times; i++) {
for (const auto &input : inputs) {
ASSERT_TRUE(predictor->Run(input, &outputs_tid));
}
}
auto time = timer.toc();
total_time += time;
PrintTime(batch_size, num_times, num_threads, tid, time / num_times,
inputs.size());
}
}); });
} }
for (int i = 0; i < num_threads; ++i) { for (int i = 0; i < num_threads; ++i) {
...@@ -367,6 +431,31 @@ void CompareNativeAndAnalysis( ...@@ -367,6 +431,31 @@ void CompareNativeAndAnalysis(
CompareResult(analysis_outputs, native_outputs); CompareResult(analysis_outputs, native_outputs);
} }
void CompareAnalysisAndZeroCopy(
PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs,
const std::vector<std::string> &outputs_name) {
int batch_size = FLAGS_batch_size;
// analysis
std::vector<PaddleTensor> analysis_outputs;
auto predictor = CreateTestPredictor(config, true);
predictor->Run(inputs[0], &analysis_outputs, batch_size);
// analysis + zero_copy
std::vector<ZeroCopyTensor> zerocopy_outputs;
reinterpret_cast<AnalysisConfig *>(config)->SwitchUseFeedFetchOps(false);
predictor = CreateTestPredictor(config, true);
ConvertPaddleTensorToZeroCopyTensor(predictor.get(), inputs[0]);
predictor->ZeroCopyRun();
for (size_t i = 0; i < outputs_name.size(); i++) {
ZeroCopyTensor zerocopy_output =
*predictor->GetOutputTensor(outputs_name[i]).get();
zerocopy_outputs.emplace_back(zerocopy_output);
LOG(INFO) << "ZeroCopy output: " << DescribeZeroCopyTensor(zerocopy_output);
}
// compare
CompareResult(analysis_outputs, zerocopy_outputs);
}
template <typename T> template <typename T>
std::string LoDTensorSummary(const framework::LoDTensor &tensor) { std::string LoDTensorSummary(const framework::LoDTensor &tensor) {
std::stringstream ss; std::stringstream ss;
......
...@@ -30,19 +30,20 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME) ...@@ -30,19 +30,20 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME)
${EXTERNAL_PROJECT_NAME} ${EXTERNAL_PROJECT_NAME}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${INSTALL_DIR} PREFIX ${INSTALL_DIR}
URL ${URL}/${FILENAME} DOWNLOAD_COMMAND wget -q -O ${INSTALL_DIR}/${FILENAME} ${URL}/${FILENAME} &&
${CMAKE_COMMAND} -E tar xzf ${INSTALL_DIR}/${FILENAME}
DOWNLOAD_DIR ${INSTALL_DIR} DOWNLOAD_DIR ${INSTALL_DIR}
DOWNLOAD_NO_PROGRESS 1 DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
UPDATE_COMMAND "" UPDATE_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory ${UNPACK_DIR} ${INSTALL_DIR} INSTALL_COMMAND ""
) )
endfunction() endfunction()
set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec") set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec")
if (NOT EXISTS ${WORD2VEC_INSTALL_DIR}) if(NOT EXISTS ${WORD2VEC_INSTALL_DIR} AND NOT WIN32)
inference_download_and_uncompress(${WORD2VEC_INSTALL_DIR} ${INFERENCE_URL} "word2vec.inference.model.tar.gz") inference_download_and_uncompress(${WORD2VEC_INSTALL_DIR} ${INFERENCE_URL} "word2vec.inference.model.tar.gz")
endif() endif()
set(WORD2VEC_MODEL_DIR "${WORD2VEC_INSTALL_DIR}/word2vec.inference.model") set(WORD2VEC_MODEL_DIR "${WORD2VEC_INSTALL_DIR}/word2vec.inference.model")
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/memory/allocation/legacy_allocator.h" #include "paddle/fluid/memory/allocation/legacy_allocator.h"
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
......
include(operators) include(operators)
register_operators(DEPS naive_executor) register_operators(DEPS naive_executor)
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
...@@ -26,14 +27,6 @@ namespace operators { ...@@ -26,14 +27,6 @@ namespace operators {
using StepScopeVar = std::vector<framework::Scope *>; using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
static constexpr char kStepBlock[] = "sub_block";
static constexpr char kCondition[] = "Condition";
static constexpr char kStepScopes[] = "StepScopes";
static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
namespace { // NOLINT namespace { // NOLINT
static std::string GetSkipEagerDeletionVarsDebugString( static std::string GetSkipEagerDeletionVarsDebugString(
const std::vector<std::string> &vars) { const std::vector<std::string> &vars) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace operators {
// OpVariant is a wrapper class of OpDesc and OperatorBase
// So that API would be the same.
class OpVariant {
struct InputsVisitor
: public boost::static_visitor<const framework::VariableNameMap *> {
template <typename OpType>
const framework::VariableNameMap *operator()(const OpType *op) const {
return &(op->Inputs());
}
};
struct OutputsVisitor
: public boost::static_visitor<const framework::VariableNameMap *> {
template <typename OpType>
const framework::VariableNameMap *operator()(const OpType *op) const {
return &(op->Outputs());
}
};
struct AttributeMapVisitor
: public boost::static_visitor<const framework::AttributeMap *> {
const framework::AttributeMap *operator()(
const framework::OpDesc *op) const {
return &(op->GetAttrMap());
}
const framework::AttributeMap *operator()(
const framework::OperatorBase *op) const {
return &(op->Attrs());
}
};
struct RawPointerVisitor : public boost::static_visitor<const void *> {
template <typename OpType>
const void *operator()(const OpType *op) const {
return op;
}
};
public:
OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT
OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT
const framework::VariableNameMap &Inputs() const {
return *boost::apply_visitor(InputsVisitor(), op_);
}
const framework::VariableNameMap &Outputs() const {
return *boost::apply_visitor(OutputsVisitor(), op_);
}
const framework::AttributeMap &Attrs() const {
return *boost::apply_visitor(AttributeMapVisitor(), op_);
}
template <typename AttrType>
const AttrType &Attr(const std::string &name) const {
auto &attrs = Attrs();
auto it = attrs.find(name);
PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
return boost::get<AttrType>(it->second);
}
bool operator==(const OpVariant &other) const {
return RawPointer() == other.RawPointer();
}
const void *RawPointer() const {
return boost::apply_visitor(RawPointerVisitor(), op_);
}
int which() const { return static_cast<int>(op_.which()); }
struct Hasher {
size_t operator()(const OpVariant &op) const {
return reinterpret_cast<size_t>(op.RawPointer());
}
};
private:
const boost::variant<const framework::OperatorBase *,
const framework::OpDesc *>
op_;
};
static std::string GetDebugString(const std::vector<std::string> &names) {
if (names.empty()) return "";
std::string ret = names[0];
for (size_t i = 1; i < names.size(); ++i) {
ret += (" " + names[i]);
}
return ret;
}
// Set skip variables of while_op and while_grad_op
// These variables should be skipped when eager deletion enables.
// It is because:
// 1. while_grad_op needs some variables defined in while_op.
// 2. while_grad_op needs variables from the previous time step.
static void SetSkipVars(const OpVariant &op, std::vector<std::string> attr) {
auto &attrs = const_cast<framework::AttributeMap &>(op.Attrs());
VLOG(2) << "Prepare to skip " << attr.size()
<< " var(s): " << GetDebugString(attr);
attrs[kSkipEagerDeletionVars] = std::move(attr);
}
// Check whether the forward while_op and while_grad_op match
// The program may have many while_ops.
static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op,
const OpVariant &grad_op) {
return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) &&
fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs);
}
// Test whether the variable is skippable in forward while_op
// The variable is skippable in while_op when the variable used in while_grad
// is not from grad_block.
static bool IsSkippableVar(const std::string &name,
framework::BlockDesc *grad_block) {
return name != framework::kEmptyVarName && !grad_block->HasVar(name);
}
static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
const OpVariant &bwd_op) {
auto *grad_block = bwd_op.Attr<framework::BlockDesc *>(kStepBlock);
// Find all skippable variables in forward while_op
std::unordered_set<std::string> forward_skip_vars;
for (auto *op_desc : grad_block->AllOps()) {
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
if (IsSkippableVar(in_arg_name, grad_block)) {
forward_skip_vars.insert(in_arg_name);
}
}
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
if (IsSkippableVar(out_arg_name, grad_block)) {
forward_skip_vars.insert(out_arg_name);
}
}
}
SetSkipVars(fwd_op, std::vector<std::string>(forward_skip_vars.begin(),
forward_skip_vars.end()));
// Find all skippable variables in while_grad_op
// The skipped variables are those which would be used across time steps.
auto &fwd_input = fwd_op.Inputs().at(kX);
auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX));
PADDLE_ENFORCE_EQ(
fwd_input.size(), in_grads.size(),
"Backward input gradient number does not match forward input number.");
std::unordered_set<std::string> backward_skip_vars;
for (size_t i = 0; i < in_grads.size(); ++i) {
if (in_grads[i] == framework::kEmptyVarName) {
continue;
}
backward_skip_vars.insert(in_grads[i]);
backward_skip_vars.insert(framework::GradVarName(fwd_input[i]));
}
SetSkipVars(bwd_op, std::vector<std::string>(backward_skip_vars.begin(),
backward_skip_vars.end()));
}
// Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out.
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
std::vector<OpVariant> *while_grad_ops) {
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
if (while_ops->empty()) return;
const auto *program =
while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j);
if (op->Type() == "while") {
while_ops->emplace_back(op);
} else if (op->Type() == "while_grad") {
while_grad_ops->emplace_back(op);
}
}
}
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(),
"There are extra while_grad ops in the graph or program");
}
static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);
VLOG(2) << "Found while op num: " << while_ops->size()
<< ", while grad op num: " << while_grad_ops->size();
if (while_grad_ops->empty()) {
return;
}
std::unordered_set<OpVariant, OpVariant::Hasher> while_op_set(
while_ops->begin(), while_ops->end());
for (auto &bwd_op : *while_grad_ops) {
const OpVariant *matched_fwd_op = nullptr;
for (auto &fwd_op : while_op_set) {
if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) {
PADDLE_ENFORCE(matched_fwd_op == nullptr,
"Found multiple matched while ops");
matched_fwd_op = &fwd_op;
}
}
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
"Cannot find matched forward while op.");
ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op);
while_op_set.erase(*matched_fwd_op);
}
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
// If block_id is not 0, returns
// This is because all while_ops and while_grad_ops in the whole program
// would be processed when block_id is 0 (i.e. when Executor::Run() or
// ParallelExecutor constructs).
// What's more, all while_ops and while_grad_ops must be processed when
// block_id is zero. If not, while_op may run first and erase variables
// used in while_grad_op, and in this moment, while_grad_ops may be not
// constructed yet.
if (block_id != 0) return;
std::vector<OpVariant> fwd_ops, bwd_ops;
for (auto &op : all_ops) {
if (op->Type() == "while") {
fwd_ops.emplace_back(op.get());
} else if (op->Type() == "while_grad") {
bwd_ops.emplace_back(op.get());
}
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops;
fwd_ops.reserve(while_ops.size());
for (auto *op : while_ops) {
fwd_ops.emplace_back(op);
}
bwd_ops.reserve(while_grad_ops.size());
for (auto *op : while_grad_ops) {
bwd_ops.emplace_back(op);
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,19 +14,30 @@ ...@@ -14,19 +14,30 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h" #include <memory>
#include "paddle/fluid/framework/ir/pass.h" #include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace operators {
namespace details {
class EagerDeletionPass : public ir::Pass { static constexpr char kStepBlock[] = "sub_block";
protected: static constexpr char kCondition[] = "Condition";
std::unique_ptr<ir::Graph> ApplyImpl( static constexpr char kStepScopes[] = "StepScopes";
std::unique_ptr<ir::Graph> graph) const override; static constexpr char kX[] = "X";
}; static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
} // namespace details void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
} // namespace framework int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops);
} // namespace operators
} // namespace paddle } // namespace paddle
...@@ -82,8 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -82,8 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>, auto ker =
platform::CPUPlace>(tag_num); jit::KernelFuncs<jit::CRFDecodingTuple<T>, platform::CPUPlace>::Cache()
.At(tag_num);
ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;
......
...@@ -110,8 +110,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -110,8 +110,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>, auto multiply = jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>,
platform::CPUPlace>(0); platform::CPUPlace>::Cache()
.At(0);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {
......
...@@ -52,8 +52,9 @@ struct EmbeddingVSumFunctor { ...@@ -52,8 +52,9 @@ struct EmbeddingVSumFunctor {
out_width, jit::SeqPoolType::kSum); out_width, jit::SeqPoolType::kSum);
for (size_t i = 0; i != ids_lod.size() - 1; ++i) { for (size_t i = 0; i != ids_lod.size() - 1; ++i) {
attr.index_height = ids_lod[i + 1] - ids_lod[i]; attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool = jit::Get<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>, auto emb_seqpool =
platform::CPUPlace>(attr); jit::KernelFuncs<jit::EmbSeqPoolTuple<T>, platform::CPUPlace>::Cache()
.At(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width, emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr); &attr);
} }
...@@ -135,8 +136,9 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -135,8 +136,9 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace()); T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
const T *d_output_data = d_output->data<T>(); const T *d_output_data = d_output->data<T>();
auto vbroadcast = jit::Get<jit::kVBroadcast, jit::VBroadcastTuples<T>, auto vbroadcast =
platform::CPUPlace>(out_width); jit::KernelFuncs<jit::VBroadcastTuple<T>, platform::CPUPlace>::Cache()
.At(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
const T *src = d_output_data + i * out_width; const T *src = d_output_data + i * out_width;
......
...@@ -182,29 +182,32 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -182,29 +182,32 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \ const int total_T = x_dims[0]; \
const int D3 = wh_dims[1] const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \ auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const jit::gru_attr_t attr( \ const jit::gru_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
jit::gru_t one_step; \ jit::gru_t one_step; \
auto ComputeH1 = \ auto ComputeH1 = \
jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ jit::KernelFuncs<jit::GRUH1Tuple<T>, platform::CPUPlace>::Cache().At( \
auto ComputeHtPart1 = \ attr); \
jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ auto ComputeHtPart1 = \
auto ComputeHtPart2 = \ jit::KernelFuncs<jit::GRUHtPart1Tuple<T>, platform::CPUPlace>::Cache() \
jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \ .At(attr); \
const T* x_data = x->data<T>(); \ auto ComputeHtPart2 = \
const T* wx_data = wx->data<T>(); \ jit::KernelFuncs<jit::GRUHtPart2Tuple<T>, platform::CPUPlace>::Cache() \
const T* wh_data = wh->data<T>(); \ .At(attr); \
auto place = ctx.GetPlace(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place) T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
......
...@@ -235,32 +235,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -235,32 +235,34 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D4 = wh_dims[1] const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \ const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
T* checked_cell_data = nullptr; \ T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
if (use_peepholes) { \ if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const jit::lstm_attr_t attr( \ const jit::lstm_attr_t attr( \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \ jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \ jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \ use_peepholes); \
jit::lstm_t one_step; \ jit::lstm_t one_step; \
one_step.wp = wp_data; \ one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \ one_step.checked = checked_cell_data; \
auto ComputeC1H1 = \ auto ComputeC1H1 = \
jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \ jit::KernelFuncs<jit::LSTMC1H1Tuple<T>, platform::CPUPlace>::Cache().At( \
auto ComputeCtHt = \ attr); \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr) auto ComputeCtHt = \
jit::KernelFuncs<jit::LSTMCtHtTuple<T>, platform::CPUPlace>::Cache().At( \
attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
......
...@@ -82,9 +82,11 @@ template <typename T> ...@@ -82,9 +82,11 @@ template <typename T>
static void fc_relu(const T* x, const T* w, const T* b, T* y, static void fc_relu(const T* x, const T* w, const T* b, T* y,
const jit::matmul_attr_t& attr) { const jit::matmul_attr_t& attr) {
auto matmul = auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr); jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
attr);
auto addbias_relu = auto addbias_relu =
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(attr.n); jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
attr.n);
matmul(x, w, y, &attr); matmul(x, w, y, &attr);
T* dst = y; T* dst = y;
for (int i = 0; i < attr.m; ++i) { for (int i = 0; i < attr.m; ++i) {
......
...@@ -98,7 +98,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> { ...@@ -98,7 +98,7 @@ class FusionSeqPoolConcatKernel : public framework::OpKernel<T> {
attr.type = jit::SeqPoolType::kSqrt; attr.type = jit::SeqPoolType::kSqrt;
} }
auto seqpool = auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(
attr); attr);
size_t n = ins.size(); size_t n = ins.size();
size_t dst_step_size = n * w; size_t dst_step_size = n * w;
......
...@@ -94,19 +94,23 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> { ...@@ -94,19 +94,23 @@ class FusionSquaredMatSubKernel : public framework::OpKernel<T> {
int o_numel = attr.m * attr.n; int o_numel = attr.m * attr.n;
auto vsquare_x = auto vsquare_x =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.m * jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
attr.k); attr.m * attr.k);
auto vsquare_y = auto vsquare_y =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(attr.k * jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
attr.n); attr.k * attr.n);
auto vsquare_xy = auto vsquare_xy =
jit::Get<jit::kVSquare, jit::XYNTuples<T>, platform::CPUPlace>(o_numel); jit::KernelFuncs<jit::VSquareTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto vsub = auto vsub =
jit::Get<jit::kVSub, jit::XYZNTuples<T>, platform::CPUPlace>(o_numel); jit::KernelFuncs<jit::VSubTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto vscal = auto vscal =
jit::Get<jit::kVScal, jit::AXYNTuples<T>, platform::CPUPlace>(o_numel); jit::KernelFuncs<jit::VScalTuple<T>, platform::CPUPlace>::Cache().At(
o_numel);
auto matmul = auto matmul =
jit::Get<jit::kMatMul, jit::MatMulTuples<T>, platform::CPUPlace>(attr); jit::KernelFuncs<jit::MatMulTuple<T>, platform::CPUPlace>::Cache().At(
attr);
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* y_data = y->data<T>(); const T* y_data = y->data<T>();
......
...@@ -5,7 +5,7 @@ file(APPEND ${jit_file} "\#pragma once\n") ...@@ -5,7 +5,7 @@ file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n") file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n") file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place xxhash)
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc) list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
......
...@@ -59,8 +59,6 @@ BenchJITKernel* InsertBenchmark(BenchJITKernel* b) { ...@@ -59,8 +59,6 @@ BenchJITKernel* InsertBenchmark(BenchJITKernel* b) {
InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \ InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \
void BenchJITKernel_##name##_##dtype##_##place##_::Run() void BenchJITKernel_##name##_##dtype##_##place##_::Run()
#define BENCH_FP32_CPU(name) BENCH_JITKERNEL(name, FP32, CPU)
void RUN_ALL_BENCHMARK() { void RUN_ALL_BENCHMARK() {
for (auto p : g_all_benchmarks) { for (auto p : g_all_benchmarks) {
if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) { if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) {
...@@ -90,11 +88,11 @@ std::vector<int> TestSizes() { ...@@ -90,11 +88,11 @@ std::vector<int> TestSizes() {
return s; return s;
} }
template <typename KernelTuples, typename... Args> template <typename KernelTuple, typename... Args>
struct BenchFunc { struct BenchFunc {
// return this function avg time // return this function avg time
// TODO(TJ): clear cache every time // TODO(TJ): clear cache every time
double operator()(const typename KernelTuples::func_type tgt, Args... args) { double operator()(const typename KernelTuple::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) { for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...); tgt(args...);
} }
...@@ -109,40 +107,17 @@ struct BenchFunc { ...@@ -109,40 +107,17 @@ struct BenchFunc {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
template <jit::KernelType KT, typename KernelTuples, typename PlaceType, template <typename KernelTuple, typename PlaceType, typename... Args>
typename... Args> void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { BenchFunc<KernelTuple, Args...> benchmark;
BenchFunc<KernelTuples, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos; std::vector<std::pair<std::string, double>> infos;
// test refer auto funcs = jit::GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
auto refer = jit::GetRefer<KT, KernelTuples>(); for (auto f : funcs) {
if (!refer) { infos.push_back(std::make_pair(f.first, benchmark(f.second, args...)));
LOG(FATAL) << "Refer can not be empty!";
} }
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
std::make_pair(i->ImplType(), benchmark(more, args...)));
}
}
}
// Test result from Get function // Test result from Get function
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr); auto tgt = jit::KernelFuncs<KernelTuple, PlaceType>::Cache().At(attr);
if (!tgt) { if (!tgt) {
LOG(FATAL) << "Target can not be empty!"; LOG(FATAL) << "Target can not be empty!";
} }
...@@ -150,7 +125,8 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -150,7 +125,8 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
// print // print
std::ostringstream loginfos; std::ostringstream loginfos;
loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": "; loginfos << "Kernel Type " << jit::to_string(KernelTuple::kernel_type) << ": "
<< attr << ": ";
for (auto pair : infos) { for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; "; loginfos << pair.first << " takes " << pair.second << " us; ";
} }
...@@ -159,8 +135,9 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -159,8 +135,9 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
using Tensor = paddle::framework::Tensor; using Tensor = paddle::framework::Tensor;
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchXYZNKernel() { void BenchKernelXYZN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x, y, z; Tensor x, y, z;
x.Resize({d}); x.Resize({d});
...@@ -171,16 +148,16 @@ void BenchXYZNKernel() { ...@@ -171,16 +148,16 @@ void BenchXYZNKernel() {
T* z_data = z.mutable_data<T>(PlaceType()); T* z_data = z.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data); RandomVec<T>(d, x_data);
RandomVec<T>(d, y_data); RandomVec<T>(d, y_data);
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y.data<T>(), z_data,
y.data<T>(), z_data, d); d);
// test inplace // test inplace
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data, BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), z_data, z_data, d);
z_data, d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchAXYNKernel() { void BenchKernelAXYN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
const T a = static_cast<T>(3); const T a = static_cast<T>(3);
Tensor x, y; Tensor x, y;
...@@ -189,26 +166,26 @@ void BenchAXYNKernel() { ...@@ -189,26 +166,26 @@ void BenchAXYNKernel() {
T* x_data = x.mutable_data<T>(PlaceType()); T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data); RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data, BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), y_data, d);
d);
// test inplace // test inplace
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), x_data, BenchAllImpls<KernelTuple, PlaceType>(d, &a, x.data<T>(), x_data, d);
d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchXRNKernel() { void BenchKernelXRN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x; Tensor x;
RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType())); RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType()));
T res; T res;
BenchAllImpls<KT, jit::XRNTuples<T>, PlaceType>(d, x.data<T>(), &res, d); BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), &res, d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchXYNKernel() { void BenchKernelXYN() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
Tensor x, y; Tensor x, y;
x.Resize({d}); x.Resize({d});
...@@ -216,12 +193,13 @@ void BenchXYNKernel() { ...@@ -216,12 +193,13 @@ void BenchXYNKernel() {
T* x_data = x.mutable_data<T>(PlaceType()); T* x_data = x.mutable_data<T>(PlaceType());
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
RandomVec<T>(d, x_data); RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data<T>(), y_data, d); BenchAllImpls<KernelTuple, PlaceType>(d, x.data<T>(), y_data, d);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchLSTMKernel() { void BenchKernelLSTM() {
using T = typename KernelTuple::data_type;
for (bool use_peephole : {true, false}) { for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh, const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
...@@ -252,13 +230,14 @@ void BenchLSTMKernel() { ...@@ -252,13 +230,14 @@ void BenchLSTMKernel() {
step.wp = wp_data; step.wp = wp_data;
step.checked = checked_data; step.checked = checked_data;
} }
BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr); BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchGRUKernel() { void BenchKernelGRU() {
using T = typename KernelTuple::data_type;
for (int d : TestSizes()) { for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh); const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
auto place = PlaceType(); auto place = PlaceType();
...@@ -275,12 +254,13 @@ void BenchGRUKernel() { ...@@ -275,12 +254,13 @@ void BenchGRUKernel() {
step.gates = x_data; step.gates = x_data;
step.ht_1 = ht_1_data; step.ht_1 = ht_1_data;
step.ht = ht_data; step.ht = ht_data;
BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr); BenchAllImpls<KernelTuple, PlaceType>(attr, &step, &attr);
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchSeqPoolKernel() { void BenchKernelSeqPool() {
using T = typename KernelTuple::data_type;
std::vector<jit::SeqPoolType> pool_types = { std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) { for (auto type : pool_types) {
...@@ -294,15 +274,15 @@ void BenchSeqPoolKernel() { ...@@ -294,15 +274,15 @@ void BenchSeqPoolKernel() {
RandomVec<T>(h * w, x.mutable_data<T>(PlaceType()), -2.f, 2.f); RandomVec<T>(h * w, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data, BenchAllImpls<KernelTuple, PlaceType>(attr, x_data, y_data, &attr);
y_data, &attr);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchEmbSeqPoolKernel() { void BenchKernelEmbSeqPool() {
using T = typename KernelTuple::data_type;
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum}; std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
int64_t tbl_h = 1e4; int64_t tbl_h = 1e4;
for (int tbl_w : {10, 16, 256}) { for (int tbl_w : {10, 16, 256}) {
...@@ -324,16 +304,17 @@ void BenchEmbSeqPoolKernel() { ...@@ -324,16 +304,17 @@ void BenchEmbSeqPoolKernel() {
tbl_h - 1); tbl_h - 1);
const int64_t* idx_data = idx.data<int64_t>(); const int64_t* idx_data = idx.data<int64_t>();
T* o_data = out.mutable_data<T>(PlaceType()); T* o_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(attr, table_data, idx_data,
attr, table_data, idx_data, o_data, &attr); o_data, &attr);
} }
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchSgdKernel() { void BenchKernelSgd() {
using T = typename KernelTuple::data_type;
const T lr = 0.1; const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower, auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> { const int64_t upper) -> std::vector<int64_t> {
...@@ -364,15 +345,16 @@ void BenchSgdKernel() { ...@@ -364,15 +345,16 @@ void BenchSgdKernel() {
const T* grad_data = grad.data<T>(); const T* grad_data = grad.data<T>();
const int64_t* rows_data = rows.data(); const int64_t* rows_data = rows.data();
jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size); jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
BenchAllImpls<KT, jit::SgdTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(attr, &lr, param_data, grad_data,
attr, &lr, param_data, grad_data, rows_data, param_data, &attr); rows_data, param_data, &attr);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchMatMulKernel() { void BenchKernelMatMul() {
using T = typename KernelTuple::data_type;
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
for (int k : TestSizes()) { for (int k : TestSizes()) {
...@@ -386,15 +368,16 @@ void BenchMatMulKernel() { ...@@ -386,15 +368,16 @@ void BenchMatMulKernel() {
const T* b_data = b.data<T>(); const T* b_data = b.data<T>();
T* c_data = c.mutable_data<T>(PlaceType()); T* c_data = c.mutable_data<T>(PlaceType());
const jit::matmul_attr_t attr{m, n, k}; const jit::matmul_attr_t attr{m, n, k};
BenchAllImpls<KT, jit::MatMulTuples<T>, PlaceType>(attr, a_data, b_data, BenchAllImpls<KernelTuple, PlaceType>(attr, a_data, b_data, c_data,
c_data, &attr); &attr);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchSoftmaxKernel() { void BenchKernelSoftmax() {
using T = typename KernelTuple::data_type;
for (int bs : {1, 2, 10}) { for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
Tensor x, y; Tensor x, y;
...@@ -403,14 +386,14 @@ void BenchSoftmaxKernel() { ...@@ -403,14 +386,14 @@ void BenchSoftmaxKernel() {
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f); RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType>(n, x_data, y_data, n, BenchAllImpls<KernelTuple, PlaceType>(n, x_data, y_data, n, bs);
bs);
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchLayerNormKernel() { void BenchKernelLayerNorm() {
using T = typename KernelTuple::data_type;
const T epsilon = 9.99999975e-06; const T epsilon = 9.99999975e-06;
for (int n : {1, 2, 10}) { for (int n : {1, 2, 10}) {
for (int x_dim_0 : {1, 9, 17, 50}) { for (int x_dim_0 : {1, 9, 17, 50}) {
...@@ -439,16 +422,17 @@ void BenchLayerNormKernel() { ...@@ -439,16 +422,17 @@ void BenchLayerNormKernel() {
T* var_data = var.data<T>(); T* var_data = var.data<T>();
T* out_data = out.mutable_data<T>(PlaceType()); T* out_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::LayerNormTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(right, x_data, out_data,
right, x_data, out_data, mean_data, var_data, scale_data, bias_data, mean_data, var_data, scale_data,
left, epsilon, right); bias_data, left, epsilon, right);
} }
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchCRFDecodingKernel() { void BenchKernelCRFDecoding() {
using T = typename KernelTuple::data_type;
constexpr int state_trans_base_idx = 2; constexpr int state_trans_base_idx = 2;
for (int seq_len : {1, 11, 17, 50}) { for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : TestSizes()) { for (int tag_num : TestSizes()) {
...@@ -468,14 +452,15 @@ void BenchCRFDecodingKernel() { ...@@ -468,14 +452,15 @@ void BenchCRFDecodingKernel() {
T* alpha_data = alpha.mutable_data<T>(PlaceType()); T* alpha_data = alpha.mutable_data<T>(PlaceType());
int* track_data = track.mutable_data<int>(PlaceType()); int* track_data = track.mutable_data<int>(PlaceType());
BenchAllImpls<KT, jit::CRFDecodingTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(tag_num, seq_len, x_data, w_data,
tag_num, seq_len, x_data, w_data, alpha_data, track_data, tag_num); alpha_data, track_data, tag_num);
} }
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchVBroadcastKernel() { void BenchKernelVBroadcast() {
using T = typename KernelTuple::data_type;
for (int64_t w : {1, 16, 64, 100, 256}) { for (int64_t w : {1, 16, 64, 100, 256}) {
Tensor x; Tensor x;
x.Resize({w}); x.Resize({w});
...@@ -485,78 +470,86 @@ void BenchVBroadcastKernel() { ...@@ -485,78 +470,86 @@ void BenchVBroadcastKernel() {
Tensor y; Tensor y;
y.Resize({h * w}); y.Resize({h * w});
T* y_data = y.mutable_data<T>(PlaceType()); T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType>( BenchAllImpls<KernelTuple, PlaceType>(w, x_data, y_data,
w, x_data, y_data, static_cast<int64_t>(h), w); static_cast<int64_t>(h), w);
} }
} }
} }
using T = float; #define BenchKernelVMul BenchKernelXYZN
using CPUPlace = paddle::platform::CPUPlace; #define BenchKernelVAdd BenchKernelXYZN
#define BenchKernelVAddRelu BenchKernelXYZN
#define BenchKernelVSub BenchKernelXYZN
// xyzn #define BenchKernelVScal BenchKernelAXYN
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, CPUPlace>(); } #define BenchKernelVAddBias BenchKernelAXYN
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, CPUPlace>(); }
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, CPUPlace>(); }
// axyn #define BenchKernelVRelu BenchKernelXYN
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, CPUPlace>(); } #define BenchKernelVIdentity BenchKernelXYN
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, CPUPlace>(); } #define BenchKernelVSquare BenchKernelXYN
#define BenchKernelVExp BenchKernelXYN
#define BenchKernelVSigmoid BenchKernelXYN
#define BenchKernelVTanh BenchKernelXYN
#define BenchKernelVCopy BenchKernelXYN
// xrn #define BenchKernelHMax BenchKernelXRN
BENCH_FP32_CPU(kHSum) { BenchXRNKernel<jit::kHSum, T, CPUPlace>(); } #define BenchKernelHSum BenchKernelXRN
BENCH_FP32_CPU(kHMax) { BenchXRNKernel<jit::kHMax, T, CPUPlace>(); }
// xyn #define BenchKernelLSTMCtHt BenchKernelLSTM
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, CPUPlace>(); } #define BenchKernelLSTMC1H1 BenchKernelLSTM
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
BENCH_FP32_CPU(kVCopy) { BenchXYNKernel<jit::kVCopy, T, CPUPlace>(); }
// lstm and peephole
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, CPUPlace>(); }
// gru functions
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel<jit::kGRUHtPart1, T, CPUPlace>(); }
BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
// seq pool function
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
// embedding seq pool function
BENCH_FP32_CPU(kEmbSeqPool) {
BenchEmbSeqPoolKernel<jit::kEmbSeqPool, T, CPUPlace>();
}
// sgd function #define BenchKernelGRUH1 BenchKernelGRU
BENCH_FP32_CPU(kSgd) { BenchSgdKernel<jit::kSgd, T, CPUPlace>(); } #define BenchKernelGRUHtPart1 BenchKernelGRU
#define BenchKernelGRUHtPart2 BenchKernelGRU
// matmul using CPUPlace = paddle::platform::CPUPlace;
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
// softmax #define BENCH_FP32_CPU(name) \
BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel<jit::kSoftmax, T, CPUPlace>(); } BENCH_JITKERNEL(name, FP32, CPU) { \
BenchKernel##name<jit::name##Tuple<float>, CPUPlace>(); \
}
// layernorm // xyzn
BENCH_FP32_CPU(kLayerNorm) { BENCH_FP32_CPU(VMul);
BenchLayerNormKernel<jit::kLayerNorm, T, CPUPlace>(); BENCH_FP32_CPU(VAdd);
} BENCH_FP32_CPU(VAddRelu);
BENCH_FP32_CPU(VSub);
// crfdecoding // axyn
BENCH_FP32_CPU(kCRFDecoding) { BENCH_FP32_CPU(VScal);
BenchCRFDecodingKernel<jit::kCRFDecoding, T, CPUPlace>(); BENCH_FP32_CPU(VAddBias);
}
// vbroadcast function // xyn
BENCH_FP32_CPU(kVBroadcast) { BENCH_FP32_CPU(VRelu);
BenchVBroadcastKernel<jit::kVBroadcast, T, CPUPlace>(); BENCH_FP32_CPU(VIdentity);
} BENCH_FP32_CPU(VSquare);
BENCH_FP32_CPU(VExp);
BENCH_FP32_CPU(VSigmoid);
BENCH_FP32_CPU(VTanh);
BENCH_FP32_CPU(VCopy);
// xrn
BENCH_FP32_CPU(HMax);
BENCH_FP32_CPU(HSum);
// LSTM
BENCH_FP32_CPU(LSTMCtHt);
BENCH_FP32_CPU(LSTMC1H1);
// GRU
BENCH_FP32_CPU(GRUH1);
BENCH_FP32_CPU(GRUHtPart1);
BENCH_FP32_CPU(GRUHtPart2);
BENCH_FP32_CPU(LayerNorm);
BENCH_FP32_CPU(CRFDecoding);
BENCH_FP32_CPU(SeqPool);
BENCH_FP32_CPU(EmbSeqPool);
BENCH_FP32_CPU(MatMul);
BENCH_FP32_CPU(Softmax);
BENCH_FP32_CPU(Sgd);
BENCH_FP32_CPU(VBroadcast);
// Benchmark all jit kernels including jitcode, mkl and refer. // Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...] // To use this tool, run command: ./benchmark [options...]
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h" #include "paddle/fluid/operators/jit/gen/act.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -81,7 +82,7 @@ void VActJitCode::genCode() { ...@@ -81,7 +82,7 @@ void VActJitCode::genCode() {
#define DECLARE_ACT_CREATOR(name) \ #define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override; \ bool CanBeUsed(const int& attr) const override; \
size_t CodeSize(const int& d) const override; \ size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \ std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \ return make_unique<name##JitCode>(attr, CodeSize(attr)); \
...@@ -96,27 +97,27 @@ DECLARE_ACT_CREATOR(VSigmoid); ...@@ -96,27 +97,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh); DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me // TODO(TJ): tuning use me
bool VReluCreator::UseMe(const int& d) const { bool VReluCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VSquareCreator::UseMe(const int& d) const { bool VSquareCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VIdentityCreator::UseMe(const int& d) const { bool VIdentityCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VExpCreator::UseMe(const int& d) const { bool VExpCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d < 32; return platform::MayIUse(platform::avx) && d < 32;
} }
bool VSigmoidCreator::UseMe(const int& d) const { bool VSigmoidCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
bool VTanhCreator::UseMe(const int& d) const { bool VTanhCreator::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h" #include "paddle/fluid/operators/jit/gen/blas.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -142,7 +143,7 @@ void NCHW16CMulNCJitCode::genCode() { ...@@ -142,7 +143,7 @@ void NCHW16CMulNCJitCode::genCode() {
class NCHW16CMulNCCreator : public JitCodeCreator<int> { class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public: public:
bool UseMe(const int& attr) const override { bool CanBeUsed(const int& attr) const override {
return platform::MayIUse(platform::avx512f); return platform::MayIUse(platform::avx512f);
} }
size_t CodeSize(const int& d) const override { return 256 * 1024; } size_t CodeSize(const int& d) const override { return 256 * 1024; }
...@@ -154,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> { ...@@ -154,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
#define DECLARE_BLAS_CREATOR(name) \ #define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override { \ bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx) && attr <= 1024; \ return platform::MayIUse(platform::avx) && attr <= 1024; \
} \ } \
size_t CodeSize(const int& d) const override { \ size_t CodeSize(const int& d) const override { \
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/embseqpool.h" #include "paddle/fluid/operators/jit/gen/embseqpool.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
...@@ -121,7 +122,7 @@ void EmbSeqPoolJitCode::genCode() { ...@@ -121,7 +122,7 @@ void EmbSeqPoolJitCode::genCode() {
class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> { class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
public: public:
bool UseMe(const emb_seq_pool_attr_t& attr) const override { bool CanBeUsed(const emb_seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx) && return platform::MayIUse(platform::avx) &&
attr.table_width % YMM_FLOAT_BLOCK == 0; attr.table_width % YMM_FLOAT_BLOCK == 0;
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/gru.h" #include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -86,7 +87,7 @@ void GRUJitCode::genCode() { ...@@ -86,7 +87,7 @@ void GRUJitCode::genCode() {
class name##Creator : public JitCodeCreator<gru_attr_t> { \ class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \ public: \
/* TODO(TJ): enable more */ \ /* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \ bool CanBeUsed(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \ } \
size_t CodeSize(const gru_attr_t& attr) const override { \ size_t CodeSize(const gru_attr_t& attr) const override { \
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/hopv.h" #include "paddle/fluid/operators/jit/gen/hopv.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -76,7 +77,7 @@ void HOPVJitCode::genCode() { ...@@ -76,7 +77,7 @@ void HOPVJitCode::genCode() {
#define DECLARE_HOP_CREATOR(name) \ #define DECLARE_HOP_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override { \ bool CanBeUsed(const int& attr) const override { \
return platform::MayIUse(platform::avx); \ return platform::MayIUse(platform::avx); \
} \ } \
size_t CodeSize(const int& d) const override { \ size_t CodeSize(const int& d) const override { \
......
...@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator { ...@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
virtual void genCode() = 0; virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); } size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override { const unsigned char* getCodeInternal() const override {
const Xbyak::uint8* code = CodeGenerator::getCode(); const Xbyak::uint8* code = CodeGenerator::getCode();
return code; return code;
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/lstm.h" #include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -114,7 +115,7 @@ void LSTMJitCode::genCode() { ...@@ -114,7 +115,7 @@ void LSTMJitCode::genCode() {
class name##Creator : public JitCodeCreator<lstm_attr_t> { \ class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \ public: \
/* TODO(TJ): enable more */ \ /* TODO(TJ): enable more */ \
bool UseMe(const lstm_attr_t& attr) const override { \ bool CanBeUsed(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \ } \
size_t CodeSize(const lstm_attr_t& attr) const override { \ size_t CodeSize(const lstm_attr_t& attr) const override { \
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "paddle/fluid/operators/jit/gen/matmul.h" #include "paddle/fluid/operators/jit/gen/matmul.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() { ...@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
class MatMulCreator : public JitCodeCreator<matmul_attr_t> { class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
public: public:
bool UseMe(const matmul_attr_t& attr) const override { bool CanBeUsed(const matmul_attr_t& attr) const override {
return attr.m == 1 && platform::MayIUse(platform::avx512f) && return attr.m == 1 && platform::MayIUse(platform::avx512f) &&
attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512; attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h" #include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <memory>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones #include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -57,7 +58,7 @@ void SeqPoolJitCode::genCode() { ...@@ -57,7 +58,7 @@ void SeqPoolJitCode::genCode() {
class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> { class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
public: public:
bool UseMe(const seq_pool_attr_t& attr) const override { bool CanBeUsed(const seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
size_t CodeSize(const seq_pool_attr_t& attr) const override { size_t CodeSize(const seq_pool_attr_t& attr) const override {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/sgd.h" #include "paddle/fluid/operators/jit/gen/sgd.h"
#include <stddef.h> // offsetof #include <stddef.h> // offsetof
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/registry.h" #include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
...@@ -104,7 +105,7 @@ void SgdJitCode::genCode() { ...@@ -104,7 +105,7 @@ void SgdJitCode::genCode() {
class SgdCreator : public JitCodeCreator<sgd_attr_t> { class SgdCreator : public JitCodeCreator<sgd_attr_t> {
public: public:
bool UseMe(const sgd_attr_t& attr) const override { bool CanBeUsed(const sgd_attr_t& attr) const override {
return platform::MayIUse(platform::avx) && return platform::MayIUse(platform::avx) &&
attr.grad_width % YMM_FLOAT_BLOCK == 0; attr.grad_width % YMM_FLOAT_BLOCK == 0;
} }
......
...@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() { ...@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
class VBroadcastCreator : public JitCodeCreator<int64_t> { class VBroadcastCreator : public JitCodeCreator<int64_t> {
public: public:
bool UseMe(const int64_t& w) const override { bool CanBeUsed(const int64_t& w) const override {
return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0; return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0;
} }
size_t CodeSize(const int64_t& w) const override { size_t CodeSize(const int64_t& w) const override {
......
...@@ -31,7 +31,7 @@ namespace paddle { ...@@ -31,7 +31,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
// refer do not need useme, it would be the last one. // refer do not need CanBeUsed, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const { void GenBase::dumpCode(const unsigned char* code) const {
if (code) { if (code) {
static int counter = 0; static int counter = 0;
......
...@@ -31,9 +31,10 @@ class GenBase : public Kernel { ...@@ -31,9 +31,10 @@ class GenBase : public Kernel {
virtual ~GenBase() = default; virtual ~GenBase() = default;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual size_t getSize() const = 0; virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0; virtual const unsigned char* getCodeInternal() const = 0;
const char* ImplType() const override { return "JitCode"; }
template <typename Func> template <typename Func>
Func getCode() { Func getCode() const {
const unsigned char* code = this->getCodeInternal(); const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) { if (FLAGS_dump_jitcode) {
this->dumpCode(code); this->dumpCode(code);
...@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator { ...@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
virtual ~JitCodeCreator() = default; virtual ~JitCodeCreator() = default;
// condition when this jit code can be used. // condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0; virtual bool CanBeUsed(const Attr& attr) const = 0;
// estimate this code size // estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0; virtual size_t CodeSize(const Attr& attr) const = 0;
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <unordered_map>
#include <utility> // for std::move
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
...@@ -27,35 +29,34 @@ namespace paddle { ...@@ -27,35 +29,34 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value && std::is_same<typename KernelTuple::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value, std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type const Kernel*>::type
GetJitCode(const typename KernelTuples::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuples::func_type; using Attr = typename KernelTuple::attr_type;
using Attr = typename KernelTuples::attr_type; int64_t key = JitCodeKey<Attr>(attr);
size_t key = JitCodeKey<Attr>(attr); auto& codes = JitCodePool<KernelTuple::kernel_type>::Instance();
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) { if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>(); return codes.AllKernels().at(key).get();
} }
// creator is not related with attr, so can use KernelKey as key // creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KernelTuple::kernel_type, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>) // pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); auto& creator_map = JitCodeCreatorPool::Instance().AllCreators();
auto iter = creator_map.find(kkey); auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) { if (iter != creator_map.end()) {
auto& creators = iter->second; auto& creators = iter->second;
for (auto& cur : creators) { for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get()); auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) { if (i && i->CanBeUsed(attr)) {
auto p = i->CreateJitCode(attr); auto p = i->CreateJitCode(attr);
if (p) { if (p) {
auto f = p->template getCode<Func>(); auto res = p.get();
codes.Insert(key, std::move(p)); codes.Insert(key, std::move(p));
return f; return res;
} }
} }
} }
...@@ -63,87 +64,153 @@ GetJitCode(const typename KernelTuples::attr_type& attr) { ...@@ -63,87 +64,153 @@ GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr; return nullptr;
} }
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
inline typename std::enable_if< inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value || !std::is_same<typename KernelTuple::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value, !std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type const Kernel*>::type
GetJitCode(const typename KernelTuples::attr_type& attr) { GetJitCode(const typename KernelTuple::attr_type& attr) {
return nullptr; return nullptr;
} }
// Refer code do not related with attr, which is just for cast // Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace // Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples> template <typename KernelTuple>
inline typename KernelTuples::func_type GetRefer() { inline const Kernel* GetReferKernel() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels(); auto& ref_pool = ReferKernelPool::Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace()); KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey); auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(), PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function."); "Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second; auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) { for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get()); auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
if (i) { if (i) {
return i->GetFunc(); return i;
} }
} }
return nullptr; return nullptr;
} }
template <KernelType KT, typename KernelTuples, template <typename KernelTuple>
typename PlaceType = platform::CPUPlace> inline typename KernelTuple::func_type GetReferFunc() {
typename KernelTuples::func_type Get( auto ker = GetReferKernel<KernelTuple>();
const typename KernelTuples::attr_type& attr) { auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr); PADDLE_ENFORCE(p, "The Refer kernel should exsit");
if (jitfunc) { return p->GetFunc();
return jitfunc; }
// Return all Kernels that can be used
template <typename KernelTuple, typename PlaceType>
std::vector<const Kernel*> GetAllCandidateKernels(
const typename KernelTuple::attr_type& attr) {
// the search order shoudl be jitcode > more > refer
std::vector<const Kernel*> res;
auto jitker = GetJitCode<KernelTuple, PlaceType>(attr);
if (jitker) {
res.emplace_back(jitker);
} }
// pool: (KernelKey(type, place), vector<KernelPtr>) // more kernelpool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType()); KernelKey kkey(KernelTuple::kernel_type, PlaceType());
auto& pool = KernelPool().Instance().AllKernels(); auto& pool = KernelPool::Instance().AllKernels();
auto iter = pool.find(kkey); auto iter = pool.find(kkey);
if (iter != pool.end()) { if (iter != pool.end()) {
auto& impls = iter->second; auto& impls = iter->second;
for (auto& impl : impls) { for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get()); auto i = dynamic_cast<const KernelMore<KernelTuple>*>(impl.get());
if (i && i->UseMe(attr)) { if (i && i->CanBeUsed(attr)) {
return i->GetFunc(); res.emplace_back(i);
} }
} }
} }
// The last implementation should be reference function on CPUPlace. // The last implementation should be reference function on CPUPlace.
return GetRefer<KT, KernelTuples>(); auto ref = GetReferKernel<KernelTuple>();
PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty.");
res.emplace_back(ref);
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<std::pair<std::string, typename KernelTuple::func_type>>
GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
using Func = typename KernelTuple::func_type;
auto kers = GetAllCandidateKernels<KernelTuple, PlaceType>(attr);
std::vector<std::pair<std::string, Func>> res;
for (auto k : kers) {
std::string name = k->ImplType();
if (name == "JitCode") {
auto i = dynamic_cast<const GenBase*>(k);
PADDLE_ENFORCE(i, "jitcode kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
} else {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
PADDLE_ENFORCE(i, "kernel cast can not fail.");
res.emplace_back(std::make_pair(name, i->GetFunc()));
}
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
std::vector<typename KernelTuple::func_type> GetAllCandidateFuncs(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncsWithTypes<KernelTuple, PlaceType>(attr);
std::vector<typename KernelTuple::func_type> res;
for (auto& i : funcs) {
res.emplace_back(i.second);
}
return res;
}
template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type GetDefaultBestFunc(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
PADDLE_ENFORCE_GE(funcs.size(), 1UL);
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
return funcs[0];
} }
template <KernelType KT, typename KernelTuples, typename PlaceType> template <typename KernelTuple, typename PlaceType>
class KernelFuncs { class KernelFuncs {
public: public:
KernelFuncs() = default; KernelFuncs() = default;
static KernelFuncs& Cache() { static KernelFuncs& Cache() {
static thread_local KernelFuncs<KT, KernelTuples, PlaceType> g_func_cache; static thread_local KernelFuncs<KernelTuple, PlaceType> g_func_cache;
return g_func_cache; return g_func_cache;
} }
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); } // the exposed interface to use
typename KernelTuple::func_type At(
void Insert(int key, typename KernelTuples::func_type func) { const typename KernelTuple::attr_type& attr) {
funcs_.emplace(key, func); // Maybe here is not good enough, not all kernels should have jitcode
} int64_t key = JitCodeKey<typename KernelTuple::attr_type>(attr);
typename KernelTuples::func_type At(int key) {
if (Has(key)) { if (Has(key)) {
return funcs_.at(key); return funcs_.at(key);
} }
auto func = Get<KT, KernelTuples, PlaceType>(key); // If do not have this attr in cache then get the default best
auto func = GetDefaultBestFunc<KernelTuple, PlaceType>(attr);
Insert(key, func); Insert(key, func);
return func; return func;
} }
typename KernelTuple::func_type operator[](
const typename KernelTuple::attr_type& attr) {
return At(attr);
}
protected:
bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); }
void Insert(int64_t key, typename KernelTuple::func_type func) {
funcs_.emplace(key, func);
}
private: private:
std::unordered_map<int, typename KernelTuples::func_type> funcs_; std::unordered_map<int64_t, typename KernelTuple::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncs); DISABLE_COPY_AND_ASSIGN(KernelFuncs);
}; };
......
...@@ -62,26 +62,55 @@ typedef enum { ...@@ -62,26 +62,55 @@ typedef enum {
kSqrt, kSqrt,
} SeqPoolType; } SeqPoolType;
// x, y, z, n
template <typename T> template <typename T>
struct XYZNTuples { struct XYZNTuple {
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int); typedef void (*func_type)(const T*, const T*, T*, int);
}; };
// a, x, y, n
template <typename T> template <typename T>
struct AXYNTuples : public XYZNTuples<T> {}; struct AXYNTuple : public XYZNTuple<T> {};
// x, y, n
template <typename T> template <typename T>
struct XYNTuples { struct XYNTuple {
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, T*, int); typedef void (*func_type)(const T*, T*, int);
}; };
// x, return and int // x, returned value, n
template <typename T> template <typename T>
struct XRNTuples : public XYNTuples<T> {}; struct XRNTuple : public XYNTuple<T> {};
#define DECLARE_KERNELTUPLE(kernel_tuple, type) \
template <typename T> \
struct type##Tuple : public kernel_tuple<T> { \
static constexpr KernelType kernel_type = k##type; \
}
// Tuple should be corresponding to the KernelType
DECLARE_KERNELTUPLE(XYZNTuple, VMul);
DECLARE_KERNELTUPLE(XYZNTuple, VAdd);
DECLARE_KERNELTUPLE(XYZNTuple, VAddRelu);
DECLARE_KERNELTUPLE(XYZNTuple, VSub);
DECLARE_KERNELTUPLE(AXYNTuple, VScal);
DECLARE_KERNELTUPLE(AXYNTuple, VAddBias);
DECLARE_KERNELTUPLE(XYNTuple, VRelu);
DECLARE_KERNELTUPLE(XYNTuple, VIdentity);
DECLARE_KERNELTUPLE(XYNTuple, VSquare);
DECLARE_KERNELTUPLE(XYNTuple, VExp);
DECLARE_KERNELTUPLE(XYNTuple, VSigmoid);
DECLARE_KERNELTUPLE(XYNTuple, VTanh);
DECLARE_KERNELTUPLE(XYNTuple, VCopy);
DECLARE_KERNELTUPLE(XRNTuple, HMax);
DECLARE_KERNELTUPLE(XRNTuple, HSum);
typedef struct { typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh void* gates; // gates: x_ch, x_ih, x_fh, x_oh
...@@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t; ...@@ -122,21 +151,31 @@ typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t; typedef struct lstm_attr_s lstm_attr_t;
template <typename T> template <typename T>
struct LSTMTuples { struct LSTMTuple {
typedef T data_type; typedef T data_type;
typedef lstm_attr_t attr_type; typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*); typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
}; };
template <typename T> template <typename T>
struct GRUTuples { struct GRUTuple {
typedef T data_type; typedef T data_type;
typedef gru_attr_t attr_type; typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*); typedef void (*func_type)(gru_t*, const gru_attr_t*);
}; };
DECLARE_KERNELTUPLE(LSTMTuple, LSTMCtHt);
DECLARE_KERNELTUPLE(LSTMTuple, LSTMC1H1);
DECLARE_KERNELTUPLE(GRUTuple, GRUH1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart1);
DECLARE_KERNELTUPLE(GRUTuple, GRUHtPart2);
#undef DECLARE_KERNELTUPLE
template <typename T> template <typename T>
struct VBroadcastTuples { struct VBroadcastTuple {
static constexpr KernelType kernel_type = kVBroadcast;
typedef T data_type; typedef T data_type;
typedef int64_t attr_type; typedef int64_t attr_type;
typedef void (*func_type)(const T*, T*, int64_t, int64_t); typedef void (*func_type)(const T*, T*, int64_t, int64_t);
...@@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s { ...@@ -151,7 +190,8 @@ typedef struct seq_pool_attr_s {
} seq_pool_attr_t; } seq_pool_attr_t;
template <typename T> template <typename T>
struct SeqPoolTuples { struct SeqPoolTuple {
static constexpr KernelType kernel_type = kSeqPool;
typedef T data_type; typedef T data_type;
typedef seq_pool_attr_t attr_type; typedef seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
...@@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s { ...@@ -176,7 +216,8 @@ typedef struct emb_seq_pool_attr_s {
} emb_seq_pool_attr_t; } emb_seq_pool_attr_t;
template <typename T> template <typename T>
struct EmbSeqPoolTuples { struct EmbSeqPoolTuple {
static constexpr KernelType kernel_type = kEmbSeqPool;
typedef T data_type; typedef T data_type;
typedef emb_seq_pool_attr_t attr_type; typedef emb_seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, const int64_t*, T*, typedef void (*func_type)(const T*, const int64_t*, T*,
...@@ -198,7 +239,8 @@ typedef struct sgd_attr_s { ...@@ -198,7 +239,8 @@ typedef struct sgd_attr_s {
} sgd_attr_t; } sgd_attr_t;
template <typename T> template <typename T>
struct SgdTuples { struct SgdTuple {
static constexpr KernelType kernel_type = kSgd;
typedef T data_type; typedef T data_type;
typedef sgd_attr_t attr_type; typedef sgd_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*, typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*,
...@@ -214,21 +256,24 @@ typedef struct matmul_attr_s { ...@@ -214,21 +256,24 @@ typedef struct matmul_attr_s {
} matmul_attr_t; } matmul_attr_t;
template <typename T> template <typename T>
struct MatMulTuples { struct MatMulTuple {
static constexpr KernelType kernel_type = kMatMul;
typedef T data_type; typedef T data_type;
typedef matmul_attr_t attr_type; typedef matmul_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*); typedef void (*func_type)(const T*, const T*, T*, const matmul_attr_t*);
}; };
template <typename T> template <typename T>
struct CRFDecodingTuples { struct CRFDecodingTuple {
static constexpr KernelType kernel_type = kCRFDecoding;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int); typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
}; };
template <typename T> template <typename T>
struct LayerNormTuples { struct LayerNormTuple {
static constexpr KernelType kernel_type = kLayerNorm;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int, typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
...@@ -236,7 +281,8 @@ struct LayerNormTuples { ...@@ -236,7 +281,8 @@ struct LayerNormTuples {
}; };
template <typename T> template <typename T>
struct SoftmaxTuples { struct SoftmaxTuple {
static constexpr KernelType kernel_type = kSoftmax;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, T*, int, int); typedef void (*func_type)(const T*, T*, int, int);
...@@ -244,7 +290,8 @@ struct SoftmaxTuples { ...@@ -244,7 +290,8 @@ struct SoftmaxTuples {
// nChw16c = nChw16c .* NC // nChw16c = nChw16c .* NC
template <typename T> template <typename T>
struct NCHW16CMulNCTuples { struct NCHW16CMulNCTuple {
static constexpr KernelType kernel_type = kNCHW16CMulNC;
typedef T data_type; typedef T data_type;
typedef int attr_type; typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int); typedef void (*func_type)(const T*, const T*, T*, int, int);
...@@ -255,28 +302,29 @@ class Kernel { ...@@ -255,28 +302,29 @@ class Kernel {
public: public:
Kernel() = default; Kernel() = default;
virtual ~Kernel() = default; virtual ~Kernel() = default;
virtual const char* ImplType() const = 0;
DISABLE_COPY_AND_ASSIGN(Kernel); DISABLE_COPY_AND_ASSIGN(Kernel);
}; };
template <typename KernelTuples> template <typename KernelTuple>
class KernelMore : public Kernel { class KernelMore : public Kernel {
public: public:
using T = typename KernelTuples::data_type; using T = typename KernelTuple::data_type;
using Func = typename KernelTuples::func_type; using Func = typename KernelTuple::func_type;
using Attr = typename KernelTuples::attr_type; using Attr = typename KernelTuple::attr_type;
virtual Func GetFunc() const { return func; } virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0; // specify this kernel can be used, means it should not fail if use it.
virtual const char* ImplType() const = 0; virtual bool CanBeUsed(const Attr& attr) const = 0;
protected: protected:
Func func{nullptr}; Func func{nullptr};
}; };
template <typename KernelTuples> template <typename KernelTuple>
class ReferKernel : public KernelMore<KernelTuples> { class ReferKernel : public KernelMore<KernelTuple> {
public: public:
// Refer code can always be used // Refer code can always be used
bool UseMe(const typename KernelTuples::attr_type& attr) const override { bool CanBeUsed(const typename KernelTuple::attr_type& attr) const override {
return true; return true;
} }
const char* ImplType() const override { return "Refer"; } const char* ImplType() const override { return "Refer"; }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h> // XXH64: 13.8 GB/s
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -20,71 +21,46 @@ namespace operators { ...@@ -20,71 +21,46 @@ namespace operators {
namespace jit { namespace jit {
template <> template <>
size_t JitCodeKey<int>(const int& d) { int64_t JitCodeKey<int>(const int& d) {
return d; return d;
} }
template <> template <>
size_t JitCodeKey<int64_t>(const int64_t& d) { int64_t JitCodeKey<int64_t>(const int64_t& d) {
return d; return d;
} }
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) {
if (type == kVIdentity) {
return 0;
} else if (type == kVExp) {
return 1;
} else if (type == kVRelu) {
return 2;
} else if (type == kVSigmoid) {
return 3;
} else if (type == kVTanh) {
return 4;
}
PADDLE_THROW("Unsupported act type %d", type);
return 0;
}
template <> template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { int64_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d; return XXH64(&attr, sizeof(gru_attr_t), 0);
int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
} }
template <> template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { int64_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d; int keys[5] = {
return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) + attr.d, static_cast<int>(attr.act_gate), static_cast<int>(attr.act_cand),
(act_type_convert(attr.act_cand) << act_type_shift); static_cast<int>(attr.act_cell), static_cast<int>(attr.use_peephole)};
return XXH64(keys, sizeof(int) * 5, 0);
} }
template <> template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) { int64_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
size_t key = attr.w; int keys[2] = {attr.w, static_cast<int>(attr.type)};
constexpr int pool_type_shift = 3; return XXH64(keys, sizeof(int) * 2, 0);
return (key << pool_type_shift) + static_cast<int>(attr.type);
} }
template <> template <>
size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) { int64_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
size_t key = attr.m; return XXH64(&attr, sizeof(int) * 3, 0); // m, n, k
constexpr int shift = 21;
return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k;
} }
template <> template <>
size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) { int64_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
return attr.table_width; return attr.table_width;
} }
template <> template <>
size_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) { int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
return attr.grad_width; return attr.grad_width;
} }
......
...@@ -46,7 +46,7 @@ struct KernelKey { ...@@ -46,7 +46,7 @@ struct KernelKey {
// Every JitCode should have a method to get the key from attribution // Every JitCode should have a method to get the key from attribution
template <typename Attr> template <typename Attr>
size_t JitCodeKey(const Attr& attr); int64_t JitCodeKey(const Attr& attr);
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> // for move
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
...@@ -30,7 +31,7 @@ namespace jit { ...@@ -30,7 +31,7 @@ namespace jit {
template <KernelType KT> template <KernelType KT>
class JitCodePool { class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr; typedef std::unique_ptr<GenBase> GenBasePtr;
typedef std::unordered_map<size_t, GenBasePtr> JitCodeMap; typedef std::unordered_map<int64_t, GenBasePtr> JitCodeMap;
public: public:
JitCodePool() = default; JitCodePool() = default;
...@@ -41,9 +42,9 @@ class JitCodePool { ...@@ -41,9 +42,9 @@ class JitCodePool {
const JitCodeMap& AllKernels() { return codes_; } const JitCodeMap& AllKernels() { return codes_; }
bool Has(size_t key) const { return codes_.find(key) != codes_.end(); } bool Has(int64_t key) const { return codes_.find(key) != codes_.end(); }
void Insert(size_t key, GenBasePtr value) { void Insert(int64_t key, GenBasePtr value) {
codes_.emplace(key, std::move(value)); codes_.emplace(key, std::move(value));
} }
......
...@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w, ...@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
} }
} }
bool CRFDecodingKernel::UseMe(const int& d) const { bool CRFDecodingKernel::CanBeUsed(const int& d) const {
#ifdef __AVX512F__ #ifdef __AVX512F__
constexpr int block = ZMM_FLOAT_BLOCK; constexpr int block = ZMM_FLOAT_BLOCK;
#else #else
......
...@@ -26,11 +26,11 @@ namespace intrinsic { ...@@ -26,11 +26,11 @@ namespace intrinsic {
void CRFDecoding(const int seq_len, const float* x, const float* w, void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num); float* alpha, int* track, int tag_num);
class CRFDecodingKernel : public KernelMore<CRFDecodingTuples<float>> { class CRFDecodingKernel : public KernelMore<CRFDecodingTuple<float>> {
public: public:
CRFDecodingKernel() { this->func = CRFDecoding; } CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe( bool CanBeUsed(
const typename CRFDecodingTuples<float>::attr_type&) const override; const typename CRFDecodingTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; } const char* ImplType() const override { return "Intrinsic"; }
}; };
......
...@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var, ...@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
} }
} }
bool LayerNormKernel::UseMe(const int& d) const { bool LayerNormKernel::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK; return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK;
} }
......
...@@ -27,10 +27,11 @@ void LayerNorm(float* x, float* out, float* mean, float* var, ...@@ -27,10 +27,11 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height, const float* scale, const float* bias, int height,
const float epsilon, int right); const float epsilon, int right);
class LayerNormKernel : public KernelMore<LayerNormTuples<float>> { class LayerNormKernel : public KernelMore<LayerNormTuple<float>> {
public: public:
LayerNormKernel() { this->func = LayerNorm; } LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuples<float>::attr_type&) const override; bool CanBeUsed(
const typename LayerNormTuple<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; } const char* ImplType() const override { return "Intrinsic"; }
}; };
......
...@@ -23,6 +23,8 @@ namespace jit { ...@@ -23,6 +23,8 @@ namespace jit {
namespace more { namespace more {
namespace mix { namespace mix {
using CPUPlace = platform::CPUPlace;
void VSigmoid(const T* x, T* y, int n) { void VSigmoid(const T* x, T* y, int n) {
const float min = SIGMOID_THRESHOLD_MIN; const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX; const float max = SIGMOID_THRESHOLD_MAX;
...@@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) { ...@@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i]; y[i] = static_cast<T>(0) - y[i];
} }
auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n); auto compute = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
compute(y, y, n); compute(y, y, n);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]); y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
...@@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) { ...@@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) {
void VTanh(const T* x, T* y, int n) { void VTanh(const T* x, T* y, int n) {
const T a = 2, b = -1; const T a = 2, b = -1;
auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n); auto compute_scal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n); auto compute_addbias = KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n); auto compute_sigmoid = KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(n);
compute_scal(&a, x, y, n); compute_scal(&a, x, y, n);
compute_sigmoid(y, y, n); compute_sigmoid(y, y, n);
compute_scal(&a, y, y, n); compute_scal(&a, y, y, n);
...@@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) { ...@@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) {
} }
void Softmax(const T* x, T* y, int n, int bs) { void Softmax(const T* x, T* y, int n, int bs) {
auto compute_hmax = auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n); auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
auto compute_hsum = auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vscal =
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
auto compute_vaddbias = auto compute_vaddbias =
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n); KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
auto compute_vexp = auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
T scalar; T scalar;
...@@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) { ...@@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) {
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
if (type == kVSigmoid) { if (type == kVSigmoid) {
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVRelu) { } else if (type == kVRelu) {
return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VReluTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVTanh) { } else if (type == kVTanh) {
return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VTanhTuple<T>, CPUPlace>::Cache().At(d);
} else if (type == kVIdentity) { } else if (type == kVIdentity) {
return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d); return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
} }
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW("Not support type: %s", type);
return nullptr; return nullptr;
...@@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { ...@@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
const int d = attr->d; const int d = attr->d;
const int d2 = d * 2; const int d2 = d * 2;
const int d3 = d * 3; const int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d); auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2); auto vadd_d2 = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d2);
auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_gate_d2 = getActFunc(attr->act_gate, d2); auto act_gate_d2 = getActFunc(attr->act_gate, d2);
auto act_gate_d3 = getActFunc(attr->act_gate, d3); auto act_gate_d3 = getActFunc(attr->act_gate, d3);
...@@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) { ...@@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
int d = attr->d; int d = attr->d;
int d2 = d * 2; int d2 = d * 2;
int d3 = d * 3; int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d); auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
auto act_gate_d = getActFunc(attr->act_gate, d); auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_cand_d = getActFunc(attr->act_cand, d); auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d); auto act_cell_d = getActFunc(attr->act_cell, d);
...@@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) { ...@@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
int d2 = d * 2; int d2 = d * 2;
auto act_gate = getActFunc(attr->act_gate, d); auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d); auto act_cand = getActFunc(attr->act_cand, d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
act_gate(gates, gates, d); act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d); act_cand(gates + d2, gates + d2, d);
vmul_d(gates, gates + d2, ht, d); vmul_d(gates, gates + d2, ht, d);
...@@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) { ...@@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
T* ht = reinterpret_cast<T*>(step->ht); T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1); const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc(attr->act_gate, attr->d); auto act_gate = getActFunc(attr->act_gate, attr->d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d); auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(attr->d);
act_gate(gates + attr->d, gates + attr->d, attr->d); act_gate(gates + attr->d, gates + attr->d, attr->d);
vmul_d(ht_1, gates + attr->d, ht, attr->d); vmul_d(ht_1, gates + attr->d, ht, attr->d);
} }
...@@ -206,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { ...@@ -206,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
} }
// TODO(TJ): tuning me // TODO(TJ): tuning me
bool VSigmoidKernel::UseMe(const int& d) const { return true; } bool VSigmoidKernel::CanBeUsed(const int& d) const { return true; }
bool VTanhKernel::UseMe(const int& d) const { return true; } bool VTanhKernel::CanBeUsed(const int& d) const { return true; }
bool SoftmaxKernel::UseMe(const int& d) const { return true; } bool SoftmaxKernel::CanBeUsed(const int& d) const { return true; }
bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; } bool LSTMCtHtKernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; } bool LSTMC1H1Kernel::CanBeUsed(const lstm_attr_t& attr) const { return true; }
bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; } bool GRUH1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; } bool GRUHtPart1Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; } bool GRUHtPart2Kernel::CanBeUsed(const gru_attr_t& attr) const { return true; }
} // namespace mix } // namespace mix
} // namespace more } // namespace more
...@@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; } ...@@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
namespace mix = paddle::operators::jit::more::mix; namespace mix = paddle::operators::jit::more::mix;
#define REGISTER_MORE_KERNEL(key, func) \ #define REGISTER_MORE_KERNEL(func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel) REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid); REGISTER_MORE_KERNEL(VSigmoid);
REGISTER_MORE_KERNEL(kVTanh, VTanh); REGISTER_MORE_KERNEL(VTanh);
REGISTER_MORE_KERNEL(kSoftmax, Softmax); REGISTER_MORE_KERNEL(Softmax);
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt); REGISTER_MORE_KERNEL(LSTMCtHt);
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1); REGISTER_MORE_KERNEL(LSTMC1H1);
REGISTER_MORE_KERNEL(kGRUH1, GRUH1); REGISTER_MORE_KERNEL(GRUH1);
REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1); REGISTER_MORE_KERNEL(GRUHtPart1);
REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2); REGISTER_MORE_KERNEL(GRUHtPart2);
#undef REGISTER_MORE_KERNEL #undef REGISTER_MORE_KERNEL
...@@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr); ...@@ -34,27 +34,27 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart1(gru_t* step, const gru_attr_t* attr); void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart2(gru_t* step, const gru_attr_t* attr); void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
#define DECLARE_MORE_KERNEL(name, tuples) \ #define DECLARE_MORE_KERNEL(name) \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name; } \ name##Kernel() { this->func = name; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \ bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \ const char* ImplType() const override { return "Mixed"; } \
} }
// XYN // XYN
DECLARE_MORE_KERNEL(VSigmoid, XYNTuples); DECLARE_MORE_KERNEL(VSigmoid);
DECLARE_MORE_KERNEL(VTanh, XYNTuples); DECLARE_MORE_KERNEL(VTanh);
// XRN // XRN
DECLARE_MORE_KERNEL(Softmax, SoftmaxTuples); DECLARE_MORE_KERNEL(Softmax);
DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_MORE_KERNEL(LSTMCtHt);
DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples); DECLARE_MORE_KERNEL(LSTMC1H1);
DECLARE_MORE_KERNEL(GRUH1, GRUTuples); DECLARE_MORE_KERNEL(GRUH1);
DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples); DECLARE_MORE_KERNEL(GRUHtPart1);
DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples); DECLARE_MORE_KERNEL(GRUHtPart2);
#undef DECLARE_MORE_KERNEL #undef DECLARE_MORE_KERNEL
......
...@@ -130,105 +130,106 @@ void ASum<double>(const double* x, double* res, int n) { ...@@ -130,105 +130,106 @@ void ASum<double>(const double* x, double* res, int n) {
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <> template <>
bool VMulKernel<float>::UseMe(const int& d) const { bool VMulKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return platform::MayIUse(platform::avx512f) && d > 512;
} }
template <> template <>
bool VAddKernel<float>::UseMe(const int& d) const { bool VAddKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx) && d > 512; return platform::MayIUse(platform::avx) && d > 512;
} }
template <> template <>
bool VScalKernel<float>::UseMe(const int& d) const { bool VScalKernel<float>::CanBeUsed(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return platform::MayIUse(platform::avx512f) && d > 512;
} }
template <> template <>
bool VExpKernel<float>::UseMe(const int& d) const { bool VExpKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool VSquareKernel<float>::UseMe(const int& d) const { bool VSquareKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool VCopyKernel<float>::UseMe(const int& d) const { bool VCopyKernel<float>::CanBeUsed(const int& d) const {
return d > 15; return d > 15;
} }
template <> template <>
bool VBroadcastKernel<float>::UseMe(const int64_t& d) const { bool VBroadcastKernel<float>::CanBeUsed(const int64_t& d) const {
return d > 127; return d > 127;
} }
template <> template <>
bool VBroadcastKernel<double>::UseMe(const int64_t& attr) const { bool VBroadcastKernel<double>::CanBeUsed(const int64_t& attr) const {
return true; return true;
} }
template <> template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const { bool VSigmoidKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool VTanhKernel<float>::UseMe(const int& d) const { bool VTanhKernel<float>::CanBeUsed(const int& d) const {
return d > 7; return d > 7;
} }
template <> template <>
bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const { bool SeqPoolKernel<float>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const { bool SeqPoolKernel<double>::CanBeUsed(const seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool EmbSeqPoolKernel<float>::UseMe(const emb_seq_pool_attr_t& attr) const { bool EmbSeqPoolKernel<float>::CanBeUsed(const emb_seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const { bool EmbSeqPoolKernel<double>::CanBeUsed(
const emb_seq_pool_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SgdKernel<float>::UseMe(const sgd_attr_t& attr) const { bool SgdKernel<float>::CanBeUsed(const sgd_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SgdKernel<double>::UseMe(const sgd_attr_t& attr) const { bool SgdKernel<double>::CanBeUsed(const sgd_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const { bool MatMulKernel<float>::CanBeUsed(const matmul_attr_t& attr) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
} }
template <> template <>
bool MatMulKernel<double>::UseMe(const matmul_attr_t& attr) const { bool MatMulKernel<double>::CanBeUsed(const matmul_attr_t& attr) const {
return true; return true;
} }
template <> template <>
bool SoftmaxKernel<float>::UseMe(const int& d) const { bool SoftmaxKernel<float>::CanBeUsed(const int& d) const {
// tuned on avx2 // tuned on avx2
return platform::MayIUse(platform::avx) && d < 60; return platform::MayIUse(platform::avx) && d < 60;
} }
#define AWALYS_USE_ME_WITH_DOUBLE(func) \ #define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \ template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \ bool func##Kernel<double>::CanBeUsed(const int& d) const { \
return true; \ return true; \
} }
AWALYS_USE_ME_WITH_DOUBLE(VMul); AWALYS_USE_ME_WITH_DOUBLE(VMul);
...@@ -250,23 +251,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax); ...@@ -250,23 +251,23 @@ AWALYS_USE_ME_WITH_DOUBLE(Softmax);
namespace mkl = paddle::operators::jit::more::mkl; namespace mkl = paddle::operators::jit::more::mkl;
#define REGISTER_MKL_KERNEL(key, func) \ #define REGISTER_MKL_KERNEL(func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \ REGISTER_JITKERNEL_MORE(k##func, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>) mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kMatMul, MatMul); REGISTER_MKL_KERNEL(MatMul);
REGISTER_MKL_KERNEL(kVMul, VMul); REGISTER_MKL_KERNEL(VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd); REGISTER_MKL_KERNEL(VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(VScal);
REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare); REGISTER_MKL_KERNEL(VSquare);
REGISTER_MKL_KERNEL(kVCopy, VCopy); REGISTER_MKL_KERNEL(VCopy);
REGISTER_MKL_KERNEL(kVBroadcast, VBroadcast); REGISTER_MKL_KERNEL(VBroadcast);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(SeqPool);
REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_MKL_KERNEL(EmbSeqPool);
REGISTER_MKL_KERNEL(kSoftmax, Softmax); REGISTER_MKL_KERNEL(Softmax);
REGISTER_MKL_KERNEL(kSgd, Sgd); REGISTER_MKL_KERNEL(Sgd);
#undef REGISTER_MKL_KERNEL #undef REGISTER_MKL_KERNEL
...@@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, ...@@ -175,41 +175,38 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
} }
} }
#define DECLARE_MKL_KERNEL(name, tuples) \ #define DECLARE_MKL_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name<T>; } \ name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \ bool CanBeUsed(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \ const char* ImplType() const override { return "MKL"; } \
} }
// ABCMNK // ABCMNK
DECLARE_MKL_KERNEL(MatMul, MatMulTuples); DECLARE_MKL_KERNEL(MatMul);
// XYZN // XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples); DECLARE_MKL_KERNEL(VMul);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples); DECLARE_MKL_KERNEL(VAdd);
// AXYN // AXYN
DECLARE_MKL_KERNEL(VScal, AXYNTuples); DECLARE_MKL_KERNEL(VScal);
// XYN // XYN
DECLARE_MKL_KERNEL(VExp, XYNTuples); DECLARE_MKL_KERNEL(VExp);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid);
DECLARE_MKL_KERNEL(VTanh, XYNTuples); DECLARE_MKL_KERNEL(VTanh);
DECLARE_MKL_KERNEL(VSquare, XYNTuples); DECLARE_MKL_KERNEL(VSquare);
DECLARE_MKL_KERNEL(VCopy, XYNTuples); DECLARE_MKL_KERNEL(VCopy);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); // others
DECLARE_MKL_KERNEL(SeqPool);
DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples); DECLARE_MKL_KERNEL(EmbSeqPool);
DECLARE_MKL_KERNEL(Softmax);
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); DECLARE_MKL_KERNEL(Sgd);
DECLARE_MKL_KERNEL(VBroadcast);
DECLARE_MKL_KERNEL(Sgd, SgdTuples);
DECLARE_MKL_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_MKL_KERNEL #undef DECLARE_MKL_KERNEL
......
...@@ -17,51 +17,43 @@ ...@@ -17,51 +17,43 @@
namespace refer = paddle::operators::jit::refer; namespace refer = paddle::operators::jit::refer;
#define REGISTER_REFER_KERNEL(key, func) \ #define REGISTER_REFER_KERNEL(func) \
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \ REGISTER_JITKERNEL_REFER(k##func, refer::func##Kernel<float>, \
refer::func##Kernel<double>) refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(kVMul, VMul); REGISTER_REFER_KERNEL(VMul);
REGISTER_REFER_KERNEL(kVAdd, VAdd); REGISTER_REFER_KERNEL(VAdd);
REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu); REGISTER_REFER_KERNEL(VAddRelu);
REGISTER_REFER_KERNEL(kVSub, VSub); REGISTER_REFER_KERNEL(VSub);
REGISTER_REFER_KERNEL(kVScal, VScal); REGISTER_REFER_KERNEL(VScal);
REGISTER_REFER_KERNEL(kVAddBias, VAddBias); REGISTER_REFER_KERNEL(VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu); REGISTER_REFER_KERNEL(VRelu);
REGISTER_REFER_KERNEL(kVCopy, VCopy); REGISTER_REFER_KERNEL(VCopy);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity); REGISTER_REFER_KERNEL(VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare); REGISTER_REFER_KERNEL(VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp); REGISTER_REFER_KERNEL(VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid); REGISTER_REFER_KERNEL(VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh); REGISTER_REFER_KERNEL(VTanh);
REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt); REGISTER_REFER_KERNEL(LSTMCtHt);
REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1); REGISTER_REFER_KERNEL(LSTMC1H1);
REGISTER_REFER_KERNEL(kGRUH1, GRUH1); REGISTER_REFER_KERNEL(GRUH1);
REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1); REGISTER_REFER_KERNEL(GRUHtPart1);
REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2); REGISTER_REFER_KERNEL(GRUHtPart2);
REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding); REGISTER_REFER_KERNEL(CRFDecoding);
REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); REGISTER_REFER_KERNEL(LayerNorm);
REGISTER_REFER_KERNEL(NCHW16CMulNC);
REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); REGISTER_REFER_KERNEL(SeqPool);
REGISTER_REFER_KERNEL(MatMul);
REGISTER_REFER_KERNEL(kSeqPool, SeqPool); REGISTER_REFER_KERNEL(HMax);
REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(kMatMul, MatMul); REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(kHMax, HMax); REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(kHSum, HSum); REGISTER_REFER_KERNEL(VBroadcast);
REGISTER_REFER_KERNEL(kSoftmax, Softmax);
REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_REFER_KERNEL(kSgd, Sgd);
REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, ...@@ -490,60 +490,54 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
} }
} }
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<name##Tuple<T>> { \
public: \ public: \
name##Kernel() { this->func = name<T>; } \ name##Kernel() { this->func = name<T>; } \
} }
// const T* x, const T* y, T* z, int n // const T* x, const T* y, T* z, int n
DECLARE_REFER_KERNEL(VMul, XYZNTuples); DECLARE_REFER_KERNEL(VMul);
DECLARE_REFER_KERNEL(VAdd, XYZNTuples); DECLARE_REFER_KERNEL(VAdd);
DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples); DECLARE_REFER_KERNEL(VAddRelu);
DECLARE_REFER_KERNEL(VSub, XYZNTuples); DECLARE_REFER_KERNEL(VSub);
// const T* a, const T* x, T* y, int n // const T* a, const T* x, T* y, int n
DECLARE_REFER_KERNEL(VScal, AXYNTuples); DECLARE_REFER_KERNEL(VScal);
DECLARE_REFER_KERNEL(VAddBias, AXYNTuples); DECLARE_REFER_KERNEL(VAddBias);
// const T* x, T* y, int n // const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu, XYNTuples); DECLARE_REFER_KERNEL(VRelu);
DECLARE_REFER_KERNEL(VIdentity, XYNTuples); DECLARE_REFER_KERNEL(VIdentity);
DECLARE_REFER_KERNEL(VExp, XYNTuples); DECLARE_REFER_KERNEL(VExp);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid);
DECLARE_REFER_KERNEL(VTanh, XYNTuples); DECLARE_REFER_KERNEL(VTanh);
DECLARE_REFER_KERNEL(VSquare, XYNTuples); DECLARE_REFER_KERNEL(VSquare);
DECLARE_REFER_KERNEL(VCopy, XYNTuples); DECLARE_REFER_KERNEL(VCopy);
// lstm_t*, const lstm_attr_t* // lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_REFER_KERNEL(LSTMCtHt);
DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples); DECLARE_REFER_KERNEL(LSTMC1H1);
// gru_t*, const gru_attr_t* // gru_t*, const gru_attr_t*
DECLARE_REFER_KERNEL(GRUH1, GRUTuples); DECLARE_REFER_KERNEL(GRUH1);
DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples); DECLARE_REFER_KERNEL(GRUHtPart1);
DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples); DECLARE_REFER_KERNEL(GRUHtPart2);
DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples); DECLARE_REFER_KERNEL(HMax);
DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); DECLARE_REFER_KERNEL(HSum);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); // others
DECLARE_REFER_KERNEL(CRFDecoding);
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); DECLARE_REFER_KERNEL(LayerNorm);
DECLARE_REFER_KERNEL(NCHW16CMulNC);
DECLARE_REFER_KERNEL(MatMul, MatMulTuples); DECLARE_REFER_KERNEL(SeqPool);
DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(HMax, XRNTuples); DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(HSum, XRNTuples); DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples); DECLARE_REFER_KERNEL(VBroadcast);
DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_REFER_KERNEL(Sgd, SgdTuples);
DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <tuple> #include <tuple>
#include <type_traits> #include <type_traits>
#include <utility> // for std::move
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h" #include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -49,8 +50,8 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> { ...@@ -49,8 +50,8 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
void operator()(KernelType kt) const { void operator()(KernelType kt) const {
KernelKey kkey(kt, PlaceType()); KernelKey kkey(kt, PlaceType());
Pool().Instance().Insert(kkey, Pool::Instance().Insert(kkey,
std::move(make_unique<const KERNEL_IMPL_TYPE>())); std::move(make_unique<const KERNEL_IMPL_TYPE>()));
constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value; constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1, JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
KernelImpls...> KernelImpls...>
......
此差异已折叠。
...@@ -230,8 +230,8 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -230,8 +230,8 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(bias->numel(), right); PADDLE_ENFORCE_EQ(bias->numel(), right);
auto ker = auto ker =
jit::Get<jit::kLayerNorm, jit::LayerNormTuples<T>, platform::CPUPlace>( jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
right); .At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(), ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left), scale->data<T>(), bias->data<T>(), static_cast<int>(left),
static_cast<const float>(epsilon), right); static_cast<const float>(epsilon), right);
......
...@@ -30,17 +30,16 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -30,17 +30,16 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return; return;
} }
if (relu) { if (relu) {
auto compute = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>, auto compute =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
.At(N); N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} else { } else {
auto compute = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<T>, auto compute =
platform::CPUPlace>::Cache() jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
.At(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
......
...@@ -256,8 +256,8 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -256,8 +256,8 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
static_cast<int>(input.numel() / input.dims()[0]), static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum); jit::SeqPoolType::kSum);
auto seqpool = auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>( jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache()
attr); .At(attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]); attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr); seqpool(src, dst, &attr);
......
...@@ -82,8 +82,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -82,8 +82,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const int kClassDim = 1; const int kClassDim = 1;
// 2D data. Batch x C // 2D data. Batch x C
auto compute_softmax = auto compute_softmax =
jit::KernelFuncs<jit::kSoftmax, jit::SoftmaxTuples<float>, jit::KernelFuncs<jit::SoftmaxTuple<float>, platform::CPUPlace>::Cache()
platform::CPUPlace>::Cache()
.At(in_dims[kClassDim]); .At(in_dims[kClassDim]);
compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
} }
......
...@@ -48,7 +48,8 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -48,7 +48,8 @@ class SGDOpKernel : public framework::OpKernel<T> {
T *out_data = param_out->mutable_data<T>(ctx.GetPlace()); T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto sgd = auto sgd =
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr); jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced. // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
...@@ -82,7 +83,8 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -82,7 +83,8 @@ class SGDOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width); PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
auto sgd = auto sgd =
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr); jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr); sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
......
...@@ -282,7 +282,9 @@ class RecurrentOp : public RecurrentBase { ...@@ -282,7 +282,9 @@ class RecurrentOp : public RecurrentBase {
// Every inputs are linked now, execute! // Every inputs are linked now, execute!
executor.Run(*program, &cur_scope, block->ID(), executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/); false /*create_local_scope*/, true /*create_vars*/,
std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
...@@ -398,7 +400,9 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -398,7 +400,9 @@ class RecurrentGradOp : public RecurrentBase {
VLOG(5) << "Recurrent memory linking finished "; VLOG(5) << "Recurrent memory linking finished ";
// Run step block with cur_scope // Run step block with cur_scope
executor.Run(*program, &cur_scope, block->ID(), executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/); false /*create_local_scope*/, true /*create_vars*/,
std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
VLOG(5) << "executor.Run finished "; VLOG(5) << "executor.Run finished ";
......
...@@ -13,10 +13,18 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/imperative.h"
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -31,20 +39,20 @@ void BindTracer(pybind11::module* m) { ...@@ -31,20 +39,20 @@ void BindTracer(pybind11::module* m) {
[](imperative::Tracer& self, imperative::OpBase* op, [](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs, const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs, const imperative::VarBasePtrMap& outputs,
framework::BlockDesc* block, framework::AttributeMap attrs_map,
const platform::CPUPlace expected_place, const platform::CPUPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
return self.Trace(op, inputs, outputs, block, expected_place, return self.Trace(op, inputs, outputs, attrs_map, expected_place,
stop_gradient); stop_gradient);
}) })
.def("trace", .def("trace",
[](imperative::Tracer& self, imperative::OpBase* op, [](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs, const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs, const imperative::VarBasePtrMap& outputs,
framework::BlockDesc* block, framework::AttributeMap attrs_map,
const platform::CUDAPlace expected_place, const platform::CUDAPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
return self.Trace(op, inputs, outputs, block, expected_place, return self.Trace(op, inputs, outputs, attrs_map, expected_place,
stop_gradient); stop_gradient);
}) })
.def("py_trace", &imperative::Tracer::PyTrace, .def("py_trace", &imperative::Tracer::PyTrace,
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <Python.h> #include <Python.h>
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
...@@ -36,6 +37,8 @@ class Layer : public imperative::Layer { ...@@ -36,6 +37,8 @@ class Layer : public imperative::Layer {
class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase { class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase {
public: public:
using imperative::OpBase::OpBase; // Inherit constructors using imperative::OpBase::OpBase; // Inherit constructors
PyOpBase(const std::string& name) : OpBase(name) {}
}; };
class PyVarBase : public imperative::VarBase { class PyVarBase : public imperative::VarBase {
......
...@@ -23,97 +23,7 @@ limitations under the License. */ ...@@ -23,97 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
// Cast boost::variant for PyBind. #include "paddle/fluid/pybind/pybind_boost_headers.h"
// Copy from
// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
namespace pybind11 {
namespace detail {
#if !defined(PYBIND11_HIDDEN)
#ifdef _WIN32
#define PYBIND11_HIDDEN __declspec(dllexport)
#else
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
#endif
#endif
// Can be replaced by a generic lambda in C++14
struct PYBIND11_HIDDEN paddle_variant_caster_visitor
: public boost::static_visitor<handle> {
return_value_policy policy;
handle parent;
paddle_variant_caster_visitor(return_value_policy policy, handle parent)
: policy(policy), parent(parent) {}
template <class T>
handle operator()(T const &src) const {
return make_caster<T>::cast(src, policy, parent);
}
};
template <class Variant>
struct paddle_variant_caster;
template <template <class...> class V, class... Ts>
struct paddle_variant_caster<V<Ts...>> {
using Type = V<Ts...>;
template <typename T>
typename std::enable_if<
!std::is_same<T, boost::detail::variant::void_>::value, bool>::type
try_load(handle src, bool convert) {
auto caster = make_caster<T>();
if (!load_success_ && caster.load(src, convert)) {
load_success_ = true;
if (std::is_same<T, std::vector<float>>::value) {
auto caster_ints = make_caster<std::vector<int64_t>>();
if (caster_ints.load(src, convert)) {
VLOG(4) << "This value are floats and int64_ts satisfy "
"simultaneously, will set it's type to "
"std::vector<int64_t>";
value = cast_op<std::vector<int64_t>>(caster_ints);
return true;
}
}
value = cast_op<T>(caster);
return true;
}
return false;
}
template <typename T>
typename std::enable_if<std::is_same<T, boost::detail::variant::void_>::value,
bool>::type
try_load(handle src, bool convert) {
return false;
}
bool load(handle src, bool convert) {
auto unused = {false, try_load<Ts>(src, convert)...};
(void)(unused);
return load_success_;
}
static handle cast(Type const &src, return_value_policy policy,
handle parent) {
paddle_variant_caster_visitor visitor(policy, parent);
return boost::apply_visitor(visitor, src);
}
PYBIND11_TYPE_CASTER(Type, _("Variant"));
bool load_success_{false};
};
// Add specialization for concrete variant type
template <class... Args>
struct type_caster<boost::variant<Args...>>
: paddle_variant_caster<boost::variant<Args...>> {};
} // namespace detail
} // namespace pybind11
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
......
...@@ -149,8 +149,14 @@ PYBIND11_MODULE(core, m) { ...@@ -149,8 +149,14 @@ PYBIND11_MODULE(core, m) {
[]() { return memory::allocation::GPUMemMonitor.PrintMemUsage(); }); []() { return memory::allocation::GPUMemMonitor.PrintMemUsage(); });
py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC") py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC")
// .def(py::init<>()) .def(
.def(py::init<bool>(), py::arg("stop_gradient") = false) py::init<const std::string &, paddle::framework::proto::VarType::Type,
const std::vector<int64_t>, const paddle::platform::CPUPlace,
bool, bool>())
.def(
py::init<const std::string &, paddle::framework::proto::VarType::Type,
const std::vector<int64_t>,
const paddle::platform::CUDAPlace, bool, bool>())
.def("_run_backward", .def("_run_backward",
[](imperative::VarBase &self) { self.RunBackward(); }) [](imperative::VarBase &self) { self.RunBackward(); })
.def("_grad_name", &imperative::VarBase::GradName) .def("_grad_name", &imperative::VarBase::GradName)
...@@ -177,51 +183,21 @@ PYBIND11_MODULE(core, m) { ...@@ -177,51 +183,21 @@ PYBIND11_MODULE(core, m) {
py::return_value_policy::take_ownership) py::return_value_policy::take_ownership)
.def("value", [](const imperative::VarBase &self) { return self.var_; }, .def("value", [](const imperative::VarBase &self) { return self.var_; },
py::return_value_policy::reference) py::return_value_policy::reference)
.def_property("name", .def_property("name", &imperative::VarBase::Name,
[](const imperative::VarBase &self) { return self.name_; }, &imperative::VarBase::SetName)
[](imperative::VarBase &self, const std::string &name) { .def_property_readonly("shape", &imperative::VarBase::Shape)
self.name_ = name; .def_property_readonly("dtype", &imperative::VarBase::DType)
}) .def_property("persistable", &imperative::VarBase::IsPersistable,
.def_property("block", &imperative::VarBase::SetPersistable)
[](const imperative::VarBase &self) { return self.block_; }, .def_property("stop_gradient", &imperative::VarBase::IsStopGradient,
[](imperative::VarBase &self, framework::BlockDesc *block) { &imperative::VarBase::SetStopGradient);
self.block_ = block;
},
py::return_value_policy::reference)
.def_property(
"persistable",
[](const imperative::VarBase &self) { return self.persistable_; },
[](imperative::VarBase &self, const bool persistable) {
self.persistable_ = persistable;
})
.def_property(
"desc",
[](const imperative::VarBase &self) { return self.var_desc_; },
[](imperative::VarBase &self, framework::VarDesc *var_desc) {
self.var_desc_ = var_desc;
},
py::return_value_policy::reference)
.def_property(
"stop_gradient",
[](const imperative::VarBase &self) { return self.IsStopGradient(); },
[](imperative::VarBase &self, bool stop_gradient) {
self.SetStopGradient(stop_gradient);
});
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC") py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>()) .def(py::init<const std::string &>())
.def("register_backward_hooks", .def("register_backward_hooks",
[](imperative::OpBase &self, const py::object &callable) { [](imperative::OpBase &self, const py::object &callable) {
self.RegisterBackwardHooks(callable); self.RegisterBackwardHooks(callable);
}) })
.def_property(
"desc", [](const imperative::OpBase &self) { return self.op_desc_; },
[](imperative::OpBase &self, framework::OpDesc *op_desc) {
if (op_desc) {
self.op_desc_ = op_desc;
}
},
py::return_value_policy::reference)
.def_property("_trace_id", .def_property("_trace_id",
[](const imperative::OpBase &self) { [](const imperative::OpBase &self) {
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
...@@ -260,7 +236,17 @@ PYBIND11_MODULE(core, m) { ...@@ -260,7 +236,17 @@ PYBIND11_MODULE(core, m) {
"apply", "apply",
[](int func_id, const std::vector<imperative::VarBase *> &inputs) [](int func_id, const std::vector<imperative::VarBase *> &inputs)
-> std::vector<imperative::VarBase *> { -> std::vector<imperative::VarBase *> {
return imperative::PyLayer::Apply(func_id, inputs); auto ret_vars = imperative::PyLayer::Apply(func_id, inputs);
std::vector<imperative::VarBase *> outputs;
outputs.reserve(ret_vars.size());
for (size_t i = 0U; i != ret_vars.size(); ++i) {
framework::Variable *v = ret_vars[i];
// TODO(minqiyang): use unique_name generator to set a name
outputs.emplace_back(
new imperative::VarBase("", v, nullptr, true));
}
return outputs;
}, },
py::return_value_policy::take_ownership) py::return_value_policy::take_ownership)
.def_static("register_func", .def_static("register_func",
...@@ -876,9 +862,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -876,9 +862,11 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
.def("close", &Executor::Close) .def("close", &Executor::Close)
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) { int block_id, bool create_local_scope, bool create_vars,
const std::vector<std::string> &fetch_vars) {
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
self.Run(prog, scope, block_id, create_local_scope, create_vars); self.Run(prog, scope, block_id, create_local_scope, create_vars,
fetch_vars);
}); });
m.def("init_gflags", framework::InitGflags); m.def("init_gflags", framework::InitGflags);
......
此差异已折叠。
...@@ -128,11 +128,11 @@ def __bootstrap__(): ...@@ -128,11 +128,11 @@ def __bootstrap__():
'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph', 'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_ngraph',
'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory',
'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb', 'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb',
'fast_eager_deletion_mode', 'allocator_strategy', 'fast_eager_deletion_mode', 'memory_fraction_of_eager_deletion',
'reader_queue_speed_test_mode', 'print_sub_graph_dir', 'allocator_strategy', 'reader_queue_speed_test_mode',
'pe_profile_fname', 'warpctc_dir', 'inner_op_parallelism', 'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir',
'enable_parallel_graph', 'multiple_of_cupti_buffer_size', 'inner_op_parallelism', 'enable_parallel_graph',
'enable_subgraph_optimize' 'multiple_of_cupti_buffer_size', 'enable_subgraph_optimize'
] ]
if 'Darwin' not in sysstr: if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory') read_env_flags.append('use_pinned_memory')
......
...@@ -590,7 +590,7 @@ class Executor(object): ...@@ -590,7 +590,7 @@ class Executor(object):
fetch_var_name=fetch_var_name) fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope) self._feed_data(program, feed, feed_var_name, scope)
exe.run(program.desc, scope, 0, True, True) exe.run(program.desc, scope, 0, True, True, fetch_var_name)
outs = self._fetch_data(fetch_list, fetch_var_name, scope) outs = self._fetch_data(fetch_list, fetch_var_name, scope)
if return_numpy: if return_numpy:
outs = as_numpy(outs) outs = as_numpy(outs)
......
此差异已折叠。
...@@ -258,7 +258,7 @@ class PyLayer(core.PyLayer): ...@@ -258,7 +258,7 @@ class PyLayer(core.PyLayer):
cls.backward_id = core.PyLayer.num_funcs() + 1 cls.backward_id = core.PyLayer.num_funcs() + 1
PyLayer.register_func(cls.backward_id, cls._do_backward) PyLayer.register_func(cls.backward_id, cls._do_backward)
iop = core.OpBase() iop = core.OpBase(cls.__class__.__name__ + str(cls.forward_id))
iop.forward_id = cls.forward_id iop.forward_id = cls.forward_id
iop.backward_id = cls.backward_id iop.backward_id = cls.backward_id
block.ops.append(iop) block.ops.append(iop)
......
...@@ -36,14 +36,21 @@ class Tracer(core.Tracer): ...@@ -36,14 +36,21 @@ class Tracer(core.Tracer):
super(Tracer, self).__init__(block) super(Tracer, self).__init__(block)
self._ops = defaultdict() self._ops = defaultdict()
self._vars = defaultdict()
self._trace_id = 0 self._trace_id = 0
def trace_var(self, name, var):
self._vars[name] = var
def all_parameters(self):
return list((item for name, item in six.iteritems(self._vars)
if isinstance(item, framework.Parameter)))
def trace_op(self, op, stop_gradient=False): def trace_op(self, op, stop_gradient=False):
# record op's trace id # record op's trace id
op.iop._trace_id = self._trace_id op.iop._trace_id = self._trace_id
# trace op and save it backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs,
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.block.desc,
framework._current_expected_place(), framework._current_expected_place(),
stop_gradient) stop_gradient)
......
...@@ -10704,8 +10704,9 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002): ...@@ -10704,8 +10704,9 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
similarity_matrix = matmul( similarity_matrix = matmul(
anchor, positive, transpose_x=False, transpose_y=True) anchor, positive, transpose_x=False, transpose_y=True)
softmax_value = softmax(similarity_matrix) softmax_ce = softmax_with_cross_entropy(
cross_entropy = -1 * reduce_sum(labels * log(softmax_value), 0) logits=similarity_matrix, label=labels, soft_label=True)
cross_entropy = reduce_sum(labels * softmax_ce, 0)
celoss = reduce_mean(cross_entropy) celoss = reduce_mean(cross_entropy)
return l2loss + celoss return l2loss + celoss
...@@ -377,17 +377,16 @@ class Optimizer(object): ...@@ -377,17 +377,16 @@ class Optimizer(object):
and list of (param, grad) Variables pair for optimization. and list of (param, grad) Variables pair for optimization.
""" """
self._dtype = loss.dtype self._dtype = loss.dtype
program = loss.block.program
optimize_ops = [] optimize_ops = []
if framework._in_imperative_mode(): if framework._in_imperative_mode():
if parameter_list is not None: if parameter_list is not None:
parameters = parameter_list parameters = parameter_list
else: else:
parameters = program.global_block().all_parameters() parameters = framework._imperative_tracer().all_parameters()
params_grads = [] params_grads = []
for param in parameters: for param in parameters:
if param.stop_gradient or not param.trainable: if not param.trainable:
continue continue
# create gradient variable # create gradient variable
grad_var = Variable( grad_var = Variable(
...@@ -396,9 +395,11 @@ class Optimizer(object): ...@@ -396,9 +395,11 @@ class Optimizer(object):
stop_gradient=True, stop_gradient=True,
ivar=param._ivar._grad_ivar()) ivar=param._ivar._grad_ivar())
params_grads.append((param, grad_var)) params_grads.append((param, grad_var))
with program_guard(program, startup_program): with program_guard(framework.default_main_program(),
framework.default_startup_program()):
optimize_ops = self._create_optimization_pass(params_grads) optimize_ops = self._create_optimization_pass(params_grads)
else: else:
program = loss.block.program
with program_guard(program, startup_program): with program_guard(program, startup_program):
params_grads = self.backward(loss, startup_program, params_grads = self.backward(loss, startup_program,
parameter_list, no_grad_set) parameter_list, no_grad_set)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册