fake_quantize_op.cu 19.1 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
16
#include "paddle/fluid/memory/memcpy.h"
视言's avatar
视言 已提交
17 18 19 20 21 22 23
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

template <typename T>
24
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
视言's avatar
视言 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

  extern __shared__ T shared_max_data[];
  if (gridDim.x > 1) {
    shared_max_data[tid] = T(0);
    for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
      T tmp = fabs(in[i]);
      if (tmp > shared_max_data[tid]) {
        shared_max_data[tid] = tmp;
      }
    }
  } else {
    if (bid < n) {
      shared_max_data[tid] = fabs(in[bid]);
    } else {
      shared_max_data[tid] = T(0);
    }
  }
  __syncthreads();

  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
47
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
视言's avatar
视言 已提交
48 49 50 51 52 53 54 55 56
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    __syncthreads();
  }
  if (tid == 0) {
    out[blockIdx.x] = shared_max_data[0];
  }
}

57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
template <typename T>
struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx, const T* in,
                  const int num, T* out) {
    int block = 1024;
    int grid = (block - 1 + num) / block;
    grid = (grid > block) ? block : grid;

    framework::Tensor max;
    T* max_data =
        max.mutable_data<T>(framework::make_ddim({grid}), ctx.GetPlace());
    FindAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
        in, num, max_data);
    FindAbsMaxKernel<T><<<1, block, 1024 * sizeof(T), ctx.stream()>>>(
        max_data, grid, out);
  }
};

template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
视言's avatar
视言 已提交
76

77
template <typename T>
78 79
__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
                                                  const int c, T* out) {
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
  int tid = threadIdx.x;
  int channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
  extern __shared__ T shared_max_data[];
  shared_max_data[tid] = T(0);
  for (int i = tid; i < channel_size; i += blockDim.x) {
    T tmp = fabs(in_c[i]);
    if (tmp > shared_max_data[tid]) {
      shared_max_data[tid] = tmp;
    }
  }
  __syncthreads();
  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    __syncthreads();
  }
  if (tid == 0) {
    out[blockIdx.x] = shared_max_data[0];
  }
}

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
                                                  const int cin, const int cout,
                                                  T* out) {
  extern __shared__ T shared_max_data[];
  int cout_wh_size = n / cin;
  int wh_size = n / (cin * cout);

  int tid = threadIdx.x;
  int bid = blockIdx.x;
  const T* in_current = in + tid * cout_wh_size + bid * wh_size;
  shared_max_data[tid] = T(0);
  for (int i = 0; i < wh_size; i++) {
    T tmp = fabs(in_current[i]);
    if (tmp > shared_max_data[tid]) {
      shared_max_data[tid] = tmp;
    }
  }
  __syncthreads();

  int len = blockDim.x;
  for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) {
    if (tid < i && tid + i < len &&
        shared_max_data[tid] < shared_max_data[tid + i]) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    if (i == 1) {
      i = 0;  // break the loop
    }
    __syncthreads();
  }
134
  if (tid == 0 && shared_max_data[0] > out[bid]) {
135 136 137 138
    out[bid] = shared_max_data[0];
  }
}

139 140
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
141 142 143 144 145 146 147 148 149 150 151 152
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in_tensor, const int quant_axis,
                  T* out_abs_max) {
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));
    const int num = in_tensor.numel();
    auto in_dims = in_tensor.dims();
    const T* in_data = in_tensor.data<T>();
    if (quant_axis == 0) {
153 154
      int cout = in_dims[0];
      int grid = cout;
155 156 157
      int block = 1024;
      FindChannelAbsMaxKernelQuantAxis0<
          T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
158
          in_data, num, cout, out_abs_max);
159
    } else if (quant_axis == 1) {
160 161 162 163 164
      int cin = in_dims[0];
      int cout = in_dims[1];
      int grid = cout;
      int max_threads = 1024;

165 166 167
#ifdef PADDLE_WITH_HIP
      hipMemset(out_abs_max, 0, sizeof(T) * cout);
#else
168
      cudaMemset(out_abs_max, 0, sizeof(T) * cout);
169
#endif
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184

      for (int i = 0; i < cin / max_threads; i++) {
        int block = max_threads;
        FindChannelAbsMaxKernelQuantAxis1<
            T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
            in_data, num, cin, cout, out_abs_max);
        in_data += num / cin;
      }

      int block = cin % max_threads;
      if (block > 0) {
        FindChannelAbsMaxKernelQuantAxis1<
            T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
            in_data, num, in_dims[0], in_dims[1], out_abs_max);
      }
185
    }
186 187 188 189 190
  }
};

template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;

