diff --git a/core/privc3/fixedpoint_tensor_imp.h b/core/privc3/fixedpoint_tensor_imp.h index e92bd6840c8789e1ce6558bf51986313ed788897..adf660a7d55d2282b65b35fc41b30d0101513115 100644 --- a/core/privc3/fixedpoint_tensor_imp.h +++ b/core/privc3/fixedpoint_tensor_imp.h @@ -211,10 +211,10 @@ void FixedPointTensor::truncate(const FixedPointTensor* op, #else // use truncate3 -// Protocol. `truncate3` +// Protocol. `truncate3` (illustrated for data type T = int64_t) // motivation: // truncates in aby3 may cause msb error with small probability -// the reason is that before rishft op, its masked value e.g., x' - r' may overflow in int64 +// the reason is that before rishft op, its masked value e.g., x' - r' may overflow in int64_t // so that, in `truncate3`, we limit r' in (-2^62, 2^62) to avoid the problem. // notice: @@ -245,17 +245,14 @@ void FixedPointTensor::truncate(const FixedPointTensor* op, temp.emplace_back( tensor_factory()->template create(op->shape())); } - // r', contraint in (-2^62, 2^62) + // r', contraint in (constraint_low, contraint_upper) aby3_ctx()->template gen_random_private(*temp[0]); - int64_t contraint_upper = ~((uint64_t) 1 << 62); - int64_t contraint_low = (uint64_t) 1 << 62; + T contraint_upper = (T) 1 << (sizeof(T) * 8 - 2); + T contraint_low = - contraint_upper; std::for_each(temp[0]->data(), temp[0]->data() + temp[0]->numel(), [&contraint_upper, &contraint_low] (T& a) { - // contraint -2^62 < a < 2^62 - if (a >= 0) { - a &= contraint_upper; - } else { - a |= contraint_low; + while ((a > contraint_upper || a < contraint_low)) { + a = aby3_ctx()->template gen_random_private(); } });