未验证 提交 e480168f 编写于 作者: H huangjun12 提交者: GitHub

fix dropout bug in backward when input is 1d tensor (#26837)

* fix dropout bug in backward when input is 1d tensor, test=develop

* add test case and refine error message, test=develop

* refine error message, test=develop
上级 2f50aa22
...@@ -29,6 +29,10 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -29,6 +29,10 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CPUDropoutKernel : public framework::OpKernel<T> { class CPUDropoutKernel : public framework::OpKernel<T> {
public: public:
...@@ -116,9 +120,9 @@ class DropoutGradKernel : public framework::OpKernel<T> { ...@@ -116,9 +120,9 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto* mask = context.Input<Tensor>("Mask"); auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace()); grad_x->mutable_data<T>(context.GetPlace());
auto M = EigenMatrix<uint8_t>::Reshape(*mask, 1); auto M = EigenVector<uint8_t>::Flatten(*mask);
auto dX = EigenMatrix<T>::Reshape(*grad_x, 1); auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenMatrix<T>::Reshape(*grad_y, 1); auto dY = EigenVector<T>::Flatten(*grad_y);
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
......
...@@ -40,6 +40,23 @@ class TestDropoutOp(OpTest): ...@@ -40,6 +40,23 @@ class TestDropoutOp(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestDropoutOpInput1d(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((2000)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((2000)).astype('uint8')
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
class TestDropoutOp2(TestDropoutOp): class TestDropoutOp2(TestDropoutOp):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
...@@ -436,6 +453,13 @@ class TestDropoutFAPIError(unittest.TestCase): ...@@ -436,6 +453,13 @@ class TestDropoutFAPIError(unittest.TestCase):
self.assertRaises(ValueError, test_axis_max) self.assertRaises(ValueError, test_axis_max)
def test_axis_min():
# minimum of axis should greater equal than 0
x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32")
paddle.nn.functional.dropout(x2, axis=[0, -1])
self.assertRaises(ValueError, test_axis_min)
def test_axis_len(): def test_axis_len():
# length of axis should not greater than dimensions of x # length of axis should not greater than dimensions of x
x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32")
......
...@@ -910,12 +910,12 @@ def dropout(x, ...@@ -910,12 +910,12 @@ def dropout(x,
#get mask shape #get mask shape
input_shape = x.shape input_shape = x.shape
drop_axes = [axis] if isinstance(axis, int) else axis drop_axes = [axis] if isinstance(axis, int) else axis
if max(drop_axes) > len(input_shape) - 1: if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1:
raise ValueError("axis value should less than dimensions of x:{}, but get drop_axes value:{} " \ raise ValueError("axis value should be greater than or equal to 0 and less than dimensions of x:{}, but get axis value:{} " \
.format(len(input_shape), max(drop_axes))) .format(len(input_shape), max(drop_axes)))
if len(drop_axes) > len(input_shape): if len(drop_axes) > len(input_shape):
raise ValueError( raise ValueError(
"length of axis should not greater than dimensions of x:{}, but get length of drop axes: {}". "length of axis should not be greater than dimensions of x:{}, but get length of axis: {}".
format(len(input_shape), len(drop_axes))) format(len(input_shape), len(drop_axes)))
mask_shape = [1] * len(input_shape) mask_shape = [1] * len(input_shape)
for i in drop_axes: for i in drop_axes:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册