diff --git a/.gitignore b/.gitignore index db91b19e48091d4a62d8f02bd9db6e57bc729a95..b81c5d37d808d79512a206e07f6f35f1e8e3a317 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ tools/venv *.tar *.tar.gz .ipynb_checkpoints +*.npz diff --git a/.notebook/u2_model.ipynb b/.notebook/u2_model.ipynb index 9658af0ef6b78da73755fedaeaeeea9f2cd872f0..7f17b921e17aa70c651cac9c58a54c0e2dc13865 100644 --- a/.notebook/u2_model.ipynb +++ b/.notebook/u2_model.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "comic-scotland", + "id": "warming-contrast", "metadata": {}, "outputs": [ { @@ -32,7 +32,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "trying-palestinian", + "id": "genuine-marker", "metadata": {}, "outputs": [ { @@ -42,39 +42,39 @@ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " def convert_to_list(value, n, name, dtype=np.int):\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:93] register user softmax to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:97] register user log_softmax to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:101] register user sigmoid to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:105] register user log_sigmoid to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:109] register user relu to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:119] override cat of paddle if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:133] override item of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:144] override long of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:164] override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:179] override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:185] override eq of paddle if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:195] override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:212] override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:223] register user view to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:233] register user view_as to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:259] register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:277] register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:288] register user fill_ to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:298] register user repeat to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:303] register user softmax to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:308] register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:312] register user relu to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:322] register user type_as to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:337] register user to to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:346] register user float to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:356] register user tolist to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:371] register user glu to paddle.nn.functional, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:422] override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:428] register user Module to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:434] register user ModuleList to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:450] register user GLU to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:483] register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 08:20:33 __init__.py:489] register user export to paddle.jit, remove this when fixed!\n" + "[WARNING 2021/04/16 10:35:27 __init__.py:93] register user softmax to paddle, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:97] register user log_softmax to paddle, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:101] register user sigmoid to paddle, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:105] register user log_sigmoid to paddle, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:109] register user relu to paddle, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:119] override cat of paddle if exists or register, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:133] override item of paddle.Tensor if exists or register, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:144] override long of paddle.Tensor if exists or register, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:164] override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:179] override eq of paddle.Tensor if exists or register, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:185] override eq of paddle if exists or register, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:195] override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:212] override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:223] register user view to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:233] register user view_as to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:259] register user masked_fill to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:277] register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:288] register user fill_ to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:298] register user repeat to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:303] register user softmax to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:308] register user sigmoid to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:312] register user relu to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:322] register user type_as to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:337] register user to to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:346] register user float to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:356] register user tolist to paddle.Tensor, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:371] register user glu to paddle.nn.functional, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:422] override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:428] register user Module to paddle.nn, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:434] register user ModuleList to paddle.nn, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:450] register user GLU to paddle.nn, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:483] register user ConstantPad2d to paddle.nn, remove this when fixed!\n", + "[WARNING 2021/04/16 10:35:27 __init__.py:489] register user export to paddle.jit, remove this when fixed!\n" ] } ], @@ -91,7 +91,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "committed-glance", + "id": "accepting-genesis", "metadata": {}, "outputs": [ { @@ -100,8 +100,8 @@ "text": [ "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n", - "[INFO 2021/04/16 08:20:34 u2.py:834] U2 Encoder type: conformer\n", - "[INFO 2021/04/16 08:20:34 u2.py:834] U2 Encoder type: conformer\n" + "[INFO 2021/04/16 10:35:28 u2.py:834] U2 Encoder type: conformer\n", + "[INFO 2021/04/16 10:35:28 u2.py:834] U2 Encoder type: conformer\n" ] }, { @@ -608,11 +608,11 @@ "encoder.encoders.11.norm_final.bias | [256] | 256 | True\n", "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n", "encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n", - "decoder.embed.0.weight | [4223, 256] | 1081088 | True\n", + "decoder.embed.0.weight | [4233, 256] | 1083648 | True\n", "decoder.after_norm.weight | [256] | 256 | True\n", "decoder.after_norm.bias | [256] | 256 | True\n", - "decoder.output_layer.weight | [256, 4223] | 1081088 | True\n", - "decoder.output_layer.bias | [4223] | 4223 | True\n", + "decoder.output_layer.weight | [256, 4233] | 1083648 | True\n", + "decoder.output_layer.bias | [4233] | 4233 | True\n", "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", @@ -793,9 +793,9 @@ "decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n", "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n", "decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n", - "ctc.ctc_lo.weight | [256, 4223] | 1081088 | True\n", - "ctc.ctc_lo.bias | [4223] | 4223 | True\n", - "Total parameters: 687.0, 49347582.0 elements.\n" + "ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n", + "ctc.ctc_lo.bias | [4233] | 4233 | True\n", + "Total parameters: 687.0, 49355282.0 elements.\n" ] } ], @@ -803,7 +803,7 @@ "conf_str='examples/aishell/s1/conf/conformer.yaml'\n", "cfg = CN().load_cfg(open(conf_str))\n", "cfg.model.input_dim = 80\n", - "cfg.model.output_dim = 4223\n", + "cfg.model.output_dim = 4233\n", "cfg.model.cmvn_file = \"/workspace/wenet/examples/aishell/s0/raw_wav/train/global_cmvn\"\n", "cfg.model.cmvn_file_type = 'json'\n", "cfg.freeze()\n", @@ -815,7 +815,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "reserved-nightlife", + "id": "baking-ozone", "metadata": {}, "outputs": [ { @@ -1324,11 +1324,11 @@ "encoder.encoders.11.norm_final.bias | [256] | 256\n", "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072\n", "encoder.encoders.11.concat_linear.bias | [256] | 256\n", - "decoder.embed.0.weight | [4223, 256] | 1081088\n", + "decoder.embed.0.weight | [4233, 256] | 1083648\n", "decoder.after_norm.weight | [256] | 256\n", "decoder.after_norm.bias | [256] | 256\n", - "decoder.output_layer.weight | [256, 4223] | 1081088\n", - "decoder.output_layer.bias | [4223] | 4223\n", + "decoder.output_layer.weight | [256, 4233] | 1083648\n", + "decoder.output_layer.bias | [4233] | 4233\n", "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n", "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n", @@ -1427,7 +1427,13 @@ "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256\n", "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n", "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536\n", + "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256\n", "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536\n", "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256\n", @@ -1509,9 +1515,9 @@ "decoder.decoders.5.concat_linear1.bias | [256] | 256\n", "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072\n", "decoder.decoders.5.concat_linear2.bias | [256] | 256\n", - "ctc.ctc_lo.weight | [256, 4223] | 1081088\n", - "ctc.ctc_lo.bias | [4223] | 4223\n", - "Total parameters: 689, 49347742 elements.\n" + "ctc.ctc_lo.weight | [256, 4233] | 1083648\n", + "ctc.ctc_lo.bias | [4233] | 4233\n", + "Total parameters: 689, 49355442 elements.\n" ] } ], @@ -1519,16 +1525,968 @@ "summary(model)" ] }, + { + "cell_type": "code", + "execution_count": 5, + "id": "committed-supplier", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dataloader.ipynb mask_and_masked_fill_test.ipynb\r\n", + "dataloader_with_tokens_tokenids.ipynb model.npz\r\n", + "data.npz python_test.ipynb\r\n", + "encoder.npz train_test.ipynb\r\n", + "hack_api_test.ipynb u2_model.ipynb\r\n", + "jit_infer.ipynb\r\n" + ] + } + ], + "source": [ + "%ls .notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "wooden-rugby", + "metadata": {}, + "outputs": [], + "source": [ + "data = np.load('.notebook/data.npz', allow_pickle=True)\n", + "keys=data['keys']\n", + "feat=data['feat']\n", + "feat_len=data['feat_len']\n", + "text=data['text']\n", + "text_len=data['text_len']" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "streaming-queue", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['BAC009S0739W0246' 'BAC009S0727W0424' 'BAC009S0753W0412'\n", + " 'BAC009S0756W0206' 'BAC009S0740W0414' 'BAC009S0728W0426'\n", + " 'BAC009S0739W0214' 'BAC009S0753W0423' 'BAC009S0734W0201'\n", + " 'BAC009S0740W0427' 'BAC009S0730W0423' 'BAC009S0728W0367'\n", + " 'BAC009S0730W0418' 'BAC009S0727W0157' 'BAC009S0749W0409'\n", + " 'BAC009S0727W0418']\n", + "(16, 207, 80)\n", + "[[[ 8.994624 9.538309 9.191589 ... 10.507416 9.563305 8.256403 ]\n", + " [ 9.798841 10.405224 9.26511 ... 10.251211 9.543982 8.873768 ]\n", + " [10.6890745 10.395469 8.053548 ... 9.906749 10.064903 8.050915 ]\n", + " ...\n", + " [ 9.217986 9.65069 8.505259 ... 9.687183 8.742463 7.9865475]\n", + " [10.129122 9.935194 9.37982 ... 9.563894 9.825992 8.979543 ]\n", + " [ 9.095531 7.1338377 9.468001 ... 9.472748 9.021235 7.447914 ]]\n", + "\n", + " [[11.430976 10.671858 6.0841026 ... 9.382682 8.729745 7.5315614]\n", + " [ 9.731717 7.8104815 7.5714607 ... 10.043035 9.243595 7.3540792]\n", + " [10.65017 10.600604 8.467784 ... 9.281448 9.186885 8.070343 ]\n", + " ...\n", + " [ 9.096987 9.2637 8.075275 ... 8.431845 8.370505 8.002926 ]\n", + " [10.461651 10.147784 6.7693496 ... 9.779426 9.577453 8.080652 ]\n", + " [ 7.794432 5.621059 7.9750648 ... 9.997245 9.849678 8.031287 ]]\n", + "\n", + " [[ 7.3455667 7.896357 7.5795946 ... 11.631024 10.451254 9.123633 ]\n", + " [ 8.628678 8.4630575 7.499242 ... 12.415986 10.975749 8.9425745]\n", + " [ 9.831394 10.2812805 8.97241 ... 12.1386795 10.40175 9.005517 ]\n", + " ...\n", + " [ 7.089641 7.405548 6.8142557 ... 9.325196 9.273162 8.353427 ]\n", + " [ 0. 0. 0. ... 0. 0. 0. ]\n", + " [ 0. 0. 0. ... 0. 0. 0. ]]\n", + "\n", + " ...\n", + "\n", + " [[10.933237 10.464394 7.7202725 ... 10.348816 9.302338 7.1553144]\n", + " [10.449866 9.907033 9.029272 ... 9.952465 9.414051 7.559279 ]\n", + " [10.487655 9.81259 9.895244 ... 9.58662 9.341254 7.7849016]\n", + " ...\n", + " [ 0. 0. 0. ... 0. 0. 0. ]\n", + " [ 0. 0. 0. ... 0. 0. 0. ]\n", + " [ 0. 0. 0. ... 0. 0. 0. ]]\n", + "\n", + " [[ 9.944384 9.585867 8.220328 ... 11.588647 11.045029 8.817075 ]\n", + " [ 7.678356 8.322397 7.533047 ... 11.055085 10.535685 9.27465 ]\n", + " [ 8.626197 9.675917 9.841045 ... 11.378827 10.922112 8.991444 ]\n", + " ...\n", + " [ 0. 0. 0. ... 0. 0. 0. ]\n", + " [ 0. 0. 0. ... 0. 0. 0. ]\n", + " [ 0. 0. 0. ... 0. 0. 0. ]]\n", + "\n", + " [[ 8.107938 7.759043 6.710301 ... 12.650573 11.466156 11.061517 ]\n", + " [11.380332 11.222007 8.658889 ... 12.810616 12.222216 11.689288 ]\n", + " [10.677676 9.920579 8.046089 ... 13.572894 12.5624075 11.155033 ]\n", + " ...\n", + " [ 0. 0. 0. ... 0. 0. 0. ]\n", + " [ 0. 0. 0. ... 0. 0. 0. ]\n", + " [ 0. 0. 0. ... 0. 0. 0. ]]]\n", + "[207 207 205 205 203 203 198 197 195 188 186 186 185 180 166 163]\n", + "[[2995 3116 1209 565 -1 -1]\n", + " [ 236 1176 331 66 3925 4077]\n", + " [2693 524 234 1145 366 -1]\n", + " [3875 4211 3062 700 -1 -1]\n", + " [ 272 987 1134 494 2959 -1]\n", + " [1936 3715 120 2553 2695 2710]\n", + " [ 25 1149 3930 -1 -1 -1]\n", + " [1753 1778 1237 482 3925 110]\n", + " [3703 2 565 3827 -1 -1]\n", + " [1150 2734 10 2478 3490 -1]\n", + " [ 426 811 95 489 144 -1]\n", + " [2313 2006 489 975 -1 -1]\n", + " [3702 3414 205 1488 2966 1347]\n", + " [ 70 1741 702 1666 -1 -1]\n", + " [ 703 1778 1030 849 -1 -1]\n", + " [ 814 1674 115 3827 -1 -1]]\n", + "[4 6 5 4 5 6 3 6 4 5 5 4 6 4 4 4]\n" + ] + } + ], + "source": [ + "print(keys)\n", + "print(feat.shape)\n", + "print(feat)\n", + "print(feat_len)\n", + "print(text)\n", + "print(text_len)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cardiovascular-controversy", + "metadata": {}, + "outputs": [], + "source": [ + "# ['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n", + "# torch.Size([16, 207, 80])\n", + "# tensor([[[ 8.9935, 9.5377, 9.1939, ..., 10.4592, 9.5223, 8.2839],\n", + "# [ 9.7994, 10.4060, 9.2669, ..., 10.2340, 9.5668, 8.7706],\n", + "# [10.6888, 10.3949, 8.0560, ..., 9.9335, 10.1175, 8.1560],\n", + "# ...,\n", + "# [ 9.2174, 9.6504, 8.5052, ..., 9.6707, 8.7834, 8.0564],\n", + "# [10.1287, 9.9347, 9.3788, ..., 9.5698, 9.8277, 8.9262],\n", + "# [ 9.0959, 7.1305, 9.4666, ..., 9.5228, 8.9921, 7.4808]],\n", + "\n", + "# [[11.4309, 10.6716, 6.0973, ..., 9.3820, 8.7208, 7.6153],\n", + "# [ 9.7314, 7.8097, 7.5711, ..., 10.0005, 9.2962, 7.5479],\n", + "# [10.6502, 10.6007, 8.4671, ..., 9.2416, 9.2412, 8.1083],\n", + "# ...,\n", + "# [ 9.0977, 9.2650, 8.0763, ..., 8.3842, 8.4285, 8.0505],\n", + "# [10.4615, 10.1473, 6.7677, ..., 9.8455, 9.6548, 8.2006],\n", + "# [ 7.7949, 5.6219, 7.9746, ..., 9.9617, 9.8019, 8.0486]],\n", + "\n", + "# [[ 7.3481, 7.8987, 7.5786, ..., 11.6611, 10.4626, 9.0665],\n", + "# [ 8.6274, 8.4604, 7.4981, ..., 12.4233, 11.0101, 8.9767],\n", + "# [ 9.8315, 10.2812, 8.9717, ..., 12.1325, 10.4014, 9.0196],\n", + "# ...,\n", + "# [ 7.0872, 7.4009, 6.8090, ..., 9.3759, 9.2273, 8.1752],\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", + "\n", + "# ...,\n", + "\n", + "# [[10.9333, 10.4647, 7.7200, ..., 10.3486, 9.2818, 7.2852],\n", + "# [10.4503, 9.9080, 9.0299, ..., 9.9633, 9.4876, 7.6330],\n", + "# [10.4877, 9.8130, 9.8961, ..., 9.6017, 9.3175, 7.6303],\n", + "# ...,\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", + "\n", + "# [[ 9.9448, 9.5868, 8.2200, ..., 11.6113, 11.0576, 8.7598],\n", + "# [ 7.6800, 8.3231, 7.5294, ..., 11.0965, 10.5442, 9.3556],\n", + "# [ 8.6248, 9.6746, 9.8406, ..., 11.4058, 10.9484, 8.9749],\n", + "# ...,\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", + "\n", + "# [[ 8.1097, 7.7619, 6.7079, ..., 12.6548, 11.4666, 11.0747],\n", + "# [11.3805, 11.2223, 8.6587, ..., 12.7926, 12.2433, 11.7217],\n", + "# [10.6778, 9.9210, 8.0447, ..., 13.5741, 12.5711, 11.1356],\n", + "# ...,\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n", + "# tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n", + "# 166, 163], dtype=torch.int32)\n", + "# tensor([[2995, 3116, 1209, 565, -1, -1],\n", + "# [ 236, 1176, 331, 66, 3925, 4077],\n", + "# [2693, 524, 234, 1145, 366, -1],\n", + "# [3875, 4211, 3062, 700, -1, -1],\n", + "# [ 272, 987, 1134, 494, 2959, -1],\n", + "# [1936, 3715, 120, 2553, 2695, 2710],\n", + "# [ 25, 1149, 3930, -1, -1, -1],\n", + "# [1753, 1778, 1237, 482, 3925, 110],\n", + "# [3703, 2, 565, 3827, -1, -1],\n", + "# [1150, 2734, 10, 2478, 3490, -1],\n", + "# [ 426, 811, 95, 489, 144, -1],\n", + "# [2313, 2006, 489, 975, -1, -1],\n", + "# [3702, 3414, 205, 1488, 2966, 1347],\n", + "# [ 70, 1741, 702, 1666, -1, -1],\n", + "# [ 703, 1778, 1030, 849, -1, -1],\n", + "# [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n", + "# tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "sorted-nursery", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dataloader.ipynb\t\t mask_and_masked_fill_test.ipynb\r\n", + "dataloader_with_tokens_tokenids.ipynb model.npz\r\n", + "data.npz\t\t\t python_test.ipynb\r\n", + "encoder.npz\t\t\t train_test.ipynb\r\n", + "hack_api_test.ipynb\t\t u2_model.ipynb\r\n", + "jit_infer.ipynb\r\n" + ] + } + ], + "source": [ + "!ls .notebook\n", + "data = np.load('.notebook/model.npz', allow_pickle=True)\n", + "state_dict = data['state'].item()\n", + "\n", + "for key, _ in model.state_dict().items():\n", + " if key not in state_dict:\n", + " print(f\"{key} not find.\")\n", + "\n", + "model.set_state_dict(state_dict)\n", + "\n", + "now_state_dict = model.state_dict()\n", + "for key, value in now_state_dict.items():\n", + " if not np.allclose(value.numpy(), state_dict[key]):\n", + " print(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "typical-destruction", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "junior-toner", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/framework.py:686: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " elif dtype == np.bool:\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [[[-0.98870903, -0.07539634, 0.03839889, ..., 0.28173494, 0.36495337, -0.03607073],\n", + " [-0.36884263, -0.22921059, -0.09850989, ..., 1.17554832, 0.49281687, 1.22981191],\n", + " [-0.42831492, 0.14306782, -0.40504572, ..., 1.19258320, 0.28560629, 0.84126252],\n", + " ...,\n", + " [-1.14067113, -0.90225518, 0.08112312, ..., 0.22529972, 0.98848087, 1.42083788],\n", + " [-1.34911966, 0.18967032, 0.27775878, ..., 0.31862095, 0.63177413, 0.15082565],\n", + " [-0.95137477, -0.03690310, -0.21094164, ..., 0.99404806, 0.53174424, 1.83114266]],\n", + "\n", + " [[-0.40670884, 0.22098994, -0.52978617, ..., -0.16111313, 0.73881495, 0.01380203],\n", + " [-0.51442140, -0.45173034, -0.45147005, ..., 1.22010005, 1.24763870, 0.03303454],\n", + " [-0.65140647, 0.16316377, -0.43823493, ..., 1.64499593, 0.57617754, 0.28116497],\n", + " ...,\n", + " [-0.53139108, -0.20081151, 0.54881495, ..., 0.31859449, 1.30965185, 1.90029418],\n", + " [-1.31833756, 0.42574614, -0.10103188, ..., 0.32908860, -0.09044939, -0.02275553],\n", + " [-0.90923810, 0.04415442, 0.16781625, ..., 1.19873142, 0.70491177, 1.67834747]],\n", + "\n", + " [[-0.53979987, 0.18136497, -0.01803534, ..., 0.19695832, 1.25342798, -0.06128683],\n", + " [ 0.55232340, -0.64118379, -0.37508020, ..., 1.14505792, 1.61396039, 0.87614059],\n", + " [-1.02553070, -0.25136885, 0.34500298, ..., 1.65974748, 0.41719219, 0.66209674],\n", + " ...,\n", + " [-1.29586899, -0.31949744, 0.15714335, ..., 0.75515050, 0.94777793, 2.14865851],\n", + " [-1.39566910, 0.06694880, 0.34747776, ..., 0.71009159, 0.68929648, -0.16454494],\n", + " [-0.95307189, -0.09190658, -0.10012256, ..., 1.55584967, 0.73311400, 1.79356611]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.49390051, 0.14108033, -0.53815168, ..., 0.66417909, -0.43801153, 0.06367429],\n", + " [-0.13990593, -0.18394402, -0.51444548, ..., 1.64648640, 0.75647151, 0.73829728],\n", + " [-0.54492640, 0.11887605, 0.00587618, ..., 1.19514525, -0.07906327, 0.48107162],\n", + " ...,\n", + " [-1.33633518, -0.44442374, 0.00936849, ..., 0.91423398, 0.98535562, 0.98347098],\n", + " [-1.19861710, 0.70938700, 0.33154529, ..., 0.16847876, 0.02984418, -0.16296168],\n", + " [-0.89762348, 0.13328603, 0.37963712, ..., 1.21883786, 0.40238193, 1.44023502]],\n", + "\n", + " [[-0.08951648, 0.31010029, 0.40794152, ..., -0.10481174, 0.06963947, -0.45780548],\n", + " [ 0.62238014, -0.20880134, -0.22700992, ..., 1.21718991, 1.12063444, 0.40797234],\n", + " [-0.36213374, -0.26551899, 0.57684356, ..., 1.14578938, 0.28899658, 0.24930142],\n", + " ...,\n", + " [-0.88929099, -0.24094193, 0.38044125, ..., -0.01533419, 1.05152702, 0.98240042],\n", + " [-1.06873631, 0.38082325, 0.74675465, ..., -0.03644872, 0.26738623, -0.43120855],\n", + " [-0.94091892, -0.32104436, 0.47966722, ..., 0.61019003, 0.43108502, 1.11352766]],\n", + "\n", + " [[-0.03323537, 0.22007366, -0.03000726, ..., 0.36668554, 0.08975718, -0.25875339],\n", + " [ 0.40793720, -0.16809593, -0.73204160, ..., 1.41993105, 1.22917044, 0.72486037],\n", + " [-0.50788718, -0.43409127, 0.48296678, ..., 1.11637628, 0.16383135, 0.40282369],\n", + " ...,\n", + " [-0.74193639, -0.63939446, 0.55139303, ..., -0.00370563, 0.73491311, 1.21351111],\n", + " [-1.04918861, 0.59047806, 0.64082241, ..., 0.29343244, 0.25179449, -0.50433135],\n", + " [-0.86854327, -0.45206326, 0.32531947, ..., 0.38761431, 0.32762241, 1.13863206]]])\n", + "Tensor(shape=[16, 51, 4233], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [[[ 0.01700813, 0.56431651, 0.45506364, ..., -0.90041381, 0.26185888, 0.68741971],\n", + " [-0.14328328, 0.16787037, 1.60204566, ..., -0.52197266, -0.27033603, 0.44314486],\n", + " [-0.22867197, 0.23935843, 1.40139520, ..., -0.58817720, 0.36277789, 0.60821676],\n", + " ...,\n", + " [-0.08569217, 0.53737843, 0.74085897, ..., -0.88298100, 0.06646422, 0.98183125],\n", + " [-0.33066741, 0.65147656, 0.50528461, ..., -0.88622850, 0.37098962, 1.03324938],\n", + " [ 0.39562812, 0.51454604, 0.33244559, ..., -0.73552674, -0.23745571, 0.55406201]],\n", + "\n", + " [[-0.38542494, 0.65172035, 0.47112849, ..., -0.60375690, 0.56403750, 0.86565256],\n", + " [-0.16968662, 0.32454279, 1.09088314, ..., -0.22235930, 0.33991110, 0.58421040],\n", + " [-0.32392421, 0.61689788, 0.94623339, ..., -0.51428318, 0.46278131, 0.49175799],\n", + " ...,\n", + " [ 0.18902412, 0.70370960, 0.22131878, ..., -0.49284744, -0.19460268, 0.56502676],\n", + " [-0.62619895, 1.07694829, 0.36491036, ..., -0.60827464, 0.18799752, 1.20347393],\n", + " [ 0.36972979, 0.69460171, 0.32603034, ..., -0.49348083, -0.15541299, 0.73012495]],\n", + "\n", + " [[-0.65226990, 0.72903591, -0.02955327, ..., -0.62513059, 0.78257781, 1.06949353],\n", + " [-0.31972992, 0.24137607, 1.32179105, ..., -0.31378460, 0.47126365, 0.50631112],\n", + " [-0.27153456, 0.61149585, 1.36779737, ..., -0.41040954, 0.19214611, 0.66955560],\n", + " ...,\n", + " [ 0.70239127, 0.59776336, 0.41315046, ..., -0.63964498, -0.05725104, 0.11523478],\n", + " [-0.61306721, 0.93517447, 0.13917899, ..., -1.07090628, 0.08259787, 1.05669415],\n", + " [ 0.30364236, 0.70674980, 0.27861559, ..., -0.45961899, -0.49536246, 0.42410135]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.64492232, 1.16129804, 0.41210422, ..., -0.19025707, 0.62510222, 0.93904167],\n", + " [-0.21679890, 0.01541298, 1.22289670, ..., 0.38956094, 0.33127683, 0.55291802],\n", + " [-0.52435982, 0.53476179, 1.36162400, ..., -0.24845126, 0.30851704, 0.73026729],\n", + " ...,\n", + " [ 0.33089486, 1.21250021, 0.50133944, ..., -0.23968413, -0.05249966, 0.33221221],\n", + " [-0.85425609, 0.91674101, 0.37947315, ..., -0.54663503, 0.32272232, 0.91941363],\n", + " [ 0.73812121, 1.22125304, 0.54933113, ..., -0.34835899, -0.45703983, 0.10094876]],\n", + "\n", + " [[-0.14543423, 0.59343618, 0.48727173, ..., -0.48721361, 0.23470370, 1.04386616],\n", + " [-0.37399894, 0.05687386, 0.98770601, ..., 0.20608327, 0.28952795, 0.69849032],\n", + " [-0.52618062, 0.19394255, 1.08136940, ..., -0.51677036, 0.21367601, 0.81429225],\n", + " ...,\n", + " [ 0.91529322, 0.82572049, 0.56763554, ..., -0.48792118, -0.20669226, 0.10400648],\n", + " [-0.65565026, 0.82217371, 0.45654771, ..., -0.70658189, -0.00154681, 1.01031244],\n", + " [ 0.85112470, 0.92439699, 0.51105708, ..., -0.57625800, -0.60960227, -0.02037612]],\n", + "\n", + " [[-0.51818568, 0.87956434, 0.36026087, ..., -0.60333908, 0.30989277, 0.92859864],\n", + " [-0.36991373, -0.02736802, 1.04911196, ..., 0.23815414, 0.36916631, 0.56326580],\n", + " [-0.58471107, 0.27818793, 1.23031902, ..., -0.47299296, -0.03227636, 0.80790430],\n", + " ...,\n", + " [ 0.74327284, 1.02660847, 0.59810358, ..., -0.35650834, -0.50914389, 0.08961441],\n", + " [-0.64146334, 0.82072812, 0.35041004, ..., -0.80564159, -0.01707828, 0.84261787],\n", + " [ 0.85794026, 1.18059790, 0.34535947, ..., -0.57844251, -0.85070610, 0.06602620]]])\n", + "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [139.62181091]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [36.32815552]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [380.64031982])\n" + ] + } + ], + "source": [ + "import paddle\n", + "feat=paddle.to_tensor(feat)\n", + "feat_len=paddle.to_tensor(feat_len, dtype='int64')\n", + "text=paddle.to_tensor(text)\n", + "text_len=paddle.to_tensor(text_len, dtype='int64')\n", + "\n", + "model.eval()\n", + "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", + " text, text_len)\n", + "print(total_loss, attention_loss, ctc_loss )" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "dense-brake", + "metadata": {}, + "outputs": [], + "source": [ + "# tensor(142.4635, grad_fn=) tensor(41.8416, grad_fn=) tensor(377.2479, grad_fn=)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "voluntary-arcade", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "surprising-teach", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[16, 51, 256]\n", + "[16, 1, 51]\n", + "Tensor(shape=[51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [[-0.98870903, -0.07539634, 0.03839889, ..., 0.28173494, 0.36495337, -0.03607073],\n", + " [-0.36884263, -0.22921059, -0.09850989, ..., 1.17554832, 0.49281687, 1.22981191],\n", + " [-0.42831492, 0.14306782, -0.40504572, ..., 1.19258320, 0.28560629, 0.84126252],\n", + " ...,\n", + " [-1.14067113, -0.90225518, 0.08112312, ..., 0.22529972, 0.98848087, 1.42083788],\n", + " [-1.34911966, 0.18967032, 0.27775878, ..., 0.31862095, 0.63177413, 0.15082565],\n", + " [-0.95137477, -0.03690310, -0.21094164, ..., 0.99404806, 0.53174424, 1.83114266]])\n" + ] + } + ], + "source": [ + "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n", + "print(encoder_out.shape)\n", + "print(encoder_mask.shape)\n", + "print(encoder_out[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "permanent-loading", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "criminal-setup", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "deepspeech examples README_cn.md\tsetup.sh tools\r\n", + "docs\t LICENSE README.md\t\ttests\t utils\r\n", + "env.sh\t log requirements.txt\tthird_party\r\n" + ] + } + ], + "source": [ + "!ls\n", + "data = np.load('.notebook/encoder.npz', allow_pickle=True)\n", + "torch_mask = data['mask']\n", + "torch_encoder_out = data['out']" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "brazilian-happening", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "None\n" + ] + } + ], + "source": [ + "print(np.testing.assert_equal(torch_mask, encoder_mask.numpy()))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "separate-eligibility", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "False\n", + "[[-0.70815086 0.5656927 0.709813 ... 1.0982457 0.7758755\n", + " 1.1307045 ]\n", + " [-0.78350693 0.39481696 0.74499094 ... 1.2273936 0.8813775\n", + " 1.3142622 ]\n", + " [-0.9625825 0.63913065 0.90481734 ... 0.9587627 0.73829174\n", + " 1.2868171 ]\n", + " ...\n", + " [-1.089918 0.6853822 0.9498568 ... 0.8842667 0.81529033\n", + " 1.325533 ]\n", + " [-1.1811031 0.6971649 0.7225241 ... 1.200684 0.8006199\n", + " 1.4533575 ]\n", + " [-1.2878689 0.72914284 0.7896784 ... 0.916238 0.87275296\n", + " 1.2629912 ]]\n" + ] + } + ], + "source": [ + "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n", + "print(torch_encoder_out[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "alternate-comment", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[[-0.70815086, 0.56569272, 0.70981300, ..., 1.09824574, 0.77587551, 1.13070452],\n", + " [-0.78350693, 0.39481696, 0.74499094, ..., 1.22739363, 0.88137752, 1.31426215],\n", + " [-0.96258253, 0.63913065, 0.90481734, ..., 0.95876271, 0.73829174, 1.28681707],\n", + " ...,\n", + " [-1.08991802, 0.68538219, 0.94985682, ..., 0.88426667, 0.81529033, 1.32553303],\n", + " [-1.18110311, 0.69716489, 0.72252411, ..., 1.20068395, 0.80061990, 1.45335746],\n", + " [-1.28786886, 0.72914284, 0.78967839, ..., 0.91623801, 0.87275296, 1.26299119]],\n", + "\n", + " [[-0.92869806, 0.66449726, 0.50940996, ..., 0.67377377, 0.75721473, 1.44601321],\n", + " [-0.58323175, 0.39969942, 0.46701184, ..., 0.76123071, 0.82149148, 1.53387356],\n", + " [-0.66912395, 0.42107889, 0.53314692, ..., 0.77352434, 0.73588967, 1.55955291],\n", + " ...,\n", + " [-0.91979462, 0.78827965, 0.51364565, ..., 0.92784536, 0.88741118, 1.55079234],\n", + " [-0.90603584, 0.75470775, 0.51157582, ..., 0.99914151, 0.87281585, 1.49555171],\n", + " [-0.94746929, 0.86679929, 0.65138626, ..., 0.94967902, 0.74416542, 1.38868642]],\n", + "\n", + " [[-0.48889187, 0.40629929, -0.03985359, ..., 0.96110481, 0.72562295, 1.63959312],\n", + " [-0.23216049, 0.47649717, 0.06432461, ..., 1.12634289, 0.79304028, 1.72600341],\n", + " [-0.35049179, 0.45091787, 0.23251781, ..., 1.18179774, 0.77048856, 1.67785954],\n", + " ...,\n", + " [-0.95444369, 0.62471539, 0.32779199, ..., 1.38101709, 1.06079900, 1.50111783],\n", + " [-0.92489260, 0.75614768, 0.24058929, ..., 1.37509775, 1.08733690, 1.56775463],\n", + " [-0.87840986, 0.66779983, 0.13002315, ..., 1.30724812, 1.16084790, 1.31587541]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.49060634, 0.58141297, 0.56432068, ..., 1.01875186, 0.29791155, 1.37021196],\n", + " [-0.68209082, 0.68713498, 0.37579197, ..., 1.03290558, 0.31765509, 1.85167778],\n", + " [-0.59835142, 0.70429099, 0.52930498, ..., 1.10545111, 0.27167040, 1.79945505],\n", + " ...,\n", + " [-0.92881185, 0.90744990, 0.30645573, ..., 1.21084821, 0.45378613, 1.54552865],\n", + " [-0.93471462, 0.98222488, 0.33421245, ..., 1.20006037, 0.48279485, 1.54707932],\n", + " [-0.88121003, 0.97045374, 0.41706085, ..., 1.17172730, 0.44657633, 1.51080203]],\n", + "\n", + " [[-0.75599599, -0.00976199, 0.22203811, ..., 0.83421057, 0.32212344, 1.65036464],\n", + " [-0.82587808, -0.04545709, 0.31506237, ..., 1.26919305, 0.44509020, 1.73162079],\n", + " [-0.76584357, -0.23916586, 0.41122752, ..., 1.08345842, 0.35172719, 1.59721172],\n", + " ...,\n", + " [-1.20936334, 0.74367058, 0.41594249, ..., 1.40040612, 0.81670052, 1.13627040],\n", + " [-1.21405351, 0.80623198, 0.41914314, ..., 1.40204942, 0.80985707, 1.16537964],\n", + " [-1.19519651, 0.79087526, 0.48453161, ..., 1.36768281, 0.76330566, 1.13404262]],\n", + "\n", + " [[-0.74483055, -0.33014604, 0.24039182, ..., 0.02945682, 0.71929377, 1.91275668],\n", + " [-0.56035036, -0.41564703, 0.36313012, ..., 0.41183007, 0.90209144, 1.80845654],\n", + " [-0.69359547, -0.13844451, 0.30018413, ..., 0.49444827, 0.56794512, 1.67332709],\n", + " ...,\n", + " [-0.99715513, 1.01512778, 0.43277434, ..., 1.09037900, 0.86760134, 1.29863596],\n", + " [-0.99962872, 1.07428896, 0.44226229, ..., 1.09051895, 0.88753319, 1.33773279],\n", + " [-0.98149830, 1.05249369, 0.51830143, ..., 1.04208529, 0.84298122, 1.31557417]]])\n", + "Tensor(shape=[16, 51, 4233], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [[[-0.38691020, 1.10840213, -0.36066169, ..., -0.74562210, 0.38264662, 0.09510683],\n", + " [-0.40077940, 1.08683729, -0.28951788, ..., -0.79604387, 0.23650460, 0.21539813],\n", + " [-0.48349607, 1.13704205, -0.34528807, ..., -0.74176753, 0.15859264, 0.08665741],\n", + " ...,\n", + " [-0.33308679, 1.06052911, -0.28531107, ..., -0.56407875, 0.06546993, 0.34067774],\n", + " [-0.31819728, 1.02948821, -0.34244826, ..., -0.59871835, 0.13086139, 0.23477128],\n", + " [-0.46234924, 1.05966771, -0.25739416, ..., -0.73751336, 0.15748897, 0.26660469]],\n", + "\n", + " [[-0.31971461, 1.24201715, -0.42921415, ..., -1.03340065, 0.12717772, -0.02929212],\n", + " [-0.29465508, 1.20464718, -0.42703199, ..., -0.81277102, 0.03500172, 0.06429010],\n", + " [-0.23355499, 1.22438145, -0.37154198, ..., -0.80892444, -0.04463244, 0.14419895],\n", + " ...,\n", + " [-0.30919531, 1.24750948, -0.43951514, ..., -0.78709352, 0.09086802, 0.22021589],\n", + " [-0.33325344, 1.25496054, -0.43700716, ..., -0.88238114, 0.15829682, 0.27076158],\n", + " [-0.32431445, 1.20970893, -0.44767022, ..., -0.85771841, 0.15963244, 0.26043096]],\n", + "\n", + " [[-0.35240874, 1.21549594, -0.53064364, ..., -0.64634734, 0.05578946, -0.11943770],\n", + " [-0.57545149, 1.34280396, -0.46211162, ..., -0.65927368, 0.20014796, -0.09852441],\n", + " [-0.44432947, 1.32504761, -0.42148980, ..., -0.74191439, 0.19582249, -0.07732908],\n", + " ...,\n", + " [-0.44347334, 1.20052171, -0.54884982, ..., -0.68667632, 0.21917908, 0.24867907],\n", + " [-0.45897871, 1.27240980, -0.38485754, ..., -0.65063947, 0.28696120, 0.12868038],\n", + " [-0.39181367, 1.26035547, -0.59301054, ..., -0.83063912, 0.30225450, 0.48679376]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.40601602, 1.20973194, -0.49095234, ..., -0.85096097, 0.20683658, 0.04339755],\n", + " [-0.40158153, 1.20891953, -0.45978341, ..., -0.98469758, 0.15744446, 0.03174295],\n", + " [-0.31740543, 1.22869027, -0.51616722, ..., -1.15985453, 0.11959577, -0.09386670],\n", + " ...,\n", + " [-0.44220674, 1.22656810, -0.59912074, ..., -1.26186705, 0.33093452, 0.08640137],\n", + " [-0.44747517, 1.23198783, -0.61370420, ..., -1.28686309, 0.32481337, 0.06313021],\n", + " [-0.45813197, 1.19577587, -0.57291198, ..., -1.30331659, 0.31380397, 0.09586264]],\n", + "\n", + " [[-0.34510323, 1.13676333, -0.41883209, ..., -0.19890606, -0.03747968, 0.15454675],\n", + " [-0.40049601, 1.07489455, -0.20783955, ..., -0.38220686, -0.01861078, 0.21973050],\n", + " [-0.27853724, 1.00104034, -0.15550351, ..., -0.38109386, -0.04351424, 0.20367554],\n", + " ...,\n", + " [-0.51515359, 1.21439159, -0.54381990, ..., -0.88646334, 0.26562017, 0.44584516],\n", + " [-0.52407527, 1.21481705, -0.54217672, ..., -0.92878431, 0.23799631, 0.44936055],\n", + " [-0.53740996, 1.18220830, -0.50675553, ..., -0.93877101, 0.24513872, 0.46150753]],\n", + "\n", + " [[-0.32741299, 0.97497153, -0.00948974, ..., -0.39587873, 0.03406802, 0.24171287],\n", + " [-0.43713254, 0.97446007, -0.12497631, ..., -0.57407486, 0.05668554, 0.24453926],\n", + " [-0.21812496, 0.95889568, -0.10461410, ..., -0.71747971, -0.03854717, 0.17685428],\n", + " ...,\n", + " [-0.50795484, 1.18626249, -0.55178732, ..., -1.05484831, 0.28090888, 0.26255831],\n", + " [-0.51629281, 1.18509519, -0.54967672, ..., -1.09254313, 0.25126994, 0.26916048],\n", + " [-0.52659053, 1.15537941, -0.52105296, ..., -1.09778607, 0.25193223, 0.28108835]]])\n", + "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [377.24789429])\n", + "[1.]\n", + "[-9.1570435e+00 4.7310561e-02 -1.8856564e-01 ... 8.5132439e-03\n", + " 2.1997439e-02 2.7489617e-02]\n" + ] + } + ], + "source": [ + "from paddle.nn import functional as F\n", + "def ctc_loss(logits,\n", + " labels,\n", + " input_lengths,\n", + " label_lengths,\n", + " blank=0,\n", + " reduction='mean',\n", + " norm_by_times=False):\n", + " loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,\n", + " input_lengths, label_lengths)\n", + " loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])\n", + " assert reduction in ['mean', 'sum', 'none']\n", + " if reduction == 'mean':\n", + " loss_out = paddle.mean(loss_out / label_lengths)\n", + " elif reduction == 'sum':\n", + " loss_out = paddle.sum(loss_out)\n", + " return loss_out\n", + "\n", + "F.ctc_loss = ctc_loss\n", + "\n", + "torch_mask_t = paddle.to_tensor(torch_mask, dtype='int64')\n", + "encoder_out_lens = torch_mask_t.squeeze(1).sum(1)\n", + "loss_ctc = model.ctc(paddle.to_tensor(torch_encoder_out), encoder_out_lens, text, text_len)\n", + "print(loss_ctc)\n", + "# ctc tensor(377.2479, device='cuda:0', grad_fn=)\n", + "loss_ctc.backward()\n", + "print(loss_ctc.grad)\n", + "print(model.ctc.ctc_lo.weight.grad)\n", + "print(model.ctc.ctc_lo.bias.grad)\n", + "# tensor([[ 3.2806e+00, -1.8297e+00, -2.5472e+00, ..., -4.4421e+00,\n", + "# -3.4516e+00, -6.8526e+00],\n", + "# [-1.5462e-02, 8.0163e-03, 1.3837e-02, ..., 2.4541e-02,\n", + "# 1.7295e-02, 3.5211e-02],\n", + "# [ 5.0349e-02, -4.5908e-02, -2.8797e-02, ..., -8.8659e-02,\n", + "# -6.3412e-02, -1.2411e-01],\n", + "# ...,\n", + "# [-2.4901e-03, 1.0179e-03, 2.3745e-03, ..., 4.3330e-03,\n", + "# 3.2267e-03, 6.2963e-03],\n", + "# [-6.0131e-03, 2.5570e-03, 6.0628e-03, ..., 1.1443e-02,\n", + "# 8.4951e-03, 1.6021e-02],\n", + "# [-7.2826e-03, 2.3929e-03, 7.7757e-03, ..., 1.4101e-02,\n", + "# 1.0566e-02, 2.0105e-02]], device='cuda:0')" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "ranking-beads", + "id": "polish-opportunity", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "improved-alabama", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [36.31002045]) 0.0\n" + ] + } + ], + "source": [ + "loss_att, acc_att = model._calc_att_loss(paddle.to_tensor(torch_encoder_out), paddle.to_tensor(torch_mask),\n", + " text, text_len)\n", + "print(loss_att, acc_att)\n", + "#tensor(41.8416, device='cuda:0', grad_fn=) 0.0" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "metric-destruction", + "metadata": {}, + "outputs": [], + "source": [ + "# encoder, decoder + att_loss 不对齐" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "turkish-watch", + "metadata": {}, + "outputs": [], + "source": [ + "data = np.load(\".notebook/decoder.npz\", allow_pickle=True)\n", + "torch_decoder_out = data['decoder_out']" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "drawn-crash", "metadata": {}, "outputs": [], "source": [ - "total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len,\n", - " self.text, self.text_len)" + "def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,\n", + " ignore_id: int):\n", + " \"\"\"Add and labels.\n", + " Args:\n", + " ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)\n", + " sos (int): index of \n", + " eos (int): index of \n", + " ignore_id (int): index of padding\n", + " Returns:\n", + " ys_in (paddle.Tensor) : (B, Lmax + 1)\n", + " ys_out (paddle.Tensor) : (B, Lmax + 1)\n", + " Examples:\n", + " >>> sos_id = 10\n", + " >>> eos_id = 11\n", + " >>> ignore_id = -1\n", + " >>> ys_pad\n", + " tensor([[ 1, 2, 3, 4, 5],\n", + " [ 4, 5, 6, -1, -1],\n", + " [ 7, 8, 9, -1, -1]], dtype=paddle.int32)\n", + " >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)\n", + " >>> ys_in\n", + " tensor([[10, 1, 2, 3, 4, 5],\n", + " [10, 4, 5, 6, 11, 11],\n", + " [10, 7, 8, 9, 11, 11]])\n", + " >>> ys_out\n", + " tensor([[ 1, 2, 3, 4, 5, 11],\n", + " [ 4, 5, 6, 11, -1, -1],\n", + " [ 7, 8, 9, 11, -1, -1]])\n", + " \"\"\"\n", + " # TODO(Hui Zhang): using comment code, \n", + " #_sos = paddle.to_tensor(\n", + " # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n", + " #_eos = paddle.to_tensor(\n", + " # [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n", + " #ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n", + " #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]\n", + " #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]\n", + " #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)\n", + " B = ys_pad.size(0)\n", + " _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos\n", + " _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos\n", + " ys_in = paddle.cat([_sos, ys_pad], dim=1)\n", + " mask_pad = (ys_in == ignore_id)\n", + " ys_in = ys_in.masked_fill(mask_pad, eos)\n", + " \n", + "\n", + " ys_out = paddle.cat([ys_pad, _eos], dim=1)\n", + " ys_out = ys_out.masked_fill(mask_pad, eos)\n", + " mask_eos = (ys_out == ignore_id)\n", + " ys_out = ys_out.masked_fill(mask_eos, eos)\n", + " ys_out = ys_out.masked_fill(mask_pad, ignore_id)\n", + " return ys_in, ys_out" ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "informative-optics", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(shape=[16, 7], dtype=int32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[4232, 2995, 3116, 1209, 565 , 4232, 4232],\n", + " [4232, 236 , 1176, 331 , 66 , 3925, 4077],\n", + " [4232, 2693, 524 , 234 , 1145, 366 , 4232],\n", + " [4232, 3875, 4211, 3062, 700 , 4232, 4232],\n", + " [4232, 272 , 987 , 1134, 494 , 2959, 4232],\n", + " [4232, 1936, 3715, 120 , 2553, 2695, 2710],\n", + " [4232, 25 , 1149, 3930, 4232, 4232, 4232],\n", + " [4232, 1753, 1778, 1237, 482 , 3925, 110 ],\n", + " [4232, 3703, 2 , 565 , 3827, 4232, 4232],\n", + " [4232, 1150, 2734, 10 , 2478, 3490, 4232],\n", + " [4232, 426 , 811 , 95 , 489 , 144 , 4232],\n", + " [4232, 2313, 2006, 489 , 975 , 4232, 4232],\n", + " [4232, 3702, 3414, 205 , 1488, 2966, 1347],\n", + " [4232, 70 , 1741, 702 , 1666, 4232, 4232],\n", + " [4232, 703 , 1778, 1030, 849 , 4232, 4232],\n", + " [4232, 814 , 1674, 115 , 3827, 4232, 4232]])\n", + "Tensor(shape=[16, 7], dtype=int32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[2995, 3116, 1209, 565, 4232, -1 , -1 ],\n", + " [ 236, 1176, 331, 66 , 3925, 4077, 4232],\n", + " [2693, 524, 234, 1145, 366, 4232, -1 ],\n", + " [3875, 4211, 3062, 700, 4232, -1 , -1 ],\n", + " [ 272, 987, 1134, 494, 2959, 4232, -1 ],\n", + " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", + " [ 25 , 1149, 3930, 4232, -1 , -1 , -1 ],\n", + " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", + " [3703, 2 , 565, 3827, 4232, -1 , -1 ],\n", + " [1150, 2734, 10 , 2478, 3490, 4232, -1 ],\n", + " [ 426, 811, 95 , 489, 144, 4232, -1 ],\n", + " [2313, 2006, 489, 975, 4232, -1 , -1 ],\n", + " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", + " [ 70 , 1741, 702, 1666, 4232, -1 , -1 ],\n", + " [ 703, 1778, 1030, 849, 4232, -1 , -1 ],\n", + " [ 814, 1674, 115, 3827, 4232, -1 , -1 ]])\n" + ] + } + ], + "source": [ + "ys_pad = text\n", + "ys_pad_lens = text_len\n", + "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n", + " model.ignore_id)\n", + "ys_in_lens = ys_pad_lens + 1\n", + "print(ys_in_pad)\n", + "print(ys_out_pad)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "northern-advisory", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[16, 7, 4233]\n", + "Tensor(shape=[7, 4233], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [[-0.37553221, -0.83114165, 0.70238966, ..., 0.30866742, 0.03037567, 0.43291825],\n", + " [-0.87047130, -0.32394654, 0.37882078, ..., 0.34444264, -0.12801090, -0.97179270],\n", + " [-0.43517584, 0.02496703, -0.32672805, ..., 0.04601809, -1.15214014, -0.23627253],\n", + " ...,\n", + " [ 0.42706215, 0.58341736, -0.01791662, ..., 0.34311637, 0.06014483, -0.34610766],\n", + " [-0.37887222, -0.81906295, 0.71680295, ..., 0.22679621, 0.01455487, 0.45493346],\n", + " [-0.38187075, -0.82030386, 0.70901453, ..., 0.22812662, 0.01431785, 0.45638454]])\n", + "False\n" + ] + } + ], + "source": [ + "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", + " ys_in_lens)\n", + "print(decoder_out.shape)\n", + "print(decoder_out[0])\n", + "print(np.allclose(decoder_out.numpy(), torch_decoder_out))" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "prospective-death", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", + " [41.84283447])\n", + "Tensor(shape=[16, 7], dtype=int32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[2995, 3116, 1209, 565, 4232, -1 , -1 ],\n", + " [ 236, 1176, 331, 66 , 3925, 4077, 4232],\n", + " [2693, 524, 234, 1145, 366, 4232, -1 ],\n", + " [3875, 4211, 3062, 700, 4232, -1 , -1 ],\n", + " [ 272, 987, 1134, 494, 2959, 4232, -1 ],\n", + " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", + " [ 25 , 1149, 3930, 4232, -1 , -1 , -1 ],\n", + " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", + " [3703, 2 , 565, 3827, 4232, -1 , -1 ],\n", + " [1150, 2734, 10 , 2478, 3490, 4232, -1 ],\n", + " [ 426, 811, 95 , 489, 144, 4232, -1 ],\n", + " [2313, 2006, 489, 975, 4232, -1 , -1 ],\n", + " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", + " [ 70 , 1741, 702, 1666, 4232, -1 , -1 ],\n", + " [ 703, 1778, 1030, 849, 4232, -1 , -1 ],\n", + " [ 814, 1674, 115, 3827, 4232, -1 , -1 ]])\n" + ] + } + ], + "source": [ + "loss_att = model.criterion_att(paddle.to_tensor(torch_decoder_out), ys_out_pad)\n", + "print(loss_att)\n", + "print(ys_out_pad)\n", + "# tensor(41.8416, device='cuda:0', grad_fn=)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "closed-partner", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "silent-animal", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "None\n" + ] + } + ], + "source": [ + "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", + " ys_in_lens)\n", + "loss_att = model.criterion_att(paddle.to_tensor(torch_decoder_out), ys_out_pad)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fatal-board", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/deepspeech/frontend/normalizer.py b/deepspeech/frontend/normalizer.py index 83c1ff905fe420772b53bc2368c08f2b9d88196c..53ae26b809bfb045c8a710ed1b14b78dc73abbd1 100644 --- a/deepspeech/frontend/normalizer.py +++ b/deepspeech/frontend/normalizer.py @@ -101,7 +101,7 @@ class FeatureNormalizer(object): features.append( featurize_func(AudioSegment.from_file(instance["feat"]))) features = np.hstack(features) #(D, T) - self._mean = np.mean(features, axis=1).reshape([1, -1]) #(1, D) - std = np.std(features, axis=1).reshape([1, -1]) #(1, D) + self._mean = np.mean(features, axis=1) #(D,) + std = np.std(features, axis=1) #(D,) std = np.clip(std, eps, None) self._istd = 1.0 / std diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 6c423167c5558eef9dfb8c95921f549b72ea2ac7..227306447ace3704b6af964ff40f984df5a5ecfb 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -132,7 +132,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, ys_out = paddle.cat([ys_pad, _eos], dim=1) ys_out = ys_out.masked_fill(mask_pad, eos) - mask_eos = (ys_in == ignore_id) + mask_eos = (ys_out == ignore_id) ys_out = ys_out.masked_fill(mask_eos, eos) ys_out = ys_out.masked_fill(mask_pad, ignore_id) return ys_in, ys_out