未验证 提交 bfb07aaf 编写于 作者: Z zhongpu 提交者: GitHub

Revert "Exhaustive search (#22821)", test=develop (#23401)

This reverts commit 48144e40.
上级 7fda333a
......@@ -905,6 +905,16 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->InferShape(&infer_shape_ctx);
}
std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
const OpKernelType& key) const {
auto config_iter = kernel_configs_map_.find(key);
std::vector<KernelConfig>* kernel_configs = nullptr;
if (config_iter != kernel_configs_map_.end()) {
kernel_configs = &(config_iter->second);
}
return kernel_configs;
}
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
// To reduce the elapsed time of HasAttr, we use bool variable to record the
......@@ -941,6 +951,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
ChooseKernel(*runtime_ctx, scope, place);
}
std::vector<KernelConfig>* kernel_configs = GetKernelConfig(*kernel_type_);
// do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars;
Scope* transfer_scope = nullptr;
......@@ -976,8 +988,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
{
platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp);
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
(*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx,
kernel_configs));
}
if (!transfered_inplace_vars.empty()) {
......@@ -1046,7 +1058,7 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx));
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace();
......
......@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
......@@ -215,12 +216,30 @@ class OperatorBase {
const platform::Place& place) const = 0;
};
#ifdef PADDLE_WITH_CUDA
using KernelConfig = boost::variant<
std::shared_ptr<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>>;
#else
using KernelConfig = boost::variant<boost::blank>;
#endif
using OpKernelConfigsMap =
std::unordered_map<OpKernelType, std::vector<KernelConfig>,
OpKernelType::Hash>;
class ExecutionContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context,
const RuntimeContext& ctx)
: op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
const RuntimeContext& ctx,
std::vector<KernelConfig>* configs)
: op_(op),
scope_(scope),
device_context_(device_context),
ctx_(ctx),
kernel_configs_(configs) {}
virtual ~ExecutionContext() {}
virtual std::string InputName(const std::string& name) const {
......@@ -386,6 +405,15 @@ class ExecutionContext {
return temp_tensor;
}
template <typename T>
T& GetKernelConfig(size_t idx) const {
PADDLE_ENFORCE(
kernel_configs_ && kernel_configs_->size() > static_cast<size_t>(idx),
"%s selected kernel doesn't have kernel config %lu <= %lu",
op_.Type().c_str(), kernel_configs_->size(), idx);
return *boost::get<std::shared_ptr<T>>((*kernel_configs_)[idx]);
}
const RuntimeContext Context() const { return ctx_; }
std::string DebugString() const { return op_.DebugString(); }
......@@ -395,6 +423,7 @@ class ExecutionContext {
const Scope& scope_;
const platform::DeviceContext& device_context_;
const RuntimeContext& ctx_;
mutable std::vector<KernelConfig>* kernel_configs_;
};
template <>
......@@ -470,6 +499,8 @@ class OperatorWithKernel : public OperatorBase {
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
std::vector<KernelConfig>* GetKernelConfig(const OpKernelType& key) const;
// change this to public so that in dygraph mode we can call it to check if we
// need transform data
virtual OpKernelType GetKernelTypeForVar(
......@@ -506,6 +537,7 @@ class OperatorWithKernel : public OperatorBase {
const platform::Place& place) const;
protected:
mutable OpKernelConfigsMap kernel_configs_map_;
mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::unique_ptr<OpKernelFunc> kernel_func_;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
......
......@@ -21,21 +21,19 @@ limitations under the License. */
namespace paddle {
namespace framework {
// thread-safe.
// Not thread-safe. Should be owned per-kernel.
template <typename TAlgorithm>
class AlgorithmsCache {
public:
AlgorithmsCache() : search_times_(0) { hash_.clear(); }
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
// cudnn_dtype set for different data type
TAlgorithm GetAlgorithm(const std::vector<int64_t>& dims1,
const std::vector<int64_t>& dims2,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags,
int64_t cudnn_dtype,
std::function<TAlgorithm()> gen_func);
TAlgorithm GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations,
int algorithmFlags, // can set for different data type
std::function<TAlgorithm()> gen_func);
TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func);
......@@ -43,14 +41,13 @@ class AlgorithmsCache {
private:
std::unordered_map<int64_t, TAlgorithm> hash_;
int search_times_;
std::mutex cache_mutex;
};
template <typename TAlgorithm>
TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags, int64_t cudnn_dtype,
const std::vector<int>& dilations, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
int64_t seed = 0;
// Hash all of the inputs, use to try and look up a previously
......@@ -84,73 +81,36 @@ TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 5;
seed ^= hashFn(static_cast<int64_t>(cudnn_dtype)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 6;
VLOG(10) << "seed:" << seed << ", hash_.size:" << hash_.size();
if (seed == 0) return gen_func();
TAlgorithm ret;
auto it = hash_.end();
bool have_found = false;
{
std::lock_guard<std::mutex> lock(cache_mutex);
it = hash_.find(seed);
if (it != hash_.end()) {
ret = it->second;
have_found = true;
}
}
if (!have_found) {
ret = gen_func();
std::lock_guard<std::mutex> lock(cache_mutex);
hash_[seed] = ret;
if (hash_.find(seed) == hash_.end()) {
TAlgorithm value = gen_func();
hash_[seed] = value;
}
return ret;
return hash_[seed];
}
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
auto it = hash_.end();
{
std::lock_guard<std::mutex> lock(cache_mutex);
it = hash_.find(area);
if (it != hash_.end()) {
return it->second;
}
}
bool gene_flag = false;
{
std::lock_guard<std::mutex> lock(cache_mutex);
gene_flag = (search_times_ < search_times);
if (hash_.find(area) != hash_.end()) {
return hash_[area];
}
TAlgorithm algo{};
if (gene_flag) {
algo = gen_func();
std::lock_guard<std::mutex> lock(cache_mutex);
if (search_times_ < search_times) {
auto algo = gen_func();
hash_[area] = algo;
++search_times_;
return algo;
}
TAlgorithm algo{};
int64_t min = static_cast<uint64_t>(INT_MAX);
{
std::lock_guard<std::mutex> lock(cache_mutex);
for (const auto& m : hash_) {
if (m.first < min) {
min = m.first;
algo = m.second;
}
for (const auto& m : hash_) {
if (m.first < min) {
min = m.first;
algo = m.second;
}
}
return algo;
......
......@@ -525,7 +525,7 @@ TEST(ExecutionContextAttrAndInOut, new_api) {
paddle::framework::RuntimeContext ctx({}, {});
paddle::framework::ExecutionContext exe_context(*(op.get()), scope, *dev_ctx,
ctx);
ctx, nullptr);
ASSERT_EQ(exe_context.InputSize("input"), 1u);
ASSERT_EQ(exe_context.OutputSize("output"), 1u);
......
......@@ -33,10 +33,11 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::Scope& scope,
const platform::DeviceContext& device_context,
const framework::RuntimeContext& ctx,
std::vector<framework::KernelConfig>* configs,
const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs)
: ExecutionContext(op, scope, device_context, ctx),
: ExecutionContext(op, scope, device_context, ctx, configs),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
......
......@@ -80,8 +80,13 @@ void PreparedOp::PrepareData(
PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx)
: op_(op), ctx_(ctx), func_(func), dev_ctx_(dev_ctx) {}
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs)
: op_(op),
ctx_(ctx),
func_(func),
dev_ctx_(dev_ctx),
kernel_configs_(kernel_configs) {}
template <typename VarType>
PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
......@@ -106,7 +111,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
framework::RuntimeContext ctx({}, {});
auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs));
op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
......@@ -115,6 +120,8 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
PADDLE_THROW("op %s does not have kernel for %s", op.Type(),
KernelTypeToString(expected_kernel_key));
}
std::vector<framework::KernelConfig>* kernel_configs =
op.GetKernelConfig(expected_kernel_key);
if (!(expected_kernel_key.place_ == place)) {
dev_ctx = pool.Get(expected_kernel_key.place_);
......@@ -122,7 +129,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
}
PrepareDataImpl<VarType>(place, ins, op, expected_kernel_key);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
......@@ -145,8 +152,10 @@ template <typename VarType>
static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs) {
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs) {
// TODO(zjl): remove scope in dygraph
framework::Scope scope;
......@@ -154,21 +163,22 @@ static void PreparedOpRunImpl(
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs));
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx,
kernel_configs, ins, outs, attrs));
}
void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VarBase>(op_, ctx_, func_, dev_ctx_, ins, outs, attrs);
PreparedOpRunImpl<VarBase>(op_, ctx_, func_, dev_ctx_, kernel_configs_, ins,
outs, attrs);
}
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, func_, dev_ctx_, ins, outs,
attrs);
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, func_, dev_ctx_,
kernel_configs_, ins, outs, attrs);
}
} // namespace imperative
......
......@@ -33,7 +33,8 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx);
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
......@@ -71,6 +72,7 @@ class PreparedOp {
const framework::RuntimeContext& ctx_;
framework::OperatorWithKernel::OpKernelFunc func_;
platform::DeviceContext* dev_ctx_;
std::vector<framework::KernelConfig>* kernel_configs_;
};
} // namespace imperative
......
......@@ -235,7 +235,7 @@ TEST(test_layer, test_dygraph_execution_context) {
framework::Scope scope;
DygraphExecutionContext<imperative::VarBase> dy_exe_context(
*(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map);
*(op.get()), scope, *dev_ctx, ctx, nullptr, ins, outs, concat_att_map);
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
......
......@@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
auto& dev_ctx = *pool.Get(dev_place);
framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx, nullptr);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
......
......@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_desc.h"
// #include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -90,43 +89,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
return out;
}
// ConvSearchCache using framework::AlgorithmsCache to search
// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or
// cudnnConvolutionBwdFilterAlgo_t
class ConvSearchCache {
public:
static ConvSearchCache& Instance() {
static ConvSearchCache instance;
return instance;
}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetForward() {
return &forward_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* GetBackwardData() {
return &backward_data_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>*
GetBackwardFilter() {
return &backward_filter_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetConvFusion() {
return &fusion_forward_cache_;
}
private:
ConvSearchCache() {}
~ConvSearchCache() {}
ConvSearchCache(const ConvSearchCache&) {}
ConvSearchCache& operator=(const ConvSearchCache&) {}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_cache_;
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>
backward_data_cache_;
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>
backward_filter_cache_;
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> fusion_forward_cache_;
};
using framework::AlgorithmsCache;
struct ConvArgs {
cudnnHandle_t handle;
......@@ -134,7 +97,6 @@ struct ConvArgs {
platform::FilterDescriptor wdesc;
platform::ConvolutionDescriptor cdesc;
const framework::Tensor *x, *w, *o;
cudnnDataType_t cudnn_dtype;
// strides
std::vector<int> s;
......@@ -145,9 +107,8 @@ struct ConvArgs {
ConvArgs(const framework::Tensor* x, const framework::Tensor* w,
const framework::Tensor* o, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> d,
cudnnDataType_t dtype)
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
const std::vector<int> p, const std::vector<int> d)
: x(x), w(w), o(o), s(s), p(p), d(d) {}
};
template <typename perf_t>
......@@ -160,7 +121,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic,
bool deterministic, int algo_cache_id,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool has_got_workspace_size = true;
......@@ -222,24 +183,22 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
#endif
VLOG(3) << "choose algo " << algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetForward());
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
<< algo_cache_id << ", x_dims:" << x_dims
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
<< args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
......@@ -285,7 +244,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic,
bool deterministic, int algo_cache_id,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
......@@ -362,23 +321,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
} else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetBackwardData());
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
<< algo_cache_id << ", x_dims:" << x_dims
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
<< args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
......@@ -427,7 +385,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic,
bool deterministic, int algo_cache_id,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
......@@ -491,22 +449,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
} else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetBackwardFilter());
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
<< algo_cache_id << ", x_dims:" << x_dims
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
<< args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
......
......@@ -216,13 +216,9 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
const T* filter_data = transformed_filter_channel.data<T>();
// ------------------- cudnn descriptors ---------------------
ConvArgs args{&transformed_input,
&transformed_filter_channel,
&transformed_output,
strides,
padding_common,
dilations,
dtype};
ConvArgs args{&transformed_input, &transformed_filter_channel,
&transformed_output, strides,
padding_common, dilations};
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
......@@ -273,7 +269,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnnConvolutionFwdAlgo_t algo{};
using search = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
algo = search::Find<T>(args, exhaustive_search, false, ctx);
algo = search::Find<T>(args, exhaustive_search, false, 0, ctx);
workspace_size = search::GetWorkspaceSize(args, algo);
#if CUDNN_VERSION_MIN(7, 0, 1)
......@@ -522,15 +518,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
&transformed_output_grad_channel,
strides,
padding_common,
dilations,
dtype};
dilations};
ConvArgs args2{&transformed_input,
&transformed_filter_grad_channel,
&transformed_output_grad_channel,
strides,
padding_common,
dilations,
dtype};
dilations};
auto handle = dev_ctx.cudnn_handle();
DataLayout layout = compute_format == DataLayout::kNHWC ? DataLayout::kNHWC
......@@ -586,7 +580,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo =
search1::Find<T>(args1, exhaustive_search, deterministic, ctx);
search1::Find<T>(args1, exhaustive_search, deterministic, 0, ctx);
workspace_size =
std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo));
}
......@@ -603,7 +597,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
search2::Find<T>(args2, exhaustive_search, deterministic, ctx);
search2::Find<T>(args2, exhaustive_search, deterministic, 1, ctx);
workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, filter_algo));
}
......@@ -904,26 +898,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
auto handle = dev_ctx.cudnn_handle();
ConvArgs args1{&transformed_ddX,
W,
&transformed_ddO_channel,
strides,
padding_common,
dilations,
dtype};
ConvArgs args2{
&transformed_X, ddW, &transformed_ddO_channel, strides, padding_common,
dilations, dtype};
ConvArgs args3{&transformed_ddX,
dW,
&transformed_dO_channel,
strides,
padding_common,
dilations,
dtype};
ConvArgs args4{
&transformed_dX, ddW, &transformed_dO_channel, strides, padding_common,
dilations, dtype};
ConvArgs args1{&transformed_ddX, W,
&transformed_ddO_channel, strides,
padding_common, dilations};
ConvArgs args2{&transformed_X, ddW, &transformed_ddO_channel, strides,
padding_common, dilations};
ConvArgs args3{&transformed_ddX, dW, &transformed_dO_channel, strides,
padding_common, dilations};
ConvArgs args4{&transformed_dX, ddW, &transformed_dO_channel, strides,
padding_common, dilations};
cudnnConvolutionFwdAlgo_t fwd_algo1 =
static_cast<cudnnConvolutionFwdAlgo_t>(0);
......@@ -951,7 +934,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args1.cdesc.set(dtype, padding_common, strides, dilations, c_group);
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, 0, ctx);
workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1);
}
......@@ -966,7 +949,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args2.cdesc.set(dtype, padding_common, strides, dilations, c_group);
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, 0, ctx);
workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, fwd_algo2));
}
......@@ -984,7 +967,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
search3::Find<T>(args3, exhaustive_search, deterministic, ctx);
search3::Find<T>(args3, exhaustive_search, deterministic, 1, ctx);
workspace_size = std::max(workspace_size,
search3::GetWorkspaceSize(args3, filter_algo));
}
......@@ -1000,7 +983,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo =
search4::Find<T>(args4, exhaustive_search, deterministic, ctx);
search4::Find<T>(args4, exhaustive_search, deterministic, 2, ctx);
workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
}
......
......@@ -178,6 +178,17 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value);
#ifdef PADDLE_WITH_CUDA
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
// TODO(dangqingqing): Currently conv_fusion_op use cudnn but sets use_cudnn
// to false. It should be fixed and then here should only create if library
// is kCUDNN.
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>> p(
new framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>());
configs.push_back(p);
}
#endif
return type;
}
......@@ -552,6 +563,21 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>
p(new framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>());
configs.push_back(p);
std::shared_ptr<
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>
p2(new framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>());
configs.push_back(p2);
}
}
#endif
return type;
}
......@@ -728,6 +754,25 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>> p0(
new framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>());
configs.push_back(p0);
std::shared_ptr<
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>
p1(new framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>());
configs.push_back(p1);
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>
p2(new framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>());
configs.push_back(p2);
}
}
#endif
return type;
}
......
......@@ -58,7 +58,8 @@ void MainTest(const TestData& test_data) {
RuntimeContext runtime_ctx =
RuntimeContext(op->Inputs(), op->Outputs(), scope);
ExecutionContext ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_ctx);
ExecutionContext ctx =
ExecutionContext(*op, scope, *dev_ctx, runtime_ctx, nullptr);
bool result = ElementwiseMulOp::AreDimsAndFormatCorrect(
ctx, 16, MKLDNNMemoryFormat::nChw16c);
if (test_data.supposed_to_fail)
......
......@@ -14,10 +14,10 @@ limitations under the License. */
#include <array>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_int64(cudnn_exhaustive_search_times);
......@@ -233,7 +233,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
return fwd_perf_stat[0].algo;
};
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
*(ConvSearchCache::Instance().GetConvFusion());
ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(0);
int search_times = ctx.Attr<int>("search_times");
search_times = std::max(
static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
......@@ -245,9 +245,8 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
algo = algo_cache.GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0,
search_func);
} else {
auto dtype = platform::CudnnDataType<T>::type;
algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, paddings,
dilations, 0, dtype, search_func);
dilations, 0, search_func);
}
VLOG(3) << "choose algo " << algo;
}
......
......@@ -61,8 +61,8 @@ class WarpCTCOp : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace(),
layout_, library_);
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
ctx.device_context(), layout_, library_);
}
};
......@@ -174,7 +174,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")),
ctx.GetPlace());
ctx.device_context());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册