未验证 提交 bfb07aaf 编写于 作者: Z zhongpu 提交者: GitHub

Revert "Exhaustive search (#22821)", test=develop (#23401)

This reverts commit 48144e40.
上级 7fda333a
...@@ -905,6 +905,16 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, ...@@ -905,6 +905,16 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
const OpKernelType& key) const {
auto config_iter = kernel_configs_map_.find(key);
std::vector<KernelConfig>* kernel_configs = nullptr;
if (config_iter != kernel_configs_map_.end()) {
kernel_configs = &(config_iter->second);
}
return kernel_configs;
}
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
// To reduce the elapsed time of HasAttr, we use bool variable to record the // To reduce the elapsed time of HasAttr, we use bool variable to record the
...@@ -941,6 +951,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -941,6 +951,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
ChooseKernel(*runtime_ctx, scope, place); ChooseKernel(*runtime_ctx, scope, place);
} }
std::vector<KernelConfig>* kernel_configs = GetKernelConfig(*kernel_type_);
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
Scope* transfer_scope = nullptr; Scope* transfer_scope = nullptr;
...@@ -976,8 +988,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -976,8 +988,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
{ {
platform::RecordEvent record_event("compute", platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
(*kernel_func_)( (*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx,
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); kernel_configs));
} }
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
...@@ -1046,7 +1058,7 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, ...@@ -1046,7 +1058,7 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
auto expected_kernel_key = this->GetExpectedKernelType( auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx)); ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
if (HasAttr("op_device")) { if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") { if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace(); expected_kernel_key.place_ = platform::CPUPlace();
......
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -215,12 +216,30 @@ class OperatorBase { ...@@ -215,12 +216,30 @@ class OperatorBase {
const platform::Place& place) const = 0; const platform::Place& place) const = 0;
}; };
#ifdef PADDLE_WITH_CUDA
using KernelConfig = boost::variant<
std::shared_ptr<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>,
std::shared_ptr<AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>>;
#else
using KernelConfig = boost::variant<boost::blank>;
#endif
using OpKernelConfigsMap =
std::unordered_map<OpKernelType, std::vector<KernelConfig>,
OpKernelType::Hash>;
class ExecutionContext { class ExecutionContext {
public: public:
ExecutionContext(const OperatorBase& op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context, const platform::DeviceContext& device_context,
const RuntimeContext& ctx) const RuntimeContext& ctx,
: op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {} std::vector<KernelConfig>* configs)
: op_(op),
scope_(scope),
device_context_(device_context),
ctx_(ctx),
kernel_configs_(configs) {}
virtual ~ExecutionContext() {} virtual ~ExecutionContext() {}
virtual std::string InputName(const std::string& name) const { virtual std::string InputName(const std::string& name) const {
...@@ -386,6 +405,15 @@ class ExecutionContext { ...@@ -386,6 +405,15 @@ class ExecutionContext {
return temp_tensor; return temp_tensor;
} }
template <typename T>
T& GetKernelConfig(size_t idx) const {
PADDLE_ENFORCE(
kernel_configs_ && kernel_configs_->size() > static_cast<size_t>(idx),
"%s selected kernel doesn't have kernel config %lu <= %lu",
op_.Type().c_str(), kernel_configs_->size(), idx);
return *boost::get<std::shared_ptr<T>>((*kernel_configs_)[idx]);
}
const RuntimeContext Context() const { return ctx_; } const RuntimeContext Context() const { return ctx_; }
std::string DebugString() const { return op_.DebugString(); } std::string DebugString() const { return op_.DebugString(); }
...@@ -395,6 +423,7 @@ class ExecutionContext { ...@@ -395,6 +423,7 @@ class ExecutionContext {
const Scope& scope_; const Scope& scope_;
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
const RuntimeContext& ctx_; const RuntimeContext& ctx_;
mutable std::vector<KernelConfig>* kernel_configs_;
}; };
template <> template <>
...@@ -470,6 +499,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -470,6 +499,8 @@ class OperatorWithKernel : public OperatorBase {
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
std::vector<KernelConfig>* GetKernelConfig(const OpKernelType& key) const;
// change this to public so that in dygraph mode we can call it to check if we // change this to public so that in dygraph mode we can call it to check if we
// need transform data // need transform data
virtual OpKernelType GetKernelTypeForVar( virtual OpKernelType GetKernelTypeForVar(
...@@ -506,6 +537,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -506,6 +537,7 @@ class OperatorWithKernel : public OperatorBase {
const platform::Place& place) const; const platform::Place& place) const;
protected: protected:
mutable OpKernelConfigsMap kernel_configs_map_;
mutable std::unique_ptr<OpKernelType> kernel_type_; mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::unique_ptr<OpKernelFunc> kernel_func_; mutable std::unique_ptr<OpKernelFunc> kernel_func_;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_; mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
......
...@@ -21,21 +21,19 @@ limitations under the License. */ ...@@ -21,21 +21,19 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// thread-safe. // Not thread-safe. Should be owned per-kernel.
template <typename TAlgorithm> template <typename TAlgorithm>
class AlgorithmsCache { class AlgorithmsCache {
public: public:
AlgorithmsCache() : search_times_(0) { hash_.clear(); } AlgorithmsCache() : search_times_(0) { hash_.clear(); }
// Caches the best algorithm for a given // Caches the best algorithm for a given
// combination of tensor dimensions & compute data type. // combination of tensor dimensions & compute data type.
// cudnn_dtype set for different data type TAlgorithm GetAlgorithm(
TAlgorithm GetAlgorithm(const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int64_t>& dims2, const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& strides, const std::vector<int>& dilations,
const std::vector<int>& paddings, int algorithmFlags, // can set for different data type
const std::vector<int>& dilations, int algorithmFlags, std::function<TAlgorithm()> gen_func);
int64_t cudnn_dtype,
std::function<TAlgorithm()> gen_func);
TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags, TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func); std::function<TAlgorithm()> gen_func);
...@@ -43,14 +41,13 @@ class AlgorithmsCache { ...@@ -43,14 +41,13 @@ class AlgorithmsCache {
private: private:
std::unordered_map<int64_t, TAlgorithm> hash_; std::unordered_map<int64_t, TAlgorithm> hash_;
int search_times_; int search_times_;
std::mutex cache_mutex;
}; };
template <typename TAlgorithm> template <typename TAlgorithm>
TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm( TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2, const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings, const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags, int64_t cudnn_dtype, const std::vector<int>& dilations, int algorithmFlags,
std::function<TAlgorithm()> gen_func) { std::function<TAlgorithm()> gen_func) {
int64_t seed = 0; int64_t seed = 0;
// Hash all of the inputs, use to try and look up a previously // Hash all of the inputs, use to try and look up a previously
...@@ -84,73 +81,36 @@ TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm( ...@@ -84,73 +81,36 @@ TAlgorithm framework::AlgorithmsCache<TAlgorithm>::GetAlgorithm(
seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 + seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 5; (seed << 6) + (seed >> 2) + 5;
seed ^= hashFn(static_cast<int64_t>(cudnn_dtype)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 6;
VLOG(10) << "seed:" << seed << ", hash_.size:" << hash_.size(); VLOG(10) << "seed:" << seed << ", hash_.size:" << hash_.size();
if (seed == 0) return gen_func(); if (seed == 0) return gen_func();
TAlgorithm ret; if (hash_.find(seed) == hash_.end()) {
auto it = hash_.end(); TAlgorithm value = gen_func();
bool have_found = false; hash_[seed] = value;
{
std::lock_guard<std::mutex> lock(cache_mutex);
it = hash_.find(seed);
if (it != hash_.end()) {
ret = it->second;
have_found = true;
}
}
if (!have_found) {
ret = gen_func();
std::lock_guard<std::mutex> lock(cache_mutex);
hash_[seed] = ret;
} }
return hash_[seed];
return ret;
} }
template <typename TAlgorithm> template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm( TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
int64_t area, int search_times, int algorithmFlags, int64_t area, int search_times, int algorithmFlags,
std::function<TAlgorithm()> gen_func) { std::function<TAlgorithm()> gen_func) {
auto it = hash_.end(); if (hash_.find(area) != hash_.end()) {
{ return hash_[area];
std::lock_guard<std::mutex> lock(cache_mutex);
it = hash_.find(area);
if (it != hash_.end()) {
return it->second;
}
}
bool gene_flag = false;
{
std::lock_guard<std::mutex> lock(cache_mutex);
gene_flag = (search_times_ < search_times);
} }
if (search_times_ < search_times) {
TAlgorithm algo{}; auto algo = gen_func();
if (gene_flag) {
algo = gen_func();
std::lock_guard<std::mutex> lock(cache_mutex);
hash_[area] = algo; hash_[area] = algo;
++search_times_; ++search_times_;
return algo; return algo;
} }
TAlgorithm algo{};
int64_t min = static_cast<uint64_t>(INT_MAX); int64_t min = static_cast<uint64_t>(INT_MAX);
{ for (const auto& m : hash_) {
std::lock_guard<std::mutex> lock(cache_mutex); if (m.first < min) {
for (const auto& m : hash_) { min = m.first;
if (m.first < min) { algo = m.second;
min = m.first;
algo = m.second;
}
} }
} }
return algo; return algo;
......
...@@ -525,7 +525,7 @@ TEST(ExecutionContextAttrAndInOut, new_api) { ...@@ -525,7 +525,7 @@ TEST(ExecutionContextAttrAndInOut, new_api) {
paddle::framework::RuntimeContext ctx({}, {}); paddle::framework::RuntimeContext ctx({}, {});
paddle::framework::ExecutionContext exe_context(*(op.get()), scope, *dev_ctx, paddle::framework::ExecutionContext exe_context(*(op.get()), scope, *dev_ctx,
ctx); ctx, nullptr);
ASSERT_EQ(exe_context.InputSize("input"), 1u); ASSERT_EQ(exe_context.InputSize("input"), 1u);
ASSERT_EQ(exe_context.OutputSize("output"), 1u); ASSERT_EQ(exe_context.OutputSize("output"), 1u);
......
...@@ -33,10 +33,11 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -33,10 +33,11 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::Scope& scope, const framework::Scope& scope,
const platform::DeviceContext& device_context, const platform::DeviceContext& device_context,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
std::vector<framework::KernelConfig>* configs,
const NameVarMap<VarType>& var_base_map_in, const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out, const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: ExecutionContext(op, scope, device_context, ctx), : ExecutionContext(op, scope, device_context, ctx, configs),
var_base_map_in_(var_base_map_in), var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out), var_base_map_out_(var_base_map_out),
attrs_(attrs) {} attrs_(attrs) {}
......
...@@ -80,8 +80,13 @@ void PreparedOp::PrepareData( ...@@ -80,8 +80,13 @@ void PreparedOp::PrepareData(
PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx) platform::DeviceContext* dev_ctx,
: op_(op), ctx_(ctx), func_(func), dev_ctx_(dev_ctx) {} std::vector<framework::KernelConfig>* kernel_configs)
: op_(op),
ctx_(ctx),
func_(func),
dev_ctx_(dev_ctx),
kernel_configs_(kernel_configs) {}
template <typename VarType> template <typename VarType>
PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
...@@ -106,7 +111,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, ...@@ -106,7 +111,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
framework::RuntimeContext ctx({}, {}); framework::RuntimeContext ctx({}, {});
auto expected_kernel_key = auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>( op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs)); op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -115,6 +120,8 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, ...@@ -115,6 +120,8 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
PADDLE_THROW("op %s does not have kernel for %s", op.Type(), PADDLE_THROW("op %s does not have kernel for %s", op.Type(),
KernelTypeToString(expected_kernel_key)); KernelTypeToString(expected_kernel_key));
} }
std::vector<framework::KernelConfig>* kernel_configs =
op.GetKernelConfig(expected_kernel_key);
if (!(expected_kernel_key.place_ == place)) { if (!(expected_kernel_key.place_ == place)) {
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
...@@ -122,7 +129,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, ...@@ -122,7 +129,7 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
} }
PrepareDataImpl<VarType>(place, ins, op, expected_kernel_key); PrepareDataImpl<VarType>(place, ins, op, expected_kernel_key);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx); return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
...@@ -145,8 +152,10 @@ template <typename VarType> ...@@ -145,8 +152,10 @@ template <typename VarType>
static void PreparedOpRunImpl( static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins, platform::DeviceContext* dev_ctx,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs) { std::vector<framework::KernelConfig>* kernel_configs,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs) {
// TODO(zjl): remove scope in dygraph // TODO(zjl): remove scope in dygraph
framework::Scope scope; framework::Scope scope;
...@@ -154,21 +163,22 @@ static void PreparedOpRunImpl( ...@@ -154,21 +163,22 @@ static void PreparedOpRunImpl(
static_cast<const framework::OperatorWithKernel&>(op).InferShape( static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx); &infer_shape_ctx);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs, func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx,
attrs)); kernel_configs, ins, outs, attrs));
} }
void PreparedOp::Run(const NameVarMap<VarBase>& ins, void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VarBase>(op_, ctx_, func_, dev_ctx_, ins, outs, attrs); PreparedOpRunImpl<VarBase>(op_, ctx_, func_, dev_ctx_, kernel_configs_, ins,
outs, attrs);
} }
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, func_, dev_ctx_, ins, outs, PreparedOpRunImpl<VariableWrapper>(op_, ctx_, func_, dev_ctx_,
attrs); kernel_configs_, ins, outs, attrs);
} }
} // namespace imperative } // namespace imperative
......
...@@ -33,7 +33,8 @@ class PreparedOp { ...@@ -33,7 +33,8 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op, PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx); platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins, static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
...@@ -71,6 +72,7 @@ class PreparedOp { ...@@ -71,6 +72,7 @@ class PreparedOp {
const framework::RuntimeContext& ctx_; const framework::RuntimeContext& ctx_;
framework::OperatorWithKernel::OpKernelFunc func_; framework::OperatorWithKernel::OpKernelFunc func_;
platform::DeviceContext* dev_ctx_; platform::DeviceContext* dev_ctx_;
std::vector<framework::KernelConfig>* kernel_configs_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -235,7 +235,7 @@ TEST(test_layer, test_dygraph_execution_context) { ...@@ -235,7 +235,7 @@ TEST(test_layer, test_dygraph_execution_context) {
framework::Scope scope; framework::Scope scope;
DygraphExecutionContext<imperative::VarBase> dy_exe_context( DygraphExecutionContext<imperative::VarBase> dy_exe_context(
*(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map); *(op.get()), scope, *dev_ctx, ctx, nullptr, ins, outs, concat_att_map);
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u); ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin"); ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
......
...@@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
auto& dev_ctx = *pool.Get(dev_place); auto& dev_ctx = *pool.Get(dev_place);
framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope); framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx); framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx, nullptr);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids"); const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores"); const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
......
...@@ -21,7 +21,6 @@ limitations under the License. */ ...@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_desc.h" #include "paddle/fluid/platform/cudnn_desc.h"
// #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -90,43 +89,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) { ...@@ -90,43 +89,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
return out; return out;
} }
// ConvSearchCache using framework::AlgorithmsCache to search using framework::AlgorithmsCache;
// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or
// cudnnConvolutionBwdFilterAlgo_t
class ConvSearchCache {
public:
static ConvSearchCache& Instance() {
static ConvSearchCache instance;
return instance;
}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetForward() {
return &forward_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* GetBackwardData() {
return &backward_data_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>*
GetBackwardFilter() {
return &backward_filter_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetConvFusion() {
return &fusion_forward_cache_;
}
private:
ConvSearchCache() {}
~ConvSearchCache() {}
ConvSearchCache(const ConvSearchCache&) {}
ConvSearchCache& operator=(const ConvSearchCache&) {}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_cache_;
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>
backward_data_cache_;
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>
backward_filter_cache_;
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> fusion_forward_cache_;
};
struct ConvArgs { struct ConvArgs {
cudnnHandle_t handle; cudnnHandle_t handle;
...@@ -134,7 +97,6 @@ struct ConvArgs { ...@@ -134,7 +97,6 @@ struct ConvArgs {
platform::FilterDescriptor wdesc; platform::FilterDescriptor wdesc;
platform::ConvolutionDescriptor cdesc; platform::ConvolutionDescriptor cdesc;
const framework::Tensor *x, *w, *o; const framework::Tensor *x, *w, *o;
cudnnDataType_t cudnn_dtype;
// strides // strides
std::vector<int> s; std::vector<int> s;
...@@ -145,9 +107,8 @@ struct ConvArgs { ...@@ -145,9 +107,8 @@ struct ConvArgs {
ConvArgs(const framework::Tensor* x, const framework::Tensor* w, ConvArgs(const framework::Tensor* x, const framework::Tensor* w,
const framework::Tensor* o, const std::vector<int> s, const framework::Tensor* o, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> d, const std::vector<int> p, const std::vector<int> d)
cudnnDataType_t dtype) : x(x), w(w), o(o), s(s), p(p), d(d) {}
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
}; };
template <typename perf_t> template <typename perf_t>
...@@ -160,7 +121,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -160,7 +121,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
template <typename T> template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search, static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, bool deterministic, int algo_cache_id,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
bool has_got_workspace_size = true; bool has_got_workspace_size = true;
...@@ -222,24 +183,22 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -222,24 +183,22 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
#endif #endif
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
} else { } else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetForward());
auto x_dims = framework::vectorize(args.x->dims()); auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims()); auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:" VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" << algo_cache_id << ", x_dims:" << x_dims
<< args.s << ", args.p" << args.p << ", args.d" << args.d; << ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
<< args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm( algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count; int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat; std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
...@@ -285,7 +244,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -285,7 +244,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
template <typename T> template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search, static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, bool deterministic, int algo_cache_id,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
...@@ -362,23 +321,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -362,23 +321,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
} else if (deterministic) { } else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else { } else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetBackwardData());
auto x_dims = framework::vectorize(args.x->dims()); auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims()); auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t" VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" << algo_cache_id << ", x_dims:" << x_dims
<< args.s << ", args.p" << args.p << ", args.d" << args.d; << ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
<< args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm( algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count; int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat; std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
...@@ -427,7 +385,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -427,7 +385,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
template <typename T> template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search, static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, bool deterministic, int algo_cache_id,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
...@@ -491,22 +449,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -491,22 +449,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
} else if (deterministic) { } else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else { } else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetBackwardFilter());
auto x_dims = framework::vectorize(args.x->dims()); auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims()); auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:" VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" << algo_cache_id << ", x_dims:" << x_dims
<< args.s << ", args.p" << args.p << ", args.d" << args.d; << ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
<< args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm( algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count; int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat; std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
......
...@@ -216,13 +216,9 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -216,13 +216,9 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
const T* filter_data = transformed_filter_channel.data<T>(); const T* filter_data = transformed_filter_channel.data<T>();
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
ConvArgs args{&transformed_input, ConvArgs args{&transformed_input, &transformed_filter_channel,
&transformed_filter_channel, &transformed_output, strides,
&transformed_output, padding_common, dilations};
strides,
padding_common,
dilations,
dtype};
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
...@@ -273,7 +269,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -273,7 +269,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnnConvolutionFwdAlgo_t algo{}; cudnnConvolutionFwdAlgo_t algo{};
using search = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
algo = search::Find<T>(args, exhaustive_search, false, ctx); algo = search::Find<T>(args, exhaustive_search, false, 0, ctx);
workspace_size = search::GetWorkspaceSize(args, algo); workspace_size = search::GetWorkspaceSize(args, algo);
#if CUDNN_VERSION_MIN(7, 0, 1) #if CUDNN_VERSION_MIN(7, 0, 1)
...@@ -522,15 +518,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -522,15 +518,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
&transformed_output_grad_channel, &transformed_output_grad_channel,
strides, strides,
padding_common, padding_common,
dilations, dilations};
dtype};
ConvArgs args2{&transformed_input, ConvArgs args2{&transformed_input,
&transformed_filter_grad_channel, &transformed_filter_grad_channel,
&transformed_output_grad_channel, &transformed_output_grad_channel,
strides, strides,
padding_common, padding_common,
dilations, dilations};
dtype};
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
DataLayout layout = compute_format == DataLayout::kNHWC ? DataLayout::kNHWC DataLayout layout = compute_format == DataLayout::kNHWC ? DataLayout::kNHWC
...@@ -586,7 +580,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -586,7 +580,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo = data_algo =
search1::Find<T>(args1, exhaustive_search, deterministic, ctx); search1::Find<T>(args1, exhaustive_search, deterministic, 0, ctx);
workspace_size = workspace_size =
std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo));
} }
...@@ -603,7 +597,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -603,7 +597,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = filter_algo =
search2::Find<T>(args2, exhaustive_search, deterministic, ctx); search2::Find<T>(args2, exhaustive_search, deterministic, 1, ctx);
workspace_size = std::max(workspace_size, workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, filter_algo)); search2::GetWorkspaceSize(args2, filter_algo));
} }
...@@ -904,26 +898,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -904,26 +898,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
ConvArgs args1{&transformed_ddX, ConvArgs args1{&transformed_ddX, W,
W, &transformed_ddO_channel, strides,
&transformed_ddO_channel, padding_common, dilations};
strides, ConvArgs args2{&transformed_X, ddW, &transformed_ddO_channel, strides,
padding_common, padding_common, dilations};
dilations, ConvArgs args3{&transformed_ddX, dW, &transformed_dO_channel, strides,
dtype}; padding_common, dilations};
ConvArgs args2{ ConvArgs args4{&transformed_dX, ddW, &transformed_dO_channel, strides,
&transformed_X, ddW, &transformed_ddO_channel, strides, padding_common, padding_common, dilations};
dilations, dtype};
ConvArgs args3{&transformed_ddX,
dW,
&transformed_dO_channel,
strides,
padding_common,
dilations,
dtype};
ConvArgs args4{
&transformed_dX, ddW, &transformed_dO_channel, strides, padding_common,
dilations, dtype};
cudnnConvolutionFwdAlgo_t fwd_algo1 = cudnnConvolutionFwdAlgo_t fwd_algo1 =
static_cast<cudnnConvolutionFwdAlgo_t>(0); static_cast<cudnnConvolutionFwdAlgo_t>(0);
...@@ -951,7 +934,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -951,7 +934,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args1.cdesc.set(dtype, padding_common, strides, dilations, c_group); args1.cdesc.set(dtype, padding_common, strides, dilations, c_group);
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx); fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, 0, ctx);
workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1);
} }
...@@ -966,7 +949,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -966,7 +949,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
args2.cdesc.set(dtype, padding_common, strides, dilations, c_group); args2.cdesc.set(dtype, padding_common, strides, dilations, c_group);
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx); fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, 0, ctx);
workspace_size = std::max(workspace_size, workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, fwd_algo2)); search2::GetWorkspaceSize(args2, fwd_algo2));
} }
...@@ -984,7 +967,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -984,7 +967,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = filter_algo =
search3::Find<T>(args3, exhaustive_search, deterministic, ctx); search3::Find<T>(args3, exhaustive_search, deterministic, 1, ctx);
workspace_size = std::max(workspace_size, workspace_size = std::max(workspace_size,
search3::GetWorkspaceSize(args3, filter_algo)); search3::GetWorkspaceSize(args3, filter_algo));
} }
...@@ -1000,7 +983,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> { ...@@ -1000,7 +983,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo = data_algo =
search4::Find<T>(args4, exhaustive_search, deterministic, ctx); search4::Find<T>(args4, exhaustive_search, deterministic, 2, ctx);
workspace_size = workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
} }
......
...@@ -178,6 +178,17 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -178,6 +178,17 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value); library, customized_type_value);
#ifdef PADDLE_WITH_CUDA
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
// TODO(dangqingqing): Currently conv_fusion_op use cudnn but sets use_cudnn
// to false. It should be fixed and then here should only create if library
// is kCUDNN.
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>> p(
new framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>());
configs.push_back(p);
}
#endif
return type; return type;
} }
...@@ -552,6 +563,21 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -552,6 +563,21 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
auto type = framework::OpKernelType( auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value); layout_, library_, customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>
p(new framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>());
configs.push_back(p);
std::shared_ptr<
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>
p2(new framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>());
configs.push_back(p2);
}
}
#endif
return type; return type;
} }
...@@ -728,6 +754,25 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( ...@@ -728,6 +754,25 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
auto type = framework::OpKernelType( auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value); layout_, library_, customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>> p0(
new framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>());
configs.push_back(p0);
std::shared_ptr<
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>
p1(new framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>());
configs.push_back(p1);
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>
p2(new framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>());
configs.push_back(p2);
}
}
#endif
return type; return type;
} }
......
...@@ -58,7 +58,8 @@ void MainTest(const TestData& test_data) { ...@@ -58,7 +58,8 @@ void MainTest(const TestData& test_data) {
RuntimeContext runtime_ctx = RuntimeContext runtime_ctx =
RuntimeContext(op->Inputs(), op->Outputs(), scope); RuntimeContext(op->Inputs(), op->Outputs(), scope);
ExecutionContext ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_ctx); ExecutionContext ctx =
ExecutionContext(*op, scope, *dev_ctx, runtime_ctx, nullptr);
bool result = ElementwiseMulOp::AreDimsAndFormatCorrect( bool result = ElementwiseMulOp::AreDimsAndFormatCorrect(
ctx, 16, MKLDNNMemoryFormat::nChw16c); ctx, 16, MKLDNNMemoryFormat::nChw16c);
if (test_data.supposed_to_fail) if (test_data.supposed_to_fail)
......
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#include <array> #include <array>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h" #include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_int64(cudnn_exhaustive_search_times); DECLARE_int64(cudnn_exhaustive_search_times);
...@@ -233,7 +233,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -233,7 +233,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
return fwd_perf_stat[0].algo; return fwd_perf_stat[0].algo;
}; };
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache = AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
*(ConvSearchCache::Instance().GetConvFusion()); ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(0);
int search_times = ctx.Attr<int>("search_times"); int search_times = ctx.Attr<int>("search_times");
search_times = std::max( search_times = std::max(
static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times); static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
...@@ -245,9 +245,8 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -245,9 +245,8 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
algo = algo_cache.GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0, algo = algo_cache.GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0,
search_func); search_func);
} else { } else {
auto dtype = platform::CudnnDataType<T>::type;
algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, paddings, algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, paddings,
dilations, 0, dtype, search_func); dilations, 0, search_func);
} }
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
} }
......
...@@ -61,8 +61,8 @@ class WarpCTCOp : public framework::OperatorWithKernel { ...@@ -61,8 +61,8 @@ class WarpCTCOp : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
layout_, library_); ctx.device_context(), layout_, library_);
} }
}; };
...@@ -174,7 +174,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { ...@@ -174,7 +174,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")), ctx, framework::GradVarName("Loss")),
ctx.GetPlace()); ctx.device_context());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册