提交 dea52781 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mgb/opr): let PowC & TypeCvt support empty IO

GitOrigin-RevId: f97b3005fd3d60c7c8d1159672debf3a5e30cc64
上级 2f68aeb9
...@@ -776,6 +776,10 @@ void TypeCvt::perform(DeviceTensorND &dest, ...@@ -776,6 +776,10 @@ void TypeCvt::perform(DeviceTensorND &dest,
intl::UniqPtrWithCN<megdnn::TypeCvt> &opr) { intl::UniqPtrWithCN<megdnn::TypeCvt> &opr) {
mgb_assert(src.comp_node() == opr.comp_node()); mgb_assert(src.comp_node() == opr.comp_node());
mgb_assert(dest_type.valid()); mgb_assert(dest_type.valid());
if (src.empty()) {
mgb_assert(dest.empty());
return;
}
if (src.dtype() == dest_type) { if (src.dtype() == dest_type) {
dest.copy_from(src); dest.copy_from(src);
return; return;
...@@ -1739,7 +1743,13 @@ void Reduce::record_execute_deps(ExecDependencyArray& deps) { ...@@ -1739,7 +1743,13 @@ void Reduce::record_execute_deps(ExecDependencyArray& deps) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(PowC); MGB_DYN_TYPE_OBJ_FINAL_IMPL(PowC);
MEGDNN_OPR_CTOR_INIT1(PowC, ssprintf("powc_%g", param.exp)) PowC::PowC(VarNode *i0, const Param &param, const OperatorNodeConfig &config)
: Super(OperatorNodeBaseCtorParam{ i0->owner_graph(), config, ssprintf("powc_%g", param.exp), {i0}} ) {
init_megdnn_opr(*this, param);
add_input({i0});
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
intl::MegDNNOprInitPostCtor<PowC>::apply(*this);
}
SymbolVar PowC::make(SymbolVar x, const Param& param, SymbolVar PowC::make(SymbolVar x, const Param& param,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
...@@ -1778,6 +1788,22 @@ void PowC::init_output_static_infer_desc() { ...@@ -1778,6 +1788,22 @@ void PowC::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value});
} }
void PowC::scn_do_execute() {
if (input(0)->dev_tensor().empty()) {
mgb_assert(output(0)->dev_tensor().empty());
return;
}
mgb_assert(!output(0)->dev_tensor().empty());
Super::scn_do_execute();
}
PowC::NodeProp* PowC::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(PowC) { MGB_IMPL_OPR_GRAD(PowC) {
auto exp = opr.param().exp; auto exp = opr.param().exp;
......
...@@ -352,6 +352,8 @@ MGB_DEFINE_OPR_CLASS(PowC, intl::MegDNNOprWrapperFwd<megdnn::PowC>) // { ...@@ -352,6 +352,8 @@ MGB_DEFINE_OPR_CLASS(PowC, intl::MegDNNOprWrapperFwd<megdnn::PowC>) // {
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
void mem_plan_fwd_in2out_writable() override; void mem_plan_fwd_in2out_writable() override;
NodeProp* do_make_node_prop() const override;
void scn_do_execute() override;
public: public:
PowC(VarNode* inp, const Param& param, const OperatorNodeConfig& config); PowC(VarNode* inp, const Param& param, const OperatorNodeConfig& config);
......
...@@ -589,6 +589,23 @@ TEST(TestOprBasicArith, TypeCvtFromBool) { ...@@ -589,6 +589,23 @@ TEST(TestOprBasicArith, TypeCvtFromBool) {
ASSERT_EQ(TensorShape({2}), host_y.shape()); ASSERT_EQ(TensorShape({2}), host_y.shape());
} }
TEST(TestOprBasicArith, TypeCvtPerformEmptyIO) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("xpu0");
auto host_x = gen({2, 0, 3, 4});
auto dev_x = std::make_shared<DeviceTensorND>(cn);
dev_x->copy_from(*host_x);
auto dev_y = std::make_shared<DeviceTensorND>(cn, dtype::Int32{});
dev_y->resize(dev_x->shape());
auto dnn_opr = opr::intl::create_megdnn_opr<megdnn::TypeCvt>(cn);
ASSERT_NO_THROW(opr::TypeCvt::perform(*dev_y, dtype::Int32{}, *dev_x, dnn_opr));
ASSERT_TRUE(dev_y->empty());
ASSERT_TRUE(dev_y->shape().is_empty());
MGB_ASSERT_SHAPE_EQ(dev_x->shape(), dev_y->shape());
}
TEST(TestOprBasicArith, ElemwiseMemFwd) { TEST(TestOprBasicArith, ElemwiseMemFwd) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
...@@ -756,4 +773,19 @@ TEST(TestOprBasicArith, PowCInfer) { ...@@ -756,4 +773,19 @@ TEST(TestOprBasicArith, PowCInfer) {
run(true); run(true);
} }
TEST(TestOprBasicArith, PowCEmptyIO) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
// empty input
auto host_x = gen({4, 0, 2, 3});
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::PowC::make(x, 3.f);
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
ASSERT_NO_THROW(func->execute().wait());
ASSERT_TRUE(host_y.empty());
ASSERT_TRUE(host_y.shape().is_empty());
MGB_ASSERT_SHAPE_EQ(host_x->shape(), host_y.shape());
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册