未验证 提交 4ad504e7 编写于 作者: Z zhupengyang 提交者: GitHub

hardshrink: support threshold < 0 (#26403)

上级 e92f770c
......@@ -388,9 +388,9 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>();
out.device(d) = x * (temp1 + temp2);
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
out.device(d) = x * (temp1 + temp2 > 0).template cast<T>();
}
};
......@@ -405,9 +405,9 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
dx.device(d) = dout * (temp1 + temp2 > 0).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
......
......@@ -453,20 +453,29 @@ class TestHardShrink(TestActivation):
self.op_type = "hard_shrink"
self.init_dtype()
threshold = 0.5
self.threshold = 0.5
self.set_attrs()
x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) * 10
out = ref_hardshrink(x, threshold)
out = ref_hardshrink(x, self.threshold)
self.attrs = {'threshold': threshold}
self.attrs = {'threshold': self.threshold}
self.inputs = {'X': x}
self.outputs = {'Out': out}
def set_attrs(self):
pass
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
class TestHardShrink_threshold_negative(TestHardShrink):
def set_attrs(self):
self.threshold = -0.1
class TestHardShrinkAPI(unittest.TestCase):
# test paddle.nn.Hardshrink, paddle.nn.functional.hardshrink
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册