提交 ff390a95 编写于 作者: Y yangqingyou

limit r by rshift

上级 ba29f12b
......@@ -247,11 +247,7 @@ void FixedPointTensor<T, N>::truncate(const FixedPointTensor<T, N>* op,
}
// r'
aby3_ctx()->template gen_random_private(*temp[0]);
std::for_each(temp[0]->data(), temp[0]->data() + temp[0]->numel(),
[] (T& a) {
a = (T) (a * std::pow(2, sizeof(T) * 8 - 2)
/ std::numeric_limits<T>::max());
});
temp[0]->rshift(1, temp[0].get());
//r'_0, r'_1
aby3_ctx()->template gen_random_private(*temp[1]);
......
......@@ -3438,7 +3438,7 @@ TEST_F(FixedTensorTest, inv_sqrt_test) {
}
#ifdef USE_ABY3_TRUNC1 //use aby3 trunc1
TEST_F(FixedTensorTest, truncate1_msb_failed) {
TEST_F(FixedTensorTest, truncate1_msb_incorrect) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[3] = { gen(shape), gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sout[6] = { gen(shape), gen(shape), gen(shape),
......@@ -3498,7 +3498,7 @@ TEST_F(FixedTensorTest, truncate1_msb_failed) {
}
#else
TEST_F(FixedTensorTest, truncate3_msb_not_failed) {
TEST_F(FixedTensorTest, truncate3_msb_correct) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[3] = { gen(shape), gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sout[6] = { gen(shape), gen(shape), gen(shape),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册