提交 f77de54a 编写于 作者: D dinghao

fix tensor dirty

上级 9c1a5db4
......@@ -164,8 +164,9 @@ Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::arr
Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
: MetaTensor(tensor), dirty_(tensor.dirty_), device_address_(tensor.device_address_) {
: MetaTensor(tensor), device_address_(tensor.device_address_) {
init(tensor.data_, data_type);
dirty_ = tensor.is_dirty();
}
Tensor &Tensor::operator=(const Tensor &tensor) {
......@@ -291,6 +292,7 @@ void Tensor::init(const py::array &input, const TypeId &data_type) {
} else {
data_ = input;
}
dirty_ = true;
}
void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {
......
......@@ -127,6 +127,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->enable_pynative_infer()) {
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
tensor->set_dirty(false);
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(true))) {
......@@ -491,7 +492,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
need_sync = true;
}
} else {
if (tensor->is_dirty() || !AnfAlgo::IsParameterWeight(pk_node)) {
if (tensor->is_dirty()) {
need_sync = true;
} else if (tensor->device_address() != device_address) {
(void)tensor->data_sync();
......
......@@ -51,19 +51,22 @@ def test_assign_add():
[[54, 57, 60],
[63, 66, 69],
[72, 75, 78]]]])
x = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
x1 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y1 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
add = AssignAdd()
output1 = add(x, y)
output1 = add(x1, y1)
assert (output1.asnumpy() == expect1).all()
output2 = add(output1, y)
output2 = add(output1, y1)
assert (output2.asnumpy() == expect2).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
add = AssignAdd()
output1 = add(x, y)
output1 = add(x2, y2)
assert (output1.asnumpy() == expect1).all()
output2 = add(output1, y)
output2 = add(output1, y2)
assert (output2.asnumpy() == expect2).all()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册