From 62e6dac402ca63b402b5dfd1d7649cba1e258d41 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 18 Aug 2017 14:30:09 +0800 Subject: [PATCH] add MKLDNNMatrix files --- paddle/gserver/layers/MKLDNNLayer.h | 1 + paddle/math/CMakeLists.txt | 15 ++++++++++ paddle/math/MKLDNNMatrix.cpp | 19 ++++++++++++ paddle/math/MKLDNNMatrix.h | 45 +++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+) create mode 100644 paddle/math/MKLDNNMatrix.cpp create mode 100644 paddle/math/MKLDNNMatrix.h diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index 63e29f447ee..9533027fa6c 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "Layer.h" #include "MKLDNNBase.h" #include "mkldnn.hpp" +#include "paddle/math/MKLDNNMatrix.h" DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn_wgt); diff --git a/paddle/math/CMakeLists.txt b/paddle/math/CMakeLists.txt index bf28092e82b..ad6de18c81d 100644 --- a/paddle/math/CMakeLists.txt +++ b/paddle/math/CMakeLists.txt @@ -14,6 +14,21 @@ # file(GLOB MATH_HEADERS . *.h) file(GLOB MATH_SOURCES . *.cpp) + +message(STATUS "----------MATH_HEADERS:${MATH_HEADERS}") +message(STATUS "----------MATH_SOURCES:${MATH_SOURCES}") +if(NOT WITH_MKLDNN) + file(GLOB_RECURSE DNN_HEADER RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "MKLDNN*.h") + file(GLOB_RECURSE DNN_SOURCES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "MKLDNN*.cpp") + message(STATUS "----------DNN_HEADER:${DNN_HEADER}") + message(STATUS "----------DNN_SOURCES:${DNN_SOURCES}") + list(REMOVE_ITEM MATH_HEADERS ${DNN_HEADER}) + list(REMOVE_ITEM MATH_SOURCES ${DNN_SOURCES}) + message(STATUS "Skip compiling with MKLDNNMatrix") +else() + message(STATUS "Compile with MKLDNNMatrix") +endif() + set(MATH_SOURCES "${PADDLE_SOURCE_DIR}/paddle/math/BaseMatrix.cu" "${PADDLE_SOURCE_DIR}/paddle/math/TrainingAlgorithmOp.cu" diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp new file mode 100644 index 00000000000..df8e72d78be --- /dev/null +++ b/paddle/math/MKLDNNMatrix.cpp @@ -0,0 +1,19 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "MKLDNNMatrix.h" + +using namespace mkldnn; // NOLINT + +namespace paddle {} // namespace paddle diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h new file mode 100644 index 00000000000..91ef56f2c34 --- /dev/null +++ b/paddle/math/MKLDNNMatrix.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +//#include "Matrix.h" +#include "Vector.h" + +#include "mkldnn.hpp" +#include "paddle/parameter/Parameter.h" + +namespace paddle { + +static const std::map PARAM_FOARMAT_MAP = + {{mkldnn::memory::format::oi, PARAM_FORMAT_MKLDNN_OI}}; + +class MKLDNNMatrix; +typedef std::shared_ptr MKLDNNMatrixPtr; + +/** + * @brief MKLDNN Matrix. + * + */ +class MKLDNNMatrix : public CpuVector { +public: + explicit MKLDNNMatrix(size_t size, int fmt) : CpuVector(size), fmt_(fmt) {} + + ~MKLDNNMatrix() {} + +protected: + int fmt_; +}; + +} // namespace paddle -- GitLab