提交 a5a60679 编写于 作者: M Megvii Engine Team

feat(imperative/interpreter): add more dispatch mode in apply_op

GitOrigin-RevId: 2663504470e6cf83a4ce5d84131f0cbd2f39716e
上级 45e20602
......@@ -29,7 +29,7 @@ Interpreter& Interpreter::inst() {
return inst_;
}
void* ChannelImpl::put(const HostTensorND& value, bool no_cache) {
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
auto info = alloc();
info->desc.layout = value.layout();
info->desc.comp_node = value.comp_node();
......@@ -39,7 +39,7 @@ void* ChannelImpl::put(const HostTensorND& value, bool no_cache) {
return info;
}
void* ChannelImpl::put(const DeviceTensorND& data) {
Handle ChannelImpl::put(const DeviceTensorND& data) {
auto info = alloc();
info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node();
......@@ -48,12 +48,12 @@ void* ChannelImpl::put(const DeviceTensorND& data) {
return info;
}
void ChannelImpl::del(void* handle) {
void ChannelImpl::del(Handle handle) {
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle);
m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)});
}
void ChannelImpl::swap_in(void* handle) {
void ChannelImpl::swap_in(Handle handle) {
if (m_enable_evict & SWAP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -61,7 +61,7 @@ void ChannelImpl::swap_in(void* handle) {
}
}
void ChannelImpl::swap_out(void* handle) {
void ChannelImpl::swap_out(Handle handle) {
if (m_enable_evict & SWAP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -69,7 +69,7 @@ void ChannelImpl::swap_out(void* handle) {
}
}
void ChannelImpl::drop(void* handle) {
void ChannelImpl::drop(Handle handle) {
if (m_enable_evict & DROP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -77,45 +77,91 @@ void ChannelImpl::drop(void* handle) {
}
}
SmallVector<void*> ChannelImpl::apply_op(
void ChannelImpl::dispatch_default_cpu(
std::shared_ptr<OpDef> op,
const SmallVector<void*>& inputs) {
for (auto i : inputs) {
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
"invalid handle: %p", i);
}
SmallVector<TensorInfo*> input_infos;
input_infos.reserve(inputs.size());
SmallVector<LogicalTensorDesc> input_descs;
input_descs.reserve(inputs.size());
const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs) {
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
SmallVector<DeviceTensorND> input_tensornds;
input_tensornds.reserve(input_descs.size());
CompNode output_cn;
{
MGB_LOCK_GUARD(m_mutex);
for (auto i : inputs) {
auto info = reinterpret_cast<TensorInfo*>(i);
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
input_infos.push_back(info);
input_descs.push_back(info->desc);
for (auto&& info : input_infos) {
mgb_assert(info->ptr, "invalid tensor ptr!");
if (!output_cn.valid()) {
output_cn = info->ptr->comp_node();
} else {
mgb_assert(output_cn == info->ptr->comp_node(), "cannot decide output comp node");
}
mgb_assert(info->ptr->try_get_value(), "no valid host value");
input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
}
}
outputs->reserve(output_descs.size());
SmallVector<DeviceTensorND> output_tensornds;
output_tensornds.reserve(output_descs.size());
for (auto&& desc : output_descs) {
// TODO: may conflict with condtake, which need alloc inside
mgb_assert(!desc.layout.is_empty());
// use HostTensorND alloc_host for cuda pinned memory
output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
}
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
SmallVector<TensorInfo*> output_infos;
output_infos.reserve(output_descs.size());
for (auto&& tensornd : output_tensornds) {
// tensornd -> host_tensornd
HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
.proxy_to_comp_node(output_cn);
// tensornd -> desc
LogicalTensorDesc desc = {tensornd.layout(), output_cn, tensornd};
// tensornd -> tensor
auto info = alloc();
info->desc = desc;
m_valid_handle.insert(info);
output_infos.push_back(info);
info->ptr = Tensor::make(host_tensornd, true); // host_only=true
info->value_fetched = true;
outputs->push_back(info);
}
if (m_enable_evict & DROP) {
for (auto out : output_infos) {
out->path.op = op;
for (auto out_ : output_infos) {
out->path.outputs.push_back(m_st.at(out_));
}
for (auto inp : input_infos) {
out->path.inputs.push_back(m_st.at(inp));
inp->path.dep_outputs.push_back(m_st.at(out));
}
}
}
}
void ChannelImpl::dispatch_kernel(
std::shared_ptr<OpDef> op,
const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs) {
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
ApplyOp cmd{std::move(op)};
cmd.inputs = std::move(input_infos);
cmd.outputs.reserve(output_descs.size());
SmallVector<void*> outputs;
// FIXME: remove this check when op check is correct
bool validated_bkp = true;
for (size_t i = 0;i < output_descs.size();i ++) {
auto&& desc = output_descs[i];
if (desc.layout.ndim == 0) {
validated_bkp = false;
}
outputs->reserve(output_descs.size());
for (auto&& desc : output_descs) {
auto info = alloc();
info->desc = desc;
m_valid_handle.insert(info);
cmd.outputs.push_back(info);
outputs.push_back(info);
outputs->push_back(info);
}
if (m_enable_evict & DROP) {
for (auto out : cmd.outputs) {
......@@ -130,20 +176,55 @@ SmallVector<void*> ChannelImpl::apply_op(
}
}
m_buffer.enqueue(std::move(cmd));
if (!(validated && validated_bkp) && m_async_level == 1) {
if (!validated && m_async_level == 1) {
sync();
} else if (m_async_level == 0) {
sync();
// check device error
for (auto&& oup : outputs) {
for (auto&& oup : *outputs) {
auto info = reinterpret_cast<TensorInfo*>(oup);
info->ptr->comp_node().sync();
}
}
}
SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op,
const SmallVector<Handle>& inputs) {
for (auto i : inputs) {
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
"invalid handle: %p", i);
}
SmallVector<TensorInfo*> input_infos;
input_infos.reserve(inputs.size());
SmallVector<LogicalTensorDesc> input_descs;
input_descs.reserve(inputs.size());
{
MGB_LOCK_GUARD(m_mutex);
for (auto i : inputs) {
auto info = reinterpret_cast<TensorInfo*>(i);
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
input_infos.push_back(info);
input_descs.push_back(info->desc);
}
}
SmallVector<Handle> outputs;
switch (OpDef::decide_dispatch_mode(*op, input_descs)) {
case DEFAULT_CPU: {
dispatch_default_cpu(op, input_infos, input_descs, &outputs);
break;
}
case KERNEL: {
dispatch_kernel(op, input_infos, input_descs, &outputs);
break;
}
}
mgb_assert(outputs.size() > 0, "Invalid dispatch mode!");
return outputs;
}
HostTensorND ChannelImpl::get_value(void* handle) {
HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -163,7 +244,7 @@ HostTensorND ChannelImpl::get_value(void* handle) {
return info->ptr->get_value();
}
TensorShape ChannelImpl::get_shape(void* handle) {
TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -184,7 +265,7 @@ TensorShape ChannelImpl::get_shape(void* handle) {
return ret;
}
DType ChannelImpl::get_dtype(void* handle) {
DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -193,7 +274,7 @@ DType ChannelImpl::get_dtype(void* handle) {
return ret;
}
CompNode ChannelImpl::get_device(void* handle) {
CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -202,7 +283,7 @@ CompNode ChannelImpl::get_device(void* handle) {
return ret;
}
DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) {
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -262,25 +343,15 @@ ChannelImpl::~ChannelImpl() {
}
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) {
if (notice) {
MGB_LOCK_GUARD(m_mutex);
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
// if (dest->desc.layout.ndim) {
// mgb_assert(dest->desc.layout.eq_shape(ptr->layout()));
// }
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr);
if (m_waitee == dest) {
m_cv.notify_all();
}
} else {
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr);
auto lock = notice ? std::unique_lock<std::mutex>(m_mutex)
: std::unique_lock<std::mutex>();
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr);
if (notice && m_waitee == dest) {
m_cv.notify_all();
}
}
......@@ -295,7 +366,7 @@ void ChannelImpl::do_swap_out(TensorInfo* dest) {
dest->evict_type = SWAP;
dest->value_fetched = false;
// TODO: swap in parallel
dest->h_value.copy_from(dest->ptr->dev_tensor()).sync();
dest->h_value = dest->ptr->get_value();
dest->ptr.reset();
}
......
......@@ -198,6 +198,17 @@ private:
void do_drop(TensorInfo* dest);
void regenerate(TensorInfo* dest, bool must_drop);
void dispatch_default_cpu(
std::shared_ptr<OpDef> op,
const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs);
void dispatch_kernel(
std::shared_ptr<OpDef> op,
const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs);
std::mutex m_mutex;
std::condition_variable m_cv;
MemPool<TensorInfo> m_pool;
......
......@@ -30,12 +30,26 @@ std::shared_ptr<OpDef> OpDef::make_from_op_node(
return trait->make_from_op_node(node);
}
DispatchMode OpDef::decide_dispatch_mode(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
return def.trait()->decide_dispatch_mode(def, inputs);
}
SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
const OpDef& def,
SmallVector<TensorPtr> inputs) {
return def.trait()->apply_on_physical_tensor(def, std::move(inputs));
}
void OpDef::apply_on_device_tensornd(
const OpDef& def,
const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) {
def.trait()->apply_on_device_tensornd(def, inputs, outputs);
return;
}
VarNodeArray OpDef::apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
......
......@@ -9,12 +9,16 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <exception>
#include <sstream>
#include <stdexcept>
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/tensor.h"
#include "./op_trait.h"
#include "megbrain/imperative/proxy_graph_detail.h"
namespace mgb {
namespace imperative {
......@@ -62,6 +66,12 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
}
}
DispatchMode fallback_decide_dispatch_mode(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
return KERNEL;
}
OpTraitRegistry& OpTraitRegistry::fallback() {
if (trait->apply_on_var_node) {
// fallback to proxy graph impl
......@@ -78,6 +88,9 @@ OpTraitRegistry& OpTraitRegistry::fallback() {
proxy_graph_detail::make_backward_graph;
}
}
if (!trait->decide_dispatch_mode) {
trait->decide_dispatch_mode = fallback_decide_dispatch_mode;
}
return *this;
}
......
......@@ -60,8 +60,12 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type {
using OpDefMaker = detail::OpMeth<
decltype(OpDef::make_from_op_node)>;
using DecideDispatchMode = detail::OpMeth<
decltype(OpDef::decide_dispatch_mode)>;
using ApplyOnPhysicalTensor = detail::OpMeth<
decltype(OpDef::apply_on_physical_tensor)>;
using ApplyOnDeviceTensorND = detail::OpMeth<
decltype(OpDef::apply_on_device_tensornd)>;
using ApplyOnVarNode = detail::OpMeth<
decltype(OpDef::apply_on_var_node)>;
using InferOutputAttrsFallible = detail::OpMeth<
......@@ -74,7 +78,9 @@ using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;
struct OpTrait {
const char* name;
OpDefMaker make_from_op_node;
DecideDispatchMode decide_dispatch_mode;
ApplyOnPhysicalTensor apply_on_physical_tensor;
ApplyOnDeviceTensorND apply_on_device_tensornd;
ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible;
GradMaker make_backward_graph;
......@@ -88,7 +94,9 @@ struct OpTrait {
#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \
cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(make_backward_graph) \
......
......@@ -68,23 +68,46 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
DispatchMode decide_dispatch_mode(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
const SmallVector<LogicalTensorDesc>& inputs) {
bool host_computable = true;
constexpr int size_threshhold = TensorShape::MAX_NDIM;
for (auto&& inp : inputs) {
if (inp.value.empty() || inp.value.layout().ndim == 0
|| inp.value.layout().total_nr_elems() > size_threshhold) {
host_computable = false;
break;
}
}
return host_computable ? DEFAULT_CPU : KERNEL;
}
void apply_on_device_tensornd(
const OpDef& def,
const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) {
auto&& op_def = def.cast_final_safe<Elemwise>();
auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
mgb_assert(inputs.size() == trait.arity,
"%s expects %u inputs; got %zu actually", trait.name,
trait.arity, inputs.size());
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0].comp_node());
opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
DeviceTensorND out;
SmallVector<DeviceTensorND> dt_inputs(inputs.size());
SmallVector<DeviceTensorND> inp_tensornds(inputs.size());
for (unsigned i = 0; i < inputs.size(); ++i){
dt_inputs[i] = inputs[i]->dev_tensor();
inp_tensornds[i] = inputs[i]->dev_tensor();
}
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0]->comp_node());
opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr);
return {Tensor::make(out)};
SmallVector<DeviceTensorND> oup_tensornds = {{inp_tensornds[0].comp_node(), inp_tensornds[0].dtype()}};
apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds);
return {Tensor::make(oup_tensornds[0])};
}
MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{
......@@ -214,8 +237,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_
OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
.make_from_op_node(make_from_op_node)
.decide_dispatch_mode(decide_dispatch_mode)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_device_tensornd(apply_on_device_tensornd)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
......
......@@ -15,8 +15,8 @@
#include "../op_trait.h"
namespace mgb::imperative {
namespace {
namespace get_var_shape {
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
......@@ -24,17 +24,38 @@ cg::OperatorNodeBase* apply_on_var_node(
return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr();
}
SmallVector<TensorPtr> apply_on_physical_tensor(
DispatchMode decide_dispatch_mode(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
const SmallVector<LogicalTensorDesc>& inputs) {
bool host_computable = true;
for (auto&& inp : inputs) {
// FIXME(czh): remove value chech after proxy graph's
// apply_on_device_tensornd is supported and output Tensor
// is made before add_task.
// then if layout is valid, ptr->layout must be ready
if (inp.value.empty() || inp.value.layout().ndim == 0) {
host_computable = false;
break;
}
}
return host_computable ? DEFAULT_CPU : KERNEL;
}
void apply_on_device_tensornd(
const OpDef& def,
const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) {
auto&& op_def = def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& inp = inputs[0];
auto&& shp = inp->layout();
auto&& shp = inp.layout();
mgb_assert(shp.ndim != 0, "input shape invalid");
mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(),
"GetVarShape's apply_on_device_tensornd should receive default_cpu outputs.");
HostTensorND hv;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){
hv = HostTensorND(inp->comp_node(), {shp.ndim}, dtype::Int32());
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
for (size_t i = 0; i < shp.ndim; ++i) {
ptr[i] = shp.shape[i];
......@@ -45,11 +66,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
axis += shp.ndim;
}
mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32());
hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = shp.shape[axis];
}
return {Tensor::make(std::move(hv))};
(*outputs)[0] = DeviceTensorND::make_proxy(hv);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
SmallVector<DeviceTensorND> input_tensornds;
input_tensornds.reserve(inputs.size());
for (auto&& inp : inputs) {
input_tensornds.push_back(inp->dev_tensor());
}
SmallVector<DeviceTensorND> output_tensornds = {{CompNode::default_cpu(), dtype::Int32()}};
apply_on_device_tensornd(def, input_tensornds, &output_tensornds);
// restore to input comp_node
HostTensorND host_tensornd = HostTensorND::make_proxy(output_tensornds[0])
.proxy_to_comp_node(inputs[0]->comp_node());
return {Tensor::make(std::move(host_tensornd))};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
......@@ -62,7 +101,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
}
DeviceTensorND value;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
for (size_t i = 0; i < desc.layout.ndim; ++i) {
......@@ -88,11 +127,15 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
.make_from_op_node(make_from_op_node)
.decide_dispatch_mode(decide_dispatch_mode)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(apply_on_device_tensornd)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // get_var_shape
namespace param_pack {
TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) {
TensorShapeArray ret;
for (auto&& i:shapes) {
......@@ -156,6 +199,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
.apply_on_var_node(param_pack_concat_apply_on_var_node)
.fallback();
} // namespace
} // param_pack
} // namespace mgb::imperative
......@@ -20,6 +20,11 @@ namespace imperative {
class OpDef;
struct OpTrait;
enum DispatchMode {
DEFAULT_CPU = 0,
KERNEL = 1
};
struct BackwardGraphResult {
std::shared_ptr<OpDef> backward;
std::vector<bool> save_for_backward;
......@@ -36,10 +41,31 @@ public:
static std::shared_ptr<OpDef> make_from_op_node(
cg::OperatorNodeBase* node);
/*!
* \brief Decide which dispatch method to be used according to the inputs'
* host value and size.
*
* \param def Specific :c:expr:`OpDef` to be executed.
* \param inputs Input tensor descriptions.
* \return Which DispatchMode to be used, such as `CUDA` or `DEFAULT_CPU`.
*/
static DispatchMode decide_dispatch_mode(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs);
static SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
SmallVector<TensorPtr> inputs);
/*!
* \brief Call the corresponding dnn op to calculate results. Output
* tensors' device memory should be allocated outside.
*/
static void apply_on_device_tensornd(
const OpDef& def,
const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs);
static cg::VarNodeArray apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册