From 1501a80f74a6e9c129889c3c48bdf86829105e1c Mon Sep 17 00:00:00 2001 From: Li Fuchen Date: Sun, 27 Sep 2020 19:28:52 +0800 Subject: [PATCH] add support to float64 input of warpctc op. (#27399) * add float64 input to ctc_loss * modified error message of warpctc * update repo and tag of warpctc * add test for warpctc with float64 input * modified warpctc.cmake to make sure build always * resolved sample code bug of warpctc * add core.ops in warpctc dygraph * fix a bug of test --- cmake/external/warpctc.cmake | 5 +- paddle/fluid/operators/math/sequence_scale.cc | 1 + paddle/fluid/operators/math/sequence_scale.cu | 1 + paddle/fluid/operators/warpctc_op.cc | 10 +- paddle/fluid/operators/warpctc_op.cu.cc | 6 +- paddle/fluid/operators/warpctc_op.h | 84 ++++++++++--- paddle/fluid/platform/dynload/warpctc.h | 4 +- paddle/fluid/pybind/op_function_generator.cc | 1 + python/paddle/fluid/layers/loss.py | 23 +++- .../fluid/tests/unittests/test_warpctc_op.py | 119 ++++++++++++++++-- python/paddle/nn/functional/loss.py | 2 +- python/paddle/nn/layer/loss.py | 2 +- 12 files changed, 221 insertions(+), 37 deletions(-) diff --git a/cmake/external/warpctc.cmake b/cmake/external/warpctc.cmake index ac6cf624e82..7f2ab1fb11d 100644 --- a/cmake/external/warpctc.cmake +++ b/cmake/external/warpctc.cmake @@ -18,7 +18,7 @@ SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc) SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc) SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc) set(WARPCTC_REPOSITORY https://github.com/baidu-research/warp-ctc.git) -set(WARPCTC_TAG fc7f226b93758216a03b1be9d24593a12819b984) +set(WARPCTC_TAG 95a461eddeabd51099ef059dcfada1117eb1bfb8) SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" CACHE PATH "Warp-ctc Directory" FORCE) @@ -44,8 +44,9 @@ ExternalProject_Add( "${WARPCTC_DOWNLOAD_CMD}" PREFIX ${WARPCTC_PREFIX_DIR} SOURCE_DIR ${WARPCTC_SOURCE_DIR} - UPDATE_COMMAND "" + #UPDATE_COMMAND "" PATCH_COMMAND "" + BUILD_ALWAYS 1 CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} diff --git a/paddle/fluid/operators/math/sequence_scale.cc b/paddle/fluid/operators/math/sequence_scale.cc index 78cbdf311ad..8e58411a1f2 100644 --- a/paddle/fluid/operators/math/sequence_scale.cc +++ b/paddle/fluid/operators/math/sequence_scale.cc @@ -46,6 +46,7 @@ class ScaleLoDTensorFunctor { }; template class ScaleLoDTensorFunctor; +template class ScaleLoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_scale.cu b/paddle/fluid/operators/math/sequence_scale.cu index 079338c1d3d..4a952afe15f 100644 --- a/paddle/fluid/operators/math/sequence_scale.cu +++ b/paddle/fluid/operators/math/sequence_scale.cu @@ -52,6 +52,7 @@ class ScaleLoDTensorFunctor { }; template class ScaleLoDTensorFunctor; +template class ScaleLoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index 5dcbabc96b4..f043b017949 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -103,13 +103,13 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker { "Target sequence length for Label when Label is a 2-D tensor.") .AsDispensable(); AddOutput("WarpCTCGrad", - "(Tensor, default: Tensor), a temporary " + "(Tensor), a temporary " "output Tensor to store the gradients of warp-ctc, which is " "computed with loss together in one call. It is a 3-D Tensor of " "the shape [max_sequence_length, batch_size, num_classes + 1].") .AsIntermediate(); AddOutput("Loss", - "(Tensor, default: Tensor), the Connectionist " + "(Tensor), the Connectionist " "Temporal Classification (CTC) loss, which is a 2-D Tensor of " "the shape [batch_size, 1]"); AddAttr("blank", @@ -197,7 +197,9 @@ REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker, REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp, ops::WarpCTCGradOpNoNeedBufferVarInferer); REGISTER_OP_CPU_KERNEL( - warpctc, ops::WarpCTCKernel); + warpctc, ops::WarpCTCKernel, + ops::WarpCTCKernel); REGISTER_OP_CPU_KERNEL( warpctc_grad, - ops::WarpCTCGradKernel); + ops::WarpCTCGradKernel, + ops::WarpCTCGradKernel); diff --git a/paddle/fluid/operators/warpctc_op.cu.cc b/paddle/fluid/operators/warpctc_op.cu.cc index 6f8559f542f..a42093aaa29 100644 --- a/paddle/fluid/operators/warpctc_op.cu.cc +++ b/paddle/fluid/operators/warpctc_op.cu.cc @@ -16,7 +16,9 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - warpctc, ops::WarpCTCKernel); + warpctc, ops::WarpCTCKernel, + ops::WarpCTCKernel); REGISTER_OP_CUDA_KERNEL( warpctc_grad, - ops::WarpCTCGradKernel); + ops::WarpCTCGradKernel, + ops::WarpCTCGradKernel); diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index 951a258fd21..8b9276d4fa0 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -27,7 +27,52 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; +template +class ComputeCtcLossFunctor { + public: + ctcStatus_t operator()(const T* const activations, T* gradients, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, int alphabet_size, + int minibatch, T* costs, void* workspace, + ctcOptions options) { + return CTC_STATUS_EXECUTION_FAILED; + } +}; + +template +class ComputeCtcLossFunctor { + public: + ctcStatus_t operator()(const float* const activations, float* gradients, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, int alphabet_size, + int minibatch, float* costs, void* workspace, + ctcOptions options) { + return platform::dynload::compute_ctc_loss( + activations, gradients, flat_labels, label_lengths, input_lengths, + static_cast(alphabet_size), static_cast(minibatch), costs, + workspace, options); + } +}; + template +class ComputeCtcLossFunctor { + public: + ctcStatus_t operator()(const double* const activations, double* gradients, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, int alphabet_size, + int minibatch, double* costs, void* workspace, + ctcOptions options) { + return platform::dynload::compute_ctc_loss_double( + activations, gradients, flat_labels, label_lengths, input_lengths, + static_cast(alphabet_size), static_cast(minibatch), costs, + workspace, options); + } +}; + +template class WarpCTCFunctor { public: /* @@ -51,21 +96,29 @@ class WarpCTCFunctor { * \param blank blank label used in ctc loss function. * \param cpu_losss cost of each sequence in CPU memory. */ - void operator()(const framework::ExecutionContext& ctx, const float* input, - float* gradient, const int* cpu_labels, + void operator()(const framework::ExecutionContext& ctx, const T* input, + T* gradient, const int* cpu_labels, const int* cpu_label_lengths, const int* cpu_input_lengths, const size_t sequence_width, const size_t num_sequences, - const size_t blank, float* cpu_loss) { + const size_t blank, T* cpu_loss) { // Init warp-ctc options init(ctx, blank); // Compute the required workspace size. // There is no memory allocated operations within warp-ctc. size_t workspace_bytes = 0; - ctcStatus_t status = platform::dynload::get_workspace_size( - cpu_label_lengths, cpu_input_lengths, static_cast(sequence_width), - static_cast(num_sequences), options_, &workspace_bytes); - + ctcStatus_t status = CTC_STATUS_UNKNOWN_ERROR; + if (sizeof(T) == 4) { + status = platform::dynload::get_workspace_size( + cpu_label_lengths, cpu_input_lengths, + static_cast(sequence_width), static_cast(num_sequences), + options_, &workspace_bytes); + } else { + status = platform::dynload::get_workspace_size_double( + cpu_label_lengths, cpu_input_lengths, + static_cast(sequence_width), static_cast(num_sequences), + options_, &workspace_bytes); + } PADDLE_ENFORCE_EQ( CTC_STATUS_SUCCESS, status, platform::errors::PreconditionNotMet( @@ -79,17 +132,17 @@ class WarpCTCFunctor { workspace_bytes)); auto& dev_ctx = ctx.template device_context(); - size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL; - Tensor workspace = ctx.AllocateTmpTensor( + size_t workspace_elements = workspace_bytes / sizeof(T) + 1UL; + Tensor workspace = ctx.AllocateTmpTensor( framework::make_ddim({static_cast(workspace_elements)}), dev_ctx); - float* workspace_data = workspace.data(); - math::SetConstant()( + T* workspace_data = workspace.data(); + math::SetConstant()( ctx.template device_context(), &workspace, - static_cast(0)); + static_cast(0)); // compute loss and gradient - status = platform::dynload::compute_ctc_loss( + status = ComputeCtcLossFunctor()( input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths, static_cast(sequence_width), static_cast(num_sequences), cpu_loss, workspace_data, options_); @@ -112,7 +165,8 @@ class WarpCTCFunctor { ctx.device_context()) .stream(); #else - PADDLE_THROW("[warpctc init] GPU is not enabled."); + PADDLE_THROW(platform::errors::PreconditionNotMet( + "[warpctc init] GPU is not enabled.")); #endif } else { options_.loc = CTC_CPU; @@ -292,7 +346,7 @@ class WarpCTCKernel : public framework::OpKernel { const size_t blank = static_cast(ctx.Attr("blank")); - WarpCTCFunctor()( + WarpCTCFunctor()( ctx, warpctc_logits_data, warpctc_grad_data, warpctc_label_data, warpctc_label_lengths.data(), warpctc_logits_lengths.data(), sequence_width, num_sequences, blank, warpctc_loss_data); diff --git a/paddle/fluid/platform/dynload/warpctc.h b/paddle/fluid/platform/dynload/warpctc.h index e10a7233b62..5f1b7612117 100644 --- a/paddle/fluid/platform/dynload/warpctc.h +++ b/paddle/fluid/platform/dynload/warpctc.h @@ -53,7 +53,9 @@ extern void* warpctc_dso_handle; __macro(get_warpctc_version); \ __macro(ctcGetStatusString); \ __macro(compute_ctc_loss); \ - __macro(get_workspace_size) + __macro(compute_ctc_loss_double); \ + __macro(get_workspace_size); \ + __macro(get_workspace_size_double) WARPCTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP); diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index d3052ebd351..9bc603c0ecc 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -48,6 +48,7 @@ std::map> op_ins_map = { {"collect_fpn_proposals", {"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}}, {"distribute_fpn_proposals", {"FpnRois", "RoisNum"}}, + {"warpctc", {"Logits", "Label", "LogitsLength", "LabelLength"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index f468815c99e..037c7e85004 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -541,7 +541,7 @@ def warpctc(input, (not including the blank label). When it is a 3-D Tensor, its shape is `[max_logit_length, batch_size, num_classes + 1]`, where `max_logit_length` is the longest length of - input logit sequence. The data type must be float32. + input logit sequence. The data type should be float32 or float64. label (Variable): The ground truth of variable-length sequence, which must be a 2-D Tensor with LoD information or a 3-D Tensor without LoD information, needs to be consistent with the coressponding input. @@ -571,6 +571,7 @@ def warpctc(input, .. code-block:: python # using LoDTensor + import paddle import paddle.fluid as fluid import numpy as np @@ -581,6 +582,7 @@ def warpctc(input, # class num class_num = 5 + paddle.enable_static() logits = fluid.data(name='logits',shape=[None, class_num+1], dtype='float32',lod_level=1) label = fluid.data(name='label', shape=[None, 1], @@ -602,6 +604,7 @@ def warpctc(input, .. code-block:: python # using Tensor + import paddle import paddle.fluid as fluid import numpy as np @@ -613,6 +616,7 @@ def warpctc(input, batch_size = 16 # class num class_num = 5 + paddle.enable_static() logits = fluid.data(name='logits', shape=[max_seq_length, batch_size, class_num+1], dtype='float32') @@ -637,8 +641,23 @@ def warpctc(input, fetch_list=[cost.name]) print(output) """ + if in_dygraph_mode(): + if input_length is None or label_length is None: + raise ValueError( + "input_length and label_length must not be None in dygraph mode!" + ) + grad, loss_out = core.ops.warpctc( + input, + label, + input_length, + label_length, + 'blank', + blank, + 'norm_by_times', + norm_by_times, ) + return loss_out helper = LayerHelper('warpctc', **locals()) - check_variable_and_dtype(input, 'input', ['float32'], "warpctc") + check_variable_and_dtype(input, 'input', ['float32', 'float64'], "warpctc") check_variable_and_dtype(label, 'label', ['int32'], "warpctc") this_inputs = {'Logits': [input], 'Label': [label]} if input_length is not None and label_length is not None: diff --git a/python/paddle/fluid/tests/unittests/test_warpctc_op.py b/python/paddle/fluid/tests/unittests/test_warpctc_op.py index c4155e0d826..b82ab04c986 100644 --- a/python/paddle/fluid/tests/unittests/test_warpctc_op.py +++ b/python/paddle/fluid/tests/unittests/test_warpctc_op.py @@ -24,7 +24,7 @@ from paddle.fluid import Program, program_guard import paddle import paddle.nn.functional as F -CUDA_BLOCK_SIZE = 512 +CUDA_BLOCK_SIZE = 32 class CTCForward(object): @@ -41,8 +41,8 @@ class CTCForward(object): self.num_classes = num_classes self.batch_size = batch_size - self.loss = np.zeros([self.batch_size, 1], dtype="float32") - self.gradient = np.zeros(self.softmax.shape, dtype="float32") + self.loss = np.zeros([self.batch_size, 1], dtype=softmax.dtype) + self.gradient = np.zeros(self.softmax.shape, dtype=softmax.dtype) # float64 self.EXP_MAX = sys.float_info.max @@ -112,13 +112,15 @@ class CTCForward(object): # calculate the forward and backward variables, # reference Chapter 7.3 of "Alex Grave, Supervised Sequence # Labelling with Recurrent Neural Networks" - log_acts = np.zeros([total_times, self.num_classes], dtype="float32") + log_acts = np.zeros( + [total_times, self.num_classes], dtype=softmax_a_sequence.dtype) for i in range(total_times): for j in range(self.num_classes): log_acts[i, j] = self.safe_log(softmax_a_sequence[i, j]) # calculate the forward variables - forward_vars = np.zeros([total_times, total_segments], dtype="float32") + forward_vars = np.zeros( + [total_times, total_segments], dtype=softmax_a_sequence.dtype) for i in range(total_times): for j in range(total_segments): forward_vars[i, j] = self.LOG_ZERO @@ -219,7 +221,7 @@ class TestWarpCTCOp(OpTest): self.logits_lod[0][i]) self.gradient = np.zeros( [max_sequence_length, self.batch_size, self.num_classes], - dtype="float32") + dtype=logits.dtype) self.inputs = { "Logits": (logits, self.logits_lod), @@ -287,7 +289,7 @@ class TestWarpCTCOpWithPadding(OpTest): # reshape logits to T*N*S new_logits = np.zeros( [max_sequence_length, self.batch_size, self.num_classes], - dtype="float32") + dtype=logits.dtype) cur = 0 for batch_id in range(self.batch_size): @@ -312,7 +314,7 @@ class TestWarpCTCOpWithPadding(OpTest): self.gradient = np.zeros( [max_sequence_length, self.batch_size, self.num_classes], - dtype="float32") + dtype=logits.dtype) self.inputs = { "Logits": new_logits, @@ -347,6 +349,90 @@ class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding): self.norm_by_times = False +class TestWarpCTCOpFp64(OpTest): + def config(self): + self.batch_size = 4 + self.num_classes = 8 + self.logits_lod = [[4, 1, 5, 5]] + self.labels_lod = [[3, 1, 4, 2]] + self.logits_length = np.array([4, 1, 5, 5], dtype=np.int64) + self.labels_length = np.array([3, 1, 4, 2], dtype=np.int64) + self.blank = self.num_classes - 1 + self.norm_by_times = False + + def setUp(self): + self.op_type = "warpctc" + self.config() + + logits = np.random.uniform( + 0.1, 1.0, + [sum(self.logits_length), self.num_classes]).astype("float64") + softmax = np.apply_along_axis(stable_softmax, 1, logits) + # labels should not be blank + labels = np.random.randint( + 0, + self.num_classes - 1, [sum(self.labels_length), 1], + dtype="int32") + + ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod, + self.num_classes, self.batch_size, self.blank, + self.norm_by_times) + loss = ctc.forward() + + max_sequence_length = 0 + for i in range(self.batch_size): + max_sequence_length = max(max_sequence_length, + self.logits_length[i]) + # reshape logits to T*N*S + new_logits = np.zeros( + [max_sequence_length, self.batch_size, self.num_classes], + dtype=logits.dtype) + + cur = 0 + for batch_id in range(self.batch_size): + for i in range(self.logits_length[batch_id]): + for j in range(self.num_classes): + new_logits[i, batch_id, j] = logits[cur + i, j] + cur = cur + self.logits_length[batch_id] + + # reshape labels to N*S + max_target_seq_length = 0 + for i in range(self.batch_size): + max_target_seq_length = max(max_target_seq_length, + self.labels_length[i]) + new_labels = np.zeros( + [self.batch_size, max_target_seq_length], dtype="int32") + + cur = 0 + for batch_id in range(self.batch_size): + for i in range(self.labels_length[batch_id]): + new_labels[batch_id, i] = labels[cur + i] + cur = cur + self.labels_length[batch_id] + + self.gradient = np.zeros( + [max_sequence_length, self.batch_size, self.num_classes], + dtype=logits.dtype) + + self.inputs = { + "Logits": new_logits, + "Label": new_labels, + "LogitsLength": self.logits_length, + "LabelLength": self.labels_length + } + self.outputs = {"Loss": loss} + self.attrs = { + "blank": self.blank, + "norm_by_times": self.norm_by_times, + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.outputs['WarpCTCGrad'] = self.gradient + self.check_grad(["Logits"], "Loss") + + class TestWarpCTCOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): @@ -359,7 +445,7 @@ class TestWarpCTCOpError(unittest.TestCase): name='labels_length', shape=[None], dtype='int64') def test_logits_Variable(): - logits_data = np.random.rand(5, 16, 6).astype("float32") + logits_data = np.random.rand(5, 16, 6).astype(logits.dtype) fluid.layers.warpctc( input=logits_data, label=label, @@ -398,6 +484,21 @@ class TestWarpCTCOpError(unittest.TestCase): self.assertRaises(TypeError, test_label_len_Variable) + def test_dygraph_errors(self): + def test_dygraph_with_lod(): + + logits = np.random.uniform(0.1, 1.0, [20, 15]).astype("float32") + # labels should not be blank + labels = np.random.randint(0, 15 - 1, [15, 1], dtype="int32") + softmax = paddle.to_variable(logits) + labels = paddle.to_variable(labels) + + fluid.layers.warpctc(input=softmax, label=labels) + + paddle.disable_static() + self.assertRaises(ValueError, test_dygraph_with_lod) + paddle.enable_static() + class TestCTCLossAPICase(unittest.TestCase): def test_functinal_api(self): diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 4395520eec7..d27bac14d0a 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -933,7 +933,7 @@ def ctc_loss(log_probs, is interated to the Warp-CTC library to normalize values for each row of the input tensor. Parameters: - log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type must be float32. + log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type should be float32 or float64. labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32. input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64. label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64. diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 271dc9b4e68..98048bb7e64 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -773,7 +773,7 @@ class CTCLoss(fluid.dygraph.Layer): reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``. Shape: - log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type must be float32. + log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type should be float32 or float64. labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32. input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64. label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64. -- GitLab