From 251eb372c2e78089cabb1923b57cff32a8a1b610 Mon Sep 17 00:00:00 2001 From: Dang Qingqing Date: Thu, 30 Aug 2018 22:07:23 +0800 Subject: [PATCH] Improve and fix fake_quantize_op. --- paddle/fluid/operators/CMakeLists.txt | 3 + paddle/fluid/operators/fake_quantize_op.cc | 105 +++++++++------ paddle/fluid/operators/fake_quantize_op.cu | 124 ++++++++++-------- paddle/fluid/operators/fake_quantize_op.h | 56 +++----- .../tests/unittests/test_fake_quantize_op.py | 45 ++++--- 5 files changed, 182 insertions(+), 151 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 68fbde2c09f..1f9a3be8b33 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -178,6 +178,8 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(relu);\n") elseif(${TARGET} STREQUAL "fake_dequantize") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") + elseif(${TARGET} STREQUAL "fake_quantize") + file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n") elseif(${TARGET} STREQUAL "tensorrt_engine_op") message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") elseif(${TARGET} STREQUAL "fc") @@ -291,6 +293,7 @@ op_library(unsqueeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op) op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) +op_library(fake_quantize_op DEPS memory) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 05680345a5d..e608eba05d5 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -15,43 +15,55 @@ limitations under the License. */ #include "paddle/fluid/operators/fake_quantize_op.h" #include #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { -template -using EigenVectorArrayMap = Eigen::Map>; -template +template +using EigenVectorArrayMap = + Eigen::TensorMap>; + +template using ConstEigenVectorArrayMap = - Eigen::Map>; + Eigen::TensorMap>; template struct FindAbsMaxFunctor { - void operator()(const CPUDeviceContext& ctx, const T* in, const int num, - T* out) { - ConstEigenVectorArrayMap in_e(in, num); - EigenVectorArrayMap out_e(out, 1); + void operator()(const platform::CPUDeviceContext& ctx, const T* in, + const int num, T* out) { + Eigen::DSizes idim(num); + Eigen::DSizes odim(1); + Eigen::TensorMap> in_e(in, idim); + Eigen::TensorMap> out_e(out, odim); - auto& dev = ctx.eigen_device(); out_e = in_e.abs().maximum(); } }; +template struct FindAbsMaxFunctor; + template struct ClipAndFakeQuantFunctor { - void operator()(const CPUDeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor* scale, const int bin_cnt, - framework::Tensor* out) { - T s = scale->data()[0]; - Transform trans; + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + T s = scale.data()[0]; + platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); auto in_e = framework::EigenVector::Flatten(in); auto out_e = framework::EigenVector::Flatten(*out); - out_e.device(dev) = (bin_cnt / s * in_e).round(); + + out_e.device(*ctx.eigen_device()) = (bin_cnt / s * in_e).round(); } }; +template struct ClipAndFakeQuantFunctor; + template struct FindRangeAbsMaxFunctor { void operator()(const platform::CPUDeviceContext& ctx, @@ -59,10 +71,10 @@ struct FindRangeAbsMaxFunctor { const framework::Tensor& last_scale, const framework::Tensor& iter, const int window_size, framework::Tensor* scales_arr, framework::Tensor* out_scale) { - T* scale_arr = scales_arr->mutable_data(cxt.GetPlace()); - int it = iter.data()[0]; + T* scale_arr = scales_arr->mutable_data(ctx.GetPlace()); + int64_t it = iter.data()[0]; int idx = it % window_size; - T removd = scale_arr[idx]; + T removed = scale_arr[idx]; T cur = cur_scale.data()[0]; scale_arr[idx] = cur; @@ -74,10 +86,12 @@ struct FindRangeAbsMaxFunctor { FindAbsMaxFunctor()(ctx, scale_arr, size, &max); } - out_scale->mutable_data()[0] = max; + out_scale->mutable_data(ctx.GetPlace())[0] = max; } }; +template struct FindRangeAbsMaxFunctor; + class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { public: FakeQuantizeAbsMaxOp(const std::string& type, @@ -97,6 +111,14 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { ctx->SetOutputDim("OutScale", {1}); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } }; class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { @@ -126,10 +148,10 @@ $$Out = round(X/scale * range)$$ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { public: - FakeQuantizeOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) + FakeQuantizeRangeAbsMaxOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext* ctx) const override { @@ -141,16 +163,22 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->HasOutput("OutScale"), "Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null"); - if (ctx->HasInput("InScales")) { - PADDLE_ENFORCE( - ctx->HasOutput("OutScales"), - "Output(OutScales) of FakeQuantizeRangeAbsMaxOp should not be null"); - ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales")); + if (ctx->HasOutput("OutScales")) { + int window_size = ctx->Attrs().Get("window_size"); + ctx->SetOutputDim("OutScales", {window_size}); } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("OutScale", {1}); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } }; class FakeQuantizeRangeAbsMaxOpMaker @@ -158,10 +186,8 @@ class FakeQuantizeRangeAbsMaxOpMaker public: void Make() override { AddInput("X", "(Tensor) Input is float data type."); - AddInput("InScales", "(Tensor) scale buffer.").AsDispensable(); - AddInput("InScale", "Last scale.") - AddInput("Iter", "Global step iteration.") - .AsDispensable(); + AddInput("InScale", "Last scale."); + AddInput("Iter", "Global step iteration.").AsDispensable(); AddOutput("Out", "(Tensor) Output of quantized low level tensor."); AddOutput("OutScale", " Current scale"); AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable(); @@ -189,19 +215,16 @@ $$Out = round(X/scale * range)$$ } // namespace paddle namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp, ops::FakeQuantizeAbsMaxOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - fake_quantize_abs_max, - ops::FakeQuantizeKernel, - ops::FakeQuantizeKernel); +REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max, + ops::FakeQuantizeAbsMaxKernel); -REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeOp, +REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp, ops::FakeQuantizeRangeAbsMaxOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - fake_quantize_range_abs_max, - ops::FakeQuantizeKernel, - ops::FakeQuantizeKernel); +REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max, + ops::FakeQuantizeRangeAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 17a451fcf62..7c65d6dba7d 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -24,53 +25,55 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; - extern __shared__ T shared_max[]; + extern __shared__ T shared_max_data[]; if (gridDim.x > 1) { - shared_max[tid] = T(0); + shared_max_data[tid] = T(0); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T tmp = fabs(in[i]); - if (tmp > shared_max[tid]) { - shared_max[tid] = tmp; + if (tmp > shared_max_data[tid]) { + shared_max_data[tid] = tmp; } } } else { if (bid < n) { - shared_max[tid] = fabs(in[bid]); + shared_max_data[tid] = fabs(in[bid]); } else { - shared_max[tid] = T(0); + shared_max_data[tid] = T(0); } } __syncthreads(); for (int i = blockDim.x / 2; i > 0; i >>= 1) { - if (tid < i && (shared_max[tid] < shared_max[tid + i])) { - shared_max[tid] = shared_max[tid + i]; + if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { + shared_max_data[tid] = shared_max_data[tid + i]; } __syncthreads(); } if (tid == 0) { - out[blockIdx.x] = shared_max[0]; + out[blockIdx.x] = shared_max_data[0]; } } template struct FindAbsMaxFunctor { - void operator()(const CUDADeviceContext& ctx, const T* in, const int num, - T* out) { + void operator()(const platform::CUDADeviceContext& ctx, const T* in, + const int num, T* out) { int block = 1024; int grid = (block - 1 + num) / block; grid = (grid > block) ? block : grid; - Tensor max; + framework::Tensor max; T* max_data = max.mutable_data(framework::make_ddim({grid}), ctx.GetPlace()); - FindAbsMaxKernel<<>>( - in_data, num, max_data); - FindAbsMaxKernel<<<1, block, block * sizeof(T), ctx.stream()>>>( + FindAbsMaxKernel<<>>( + in, num, max_data); + FindAbsMaxKernel<<<1, block, 1024 * sizeof(T), ctx.stream()>>>( max_data, grid, out); } }; +template struct FindAbsMaxFunctor; + template __global__ void ClipAndQuantKernel(const T* in, const T* scale, const int bin_cnt, const int n, T* out) { @@ -88,11 +91,25 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, } template -__global__ void FillScaleArray(T* scale_arr, T* out_scale, const int* it, - const int window_size, ) { - int tid = threadIdx.x; +__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, + const T* last_scale, + const int64_t* iter, + const int window_size, T* scale_arr, + T* out_scale, int* need_find_max, + int* out_size) { + int it = iter[0]; int idx = it % window_size; - // scale_arr[idx] = ; + T removed = scale_arr[idx]; + T cur = cur_scale[0]; + scale_arr[idx] = cur; + T max = last_scale[0]; + out_scale[0] = max < cur ? cur : max; + if (fabs(removed - max) < 1e-6) { + need_find_max[0] = 1; + out_size[0] = it > window_size ? window_size : it; + } else { + need_find_max[0] = 0; + } } template @@ -102,46 +119,44 @@ struct FindRangeAbsMaxFunctor { const framework::Tensor& last_scale, const framework::Tensor& iter, const int window_size, framework::Tensor* scales_arr, framework::Tensor* out_scale) { - T* scale_arr = scales_arr->mutable_data(cxt.GetPlace()); auto& gpu_place = boost::get(ctx.GetPlace()); - int it; - memory::Copy(platform::CPUPlace(), &it, gpu_place, iter.data(), - sizeof(int), ctx.stream()); - int idx = current_iter % window_size; - T removed; - memory::Copy(platform::CPUPlace(), &removed, gpu_place, scale_arr + idx, - sizeof(T), ctx.stream()); - T cur; - memory::Copy(gpu_place, &cur, gpu_place, cur_scale.data(), sizeof(T), - ctx.stream()); - - T max; - memory::Copy(platform::CPUPlace(), &max, gpu_place, last_scale.data(), - sizeof(T), ctx.stream()); + T* scale_arr = scales_arr->mutable_data(gpu_place); T* out_scale_data = out_scale->mutable_data(gpu_place); - if (max < cur) { - max = cur; - memory::Copy(gpu_place, out_scale_data, gpu_place, &max, sizeof(T), - ctx.stream()); - } else if (fabs(removed - max) < 1e-6) { - int size = (it > window_size) ? window_size : it; - FindAbsMaxFunctor()(ctx, scale_arr, size, - out_scale_data); + + framework::Tensor need_find_max, out_size; + int* find_max = need_find_max.mutable_data(gpu_place); + int* out_size_data = out_size.mutable_data(gpu_place); + + FindRangeAbsMaxAndFillArray<<<1, 1, 0, ctx.stream()>>>( + cur_scale.data(), last_scale.data(), iter.data(), + window_size, scale_arr, out_scale_data, find_max, out_size_data); + + int g_find_max; + memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, + sizeof(int), 0); + if (g_find_max) { + int len; + memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, + sizeof(int), 0); + FindAbsMaxFunctor()(ctx, scale_arr, len, + out_scale_data); } } }; +template struct FindRangeAbsMaxFunctor; + template -struct ClipAndFakeQuantFunctor { - void operator()(const CPUDeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor* scale, const int bin_cnt, - framework::Tensor* out) { +struct ClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { int num = in.numel(); int block = 1024; int grid = (block - 1 + num) / block; - T* in_data = in.data(); - T* scale_data = scale.data(); + const T* in_data = in.data(); + const T* scale_data = scale.data(); T* out_data = out->mutable_data(ctx.GetPlace()); ClipAndQuantKernel<<>>( @@ -149,11 +164,14 @@ struct ClipAndFakeQuantFunctor { } }; +template struct ClipAndFakeQuantFunctor; + } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(fake_quantize, - paddle::operators::FakeQuantizeCUDAKernel< - paddle::platform::CUDADeviceContext, float>, - paddle::operators::FakeQuantizeCUDAKernel< - paddle::platform::CUDADeviceContext, double>); +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max, + ops::FakeQuantizeAbsMaxKernel); +REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, + ops::FakeQuantizeRangeAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index ad97ca91aec..7ace7573ec5 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -17,9 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { @@ -32,51 +30,26 @@ struct FindAbsMaxFunctor { template struct ClipAndFakeQuantFunctor { void operator()(const DeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor* scale, const int bin_cnt, + const framework::Tensor& scale, const int bin_cnt, framework::Tensor* out); }; template struct FindRangeAbsMaxFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor& cur_scale, + void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale, const framework::Tensor& last_scale, const framework::Tensor& iter, const int window_size, - framework::Tensor* scales_arr, framework::Tensor* out_scale, - framework::Tensor* out); + framework::Tensor* scales_arr, framework::Tensor* out_scale); }; -void FindRangeAbsMax(const platform::CUDADeviceContext& ctx, - framework::Tensor* scale_list, const T last_max_scale, - const T& cur_scale, int window_size, - int current_iter) const { - T* sl = scale_list->mutable_data(scale_list->place()); - T remove_tmp; - auto& gpu_place = boost::get(ctx.GetPlace()); - int idx = current_iter % window_size; - memory::Copy(platform::CPUPlace(), &remove_tmp, gpu_place, sl + idx, - sizeof(float), ctx.stream()); - memory::Copy(gpu_place, sl + idx, platform::CPUPlace(), &cur_scale, sizeof(T), - ctx.stream()); - T max_scale = last_max_scale; - if (max_scale < cur_scale) { - max_scale = cur_scale; - } else if (fabs(remove_tmp - max_scale) < 1e-6) { - int size = (current_iter > window_size) ? window_size : current_iter; - max_scale = T(FindAbsMaxGpu(ctx, scale_list->data(), size)); - } - return max_scale; -} - template class FakeQuantizeAbsMaxKernel : public framework::OpKernel { - virtual void Compute(const framework::ExecutionContext& context) const { + public: + void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); - auto* in_scale = context.Input("InScale"); auto* out = context.Output("Out"); auto* out_scale = context.Output("OutScale"); - T* out_data = out->mutable_data(context.GetPlace()); T* out_s = out_scale->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); @@ -84,7 +57,7 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); const T* in_data = in->data(); - FindAbsMaxFunctor()(dev_ctx, in_data, in.numel(), out_s); + FindAbsMaxFunctor()(dev_ctx, in_data, in->numel(), out_s); ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, bin_cnt, out); } @@ -92,9 +65,10 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel { template class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { - virtual void Compute(const framework::ExecutionContext& context) const { + public: + void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); - auto* in_scale = context.Input("X"); + auto* in_scale = context.Input("InScale"); auto* out = context.Output("Out"); out->mutable_data(context.GetPlace()); @@ -113,19 +87,19 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { // training auto* out_scale = context.Output("OutScale"); - auto* in_scales = context.Input("InScales"); - auto* out_scales = context.Input("OutScales"); + auto* out_scales = context.Output("OutScales"); auto* iter = context.Input("Iter"); - bool window_size = context.Attr("window_size"); + int window_size = context.Attr("window_size"); out_scale->mutable_data(context.GetPlace()); - Tensor cur_scale; + framework::Tensor cur_scale; T* cur_scale_data = cur_scale.mutable_data({1}, context.GetPlace()); FindAbsMaxFunctor()(dev_ctx, in->data(), in->numel(), cur_scale_data); - FindRangeAbsMaxFunctor()( - dev_ctx, cur_scale, in_scale, iter, window_size, out_scale, out_scale); + FindRangeAbsMaxFunctor()(dev_ctx, cur_scale, *in_scale, + *iter, window_size, out_scales, + out_scale); ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, bin_cnt, out); } diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index cc0494774a5..820ad4af88e 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -21,28 +21,41 @@ from op_test import OpTest class TestFakeQuantizeOp(OpTest): def setUp(self): - self.op_type = "fake_quantize" + self.op_type = "fake_quantize_abs_max" + self.attrs = {'bit_length': 8} + self.inputs = {'X': np.random.random((124, 240)).astype("float32"), } + scale = np.max(np.abs(self.inputs['X'])).astype("float32") + self.outputs = { + 'Out': np.round(self.inputs['X'] / scale * ( + (1 << (self.attrs['bit_length'] - 1)) - 1)), + 'OutScale': np.array(scale).astype("float32"), + } + + def test_check_output(self): + self.check_output() + + +class TestFakeQuantizeOp(OpTest): + def setUp(self): + self.op_type = "fake_quantize_range_abs_max" self.attrs = { - 'bit_length': 8, - 'quantize_type': 'abs_max', - 'window_size': 10000 + 'bit_length': int(5), + 'window_size': int(1), + 'is_test': False } self.inputs = { - 'X': np.random.random((10, 10)).astype("float32"), - 'InScales': np.zeros(self.attrs['window_size']).astype("float32"), - 'InCurrentIter': np.zeros(1).astype("float32"), - 'InMovingScale': np.zeros(1).astype("float32") - } - self.scale = { - 'abs_max': np.max(np.abs(self.inputs['X'])).astype("float32") + 'X': np.random.random((8, 16, 7, 7)).astype("float32"), + 'Iter': np.zeros(1).astype("int64"), + 'InScale': np.zeros(1).astype("float32") } + scale = np.max(np.abs(self.inputs['X'])).astype("float32") + out_scales = np.zeros(self.attrs['window_size']).astype("float32") + out_scales[0] = scale self.outputs = { - 'Out': np.round(self.inputs['X'] / self.scale['abs_max'] * ( + 'Out': np.round(self.inputs['X'] / scale * ( (1 << (self.attrs['bit_length'] - 1)) - 1)), - 'OutScales': np.zeros(self.attrs['window_size']).astype("float32"), - 'OutMovingScale': - np.array([self.scale['abs_max']]).astype("float32"), - 'OutCurrentIter': np.zeros(1).astype("float32") + 'OutScale': scale, + 'OutScales': out_scales, } def test_check_output(self): -- GitLab