From 287cab49c299a9aae94e16edc08319df0b18487a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 26 Jul 2021 17:47:18 +0800 Subject: [PATCH] fix(mgb/sereg): fix rng operator compatibility GitOrigin-RevId: 66d1694035b026cb2b541f1249a4eadb2cbff50b --- dnn/scripts/opr_param_defs.py | 11 +++++++++-- src/opr/impl/rand.oprdecl | 4 ++-- src/opr/impl/rand.sereg.h | 17 ++++++++++------- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 38dfe14b0..f1aa516c1 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -745,13 +745,20 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) 'dtype', Doc('dtype', 'data type of output value'), 'DTypeEnum::Float32')) -(pdef('UniformRNG'). +(pdef('UniformRNG', version=0, is_legacy=True). + add_fields('uint64', 'seed', 0)) + +(pdef('UniformRNG', version=1). add_fields('uint64', 'seed', 0). add_fields( 'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'), 'DTypeEnum::Float32')) -(pdef('GaussianRNG'). +(pdef('GaussianRNG', version=0, is_legacy=True). + add_fields('uint64', 'seed', 0). + add_fields('float32', 'mean', 0, 'std', 1)) + +(pdef('GaussianRNG', version=1). add_fields('uint64', 'seed', 0). add_fields('float32', 'mean', 0, 'std', 1). add_fields( diff --git a/src/opr/impl/rand.oprdecl b/src/opr/impl/rand.oprdecl index 860095c18..763fcd1f6 100644 --- a/src/opr/impl/rand.oprdecl +++ b/src/opr/impl/rand.oprdecl @@ -1,12 +1,12 @@ decl_opr('UniformRNG', pyname='_uniform_rng', inputs=['shape'], params='UniformRNG', - canonize_input_vars='canonize_shape_input') + canonize_input_vars='canonize_shape_input', version=1) decl_opr('GaussianRNG', pyname='_gaussian_rng', inputs=['shape'], params='GaussianRNG', - canonize_input_vars='canonize_shape_input') + canonize_input_vars='canonize_shape_input', version=1) inputs = [ Doc('shape', diff --git a/src/opr/impl/rand.sereg.h b/src/opr/impl/rand.sereg.h index 315b69756..8c5c6c22a 100644 --- a/src/opr/impl/rand.sereg.h +++ b/src/opr/impl/rand.sereg.h @@ -13,18 +13,21 @@ #include "megbrain/serialization/sereg.h" namespace mgb { + + namespace opr { - MGB_SEREG_OPR(UniformRNG, 1); - MGB_SEREG_OPR(GaussianRNG, 1); - MGB_SEREG_OPR(GammaRNG, 2); - MGB_SEREG_OPR(PoissonRNG, 1); - MGB_SEREG_OPR(PermutationRNG, 1); - MGB_SEREG_OPR(BetaRNG, 2); +using UniformRNGV1 = opr::UniformRNG; +MGB_SEREG_OPR(UniformRNGV1, 1); +using GaussianRNGV1 = opr::GaussianRNG; +MGB_SEREG_OPR(GaussianRNGV1, 1); +MGB_SEREG_OPR(GammaRNG, 2); +MGB_SEREG_OPR(PoissonRNG, 1); +MGB_SEREG_OPR(PermutationRNG, 1); +MGB_SEREG_OPR(BetaRNG, 2); } // namespace opr } // namespace mgb - // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab