未验证 提交 d2243979 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #736 from kuke/decoder_wrapper

Add pybind11 wrapper for decoder
...@@ -50,6 +50,13 @@ def batch_to_ndarray(batch_samples, lod): ...@@ -50,6 +50,13 @@ def batch_to_ndarray(batch_samples, lod):
return (batch_feature, batch_label) return (batch_feature, batch_label)
def split_infer_result(infer_seq, lod):
infer_batch = []
for i in xrange(0, len(lod[0]) - 1):
infer_batch.append(infer_seq[lod[0][i]:lod[0][i + 1]])
return infer_batch
class DaemonProcessGroup(object): class DaemonProcessGroup(object):
def __init__(self, proc_num, target, args): def __init__(self, proc_num, target, args):
self._proc_num = proc_num self._proc_num = proc_num
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "decoder.h"
std::string decode(std::vector<std::vector<float>> probs_mat) {
// Add decoding logic here
return "example decoding result";
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
std::string decode(std::vector<std::vector<float>> probs_mat);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "decoder.h"
namespace py = pybind11;
PYBIND11_MODULE(decoder, m) {
m.doc() = "Decode function for Deep ASR model";
m.def("decode",
&decode,
"Decode one input probability matrix "
"and return the transcription");
}
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from distutils.core import setup, Extension
from distutils.sysconfig import get_config_vars
args = ['-std=c++11']
# remove warning about -Wstrict-prototypes
(opt, ) = get_config_vars('OPT')
os.environ['OPT'] = " ".join(flag for flag in opt.split()
if flag != '-Wstrict-prototypes')
ext_modules = [
Extension(
'decoder',
['pybind.cc', 'decoder.cc'],
include_dirs=['pybind11/include', '.'],
language='c++',
extra_compile_args=args, ),
]
setup(
name='decoder',
version='0.0.1',
author='Paddle',
author_email='',
description='Decoder for Deep ASR model',
ext_modules=ext_modules, )
if [ ! -d pybind11 ]; then
git clone https://github.com/pybind/pybind11.git
fi
python setup.py build_ext -i
...@@ -10,6 +10,7 @@ import data_utils.augmentor.trans_add_delta as trans_add_delta ...@@ -10,6 +10,7 @@ import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice import data_utils.augmentor.trans_splice as trans_splice
import data_utils.data_reader as reader import data_utils.data_reader as reader
from data_utils.util import lodtensor_to_ndarray from data_utils.util import lodtensor_to_ndarray
from data_utils.util import split_infer_result
def parse_args(): def parse_args():
...@@ -58,13 +59,6 @@ def print_arguments(args): ...@@ -58,13 +59,6 @@ def print_arguments(args):
print('------------------------------------------------') print('------------------------------------------------')
def split_infer_result(infer_seq, lod):
infer_batch = []
for i in xrange(0, len(lod[0]) - 1):
infer_batch.append(infer_seq[lod[0][i]:lod[0][i + 1]])
return infer_batch
def infer(args): def infer(args):
""" Gets one batch of feature data and predicts labels for each sample. """ Gets one batch of feature data and predicts labels for each sample.
""" """
......
...@@ -13,8 +13,10 @@ import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm ...@@ -13,8 +13,10 @@ import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice import data_utils.augmentor.trans_splice as trans_splice
import data_utils.async_data_reader as reader import data_utils.async_data_reader as reader
import decoder.decoder as decoder
from data_utils.util import lodtensor_to_ndarray from data_utils.util import lodtensor_to_ndarray
from model_utils.model import stacked_lstmp_model from model_utils.model import stacked_lstmp_model
from data_utils.util import split_infer_result
def parse_args(): def parse_args():
...@@ -141,13 +143,20 @@ def infer_from_ckpt(args): ...@@ -141,13 +143,20 @@ def infer_from_ckpt(args):
infer_data_reader.recycle(features, labels, lod) infer_data_reader.recycle(features, labels, lod)
cost, acc = exe.run(infer_program, results = exe.run(infer_program,
feed={"feature": feature_t, feed={"feature": feature_t,
"label": label_t}, "label": label_t},
fetch_list=[avg_cost, accuracy], fetch_list=[prediction, avg_cost, accuracy],
return_numpy=False) return_numpy=False)
infer_costs.append(lodtensor_to_ndarray(cost)[0]) infer_costs.append(lodtensor_to_ndarray(results[1])[0])
infer_accs.append(lodtensor_to_ndarray(acc)[0]) infer_accs.append(lodtensor_to_ndarray(results[2])[0])
probs, lod = lodtensor_to_ndarray(results[0])
infer_batch = split_infer_result(probs, lod)
for index, sample in enumerate(infer_batch):
print("Decoding %d: " % (batch_id * args.batch_size + index),
decoder.decode(sample))
print(np.mean(infer_costs), np.mean(infer_accs)) print(np.mean(infer_costs), np.mean(infer_accs))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册