提交 c2b1ddb6 编写于 作者: Y Yibing Liu

Correct the dropout_op's computation in test

上级 5b524810
...@@ -71,7 +71,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -71,7 +71,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto M = EigenMatrix<T>::Reshape(*mask, 1); auto M = EigenMatrix<T>::Reshape(*mask, 1);
Y.device(place) = X * M; Y.device(place) = X * M;
} else { } else {
Y.device(place) = X * dropout_prob; Y.device(place) = X * (1.0f - dropout_prob);
} }
} }
}; };
......
...@@ -57,7 +57,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -57,7 +57,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * dropout_prob; Y.device(place) = X * (1.0f - dropout_prob);
} }
} }
}; };
......
...@@ -47,7 +47,9 @@ class TestDropoutOp4(OpTest): ...@@ -47,7 +47,9 @@ class TestDropoutOp4(OpTest):
self.op_type = "dropout" self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")} self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.35, 'is_test': True} self.attrs = {'dropout_prob': 0.35, 'is_test': True}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} self.outputs = {
'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob'])
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -58,7 +60,9 @@ class TestDropoutOp5(OpTest): ...@@ -58,7 +60,9 @@ class TestDropoutOp5(OpTest):
self.op_type = "dropout" self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {'dropout_prob': 0.75, 'is_test': True} self.attrs = {'dropout_prob': 0.75, 'is_test': True}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} self.outputs = {
'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob'])
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册