提交 659161df 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5029 Fix custom ops log/exp cast logic

Merge pull request !5029 from zichun_ye/fix_custom_ops
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
def exp_by_step(input_x): def exp_by_step(input_x):
""" """
Log op on Ascend doesn't supprot int types. Log op on Ascend doesn't supprot int types.
...@@ -24,23 +25,18 @@ def exp_by_step(input_x): ...@@ -24,23 +25,18 @@ def exp_by_step(input_x):
""" """
exp = P.Exp() exp = P.Exp()
cast = P.Cast() cast = P.Cast()
dtype = P.DType()
checktype = P.IsSubClass()
if checktype(dtype(input_x), mstype.int_): input_x = cast(input_x, mstype.float32)
input_x = cast(input_x, mstype.float32)
elif checktype(dtype(input_x), mstype.float_):
pass
else:
return None
return exp(input_x) return exp(input_x)
def expm1_by_step(input_x): def expm1_by_step(input_x):
""" """
Expm1 ops under GPU context. Expm1 ops under GPU context.
""" """
return exp_by_step(input_x) - 1.0 return exp_by_step(input_x) - 1.0
def log_by_step(input_x): def log_by_step(input_x):
""" """
Log op on Ascend is calculated as log(abs(x)). Log op on Ascend is calculated as log(abs(x)).
...@@ -56,14 +52,8 @@ def log_by_step(input_x): ...@@ -56,14 +52,8 @@ def log_by_step(input_x):
dtype = P.DType() dtype = P.DType()
shape = P.Shape() shape = P.Shape()
select = P.Select() select = P.Select()
checktype = P.IsSubClass()
if checktype(dtype(input_x), mstype.int_): input_x = cast(input_x, mstype.float32)
input_x = cast(input_x, mstype.float32)
elif checktype(dtype(input_x), mstype.float_):
pass
else:
return None
nan = fill(dtype(input_x), shape(input_x), np.nan) nan = fill(dtype(input_x), shape(input_x), np.nan)
inf = fill(dtype(input_x), shape(input_x), np.inf) inf = fill(dtype(input_x), shape(input_x), np.inf)
neg_x = less(input_x, 0.0) neg_x = less(input_x, 0.0)
...@@ -72,6 +62,7 @@ def log_by_step(input_x): ...@@ -72,6 +62,7 @@ def log_by_step(input_x):
result = select(nonpos_x, -inf, log_x) result = select(nonpos_x, -inf, log_x)
return select(neg_x, nan, result) return select(neg_x, nan, result)
def log1p_by_step(x): def log1p_by_step(x):
""" """
Log1p ops on GPU device or when device_target == GPU. Log1p ops on GPU device or when device_target == GPU.
......
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
# ============================================================================ # ============================================================================
"""Utitly functions to help distribution class.""" """Utitly functions to help distribution class."""
import numpy as np import numpy as np
from mindspore.ops import _utils as utils from mindspore import context
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import _utils as utils
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore import context from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability as msp import mindspore.nn.probability as msp
...@@ -82,6 +82,24 @@ def convert_to_batch(t, batch_shape, required_type): ...@@ -82,6 +82,24 @@ def convert_to_batch(t, batch_shape, required_type):
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type) return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)
def cast_type_for_device(dtype):
"""
use the alternative dtype supported by the device.
Args:
dtype (mindspore.dtype): input dtype.
Returns:
mindspore.dtype.
"""
if context.get_context("device_target") == "GPU":
if dtype in mstype.uint_type or dtype == mstype.int8:
return mstype.int16
if dtype == mstype.int64:
return mstype.int32
if dtype == mstype.float64:
return mstype.float32
return dtype
def check_scalar_from_param(params): def check_scalar_from_param(params):
""" """
Check if params are all scalars. Check if params are all scalars.
...@@ -293,10 +311,10 @@ def raise_not_impl_error(name): ...@@ -293,10 +311,10 @@ def raise_not_impl_error(name):
def check_distribution_name(name, expected_name): def check_distribution_name(name, expected_name):
if name is None: if name is None:
raise ValueError( raise ValueError(
f"Distribution should be a constant which is not None.") f"Input dist should be a constant which is not None.")
if name != expected_name: if name != expected_name:
raise ValueError( raise ValueError(
f"Expected distribution name is {expected_name}, but got {name}.") f"Expected dist input is {expected_name}, but got {name}.")
class CheckTuple(PrimitiveWithInfer): class CheckTuple(PrimitiveWithInfer):
......
...@@ -16,9 +16,10 @@ ...@@ -16,9 +16,10 @@
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device
from ._utils.utils import CheckTuple, CheckTensor from ._utils.utils import CheckTuple, CheckTensor
class Distribution(Cell): class Distribution(Cell):
""" """
Base class for all mathematical distributions. Base class for all mathematical distributions.
...@@ -43,12 +44,12 @@ class Distribution(Cell): ...@@ -43,12 +44,12 @@ class Distribution(Cell):
new distribution specified by the dist_spec_args. But it won't change the new distribution specified by the dist_spec_args. But it won't change the
original distribuion. original distribuion.
""" """
def __init__(self, def __init__(self,
seed, seed,
dtype, dtype,
name, name,
param): param):
""" """
Constructor of distribution class. Constructor of distribution class.
""" """
...@@ -58,7 +59,7 @@ class Distribution(Cell): ...@@ -58,7 +59,7 @@ class Distribution(Cell):
self._name = name self._name = name
self._seed = seed self._seed = seed
self._dtype = dtype self._dtype = cast_type_for_device(dtype)
self._parameters = {} self._parameters = {}
# parsing parameters # parsing parameters
for k in param.keys(): for k in param.keys():
...@@ -436,7 +437,6 @@ class Distribution(Cell): ...@@ -436,7 +437,6 @@ class Distribution(Cell):
""" """
return self._sample(*args, **kwargs) return self._sample(*args, **kwargs)
def construct(self, name, *args, **kwargs): def construct(self, name, *args, **kwargs):
""" """
Override construct in Cell. Override construct in Cell.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册