From cd27e536671962bf2082780ac93926e7d7c54860 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 2 Aug 2018 15:47:15 +0800 Subject: [PATCH] Fix python code in contrib --- CMakeLists.txt | 16 ++++++++-------- cmake/cudnn.cmake | 1 + python/paddle/fluid/contrib/__init__.py | 4 ++-- python/paddle/fluid/contrib/decoder/__init__.py | 4 ++-- .../fluid/contrib/decoder/beam_search_decoder.py | 11 ++++++----- 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f180b5cfa6..2cab76e8f0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -201,6 +201,14 @@ include(external/snappy) # download snappy include(external/snappystream) include(external/threadpool) +if(WITH_GPU) + include(cuda) + include(tensorrt) + include(external/anakin) +else() + set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE) +endif() + include(cudnn) # set cudnn libraries, must before configure include(cupti) include(configure) # add paddle env configuration @@ -229,14 +237,6 @@ set(EXTERNAL_LIBS ${PYTHON_LIBRARIES} ) -if(WITH_GPU) - include(cuda) - include(tensorrt) - include(external/anakin) -else() - set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE) -endif() - if(WITH_AMD_GPU) find_package(HIP) include(hip) diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 2c84061ff5..9eebea816c 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -21,6 +21,7 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS ${CUDNN_ROOT}/lib64 ${CUDNN_ROOT}/lib ${CUDNN_ROOT}/lib/${TARGET_ARCH}-linux-gnu + ${CUDNN_ROOT}/local/cuda-${CUDA_VERSION}/targets/${TARGET_ARCH}-linux/lib/ $ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/lib64 $ENV{CUDNN_ROOT}/lib diff --git a/python/paddle/fluid/contrib/__init__.py b/python/paddle/fluid/contrib/__init__.py index 12cd5d918e..a183543d07 100644 --- a/python/paddle/fluid/contrib/__init__.py +++ b/python/paddle/fluid/contrib/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import decoder -from decoder import * +from . import decoder +from .decoder import * __all__ = decoder.__all__ diff --git a/python/paddle/fluid/contrib/decoder/__init__.py b/python/paddle/fluid/contrib/decoder/__init__.py index 22cfe69269..6343c1543d 100644 --- a/python/paddle/fluid/contrib/decoder/__init__.py +++ b/python/paddle/fluid/contrib/decoder/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import beam_search_decoder -from beam_search_decoder import * +from . import beam_search_decoder +from .beam_search_decoder import * __all__ = beam_search_decoder.__all__ diff --git a/python/paddle/fluid/contrib/decoder/beam_search_decoder.py b/python/paddle/fluid/contrib/decoder/beam_search_decoder.py index ba6e138782..24e920f01f 100644 --- a/python/paddle/fluid/contrib/decoder/beam_search_decoder.py +++ b/python/paddle/fluid/contrib/decoder/beam_search_decoder.py @@ -191,7 +191,7 @@ class StateCell(object): self._helper = LayerHelper('state_cell', name=name) self._cur_states = {} self._state_names = [] - for state_name, state in states.items(): + for state_name, state in six.iteritems(states): if not isinstance(state, InitState): raise ValueError('state must be an InitState object.') self._cur_states[state_name] = state @@ -346,7 +346,7 @@ class StateCell(object): if self._in_decoder and not self._switched_decoder: self._switch_decoder() - for input_name, input_value in inputs.items(): + for input_name, input_value in six.iteritems(inputs): if input_name not in self._inputs: raise ValueError('Unknown input %s. ' 'Please make sure %s in input ' @@ -361,7 +361,7 @@ class StateCell(object): if self._in_decoder and not self._switched_decoder: self._switched_decoder() - for state_name, decoder_state in self._states_holder.items(): + for state_name, decoder_state in six.iteritems(self._states_holder): if id(self._cur_decoder_obj) not in decoder_state: raise ValueError('Unknown decoder object, please make sure ' 'switch_decoder been invoked.') @@ -671,7 +671,7 @@ class BeamSearchDecoder(object): feed_dict = {} update_dict = {} - for init_var_name, init_var in self._input_var_dict.items(): + for init_var_name, init_var in six.iteritems(self._input_var_dict): if init_var_name not in self.state_cell._inputs: raise ValueError('Variable ' + init_var_name + ' not found in StateCell!\n') @@ -721,7 +721,8 @@ class BeamSearchDecoder(object): self.state_cell.update_states() self.update_array(prev_ids, selected_ids) self.update_array(prev_scores, selected_scores) - for update_name, var_to_update in update_dict.items(): + for update_name, var_to_update in six.iteritems( + update_dict): self.update_array(var_to_update, feed_dict[update_name]) def read_array(self, init, is_ids=False, is_scores=False): -- GitLab