diff --git a/CMakeLists.txt b/CMakeLists.txt index de47086dbd6a440cd413c7843c83b1c69d9841b2..23bbe829ac16180088bfa37df66e23f19b021ea3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,7 +39,6 @@ option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_F option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND}) -option(WITH_TENSORRT "Compile PaddlePaddle with TensorRT support." OFF) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) @@ -180,13 +179,9 @@ set(EXTERNAL_LIBS if(WITH_GPU) include(cuda) + include(tensorrt) endif(WITH_GPU) -# TensorRT depends on GPU. -if (NOT WITH_GPU) - set(WITH_TENSORRT OFF) -endif() - if(WITH_AMD_GPU) find_package(HIP) include(hip) diff --git a/Dockerfile b/Dockerfile index 9097bb657d2366997112ec7662762a93358aa647..870304a6acc99e715dffbfabd8058be000b6872c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -46,7 +46,7 @@ ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin RUN curl -s -q https://glide.sh/get | sh # Install TensorRT -# The unnecessary files has been removed to make the library small. +# The unnecessary files has been removed to make the library small. It only contains include and lib now. RUN wget -qO- http://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.0.3.Ubuntu-16.04.4.x86_64-gnu.cuda-8.0.cudnn7.0.tar.gz | \ tar -xz -C /usr/local && \ cp -rf /usr/local/TensorRT/include /usr && \ diff --git a/cmake/configure.cmake b/cmake/configure.cmake index f726405c4773994f6ca6509e5218750805b03995..e490397cc0624c310949a4b571bd00cac6e8953b 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -80,6 +80,16 @@ if(WITH_GPU) # Include cuda and cudnn include_directories(${CUDNN_INCLUDE_DIR}) include_directories(${CUDA_TOOLKIT_INCLUDE}) + + if(TENSORRT_FOUND) + if(${CUDA_VERSION_MAJOR} VERSION_LESS 8) + message(FATAL_ERROR "TensorRT needs CUDA >= 8.0 to compile") + endif() + if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) + message(FATAL_ERROR "TensorRT needs CUDNN >= 7.0 to compile") + endif() + include_directories(${TENSORRT_INCLUDE_DIR}) + endif() elseif(WITH_AMD_GPU) add_definitions(-DPADDLE_WITH_HIP) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__") diff --git a/cmake/tensorrt.cmake b/cmake/tensorrt.cmake new file mode 100644 index 0000000000000000000000000000000000000000..0c07d36bed65400164853b99f18ec0335341cd94 --- /dev/null +++ b/cmake/tensorrt.cmake @@ -0,0 +1,33 @@ +if(NOT WITH_GPU) + return() +endif() + +set(TENSORRT_ROOT "/usr" CACHE PATH "TENSORRT ROOT") +find_path(TENSORRT_INCLUDE_DIR NvInfer.h + PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/include + $ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/include + NO_DEFAULT_PATH +) + +find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a + PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/lib + $ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/lib + NO_DEFAULT_PATH + DOC "Path to TensorRT library.") + +if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY) + set(TENSORRT_FOUND ON) +else() + set(TENSORRT_FOUND OFF) +endif() + +if(TENSORRT_FOUND) + file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS) + string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION + "${TENSORRT_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1" + TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}") + + message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. " + "Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ") +endif() diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index c1486b527d2e06d2b3f7e0f89458bf9a22564586..0962f40c4a64f18f7105626c54a83f1c5b299c50 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -155,13 +155,9 @@ void ParallelExecutor::BCastParamsToGPUs( #endif } -void ParallelExecutor::Run( - const std::vector &fetch_tensors, - const std::string &fetched_var_name, - const std::unordered_map &feed_tensors) { +void ParallelExecutor::Run(const std::vector &fetch_tensors, + const std::string &fetched_var_name) { platform::RecordBlock b(0); - SplitTensorToPlaces(feed_tensors); - // Create local scopes. for (auto &scope : member_->local_scopes_) { Scope &local_scope = scope->NewScope(); @@ -195,14 +191,28 @@ void ParallelExecutor::Run( auto &local_scope = *scope->Var(details::kLocalExecScopeName)->GetMutable(); scope->DeleteScope(local_scope); - local_scope = nullptr; } } -void ParallelExecutor::SplitTensorToPlaces( - const std::unordered_map &feed_tensors) { - for (auto it : feed_tensors) { - auto lod_tensors = it.second.SplitLoDTensor(member_->places_); +void ParallelExecutor::FeedTensorsIntoLocalScopes( + const std::vector> &tensors) { + PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), tensors.size()); + + for (size_t i = 0; i < tensors.size(); ++i) { + auto &map = tensors[i]; + auto *scope = member_->local_scopes_[i]; + for (auto &pair : map) { + auto *trg = scope->Var(pair.first)->GetMutable(); + trg->ShareDataWith(pair.second); + trg->set_lod(pair.second.lod()); + } + } +} + +void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( + const std::unordered_map &tensors) { + for (auto pair : tensors) { + auto lod_tensors = pair.second.SplitLoDTensor(member_->places_); PADDLE_ENFORCE_EQ( member_->places_.size(), lod_tensors.size(), "The number of samples of current batch is less than the count of " @@ -211,7 +221,7 @@ void ParallelExecutor::SplitTensorToPlaces( for (size_t j = 0; j < member_->places_.size(); ++j) { // TODO(panxy0718): Do I need to delete this var? auto t = - member_->local_scopes_[j]->Var(it.first)->GetMutable(); + member_->local_scopes_[j]->Var(pair.first)->GetMutable(); t->ShareDataWith(lod_tensors[j]); t->set_lod(lod_tensors[j].lod()); } diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index b4f16dba858fb279ec23a8a04257dda6651148cc..303ac3bc55cfed57a03765b27d8aba581eabd1c8 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -44,16 +44,22 @@ class ParallelExecutor { std::vector& GetLocalScopes(); + /** + * Feed tensors to local scopes. The size of tensors should be equal to the + * size of local scopes. + */ + void FeedTensorsIntoLocalScopes( + const std::vector>& tensors); + + void FeedAndSplitTensorIntoLocalScopes( + const std::unordered_map& tensors); + void Run(const std::vector& fetch_tensors, - const std::string& fetched_var_name, - const std::unordered_map& feed_tensors); + const std::string& fetched_var_name); void BCastParamsToGPUs(const std::unordered_set& vars) const; private: - void SplitTensorToPlaces( - const std::unordered_map& feed_tensors); - ParallelExecutorPrivate* member_; }; diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 8494edee6c2c714c285c45bbb4fe1d8cb1a524aa..cc45bfe9b17d767be039cc0d8d83234b6994d6c1 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -21,7 +21,7 @@ endif() if(WITH_TESTING) add_subdirectory(tests/book) - if (WITH_TENSORRT) + if (TENSORRT_FOUND) add_subdirectory(tensorrt) endif() endif() diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index b93b925a72a55442c105e4280a3580f4ea5b93a1..364c4901b297dbd647faae85b01f682a1daace9c 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -1,7 +1,7 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc nccl.cc) -if (WITH_TENSORRT) +if (TENSORRT_FOUND) list(APPEND CUDA_SRCS tensorrt.cc) endif() diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a1e8ff6399f0812773a7bb753c90e4400b1763d9..19bd30d9665dc1e8f9d475868cabbf14c8847352 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -505,11 +505,19 @@ All parameter, weight, gradient are variables in Paddle. scope, local_scopes, allow_op_delay); }) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) + // NOTE: even we return a vec* to Python use reference policy. + // We still cannot get local_scope from this vector, since the element + // of vec will be freed by Python GC. We can only return Scope* + // one by one and mark them as reference. .def("local_scopes", [](ParallelExecutor &self) -> std::vector * { return &self.GetLocalScopes(); }, py::return_value_policy::reference) + .def("feed_tensors_into_local_scopes", + &ParallelExecutor::FeedTensorsIntoLocalScopes) + .def("feed_and_split_tensor_into_local_scopes", + &ParallelExecutor::FeedAndSplitTensorIntoLocalScopes) .def("run", &ParallelExecutor::Run); BindRecordIOWriter(&m); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 4a9dbd324c90380e784cc9457845fabd858585be..159d1d5f4e70033fabf93514bd63b38f83675bff 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -190,6 +190,11 @@ void PyCUDATensorSetFromArray( static_cast(pool.Get(place)); paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice, dev_ctx->stream()); + // NOTE: For safety, here wait the copy complete. + // It because the CPU array.data() could be destroyed after this method. + // If we make this method async, it could be copied data from a memory buffer + // that has been freed. + dev_ctx->Wait(); } template <> @@ -216,6 +221,11 @@ void PyCUDATensorSetFromArray( paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(uint16_t) * array.size(), cudaMemcpyHostToDevice, dev_ctx->stream()); + // NOTE: For safety, here wait the copy complete. + // It because the CPU array.data() could be destroyed after this method. + // If we make this method async, it could be copied data from a memory buffer + // that has been freed. + dev_ctx->Wait(); } template diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 7ad5f0d740ac22c4f72eb5427c7f81aaa0a7a3dc..fbdd6fd449625a21f91758dc12490b02070aea1a 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -17,6 +17,7 @@ import multiprocessing import framework import executor import warnings +import sys __all__ = ['ParallelExecutor'] @@ -103,8 +104,8 @@ class ParallelExecutor(object): self.persistable_vars = [ v.name - for v in filter(lambda var: \ - var.persistable and var.type != core.VarDesc.VarType.RAW, + for v in filter( + lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW, main.list_vars()) ] @@ -124,34 +125,93 @@ class ParallelExecutor(object): allow_op_delay) self.scope = scope - def run(self, fetch_list, feed={}, feed_dict={}): + def run(self, fetch_list, feed=None, feed_dict=None): """ - :param fetch_list: A list of variable names that will be fetched. - :param feed: A dict mapping for feed variable name to LoDTensor - or numpy array. - :return: fetched value list. + Run a parallel executor with fetch_list. + + The feed parameter can be a dict or a list. If feed is a dict, the + feed data will be split into multiple devices. If feed is a list, we + assume the data has been splitted into multiple devices, the each + element in the list will be copied to each device directly. + + For example, if the feed is a dict: + >>> exe = ParallelExecutor() + >>> # the image will be splitted into devices. If there is two devices + >>> # each device will process an image with shape (24, 1, 28, 28) + >>> exe.run(feed={'image': numpy.random.random(size=(48, 1, 28, 28))}) + + For example, if the feed is a list: + >>> exe = ParallelExecutor() + >>> # each device will process each element in the list. + >>> # the 1st device will process an image with shape (48, 1, 28, 28) + >>> # the 2nd device will process an image with shape (32, 1, 28, 28) + >>> # + >>> # you can use exe.device_count to get the device number. + >>> exe.run(feed=[{"image": numpy.random.random(size=(48, 1, 28, 28))}, + >>> {"image": numpy.random.random(size=(32, 1, 28, 28))}, + >>> ]) + + + Args: + fetch_list(list): The fetched variable names + feed(list|dict|None): The feed variables. If the feed is a dict, + tensors in that dict will be splitted into each devices. If + the feed is a list, each element of the list will be copied + to each device. + feed_dict: Alias for feed parameter, for backward compatibility. + This parameter is deprecated. + + Returns: fetched result list. + """ - if not feed_dict == {}: - warnings.warn( - "The 'feed_dict' of ParallelExecutor.run() is deprecated. Please use 'feed' instead." - ) - if feed == {}: + if feed is None and feed_dict is not None: feed = feed_dict - if not isinstance(feed, dict): - raise TypeError("feed should be a dict") - - feed_tensor_dict = {} - for i, feed_name in enumerate(feed): - feed_tensor = feed[feed_name] - if not isinstance(feed_tensor, core.LoDTensor): - feed_tensor = core.LoDTensor() - feed_tensor.set(feed[feed_name], self._act_places[0]) - feed_tensor_dict[feed_name] = feed_tensor + print >> sys.stderr, "`feed_dict` is deprecated. Please use `feed=`" + + if isinstance(feed, dict): + feed_tensor_dict = dict() + for feed_name in feed: + feed_tensor = feed[feed_name] + if not isinstance(feed_tensor, core.LoDTensor): + feed_tensor = core.LoDTensor() + # always set to CPU place, since the tensor need to be splitted + # it is fast in CPU + feed_tensor.set(feed[feed_name], core.CPUPlace()) + feed_tensor_dict[feed_name] = feed_tensor + + self.executor.feed_and_split_tensor_into_local_scopes( + feed_tensor_dict) + elif isinstance(feed, list) or isinstance(feed, tuple): + if len(feed) != len(self._act_places): + raise ValueError( + "Feed a list of tensor, the list should be the same size as places" + ) + + res = list() + + for i, each in enumerate(feed): + if not isinstance(each, dict): + raise TypeError( + "Each element of feed list should be a dict") + res_dict = dict() + for feed_name in each: + tensor = each[feed_name] + if not isinstance(tensor, core.LoDTensor): + tmp = core.LoDTensor() + tmp.set(tensor, self._act_places[i]) + tensor = tmp + res_dict[feed_name] = tensor + res.append(res_dict) + self.executor.feed_tensors_into_local_scopes(res) fetch_var_name = '@FETCHED_VAR_NAME@' - self.executor.run(fetch_list, fetch_var_name, feed_tensor_dict) + self.executor.run(fetch_list, fetch_var_name) arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() return [arr[i] for i in range(len(arr))] def bcast_params(self): self.executor.bcast_params(set(self.persistable_vars)) + + @property + def device_count(self): + return len(self._act_places) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 356c3e64b3d03b520a1bec5b5e0174e1d8ee23e8..8f17eeea139c86055f0d4a06b21cbf66d8395cdc 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1,10 +1,13 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -# The fully connected test is removed whe the WITH_MKLDNN flag is OFF -# Because the fully connected layer has only one kernel (MKLDNN) +# The MKLDNN tests are skiped when the MKLDNN flag is OFF if(NOT WITH_MKLDNN) - list(REMOVE_ITEM TEST_OPS test_fc_op) + foreach(src ${TEST_OPS}) + if(${src} MATCHES ".*_mkldnn_op$") + list(REMOVE_ITEM TEST_OPS ${src}) + endif() + endforeach() endif(NOT WITH_MKLDNN) if(NOT WITH_DISTRIBUTE) diff --git a/python/paddle/fluid/tests/unittests/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_activation_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7d554c2276c9acd710d14c8f8b32c802e3e17515 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_activation_mkldnn_op.py @@ -0,0 +1,99 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest +from scipy.special import expit +from test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs + + +class TestMKLDNNReluDim2(TestRelu): + def setUp(self): + super(TestMKLDNNReluDim2, self).setUp() + + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNTanhDim2(TestTanh): + def setUp(self): + super(TestMKLDNNTanhDim2, self).setUp() + + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNSqrtDim2(TestSqrt): + def setUp(self): + super(TestMKLDNNSqrtDim2, self).setUp() + + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNAbsDim2(TestAbs): + def setUp(self): + super(TestMKLDNNAbsDim2, self).setUp() + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNReluDim4(TestRelu): + def setUp(self): + super(TestMKLDNNReluDim4, self).setUp() + + x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32") + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + out = np.maximum(x, 0) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNTanhDim4(TestTanh): + def setUp(self): + super(TestMKLDNNTanhDim4, self).setUp() + + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") + } + self.outputs = {'Out': np.tanh(self.inputs['X'])} + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNSqrtDim4(TestSqrt): + def setUp(self): + super(TestMKLDNNSqrtDim4, self).setUp() + + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") + } + self.outputs = {'Out': np.sqrt(self.inputs['X'])} + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNAbsDim4(TestAbs): + def setUp(self): + super(TestMKLDNNAbsDim4, self).setUp() + + x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32") + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + self.inputs = {'X': x} + self.outputs = {'Out': np.abs(self.inputs['X'])} + self.attrs = {"use_mkldnn": True} + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 57d4a50e913c0d2994c62600f4e479056ed4c306..c9069777faf9d141db93184e8b1e6dc2a7034980 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1098,82 +1098,5 @@ class TestFP16Swish(TestSwish): self.check_output_with_place(place, atol=1e-3) -#--------------------test MKLDNN-------------------- -class TestMKLDNNReluDim2(TestRelu): - def setUp(self): - super(TestMKLDNNReluDim2, self).setUp() - - self.attrs = {"use_mkldnn": True} - - -class TestMKLDNNTanhDim2(TestTanh): - def setUp(self): - super(TestMKLDNNTanhDim2, self).setUp() - - self.attrs = {"use_mkldnn": True} - - -class TestMKLDNNSqrtDim2(TestSqrt): - def setUp(self): - super(TestMKLDNNSqrtDim2, self).setUp() - - self.attrs = {"use_mkldnn": True} - - -class TestMKLDNNAbsDim2(TestAbs): - def setUp(self): - super(TestMKLDNNAbsDim2, self).setUp() - - self.attrs = {"use_mkldnn": True} - - -class TestMKLDNNReluDim4(TestRelu): - def setUp(self): - super(TestMKLDNNReluDim4, self).setUp() - - x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32") - # The same reason with TestAbs - x[np.abs(x) < 0.005] = 0.02 - out = np.maximum(x, 0) - - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.outputs = {'Out': out} - self.attrs = {"use_mkldnn": True} - - -class TestMKLDNNTanhDim4(TestTanh): - def setUp(self): - super(TestMKLDNNTanhDim4, self).setUp() - - self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") - } - self.outputs = {'Out': np.tanh(self.inputs['X'])} - self.attrs = {"use_mkldnn": True} - - -class TestMKLDNNSqrtDim4(TestSqrt): - def setUp(self): - super(TestMKLDNNSqrtDim4, self).setUp() - - self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") - } - self.outputs = {'Out': np.sqrt(self.inputs['X'])} - self.attrs = {"use_mkldnn": True} - - -class TestMKLDNNAbsDim4(TestAbs): - def setUp(self): - super(TestMKLDNNAbsDim4, self).setUp() - - x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32") - # The same reason with TestAbs - x[np.abs(x) < 0.005] = 0.02 - self.inputs = {'X': x} - self.outputs = {'Out': np.abs(self.inputs['X'])} - self.attrs = {"use_mkldnn": True} - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..db6be21baaa54d33af9f5c44d1815e4b389eb884 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py @@ -0,0 +1,36 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride + + +class TestMKLDNN(TestConv2dOp): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNWithPad(TestWithPad): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNWithStride(TestWithStride): + def init_kernel_type(self): + self.use_mkldnn = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 65606a0b4373b28036096cf046da5143a3b8bcd0..a478649541ba9828e55c4239090d5aee554223ac 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -373,22 +373,5 @@ class TestDepthwiseConv2(TestConv2dOp): # def init_op_type(self): # self.op_type = "conv_cudnn" - -#----------------Conv2dMKLDNN---------------- -class TestMKLDNN(TestConv2dOp): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNWithPad(TestWithPad): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNWithStride(TestWithStride): - def init_kernel_type(self): - self.use_mkldnn = True - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fc_op.py b/python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py similarity index 100% rename from python/paddle/fluid/tests/unittests/test_fc_op.py rename to python/paddle/fluid/tests/unittests/test_fc_mkldnn_op.py diff --git a/python/paddle/fluid/tests/unittests/test_lrn_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..966a16dc870c041b9deb140bed57d907cf305fd8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lrn_mkldnn_op.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from test_lrn_op import TestLRNOp + + +class TestLRNMKLDNNOp(TestLRNOp): + def get_attrs(self): + attrs = TestLRNOp.get_attrs(self) + attrs['use_mkldnn'] = True + return attrs + + def test_check_output(self): + self.check_output(atol=0.002) + + +class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp): + def get_attrs(self): + attrs = TestLRNMKLDNNOp.get_attrs(self) + attrs['is_test'] = True + return attrs + + def test_check_grad_normal(self): + def check_raise_is_test(): + try: + self.check_grad(['X'], 'Out', max_relative_error=0.01) + except Exception as e: + t = \ + "is_test attribute should be set to False in training phase." + if t in str(e): + raise AttributeError + + self.assertRaises(AttributeError, check_raise_is_test) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lrn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_op.py index 8fa480b9bce84d2936f23cce9e41e8e54014b074..eaff45cbb2a58798e9d55149510bec72eea370cd 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lrn_op.py @@ -87,34 +87,5 @@ class TestLRNOp(OpTest): self.check_grad(['X'], 'Out', max_relative_error=0.01) -class TestLRNMKLDNNOp(TestLRNOp): - def get_attrs(self): - attrs = TestLRNOp.get_attrs(self) - attrs['use_mkldnn'] = True - return attrs - - def test_check_output(self): - self.check_output(atol=0.002) - - -class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp): - def get_attrs(self): - attrs = TestLRNMKLDNNOp.get_attrs(self) - attrs['is_test'] = True - return attrs - - def test_check_grad_normal(self): - def check_raise_is_test(): - try: - self.check_grad(['X'], 'Out', max_relative_error=0.01) - except Exception as e: - t = \ - "is_test attribute should be set to False in training phase." - if t in str(e): - raise AttributeError - - self.assertRaises(AttributeError, check_raise_is_test) - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index b653f2c11f21e1d42a0b844d73fe34a7b778b744..c783a142467f3f6a9cd210425acfc526a32a6f71 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -203,7 +203,7 @@ class TestParallelExecutorBase(unittest.TestCase): iter=50, batch_size=None, allow_op_delay=False, - feed_dict={}, + feed_dict=None, seed=None, use_parallel_executor=True): def run_executor(exe, feed, fetch_list, program=None): @@ -223,7 +223,7 @@ class TestParallelExecutorBase(unittest.TestCase): with fluid.program_guard(main, startup): if seed is not None: startup.random_seed = seed - loss = method(use_feed=len(feed_dict) > 0) + loss = method(use_feed=feed_dict is not None) adam = fluid.optimizer.Adam() adam.minimize(loss) if memory_opt: diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..003ebba18b26198427d9f313596ae85656ac24fa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py @@ -0,0 +1,50 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from test_pool2d_op import TestPool2d_Op, TestCase1, TestCase2, TestCase3, TestCase4, TestCase5 + + +class TestMKLDNNCase1(TestPool2d_Op): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNCase2(TestCase1): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNCase3(TestCase2): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNCase4(TestCase3): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNCase5(TestCase4): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNCase6(TestCase5): + def init_kernel_type(self): + self.use_mkldnn = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_op.py index 764fa575fba1615de3171e848890b3836e640849..328a9ffd25b9fce3fd45bbe847e365f090acd17c 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_op.py @@ -317,36 +317,5 @@ class TestCeilModeCase4(TestCase2): self.ceil_mode = True -#--------------------test pool2d MKLDNN-------------------- -class TestMKLDNNCase1(TestPool2d_Op): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNCase2(TestCase1): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNCase3(TestCase2): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNCase4(TestCase3): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNCase5(TestCase4): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNCase6(TestCase5): - def init_kernel_type(self): - self.use_mkldnn = True - - if __name__ == '__main__': unittest.main()