提交 1a1748da 编写于 作者: M Megvii Engine Team

feat(opr): let Argsort support empty IO

GitOrigin-RevId: 05fcac6e472e9d7e868516c210c96d8d2987b5dc
上级 7234efe1
......@@ -110,16 +110,42 @@ def test_sort():
data2_shape = (12, 2)
data1 = np.random.random(data1_shape).astype(np.float32)
data2 = np.random.random(data2_shape).astype(np.float32)
output0 = [np.sort(data1), np.argsort(data1).astype(np.int32)]
output1 = [np.sort(data2), np.argsort(data2).astype(np.int32)]
output1 = [np.sort(data1), np.argsort(data1).astype(np.int32)]
output2 = [np.sort(data2), np.argsort(data2).astype(np.int32)]
cases = [
{"input": data1, "output": output0},
{"input": data2, "output": output1},
{"input": data1, "output": output1},
{"input": data2, "output": output2},
]
opr_test(cases, F.sort)
@pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_sort_empty(is_symbolic):
data_shapes = [
(0,),
(10, 0),
]
def fn(x):
return F.sort(x)
for shape in data_shapes:
if is_symbolic is not None:
fn_ = jit.trace(symbolic=is_symbolic)(fn)
else:
fn_ = fn
data = np.random.random(shape).astype(np.float32)
for _ in range(3):
outs = fn_(tensor(data))
ref_outs = (np.sort(data), np.argsort(data))
assert len(ref_outs) == len(outs)
for i in range(len(outs)):
np.testing.assert_equal(outs[i].numpy(), ref_outs[i])
if is_symbolic is None:
break
def test_normalize():
cases = [
......
......@@ -75,7 +75,16 @@ MEGDNN_OPR_INIT1(Argmin, "argmin")
/* ================= ArgsortForward ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortForward);
MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort")
// MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort")
ArgsortForward::ArgsortForward(VarNode *i0, const Param &param, const OperatorNodeConfig &config):
Super(OperatorNodeBaseCtorParam{ i0->owner_graph(), config, "argsort", {i0}} ) {
init_megdnn_opr(*this, param);
add_input({i0});
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // sorted value
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // sorted index
intl::MegDNNOprInitPostCtor<ArgsortForward>::apply(*this);
}
std::array<SymbolVar, 2> ArgsortForward::make(
SymbolVar in_tensor, const Param &param,
......@@ -87,6 +96,32 @@ std::array<SymbolVar, 2> ArgsortForward::make(
return {node->output(0), node->output(1)};
}
void ArgsortForward::scn_do_execute() {
if (input(0)->dev_tensor().empty()) {
mgb_assert(output(0)->dev_tensor().empty() &&
output(1)->dev_tensor().empty());
return;
}
mgb_assert(!output(0)->dev_tensor().empty() &&
!output(1)->dev_tensor().empty());
Super::scn_do_execute();
}
void ArgsortForward::get_output_var_shape(
const TensorShapeArray &inp_shape,
TensorShapeArray &out_shape) const {
mgb_assert(inp_shape.size() == 1 && out_shape.size() == 2);
out_shape[0] = inp_shape[0];
out_shape[1] = inp_shape[0];
}
ArgsortForward::NodeProp* ArgsortForward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ArgsortForward) {
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);
......
......@@ -55,6 +55,12 @@ MGB_DEFINE_OPR_CLASS(Argmin,
*/
MGB_DEFINE_OPR_CLASS(ArgsortForward,
intl::MegDNNOprWrapperFwd<megdnn::ArgsortForward>) // {
protected:
NodeProp* do_make_node_prop() const override;
void scn_do_execute() override;
void get_output_var_shape(
const TensorShapeArray &inp_shape,
TensorShapeArray &out_shape) const override;
public:
ArgsortForward(VarNode *in_tensor,
const Param &param,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册