未验证 提交 546a663c 编写于 作者: M mapingshuo 提交者: GitHub

test=develop (#3943)

上级 845e80d9
......@@ -85,21 +85,31 @@ class SequenceConvComputeTester : public arena::TestCase {
auto output_dims = output->dims();
auto output_data = output->mutable_data<float>();
std::vector<std::vector<float>> res;
if (contextStart_ == -2) {
if (contextStart_ == -2 && lod_.size() == 1 &&
lod_[0] == std::vector<uint64_t>({0, 4})) {
res = {{-0.08867277f, -0.17257819f, -0.2564836f},
{0.194508f, 0.05720823f, -0.08009153f},
{0.73512584f, 0.5749428f, 0.41475973f},
{0.5635012f, 0.49485126f, 0.42620137f}};
} else if (contextStart_ == -1) {
} else if (contextStart_ == -1 && lod_.size() == 1 &&
lod_[0] == std::vector<uint64_t>({0, 4})) {
res = {{0.194508f, 0.05720823f, -0.08009153f},
{0.73512584f, 0.5749428f, 0.41475973f},
{0.5635012f, 0.49485126f, 0.42620137f},
{0.2517162f, 0.23646072f, 0.22120519f}};
} else if (contextStart_ == 0) {
} else if (contextStart_ == 0 && lod_.size() == 1 &&
lod_[0] == std::vector<uint64_t>({0, 4})) {
res = {{0.73512584f, 0.5749428f, 0.41475973f},
{0.5635012f, 0.49485126f, 0.42620137f},
{0.2517162f, 0.23646072f, 0.22120519f},
{0.02574372f, 0.03337148f, 0.04099924f}};
} else if (contextStart_ == -1 && lod_.size() == 1 &&
lod_[0] == std::vector<uint64_t>({0, 2, 4})) {
res = {{0.194508, 0.05720823, -0.08009153},
{0.7093821, 0.57208234, 0.43478262},
{0.19450802, 0.17925248, 0.16399695},
{0.2517162, 0.23646072, 0.22120519}};
} else {
fprintf(stderr, "not supported contextStart_\n");
exit(-1);
......@@ -136,12 +146,25 @@ void TestNormalCase(Place place, float abs_error = 2e-5) {
}
}
void TestBatchCase(Place place, float abs_error = 2e-5) {
std::vector<std::vector<uint64_t>> lod{{0, 2, 4}};
std::vector<int64_t> dims{4, 5};
std::vector<int> candidate_pad_idx{-1};
for (int pad_idx : candidate_pad_idx) {
std::unique_ptr<arena::TestCase> tester(new SequenceConvComputeTester(
place, "def", lod, DDim(dims), pad_idx, 1, 3, 3));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
TEST(sequence_conv, precision) {
#ifdef LITE_WITH_ARM
float abs_error = 2e-5;
Place place(TARGET(kARM));
TestNormalCase(place, abs_error);
TestBatchCase(place, abs_error);
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册