提交 3fce8544 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

refactor(mgb/parampacksplit): remove input offsets on device

GitOrigin-RevId: b2bf3bf15588d6f513c25cd000fd808584ba5107
上级 11388e58
......@@ -47,16 +47,10 @@ SymbolVarArray _Opr::param_pack_split(
}
auto cn = src.node()->comp_node();
auto offsets_val = megdnn::ParamPackConcat::gen_offsets(
auto offsets = megdnn::ParamPackConcat::gen_offsets(
shapearr, cn.get_mem_addr_alignment(), src.dtype().size());
if (config.has_comp_node_set()) {
cn = config.get_single_comp_node();
}
HostTensorND hv{cn, TensorShape{{offsets_val.size()}}, dtype::Int32{}};
memcpy(hv.raw_ptr(), offsets_val.data(), offsets_val.size() * sizeof(int));
auto offsets = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv);
return mgb::opr::ParamPackSplit::make(src, offsets, offsets_val, shapearr, config);
return mgb::opr::ParamPackSplit::make(src, offsets, shapearr, config);
}
#if MGB_ENABLE_OPR_MM
......
......@@ -13,6 +13,7 @@
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/io.h"
#include "megbrain/graph/event.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/utils/arith_helper.h"
......@@ -1434,15 +1435,13 @@ void ParamPackConcat::on_output_comp_node_stream_changed(){
/* f{{{ ======================= ParamPackSplit ======================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit);
ParamPackSplit::ParamPackSplit(VarNode* src, VarNode* offsets,
const std::vector<dt_int32> offsets_val,
ParamPackSplit::ParamPackSplit(VarNode* src,
const std::vector<dt_int32> offsets,
TensorShapeArray& shapes,
const OperatorNodeConfig& config)
: Super{src->owner_graph(), config, "ParamPackSplit", {src, offsets}},
m_shapes(shapes), m_offsets(offsets_val) {
mgb_assert(src->comp_node() == offsets->comp_node());
: Super{src->owner_graph(), config, "ParamPackSplit", {src}},
m_shapes(shapes), m_offsets(offsets) {
add_input({src});
add_input({offsets});
for (size_t i = 0; i < shapes.size(); i++) {
mgb_assert(shapes[i].total_nr_elems(), "empty param is not allowed!");
......@@ -1456,14 +1455,13 @@ void ParamPackSplit::add_input_layout_constraint(){
}
SymbolVarArray ParamPackSplit::make(const SymbolVar& src,
const SymbolVar& offsets,
const std::vector<dt_int32> offsets_val,
const std::vector<dt_int32> offsets,
TensorShapeArray shapes,
const OperatorNodeConfig& config) {
auto&& out = src.node()
->owner_graph()
->insert_opr(std::make_unique<ParamPackSplit>(
src.node(), offsets.node(), offsets_val,
src.node(), offsets,
shapes, config))
->output();
......@@ -1499,7 +1497,7 @@ void ParamPackSplit::init_output_static_infer_desc() {
using namespace std::placeholders;
auto&& mgr = owner_graph()->static_infer_manager();
DepVal shp_deps{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}};
DepVal shp_deps{{input(0), DepType::SHAPE}};
for (size_t i = 0; i < output().size(); i++) {
auto ov = output(i);
......@@ -1519,9 +1517,17 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) {
}
grad.emplace_back(gval);
}
auto offsets_val = opr.get_offsets();
auto cn = opr.input(0)->comp_node();
if (opr.config().has_comp_node_set()) {
cn = opr.config().get_single_comp_node();
}
HostTensorND hv{cn, TensorShape{offsets_val.size()}, dtype::Int32{}};
memcpy(hv.raw_ptr(), offsets_val.data(), offsets_val.size() * sizeof(int));
auto offsets = opr::ImmutableTensor::make(*opr.input(0)->owner_graph(), hv);
return ParamPackConcat::make(
grad, opr.input(1), opr.get_offsets(),
grad, offsets, offsets_val,
OperatorNodeConfig{}.follow_comp_node(opr.input(0)))
.node();
}
......
......@@ -162,7 +162,7 @@ namespace opr {
auto &&offsets = opr.get_offsets();
auto &&shape = opr.get_output_shapes();
return ParamPackSplit::make(inputs[0], inputs[1], offsets, shape, config).at(0).
return ParamPackSplit::make(inputs[0], offsets, shape, config).at(0).
node()->owner_opr();
}
......
......@@ -600,12 +600,11 @@ MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // {
void add_input_layout_constraint() override;
public:
ParamPackSplit(VarNode* src, VarNode* offsets,
const std::vector<dt_int32> offsets_val,
ParamPackSplit(VarNode* src, const std::vector<dt_int32> offsets,
TensorShapeArray& shapes, const OperatorNodeConfig& config);
static SymbolVarArray make(const SymbolVar& src, const SymbolVar& offsets,
const std::vector<dt_int32> offsets_val,
static SymbolVarArray make(const SymbolVar& src,
const std::vector<dt_int32> offsets,
TensorShapeArray shapes,
const OperatorNodeConfig& config = {});
......
......@@ -1952,9 +1952,7 @@ void test_param_pack_split(const TensorShapeArray& shapes) {
.comp_node(cn)
.resize({offsets_val.size()})
.ptr<dt_int32>());
auto sym_offsets = opr::SharedDeviceTensor::make(
*inputs[0].node()->owner_graph(), offsets);
auto out = opr::ParamPackSplit::make(inputs[0], sym_offsets, offsets_val,
auto out = opr::ParamPackSplit::make(inputs[0], offsets_val,
shapes);
mgb_assert(out.size() == nr_out);
typename Checker::SymOutArray ret;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册