未验证 提交 5dc069d0 编写于 作者: L Li Fuchen 提交者: GitHub

OP(warpctc, add_position_encoding, scaled_dot_product_attention) error message enhancement (#24261)

* enhance add_position_encoding error message, test=develop

* enhance warpctc & scaled_dot_product_attention error message, test=develop

* modified error message and ctest of scaled_dot_product_attention, test=develop
上级 19511dfa
...@@ -23,11 +23,9 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel { ...@@ -23,11 +23,9 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AddPositionEncoding");
"X(Input) of add_position_encoding_op should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
PADDLE_ENFORCE( "AddPositionEncoding");
ctx->HasOutput("Out"),
"Out(Output) of add_position_encoding_op should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
......
...@@ -39,25 +39,40 @@ class AddPositionEncodingKernel : public framework::OpKernel<T> { ...@@ -39,25 +39,40 @@ class AddPositionEncodingKernel : public framework::OpKernel<T> {
int enc_size = 0; int enc_size = 0;
if (x_lod.empty()) { if (x_lod.empty()) {
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(x_dim.size(), 3,
x_dim.size() == 3UL, platform::errors::InvalidArgument(
"The input X of Add Position Encoding should be 3-D Tensor!"); "The input(X)'s dimension of AddPositionEncodingOp "
"should be equal to "
"3, but received %d. ",
x_dim.size()));
batch_size = x_dim[0]; batch_size = x_dim[0];
max_seq_len = x_dim[1]; max_seq_len = x_dim[1];
enc_size = x_dim[2]; enc_size = x_dim[2];
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(x_dim.size(), 2,
x_dim.size() == 2UL, platform::errors::InvalidArgument(
"The input X of Add Position Encoding should be 2-D LoDTensor!"); "The input(X)'s dimension of AddPositionEncodingOp "
PADDLE_ENFORCE( "should be equal to "
x_lod.size() == 1UL, "2, but received %d. ",
"The Add Position Encoding Op only supports lod_level == 1!"); x_dim.size()));
PADDLE_ENFORCE_EQ(x_lod.size(), 1,
platform::errors::InvalidArgument(
"The input(X)'s lod level of AddPositionEncodingOp "
"should be equal to "
"1, but received %d. ",
x_lod.size()));
batch_size = x_lod[0].size() - 1; batch_size = x_lod[0].size() - 1;
max_seq_len = -1; max_seq_len = -1;
enc_size = x_dim[1]; enc_size = x_dim[1];
} }
PADDLE_ENFORCE(enc_size % 2 == 0, "Only support even encode size!"); PADDLE_ENFORCE_EQ(enc_size % 2, 0,
platform::errors::InvalidArgument(
"The input(X)'s feature size of "
"AddPositionEncodingOp only support even, "
"but received an odd number: %d. ",
enc_size));
const int half_size = enc_size / 2; const int half_size = enc_size / 2;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
......
...@@ -28,14 +28,11 @@ class WarpCTCOp : public framework::OperatorWithKernel { ...@@ -28,14 +28,11 @@ class WarpCTCOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Logits"), OP_INOUT_CHECK(ctx->HasInput("Logits"), "Input", "Logits", "WarpCTC");
"Input(Logits) of WarpCTCOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "WarpCTC");
PADDLE_ENFORCE(ctx->HasInput("Label"), OP_INOUT_CHECK(ctx->HasOutput("WarpCTCGrad"), "Output", "WarpCTCGrad",
"Input(Label) of WarpCTCOp should not be null."); "WarpCTC");
PADDLE_ENFORCE(ctx->HasOutput("WarpCTCGrad"), OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss", "WarpCTC");
"Output(WarpCTCGrad) of WarpCTCOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
"Output(Loss) of WarpCTCOp should not be null.");
auto logits_dims = ctx->GetInputDim("Logits"); auto logits_dims = ctx->GetInputDim("Logits");
int blank = ctx->Attrs().Get<int>("blank"); int blank = ctx->Attrs().Get<int>("blank");
...@@ -47,9 +44,18 @@ class WarpCTCOp : public framework::OperatorWithKernel { ...@@ -47,9 +44,18 @@ class WarpCTCOp : public framework::OperatorWithKernel {
sequence_width = sequence_width =
static_cast<int>(framework::product(logits_dims) / logits_dims[0]); static_cast<int>(framework::product(logits_dims) / logits_dims[0]);
} }
PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width),
"The value of Attr(blank) should be in interval [0, %d).", PADDLE_ENFORCE_GE(
sequence_width); blank, 0, platform::errors::InvalidArgument(
"The value of Attr(blank) should be in interval [0, %d), "
"but received %d",
blank));
PADDLE_ENFORCE_LT(
blank, sequence_width,
platform::errors::InvalidArgument(
"The value of Attr(blank) should be in interval [0, %d), "
"but received %d",
blank));
// TODO(liuyiqun): it is tricky to set the wrong dimension here. // TODO(liuyiqun): it is tricky to set the wrong dimension here.
ctx->SetOutputDim("Loss", {-1, 1}); ctx->SetOutputDim("Loss", {-1, 1});
...@@ -160,10 +166,10 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { ...@@ -160,10 +166,10 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("WarpCTCGrad"), OP_INOUT_CHECK(ctx->HasInput("WarpCTCGrad"), "Input", "WarpCTCGrad",
"Input(WarpCTCGrad) of WarpCTCGradOp should not be null."); "WarpCTCGrad");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")), OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Logits")), "Output",
"Output(Logits@GRAD) of WarpCTCGradOp should not be null."); "WarpCTCGrad", "WarpCTCGrad");
ctx->SetOutputDim(framework::GradVarName("Logits"), ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Logits")); ctx->GetInputDim("Logits"));
ctx->ShareLoD("Logits", /*->*/ framework::GradVarName("Logits")); ctx->ShareLoD("Logits", /*->*/ framework::GradVarName("Logits"));
......
...@@ -65,13 +65,18 @@ class WarpCTCFunctor { ...@@ -65,13 +65,18 @@ class WarpCTCFunctor {
ctcStatus_t status = platform::dynload::get_workspace_size( ctcStatus_t status = platform::dynload::get_workspace_size(
cpu_label_lengths, cpu_input_lengths, static_cast<int>(sequence_width), cpu_label_lengths, cpu_input_lengths, static_cast<int>(sequence_width),
static_cast<int>(num_sequences), options_, &workspace_bytes); static_cast<int>(num_sequences), options_, &workspace_bytes);
PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status,
"warp-ctc [version %d] Error in get_workspace_size: ", PADDLE_ENFORCE_EQ(
warpctc_version_, CTC_STATUS_SUCCESS, status,
platform::dynload::ctcGetStatusString(status)); platform::errors::PreconditionNotMet(
PADDLE_ENFORCE_GT(workspace_bytes, 0UL, "warp-ctc [version %d] Error in get_workspace_size: %s",
"Bytes of workspace got by warp-ctc function, " warpctc_version_, platform::dynload::ctcGetStatusString(status)));
"get_workspace_size(), should be larger than 0."); PADDLE_ENFORCE_GT(
workspace_bytes, 0UL,
platform::errors::InvalidArgument(
"Bytes of workspace got by warp-ctc function, "
"get_workspace_size() should be larger than 0, but received %d",
workspace_bytes));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL; size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL;
...@@ -88,10 +93,12 @@ class WarpCTCFunctor { ...@@ -88,10 +93,12 @@ class WarpCTCFunctor {
input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths, input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths,
static_cast<int>(sequence_width), static_cast<int>(num_sequences), static_cast<int>(sequence_width), static_cast<int>(num_sequences),
cpu_loss, workspace_data, options_); cpu_loss, workspace_data, options_);
PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status,
"warp-ctc [version %d] Error in compute_ctc_loss: ", PADDLE_ENFORCE_EQ(
warpctc_version_, CTC_STATUS_SUCCESS, status,
platform::dynload::ctcGetStatusString(status)); platform::errors::PreconditionNotMet(
"warp-ctc [version %d] Error in get_workspace_size: %s",
warpctc_version_, platform::dynload::ctcGetStatusString(status)));
} }
protected: protected:
...@@ -156,23 +163,40 @@ class WarpCTCKernel : public framework::OpKernel<T> { ...@@ -156,23 +163,40 @@ class WarpCTCKernel : public framework::OpKernel<T> {
labels_length_cpu.data<int64_t>()[i]); labels_length_cpu.data<int64_t>()[i]);
} }
} else { } else {
PADDLE_ENFORCE_GT(logits->NumLevels(), 0UL,
platform::errors::InvalidArgument(
"Input(Logits) Tensor of WarpCTC "
"does not contain LoD information."));
PADDLE_ENFORCE_GT(label->NumLevels(), 0UL,
platform::errors::InvalidArgument(
"Input(Label) Tensor of WarpCTC "
"does not contain LoD information."));
logits_lod = framework::ToAbsOffset(logits->lod())[0]; logits_lod = framework::ToAbsOffset(logits->lod())[0];
auto logits_dims = logits->dims(); auto logits_dims = logits->dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
logits_dims[0], static_cast<int64_t>(logits_lod.back()), logits_dims[0], static_cast<int64_t>(logits_lod.back()),
"The first dimension of Input(Logits) should be equal to " platform::errors::InvalidArgument(
"the sum of all sequences' lengths."); "The first dimension of Input(Logits) should be equal to "
"the sum of all sequences' lengths = %d., but received %d. ",
static_cast<int64_t>(logits_lod.back()), logits_dims[0]));
label_lod = framework::ToAbsOffset(label->lod())[0]; label_lod = framework::ToAbsOffset(label->lod())[0];
auto label_dims = label->dims(); auto label_dims = label->dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(label_dims[1], 1,
label_dims[0], label->numel(), platform::errors::InvalidArgument(
"The width of each timestep in Input(Label) should be 1."); "The last dimension of Input(Label) should be 1, "
"but received %d",
label_dims[1]));
num_sequences = logits_lod.size() - 1; num_sequences = logits_lod.size() - 1;
PADDLE_ENFORCE_EQ(num_sequences, label_lod.size() - 1, PADDLE_ENFORCE_EQ(
"The number of sequences of Input(Logits) should be " num_sequences, label_lod.size() - 1,
"equal to that of Input(Label)."); platform::errors::InvalidArgument(
"The number of sequences of Input(Logits) should be "
"equal to that of Input(Label) = %d, but received %d",
label_lod.size() - 1, num_sequences));
sequence_width = logits->numel() / logits_dims[0]; sequence_width = logits->numel() / logits_dims[0];
max_sequence_length = math::MaximumSequenceLength(logits_lod); max_sequence_length = math::MaximumSequenceLength(logits_lod);
......
...@@ -616,8 +616,14 @@ def warpctc(input, ...@@ -616,8 +616,14 @@ def warpctc(input,
print(output) print(output)
""" """
helper = LayerHelper('warpctc', **locals()) helper = LayerHelper('warpctc', **locals())
check_variable_and_dtype(input, 'input', ['float32'], "warpctc")
check_variable_and_dtype(label, 'label', ['int32'], "warpctc")
this_inputs = {'Logits': [input], 'Label': [label]} this_inputs = {'Logits': [input], 'Label': [label]}
if input_length is not None and label_length is not None: if input_length is not None and label_length is not None:
check_variable_and_dtype(input_length, 'LogitsLength', ['int64'],
"warpctc")
check_variable_and_dtype(label_length, 'LabelLength', ['int64'],
"warpctc")
this_inputs['LogitsLength'] = [input_length] this_inputs['LogitsLength'] = [input_length]
this_inputs['LabelLength'] = [label_length] this_inputs['LabelLength'] = [label_length]
......
...@@ -12463,6 +12463,8 @@ def add_position_encoding(input, alpha, beta, name=None): ...@@ -12463,6 +12463,8 @@ def add_position_encoding(input, alpha, beta, name=None):
""" """
helper = LayerHelper('add_position_encoding', **locals()) helper = LayerHelper('add_position_encoding', **locals())
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
"add_position_encoding")
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import six import six
from . import layers from . import layers
from .data_feeder import check_variable_and_dtype, convert_dtype
__all__ = [ __all__ = [
"simple_img_conv_pool", "simple_img_conv_pool",
...@@ -410,9 +411,10 @@ def scaled_dot_product_attention(queries, ...@@ -410,9 +411,10 @@ def scaled_dot_product_attention(queries,
Multi-Head Attention. Multi-Head Attention.
Raises: Raises:
TypeError: The dtype of inputs keys, values and queries should be the same.
ValueError: Inputs queries, keys and values should all be 3-D tensors. ValueError: Inputs queries, keys and values should all be 3-D tensors.
ValueError: The hidden size of queries and keys should be the same. ValueError: The hidden size of queries and keys should be the same.
ValueError: The max sequence length in query batch and in key batch should be the same. ValueError: The max sequence length in value batch and in key batch should be the same.
ValueError: he hidden size of keys must be divisible by the number of attention heads. ValueError: he hidden size of keys must be divisible by the number of attention heads.
ValueError: he hidden size of values must be divisible by the number of attention heads. ValueError: he hidden size of values must be divisible by the number of attention heads.
...@@ -427,17 +429,38 @@ def scaled_dot_product_attention(queries, ...@@ -427,17 +429,38 @@ def scaled_dot_product_attention(queries,
contexts = fluid.nets.scaled_dot_product_attention(queries, keys, values) contexts = fluid.nets.scaled_dot_product_attention(queries, keys, values)
contexts.shape # [3, 5, 10] contexts.shape # [3, 5, 10]
""" """
check_variable_and_dtype(queries, 'queries', ['float32', 'float64'],
"scaled_dot_product_attention")
check_variable_and_dtype(keys, 'keys', ['float32', 'float64'],
"scaled_dot_product_attention")
check_variable_and_dtype(values, 'values', ['float32', 'float64'],
"scaled_dot_product_attention")
if not (queries.dtype == keys.dtype == values.dtype):
raise TypeError(
"The dtype of keys, values and queries should be the same."
"But received queries.dtype = %s, "
" keys.dtype = %s, values.dtype) = %s." %
(convert_dtype(queries.dtype), convert_dtype(keys.dtype),
convert_dtype(values.dtype)))
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError( raise ValueError(
"Inputs queries, keys and values should all be 3-D tensors.") "Inputs queries, keys and values should all be 3-D tensors."
"But received len(queries.shape) = %d, "
"len(keys.shape) = %d, len(values.shape) = %d." %
(len(queries.shape), len(keys.shape), len(values.shape)))
if queries.shape[-1] != keys.shape[-1]: if queries.shape[-1] != keys.shape[-1]:
raise ValueError( raise ValueError(
"The hidden size of queries and keys should be the same.") "The hidden size of queries and keys should be the same."
"But received queries' hidden size = %d and keys' hidden size = %d."
% (queries.shape[-1], keys.shape[-1]))
if keys.shape[-2] != values.shape[-2]: if keys.shape[-2] != values.shape[-2]:
raise ValueError( raise ValueError(
"The max sequence length in query batch and in key batch " "The max sequence length in value batch and in key batch "
"should be the same.") "should be the same. But received max sequence length in value batch "
"= %d, in key batch = %d." % (values.shape[-2], keys.shape[-2]))
if keys.shape[-1] % num_heads != 0: if keys.shape[-1] % num_heads != 0:
raise ValueError("The hidden size of keys (%d) must be divisible " raise ValueError("The hidden size of keys (%d) must be divisible "
"by the number of attention heads (%d)." % "by the number of attention heads (%d)." %
......
...@@ -16,6 +16,8 @@ import numpy as np ...@@ -16,6 +16,8 @@ import numpy as np
import math import math
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestAddPositionEncodingTensorOp(OpTest): class TestAddPositionEncodingTensorOp(OpTest):
...@@ -130,5 +132,18 @@ class TestAddPositionEncodingLoDTensorOp(OpTest): ...@@ -130,5 +132,18 @@ class TestAddPositionEncodingLoDTensorOp(OpTest):
start += max_length start += max_length
class TestAddPositionEncodingOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
input_data = np.random.random((4, 16, 8)).astype("float32")
def test_Variable():
# the input type must be Variable
fluid.layers.add_position_encoding(
input=input_data, alpha=1.0, beta=1.0)
self.assertRaises(TypeError, test_Variable)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestScaledDotProductAttentionError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
queries = fluid.data(
name="queries", shape=[3, 5, 9], dtype="float32")
keys = fluid.data(name="keys", shape=[3, 6, 9], dtype="float32")
values = fluid.data(
name="values", shape=[3, 6, 10], dtype="float32")
def test_queries_Variable():
queries_data = np.random.rand(3, 5, 9).astype("float32")
fluid.nets.scaled_dot_product_attention(queries_data, keys,
values)
self.assertRaises(TypeError, test_queries_Variable)
def test_keys_Variable():
keys_data = np.random.rand(3, 6, 9).astype("float32")
fluid.nets.scaled_dot_product_attention(queries, keys_data,
values)
self.assertRaises(TypeError, test_keys_Variable)
def test_values_Variable():
values_data = np.random.rand(3, 6, 10).astype("float32")
fluid.nets.scaled_dot_product_attention(queries, keys,
values_data)
self.assertRaises(TypeError, test_values_Variable)
def test_diff_dtype():
keys_error = fluid.data(
name="keys_error", shape=[3, 6, 9], dtype="float64")
values_error = fluid.data(
name="values_error", shape=[3, 6, 10], dtype="float64")
fluid.nets.scaled_dot_product_attention(queries, keys_error,
values_error)
self.assertRaises(TypeError, test_diff_dtype)
def test_diff_dim():
keys_error_dim = fluid.data(
name="keys_error_dim", shape=[3, 6], dtype="float32")
values_error_dim = fluid.data(
name="values_error_dim", shape=[3], dtype="float32")
fluid.nets.scaled_dot_product_attention(queries, keys_error_dim,
values_error_dim)
self.assertRaises(ValueError, test_diff_dim)
def test_diff_hidden_size():
queries_error_hs = fluid.data(
name="queries_error_hs", shape=[3, 5, 9], dtype="float32")
keys_error_hs = fluid.data(
name="keys_error_hs", shape=[3, 6, 10], dtype="float32")
fluid.nets.scaled_dot_product_attention(queries_error_hs,
keys_error_hs, values)
self.assertRaises(ValueError, test_diff_hidden_size)
def test_diff_max_len():
keys_error_len = fluid.data(
name="keys_error_len", shape=[3, 7, 9], dtype="float32")
values_error_len = fluid.data(
name="values_error_len", shape=[3, 6, 10], dtype="float32")
fluid.nets.scaled_dot_product_attention(queries, keys_error_len,
values_error_len)
self.assertRaises(ValueError, test_diff_max_len)
if __name__ == "__main__":
unittest.main()
...@@ -19,6 +19,8 @@ import unittest ...@@ -19,6 +19,8 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
from test_softmax_op import stable_softmax from test_softmax_op import stable_softmax
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
CUDA_BLOCK_SIZE = 512 CUDA_BLOCK_SIZE = 512
...@@ -335,5 +337,57 @@ class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding): ...@@ -335,5 +337,57 @@ class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding):
self.norm_by_times = False self.norm_by_times = False
class TestWarpCTCOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
logits = fluid.data(
name='logits', shape=[5, 16, 6], dtype='float32')
logits_length = fluid.data(
name='logits_length', shape=[None], dtype='int64')
label = fluid.data(name='label', shape=[16, 3], dtype='int32')
label_length = fluid.data(
name='labels_length', shape=[None], dtype='int64')
def test_logits_Variable():
logits_data = np.random.rand(5, 16, 6).astype("float32")
fluid.layers.warpctc(
input=logits_data,
label=label,
input_length=logits_length,
label_length=label_length)
self.assertRaises(TypeError, test_logits_Variable)
def test_label_Variable():
label_data = np.random.randint(0, 5, [5, 1]).astype("int32")
fluid.layers.warpctc(
input=logits,
label=label_data,
input_length=logits_length,
label_length=label_length)
self.assertRaises(TypeError, test_label_Variable)
def test_logits_len_Variable():
logits_length_data = np.array([5] * 16).astype("int64")
fluid.layers.warpctc(
input=logits,
label=label,
input_length=logits_length_data,
label_length=label_length)
self.assertRaises(TypeError, test_logits_len_Variable)
def test_label_len_Variable():
label_length_data = np.array([3] * 16).astype("int64")
fluid.layers.warpctc(
input=logits,
label=label,
input_length=logits_length,
label_length=label_length_data)
self.assertRaises(TypeError, test_label_len_Variable)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册