From 8b8ad6b1640aaeebcab852d776cb14f9f8ce565a Mon Sep 17 00:00:00 2001 From: caoying03 Date: Mon, 25 Sep 2017 09:56:32 +0800 Subject: [PATCH] fix implementations of supporting soft labels. --- paddle/operators/cross_entropy_op.cu | 4 +- paddle/operators/cross_entropy_op.h | 24 +-- paddle/operators/math/softmax.cc | 2 +- paddle/operators/math/softmax.h | 4 +- paddle/operators/softmax_op.h | 2 +- .../softmax_with_cross_entropy_op.cc | 79 +++++++--- .../softmax_with_cross_entropy_op.cu | 142 +++++++++++++----- .../operators/softmax_with_cross_entropy_op.h | 69 ++++++--- .../test_softmax_with_cross_entropy_op.py | 48 +++++- 9 files changed, 272 insertions(+), 102 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 1d6361a8147..2989e55075f 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -28,7 +28,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -tolerable_value(log(X[i * D + label[i]])); + Y[i] = -TolerableValue()(log(X[i * D + label[i]])); } } @@ -39,7 +39,7 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, i += blockDim.x * gridDim.x) { T sum = static_cast(0); for (int j = 0; j < D; j++) { - sum += label[i * D + j] * tolerable_value(log(X[i * D + j])); + sum += label[i * D + j] * TolerableValue()(log(X[i * D + j])); } Y[i] = -sum; } diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 69caba5ff31..942a532f642 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -22,17 +22,16 @@ namespace operators { using Tensor = framework::Tensor; template -HOSTDEVICE T tolerable_value(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - if (x == INFINITY) { - return kApproInf; +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; } - if (x == -INFINITY) { - return -kApproInf; - } - return x; -} +}; template class CrossEntropyOpKernel : public framework::OpKernel { @@ -57,7 +56,8 @@ class CrossEntropyOpKernel : public framework::OpKernel { for (int i = 0; i < batch_size; ++i) { T sum = static_cast(0); for (int j = 0; j < class_num; ++j) { - sum += label_data[index] * tolerable_value(std::log(x_data[index])); + sum += + label_data[index] * TolerableValue()(std::log(x_data[index])); y_data[i] = -sum; index++; } @@ -66,7 +66,7 @@ class CrossEntropyOpKernel : public framework::OpKernel { auto* label_data = ctx.Input("Label")->data(); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; - y_data[i] = -tolerable_value(std::log(x_data[index])); + y_data[i] = -TolerableValue()(std::log(x_data[index])); } } } diff --git a/paddle/operators/math/softmax.cc b/paddle/operators/math/softmax.cc index ac9f3c4bf61..1224c058105 100644 --- a/paddle/operators/math/softmax.cc +++ b/paddle/operators/math/softmax.cc @@ -18,7 +18,7 @@ namespace paddle { namespace operators { namespace math { -template class SoftmaxFunctor; +template class SoftmaxFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h index ce29a69bce3..08dafed971e 100644 --- a/paddle/operators/math/softmax.h +++ b/paddle/operators/math/softmax.h @@ -28,8 +28,8 @@ using EigenMatrix = framework::EigenMatrix; template class SoftmaxFunctor { public: - void operator()(const framework::Tensor* X, framework::Tensor* Y, - const framework::ExecutionContext& context) { + void operator()(const framework::ExecutionContext& context, + const framework::Tensor* X, framework::Tensor* Y) { auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 18494e470a3..7220f486be0 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Y->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(X, Y, context); + math::SoftmaxFunctor()(context, X, Y); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index 3dd21279add..cb2aa30055d 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -23,31 +23,31 @@ class SoftmaxWithCrossEntropyOpMaker SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - //(TODO caoying) replace int with boolean - AddAttr("soft_label", - "(int, default 0), A flag to indicate whether to interpretate " - "the given labels as soft labels.") - .SetDefault(0); + AddAttr( + "softLabel", + "(bool, default: false), A flag to indicate whether to interpretate " + "the given labels as soft labels.") + .SetDefault(false); AddInput("Logits", - "(Tensor, default Tensor), The unscaled log probabilities " + "(Tensor, default: Tensor), The unscaled log probabilities " "which is a 2-D tensor with shape [N x K]. N is the batch_size, " "and K is the class number.") .NotInGradient(); AddInput( "Label", - "(Tensor, default Tensor), The ground truth which is " - "a 1-D or 2-D tensor. " - "If soft_label is set to 0, Label is a Tensor with shape [N x 1]. " - "If soft_label is set to 1, Label is a Tensor " + "(Tensor, default: Tensor), The ground truth which is a 2-D " + "tensor. " + "If softLable is set to 0, Label is a Tensor with shape [N x 1]. " + "If softLable is set to 1, Label is a Tensor " "with shape [N x K]."); AddOutput( "Softmax", - "(Tensor, default Tensor), A 2-D tensor with shape [N x K]. " + "(Tensor, default: Tensor), A 2-D tensor with shape [N x K]. " "The outputs value of softmax activation by given the input batch, " "which will be used in backward calculation.") .AsIntermediate(); AddOutput("Loss", - "(Tensor, default Tensor), A 1-D tensor. The cross " + "(Tensor, default: Tensor), A 2-D tensor. The cross " "entropy loss with shape [N x 1]."); AddComment(R"DOC( Cross entropy loss with softmax are used as the output layer extensively. This @@ -83,15 +83,39 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"), + "Input(Logits) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) should be not null."); + + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"), + "Output(Softmax) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"), + "Output(Loss) should be not null."); + const Tensor* logits = ctx.Input("Logits"); + const Tensor* labels = ctx.Input("Label"); PADDLE_ENFORCE( logits->dims().size() == 2UL, - "The input of softmax_with_cross_entropy should be a 2-d tensor."); - PADDLE_ENFORCE(ctx.Input("Label")->dims().size() == 1UL, - "The label should be a 1-d tensor."); - - ctx.Output("Softmax")->Resize(logits->dims()); - ctx.Output("Loss")->Resize({logits->dims()[0], 1}); + "The input of softmax_with_cross_entropy should be a 2-D tensor."); + PADDLE_ENFORCE(ctx.Input("Label")->dims().size() == 2UL, + "The labels should be a 2-D tensor."); + + if (ctx.Attr("softLabel")) { + PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1], + "If Attr(softLabel) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); + } else { + PADDLE_ENFORCE_EQ(labels->dims()[1], 1, + "If Attr(softLabel) == false, the 2nd dimension of " + "Input(Label) should be 1."); + } + + ctx.Output("Softmax")->Resize(logits->dims()); + ctx.Output("Loss")->Resize({logits->dims()[0], 1}); + + ctx.ShareLoD("Logits", /*->*/ "Softmax"); + ctx.ShareLoD("Logits", /*->*/ "Loss"); } }; @@ -102,11 +126,28 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")), - "Input(Loss@Grad) should not be null"); + "Input(Loss@Grad) should not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"), "Input(Softmax) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), "Input(Lable) should be not null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")), + "Output(Logits@Grad) should be not null."); + + const Tensor* softmax = ctx.Input("Softmax"); + const Tensor* labels = ctx.Input("Label"); + PADDLE_ENFORCE(ctx.Input("Label")->dims().size() == 2UL, + "The labels should be a 2-D tensor."); + + if (ctx.Attr("softLabel")) { + PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1], + "When Attr(softLabel) == true, the 2nd dimension of " + "Input(X) and Input(Label) should be equal."); + } else { + PADDLE_ENFORCE_EQ(labels->dims()[1], 1, + "When Attr(softLabel) == false, the 2nd dimension of " + "Input(Label) should be 1."); + } ctx.Output(framework::GradVarName("Logits")) ->Resize(ctx.Input("Softmax")->dims()); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index 68bb85fa8ad..feae903dabc 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -24,25 +24,78 @@ namespace operators { using Tensor = framework::Tensor; template -__global__ void CrossEntropyKernel(T* out, const T* softmax_out, - const int* label, const int batch_size, - const int class_num) { +__global__ void CrossEntropy(T* out, const T* softmax_out, const int* labels, + const int batch_size, const int class_num) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < batch_size) { - PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); - out[i] = -tolerable_value(std::log(softmax_out[i * class_num + label[i]])); + PADDLE_ASSERT(labels[i] >= 0 && labels[i] < class_num); + out[i] = + -TolerableValue()(std::log(softmax_out[i * class_num + labels[i]])); } } template -__global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out, - const int* label, - const int batch_size, - const int class_num) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < batch_size) { - PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num); - softmax_out[i * class_num + label[i]] -= 1.; +__global__ void CrossEntropyGrad(T* out_grad, const T* in_grad, + const int* labels, const int batch_size, + const int class_num) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int sample_idx = tid / class_num; + + if (tid < batch_size * class_num) out_grad[tid] *= in_grad[sample_idx]; + __syncthreads(); + + if (tid < batch_size) { + PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num); + out_grad[tid * class_num + labels[tid]] -= 1.; + } +} + +template +__device__ __forceinline__ T sum_single_warp(T val) { + val += __shfl_down(val, 16); + val += __shfl_down(val, 8); + val += __shfl_down(val, 4); + val += __shfl_down(val, 2); + val += __shfl_down(val, 1); + return val; +} + +template +__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, + const int class_num) { + int tid = threadIdx.x; + extern __shared__ T d_sum[]; + d_sum[tid] = 0; + + int cur_idx = tid; + int next_idx = blockIdx.x * class_num + tid; + while (cur_idx < class_num) { + d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx]; + next_idx += blockDim.x; + cur_idx += blockDim.x; + } + __syncthreads(); + + for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) { + if (tid < stride) d_sum[tid] += d_sum[tid + stride]; + __syncthreads(); + } + + T val = d_sum[tid]; + val = sum_single_warp(val); + if (tid == 0) Y[blockIdx.x] = -val; +} + +template +__global__ void SoftCrossEntropyGradientKernel(T* logit_grad, + const T* loss_grad, + const T* labels, + const int batch_size, + const int class_num) { + int ids = blockIdx.x * blockDim.x + threadIdx.x; + if (ids < batch_size * class_num) { + int row_ids = ids / class_num; + logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids]; } } @@ -52,27 +105,36 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), "This kernel only runs on GPU device."); + T* loss_data = + context.Output("Loss")->mutable_data(context.GetPlace()); - // Calculate ths softmax outputs. const Tensor* logits = context.Input("Logits"); Tensor* softmax = context.Output("Softmax"); - softmax->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(logits, softmax, context); - T* softmax_out = softmax->data(); - - // Calculate the cross entropy loss based on hard labels. - const int* label_data = context.Input("Label")->data(); - Tensor* loss = context.Output("Loss"); - loss->mutable_data(context.GetPlace()); - T* loss_data = loss->data(); + T* softmax_out = softmax->mutable_data(context.GetPlace()); + math::SoftmaxFunctor()(context, logits, softmax); const int batch_size = logits->dims()[0]; const int class_num = logits->dims()[1]; int block = 512; int grid = (batch_size + block - 1) / block; - CrossEntropyKernel<<>>(loss_data, softmax_out, label_data, - batch_size, class_num); + if (context.Attr("softLabel")) { + const T* label_data = context.Input("Label")->data(); + block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); + + SoftCrossEntropyKernel< + T><<( + context.device_context()) + .stream()>>>(loss_data, softmax_out, label_data, class_num); + } else { + const int* label_data = context.Input("Label")->data(); + CrossEntropy<<( + context.device_context()) + .stream()>>>(loss_data, softmax_out, label_data, + batch_size, class_num); + } } }; @@ -82,7 +144,9 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), "This kernel only runs on GPU device."); - + const Tensor* labels = context.Input("Label"); + const T* loss_grad_data = + context.Input(framework::GradVarName("Loss"))->data(); Tensor* logit_grad = context.Output(framework::GradVarName("Logits")); logit_grad->ShareDataWith(*context.Input("Softmax")); @@ -90,14 +154,24 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { const int batch_size = logit_grad->dims()[0]; const int class_num = logit_grad->dims()[1]; - - const int* label_data = context.Input("Label")->data(); - - const int block = 512; - const int grid = (batch_size + block - 1) / block; - - CrossEntropyWithSoftmaxGradKernel<<>>( - logit_grad_data, label_data, batch_size, class_num); + int block = 512; + int grid = (batch_size * class_num + block - 1) / block; + + if (context.Attr("softLabel")) { + const T* label_data = labels->data(); + SoftCrossEntropyGradientKernel<<< + grid, block, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(logit_grad_data, loss_grad_data, + label_data, batch_size, class_num); + } else { + const int* label_data = labels->data(); + CrossEntropyGrad<<< + grid, block, 0, reinterpret_cast( + context.device_context()) + .stream()>>>(logit_grad_data, loss_grad_data, + label_data, batch_size, class_num); + } } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index 0ad48dae2cd..71705cedf26 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -32,28 +32,35 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()), "This kernel only runs on CPU."); - - // Calculate ths softmax outputs. const Tensor* logits = context.Input("Logits"); + const Tensor* labels = context.Input("Label"); Tensor* softmax = context.Output("Softmax"); - softmax->mutable_data(context.GetPlace()); - - math::SoftmaxFunctor()(logits, softmax, context); + Tensor* loss = context.Output("Loss"); - // Calculate the cross entropy loss based on hard labels. - T* softmax_out = softmax->data(); - const int* label_data = context.Input("Label")->data(); + T* softmax_data = softmax->mutable_data(context.GetPlace()); + T* loss_data = loss->mutable_data(context.GetPlace()); - Tensor* loss = context.Output("Loss"); - loss->mutable_data(context.GetPlace()); - T* loss_data = loss->data(); + math::SoftmaxFunctor()(context, logits, softmax); const int batch_size = logits->dims()[0]; - const int class_num = logits->dims()[1]; - - for (int i = 0; i < batch_size; ++i) { - int index = i * class_num + label_data[i]; - loss_data[i] = -tolerable_value(std::log(softmax_out[index])); + if (context.Attr("softLabel")) { + //(TODO caoying) the forward implementation can be further optimized. + // Current implementation is exactly cross entropy after softmax. + auto prob = EigenMatrix::From(*softmax); + auto lbl_mat = EigenMatrix::From(*labels); + auto loss_mat = EigenMatrix::From(*loss); + + loss_mat.device(context.GetEigenDevice()) = + -((lbl_mat * prob.log().unaryExpr(TolerableValue())) + .sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(batch_size, 1))); + } else { + const int* label_data = labels->data(); + const int class_num = logits->dims()[1]; + + for (int i = 0; i < batch_size; ++i) + loss_data[i] = -TolerableValue()( + std::log(softmax_data[i * class_num + label_data[i]])); } } }; @@ -62,18 +69,34 @@ template class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + const Tensor* out_grad = + context.Input(framework::GradVarName("Loss")); + const Tensor* labels = context.Input("Label"); Tensor* logit_grad = context.Output(framework::GradVarName("Logits")); logit_grad->ShareDataWith(*context.Input("Softmax")); - T* logit_grad_data = logit_grad->data(); - const int batch_size = logit_grad->dims()[0]; const int class_num = logit_grad->dims()[1]; - - const int* label_data = context.Input("Label")->data(); - for (int i = 0; i < batch_size; ++i) { - int index = i * class_num + label_data[i]; - logit_grad_data[index] -= 1.; + if (context.Attr("softLabel")) { + auto out_grad_mat = EigenMatrix::From(*out_grad); + auto logit_grad_mat = EigenMatrix::From(*logit_grad); + auto lbl_mat = EigenMatrix::From(*labels); + + logit_grad_mat.device(context.GetEigenDevice()) = + logit_grad_mat * + out_grad_mat.broadcast(Eigen::DSizes(1, class_num)) - + lbl_mat; + } else { + const int batch_size = logit_grad->dims()[0]; + const int* label_data = labels->data(); + const T* out_grad_data = out_grad->data(); + T* logit_grad_data = logit_grad->data(); + + for (int i = 0; i < batch_size; ++i) { + int index = i * class_num + label_data[i]; + logit_grad_data[index] = + (out_grad_data[i] * logit_grad_data[index] - 1.); + } } } }; diff --git a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py index 9c9ee77b734..428395b76c8 100644 --- a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py @@ -6,22 +6,23 @@ from test_softmax_op import stable_softmax class TestSoftmaxWithCrossEntropyOp(OpTest): + """ + Test softmax with cross entropy operator with discreate one-hot labels. + """ + def setUp(self): self.op_type = "softmax_with_cross_entropy" - - MAX_BATCH_SIZE = 23 - MAX_CLASS_NUM = 17 - - batch_size = np.random.randint(1, MAX_BATCH_SIZE, 1)[0] - class_num = np.random.randint(2, MAX_CLASS_NUM, 1)[0] + batch_size = 3 + class_num = 37 logits = np.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") softmax = np.apply_along_axis(stable_softmax, 1, logits) - labels = np.random.randint(0, class_num, batch_size, dtype="int32") + labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int32") cross_entropy = np.asmatrix( - [[-np.log(softmax[i][labels[i]])] for i in range(softmax.shape[0])], + [[-np.log(softmax[i][labels[i][0]])] + for i in range(softmax.shape[0])], dtype="float32") self.inputs = {"Logits": logits, "Label": labels} @@ -34,5 +35,36 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.check_grad(["Logits"], "Loss", max_relative_error=0.05) +class TestSoftmaxWithCrossEntropyOp2(OpTest): + """ + Test softmax with cross entropy operator with soft labels. + """ + + def setUp(self): + self.op_type = "softmax_with_cross_entropy" + batch_size = 2 + class_num = 17 + + logits = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + softmax = np.apply_along_axis(stable_softmax, 1, logits) + labels = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + labels /= np.sum(labels, axis=1, keepdims=True) + + cross_entropy = (-labels * np.log(softmax)).sum( + axis=1, keepdims=True).astype("float32") + + self.inputs = {"Logits": logits, "Label": labels} + self.outputs = {"Softmax": softmax, "Loss": cross_entropy} + self.attrs = {"softLabel": True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["Logits"], "Loss", max_relative_error=0.05) + + if __name__ == "__main__": unittest.main() -- GitLab