提交 cd27e536 编写于 作者: M minqiyang

Fix python code in contrib

上级 d8ddd3b3
...@@ -201,6 +201,14 @@ include(external/snappy) # download snappy ...@@ -201,6 +201,14 @@ include(external/snappy) # download snappy
include(external/snappystream) include(external/snappystream)
include(external/threadpool) include(external/threadpool)
if(WITH_GPU)
include(cuda)
include(tensorrt)
include(external/anakin)
else()
set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE)
endif()
include(cudnn) # set cudnn libraries, must before configure include(cudnn) # set cudnn libraries, must before configure
include(cupti) include(cupti)
include(configure) # add paddle env configuration include(configure) # add paddle env configuration
...@@ -229,14 +237,6 @@ set(EXTERNAL_LIBS ...@@ -229,14 +237,6 @@ set(EXTERNAL_LIBS
${PYTHON_LIBRARIES} ${PYTHON_LIBRARIES}
) )
if(WITH_GPU)
include(cuda)
include(tensorrt)
include(external/anakin)
else()
set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE)
endif()
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
find_package(HIP) find_package(HIP)
include(hip) include(hip)
......
...@@ -21,6 +21,7 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS ...@@ -21,6 +21,7 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS
${CUDNN_ROOT}/lib64 ${CUDNN_ROOT}/lib64
${CUDNN_ROOT}/lib ${CUDNN_ROOT}/lib
${CUDNN_ROOT}/lib/${TARGET_ARCH}-linux-gnu ${CUDNN_ROOT}/lib/${TARGET_ARCH}-linux-gnu
${CUDNN_ROOT}/local/cuda-${CUDA_VERSION}/targets/${TARGET_ARCH}-linux/lib/
$ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}
$ENV{CUDNN_ROOT}/lib64 $ENV{CUDNN_ROOT}/lib64
$ENV{CUDNN_ROOT}/lib $ENV{CUDNN_ROOT}/lib
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import decoder from . import decoder
from decoder import * from .decoder import *
__all__ = decoder.__all__ __all__ = decoder.__all__
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import beam_search_decoder from . import beam_search_decoder
from beam_search_decoder import * from .beam_search_decoder import *
__all__ = beam_search_decoder.__all__ __all__ = beam_search_decoder.__all__
...@@ -191,7 +191,7 @@ class StateCell(object): ...@@ -191,7 +191,7 @@ class StateCell(object):
self._helper = LayerHelper('state_cell', name=name) self._helper = LayerHelper('state_cell', name=name)
self._cur_states = {} self._cur_states = {}
self._state_names = [] self._state_names = []
for state_name, state in states.items(): for state_name, state in six.iteritems(states):
if not isinstance(state, InitState): if not isinstance(state, InitState):
raise ValueError('state must be an InitState object.') raise ValueError('state must be an InitState object.')
self._cur_states[state_name] = state self._cur_states[state_name] = state
...@@ -346,7 +346,7 @@ class StateCell(object): ...@@ -346,7 +346,7 @@ class StateCell(object):
if self._in_decoder and not self._switched_decoder: if self._in_decoder and not self._switched_decoder:
self._switch_decoder() self._switch_decoder()
for input_name, input_value in inputs.items(): for input_name, input_value in six.iteritems(inputs):
if input_name not in self._inputs: if input_name not in self._inputs:
raise ValueError('Unknown input %s. ' raise ValueError('Unknown input %s. '
'Please make sure %s in input ' 'Please make sure %s in input '
...@@ -361,7 +361,7 @@ class StateCell(object): ...@@ -361,7 +361,7 @@ class StateCell(object):
if self._in_decoder and not self._switched_decoder: if self._in_decoder and not self._switched_decoder:
self._switched_decoder() self._switched_decoder()
for state_name, decoder_state in self._states_holder.items(): for state_name, decoder_state in six.iteritems(self._states_holder):
if id(self._cur_decoder_obj) not in decoder_state: if id(self._cur_decoder_obj) not in decoder_state:
raise ValueError('Unknown decoder object, please make sure ' raise ValueError('Unknown decoder object, please make sure '
'switch_decoder been invoked.') 'switch_decoder been invoked.')
...@@ -671,7 +671,7 @@ class BeamSearchDecoder(object): ...@@ -671,7 +671,7 @@ class BeamSearchDecoder(object):
feed_dict = {} feed_dict = {}
update_dict = {} update_dict = {}
for init_var_name, init_var in self._input_var_dict.items(): for init_var_name, init_var in six.iteritems(self._input_var_dict):
if init_var_name not in self.state_cell._inputs: if init_var_name not in self.state_cell._inputs:
raise ValueError('Variable ' + init_var_name + raise ValueError('Variable ' + init_var_name +
' not found in StateCell!\n') ' not found in StateCell!\n')
...@@ -721,7 +721,8 @@ class BeamSearchDecoder(object): ...@@ -721,7 +721,8 @@ class BeamSearchDecoder(object):
self.state_cell.update_states() self.state_cell.update_states()
self.update_array(prev_ids, selected_ids) self.update_array(prev_ids, selected_ids)
self.update_array(prev_scores, selected_scores) self.update_array(prev_scores, selected_scores)
for update_name, var_to_update in update_dict.items(): for update_name, var_to_update in six.iteritems(
update_dict):
self.update_array(var_to_update, feed_dict[update_name]) self.update_array(var_to_update, feed_dict[update_name])
def read_array(self, init, is_ids=False, is_scores=False): def read_array(self, init, is_ids=False, is_scores=False):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册