提交 cd27e536 编写于 作者: M minqiyang

Fix python code in contrib

上级 d8ddd3b3
......@@ -201,6 +201,14 @@ include(external/snappy) # download snappy
include(external/snappystream)
include(external/threadpool)
if(WITH_GPU)
include(cuda)
include(tensorrt)
include(external/anakin)
else()
set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE)
endif()
include(cudnn) # set cudnn libraries, must before configure
include(cupti)
include(configure) # add paddle env configuration
......@@ -229,14 +237,6 @@ set(EXTERNAL_LIBS
${PYTHON_LIBRARIES}
)
if(WITH_GPU)
include(cuda)
include(tensorrt)
include(external/anakin)
else()
set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE)
endif()
if(WITH_AMD_GPU)
find_package(HIP)
include(hip)
......
......@@ -21,6 +21,7 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS
${CUDNN_ROOT}/lib64
${CUDNN_ROOT}/lib
${CUDNN_ROOT}/lib/${TARGET_ARCH}-linux-gnu
${CUDNN_ROOT}/local/cuda-${CUDA_VERSION}/targets/${TARGET_ARCH}-linux/lib/
$ENV{CUDNN_ROOT}
$ENV{CUDNN_ROOT}/lib64
$ENV{CUDNN_ROOT}/lib
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import decoder
from decoder import *
from . import decoder
from .decoder import *
__all__ = decoder.__all__
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import beam_search_decoder
from beam_search_decoder import *
from . import beam_search_decoder
from .beam_search_decoder import *
__all__ = beam_search_decoder.__all__
......@@ -191,7 +191,7 @@ class StateCell(object):
self._helper = LayerHelper('state_cell', name=name)
self._cur_states = {}
self._state_names = []
for state_name, state in states.items():
for state_name, state in six.iteritems(states):
if not isinstance(state, InitState):
raise ValueError('state must be an InitState object.')
self._cur_states[state_name] = state
......@@ -346,7 +346,7 @@ class StateCell(object):
if self._in_decoder and not self._switched_decoder:
self._switch_decoder()
for input_name, input_value in inputs.items():
for input_name, input_value in six.iteritems(inputs):
if input_name not in self._inputs:
raise ValueError('Unknown input %s. '
'Please make sure %s in input '
......@@ -361,7 +361,7 @@ class StateCell(object):
if self._in_decoder and not self._switched_decoder:
self._switched_decoder()
for state_name, decoder_state in self._states_holder.items():
for state_name, decoder_state in six.iteritems(self._states_holder):
if id(self._cur_decoder_obj) not in decoder_state:
raise ValueError('Unknown decoder object, please make sure '
'switch_decoder been invoked.')
......@@ -671,7 +671,7 @@ class BeamSearchDecoder(object):
feed_dict = {}
update_dict = {}
for init_var_name, init_var in self._input_var_dict.items():
for init_var_name, init_var in six.iteritems(self._input_var_dict):
if init_var_name not in self.state_cell._inputs:
raise ValueError('Variable ' + init_var_name +
' not found in StateCell!\n')
......@@ -721,7 +721,8 @@ class BeamSearchDecoder(object):
self.state_cell.update_states()
self.update_array(prev_ids, selected_ids)
self.update_array(prev_scores, selected_scores)
for update_name, var_to_update in update_dict.items():
for update_name, var_to_update in six.iteritems(
update_dict):
self.update_array(var_to_update, feed_dict[update_name])
def read_array(self, init, is_ids=False, is_scores=False):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册