diff --git a/mindspore/nn/probability/distribution/_utils/custom_ops.py b/mindspore/nn/probability/distribution/_utils/custom_ops.py index b81f5d186ab222ca32c4b7fbc39bb03a11483e4f..4c6213f7d2fbfed97855b0af32b61487c50a6775 100644 --- a/mindspore/nn/probability/distribution/_utils/custom_ops.py +++ b/mindspore/nn/probability/distribution/_utils/custom_ops.py @@ -17,6 +17,7 @@ import numpy as np from mindspore.ops import operations as P from mindspore.common import dtype as mstype + def exp_by_step(input_x): """ Log op on Ascend doesn't supprot int types. @@ -24,23 +25,18 @@ def exp_by_step(input_x): """ exp = P.Exp() cast = P.Cast() - dtype = P.DType() - checktype = P.IsSubClass() - if checktype(dtype(input_x), mstype.int_): - input_x = cast(input_x, mstype.float32) - elif checktype(dtype(input_x), mstype.float_): - pass - else: - return None + input_x = cast(input_x, mstype.float32) return exp(input_x) + def expm1_by_step(input_x): """ Expm1 ops under GPU context. """ return exp_by_step(input_x) - 1.0 + def log_by_step(input_x): """ Log op on Ascend is calculated as log(abs(x)). @@ -56,14 +52,8 @@ def log_by_step(input_x): dtype = P.DType() shape = P.Shape() select = P.Select() - checktype = P.IsSubClass() - if checktype(dtype(input_x), mstype.int_): - input_x = cast(input_x, mstype.float32) - elif checktype(dtype(input_x), mstype.float_): - pass - else: - return None + input_x = cast(input_x, mstype.float32) nan = fill(dtype(input_x), shape(input_x), np.nan) inf = fill(dtype(input_x), shape(input_x), np.inf) neg_x = less(input_x, 0.0) @@ -72,6 +62,7 @@ def log_by_step(input_x): result = select(nonpos_x, -inf, log_x) return select(neg_x, nan, result) + def log1p_by_step(x): """ Log1p ops on GPU device or when device_target == GPU. diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 9ad95394782af9a6edff9db2b5802fc66401e1c0..d49088ad48b6cd190e6cb83da6c932d00dbea8e0 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -14,15 +14,15 @@ # ============================================================================ """Utitly functions to help distribution class.""" import numpy as np -from mindspore.ops import _utils as utils -from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register +from mindspore import context from mindspore._checkparam import Validator as validator from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype -from mindspore.ops import operations as P +from mindspore.ops import _utils as utils from mindspore.ops import composite as C -from mindspore import context +from mindspore.ops import operations as P +from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register import mindspore.nn as nn import mindspore.nn.probability as msp @@ -82,6 +82,24 @@ def convert_to_batch(t, batch_shape, required_type): return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type) +def cast_type_for_device(dtype): + """ + use the alternative dtype supported by the device. + Args: + dtype (mindspore.dtype): input dtype. + Returns: + mindspore.dtype. + """ + if context.get_context("device_target") == "GPU": + if dtype in mstype.uint_type or dtype == mstype.int8: + return mstype.int16 + if dtype == mstype.int64: + return mstype.int32 + if dtype == mstype.float64: + return mstype.float32 + return dtype + + def check_scalar_from_param(params): """ Check if params are all scalars. @@ -293,10 +311,10 @@ def raise_not_impl_error(name): def check_distribution_name(name, expected_name): if name is None: raise ValueError( - f"Distribution should be a constant which is not None.") + f"Input dist should be a constant which is not None.") if name != expected_name: raise ValueError( - f"Expected distribution name is {expected_name}, but got {name}.") + f"Expected dist input is {expected_name}, but got {name}.") class CheckTuple(PrimitiveWithInfer): diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index 7a1385daede8f0db0d8321b83cdd936dea3268e9..203c97fcc6dd538dc8de76cadb6862e0baba2613 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -16,9 +16,10 @@ from mindspore.nn.cell import Cell from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel -from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param +from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device from ._utils.utils import CheckTuple, CheckTensor + class Distribution(Cell): """ Base class for all mathematical distributions. @@ -43,12 +44,12 @@ class Distribution(Cell): new distribution specified by the dist_spec_args. But it won't change the original distribuion. """ + def __init__(self, seed, dtype, name, param): - """ Constructor of distribution class. """ @@ -58,7 +59,7 @@ class Distribution(Cell): self._name = name self._seed = seed - self._dtype = dtype + self._dtype = cast_type_for_device(dtype) self._parameters = {} # parsing parameters for k in param.keys(): @@ -436,7 +437,6 @@ class Distribution(Cell): """ return self._sample(*args, **kwargs) - def construct(self, name, *args, **kwargs): """ Override construct in Cell.