视言's avatar
视言 已提交
191
template <typename T>
192 193
__global__ void ClipAndQuantKernel(const T* in, const T* scale,
                                   const int bin_cnt, const int n, T* out) {
视言's avatar
视言 已提交
194 195 196
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

197
  T s = scale[0];
198
  T inv_s = inverse(s);
视言's avatar
视言 已提交
199
  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
200
    T x = in[i];
201 202
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
203
    v = bin_cnt * inv_s * v;
204
    out[i] = round(v);
视言's avatar
视言 已提交
205 206 207
  }
}

208 209 210 211 212 213 214 215
template <typename T>
__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
                                          const int bin_cnt, const int n,
                                          T* out) {
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

  T s = scale[0];
216
  T inv_s = inverse(s);
217 218 219 220
  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
    T x = in[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
221
    v = bin_cnt * inv_s * v;
222 223 224 225
    out[i] = round(v) * s / bin_cnt;
  }
}

226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
        in_data, scale_data, bin_cnt, num, out_data);
  }
};

template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>(
        in_data, scale_data, bin_cnt, num, out_data);
  }
};

template struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext,
                                               float>;

267
// ChannelClipAndQuantKernel for quant_axis is 0
268
template <typename T>
269 270 271 272
__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
                                                    const int bin_cnt,
                                                    const int n, const int c,
                                                    T* out) {
273 274 275 276 277 278 279
  int tid = threadIdx.x;

  int channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
  T* out_c = out + blockIdx.x * channel_size;

  T s = scale[blockIdx.x];
280 281
  T inv_s = inverse(s);

282 283 284 285
  for (int i = tid; i < channel_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
286
    v = bin_cnt * inv_s * v;
287 288 289 290
    out_c[i] = round(v);
  }
}

291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
// ChannelClipAndQuantKernel for quant_axis is 1
template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis1(const T* in, const T* scale,
                                                    const int bin_cnt,
                                                    const int n, const int cin,
                                                    const int cout, T* out) {
  T s = scale[blockIdx.x % cout];
  T inv_s = inverse(s);

  int wh_size = n / (cin * cout);
  const T* in_c = in + blockIdx.x * wh_size;
  T* out_c = out + blockIdx.x * wh_size;

  for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt * inv_s * v;
    out_c[i] = round(v);
  }
}

313 314 315 316
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
317
                  const int bin_cnt, const int quant_axis,
318
                  framework::Tensor* out) {
319 320 321 322 323
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));
324

325 326
    int num = in.numel();
    auto in_dims = in.dims();
327 328 329 330
    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

331 332 333 334 335 336 337 338 339 340 341
    if (quant_axis == 0) {
      int grid = in_dims[0];
      int block = 1024;
      ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>(
          in_data, scale_data, bin_cnt, num, in_dims[0], out_data);
    } else if (quant_axis == 1) {
      int grid = in_dims[0] * in_dims[1];
      int block = 1024;
      ChannelClipAndQuantKernelQuantAxis1<T><<<grid, block, 0, ctx.stream()>>>(
          in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data);
    }
342 343 344 345 346 347
  }
};

template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
                                               float>;

视言's avatar
视言 已提交
348
template <typename T>
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
                                            const T* last_scale,
                                            const int64_t* iter,
                                            const int window_size, T* scale_arr,
                                            T* out_scale, int* need_find_max,
                                            int* out_size) {
  int it = iter[0];
  int idx = it % window_size;
  T removed = scale_arr[idx];
  T cur = cur_scale[0];
  scale_arr[idx] = cur;
  T max = last_scale[0];
  out_scale[0] = max < cur ? cur : max;
  if (fabs(removed - max) < 1e-6) {
    need_find_max[0] = 1;
    out_size[0] = it > window_size ? window_size : it;
视言's avatar
视言 已提交
365
  } else {
366
    need_find_max[0] = 0;
视言's avatar
视言 已提交
367 368 369 370
  }
}

template <typename T>
371 372 373 374 375 376
struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& cur_scale,
                  const framework::Tensor& last_scale,
                  const framework::Tensor& iter, const int window_size,
                  framework::Tensor* scales_arr, framework::Tensor* out_scale) {
377
    const auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
M
minqiyang 已提交
378

379 380 381 382
    T* scale_arr = scales_arr->mutable_data<T>(gpu_place);
    T* out_scale_data = out_scale->mutable_data<T>(gpu_place);

    framework::Tensor need_find_max, out_size;
Z
Zeng Jinle 已提交
383 384
    int* find_max = need_find_max.mutable_data<int>({1}, gpu_place);
    int* out_size_data = out_size.mutable_data<int>({1}, gpu_place);
385 386 387 388 389 390 391

    FindRangeAbsMaxAndFillArray<T><<<1, 1, 0, ctx.stream()>>>(
        cur_scale.data<T>(), last_scale.data<T>(), iter.data<int64_t>(),
        window_size, scale_arr, out_scale_data, find_max, out_size_data);

    int g_find_max;
    memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max,
392 393
                 sizeof(int), ctx.stream());
    ctx.Wait();
