提交 84d1a440 编写于 作者: M Megvii Engine Team

fix(imperative): do not use output_desc in rng ops

GitOrigin-RevId: e6a399be171ea93d8b1a79842a6c066b06e3843d
上级 1ce78aa0
......@@ -419,7 +419,8 @@ _INST_RNG_MAKER(2)
template <typename Op>
void exec(
const OpDef& op, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs) {
const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspace) {
auto&& rng = op.cast_final_safe<Op>();
auto dest = outputs[0];
......@@ -450,56 +451,71 @@ void exec(
}
template <typename Op>
SmallVector<CompNode> infer_output_cns(
SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
CompNode cn;
LogicalTensorDesc dest;
auto&& rng = op.cast_final_safe<Op>();
auto handle = rng.handle;
if (handle) {
cn = RNGDnnOpManager::get_comp_node(handle);
dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
} else {
cn = inputs[0]->comp_node();
dest.comp_node = inputs[0]->comp_node();
}
constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
if (!rng_with_shape) {
for (int i = 0; i < inputs.size(); ++i) {
mgb_assert(
inputs[i]->comp_node() == cn,
inputs[i]->comp_node() == dest.comp_node,
"%s expects the device of inputs[%d] to be same as the device of "
"handle; "
"got %s and %s actually",
rng.dyn_typeinfo()->name, i,
inputs[i]->comp_node().to_string().c_str(), cn.to_string().c_str());
inputs[i]->comp_node().to_string().c_str(),
dest.comp_node.to_string().c_str());
}
}
return {cn};
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng);
return {dest};
}
template <>
SmallVector<CompNode> infer_output_cns<ShuffleRNG>(
SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
SmallVector<CompNode> cns(2);
SmallVector<LogicalTensorDesc> dests(2);
auto&& rng = op.cast_final_safe<ShuffleRNG>();
auto handle = rng.handle;
if (handle) {
cns[0] = RNGDnnOpManager::get_comp_node(handle);
cns[1] = RNGDnnOpManager::get_comp_node(handle);
dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle);
dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle);
} else {
cns[0] = inputs[0]->comp_node();
cns[1] = inputs[0]->comp_node();
dests[0].comp_node = inputs[0]->comp_node();
dests[1].comp_node = inputs[0]->comp_node();
}
return cns;
dests[0].layout = TensorLayout(inputs[0]->layout());
dests[0].layout.dtype = inputs[0]->layout().dtype;
dests[1].layout =
TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32());
return dests;
}
template <>
SmallVector<CompNode> infer_output_cns<Dropout>(
SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
SmallVector<CompNode> cns(2);
SmallVector<LogicalTensorDesc> dests(2);
auto&& cn = inputs[0]->comp_node();
cns[0] = cn;
cns[1] = cn;
return cns;
dests[0].comp_node = cn;
dests[0].layout = TensorLayout(inputs[0]->layout());
dests[0].layout.dtype = inputs[0]->layout().dtype;
auto get_mask_size = [&]() -> size_t {
auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
inputs[0]->layout());
};
dests[1].comp_node = cn;
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
return dests;
}
template <typename Op>
......@@ -507,11 +523,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
SmallVector<TensorPtr> outputs;
SmallVector<CompNode> cns = infer_output_cns<Op>(def, inputs);
for (size_t i = 0; i < cns.size(); i++) {
outputs.push_back(Tensor::make(output_descs[i].layout, cns[i]));
SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs);
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
exec<Op>(def, inputs, outputs);
exec<Op>(def, inputs, outputs, {});
return outputs;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册