未验证 提交 1f3be859 编写于 作者: L Leo Chen 提交者: GitHub

Fix bug of fetch_async_op_handle when fetching the feed variable (#28194)

* fix bug of fetch_async_op_handle

* revert some changes of test_buffer_shared_memory_reuse_pass

* revert some changes of test_buffer_shared_memory_reuse_pass
上级 e7305160
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fetch_async_op_handle.h" #include "paddle/fluid/framework/details/fetch_async_op_handle.h"
#include <string> #include <string>
#include <utility> #include <utility>
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -195,7 +197,7 @@ void FetchAsyncOpHandle::FetchMergedLodTensor( ...@@ -195,7 +197,7 @@ void FetchAsyncOpHandle::FetchMergedLodTensor(
void FetchAsyncOpHandle::RunImpl() { void FetchAsyncOpHandle::RunImpl() {
platform::RecordEvent record_event(Name()); platform::RecordEvent record_event(Name());
WaitInputVarGenerated(); WaitInputVarGenerated(true);
// get src vars // get src vars
auto &scopes = *local_exec_scopes_; auto &scopes = *local_exec_scopes_;
......
...@@ -143,7 +143,7 @@ void OpHandleBase::AddOutput(VarHandleBase *out) { ...@@ -143,7 +143,7 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
out->AddInput(this, this->Node()); out->AddInput(this, this->Node());
} }
void OpHandleBase::WaitInputVarGenerated() { void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) {
for (auto in_var : inputs_) { for (auto in_var : inputs_) {
if (NeedWait(in_var)) { if (NeedWait(in_var)) {
// Dummy Variable is used to represent dependencies between operators, so // Dummy Variable is used to represent dependencies between operators, so
...@@ -165,6 +165,30 @@ void OpHandleBase::WaitInputVarGenerated() { ...@@ -165,6 +165,30 @@ void OpHandleBase::WaitInputVarGenerated() {
} }
// There are nothing to do when the place is CPUPlace. // There are nothing to do when the place is CPUPlace.
} }
} else {
// NOTE(zhiqiu): Special case when using fetch_async_op_handle may lead to
// nodetermination due to parallel execution of cuda memory operation. Eg:
// execute stream: CPU->GPU copy (feed)
// fetch stream: GPU->CUDAPinned (fetch)
if (in_var && wait_for_feed) {
auto *in_var_handle = dynamic_cast<VarHandle *>(in_var);
if (in_var_handle) {
auto &place = in_var_handle->place();
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto stream =
static_cast<platform::CUDADeviceContext *>(pool.Get(place))
->stream();
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Not compiled with CUDA."));
#endif
}
}
}
} }
} }
} }
...@@ -172,8 +196,8 @@ void OpHandleBase::WaitInputVarGenerated() { ...@@ -172,8 +196,8 @@ void OpHandleBase::WaitInputVarGenerated() {
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto in_var : inputs_) { for (auto in_var : inputs_) {
if (NeedWait(in_var)) { if (NeedWait(in_var)) {
// Dummy Variable is used to represent dependencies between operators, so // Dummy Variable is used to represent dependencies between operators,
// there doesn't add event for it. // so there doesn't add event for it.
auto *in_var_handle = dynamic_cast<VarHandle *>(in_var); auto *in_var_handle = dynamic_cast<VarHandle *>(in_var);
if (in_var_handle) { if (in_var_handle) {
if (platform::is_gpu_place(in_var_handle->place())) { if (platform::is_gpu_place(in_var_handle->place())) {
......
...@@ -81,12 +81,15 @@ class OpHandleBase { ...@@ -81,12 +81,15 @@ class OpHandleBase {
// This method adds the wait events of all the input on all the device // This method adds the wait events of all the input on all the device
// context. // context.
// NODE: This Wait is asynchronous operation. // NOTE: This Wait is asynchronous operation.
virtual void WaitInputVarGenerated(); // NOTE: wait_for_feed is added to wait for feed var, since it has
// generated op, no event and cannot perform event wait. It is only
// used in fetch_async_op_handle currently.
virtual void WaitInputVarGenerated(bool wait_for_feed = false);
// This method adds the wait events of all the input on the specified device // This method adds the wait events of all the input on the specified device
// context. // context.
// NODE: This Wait is asynchronous operation. // NOTE: This Wait is asynchronous operation.
virtual void WaitInputVarGenerated(const platform::Place &place); virtual void WaitInputVarGenerated(const platform::Place &place);
virtual bool NeedWait(VarHandleBase *in_var); virtual bool NeedWait(VarHandleBase *in_var);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -267,6 +268,8 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>( ...@@ -267,6 +268,8 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>(
const void* src, size_t num, cudaStream_t stream) { const void* src, size_t num, cudaStream_t stream) {
if (UNLIKELY(num == 0)) return; if (UNLIKELY(num == 0)) return;
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << " by thream(" << stream << ")";
if (dst_place == src_place) { if (dst_place == src_place) {
platform::SetDeviceId(src_place.device); platform::SetDeviceId(src_place.device);
if (stream) { if (stream) {
...@@ -293,6 +296,8 @@ template <> ...@@ -293,6 +296,8 @@ template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>( void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
platform::CPUPlace dst_place, void* dst, platform::CPUPlace dst_place, void* dst,
platform::CUDAPinnedPlace src_place, const void* src, size_t num) { platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place;
if (UNLIKELY(num == 0)) return; if (UNLIKELY(num == 0)) return;
std::memcpy(dst, src, num); std::memcpy(dst, src, num);
} }
...@@ -301,6 +306,8 @@ template <> ...@@ -301,6 +306,8 @@ template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>( void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
platform::CUDAPinnedPlace dst_place, void* dst, platform::CUDAPinnedPlace dst_place, void* dst,
platform::CPUPlace src_place, const void* src, size_t num) { platform::CPUPlace src_place, const void* src, size_t num) {
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place;
if (UNLIKELY(num == 0)) return; if (UNLIKELY(num == 0)) return;
std::memcpy(dst, src, num); std::memcpy(dst, src, num);
} }
...@@ -309,6 +316,8 @@ template <> ...@@ -309,6 +316,8 @@ template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>( void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
platform::CUDAPinnedPlace dst_place, void* dst, platform::CUDAPinnedPlace dst_place, void* dst,
platform::CUDAPinnedPlace src_place, const void* src, size_t num) { platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place;
if (UNLIKELY(num == 0)) return; if (UNLIKELY(num == 0)) return;
std::memcpy(dst, src, num); std::memcpy(dst, src, num);
} }
...@@ -320,6 +329,8 @@ void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>( ...@@ -320,6 +329,8 @@ void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
cudaStream_t stream) { cudaStream_t stream) {
if (UNLIKELY(num == 0)) return; if (UNLIKELY(num == 0)) return;
platform::SetDeviceId(src_place.device); platform::SetDeviceId(src_place.device);
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << " by thream(" << stream << ")";
if (stream) { if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned"); platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
...@@ -337,6 +348,8 @@ void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>( ...@@ -337,6 +348,8 @@ void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
if (UNLIKELY(num == 0)) return; if (UNLIKELY(num == 0)) return;
platform::SetDeviceId(dst_place.device); platform::SetDeviceId(dst_place.device);
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << " by thream(" << stream << ")";
if (stream) { if (stream) {
platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU"); platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
......
...@@ -34,7 +34,6 @@ class InplaceTestBase(unittest.TestCase): ...@@ -34,7 +34,6 @@ class InplaceTestBase(unittest.TestCase):
def initParameter(self): def initParameter(self):
self.use_cuda = True self.use_cuda = True
self.fuse_all_optimizer_ops = False self.fuse_all_optimizer_ops = False
self.fuse_all_reduce_ops = False
def setUp(self): def setUp(self):
paddle.enable_static() paddle.enable_static()
...@@ -94,7 +93,6 @@ class InplaceTestBase(unittest.TestCase): ...@@ -94,7 +93,6 @@ class InplaceTestBase(unittest.TestCase):
build_strategy.memory_optimize = memory_optimize build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace build_strategy.enable_inplace = enable_inplace
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
build_strategy.fuse_all_reduce_ops = self.fuse_all_reduce_ops
compiled_prog = fluid.CompiledProgram(prog).with_data_parallel( compiled_prog = fluid.CompiledProgram(prog).with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
build_strategy=build_strategy, build_strategy=build_strategy,
...@@ -117,15 +115,7 @@ class InplaceTestBase(unittest.TestCase): ...@@ -117,15 +115,7 @@ class InplaceTestBase(unittest.TestCase):
fetch_val2, = exe.run(compiled_prog, fetch_val2, = exe.run(compiled_prog,
feed=feed_dict, feed=feed_dict,
fetch_list=[fetch_var]) fetch_list=[fetch_var])
#NOTE(zhiqiu): Temporally changed from array_equal to allclose. self.assertTrue(np.array_equal(fetch_val1, fetch_val2))
# The real root is fuse_all_reduce and fuse_all_optimizer_opss may
# result in diff because of the instruction set on the virtual machine.
# And the related unit tests: test_fuse_all_reduce_pass and test_fuse_optimizer_pass use "almostEqual" in their checks.
# There are also some related issues:
# https://github.com/PaddlePaddle/Paddle/issues/21270
# https://github.com/PaddlePaddle/Paddle/issues/21046
# https://github.com/PaddlePaddle/Paddle/issues/21045
self.assertTrue(np.allclose(fetch_val1, fetch_val2))
def check_multi_card_fetch_var(self): def check_multi_card_fetch_var(self):
if self.is_invalid_test(): if self.is_invalid_test():
...@@ -148,7 +138,6 @@ class InplaceTestBase(unittest.TestCase): ...@@ -148,7 +138,6 @@ class InplaceTestBase(unittest.TestCase):
build_strategy.memory_optimize = memory_optimize build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace build_strategy.enable_inplace = enable_inplace
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
build_strategy.fuse_all_reduce_ops = self.fuse_all_reduce_ops
compiled_program = fluid.CompiledProgram( compiled_program = fluid.CompiledProgram(
prog).with_data_parallel( prog).with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
...@@ -170,15 +159,13 @@ class InplaceTestBase(unittest.TestCase): ...@@ -170,15 +159,13 @@ class InplaceTestBase(unittest.TestCase):
fetch_vals.append(fetch_val) fetch_vals.append(fetch_val)
for item in fetch_vals: for item in fetch_vals:
# save above self.assertTrue(np.array_equal(fetch_vals[0], item))
self.assertTrue(np.allclose(fetch_vals[0], item))
class CUDAInplaceTest(InplaceTestBase): class CUDAInplaceTest(InplaceTestBase):
def initParameter(self): def initParameter(self):
self.use_cuda = True self.use_cuda = True
self.fuse_all_optimizer_ops = False self.fuse_all_optimizer_ops = False
self.fuse_all_reduce_ops = False
def test_multi_card_fetch_var(self): def test_multi_card_fetch_var(self):
self.check_multi_card_fetch_var() self.check_multi_card_fetch_var()
...@@ -191,7 +178,6 @@ class CPUInplaceTest(InplaceTestBase): ...@@ -191,7 +178,6 @@ class CPUInplaceTest(InplaceTestBase):
def initParameter(self): def initParameter(self):
self.use_cuda = False self.use_cuda = False
self.fuse_all_optimizer_ops = False self.fuse_all_optimizer_ops = False
self.fuse_all_reduce_ops = False
def test_multi_card_fetch_var(self): def test_multi_card_fetch_var(self):
self.check_multi_card_fetch_var() self.check_multi_card_fetch_var()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册