394 395 396
    if (g_find_max) {
      int len;
      memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data,
397 398
                   sizeof(int), ctx.stream());
      ctx.Wait();
399 400
      FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
                                                          out_scale_data);
视言's avatar
视言 已提交
401 402
    }
  }
403
};
视言's avatar
视言 已提交
404

405
template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
视言's avatar
视言 已提交
406

407 408 409 410 411 412 413
template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in_accum,
                  const framework::Tensor& in_state, const T* cur_scale,
                  const float rate, framework::Tensor* out_state,
                  framework::Tensor* out_accum, framework::Tensor* out_scale) {
414
    const auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
415 416 417 418

    T accum;
    T state;
    T scale;
419 420 421 422
    memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data<T>(),
                 sizeof(T), ctx.stream());
    memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data<T>(),
                 sizeof(T), ctx.stream());
423
    memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T),
424 425
                 ctx.stream());
    ctx.Wait();
426 427 428 429 430
    state = rate * state + 1;
    accum = rate * accum + scale;
    scale = accum / state;

    memory::Copy(gpu_place, out_accum->mutable_data<T>(gpu_place),
431
                 platform::CPUPlace(), &accum, sizeof(T), ctx.stream());
432
    memory::Copy(gpu_place, out_state->mutable_data<T>(gpu_place),
433
                 platform::CPUPlace(), &state, sizeof(T), ctx.stream());
434
    memory::Copy(gpu_place, out_scale->mutable_data<T>(gpu_place),
435 436
                 platform::CPUPlace(), &scale, sizeof(T), ctx.stream());
    ctx.Wait();
437 438 439
  }
};

H
huangxu96 已提交
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
// ChannelClipAndQuantDequantKernel for quant_axis is 0
template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis0(
    const T* in, const T* scale, const int bin_cnt, const int n, const int c,
    T* out) {
  int tid = threadIdx.x;

  int channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
  T* out_c = out + blockIdx.x * channel_size;

  T s = scale[blockIdx.x];
  T inv_s = inverse(s);

  for (int i = tid; i < channel_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt * inv_s * v;
    out_c[i] = round(v) * s / bin_cnt;
  }
}

// ChannelClipAndQuantDequantKernel for quant_axis is 1
template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis1(
    const T* in, const T* scale, const int bin_cnt, const int n, const int cin,
    const int cout, T* out) {
  T s = scale[blockIdx.x % cout];
  T inv_s = inverse(s);

  int wh_size = n / (cin * cout);
  const T* in_c = in + blockIdx.x * wh_size;
  T* out_c = out + blockIdx.x * wh_size;

  for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt * inv_s * v;
    out_c[i] = round(v) * s / bin_cnt;
  }
}

template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, const int quant_axis,
                  framework::Tensor* out) {
    // At present, channelwise quantization supports conv2d, depthwise_conv2d
    // conv2d_transpose and mul
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));

    int num = in.numel();
    auto in_dims = in.dims();

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    if (quant_axis == 0) {
      int grid = in_dims[0];
      int block = 1024;
      ChannelClipAndQuantDequantKernelQuantAxis0<
          T><<<grid, block, 0, ctx.stream()>>>(in_data, scale_data, bin_cnt,
                                               num, in_dims[0], out_data);
    } else if (quant_axis == 1) {
      int grid = in_dims[0] * in_dims[1];
      int block = 1024;

      ChannelClipAndQuantDequantKernelQuantAxis1<
          T><<<grid, block, 0, ctx.stream()>>>(
          in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data);
    }
  }
};

template struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext,
                                                   float>;
524

视言's avatar
视言 已提交
525 526 527
}  // namespace operators
}  // namespace paddle

528 529 530 531
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
                        ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
532 533
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max,
                        ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float>);
Z
Zhen Wang 已提交
534 535
REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max,
                        ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>);
536 537
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
                        ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
538 539 540
REGISTER_OP_CUDA_KERNEL(
    fake_quantize_moving_average_abs_max,
    ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>);
Z
Zhen Wang 已提交
541 542
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
                        ops::MovingAverageAbsMaxScaleKernel<CUDA, float>);
543 544 545
REGISTER_OP_CUDA_KERNEL(
    fake_quantize_dequantize_moving_average_abs_max,
    ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);
546 547
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_grad,
                        ops::FakeQuantDequantGradKernel<CUDA, float>);
H
huangxu96 已提交
548 549 550
REGISTER_OP_CUDA_KERNEL(
    fake_channel_wise_quantize_dequantize_abs_max,
    ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>);