提交 e8bf5bc0 编写于 作者: M Megvii Engine Team

fix(mge): ensure contiguous when passing tensor from graph rt to imperative rt

GitOrigin-RevId: e4d944343d051d9263498c200f17d951af67e9b5
上级 3af10563
......@@ -406,3 +406,16 @@ def test_clip():
for i in range(3):
f(x, tensor([0]), tensor([1]))
# test returning noncontiguous tensor from trace
def test_slice():
@trace
def f(x):
return x[:, 1::2]
x = F.arange(8).reshape(2, 4)
f(x)
y = f(x)
np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2])
y + y
......@@ -156,6 +156,12 @@ cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const {
return prop;
}
void OutputCallback::add_input_layout_constraint() {
if (m_param.require_contiguous) {
input(0)->add_layout_constraint_contiguous();
}
}
void OutputCallback::scn_do_execute() {
if (m_use_host_value) {
m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0)));
......
......@@ -62,6 +62,7 @@ public:
callback_t callback;
bool borrow = false; // do not obtain shared ownership on DeviceTensorND
bool prefer_host_value = false; // use host value when possible
bool require_contiguous = true;
};
OutputCallback(Param param,
const VarNodeArray& inputs,
......@@ -80,6 +81,7 @@ protected:
void scn_do_execute() override;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;
void add_input_layout_constraint() override;
private:
Param m_param;
mutable bool m_use_host_value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册