未验证 提交 f9065e15 编写于 作者: T tianshuo78520a 提交者: GitHub

del sequence_enumerate_op (#54177)

* del sequence_enumerate_op

* del analyzer_pyramid_dnn_tester

* fix
上级 ecda253a
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h"
namespace paddle {
namespace operators {
class SequenceEnumerateOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceEnumerate");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SequenceEnumerate");
const auto x_dims = ctx->GetInputDim("X");
const auto win_size = ctx->Attrs().Get<int>("win_size");
ctx->SetOutputDim("Out", {x_dims[0], win_size});
ctx->ShareLoD("X", "Out");
}
};
class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(2-D phi::DenseTensor with the 2nd dimension equal to 1) "
"Input phi::DenseTensor of SequenceEnumerate operator.");
AddOutput("Out",
"(2-D phi::DenseTensor with the 2nd dimension equal to win_size) "
"Output phi::DenseTensor of SequenceEnumerate operator.");
AddAttr<int>("win_size", "(int) The enumerate sequence window size.")
.AddCustomChecker([](const int& win_size) {
PADDLE_ENFORCE_GE(win_size,
2,
platform::errors::InvalidArgument(
"The window size should be not less than 2."
"Received window size is %d",
win_size));
});
AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.")
.SetDefault(0);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
"Skip calling InferShape() function in the runtime.")
.SetDefault(true);
AddComment(R"DOC(
Sequence Enumerate Operator.
Generate a new sequence for the input index sequence, which enumerates all the
sub-sequences with length `win_size` of the input.
The enumerated sequence has the same 1st dimension with variable `input`, and
the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation.
Examples:
Case 1:
Input:
X.lod = [[0, 3, 5]]
X.data = [[1], [2], [3], [4], [5]]
X.dims = [5, 1]
Attrs:
win_size = 2
pad_value = 0
Output:
Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
Out.dims = [5, 2]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sequence_enumerate,
ops::SequenceEnumerateOp,
ops::SequenceEnumerateOpMaker);
PD_REGISTER_STRUCT_KERNEL(sequence_enumerate,
CPU,
ALL_LAYOUT,
ops::SequenceEnumerateKernel,
int32_t,
int64_t) {}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T>
__global__ void CalcOutPut(const T* in_data,
const size_t* in_lod,
const size_t lod_len,
const int64_t win_size,
const int64_t pad_value,
T* out_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < in_lod[lod_len - 1]) {
int end_idx = 0;
// Get LoD interval of index
for (int i = 1; i < lod_len; ++i) {
if (index < in_lod[i]) {
end_idx = in_lod[i];
break;
}
}
for (size_t i = 0; i < win_size; ++i) {
int word_pos = index + i;
out_data[index * win_size + i] =
word_pos < end_idx ? in_data[word_pos] : pad_value;
}
}
}
template <typename T, typename DeviceContext>
class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<phi::DenseTensor>("X");
auto* out = context.Output<phi::DenseTensor>("Out");
int win_size = context.Attr<int>("win_size");
int pad_value = context.Attr<int>("pad_value");
auto in_dims = in->dims();
auto in_lod = in->lod();
PADDLE_ENFORCE_EQ(
static_cast<uint64_t>(in_dims[0]),
in_lod[0].back(),
platform::errors::InvalidArgument(
"The actual input data's size mismatched with LoD information."
"Received input data size is %d (actual) vs %d (loD information).",
static_cast<uint64_t>(in_dims[0]),
in_lod[0].back()));
/* Generate enumerate sequence set */
auto stream = context.cuda_device_context().stream();
auto lod0 = in_lod[0];
auto in_len = in->numel();
auto in_data = in->data<T>();
out->Resize({in_dims[0], win_size});
auto out_data = out->mutable_data<T>(context.GetPlace());
// Copy LoD to GPU
phi::MixVector<size_t> mixv_lod0(&lod0);
const size_t* dev_in_lod_ptr = mixv_lod0.CUDAData(context.GetPlace());
// Calc output tensor
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data);
out->set_lod(in->lod());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
PD_REGISTER_STRUCT_KERNEL(sequence_enumerate,
GPU,
ALL_LAYOUT,
ops::SequenceEnumerateOpCUDAKernel,
int32_t,
int64_t) {}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T, typename DeviceContext>
class SequenceEnumerateKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<phi::DenseTensor>("X");
auto* out = context.Output<phi::DenseTensor>("Out");
int win_size = context.Attr<int>("win_size");
auto pad_value = static_cast<T>(context.Attr<int>("pad_value"));
PADDLE_ENFORCE_EQ(
in->lod().empty(),
false,
platform::errors::InvalidArgument(
"Input(X) phi::DenseTensor of SequenceEnumerateOp does not contain "
"LoD information."));
auto in_dims = phi::vectorize<int>(in->dims());
auto lod0 = in->lod()[0];
PADDLE_ENFORCE_EQ(
static_cast<uint64_t>(in_dims[0]),
lod0.back(),
platform::errors::InvalidArgument(
"The actual input data's size mismatched with LoD information."
"Received input data size is %d (actual) vs %d (loD information).",
static_cast<uint64_t>(in_dims[0]),
lod0.back()));
PADDLE_ENFORCE_EQ(
in_dims.size(),
2UL,
platform::errors::InvalidArgument(
"Input(X) of SequenceEnumerate operator's rank should be 2."
"Received %d instead.",
in_dims.size()));
PADDLE_ENFORCE_EQ(in_dims[1],
1,
platform::errors::InvalidArgument(
"Input(X) of SequenceEnumerate operator's 2nd "
"dimension should be 1. Received %d instead.",
in_dims[1]));
// Generate enumerate sequence set
auto in_data = in->data<T>();
out->Resize({in_dims[0], win_size});
out->set_lod(in->lod());
auto out_data = out->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < lod0.size() - 1; ++i) {
if (lod0[i] == lod0[i + 1]) continue;
int start = lod0[i];
int end = lod0[i + 1];
int copy_size = win_size < end - start + 1 ? win_size : end - start + 1;
int mid = end + 1 - copy_size;
int pad_num = win_size - copy_size;
copy_size *= sizeof(T);
for (int idx = start; idx < mid; ++idx) {
std::memcpy(out_data, in_data + idx, copy_size);
out_data += win_size;
}
for (int idx = mid; idx < end; ++idx) {
copy_size -= sizeof(T);
pad_num++;
std::memcpy(out_data, in_data + idx, copy_size);
T* pdata = out_data + copy_size / sizeof(T);
for (int i = 0; i < pad_num; ++i) {
pdata[i] = pad_value;
}
out_data += win_size;
}
}
}
};
} // namespace operators
} // namespace paddle
...@@ -8,7 +8,6 @@ register_unity_group( ...@@ -8,7 +8,6 @@ register_unity_group(
cc cc
sequence_concat_op.cc sequence_concat_op.cc
sequence_conv_op.cc sequence_conv_op.cc
sequence_enumerate_op.cc
sequence_erase_op.cc sequence_erase_op.cc
sequence_expand_op.cc sequence_expand_op.cc
sequence_mask_op.cc sequence_mask_op.cc
...@@ -25,7 +24,6 @@ register_unity_group( ...@@ -25,7 +24,6 @@ register_unity_group(
sequence_conv_op.cu.cc) sequence_conv_op.cu.cc)
register_unity_group( register_unity_group(
cu cu
sequence_enumerate_op.cu
sequence_erase_op.cu sequence_erase_op.cu
sequence_expand_op.cu sequence_expand_op.cu
sequence_pad_op.cu sequence_pad_op.cu
......
...@@ -427,15 +427,6 @@ if(WITH_TESTING AND WITH_INFERENCE_API_TEST) ...@@ -427,15 +427,6 @@ if(WITH_TESTING AND WITH_INFERENCE_API_TEST)
inference_analysis_api_test(test_analyzer_lac ${LAC_INSTALL_DIR} inference_analysis_api_test(test_analyzer_lac ${LAC_INSTALL_DIR}
analyzer_lac_tester.cc) analyzer_lac_tester.cc)
# Pyramid DNN
set(PYRAMID_DNN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/pyramid_dnn")
download_model_and_data_without_verify(
${PYRAMID_DNN_INSTALL_DIR} "PyramidDNN_model.tar.gz"
"PyramidDNN_data.txt.tar.gz")
inference_analysis_api_test(
test_analyzer_pyramid_dnn ${PYRAMID_DNN_INSTALL_DIR}
analyzer_pyramid_dnn_tester.cc)
# Ernie # Ernie
set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie") set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie")
download_model_and_data( download_model_and_data(
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "test/cpp/inference/api/tester_helper.h"
namespace paddle {
namespace inference {
struct DataRecord {
std::vector<std::vector<int64_t>> query_basic, query_phrase, title_basic,
title_phrase;
std::vector<size_t> lod1, lod2, lod3, lod4;
size_t batch_iter{0}, batch_size{1}, num_samples; // total number of samples
DataRecord() = default;
explicit DataRecord(const std::string &path, int batch_size = 1)
: batch_size(batch_size) {
Load(path);
}
DataRecord NextBatch() {
DataRecord data;
size_t batch_end = batch_iter + batch_size;
// NOTE skip the final batch, if no enough data is provided.
if (batch_end <= query_basic.size()) {
GetInputPerBatch(
query_basic, &data.query_basic, &data.lod1, batch_iter, batch_end);
GetInputPerBatch(
query_phrase, &data.query_phrase, &data.lod2, batch_iter, batch_end);
GetInputPerBatch(
title_basic, &data.title_basic, &data.lod3, batch_iter, batch_end);
GetInputPerBatch(
title_phrase, &data.title_phrase, &data.lod4, batch_iter, batch_end);
}
batch_iter += batch_size;
return data;
}
void Load(const std::string &path) {
std::ifstream file(path);
std::string line;
int num_lines = 0;
while (std::getline(file, line)) {
std::vector<std::string> data;
split(line, ';', &data);
// load query data
std::vector<int64_t> query_basic_data;
split_to_int64(data[1], ' ', &query_basic_data);
std::vector<int64_t> query_phrase_data;
split_to_int64(data[2], ' ', &query_phrase_data);
// load title data
std::vector<int64_t> title_basic_data;
split_to_int64(data[3], ' ', &title_basic_data);
std::vector<int64_t> title_phrase_data;
split_to_int64(data[4], ' ', &title_phrase_data);
// filter the empty data
bool flag =
data[1].size() && data[2].size() && data[3].size() && data[4].size();
if (flag) {
query_basic.push_back(std::move(query_basic_data));
query_phrase.push_back(std::move(query_phrase_data));
title_basic.push_back(std::move(title_basic_data));
title_phrase.push_back(std::move(title_phrase_data));
num_lines++;
}
}
num_samples = num_lines;
}
};
void PrepareInputs(std::vector<PaddleTensor> *input_slots,
DataRecord *data,
int batch_size) {
PaddleTensor query_basic_tensor, query_phrase_tensor, title_basic_tensor,
title_phrase_tensor;
query_basic_tensor.name = "query_basic";
query_phrase_tensor.name = "query_phrase";
title_basic_tensor.name = "pos_title_basic";
title_phrase_tensor.name = "pos_title_phrase";
auto one_batch = data->NextBatch();
// assign data
TensorAssignData<int64_t>(
&query_basic_tensor, one_batch.query_basic, one_batch.lod1);
TensorAssignData<int64_t>(
&query_phrase_tensor, one_batch.query_phrase, one_batch.lod2);
TensorAssignData<int64_t>(
&title_basic_tensor, one_batch.title_basic, one_batch.lod3);
TensorAssignData<int64_t>(
&title_phrase_tensor, one_batch.title_phrase, one_batch.lod4);
// Set inputs.
input_slots->assign({query_basic_tensor,
query_phrase_tensor,
title_basic_tensor,
title_phrase_tensor});
for (auto &tensor : *input_slots) {
tensor.dtype = PaddleDType::INT64;
}
}
void SetConfig(AnalysisConfig *cfg) {
cfg->SetModel(FLAGS_infer_model);
cfg->DisableGpu();
cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim();
cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false);
}
}
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
std::vector<PaddleTensor> input_slots;
int epoch = FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1;
LOG(INFO) << "number of samples: " << epoch * FLAGS_batch_size;
for (int bid = 0; bid < epoch; ++bid) {
PrepareInputs(&input_slots, &data, FLAGS_batch_size);
(*inputs).emplace_back(input_slots);
}
}
// Easy for profiling independently.
TEST(Analyzer_Pyramid_DNN, profile) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all,
&outputs,
FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data && !FLAGS_zero_copy) {
PADDLE_ENFORCE_GT(outputs.size(),
0,
paddle::platform::errors::Fatal(
"The size of output should be greater than 0."));
auto output = outputs.back();
PADDLE_ENFORCE_EQ(output.size(),
1UL,
paddle::platform::errors::Fatal(
"The size of output should be equal to 1."));
size_t size = GetSize(output[0]);
PADDLE_ENFORCE_GT(size,
0,
paddle::platform::errors::Fatal(
"The size of output should be greater than 0."));
float *result = static_cast<float *>(output[0].data.data());
// output is probability, which is in (0, 1).
for (size_t i = 0; i < size; i++) {
EXPECT_GT(result[i], 0);
EXPECT_LT(result[i], 1);
}
}
}
// Check the fuse status
TEST(Analyzer_Pyramid_DNN, fuse_statis) {
AnalysisConfig cfg;
SetConfig(&cfg);
int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
}
// Compare result of NativeConfig and AnalysisConfig
TEST(Analyzer_Pyramid_DNN, compare) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
// Compare result of AnalysisConfig and AnalysisConfig + ZeroCopy
TEST(Analyzer_Pyramid_DNN, compare_zero_copy) {
AnalysisConfig cfg;
SetConfig(&cfg);
AnalysisConfig cfg1;
SetConfig(&cfg1);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
std::vector<std::string> outputs_name;
outputs_name.emplace_back("cos_sim_2.tmp_0");
CompareAnalysisAndZeroCopy(reinterpret_cast<PaddlePredictor::Config *>(&cfg),
reinterpret_cast<PaddlePredictor::Config *>(&cfg1),
input_slots_all,
outputs_name);
}
// Compare Deterministic result
TEST(Analyzer_Pyramid_DNN, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
sys.path.append("../../python/paddle/fluid/tests/unittests")
from eager_op_test import OpTest
def sequence_enumerate(input_seq, in_lod, win_size, pad_value):
lod0 = [0]
for i in range(0, len(in_lod[0])):
lod0.append(lod0[i] + in_lod[0][i])
out_seq = []
for i in range(0, len(lod0) - 1):
for idx in range(lod0[i], lod0[i + 1]):
single_seq = []
for word_idx in range(win_size):
word_pos = idx + word_idx
dat = (
input_seq[word_pos] if word_pos < lod0[i + 1] else pad_value
)
single_seq.append(dat)
out_seq.append(single_seq)
return out_seq
class TestSequenceEnumerateOp(OpTest):
def setUp(self):
self.op_type = "sequence_enumerate"
self.init_test_case()
self.inputs = {'X': (self.in_seq, self.lod)}
self.attrs = {'win_size': self.win_size, 'pad_value': self.pad_value}
self.outputs = {'Out': (self.out_seq, self.lod)}
def test_check_output(self):
self.check_output()
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 2
self.pad_value = 0
out_seq = sequence_enumerate(
self.in_seq, self.lod, self.win_size, self.pad_value
)
self.out_seq = np.array(out_seq).astype("int32")
class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int64")
self.lod = [[9, 4, 11, 6]]
self.win_size = 2
self.pad_value = 0
out_seq = sequence_enumerate(
self.in_seq, self.lod, self.win_size, self.pad_value
)
self.out_seq = np.array(out_seq).astype("int64")
class TestSequenceEnumerateOpLargeWinSize(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 5
self.pad_value = 0
out_seq = sequence_enumerate(
self.in_seq, self.lod, self.win_size, self.pad_value
)
self.out_seq = np.array(out_seq).astype("int32")
class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 30
self.pad_value = 0
out_seq = sequence_enumerate(
self.in_seq, self.lod, self.win_size, self.pad_value
)
self.out_seq = np.array(out_seq).astype("int32")
class TestSequenceEnumerateOpLargePadValue(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[9, 4, 11, 6]]
self.win_size = 5
self.pad_value = 5
out_seq = sequence_enumerate(
self.in_seq, self.lod, self.win_size, self.pad_value
)
self.out_seq = np.array(out_seq).astype("int32")
class TestSequenceEnumerateOpLargePadValueSeqLen0(TestSequenceEnumerateOp):
def init_test_case(self):
self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
self.lod = [[0, 14, 0, 16, 0]]
self.win_size = 5
self.pad_value = 5
out_seq = sequence_enumerate(
self.in_seq, self.lod, self.win_size, self.pad_value
)
self.out_seq = np.array(out_seq).astype("int32")
if __name__ == "__main__":
unittest.main()
...@@ -1293,7 +1293,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [ ...@@ -1293,7 +1293,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_adaptive_avg_pool3d', 'test_adaptive_avg_pool3d',
'test_paddle_save_load_binary', 'test_paddle_save_load_binary',
'test_fused_fc_elementwise_layernorm_op', 'test_fused_fc_elementwise_layernorm_op',
'test_sequence_enumerate_op',
'test_lgamma_op', 'test_lgamma_op',
'test_modified_huber_loss_op', 'test_modified_huber_loss_op',
'trt_quant_int8_test', 'trt_quant_int8_test',
...@@ -2722,7 +2721,6 @@ TWO_PARALLEL_JOB = [ ...@@ -2722,7 +2721,6 @@ TWO_PARALLEL_JOB = [
'test_conv_shift_op', 'test_conv_shift_op',
'test_sequence_expand_as', 'test_sequence_expand_as',
'test_cos_sim_op', 'test_cos_sim_op',
'test_sequence_enumerate_op',
'test_sequence_concat', 'test_sequence_concat',
'test_data_norm_op', 'test_data_norm_op',
'test_decoupled_py_reader_data_check', 'test_decoupled_py_reader_data_check',
......
...@@ -539,7 +539,6 @@ STATIC_MODE_TESTING_LIST = [ ...@@ -539,7 +539,6 @@ STATIC_MODE_TESTING_LIST = [
'test_layers', 'test_layers',
'test_sequence_concat', 'test_sequence_concat',
'test_sequence_conv', 'test_sequence_conv',
'test_sequence_enumerate_op',
'test_sequence_erase_op', 'test_sequence_erase_op',
'test_sequence_expand', 'test_sequence_expand',
'test_sequence_expand_as', 'test_sequence_expand_as',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册