From 1722678258fab032676bbd63aa3f95e6e925d1e4 Mon Sep 17 00:00:00 2001 From: whs Date: Fri, 16 Nov 2018 16:08:22 +0800 Subject: [PATCH] Make nce support more distribution. (#13549) * Fix truncated normal. * Fix. * Make nce support more distribution. * Fix API.spec. * Fix python API. * Fix. test=develop * Fix API.spec test=develop * Fix sampler. * Fix order of arguments in python API. test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/CMakeLists.txt | 1 + paddle/fluid/operators/math/CMakeLists.txt | 1 + paddle/fluid/operators/math/sampler.cc | 117 ++++++++++++++---- paddle/fluid/operators/math/sampler.h | 55 +++++--- paddle/fluid/operators/nce_op.cc | 19 +++ paddle/fluid/operators/nce_op.h | 101 ++++++++++----- python/paddle/fluid/layers/nn.py | 47 ++++++- .../paddle/fluid/tests/unittests/test_nce.py | 4 +- 9 files changed, 272 insertions(+), 75 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index a23deebb25..da8941c351 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -97,7 +97,7 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) -paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None)) +paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 2dc83c391b..0117a24c1b 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -308,6 +308,7 @@ op_library(flatten_op DEPS reshape_op) op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) op_library(fake_quantize_op DEPS memory) +op_library(nce_op DEPS sampler) if (NOT WIN32) op_library(crf_decoding_op DEPS jit_kernel) op_library(fusion_lstm_op DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index cc3cc9787a..4cd014cbad 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -41,6 +41,7 @@ math_library(cross_entropy) math_library(cos_sim_functor) math_library(depthwise_conv) math_library(im2col) +math_library(sampler) if (NOT WIN32) # windows do not support avx functions yet. math_library(gru_compute DEPS activation_functions math_function) diff --git a/paddle/fluid/operators/math/sampler.cc b/paddle/fluid/operators/math/sampler.cc index 3066dc0ba2..690d6f6baa 100644 --- a/paddle/fluid/operators/math/sampler.cc +++ b/paddle/fluid/operators/math/sampler.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,52 +13,46 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/sampler.h" +#include +#include +#include +#include namespace paddle { -namespace random { +namespace operators { +namespace math { Sampler::~Sampler() {} -UniformSampler::UniformSampler(int64 range) - : Sampler(range), inv_range_(1.0 / range) { - random_engine_ = std::make_shared(seed_); +UniformSampler::UniformSampler(int64_t range, unsigned int seed) + : Sampler(range, seed), inv_range_(1.0 / (range + 1)) { + random_engine_ = std::make_shared(seed_); dist_ = std::make_shared>(0, range); } -UniformSampler::UniformSampler(int64 range, unsigned int seed) - : Sampler(range, seed), inv_range_(1.0 / range) { - random_engine_ = std::make_shared(seed_); - dist_ = std::make_shared>(0, range); -} - -int64 UniformSampler::Sample() const { return (*dist_)(*random_engine_); } +int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); } -float UniformSampler::Probability(int64 value) const { return inv_range_; } +float UniformSampler::Probability(int64_t value) const { return inv_range_; } -LogUniformSampler::LogUniformSampler(int64 range) - : Sampler(range), log_range_(log(range + 1)) { - random_engine_ = std::make_shared(seed_); - dist_ = std::make_shared>(0, 1); -} - -LogUniformSampler::LogUniformSampler(int64 range, unsigned int seed) +LogUniformSampler::LogUniformSampler(int64_t range, unsigned int seed) : Sampler(range, seed), log_range_(log(range + 1)) { - random_engine_ = std::make_shared(seed_); + random_engine_ = std::make_shared(seed_); dist_ = std::make_shared>(0, 1); } -int64 LogUniformSampler::Sample() const { + +int64_t LogUniformSampler::Sample() const { // Got Log Uniform distribution from uniform distribution by // inverse_transform_sampling method // More details: // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ - const int64 value = - static_cast(exp((*dist_)(*random_engine_) * log_range_)) - 1; + const int64_t value = + static_cast(exp((*dist_)(*random_engine_) * log_range_)) - 1; // Mathematically, value should be <= range_, but might not be due to some // floating point roundoff, so we mod by range_. return value % range_; } -float LogUniformSampler::Probability(int64 value) const { +float LogUniformSampler::Probability(int64_t value) const { // Given f(x) = 1/[(x+1) * log_range_] // The value's probability is integral of f(x) from value to (value + 1) // More details: @@ -66,5 +60,76 @@ float LogUniformSampler::Probability(int64 value) const { return (log((value + 2.0) / (value + 1.0))) / log_range_; } -} // namespace random +CustomSampler::CustomSampler(int64_t range, const float* probabilities, + unsigned int seed) + : Sampler(range, seed) { + random_engine_ = std::make_shared(seed_); + real_dist_ = std::make_shared>(0, 1); + int_dist_ = std::make_shared>(0, range); + alias_probs_ = std::make_shared>(range + 1); + alias_ = std::make_shared>(range + 1); + probs_ = std::make_shared>(range + 1); + + std::queue> bigs; + std::queue> littles; + for (int64_t i = 0; i <= range; ++i) { + (*probs_)[i] = probabilities[i]; + float normal_prob = probabilities[i] * (range + 1); + if (normal_prob - 1.0 > 1e-4) { + bigs.emplace(i, normal_prob); + } else if (1.0 - normal_prob > 1e-4) { + littles.emplace(i, normal_prob); + } else { + (*alias_probs_)[i] = normal_prob; + (*alias_)[i] = -1; + } + } + + while ((!littles.empty()) && (!bigs.empty())) { + auto big = bigs.front(); + auto little = littles.front(); + bigs.pop(); + littles.pop(); + (*alias_probs_)[little.first] = little.second; + (*alias_)[little.first] = big.first; + auto big_left = big.second - (1 - little.second); + if (big_left - 1.0 > 1e-4) { + bigs.emplace(big.first, big_left); + } else if (1.0 - big_left > 1e-4) { + littles.emplace(big.first, big_left); + } else { + (*alias_probs_)[big.first] = big_left; + (*alias_)[big.first] = -1; + } + } + + if (!littles.empty()) { // littles.second is close to 1.0 + auto little = littles.front(); + (*alias_probs_)[little.first] = 1.0; + (*alias_)[little.first] = -1; + } + + if (!bigs.empty()) { // bigs.second is close to 1.0 + auto big = bigs.front(); + (*alias_probs_)[big.first] = 1.0; + (*alias_)[big.first] = -1; + } +} + +int64_t CustomSampler::Sample() const { + auto index = (*int_dist_)(*random_engine_); + auto p = (*real_dist_)(*random_engine_); + if (p > (*alias_probs_)[index]) { + return (*alias_)[index]; + } else { + return index; + } +} + +float CustomSampler::Probability(int64_t value) const { + return (*probs_)[value]; +} + +} // namespace math +} // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/sampler.h b/paddle/fluid/operators/math/sampler.h index b82691f269..836cdad51f 100644 --- a/paddle/fluid/operators/math/sampler.h +++ b/paddle/fluid/operators/math/sampler.h @@ -16,6 +16,8 @@ limitations under the License. */ #include #include #include +#include + namespace paddle { namespace operators { namespace math { @@ -27,14 +29,14 @@ namespace math { */ class Sampler { public: - explicit Sampler(int64_t range) : range_(range) { - PADDLE_ENFORCE_GT(range, 0); - std::random_device r; - seed_ = r(); - } - explicit Sampler(int64_t range, unsigned int seed) - : range_(range), seed_(seed) { - PADDLE_ENFORCE_GT(range, 0); + explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) { + // PADDLE_ENFORCE_GT(range, 0, "Range should be greater than 0."); + if (seed == 0) { + std::random_device r; + seed_ = r(); + } else { + seed_ = seed; + } } virtual ~Sampler(); // Sample a single value @@ -42,7 +44,7 @@ class Sampler { // The probability that a single call to Sample() returns the given value. virtual float Probability(int64_t value) const = 0; - int64 range() { return range_; } + int64_t range() { return range_; } protected: const int64_t range_; @@ -56,13 +58,11 @@ class Sampler { */ class UniformSampler : public Sampler { public: - explicit UniformSampler(int64_t range); - - explicit UniformSampler(int64_t range, unsigned int seed); + explicit UniformSampler(int64_t range, unsigned int seed = 0UL); ~UniformSampler() override {} - int64 Sample() const override; + int64_t Sample() const override; float Probability(int64_t value) const override; @@ -79,13 +79,11 @@ class UniformSampler : public Sampler { */ class LogUniformSampler : public Sampler { public: - explicit LogUniformSampler(int64_t range); - - explicit LogUniformSampler(int64_t range, unsigned int seed); + explicit LogUniformSampler(int64_t range, unsigned int seed = 0UL); ~LogUniformSampler() override {} - int64 Sample() const override; + int64_t Sample() const override; float Probability(int64_t value) const override; @@ -95,6 +93,29 @@ class LogUniformSampler : public Sampler { std::shared_ptr> dist_; }; +/** + * Sample integers from [0, range) from custom distribution. + */ +class CustomSampler : public Sampler { + public: + explicit CustomSampler(int64_t range, const float* probabilities, + unsigned int seed = 0UL); + + ~CustomSampler() override {} + + int64_t Sample() const override; + + float Probability(int64_t value) const override; + + private: + std::shared_ptr> alias_probs_; + std::shared_ptr> alias_; + std::shared_ptr> probs_; + std::shared_ptr random_engine_; + std::shared_ptr> real_dist_; + std::shared_ptr> int_dist_; +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index 877c9a0528..9b0d45ae5b 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -35,6 +35,7 @@ class NCEOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("Input"); auto label_dims = ctx->GetInputDim("Label"); + auto w_dims = ctx->GetInputDim("Weight"); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1; if (ctx->HasInput("Bias")) { @@ -98,6 +99,13 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { "each sample. And it is a dispensable input. The default value of " "sample is 1.") .AsDispensable(); + + AddInput( + "CustomDistribution", + "(Tensor) It is used in 'CostumDist' sampler. " + "It is a tensor with shape [num_total_classes]." + "The i-th element is the probsbility of the i-th class being sampled.") + .AsDispensable(); AddOutput("Cost", "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); AddOutput("SampleLogits", @@ -121,6 +129,17 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("num_neg_samples", "The number of negative classes. The default value is 10.") .SetDefault(10); + + AddAttr("sampler", + "(int) Which sampler to be used to sample negative class." + "0: Uniform; 1: LogUniform; 2: CostumDist.") + .SetDefault(0); + + AddAttr("seed", + "(int) The seed used in sampler. If it is 0, " + "the sampler will generate a seed randomly.") + .SetDefault(0); + AddAttr>("custom_neg_classes", "This attribute only be used in unitest. Classes " "in this list wiil be used as negative classes " diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 2c4c97f28b..e9af8ad4ce 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -19,29 +19,28 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/sampler.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using Sampler = math::Sampler; template using EigenMatrix = framework::EigenMatrix; template -void PrepareSamples(const framework::ExecutionContext& context) { +void PrepareSamples(const framework::ExecutionContext& context, + Sampler* sampler) { auto label = context.Input("Label"); const int64_t* label_data = label->data(); auto label_dims = label->dims(); - int num_total_classes = context.Attr("num_total_classes"); + // int num_total_classes = context.Attr("num_total_classes"); // for unitest std::vector custom_neg_classes = context.Attr>("custom_neg_classes"); - // random machine - std::random_device rd; - std::mt19937 rng(rd()); - std::uniform_int_distribution rand(0, num_total_classes - 1); auto sample_labels = context.Output("SampleLabels"); auto sample_labels_dims = sample_labels->dims(); @@ -62,7 +61,7 @@ void PrepareSamples(const framework::ExecutionContext& context) { } else { for (; j < sample_labels_dims[1]; ++j) { // TODO(wanghaoshuang): support more distribution sampling - sample_labels_data[index++] = rand(rng); + sample_labels_data[index++] = sampler->Sample(); } } } @@ -72,7 +71,33 @@ template class NCEKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - PrepareSamples(context); + int sampler_type = context.Attr("sampler"); + int seed = context.Attr("seed"); + int num_total_classes = context.Attr("num_total_classes"); + int num_neg_samples = context.Attr("num_neg_samples"); + + Sampler* sampler; + switch (sampler_type) { + case 0: { + sampler = new math::UniformSampler(num_total_classes - 1, seed); + break; + } + case 1: { + sampler = new math::LogUniformSampler(num_total_classes - 1, seed); + break; + } + case 2: { + auto custom_dist = context.Input("CustomDistribution"); + const float* custom_dist_data = custom_dist->data(); + PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); + sampler = new math::CustomSampler(num_total_classes - 1, + custom_dist_data, seed); + break; + } + default: { PADDLE_THROW("Unsupported SamplerType."); } + } + + PrepareSamples(context, sampler); auto sample_labels = context.Output("SampleLabels"); const int64_t* sample_labels_data = sample_labels->data(); auto sample_out = context.Output("SampleLogits"); @@ -85,13 +110,12 @@ class NCEKernel : public framework::OpKernel { } auto out = context.Output("Cost"); T* out_data = out->mutable_data(context.GetPlace()); - int num_neg_samples = context.Attr("num_neg_samples"); - int num_total_classes = context.Attr("num_total_classes"); int64_t num_true_class = 1; if (label != nullptr) { num_true_class = label->dims()[1]; } - T b = 1. / num_total_classes * num_neg_samples; + int64_t sampled_labels_num = sample_labels->dims()[1]; + // T b = 1. / num_total_classes * num_neg_samples; // forward bias auto bias = context.Input("Bias"); if (bias != nullptr) { @@ -117,22 +141,17 @@ class NCEKernel : public framework::OpKernel { } // forward cost for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) { - int64_t j = 0; out_data[i] = 0; T w = sample_weight == nullptr ? 1. : sample_weight_data[i]; - // for true classes - for (; j < num_true_class; ++j) { - T o = sample_out_data[i * sample_out->dims()[1] + j]; - T cost = -log(o / (o + b)); - out_data[i] += w * cost; - } - // for sampled neg classes - for (; j < sample_labels->dims()[1]; ++j) { - T o = sample_out_data[i * sample_out->dims()[1] + j]; - T cost = -log(b / (o + b)); + for (int64_t j = 0; j < sampled_labels_num; ++j) { + int64_t target = sample_labels_data[i * sampled_labels_num + j]; + T o = sample_out_data[i * sampled_labels_num + j]; + float b = sampler->Probability(target) * num_neg_samples; + T cost = (j < num_true_class) ? -log(o / (o + b)) : -log(b / (o + b)); out_data[i] += w * cost; } } + delete sampler; } }; @@ -158,20 +177,45 @@ class NCEGradKernel : public framework::OpKernel { if (label != nullptr) { num_true_class = label->dims()[1]; } - T b = 1. / num_total_classes * num_neg_samples; + + int sampler_type = context.Attr("sampler"); + int seed = context.Attr("seed"); + Sampler* sampler; + switch (sampler_type) { + case 0: { + sampler = new math::UniformSampler(num_total_classes - 1, seed); + break; + } + case 1: { + sampler = new math::LogUniformSampler(num_total_classes - 1, seed); + break; + } + case 2: { + auto custom_dist = context.Input("CustomDistribution"); + const float* custom_dist_data = custom_dist->data(); + PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); + sampler = new math::CustomSampler(num_total_classes - 1, + custom_dist_data, seed); + break; + } + default: { PADDLE_THROW("Unsupported SamplerType."); } + } + + // T b = 1. / num_total_classes * num_neg_samples; Tensor sample_grad; // tmp tensor T* sample_grad_data = sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); // backward cost for (int64_t i = 0; i < sample_labels->numel(); ++i) { + int64_t label_idx = i % sample_labels->dims()[1]; + int64_t sample_idx = i / sample_labels->dims()[1]; + float b = sampler->Probability(sample_labels_data[i]) * num_neg_samples; T o = sample_out_data[i]; - T w = sample_weight == nullptr - ? 1 - : sample_weight_data[i / sample_labels->dims()[1]]; - sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class + T w = sample_weight == nullptr ? 1 : sample_weight_data[sample_idx]; + sample_grad_data[i] = label_idx < num_true_class ? w * (b / (o + b)) * (o - 1) : w * (o * (1 - o) / (o + b)); - sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]]; + sample_grad_data[i] *= d_out_data[sample_idx]; } // get d_bias auto d_bias = context.Output(framework::GradVarName("Bias")); @@ -207,6 +251,7 @@ class NCEGradKernel : public framework::OpKernel { w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; } } + delete sampler; } }; } // namespace operators diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 002d0f006b..af96f5de4f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4313,7 +4313,10 @@ def nce(input, param_attr=None, bias_attr=None, num_neg_samples=None, - name=None): + name=None, + sampler="uniform", + custom_dist=None, + seed=0): """ ${comment} @@ -4336,6 +4339,14 @@ def nce(input, num_neg_samples (int): ${num_neg_samples_comment} name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None. + sampler (str): The sampler used to sample class from negtive classes. + It can be 'uniform', 'log_uniform' or 'custom_dist'. + default: 'uniform'. + custom_dist (Variable): A tensor with shape [num_total_classes]. + It is used when sampler is set to 'custom_dist'. + custom_dist[i] is the probsbility of i-th class to be sampled. + default: None. + seed (int): The seed used in sampler. default: 0. Returns: Variable: The output nce loss. @@ -4365,6 +4376,16 @@ def nce(input, loss = layers.nce(input=embs, label=words[label_word], num_total_classes=dict_size, param_attr='nce.w', bias_attr='nce.b') + + #or use custom distribution + dist = fluid.layers.assign(input=np.array([0.05,0.5,0.1,0.3,0.05]).astype("float32")) + loss = layers.nce(input=embs, label=words[label_word], + num_total_classes=5, param_attr='nce.w', + bias_attr='nce.b', + num_neg_samples=3, + sampler="custom_dist", + custom_dist=dist) + """ helper = LayerHelper('nce', **locals()) assert isinstance(input, Variable) @@ -4399,9 +4420,31 @@ def nce(input, else: num_neg_samples = int(num_neg_samples) + inputs = { + 'Input': input, + 'Label': label, + 'Weight': w, + 'Bias': b, + 'SampleWeight': sample_weight if sample_weight is not None else [] + } + + if sampler == "uniform": + sampler = 0 + elif sampler == "log_uniform": + sampler = 1 + elif sampler == "custom_dist": + assert custom_dist is not None + assert isinstance(custom_dist, Variable) + inputs['CustomDistribution'] = custom_dist + sampler = 2 + else: + raise Exception("Unsupported sampler type.") + attrs = { 'num_total_classes': int(num_total_classes), - 'num_neg_samples': num_neg_samples + 'num_neg_samples': num_neg_samples, + 'seed': seed, + 'sampler': sampler } helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_nce.py b/python/paddle/fluid/tests/unittests/test_nce.py index 0745bd274f..c01fdd5ddd 100644 --- a/python/paddle/fluid/tests/unittests/test_nce.py +++ b/python/paddle/fluid/tests/unittests/test_nce.py @@ -68,7 +68,9 @@ class TestNCE(OpTest): self.attrs = { 'num_total_classes': num_classes, 'num_neg_samples': num_neg_samples, - 'custom_neg_classes': list(range(num_neg_samples)) + 'custom_neg_classes': list(range(num_neg_samples)), + 'seed': 0, + 'sampler': 0 } self.inputs = { 'Input': input, -- GitLab