提交 661f9dfa 编写于 作者: C chenzomi

add dropout primtive

上级 3d3b9d54
......@@ -25,6 +25,7 @@ from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register
from mindspore.common.api import ms_function
from mindspore import context
from ..cell import Cell
from .activation import get_activation
from ..._checkparam import Validator as validator
......@@ -84,8 +85,19 @@ class Dropout(Cell):
self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed0, Seed1=seed1)
self.dropout_do_mask = P.DropoutDoMask()
self.cast = P.Cast()
self.is_gpu = context.get_context('device_target') in ["GPU"]
if self.is_gpu:
self.dropout = P.Dropout(keep_prob)
def construct(self, x):
if not self.training:
return x
if self.is_gpu:
out, _ = self.dropout(x)
return out
shape = self.get_shape(x)
dtype = P.DType()(x)
keep_prob = self.cast(self.keep_prob, dtype)
......
......@@ -643,3 +643,17 @@ def get_bprop_binary_cross_entropy(self):
return dx, zeros_like(y), zeros_like(weight)
return bprop
@bprop_getters.register(P.Dropout)
def get_bprop_dropout(self):
"""Grad definition for `Dropout` operation."""
grad = P.DropoutGrad(self.drop_prob)
def bprop(x, out, dout):
_, mask = out
dy, _ = dout
dx = grad(dy, mask)
return (dx,)
return bprop
......@@ -52,7 +52,7 @@ from .random_ops import (RandomChoiceWithMask)
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
DropoutDoMask,
DropoutDoMask, DropoutGrad, Dropout,
DropoutGenMask, Flatten, FusedBatchNorm,
Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss,
......@@ -157,6 +157,8 @@ __all__ = [
'Shape',
'DropoutDoMask',
'DropoutGenMask',
'DropoutGrad',
'Dropout',
'Neg',
'Slice',
'DType',
......
......@@ -2762,3 +2762,68 @@ class ConfusionMulGrad(PrimitiveWithInfer):
validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name)
validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name)
return input0_dtype, input1_dtype
class Dropout(PrimitiveWithInfer):
"""
During training, randomly zeroes some of the elements of the input tensor with probability.
Args:
drop_prob (float): probability of an element to be zeroed. Default: 0.
Inputs:
- **shape** (tuple[int]) - The shape of target mask.
Outputs:
Tensor, the value of generated mask for input shape.
Examples:
>>> dropout = P.Dropout(drop_prob=0.5)
>>> in = Tensor((20, 16, 50, 50))
>>> out = dropout(in)
"""
@prim_attr_register
def __init__(self, drop_prob=0):
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name)
def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
mask_shape = x_shape
return x_shape, mask_shape
def infer_dtype(self, x_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name)
return x_dtype, x_dtype
class DropoutGrad(PrimitiveWithInfer):
"""
The gradient of Dropout. During training, randomly zeroes some of the elements
of the input tensor with probability.
Args:
drop_prob (float): probability of an element to be zeroed. Default: 0.
Inputs:
- **shape** (tuple[int]) - The shape of target mask.
Outputs:
Tensor, the value of generated mask for input shape.
Examples:
>>> dropout_grad = P.DropoutGrad(drop_prob=0.5)
>>> in = Tensor((20, 16, 50, 50))
>>> out = dropout_grad(in)
"""
@prim_attr_register
def __init__(self, drop_prob=0):
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name)
def infer_shape(self, dy_shape, mask_shape):
return dy_shape
def infer_dtype(self, dy_dtype, mask_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
return dy_dtype
......@@ -17,7 +17,9 @@ import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
context.set_context(device_target="Ascend")
def test_check_dropout_3():
Tensor(np.ones([20, 16, 50]).astype(np.int32))
......
......@@ -19,26 +19,26 @@ from mindspore.common.api import _executor
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore import context
context.set_context(device_target="Ascend")
def test_check_dropout_1():
x = Tensor(np.ones([20, 16, 50]), mstype.float32)
m = nn.Dropout(0.8)
with pytest.raises(NotImplementedError):
m(x)
m(x)
def test_check_dropout_2():
x = Tensor(np.ones([20, 16, 50]), mstype.float32)
m = nn.Dropout(0.3, seed0=1)
with pytest.raises(NotImplementedError):
m(x)
m(x)
def test_check_dropout_3():
x = Tensor(np.ones([20, 16, 50]), mstype.float32)
m = nn.Dropout(0.3, seed0=1, seed1=1)
with pytest.raises(NotImplementedError):
m(x)
m(x)
class Net_Dropout(nn.Cell):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册