提交 ff390a95 编写于 作者: Y yangqingyou

limit r by rshift

上级 ba29f12b
...@@ -247,11 +247,7 @@ void FixedPointTensor<T, N>::truncate(const FixedPointTensor<T, N>* op, ...@@ -247,11 +247,7 @@ void FixedPointTensor<T, N>::truncate(const FixedPointTensor<T, N>* op,
} }
// r' // r'
aby3_ctx()->template gen_random_private(*temp[0]); aby3_ctx()->template gen_random_private(*temp[0]);
std::for_each(temp[0]->data(), temp[0]->data() + temp[0]->numel(), temp[0]->rshift(1, temp[0].get());
[] (T& a) {
a = (T) (a * std::pow(2, sizeof(T) * 8 - 2)
/ std::numeric_limits<T>::max());
});
//r'_0, r'_1 //r'_0, r'_1
aby3_ctx()->template gen_random_private(*temp[1]); aby3_ctx()->template gen_random_private(*temp[1]);
......
...@@ -3438,7 +3438,7 @@ TEST_F(FixedTensorTest, inv_sqrt_test) { ...@@ -3438,7 +3438,7 @@ TEST_F(FixedTensorTest, inv_sqrt_test) {
} }
#ifdef USE_ABY3_TRUNC1 //use aby3 trunc1 #ifdef USE_ABY3_TRUNC1 //use aby3 trunc1
TEST_F(FixedTensorTest, truncate1_msb_failed) { TEST_F(FixedTensorTest, truncate1_msb_incorrect) {
std::vector<size_t> shape = { 1 }; std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[3] = { gen(shape), gen(shape), gen(shape) }; std::shared_ptr<TensorAdapter<int64_t>> sl[3] = { gen(shape), gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sout[6] = { gen(shape), gen(shape), gen(shape), std::shared_ptr<TensorAdapter<int64_t>> sout[6] = { gen(shape), gen(shape), gen(shape),
...@@ -3498,7 +3498,7 @@ TEST_F(FixedTensorTest, truncate1_msb_failed) { ...@@ -3498,7 +3498,7 @@ TEST_F(FixedTensorTest, truncate1_msb_failed) {
} }
#else #else
TEST_F(FixedTensorTest, truncate3_msb_not_failed) { TEST_F(FixedTensorTest, truncate3_msb_correct) {
std::vector<size_t> shape = { 1 }; std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[3] = { gen(shape), gen(shape), gen(shape) }; std::shared_ptr<TensorAdapter<int64_t>> sl[3] = { gen(shape), gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sout[6] = { gen(shape), gen(shape), gen(shape), std::shared_ptr<TensorAdapter<int64_t>> sout[6] = { gen(shape), gen(shape), gen(shape),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册