提交 5eb87506 编写于 作者: X Xin Pan

add per kernel config and remove const_cast.

test=develop
上级 a7e7d952
...@@ -921,7 +921,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -921,7 +921,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType( auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx)); ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -940,6 +940,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -940,6 +940,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
KernelTypeToString(expected_kernel_key)); KernelTypeToString(expected_kernel_key));
} }
auto config_iter = kernel_configs_map_.find(expected_kernel_key);
std::vector<KernelConfig>* kernel_configs = nullptr;
if (config_iter != kernel_configs_map_.end()) {
kernel_configs = &(config_iter->second);
}
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = auto* transfer_scope =
...@@ -957,7 +963,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -957,7 +963,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs. // not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx)); kernel_iter->second(
ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs));
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered. // there is inplace variable has been transfered.
......
...@@ -184,12 +184,125 @@ class OperatorBase { ...@@ -184,12 +184,125 @@ class OperatorBase {
const platform::Place& place) const = 0; const platform::Place& place) const = 0;
}; };
template <typename TAlgorithm>
class AlgorithmsCache {
public:
AlgorithmsCache() : search_times_(0) { hash_.clear(); }
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
TAlgorithm GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations,
int algorithmFlags, // can set for different data type
std::function<TAlgorithm()> gen_func);
TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func);
private:
std::unordered_map<int64_t, TAlgorithm> hash_;
std::mutex mutex_;
int search_times_;
};
template <typename TAlgorithm>
TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
std::lock_guard<std::mutex> lock(mutex_);
int64_t seed = 0;
// Hash all of the inputs, use to try and look up a previously
// discovered algorithm, or fall back to generating a new one.
std::hash<int64_t> hashFn;
// do hash like boost
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
for (const auto num : dims1) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
for (const auto num : dims2) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1;
}
for (const auto num : strides) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 2;
}
for (const auto num : paddings) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 3;
}
for (const auto num : dilations) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 4;
}
seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 5;
if (seed == 0) return gen_func();
if (hash_.find(seed) == hash_.end()) {
TAlgorithm value = gen_func();
hash_[seed] = value;
}
return hash_[seed];
}
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
if (hash_.find(area) != hash_.end()) {
return hash_[area];
}
if (search_times_ < search_times) {
auto algo = gen_func();
hash_[area] = algo;
++search_times_;
return algo;
}
TAlgorithm algo;
int64_t min = static_cast<uint64_t>(INT_MAX);
for (const auto& m : hash_) {
if (m.first < min) {
min = m.first;
algo = m.second;
}
}
return algo;
}
#ifdef PADDLE_WITH_CUDA
using KernelConfig = boost::variant<
std::shared_ptr<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>>;
#else
using KernelConfig = boost::variant<boost::blank>;
#endif
using OpKernelConfigsMap =
std::unordered_map<OpKernelType, std::vector<KernelConfig>,
OpKernelType::Hash>;
class ExecutionContext { class ExecutionContext {
public: public:
ExecutionContext(const OperatorBase& op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context, const platform::DeviceContext& device_context,
const RuntimeContext& ctx) const RuntimeContext& ctx,
: op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {} std::vector<KernelConfig>* configs)
: op_(op),
scope_(scope),
device_context_(device_context),
ctx_(ctx),
kernel_configs_(configs) {}
const OperatorBase& op() const { return op_; } const OperatorBase& op() const { return op_; }
...@@ -398,11 +511,20 @@ class ExecutionContext { ...@@ -398,11 +511,20 @@ class ExecutionContext {
return temp_tensor; return temp_tensor;
} }
template <typename T>
T& GetKernelConfig(int idx) const {
PADDLE_ENFORCE(kernel_configs_ && kernel_configs_->size() > idx,
"%s selected kernel doesn't have kernel config %lu <= %d",
op_.Type().c_str(), kernel_configs_->size(), idx);
return *boost::get<std::shared_ptr<T>>(kernel_configs_->at(idx));
}
private: private:
const OperatorBase& op_; const OperatorBase& op_;
const Scope& scope_; const Scope& scope_;
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
const RuntimeContext& ctx_; const RuntimeContext& ctx_;
mutable std::vector<KernelConfig>* kernel_configs_;
}; };
template <> template <>
...@@ -508,6 +630,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -508,6 +630,9 @@ class OperatorWithKernel : public OperatorBase {
void TransferInplaceVarsBack(const Scope& scope, void TransferInplaceVarsBack(const Scope& scope,
const std::vector<std::string>& inplace_vars, const std::vector<std::string>& inplace_vars,
const Scope& exec_scope) const; const Scope& exec_scope) const;
protected:
mutable OpKernelConfigsMap kernel_configs_map_;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
...@@ -50,8 +50,6 @@ class Scope; ...@@ -50,8 +50,6 @@ class Scope;
} // namespace framework } // namespace framework
namespace operators { namespace operators {
template <typename T>
class AlgorithmsCache;
class CudnnRNNCache; class CudnnRNNCache;
...@@ -144,9 +142,6 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -144,9 +142,6 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#ifndef _WIN32 #ifndef _WIN32
ncclUniqueId, platform::Communicator, ncclUniqueId, platform::Communicator,
#endif #endif
operators::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>,
operators::CudnnRNNCache, operators::CudnnRNNCache,
#endif #endif
int, float>; int, float>;
......
...@@ -249,7 +249,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -249,7 +249,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
framework::Scope scope; framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_); PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx); p.op.RuntimeInferShape(scope, place_, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); p.func(
framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
} }
} }
......
...@@ -64,8 +64,9 @@ class PreparedOp { ...@@ -64,8 +64,9 @@ class PreparedOp {
framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second; framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = op.GetExpectedKernelType( auto expected_kernel_key =
framework::ExecutionContext(op, framework::Scope(), *dev_ctx, ctx)); op.GetExpectedKernelType(framework::ExecutionContext(
op, framework::Scope(), *dev_ctx, ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
......
...@@ -139,7 +139,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -139,7 +139,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
PreparedOp prepared_op = PreparedOp::Prepare(ctx, *op_kernel, op->place_); PreparedOp prepared_op = PreparedOp::Prepare(ctx, *op_kernel, op->place_);
prepared_op.op.RuntimeInferShape(scope, op->place_, ctx); prepared_op.op.RuntimeInferShape(scope, op->place_, ctx);
prepared_op.func(framework::ExecutionContext( prepared_op.func(framework::ExecutionContext(
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx)); prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx, nullptr));
if (!stop_gradient) { if (!stop_gradient) {
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var( std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
......
...@@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
auto& dev_ctx = *pool.Get(dev_place); auto& dev_ctx = *pool.Get(dev_place);
framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope); framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx); framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx, nullptr);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids"); const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores"); const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
......
...@@ -42,6 +42,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; ...@@ -42,6 +42,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
template <typename T> template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType; using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
using framework::AlgorithmsCache;
template <typename T> template <typename T>
class CUDNNConvOpKernel : public framework::OpKernel<T> { class CUDNNConvOpKernel : public framework::OpKernel<T> {
...@@ -169,18 +170,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -169,18 +170,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
workspace_size_limit, &algo)); workspace_size_limit, &algo));
VLOG(3) << "cuDNN forward algo " << algo; VLOG(3) << "cuDNN forward algo " << algo;
} else if (exhaustive_search && (!half_float)) { } else if (exhaustive_search && (!half_float)) {
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr; AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(0);
algo_cache =
ctx.scope()
.FindVar(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
} else {
algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
}
cudnn_workspace = cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>( ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim( framework::make_ddim(
...@@ -188,7 +179,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -188,7 +179,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
dev_ctx); dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>()); cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
algo = algo_cache->GetAlgorithm( algo = algo_cache.GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() { x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count; int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS> std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
...@@ -382,22 +373,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -382,22 +373,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) { if (exhaustive_search) {
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* data_algo_cache; AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>& data_algo_cache =
if (ctx.scope().FindVar(kCUDNNBwdDataAlgoCache)) { ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>(
data_algo_cache = 0);
ctx.scope()
.FindVar(kCUDNNBwdDataAlgoCache) data_algo = data_algo_cache.GetAlgorithm(
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
} else {
data_algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNBwdDataAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
}
data_algo = data_algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() { x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count; int returned_algo_count;
std::array<cudnnConvolutionBwdDataAlgoPerf_t, std::array<cudnnConvolutionBwdDataAlgoPerf_t,
...@@ -448,22 +428,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -448,22 +428,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
if (filter_grad) { if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace()); T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) { if (exhaustive_search) {
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>* f_algo_cache; AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>& f_algo_cache =
if (ctx.scope().FindVar(kCUDNNBwdFilterAlgoCache)) { ctx.GetKernelConfig<
f_algo_cache = AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>(1);
ctx.scope()
.FindVar(kCUDNNBwdFilterAlgoCache) filter_algo = f_algo_cache.GetAlgorithm(
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
} else {
f_algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNBwdFilterAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
}
filter_algo = f_algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() { x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count; int returned_algo_count;
std::array<cudnnConvolutionBwdFilterAlgoPerf_t, std::array<cudnnConvolutionBwdFilterAlgoPerf_t,
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <functional> #include <functional>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_uint64(conv_workspace_size_limit); DECLARE_uint64(conv_workspace_size_limit);
...@@ -46,100 +47,5 @@ static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4; ...@@ -46,100 +47,5 @@ static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4;
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5; static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5;
#endif #endif
template <typename TAlgorithm>
class AlgorithmsCache {
public:
AlgorithmsCache() : search_times_(0) { hash_.clear(); }
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
TAlgorithm GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations,
int algorithmFlags, // can set for different data type
std::function<TAlgorithm()> gen_func);
TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func);
private:
std::unordered_map<int64_t, TAlgorithm> hash_;
std::mutex mutex_;
int search_times_;
};
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
std::lock_guard<std::mutex> lock(mutex_);
int64_t seed = 0;
// Hash all of the inputs, use to try and look up a previously
// discovered algorithm, or fall back to generating a new one.
std::hash<int64_t> hashFn;
// do hash like boost
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
for (const auto num : dims1) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
for (const auto num : dims2) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1;
}
for (const auto num : strides) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 2;
}
for (const auto num : paddings) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 3;
}
for (const auto num : dilations) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 4;
}
seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 5;
if (seed == 0) return gen_func();
if (hash_.find(seed) == hash_.end()) {
TAlgorithm value = gen_func();
hash_[seed] = value;
}
return hash_[seed];
}
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
if (hash_.find(area) != hash_.end()) {
return hash_[area];
}
if (search_times_ < search_times) {
auto algo = gen_func();
hash_[area] = algo;
++search_times_;
return algo;
}
TAlgorithm algo;
int64_t min = static_cast<uint64_t>(INT_MAX);
for (const auto& m : hash_) {
if (m.first < min) {
min = m.first;
algo = m.second;
}
}
return algo;
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -30,6 +30,8 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; ...@@ -30,6 +30,8 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor; using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using framework::AlgorithmsCache;
template <typename T> template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType; using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
...@@ -139,38 +141,23 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -139,38 +141,23 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
} }
return fwd_perf_stat[0].algo; return fwd_perf_stat[0].algo;
}; };
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr; AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(0);
int search_times = ctx.Attr<int>("search_times"); int search_times = ctx.Attr<int>("search_times");
search_times = std::max( search_times = std::max(
static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times); static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
// TODO(dangqingqing): Unify this if-else.
if (search_times > 0) { if (search_times > 0) {
// The searched algo will be cached by `search_times` times for // The searched algo will be cached by `search_times` times for
// different input dimension. For other dimensions, select the algo // different input dimension. For other dimensions, select the algo
// of closest area. // of closest area.
auto var_name = ctx.Inputs("AlgoCache")[0]; algo = algo_cache.GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0,
algo_cache = search_func);
ctx.scope()
.FindVar(var_name)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
algo = algo_cache->GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0,
search_func);
} else { } else {
// Cache searched algo in Var(kCUDNNFwdAlgoCache). // Cache searched algo in Var(kCUDNNFwdAlgoCache).
// all conv ops use the same kCUDNNFwdAlgoCache variable. // all conv ops use the same kCUDNNFwdAlgoCache variable.
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, paddings,
algo_cache = dilations, 0, search_func);
ctx.scope()
.FindVar(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
} else {
// TODO(qingqing) remove const_cast
algo_cache =
const_cast<framework::Scope*>(ctx.scope().parent())
->Var(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
}
algo = algo_cache->GetAlgorithm(x_dims, f_dims, strides, paddings,
dilations, 0, search_func);
} }
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
} }
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector> #include <vector>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -109,8 +110,20 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -109,8 +110,20 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
"float16 can only be used when CUDNN is used"); "float16 can only be used when CUDNN is used");
} }
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value); library, customized_type_value);
#ifdef PADDLE_WITH_CUDA
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
// TODO(dangqingqing): Currently conv_fusion_op use cudnn but sets use_cudnn
// to false. It should be fixed and then here should only create if library
// is kCUDNN.
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>> p(
new framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>());
configs.push_back(p);
}
#endif
return type;
} }
void Conv2DOpMaker::Make() { void Conv2DOpMaker::Make() {
...@@ -410,9 +423,25 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -410,9 +423,25 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), auto type = framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_, ctx.GetPlace(), layout_, library_,
customized_type_value); customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>
p(new framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>());
configs.push_back(p);
std::shared_ptr<
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>
p2(new framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>());
configs.push_back(p2);
}
}
#endif
return type;
} }
class Conv2dGradMaker : public framework::SingleGradOpDescMaker { class Conv2dGradMaker : public framework::SingleGradOpDescMaker {
......
...@@ -141,7 +141,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr) { ...@@ -141,7 +141,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = auto* dev_ctx =
static_cast<platform::CPUDeviceContext*>(pool.Get(cpu_place)); static_cast<platform::CPUDeviceContext*>(pool.Get(cpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
int numel = memory_size / sizeof(float); int numel = memory_size / sizeof(float);
framework::Tensor tensor = framework::Tensor tensor =
...@@ -156,7 +156,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr) { ...@@ -156,7 +156,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = auto* dev_ctx =
static_cast<platform::CUDADeviceContext*>(pool.Get(gpu_place)); static_cast<platform::CUDADeviceContext*>(pool.Get(gpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
int numel = memory_size / sizeof(float); int numel = memory_size / sizeof(float);
framework::Tensor tensor = framework::Tensor tensor =
ctx.AllocateTmpTensor<float, platform::CUDADeviceContext>( ctx.AllocateTmpTensor<float, platform::CUDADeviceContext>(
...@@ -179,7 +179,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr2) { ...@@ -179,7 +179,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr2) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = auto* dev_ctx =
static_cast<platform::CPUDeviceContext*>(pool.Get(cpu_place)); static_cast<platform::CPUDeviceContext*>(pool.Get(cpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
int numel = memory_size / sizeof(float); int numel = memory_size / sizeof(float);
framework::Tensor out_side_tensor; framework::Tensor out_side_tensor;
...@@ -200,7 +200,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr2) { ...@@ -200,7 +200,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr2) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = auto* dev_ctx =
static_cast<platform::CUDADeviceContext*>(pool.Get(gpu_place)); static_cast<platform::CUDADeviceContext*>(pool.Get(gpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
size_t memory_size = 500; size_t memory_size = 500;
int numel = memory_size / sizeof(float); int numel = memory_size / sizeof(float);
......
...@@ -723,7 +723,6 @@ class Operator(object): ...@@ -723,7 +723,6 @@ class Operator(object):
self._update_desc_attr(attr_name, attr_val) self._update_desc_attr(attr_name, attr_val)
self.desc.check_attrs() self.desc.check_attrs()
if self._has_kernel(type): if self._has_kernel(type):
self.desc.infer_var_type(self.block.desc) self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc) self.desc.infer_shape(self.block.desc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册