diff --git a/core/privc3/fixedpoint_tensor_imp.h b/core/privc3/fixedpoint_tensor_imp.h index 404694e787c1326b61dd6b8131099ddec8b6bf50..e4bf60c7a43c9f4f651b216af13665990764799d 100644 --- a/core/privc3/fixedpoint_tensor_imp.h +++ b/core/privc3/fixedpoint_tensor_imp.h @@ -247,11 +247,7 @@ void FixedPointTensor::truncate(const FixedPointTensor* op, } // r' aby3_ctx()->template gen_random_private(*temp[0]); - std::for_each(temp[0]->data(), temp[0]->data() + temp[0]->numel(), - [] (T& a) { - a = (T) (a * std::pow(2, sizeof(T) * 8 - 2) - / std::numeric_limits::max()); - }); + temp[0]->rshift(1, temp[0].get()); //r'_0, r'_1 aby3_ctx()->template gen_random_private(*temp[1]); diff --git a/core/privc3/fixedpoint_tensor_test.cc b/core/privc3/fixedpoint_tensor_test.cc index 11feb42870533b0fe04139e984dd9ced12e9df9a..c2f83189fc46535932fad8048d720aa408013a95 100644 --- a/core/privc3/fixedpoint_tensor_test.cc +++ b/core/privc3/fixedpoint_tensor_test.cc @@ -3438,7 +3438,7 @@ TEST_F(FixedTensorTest, inv_sqrt_test) { } #ifdef USE_ABY3_TRUNC1 //use aby3 trunc1 -TEST_F(FixedTensorTest, truncate1_msb_failed) { +TEST_F(FixedTensorTest, truncate1_msb_incorrect) { std::vector shape = { 1 }; std::shared_ptr> sl[3] = { gen(shape), gen(shape), gen(shape) }; std::shared_ptr> sout[6] = { gen(shape), gen(shape), gen(shape), @@ -3498,7 +3498,7 @@ TEST_F(FixedTensorTest, truncate1_msb_failed) { } #else -TEST_F(FixedTensorTest, truncate3_msb_not_failed) { +TEST_F(FixedTensorTest, truncate3_msb_correct) { std::vector shape = { 1 }; std::shared_ptr> sl[3] = { gen(shape), gen(shape), gen(shape) }; std::shared_ptr> sout[6] = { gen(shape), gen(shape), gen(shape),