提交 9779bc7f 编写于 作者: M Megvii Engine Team

fix(imperative): allow rng op infer shape fallible

GitOrigin-RevId: 687844500cc2cab18de576b1484215c72329e4b8
上级 8f7fa90c
...@@ -71,6 +71,7 @@ def test_dropout(): ...@@ -71,6 +71,7 @@ def test_dropout():
with gm: with gm:
out = F.nn.dropout(data, rate, training=True) out = F.nn.dropout(data, rate, training=True)
gm.backward(out, tensor(np.ones(shape, dtype=np.float32))) gm.backward(out, tensor(np.ones(shape, dtype=np.float32)))
if len(shape) != 0:
assert not out.numpy().all() assert not out.numpy().all()
np.testing.assert_allclose(out.numpy(), data.grad.numpy(), 1e-7, 1e-7) np.testing.assert_allclose(out.numpy(), data.grad.numpy(), 1e-7, 1e-7)
...@@ -99,6 +100,7 @@ def test_dropout(): ...@@ -99,6 +100,7 @@ def test_dropout():
out4 = F.nn.dropout(data, rate, training=True) out4 = F.nn.dropout(data, rate, training=True)
assert not (out1.numpy() == out4.numpy()).all() assert not (out1.numpy() == out4.numpy()).all()
test_dropout_with_shape([], 0.4)
test_dropout_with_shape([13, 17, 63, 21], 0.4) test_dropout_with_shape([13, 17, 63, 21], 0.4)
test_dropout_with_shape([16, 32, 64], 0.3) test_dropout_with_shape([16, 32, 64], 0.3)
test_multiple_dropout([1024], 0.2) test_multiple_dropout([1024], 0.2)
......
...@@ -559,25 +559,33 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -559,25 +559,33 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
} }
dest.comp_node = inputs[0].comp_node; dest.comp_node = inputs[0].comp_node;
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
return {{dest}, true}; return {{dest}, inputs[0].layout.ndim != 0};
} }
template <> template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible< std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
SmallVector<LogicalTensorDesc> dests(2); SmallVector<LogicalTensorDesc> dests(2);
dests[0].comp_node = inputs[0].comp_node; dests[0].comp_node = inputs[0].comp_node;
dests[0].layout = TensorLayout(inputs[0].layout); dests[0].layout = TensorLayout(inputs[0].layout);
dests[0].layout.dtype = inputs[0].layout.dtype; dests[0].layout.dtype = inputs[0].layout.dtype;
dests[1].comp_node = inputs[0].comp_node; dests[1].comp_node = inputs[0].comp_node;
if (success) {
dests[1].layout = dests[1].layout =
TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32()); TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32());
return {dests, true}; } else {
dests[1].layout = TensorLayout(dtype::Int32());
}
return {dests, success};
} }
template <> template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>(
const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) { const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
SmallVector<LogicalTensorDesc> dests(2); SmallVector<LogicalTensorDesc> dests(2);
auto cn = inputs[0].comp_node; auto cn = inputs[0].comp_node;
dests[0].comp_node = cn; dests[0].comp_node = cn;
...@@ -590,8 +598,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro ...@@ -590,8 +598,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro
inputs[0].layout); inputs[0].layout);
}; };
dests[1].comp_node = cn; dests[1].comp_node = cn;
if (success) {
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
return {dests, true}; } else {
dests[1].layout = TensorLayout(dtype::Byte());
}
return {dests, success};
} }
template <typename Op> template <typename Op>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册