未验证 提交 16ba0abc 编写于 作者: J JZ-LIANG 提交者: GitHub

Recompute Offload: fixed bug in memcpy (#30484)

上级 d8a9ba56
......@@ -38,10 +38,10 @@ class MemcpyFunctor {
void operator()(const framework::LoDTensor &lod_tensor) const {
auto &out_tensor = *out_->GetMutable<framework::LoDTensor>();
if (dst_place_type_ == 3) {
if (dst_place_type_ == 2) {
framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_,
&out_tensor);
} else if (dst_place_type_ == 2) {
} else if (dst_place_type_ == 1) {
framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_,
&out_tensor);
} else {
......
......@@ -4780,7 +4780,7 @@ class RecomputeOptimizer(Optimizer):
return
def _insert_async_memcpy_op(self, insert_idx, src_varname, dst_varname,
op_role, kind):
op_role, dst_place_type):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
self.block._insert_op_without_sync(
insert_idx,
......@@ -4789,8 +4789,10 @@ class RecomputeOptimizer(Optimizer):
outputs={
'Out': [self._main_program.global_block().var(dst_varname)]
},
attrs={"dst_place_type": int(kind),
OP_ROLE_KEY: op_role})
attrs={
"dst_place_type": int(dst_place_type),
OP_ROLE_KEY: op_role
})
def _insert_fetch_op(self, idx, varname):
assert varname in self.checkpoint_name2pinned_name, "Try to fetch {} from Pinned Memory, but it is NOT a checkpoint".format(
......@@ -4798,13 +4800,13 @@ class RecomputeOptimizer(Optimizer):
pinned_varname = self.checkpoint_name2pinned_name[varname]
fetch_varname = self.checkpoint_name2fetch_name[varname]
self._insert_async_memcpy_op(idx, pinned_varname, fetch_varname, 1, 2)
self._insert_async_memcpy_op(idx, pinned_varname, fetch_varname, 1, 1)
def _insert_offload_op(self, idx, varname):
assert varname in self.checkpoint_name2pinned_name, "Try to offload {} to Pinned Memory, but it is NOT a checkpoint".format(
varname)
pinned_varname = self.checkpoint_name2pinned_name[varname]
self._insert_async_memcpy_op(idx, varname, pinned_varname, 0, 3)
self._insert_async_memcpy_op(idx, varname, pinned_varname, 0, 2)
def _insert_sync_op(self, op_idx, checkpoint_name):
# single stream offload no need sync
......
......@@ -70,7 +70,7 @@ class TestMemcpy_FillConstant(unittest.TestCase):
type='memcpy',
inputs={'X': gpu_var},
outputs={'Out': pinned_var},
attrs={'dst_place_type': 3})
attrs={'dst_place_type': 2})
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
gpu_, pinned_ = exe.run(main_program,
......@@ -85,7 +85,7 @@ class TestMemcpy_FillConstant(unittest.TestCase):
type='memcpy',
inputs={'X': pinned_var},
outputs={'Out': gpu_var},
attrs={'dst_place_type': 2})
attrs={'dst_place_type': 1})
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
gpu_, pinned_ = exe.run(main_program,
......@@ -135,7 +135,7 @@ class TestMemcpyOPError(unittest.TestCase):
type='memcpy',
inputs={'X': selected_row_var},
outputs={'Out': pinned_var},
attrs={'dst_place_type': 3})
attrs={'dst_place_type': 2})
with self.assertRaises(NotImplementedError):
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册