提交 9779bc7f 编写于 作者: M Megvii Engine Team

fix(imperative): allow rng op infer shape fallible

GitOrigin-RevId: 687844500cc2cab18de576b1484215c72329e4b8
上级 8f7fa90c
......@@ -71,7 +71,8 @@ def test_dropout():
with gm:
out = F.nn.dropout(data, rate, training=True)
gm.backward(out, tensor(np.ones(shape, dtype=np.float32)))
assert not out.numpy().all()
if len(shape) != 0:
assert not out.numpy().all()
np.testing.assert_allclose(out.numpy(), data.grad.numpy(), 1e-7, 1e-7)
def test_multiple_dropout(shape, rate):
......@@ -99,6 +100,7 @@ def test_dropout():
out4 = F.nn.dropout(data, rate, training=True)
assert not (out1.numpy() == out4.numpy()).all()
test_dropout_with_shape([], 0.4)
test_dropout_with_shape([13, 17, 63, 21], 0.4)
test_dropout_with_shape([16, 32, 64], 0.3)
test_multiple_dropout([1024], 0.2)
......
......@@ -559,25 +559,33 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
}
dest.comp_node = inputs[0].comp_node;
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
return {{dest}, true};
return {{dest}, inputs[0].layout.ndim != 0};
}
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
SmallVector<LogicalTensorDesc> dests(2);
dests[0].comp_node = inputs[0].comp_node;
dests[0].layout = TensorLayout(inputs[0].layout);
dests[0].layout.dtype = inputs[0].layout.dtype;
dests[1].comp_node = inputs[0].comp_node;
dests[1].layout =
TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32());
return {dests, true};
if (success) {
dests[1].layout =
TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32());
} else {
dests[1].layout = TensorLayout(dtype::Int32());
}
return {dests, success};
}
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>(
const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
SmallVector<LogicalTensorDesc> dests(2);
auto cn = inputs[0].comp_node;
dests[0].comp_node = cn;
......@@ -590,8 +598,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro
inputs[0].layout);
};
dests[1].comp_node = cn;
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
return {dests, true};
if (success) {
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
} else {
dests[1].layout = TensorLayout(dtype::Byte());
}
return {dests, success};
}
template <typename Op>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册