未验证 提交 8c6bbb48 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Accesses to oneDNN cache optimized for conv2d (#33048)

上级 9b203ef3
...@@ -35,7 +35,8 @@ using user_function = std::function<std::shared_ptr<float>(const float*)>; ...@@ -35,7 +35,8 @@ using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = mkldnn::memory; using memory = mkldnn::memory;
template <typename T, typename TForward, template <typename T, typename TForward,
typename TBackward = mkldnn_dummy_primitive> typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive>
class MKLDNNHandlerT { class MKLDNNHandlerT {
public: public:
MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
...@@ -72,6 +73,21 @@ class MKLDNNHandlerT { ...@@ -72,6 +73,21 @@ class MKLDNNHandlerT {
return backward_p; return backward_p;
} }
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
const std::string key_p = key_ + "@bwd_w_p";
auto backward_p =
std::static_pointer_cast<TBackward_params>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) {
PADDLE_ENFORCE_NOT_NULL(bwd_w_pd_, platform::errors::Unavailable(
"Error: BWD_PD should be set when "
"getting BWD prim witk key: %s .",
key_p));
backward_p = std::make_shared<TBackward_params>(*bwd_w_pd_);
dev_ctx_.SetBlob(key_p, backward_p);
}
return backward_p;
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
...@@ -116,6 +132,29 @@ class MKLDNNHandlerT { ...@@ -116,6 +132,29 @@ class MKLDNNHandlerT {
"@diff_src_mem_p"); "@diff_src_mem_p");
} }
// Buffer of given Tensor is used for oneDNN computation
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory(
framework::Tensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"Error: BWD_W_PD should be set when getting BWD grad of weights."));
T* ptr = diff_weights->mutable_data<T>(
place_, bwd_w_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(), ptr,
"@diff_wei_mem_p");
}
// Buffer is allocated by oneDNN to store computation results
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"Error: BWD_W_PD should be set when getting BWD grad of weights."));
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(),
"@diff_wei_mem_p");
}
protected: protected:
bool isCached() { bool isCached() {
const std::string key_pd = key_common_ + "@fwd_pd"; const std::string key_pd = key_common_ + "@fwd_pd";
...@@ -243,6 +282,27 @@ class MKLDNNHandlerT { ...@@ -243,6 +282,27 @@ class MKLDNNHandlerT {
} }
} }
template <typename... Args>
void AcquireBackwardWeightsPrimitiveDescriptorNonBlocking(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptorNonBlocking
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.",
key_ + "@fwd_pd"));
const std::string key_pd = key_ + "@bwd_w_pd";
bwd_w_pd_ =
std::static_pointer_cast<typename TBackward_params::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_w_pd_ == nullptr) {
auto bwd_desc =
typename TBackward_params::desc(std::forward<Args>(args)...);
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
dev_ctx_.SetBlob(key_pd, bwd_w_pd_);
}
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
const std::string& suffix) { const std::string& suffix) {
return std::static_pointer_cast<mkldnn::memory>( return std::static_pointer_cast<mkldnn::memory>(
...@@ -370,6 +430,7 @@ class MKLDNNHandlerT { ...@@ -370,6 +430,7 @@ class MKLDNNHandlerT {
std::string key_; std::string key_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_; std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_; std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
std::shared_ptr<typename TBackward_params::primitive_desc> bwd_w_pd_;
}; };
// TODO(grygielski) this class will be deleted later. // TODO(grygielski) this class will be deleted later.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册