psi_api_test.cc 1.9 KB
Newer Older
J
jingqinghe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "psi_api.h"

#include <thread>

#include "gtest/gtest.h"

namespace psi {

class PsiAPITest : public ::testing::Test {
public:
J
jhjiangcs 已提交
25
    std::set<std::string> _input;
J
jingqinghe 已提交
26

J
jhjiangcs 已提交
27
    int _port;
J
jingqinghe 已提交
28

J
jhjiangcs 已提交
29
    static const int _s_test_size = 1e3;
J
jingqinghe 已提交
30
public:
J
jhjiangcs 已提交
31 32 33 34 35
    PsiAPITest() {
        for (int i = 0; i < _s_test_size; ++i) {
            _input.emplace(std::to_string(i));
        }
        _port = 45818;
J
jingqinghe 已提交
36 37
    }

J
jhjiangcs 已提交
38
    ~PsiAPITest() {}
J
jingqinghe 已提交
39 40 41
};

TEST_F(PsiAPITest, full_test) {
J
jhjiangcs 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    auto test_send = [this]() {
        // find valid port
        for (;; ++_port) {
            try {
                psi_send(_port, _input, nullptr);
                break;
            } catch (const std::exception& e){
                std::string s(e.what());
                if (s.find("socket error") != std::string::npos) {
                    continue;
                } else {
                    throw;
                }
            }
        }
    };
    auto t_send = std::thread(test_send);
J
jingqinghe 已提交
59

J
jhjiangcs 已提交
60
    std::vector<std::string> output;
J
jingqinghe 已提交
61

J
jhjiangcs 已提交
62 63
    std::this_thread::sleep_for(std::chrono::seconds(1));
    psi_recv("127.0.0.1", _port, _input, &output, nullptr);
J
jingqinghe 已提交
64

J
jhjiangcs 已提交
65
    t_send.join();
J
jingqinghe 已提交
66

J
jhjiangcs 已提交
67 68 69 70 71
    std::set<std::string> out_set;
    for (auto& x: output) {
        out_set.emplace(x);
    }
    ASSERT_EQ(out_set, _input);
J
jingqinghe 已提交
72 73 74
}

} // namespace psi