提交 78d7d400 编写于 作者: M Megvii Engine Team

feat(opr): add a constant flavor for SharedDeviceTensor

Also, add a CONSTANT value inference tag to outputs of
MultipleDeviceTensorHolder.

GitOrigin-RevId: 82a805ed5fed68376c0638a902e303f5a651a478
上级 14e1a205
......@@ -570,10 +570,10 @@ void ParamFusePass::apply(OptState &state) const {
*var->owner_graph(), hv, var_namer.name(var));
} else {
if (is_default_format) {
new_var = opr::SharedDeviceTensor::make(
new_var = opr::SharedDeviceTensor::make_const(
*var->owner_graph(), inferred_val, var_namer.name(var));
} else {
new_var = opr::SharedDeviceTensorWithFormat::make(
new_var = opr::SharedDeviceTensorWithFormat::make_const(
*var->owner_graph(), inferred_val, var_namer.name(var));
}
}
......
......@@ -281,11 +281,11 @@ void Host2DeviceCopy::record_execute_deps(ExecDependencyArray& deps) {
/* ===================== SharedDeviceTensor related ===================== */
intl::SharedDeviceTensorBase::SharedDeviceTensorBase(
ComputingGraph &graph, const std::shared_ptr<DeviceTensorND> &dev_data,
const OperatorNodeConfig &config):
Super{&graph, config, "shared", {}},
m_dev_data{dev_data}
{
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data,
bool const_value, const OperatorNodeConfig& config)
: Super{&graph, config, "shared", {}},
m_dev_data{dev_data},
m_const_value(const_value) {
if (config.has_comp_node_set()) {
mgb_assert(config.get_single_comp_node() == dev_data->comp_node());
}
......@@ -307,26 +307,42 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() {
comp_node(m_dev_data->comp_node());
}
bool intl::SharedDeviceTensorBase::fill_in_static_infer(DeviceTensorND* dest) {
if (m_const_value) {
if (dest) {
if (m_static_infer.empty()) {
m_static_infer.comp_node(CompNode::default_cpu())
.copy_from(*m_dev_data);
}
*dest = m_static_infer;
}
return true;
}
return false;
}
cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const {
return cg::static_infer::SourceType::CONSTANT;
}
SymbolVar SharedDeviceTensor::make(ComputingGraph &graph,
const std::shared_ptr<DeviceTensorND> &dev_data,
bool const_value,
const OperatorNodeConfig &config) {
return graph.insert_opr(std::make_unique<SharedDeviceTensor>(
graph, dev_data, config))->output(0);
graph, dev_data, const_value, config))->output(0);
}
SymbolVar SharedDeviceTensor::make(ComputingGraph &graph,
const HostTensorND &value,
bool const_value,
const OperatorNodeConfig &config) {
auto cn = value.comp_node();
if (config.has_comp_node_set())
cn = config.get_single_comp_node();
auto dev_v = std::make_shared<DeviceTensorND>();
dev_v->comp_node(cn).copy_from(value).sync();
return make(graph, dev_v, config);
return make(graph, dev_v, const_value, config);
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SharedDeviceTensor);
......@@ -342,7 +358,7 @@ SymbolVar VolatileSharedDeviceTensor::make(ComputingGraph &graph,
const std::shared_ptr<DeviceTensorND> &dev_data,
const OperatorNodeConfig &config) {
return graph.insert_opr(std::make_unique<VolatileSharedDeviceTensor>(
graph, dev_data, config))->output(0);
graph, dev_data, false, config))->output(0);
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(VolatileSharedDeviceTensor);
......@@ -354,10 +370,10 @@ void SharedDeviceTensorWithFormat::init_output_format() {
SymbolVar SharedDeviceTensorWithFormat::make(
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data,
const OperatorNodeConfig& config) {
bool const_value, const OperatorNodeConfig& config) {
auto&& opr =
graph.insert_opr(std::make_unique<SharedDeviceTensorWithFormat>(
graph, dev_data, config))
graph, dev_data, const_value, config))
->cast_final_safe<SharedDeviceTensorWithFormat>();
return opr.output(0);
}
......@@ -870,6 +886,24 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() {
};
mgr.register_shape_infer(output(i),
{SourceType::CONSTANT, {}, infer_shp});
auto infer_val = [this, i](DeviceTensorND& dest, const InpVal&) {
if (m_host_values.empty()) {
m_host_values.resize(m_values.size());
}
if (m_host_values[i].empty()) {
m_host_values[i]
.comp_node(CompNode::default_cpu())
.copy_from(*m_values[i]);
}
if (!m_host_values[i].empty()) {
dest = m_host_values[i];
return true;
}
return false;
};
mgr.register_value_infer(output(i),
{SourceType::CONSTANT, {}, infer_val});
}
}
......
......@@ -79,6 +79,10 @@ namespace serialization {
HostTensorND val;
val.copy_from(opr.get_dev_tensor()).sync();
ctx.dump_tensor(opr.name(), val, Meth::VALUE_SHARED);
// Note that we don't persist opr.m_const_value, because it does not
// affect correctness, and SharedDeviceTensor will be bundled
// together as MultipleDeviceTensorHolder in optimize_for_inference
// before being dumped.
}
static cg::OperatorNodeBase* load(
......@@ -280,9 +284,10 @@ namespace opr {
const OperatorNodeConfig &config) {
mgb_assert(inputs.empty());
auto &&opr = opr_.cast_final_safe<Opr>();
return Opr::make(
*ctx.owner_graph(opr, inputs), opr.dev_data(), config).
node()->owner_opr();
return Opr::make(*ctx.owner_graph(opr, inputs), opr.dev_data(),
opr.const_value(), config)
.node()
->owner_opr();
}
cg::OperatorNodeBase* opr_shallow_copy_immutable_tensor(
......
......@@ -75,19 +75,22 @@ class DeviceTensorHolder: public HostIONodeBase {
*/
MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // {
std::shared_ptr<DeviceTensorND> m_dev_data;
DeviceTensorND m_static_infer;
bool m_const_value;
const TensorShape& get_output_shape() override;
bool fill_in_static_infer(DeviceTensorND* dest) override {
MGB_MARK_USED_VAR(dest);
return false;
}
bool fill_in_static_infer(DeviceTensorND* dest) override;
void init_output_comp_node() override;
public:
//! const_value marks whether the device value of this operator should
//! be treated as constant during graph execution. Should be false in
//! most cases.
SharedDeviceTensorBase(ComputingGraph &graph,
const std::shared_ptr<DeviceTensorND> &dev_data,
bool const_value,
const OperatorNodeConfig &config);
const DeviceTensorND& get_dev_tensor() const override {
......@@ -97,6 +100,8 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // {
const std::shared_ptr<DeviceTensorND>& dev_data() const {
return m_dev_data;
}
bool const_value() const { return m_const_value; }
};
/*!
......@@ -104,6 +109,7 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // {
* device tensors
*
* This opr is used to speed up inference by packing params together.
* This operator assumes the device tensors are constant.
*/
MGB_DEFINE_CLS_WITH_SUPER(MultipleDeviceTensorHolderBase,
cg::OperatorNodeBase) // {
......@@ -125,6 +131,8 @@ private:
void init_output_comp_node() override;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;
SmallVector<DeviceTensorND> m_host_values;
};
} // namespace intl
......@@ -249,16 +257,43 @@ MGB_DEFINE_OPR_CLASS(SharedDeviceTensor, intl::SharedDeviceTensorBase) // {
public:
using Super::Super;
static SymbolVar make(ComputingGraph &graph,
const std::shared_ptr<DeviceTensorND> &dev_data,
const OperatorNodeConfig &config = {});
static SymbolVar make(ComputingGraph& graph,
const std::shared_ptr<DeviceTensorND>& dev_data,
bool const_value,
const OperatorNodeConfig& config);
static SymbolVar make(ComputingGraph& graph,
const std::shared_ptr<DeviceTensorND>& dev_data,
const OperatorNodeConfig& config = {}) {
return make(graph, dev_data, false, config);
}
static SymbolVar make_const(
ComputingGraph& graph,
const std::shared_ptr<DeviceTensorND>& dev_data,
const OperatorNodeConfig& config = {}) {
return make(graph, dev_data, true, config);
}
/*!
* \brief make a SharedDeviceTensor by first coping from host to device
*
* See SharedDeviceTensorBase::SharedDeviceTensorBase for const_value.
*/
static SymbolVar make(ComputingGraph &graph,
const HostTensorND &value,
const OperatorNodeConfig &config = {});
static SymbolVar make(ComputingGraph& graph, const HostTensorND& value,
bool const_value,
const OperatorNodeConfig& config);
static SymbolVar make(ComputingGraph& graph, const HostTensorND& value,
const OperatorNodeConfig& config = {}) {
return make(graph, value, false, config);
}
static SymbolVar make_const(ComputingGraph& graph,
const HostTensorND& value,
const OperatorNodeConfig& config = {}) {
return make(graph, value, false, config);
}
};
/*!
......@@ -276,7 +311,19 @@ public:
static SymbolVar make(ComputingGraph& graph,
const std::shared_ptr<DeviceTensorND>& dev_data,
const OperatorNodeConfig& config = {});
bool const_value, const OperatorNodeConfig& config);
static SymbolVar make(ComputingGraph& graph,
const std::shared_ptr<DeviceTensorND>& dev_data,
const OperatorNodeConfig& config = {}) {
return make(graph, dev_data, false, config);
}
static SymbolVar make_const(ComputingGraph& graph,
const std::shared_ptr<DeviceTensorND>& dev_data,
const OperatorNodeConfig& config = {}) {
return make(graph, dev_data, true, config);
}
};
/*!
......@@ -297,6 +344,15 @@ MGB_DEFINE_OPR_CLASS(
static SymbolVar make(ComputingGraph &graph,
const std::shared_ptr<DeviceTensorND> &dev_data,
const OperatorNodeConfig &config = {});
//! adapter for io.sereg.h: opr_shallow_copy_shared_device_tensor
static SymbolVar make(ComputingGraph& graph,
const std::shared_ptr<DeviceTensorND>& dev_data,
bool const_value,
const OperatorNodeConfig& config) {
mgb_assert(!const_value);
return make(graph, dev_data, false, config);
}
};
/*!
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册