diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index faabb8575e057ad3f6fb9b1223e649be25b7ec6a..68abe937b19d9413189e201f61c20fb3e69d3cf7 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -22,7 +22,7 @@ limitations under the License. */ #endif #ifdef PADDLE_WITH_MKLDNN -#include "mkldnn.hpp" +#include "paddle/platform/mkldnn_helper.h" #endif #include "paddle/platform/enforce.h" @@ -122,16 +122,6 @@ class CUDNNDeviceContext : public CUDADeviceContext { #endif #ifdef PADDLE_WITH_MKLDNN -using MKLDNNStream = mkldnn::stream; -using MKLDNNEngine = mkldnn::engine; -using MKLDNNMemory = mkldnn::memory; -using MKLDNNPrimitive = mkldnn::primitive; -using MKLDNNPrimitiveDesc = mkldnn::handle; - -typedef std::shared_ptr MKLDNNEnginePtr; -typedef std::shared_ptr MKLDNNMemoryPtr; -typedef std::shared_ptr MKLDNNPrimitivePtr; -typedef std::shared_ptr MKLDNNPrimitiveDescPtr; class MKLDNNDeviceContext : public CPUDeviceContext { public: explicit MKLDNNDeviceContext(CPUPlace place); diff --git a/paddle/platform/mkldnn_helper.h b/paddle/platform/mkldnn_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..2649e4a5f375554af225d6abfc84218f86046383 --- /dev/null +++ b/paddle/platform/mkldnn_helper.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "mkldnn.hpp" + +namespace paddle { +namespace platform { + +using MKLDNNStream = mkldnn::stream; +using MKLDNNEngine = mkldnn::engine; +using MKLDNNMemory = mkldnn::memory; +using MKLDNNPrimitive = mkldnn::primitive; +using MKLDNNPrimitiveDesc = mkldnn::handle; + +typedef std::shared_ptr MKLDNNEnginePtr; +typedef std::shared_ptr MKLDNNMemoryPtr; +typedef std::shared_ptr MKLDNNPrimitivePtr; +typedef std::shared_ptr MKLDNNPrimitiveDescPtr; + +} // namespace platform +} // namespace paddle