提交 a430c912 编写于 作者: M Megvii Engine Team

feat(mgb/opr): let CondTake support empty input

GitOrigin-RevId: dfb401a945d5d75909f7b78448b3713623c28a2c
上级 432fdb7e
...@@ -18,6 +18,7 @@ import megengine.amp as amp ...@@ -18,6 +18,7 @@ import megengine.amp as amp
import megengine.core.ops.builtin as builtin import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype import megengine.core.tensor.dtype as dtype
import megengine.functional as F import megengine.functional as F
import megengine.jit as jit
from megengine import Parameter, Tensor, is_cuda_available, tensor from megengine import Parameter, Tensor, is_cuda_available, tensor
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
...@@ -859,6 +860,35 @@ def test_condtake(): ...@@ -859,6 +860,35 @@ def test_condtake():
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
# @pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_condtake(is_symbolic=None):
shapes = [
(3, 3, 3),
(0,),
(3, 0, 3),
]
def fn(mask, data):
return F.cond_take(mask, data)
if is_symbolic is not None:
fn = jit.trace(symbolic=is_symbolic)(fn)
for shp in shapes:
x_np = np.random.randn(*shp).astype("float32")
mask_np = x_np > 0
x = tensor(x_np)
mask = tensor(mask_np)
ref_out = x_np[mask_np]
ref_idx = mask_np.flatten().nonzero()[0]
for i in range(3):
out, idx = fn(mask, x)
np.testing.assert_equal(out.numpy(), ref_out)
np.testing.assert_equal(idx.numpy(), ref_idx)
if is_symbolic is None:
break
def test_condtake_is_same(): def test_condtake_is_same():
op1 = builtin.CondTake() op1 = builtin.CondTake()
op2 = builtin.CondTake() op2 = builtin.CondTake()
......
...@@ -45,25 +45,30 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -45,25 +45,30 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
auto&& inp = inputs[0]; auto&& inp = inputs[0];
auto&& msk = inputs[1]; auto&& msk = inputs[1];
SmallVector<TensorPtr> out;
mgb_assert(inp->layout().eq_shape(msk->layout()), mgb_assert(inp->layout().eq_shape(msk->layout()),
"input shape does not match mask shape"); "input shape does not match mask shape");
mgb_assert(msk->get_value().dtype().enumv() == DTypeEnum::Bool, mgb_assert(msk->get_value().dtype().enumv() == DTypeEnum::Bool,
"mask dtype must be bool"); "mask dtype must be bool");
DnnOprCaller<megdnn::CondTake> dnn_op(inp->comp_node());
dnn_op.op->param().val = 1;
TensorLayout m_layout({dnn_op.op->get_workspace_in_bytes(inp->layout())},
dtype::Byte());
auto dnn_workspace = dnn_op.create_workspace(m_layout);
MegDNNDynOutMallocImpl<2> policy{inp->comp_node()}; MegDNNDynOutMallocImpl<2> policy{inp->comp_node()};
if (inp->layout().is_empty()) {
dnn_op.op->exec(inp->dev_tensor().as_megdnn(), // empty tensor
msk->dev_tensor().as_megdnn(), policy.alloc_output(0, inp->layout().dtype, {0}, nullptr);
dnn_workspace, policy.alloc_output(1, dtype::Int32(), {0}, nullptr);
&policy); } else {
DnnOprCaller<megdnn::CondTake> dnn_op(inp->comp_node());
SmallVector<TensorPtr> out; dnn_op.op->param().val = 1;
TensorLayout m_layout({dnn_op.op->get_workspace_in_bytes(inp->layout())},
dtype::Byte());
auto dnn_workspace = dnn_op.create_workspace(m_layout);
dnn_op.op->exec(inp->dev_tensor().as_megdnn(),
msk->dev_tensor().as_megdnn(),
dnn_workspace,
&policy);
}
out.push_back(policy.at(0)); out.push_back(policy.at(0));
out.push_back(policy.at(1)); out.push_back(policy.at(1));
return out; return out;
......
...@@ -264,6 +264,15 @@ CondTake::CondTake(VarNode *data, VarNode *mask, ...@@ -264,6 +264,15 @@ CondTake::CondTake(VarNode *data, VarNode *mask,
} }
} }
CondTake::NodeProp* CondTake::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(1),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondTake) { MGB_IMPL_OPR_GRAD(CondTake) {
mgb_assert(out_grad.size() == 3 && !out_grad[2]); mgb_assert(out_grad.size() == 3 && !out_grad[2]);
...@@ -305,11 +314,21 @@ void CondTake::add_input_layout_constraint() { ...@@ -305,11 +314,21 @@ void CondTake::add_input_layout_constraint() {
} }
void CondTake::scn_do_execute() { void CondTake::scn_do_execute() {
auto&& data = input(0)->dev_tensor();
auto&& mask = input(1)->dev_tensor();
intl::MegDNNDynOutMallocImpl dyn_malloc{this, comp_node()}; intl::MegDNNDynOutMallocImpl dyn_malloc{this, comp_node()};
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), if (data.layout().is_empty()) {
input(1)->dev_tensor().as_megdnn(), mgb_assert(data.layout().eq_shape(mask.layout()),
intl::get_megdnn_workspace_from_var(output().back()), "CondTake shape differs: data=%s mask=%s",
&dyn_malloc); data.layout().TensorShape::to_string().c_str(),
mask.layout().TensorShape::to_string().c_str());
dyn_malloc.alloc_output(0, data.layout().dtype, {0}, nullptr);
dyn_malloc.alloc_output(1, dtype::Int32(), {0}, nullptr);
} else {
megdnn_opr()->exec(data.as_megdnn(), mask.as_megdnn(),
intl::get_megdnn_workspace_from_var(output().back()),
&dyn_malloc);
}
} }
/* ================= TopK ================= */ /* ================= TopK ================= */
......
...@@ -151,6 +151,7 @@ MGB_DEFINE_OPR_CLASS(CondTake, intl::CondTakeBase) // { ...@@ -151,6 +151,7 @@ MGB_DEFINE_OPR_CLASS(CondTake, intl::CondTakeBase) // {
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
void scn_do_execute() override; void scn_do_execute() override;
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
NodeProp* do_make_node_prop() const override;
public: public:
CondTake(VarNode *data, VarNode *mask, CondTake(VarNode *data, VarNode *mask,
......
...@@ -256,20 +256,25 @@ TEST(TestOprMisc, CondTake) { ...@@ -256,20 +256,25 @@ TEST(TestOprMisc, CondTake) {
run(mki({100})); run(mki({100}));
} }
TEST(TestOprMisc, CondTakeEmptyOut) { TEST(TestOprMisc, CondTakeEmptyIO) {
using Param = opr::CondTake::Param; using Param = opr::CondTake::Param;
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto host_x = gen({1}); auto check = [&](const TensorShape& shp) {
host_x->ptr<float>()[0] = 1; auto host_x = gen(shp);
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto out = opr::CondTake::make(x, x, {Param::Mode::LT}); auto y = x + 1;
HostTensorND host_out0, host_out1; auto out = opr::CondTake::make(x, y, {Param::Mode::EQ});
auto func = graph->compile({make_callback_copy(out[0], host_out0), HostTensorND host_out0, host_out1;
make_callback_copy(out[1], host_out1)}); auto func = graph->compile({make_callback_copy(out[0], host_out0),
func->execute(); make_callback_copy(out[1], host_out1)});
ASSERT_EQ(TensorShape{0}, host_out0.shape()); func->execute();
ASSERT_EQ(TensorShape{0}, host_out1.shape()); ASSERT_EQ(TensorShape{0}, host_out0.shape());
ASSERT_EQ(TensorShape{0}, host_out1.shape());
};
check({1});
check({0});
check({1, 0});
} }
TEST(TestOprMisc, TopKValueOnly) { TEST(TestOprMisc, TopKValueOnly) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册