未验证 提交 48144e40 编写于 作者: Z zhongpu 提交者: GitHub

Exhaustive search (#22821)

* use global conv cache; test=develop

* use singleton cache; test=develop

* fix format error; test=develop

* add cudnn helper header; test=develop

* fix header error; test=develop

* fix mac unitest; test=develop

* fix mac unitest; test=develop

* fix file format; test=develop

* fix include file error, test=develop

* remove kernel_configs_ in class ExecutionContext and kernel_configs_map_ in class OperatorWithKernel, test=develop

* fix test_elementwise_mul_op_dim, test=develop
Co-authored-by: Nphlrain <phliuhongyu@126.com>
上级 da7c73f8
...@@ -905,16 +905,6 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, ...@@ -905,16 +905,6 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
const OpKernelType& key) const {
auto config_iter = kernel_configs_map_.find(key);
std::vector<KernelConfig>* kernel_configs = nullptr;
if (config_iter != kernel_configs_map_.end()) {
kernel_configs = &(config_iter->second);
}
return kernel_configs;
}
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
// To reduce the elapsed time of HasAttr, we use bool variable to record the // To reduce the elapsed time of HasAttr, we use bool variable to record the
...@@ -951,8 +941,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -951,8 +941,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
ChooseKernel(*runtime_ctx, scope, place); ChooseKernel(*runtime_ctx, scope, place);
} }
std::vector<KernelConfig>* kernel_configs = GetKernelConfig(*kernel_type_);
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
Scope* transfer_scope = nullptr; Scope* transfer_scope = nullptr;
...@@ -988,8 +976,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -988,8 +976,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
{ {
platform::RecordEvent record_event("compute", platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
(*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx, (*kernel_func_)(
kernel_configs)); ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
} }
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
...@@ -1058,7 +1046,7 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, ...@@ -1058,7 +1046,7 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType( auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr)); ExecutionContext(*this, scope, *dev_ctx, ctx));
if (HasAttr("op_device")) { if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") { if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace(); expected_kernel_key.place_ = platform::CPUPlace();
......
...@@ -31,7 +31,6 @@ limitations under the License. */ ...@@ -31,7 +31,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -216,30 +215,12 @@ class OperatorBase { ...@@ -216,30 +215,12 @@ class OperatorBase {
const platform::Place& place) const = 0; const platform::Place& place) const = 0;
}; };
#ifdef PADDLE_WITH_CUDA
using KernelConfig = boost::variant<
std::shared_ptr<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>>;
#else
using KernelConfig = boost::variant<boost::blank>;
#endif
using OpKernelConfigsMap =
std::unordered_map<OpKernelType, std::vector<KernelConfig>,
OpKernelType::Hash>;
class ExecutionContext { class ExecutionContext {
public: public:
ExecutionContext(const OperatorBase& op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context, const platform::DeviceContext& device_context,
const RuntimeContext& ctx, const RuntimeContext& ctx)
std::vector<KernelConfig>* configs) : op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
: op_(op),
scope_(scope),
device_context_(device_context),
ctx_(ctx),
kernel_configs_(configs) {}
virtual ~ExecutionContext() {} virtual ~ExecutionContext() {}
virtual std::string InputName(const std::string& name) const { virtual std::string InputName(const std::string& name) const {
...@@ -405,15 +386,6 @@ class ExecutionContext { ...@@ -405,15 +386,6 @@ class ExecutionContext {
return temp_tensor; return temp_tensor;
} }
template <typename T>
T& GetKernelConfig(size_t idx) const {
PADDLE_ENFORCE(
kernel_configs_ && kernel_configs_->size() > static_cast<size_t>(idx),
"%s selected kernel doesn't have kernel config %lu <= %lu",
op_.Type().c_str(), kernel_configs_->size(), idx);
return *boost::get<std::shared_ptr<T>>((*kernel_configs_)[idx]);
}
const RuntimeContext Context() const { return ctx_; } const RuntimeContext Context() const { return ctx_; }
std::string DebugString() const { return op_.DebugString(); } std::string DebugString() const { return op_.DebugString(); }
...@@ -423,7 +395,6 @@ class ExecutionContext { ...@@ -423,7 +395,6 @@ class ExecutionContext {
const Scope& scope_; const Scope& scope_;
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
const RuntimeContext& ctx_; const RuntimeContext& ctx_;
mutable std::vector<KernelConfig>* kernel_configs_;
}; };
template <> template <>
...@@ -499,8 +470,6 @@ class OperatorWithKernel : public OperatorBase { ...@@ -499,8 +470,6 @@ class OperatorWithKernel : public OperatorBase {
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
std::vector<KernelConfig>* GetKernelConfig(const OpKernelType& key) const;
// change this to public so that in dygraph mode we can call it to check if we // change this to public so that in dygraph mode we can call it to check if we
// need transform data // need transform data
virtual OpKernelType GetKernelTypeForVar( virtual OpKernelType GetKernelTypeForVar(
...@@ -537,7 +506,6 @@ class OperatorWithKernel : public OperatorBase { ...@@ -537,7 +506,6 @@ class OperatorWithKernel : public OperatorBase {
const platform::Place& place) const; const platform::Place& place) const;
protected: protected:
mutable OpKernelConfigsMap kernel_configs_map_;
mutable std::unique_ptr<OpKernelType> kernel_type_; mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::unique_ptr<OpKernelFunc> kernel_func_; mutable std::unique_ptr<OpKernelFunc> kernel_func_;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_; mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
......
...@@ -21,18 +21,20 @@ limitations under the License. */ ...@@ -21,18 +21,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Not thread-safe. Should be owned per-kernel. // thread-safe.
template <typename TAlgorithm> template <typename TAlgorithm>
class AlgorithmsCache { class AlgorithmsCache {
public: public:
AlgorithmsCache() : search_times_(0) { hash_.clear(); } AlgorithmsCache() : search_times_(0) { hash_.clear(); }
// Caches the best algorithm for a given // Caches the best algorithm for a given
// combination of tensor dimensions & compute data type. // combination of tensor dimensions & compute data type.
TAlgorithm GetAlgorithm( // cudnn_dtype set for different data type
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2, TAlgorithm GetAlgorithm(const std::vector<int64_t>& dims1,
const std::vector<int>& strides, const std::vector<int>& paddings, const std::vector<int64_t>& dims2,
const std::vector<int>& dilations, const std::vector<int>& strides,
int algorithmFlags, // can set for different data type const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags,
int64_t cudnn_dtype,
std::function<TAlgorithm()> gen_func); std::function<TAlgorithm()> gen_func);
TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags, TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags,
...@@ -41,13 +43,14 @@ class AlgorithmsCache { ...@@ -41,13 +43,14 @@ class AlgorithmsCache {
private: private:
std::unordered_map<int64_t, TAlgorithm> hash_; std::unordered_map<int64_t, TAlgorithm> hash_;
int search_times_; int search_times_;
std::mutex cache_mutex;
}; };
template <typename TAlgorithm> template <typename TAlgorithm>
TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm( TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2, const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings, const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags, const std::vector<int>& dilations, int algorithmFlags, int64_t cudnn_dtype,
std::function<TAlgorithm()> gen_func) { std::function<TAlgorithm()> gen_func) {
int64_t seed = 0; int64_t seed = 0;
// Hash all of the inputs, use to try and look up a previously // Hash all of the inputs, use to try and look up a previously
...@@ -81,38 +84,75 @@ TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm( ...@@ -81,38 +84,75 @@ TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 + seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 5; (seed << 6) + (seed >> 2) + 5;
seed ^= hashFn(static_cast<int64_t>(cudnn_dtype)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 6;
VLOG(10) << "seed:" << seed << ", hash_.size:" << hash_.size(); VLOG(10) << "seed:" << seed << ", hash_.size:" << hash_.size();
if (seed == 0) return gen_func(); if (seed == 0) return gen_func();
if (hash_.find(seed) == hash_.end()) { TAlgorithm ret;
TAlgorithm value = gen_func(); auto it = hash_.end();
hash_[seed] = value; bool have_found = false;
{
std::lock_guard<std::mutex> lock(cache_mutex);
it = hash_.find(seed);
if (it != hash_.end()) {
ret = it->second;
have_found = true;
}
} }
return hash_[seed];
if (!have_found) {
ret = gen_func();
std::lock_guard<std::mutex> lock(cache_mutex);
hash_[seed] = ret;
}
return ret;
} }
template <typename TAlgorithm> template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm( TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
int64_t area, int search_times, int algorithmFlags, int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func) { std::function<TAlgorithm()> gen_func) {
if (hash_.find(area) != hash_.end()) { auto it = hash_.end();
return hash_[area]; {
std::lock_guard<std::mutex> lock(cache_mutex);
it = hash_.find(area);
if (it != hash_.end()) {
return it->second;
}
}
bool gene_flag = false;
{
std::lock_guard<std::mutex> lock(cache_mutex);
gene_flag = (search_times_ < search_times);
} }
if (search_times_ < search_times) {
auto algo = gen_func(); TAlgorithm algo{};
if (gene_flag) {
algo = gen_func();
std::lock_guard<std::mutex> lock(cache_mutex);
hash_[area] = algo; hash_[area] = algo;
++search_times_; ++search_times_;
return algo; return algo;
} }
TAlgorithm algo{};
int64_t min = static_cast<uint64_t>(INT_MAX); int64_t min = static_cast<uint64_t>(INT_MAX);
{
std::lock_guard<std::mutex> lock(cache_mutex);
for (const auto& m : hash_) { for (const auto& m : hash_) {
if (m.first < min) { if (m.first < min) {
min = m.first; min = m.first;
algo = m.second; algo = m.second;
} }
} }
}
return algo; return algo;
} }
......
...@@ -525,7 +525,7 @@ TEST(ExecutionContextAttrAndInOut, new_api) { ...@@ -525,7 +525,7 @@ TEST(ExecutionContextAttrAndInOut, new_api) {
paddle::framework::RuntimeContext ctx({}, {}); paddle::framework::RuntimeContext ctx({}, {});
paddle::framework::ExecutionContext exe_context(*(op.get()), scope, *dev_ctx, paddle::framework::ExecutionContext exe_context(*(op.get()), scope, *dev_ctx,
ctx, nullptr); ctx);
ASSERT_EQ(exe_context.InputSize("input"), 1u); ASSERT_EQ(exe_context.InputSize("input"), 1u);
ASSERT_EQ(exe_context.OutputSize("output"), 1u); ASSERT_EQ(exe_context.OutputSize("output"), 1u);
......
...@@ -33,11 +33,10 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -33,11 +33,10 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::Scope& scope, const framework::Scope& scope,
const platform::DeviceContext& device_context, const platform::DeviceContext& device_context,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
std::vector<framework::KernelConfig>* configs,
const NameVarMap<VarType>& var_base_map_in, const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out, const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: ExecutionContext(op, scope, device_context, ctx, configs), : ExecutionContext(op, scope, device_context, ctx),
var_base_map_in_(var_base_map_in), var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out), var_base_map_out_(var_base_map_out),
attrs_(attrs) {} attrs_(attrs) {}
......
...@@ -80,13 +80,8 @@ void PreparedOp::PrepareData( ...@@ -80,13 +80,8 @@ void PreparedOp::PrepareData(
PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx, platform::DeviceContext* dev_ctx)
std::vector<framework::KernelConfig>* kernel_configs) : op_(op), ctx_(ctx), func_(func), dev_ctx_(dev_ctx) {}
: op_(op),
ctx_(ctx),
func_(func),
dev_ctx_(dev_ctx),
kernel_configs_(kernel_configs) {}
template <typename VarType> template <typename VarType>
PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
...@@ -111,7 +106,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, ...@@ -111,7 +106,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
framework::RuntimeContext ctx({}, {}); framework::RuntimeContext ctx({}, {});
auto expected_kernel_key = auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>( op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs)); op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -120,8 +115,6 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, ...@@ -120,8 +115,6 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
PADDLE_THROW("op %s does not have kernel for %s", op.Type(), PADDLE_THROW("op %s does not have kernel for %s", op.Type(),
KernelTypeToString(expected_kernel_key)); KernelTypeToString(expected_kernel_key));
} }
std::vector<framework::KernelConfig>* kernel_configs =
op.GetKernelConfig(expected_kernel_key);
if (!(expected_kernel_key.place_ == place)) { if (!(expected_kernel_key.place_ == place)) {
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
...@@ -129,7 +122,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, ...@@ -129,7 +122,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
} }
PrepareDataImpl<VarType>(place, ins, op, expected_kernel_key); PrepareDataImpl<VarType>(place, ins, op, expected_kernel_key);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs); return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
...@@ -152,10 +145,8 @@ template <typename VarType> ...@@ -152,10 +145,8 @@ template <typename VarType>
static void PreparedOpRunImpl( static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx, platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
std::vector<framework::KernelConfig>* kernel_configs, const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs) {
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs) {
// TODO(zjl): remove scope in dygraph // TODO(zjl): remove scope in dygraph
framework::Scope scope; framework::Scope scope;
...@@ -163,22 +154,21 @@ static void PreparedOpRunImpl( ...@@ -163,22 +154,21 @@ static void PreparedOpRunImpl(
static_cast<const framework::OperatorWithKernel&>(op).InferShape( static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx); &infer_shape_ctx);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
kernel_configs, ins, outs, attrs)); attrs));
} }
void PreparedOp::Run(const NameVarMap<VarBase>& ins, void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VarBase>(op_, ctx_, func_, dev_ctx_, kernel_configs_, ins, PreparedOpRunImpl<VarBase>(op_, ctx_, func_, dev_ctx_, ins, outs, attrs);
outs, attrs);
} }
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, func_, dev_ctx_, PreparedOpRunImpl<VariableWrapper>(op_, ctx_, func_, dev_ctx_, ins, outs,
kernel_configs_, ins, outs, attrs); attrs);
} }
} // namespace imperative } // namespace imperative
......
...@@ -33,8 +33,7 @@ class PreparedOp { ...@@ -33,8 +33,7 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op, PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx, platform::DeviceContext* dev_ctx);
std::vector<framework::KernelConfig>* kernel_configs);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins, static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
...@@ -72,7 +71,6 @@ class PreparedOp { ...@@ -72,7 +71,6 @@ class PreparedOp {
const framework::RuntimeContext& ctx_; const framework::RuntimeContext& ctx_;
framework::OperatorWithKernel::OpKernelFunc func_; framework::OperatorWithKernel::OpKernelFunc func_;
platform::DeviceContext* dev_ctx_; platform::DeviceContext* dev_ctx_;
std::vector<framework::KernelConfig>* kernel_configs_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -235,7 +235,7 @@ TEST(test_layer, test_dygraph_execution_context) { ...@@ -235,7 +235,7 @@ TEST(test_layer, test_dygraph_execution_context) {
framework::Scope scope; framework::Scope scope;
DygraphExecutionContext<imperative::VarBase> dy_exe_context( DygraphExecutionContext<imperative::VarBase> dy_exe_context(
*(op.get()), scope, *dev_ctx, ctx, nullptr, ins, outs, concat_att_map); *(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map);
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u); ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin"); ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
......
...@@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
auto& dev_ctx = *pool.Get(dev_place); auto& dev_ctx = *pool.Get(dev_place);
framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope); framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx, nullptr); framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids"); const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores"); const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_desc.h" #include "paddle/fluid/platform/cudnn_desc.h"
// #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -89,7 +90,43 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) { ...@@ -89,7 +90,43 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
return out; return out;
} }
using framework::AlgorithmsCache; // ConvSearchCache using framework::AlgorithmsCache to search
// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or
// cudnnConvolutionBwdFilterAlgo_t
class ConvSearchCache {
public:
static ConvSearchCache& Instance() {
static ConvSearchCache instance;
return instance;
}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetForward() {
return &forward_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* GetBackwardData() {
return &backward_data_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>*
GetBackwardFilter() {
return &backward_filter_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetConvFusion() {
return &fusion_forward_cache_;
}
private:
ConvSearchCache() {}
~ConvSearchCache() {}
ConvSearchCache(const ConvSearchCache&) {}
ConvSearchCache& operator=(const ConvSearchCache&) {}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_cache_;
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>
backward_data_cache_;
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>
backward_filter_cache_;
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> fusion_forward_cache_;
};
struct ConvArgs { struct ConvArgs {
cudnnHandle_t handle; cudnnHandle_t handle;
...@@ -97,6 +134,7 @@ struct ConvArgs { ...@@ -97,6 +134,7 @@ struct ConvArgs {
platform::FilterDescriptor wdesc; platform::FilterDescriptor wdesc;
platform::ConvolutionDescriptor cdesc; platform::ConvolutionDescriptor cdesc;
const framework::Tensor *x, *w, *o; const framework::Tensor *x, *w, *o;
cudnnDataType_t cudnn_dtype;
// strides // strides
std::vector<int> s; std::vector<int> s;
...@@ -107,8 +145,9 @@ struct ConvArgs { ...@@ -107,8 +145,9 @@ struct ConvArgs {
ConvArgs(const framework::Tensor* x, const framework::Tensor* w, ConvArgs(const framework::Tensor* x, const framework::Tensor* w,
const framework::Tensor* o, const std::vector<int> s, const framework::Tensor* o, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> d) const std::vector<int> p, const std::vector<int> d,
: x(x), w(w), o(o), s(s), p(p), d(d) {} cudnnDataType_t dtype)
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
}; };
template <typename perf_t> template <typename perf_t>
...@@ -121,7 +160,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -121,7 +160,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
template <typename T> template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search, static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, int algo_cache_id, bool deterministic,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
bool has_got_workspace_size = true; bool has_got_workspace_size = true;
...@@ -183,22 +222,24 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -183,22 +222,24 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
#endif #endif
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
} else { } else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetForward());
auto x_dims = framework::vectorize(args.x->dims()); auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims()); auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:" VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
<< algo_cache_id << ", x_dims:" << x_dims << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p" << args.s << ", args.p" << args.p << ", args.d" << args.d;
<< args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm( algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() { x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count; int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat; std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
...@@ -244,7 +285,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -244,7 +285,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
template <typename T> template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search, static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, int algo_cache_id, bool deterministic,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
...@@ -321,22 +362,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -321,22 +362,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
} else if (deterministic) { } else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else { } else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetBackwardData());
auto x_dims = framework::vectorize(args.x->dims()); auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims()); auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:" VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t"
<< algo_cache_id << ", x_dims:" << x_dims << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p" << args.s << ", args.p" << args.p << ", args.d" << args.d;
<< args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm( algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() { x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count; int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat; std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
...@@ -385,7 +427,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -385,7 +427,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
template <typename T> template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search, static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, int algo_cache_id, bool deterministic,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
...@@ -449,22 +491,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -449,22 +491,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
} else if (deterministic) { } else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else { } else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetBackwardFilter());
auto x_dims = framework::vectorize(args.x->dims()); auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims()); auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:" VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
<< algo_cache_id << ", x_dims:" << x_dims << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p" << args.s << ", args.p" << args.p << ", args.d" << args.d;
<< args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm( algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() { x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count; int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat; std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
......
...@@ -216,9 +216,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -216,9 +216,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
const T* filter_data = transformed_filter_channel.data<T>(); const T* filter_data = transformed_filter_channel.data<T>();
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
ConvArgs args{&transformed_input, &transformed_filter_channel, ConvArgs args{&transformed_input,
&transformed_output, strides, &transformed_filter_channel,
padding_common, dilations}; &transformed_output,
strides,
padding_common,
dilations,
dtype};
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
...@@ -269,7 +273,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -269,7 +273,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnnConvolutionFwdAlgo_t algo{}; cudnnConvolutionFwdAlgo_t algo{};
using search = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
algo = search::Find<T>(args, exhaustive_search, false, 0, ctx); algo = search::Find<T>(args, exhaustive_search, false, ctx);
workspace_size = search::GetWorkspaceSize(args, algo); workspace_size = search::GetWorkspaceSize(args, algo);
#if CUDNN_VERSION_MIN(7, 0, 1) #if CUDNN_VERSION_MIN(7, 0, 1)
...@@ -518,13 +522,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -518,13 +522,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
&transformed_output_grad_channel, &transformed_output_grad_channel,
strides, strides,
padding_common, padding_common,
dilations}; dilations,
dtype};
ConvArgs args2{&transformed_input, ConvArgs args2{&transformed_input,
&transformed_filter_grad_channel, &transformed_filter_grad_channel,
&transformed_output_grad_channel, &transformed_output_grad_channel,
strides, strides,
padding_common, padding_common,
dilations}; dilations,
dtype};
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
DataLayout layout = compute_format == DataLayout::kNHWC ? DataLayout::kNHWC DataLayout layout = compute_format == DataLayout::kNHWC ? DataLayout::kNHWC
...@@ -580,7 +586,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -580,7 +586,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo = data_algo =
search1::Find<T>(args1, exhaustive_search, deterministic, 0, ctx); search1::Find<T>(args1, exhaustive_search, deterministic, ctx);
workspace_size = workspace_size =
std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo));
} }
...@@ -597,7 +603,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -597,7 +603,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = filter_algo =
search2::Find<T>(args2, exhaustive_search, deterministic, 1, ctx); search2::Find<T>(args2, exhaustive_search, deterministic, ctx);
workspace_size = std::max(workspace_size, workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, filter_algo)); search2::GetWorkspaceSize(args2, filter_algo));
} }
...@@ -898,15 +904,26 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -898,15 +904,26 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
ConvArgs args1{&transformed_ddX, W, ConvArgs args1{&transformed_ddX,
&transformed_ddO_channel, strides, W,
padding_common, dilations}; &transformed_ddO_channel,
ConvArgs args2{&transformed_X, ddW, &transformed_ddO_channel, strides, strides,
padding_common, dilations}; padding_common,
ConvArgs args3{&transformed_ddX, dW, &transformed_dO_channel, strides, dilations,
padding_common, dilations}; dtype};
ConvArgs args4{&transformed_dX, ddW, &transformed_dO_channel, strides, ConvArgs args2{
padding_common, dilations}; &transformed_X, ddW, &transformed_ddO_channel, strides, padding_common,
dilations, dtype};
ConvArgs args3{&transformed_ddX,
dW,
&transformed_dO_channel,
strides,
padding_common,
dilations,
dtype};
ConvArgs args4{
&transformed_dX, ddW, &transformed_dO_channel, strides, padding_common,
dilations, dtype};
cudnnConvolutionFwdAlgo_t fwd_algo1 = cudnnConvolutionFwdAlgo_t fwd_algo1 =
static_cast<cudnnConvolutionFwdAlgo_t>(0); static_cast<cudnnConvolutionFwdAlgo_t>(0);
...@@ -934,7 +951,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -934,7 +951,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); args1.cdesc.set(dtype, padding_common, strides, dilations, c_group);
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, 0, ctx); fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1);
} }
...@@ -949,7 +966,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -949,7 +966,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); args2.cdesc.set(dtype, padding_common, strides, dilations, c_group);
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, 0, ctx); fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
workspace_size = std::max(workspace_size, workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, fwd_algo2)); search2::GetWorkspaceSize(args2, fwd_algo2));
} }
...@@ -967,7 +984,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -967,7 +984,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = filter_algo =
search3::Find<T>(args3, exhaustive_search, deterministic, 1, ctx); search3::Find<T>(args3, exhaustive_search, deterministic, ctx);
workspace_size = std::max(workspace_size, workspace_size = std::max(workspace_size,
search3::GetWorkspaceSize(args3, filter_algo)); search3::GetWorkspaceSize(args3, filter_algo));
} }
...@@ -983,7 +1000,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -983,7 +1000,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo = data_algo =
search4::Find<T>(args4, exhaustive_search, deterministic, 2, ctx); search4::Find<T>(args4, exhaustive_search, deterministic, ctx);
workspace_size = workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
} }
......
...@@ -178,17 +178,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -178,17 +178,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value); library, customized_type_value);
#ifdef PADDLE_WITH_CUDA
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
// TODO(dangqingqing): Currently conv_fusion_op use cudnn but sets use_cudnn
// to false. It should be fixed and then here should only create if library
// is kCUDNN.
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>> p(
new framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>());
configs.push_back(p);
}
#endif
return type; return type;
} }
...@@ -563,21 +552,6 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -563,21 +552,6 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
auto type = framework::OpKernelType( auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value); layout_, library_, customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>
p(new framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>());
configs.push_back(p);
std::shared_ptr<
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>
p2(new framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>());
configs.push_back(p2);
}
}
#endif
return type; return type;
} }
...@@ -754,25 +728,6 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( ...@@ -754,25 +728,6 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
auto type = framework::OpKernelType( auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value); layout_, library_, customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>> p0(
new framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>());
configs.push_back(p0);
std::shared_ptr<
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>
p1(new framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>());
configs.push_back(p1);
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>
p2(new framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>());
configs.push_back(p2);
}
}
#endif
return type; return type;
} }
......
...@@ -58,8 +58,7 @@ void MainTest(const TestData& test_data) { ...@@ -58,8 +58,7 @@ void MainTest(const TestData& test_data) {
RuntimeContext runtime_ctx = RuntimeContext runtime_ctx =
RuntimeContext(op->Inputs(), op->Outputs(), scope); RuntimeContext(op->Inputs(), op->Outputs(), scope);
ExecutionContext ctx = ExecutionContext ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_ctx);
ExecutionContext(*op, scope, *dev_ctx, runtime_ctx, nullptr);
bool result = ElementwiseMulOp::AreDimsAndFormatCorrect( bool result = ElementwiseMulOp::AreDimsAndFormatCorrect(
ctx, 16, MKLDNNMemoryFormat::nChw16c); ctx, 16, MKLDNNMemoryFormat::nChw16c);
if (test_data.supposed_to_fail) if (test_data.supposed_to_fail)
......
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#include <array> #include <array>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h" #include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_int64(cudnn_exhaustive_search_times); DECLARE_int64(cudnn_exhaustive_search_times);
...@@ -233,7 +233,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -233,7 +233,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
return fwd_perf_stat[0].algo; return fwd_perf_stat[0].algo;
}; };
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache = AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(0); *(ConvSearchCache::Instance().GetConvFusion());
int search_times = ctx.Attr<int>("search_times"); int search_times = ctx.Attr<int>("search_times");
search_times = std::max( search_times = std::max(
static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times); static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
...@@ -245,8 +245,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -245,8 +245,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
algo = algo_cache.GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0, algo = algo_cache.GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0,
search_func); search_func);
} else { } else {
auto dtype = platform::CudnnDataType<T>::type;
algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, paddings, algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, paddings,
dilations, 0, search_func); dilations, 0, dtype, search_func);
} }
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
} }
......
...@@ -61,8 +61,8 @@ class WarpCTCOp : public framework::OperatorWithKernel { ...@@ -61,8 +61,8 @@ class WarpCTCOp : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace(),
ctx.device_context(), layout_, library_); layout_, library_);
} }
}; };
...@@ -174,7 +174,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { ...@@ -174,7 +174,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")), ctx, framework::GradVarName("Loss")),
ctx.device_context()); ctx.GetPlace());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册