未验证 提交 16ba0abc 编写于 作者: J JZ-LIANG 提交者: GitHub

Recompute Offload: fixed bug in memcpy (#30484)

上级 d8a9ba56
...@@ -38,10 +38,10 @@ class MemcpyFunctor { ...@@ -38,10 +38,10 @@ class MemcpyFunctor {
void operator()(const framework::LoDTensor &lod_tensor) const { void operator()(const framework::LoDTensor &lod_tensor) const {
auto &out_tensor = *out_->GetMutable<framework::LoDTensor>(); auto &out_tensor = *out_->GetMutable<framework::LoDTensor>();
if (dst_place_type_ == 3) { if (dst_place_type_ == 2) {
framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_, framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_,
&out_tensor); &out_tensor);
} else if (dst_place_type_ == 2) { } else if (dst_place_type_ == 1) {
framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_,
&out_tensor); &out_tensor);
} else { } else {
......
...@@ -4780,7 +4780,7 @@ class RecomputeOptimizer(Optimizer): ...@@ -4780,7 +4780,7 @@ class RecomputeOptimizer(Optimizer):
return return
def _insert_async_memcpy_op(self, insert_idx, src_varname, dst_varname, def _insert_async_memcpy_op(self, insert_idx, src_varname, dst_varname,
op_role, kind): op_role, dst_place_type):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
self.block._insert_op_without_sync( self.block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -4789,8 +4789,10 @@ class RecomputeOptimizer(Optimizer): ...@@ -4789,8 +4789,10 @@ class RecomputeOptimizer(Optimizer):
outputs={ outputs={
'Out': [self._main_program.global_block().var(dst_varname)] 'Out': [self._main_program.global_block().var(dst_varname)]
}, },
attrs={"dst_place_type": int(kind), attrs={
OP_ROLE_KEY: op_role}) "dst_place_type": int(dst_place_type),
OP_ROLE_KEY: op_role
})
def _insert_fetch_op(self, idx, varname): def _insert_fetch_op(self, idx, varname):
assert varname in self.checkpoint_name2pinned_name, "Try to fetch {} from Pinned Memory, but it is NOT a checkpoint".format( assert varname in self.checkpoint_name2pinned_name, "Try to fetch {} from Pinned Memory, but it is NOT a checkpoint".format(
...@@ -4798,13 +4800,13 @@ class RecomputeOptimizer(Optimizer): ...@@ -4798,13 +4800,13 @@ class RecomputeOptimizer(Optimizer):
pinned_varname = self.checkpoint_name2pinned_name[varname] pinned_varname = self.checkpoint_name2pinned_name[varname]
fetch_varname = self.checkpoint_name2fetch_name[varname] fetch_varname = self.checkpoint_name2fetch_name[varname]
self._insert_async_memcpy_op(idx, pinned_varname, fetch_varname, 1, 2) self._insert_async_memcpy_op(idx, pinned_varname, fetch_varname, 1, 1)
def _insert_offload_op(self, idx, varname): def _insert_offload_op(self, idx, varname):
assert varname in self.checkpoint_name2pinned_name, "Try to offload {} to Pinned Memory, but it is NOT a checkpoint".format( assert varname in self.checkpoint_name2pinned_name, "Try to offload {} to Pinned Memory, but it is NOT a checkpoint".format(
varname) varname)
pinned_varname = self.checkpoint_name2pinned_name[varname] pinned_varname = self.checkpoint_name2pinned_name[varname]
self._insert_async_memcpy_op(idx, varname, pinned_varname, 0, 3) self._insert_async_memcpy_op(idx, varname, pinned_varname, 0, 2)
def _insert_sync_op(self, op_idx, checkpoint_name): def _insert_sync_op(self, op_idx, checkpoint_name):
# single stream offload no need sync # single stream offload no need sync
......
...@@ -70,7 +70,7 @@ class TestMemcpy_FillConstant(unittest.TestCase): ...@@ -70,7 +70,7 @@ class TestMemcpy_FillConstant(unittest.TestCase):
type='memcpy', type='memcpy',
inputs={'X': gpu_var}, inputs={'X': gpu_var},
outputs={'Out': pinned_var}, outputs={'Out': pinned_var},
attrs={'dst_place_type': 3}) attrs={'dst_place_type': 2})
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
gpu_, pinned_ = exe.run(main_program, gpu_, pinned_ = exe.run(main_program,
...@@ -85,7 +85,7 @@ class TestMemcpy_FillConstant(unittest.TestCase): ...@@ -85,7 +85,7 @@ class TestMemcpy_FillConstant(unittest.TestCase):
type='memcpy', type='memcpy',
inputs={'X': pinned_var}, inputs={'X': pinned_var},
outputs={'Out': gpu_var}, outputs={'Out': gpu_var},
attrs={'dst_place_type': 2}) attrs={'dst_place_type': 1})
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
gpu_, pinned_ = exe.run(main_program, gpu_, pinned_ = exe.run(main_program,
...@@ -135,7 +135,7 @@ class TestMemcpyOPError(unittest.TestCase): ...@@ -135,7 +135,7 @@ class TestMemcpyOPError(unittest.TestCase):
type='memcpy', type='memcpy',
inputs={'X': selected_row_var}, inputs={'X': selected_row_var},
outputs={'Out': pinned_var}, outputs={'Out': pinned_var},
attrs={'dst_place_type': 3}) attrs={'dst_place_type': 2})
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册