未验证 提交 44e7fcdd 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15844 from panyx0718/infer

add per kernel config and remove const_cast.
......@@ -904,6 +904,16 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->InferShape(&infer_shape_ctx);
}
std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
const OpKernelType& key) const {
auto config_iter = kernel_configs_map_.find(key);
std::vector<KernelConfig>* kernel_configs = nullptr;
if (config_iter != kernel_configs_map_.end()) {
kernel_configs = &(config_iter->second);
}
return kernel_configs;
}
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
RuntimeContext ctx(Inputs(), Outputs(), scope);
......@@ -921,7 +931,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx));
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
......@@ -940,6 +950,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
KernelTypeToString(expected_kernel_key));
}
std::vector<KernelConfig>* kernel_configs =
GetKernelConfig(expected_kernel_key);
// do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope =
......@@ -957,7 +970,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
this->InferShape(&infer_shape_ctx);
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx));
kernel_iter->second(
ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs));
if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered.
......
......@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
......@@ -184,12 +185,30 @@ class OperatorBase {
const platform::Place& place) const = 0;
};
#ifdef PADDLE_WITH_CUDA
using KernelConfig = boost::variant<
std::shared_ptr<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>>;
#else
using KernelConfig = boost::variant<boost::blank>;
#endif
using OpKernelConfigsMap =
std::unordered_map<OpKernelType, std::vector<KernelConfig>,
OpKernelType::Hash>;
class ExecutionContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context,
const RuntimeContext& ctx)
: op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
const RuntimeContext& ctx,
std::vector<KernelConfig>* configs)
: op_(op),
scope_(scope),
device_context_(device_context),
ctx_(ctx),
kernel_configs_(configs) {}
const OperatorBase& op() const { return op_; }
......@@ -398,11 +417,20 @@ class ExecutionContext {
return temp_tensor;
}
template <typename T>
T& GetKernelConfig(int idx) const {
PADDLE_ENFORCE(kernel_configs_ && kernel_configs_->size() > idx,
"%s selected kernel doesn't have kernel config %lu <= %d",
op_.Type().c_str(), kernel_configs_->size(), idx);
return *boost::get<std::shared_ptr<T>>(kernel_configs_->at(idx));
}
private:
const OperatorBase& op_;
const Scope& scope_;
const platform::DeviceContext& device_context_;
const RuntimeContext& ctx_;
mutable std::vector<KernelConfig>* kernel_configs_;
};
template <>
......@@ -483,6 +511,8 @@ class OperatorWithKernel : public OperatorBase {
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
std::vector<KernelConfig>* GetKernelConfig(const OpKernelType& key) const;
protected:
virtual OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
......@@ -508,6 +538,9 @@ class OperatorWithKernel : public OperatorBase {
void TransferInplaceVarsBack(const Scope& scope,
const std::vector<std::string>& inplace_vars,
const Scope& exec_scope) const;
protected:
mutable OpKernelConfigsMap kernel_configs_map_;
};
extern bool OpSupportGPU(const std::string& op_type);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace framework {
// Not thread-safe. Should be owned per-kernel.
template <typename TAlgorithm>
class AlgorithmsCache {
public:
AlgorithmsCache() : search_times_(0) { hash_.clear(); }
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
TAlgorithm GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations,
int algorithmFlags, // can set for different data type
std::function<TAlgorithm()> gen_func);
TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func);
private:
std::unordered_map<int64_t, TAlgorithm> hash_;
int search_times_;
};
template <typename TAlgorithm>
TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
int64_t seed = 0;
// Hash all of the inputs, use to try and look up a previously
// discovered algorithm, or fall back to generating a new one.
std::hash<int64_t> hashFn;
// do hash like boost
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
for (const auto num : dims1) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
for (const auto num : dims2) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1;
}
for (const auto num : strides) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 2;
}
for (const auto num : paddings) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 3;
}
for (const auto num : dilations) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 4;
}
seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 5;
if (seed == 0) return gen_func();
if (hash_.find(seed) == hash_.end()) {
TAlgorithm value = gen_func();
hash_[seed] = value;
}
return hash_[seed];
}
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
if (hash_.find(area) != hash_.end()) {
return hash_[area];
}
if (search_times_ < search_times) {
auto algo = gen_func();
hash_[area] = algo;
++search_times_;
return algo;
}
TAlgorithm algo;
int64_t min = static_cast<uint64_t>(INT_MAX);
for (const auto& m : hash_) {
if (m.first < min) {
min = m.first;
algo = m.second;
}
}
return algo;
}
} // namespace framework
} // namespace paddle
......@@ -50,8 +50,6 @@ class Scope;
} // namespace framework
namespace operators {
template <typename T>
class AlgorithmsCache;
class CudnnRNNCache;
......@@ -144,9 +142,6 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
#ifndef _WIN32
ncclUniqueId, platform::Communicator,
#endif
operators::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>,
operators::CudnnRNNCache,
#endif
int, float>;
......
......@@ -249,7 +249,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
p.func(
framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
}
}
......
......@@ -44,8 +44,13 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func,
platform::DeviceContext* dev_ctx)
: op(op), ctx(ctx), func(func), dev_ctx(dev_ctx) {}
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs)
: op(op),
ctx(ctx),
func(func),
dev_ctx(dev_ctx),
kernel_configs(kernel_configs) {}
static PreparedOp Prepare(const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel& op,
......@@ -64,8 +69,9 @@ class PreparedOp {
framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = op.GetExpectedKernelType(
framework::ExecutionContext(op, framework::Scope(), *dev_ctx, ctx));
auto expected_kernel_key =
op.GetExpectedKernelType(framework::ExecutionContext(
op, framework::Scope(), *dev_ctx, ctx, nullptr));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
......@@ -83,7 +89,9 @@ class PreparedOp {
PADDLE_THROW("op %s does not have kernel for %s", op.Type(),
KernelTypeToString(expected_kernel_key));
}
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
std::vector<framework::KernelConfig>* kernel_configs =
op.GetKernelConfig(expected_kernel_key);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs);
}
inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx; }
......@@ -92,6 +100,7 @@ class PreparedOp {
const framework::RuntimeContext& ctx;
framework::OperatorWithKernel::OpKernelFunc func;
platform::DeviceContext* dev_ctx;
std::vector<framework::KernelConfig>* kernel_configs;
};
class OpBase;
......
......@@ -138,8 +138,9 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->place_ = GetExpectedPlace(expected_place, inputs);
PreparedOp prepared_op = PreparedOp::Prepare(ctx, *op_kernel, op->place_);
prepared_op.op.RuntimeInferShape(scope, op->place_, ctx);
prepared_op.func(framework::ExecutionContext(
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
prepared_op.func(
framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx,
prepared_op.ctx, prepared_op.kernel_configs));
if (!stop_gradient) {
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
......
......@@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
auto& dev_ctx = *pool.Get(dev_place);
framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx, nullptr);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
......
......@@ -42,6 +42,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
using framework::AlgorithmsCache;
template <typename T>
class CUDNNConvOpKernel : public framework::OpKernel<T> {
......@@ -169,18 +170,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
workspace_size_limit, &algo));
VLOG(3) << "cuDNN forward algo " << algo;
} else if (exhaustive_search && (!half_float)) {
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr;
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) {
algo_cache =
ctx.scope()
.FindVar(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
} else {
algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
}
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(0);
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
......@@ -188,7 +179,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
algo = algo_cache->GetAlgorithm(
algo = algo_cache.GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
......@@ -382,22 +373,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) {
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* data_algo_cache;
if (ctx.scope().FindVar(kCUDNNBwdDataAlgoCache)) {
data_algo_cache =
ctx.scope()
.FindVar(kCUDNNBwdDataAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
} else {
data_algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNBwdDataAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
}
data_algo = data_algo_cache->GetAlgorithm(
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>& data_algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>(
0);
data_algo = data_algo_cache.GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionBwdDataAlgoPerf_t,
......@@ -448,22 +428,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) {
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>* f_algo_cache;
if (ctx.scope().FindVar(kCUDNNBwdFilterAlgoCache)) {
f_algo_cache =
ctx.scope()
.FindVar(kCUDNNBwdFilterAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
} else {
f_algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNBwdFilterAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
}
filter_algo = f_algo_cache->GetAlgorithm(
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>& f_algo_cache =
ctx.GetKernelConfig<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>(1);
filter_algo = f_algo_cache.GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionBwdFilterAlgoPerf_t,
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <functional>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_uint64(conv_workspace_size_limit);
......@@ -46,100 +47,5 @@ static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4;
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5;
#endif
template <typename TAlgorithm>
class AlgorithmsCache {
public:
AlgorithmsCache() : search_times_(0) { hash_.clear(); }
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
TAlgorithm GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations,
int algorithmFlags, // can set for different data type
std::function<TAlgorithm()> gen_func);
TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func);
private:
std::unordered_map<int64_t, TAlgorithm> hash_;
std::mutex mutex_;
int search_times_;
};
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
std::lock_guard<std::mutex> lock(mutex_);
int64_t seed = 0;
// Hash all of the inputs, use to try and look up a previously
// discovered algorithm, or fall back to generating a new one.
std::hash<int64_t> hashFn;
// do hash like boost
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
for (const auto num : dims1) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
for (const auto num : dims2) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1;
}
for (const auto num : strides) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 2;
}
for (const auto num : paddings) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 3;
}
for (const auto num : dilations) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 4;
}
seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 5;
if (seed == 0) return gen_func();
if (hash_.find(seed) == hash_.end()) {
TAlgorithm value = gen_func();
hash_[seed] = value;
}
return hash_[seed];
}
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
if (hash_.find(area) != hash_.end()) {
return hash_[area];
}
if (search_times_ < search_times) {
auto algo = gen_func();
hash_[area] = algo;
++search_times_;
return algo;
}
TAlgorithm algo;
int64_t min = static_cast<uint64_t>(INT_MAX);
for (const auto& m : hash_) {
if (m.first < min) {
min = m.first;
algo = m.second;
}
}
return algo;
}
} // namespace operators
} // namespace paddle
......@@ -30,6 +30,8 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout;
using framework::AlgorithmsCache;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
......@@ -139,38 +141,21 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
}
return fwd_perf_stat[0].algo;
};
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr;
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(0);
int search_times = ctx.Attr<int>("search_times");
search_times = std::max(
static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
// TODO(dangqingqing): Unify this if-else.
if (search_times > 0) {
// The searched algo will be cached by `search_times` times for
// different input dimension. For other dimensions, select the algo
// of closest area.
auto var_name = ctx.Inputs("AlgoCache")[0];
algo_cache =
ctx.scope()
.FindVar(var_name)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
algo = algo_cache->GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0,
search_func);
algo = algo_cache.GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0,
search_func);
} else {
// Cache searched algo in Var(kCUDNNFwdAlgoCache).
// all conv ops use the same kCUDNNFwdAlgoCache variable.
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) {
algo_cache =
ctx.scope()
.FindVar(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
} else {
// TODO(qingqing) remove const_cast
algo_cache =
const_cast<framework::Scope*>(ctx.scope().parent())
->Var(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
}
algo = algo_cache->GetAlgorithm(x_dims, f_dims, strides, paddings,
dilations, 0, search_func);
algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, paddings,
dilations, 0, search_func);
}
VLOG(3) << "choose algo " << algo;
}
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_MKLDNN
......@@ -109,8 +110,20 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
"float16 can only be used when CUDNN is used");
}
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value);
auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value);
#ifdef PADDLE_WITH_CUDA
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
// TODO(dangqingqing): Currently conv_fusion_op use cudnn but sets use_cudnn
// to false. It should be fixed and then here should only create if library
// is kCUDNN.
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>> p(
new framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>());
configs.push_back(p);
}
#endif
return type;
}
void Conv2DOpMaker::Make() {
......@@ -410,9 +423,25 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_,
customized_type_value);
auto type = framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_,
customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>
p(new framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>());
configs.push_back(p);
std::shared_ptr<
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>
p2(new framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>());
configs.push_back(p2);
}
}
#endif
return type;
}
class Conv2dGradMaker : public framework::SingleGradOpDescMaker {
......
......@@ -141,7 +141,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx =
static_cast<platform::CPUDeviceContext*>(pool.Get(cpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx);
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
int numel = memory_size / sizeof(float);
framework::Tensor tensor =
......@@ -156,7 +156,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx =
static_cast<platform::CUDADeviceContext*>(pool.Get(gpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx);
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
int numel = memory_size / sizeof(float);
framework::Tensor tensor =
ctx.AllocateTmpTensor<float, platform::CUDADeviceContext>(
......@@ -179,7 +179,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr2) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx =
static_cast<platform::CPUDeviceContext*>(pool.Get(cpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx);
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
int numel = memory_size / sizeof(float);
framework::Tensor out_side_tensor;
......@@ -200,7 +200,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr2) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx =
static_cast<platform::CUDADeviceContext*>(pool.Get(gpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx);
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
size_t memory_size = 500;
int numel = memory_size / sizeof(float);
......
......@@ -732,7 +732,6 @@ class Operator(object):
self._update_desc_attr(attr_name, attr_val)
self.desc.check_attrs()
if self._has_kernel(type):
self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册