diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 22a7d9728a05950d66a1acd23d6fb18263f4ff6b..8150bf923926ed871ee69f2dd8c588451d68af51 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -440,6 +440,9 @@ void OperatorWithKernel::Run(const Scope& scope, } kernel_iter->second->Compute(ctx); + + // throws errors if have. + dev_ctx.Finish(); } } // namespace framework diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 41a54a359daa14a047c49728962ea15eefd12274..8b46510db05fbc87ed482bbcad29c9da2fdfb97c 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -244,11 +244,6 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, op, value, grad, frameSize, batchSize, active_node, active_gate, active_state); } - - cudaStreamSynchronize(stream); - // TODO(qingqing): Add cuda error check for each kernel. - cudaError_t err = cudaGetLastError(); - PADDLE_ENFORCE(err, cudaGetErrorString(err)); } } // namespace detail diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 36450e926891342f37424447703781a33c1190ae..7afcdfce9371e29aad968a1729931173fb2309b5 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -124,6 +124,11 @@ void CUDADeviceContext::Wait() const { PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); } +void CUDADeviceContext::Finish() const { + Wait(); + PADDLE_ENFORCE(cudaGetLastError()); +} + Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index ef5f19214d9ccb23b9c946bee28cb764122bd7cd..526d089e35da9c9f89a3852095ad3a4c82d4d85d 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -46,6 +46,8 @@ class DeviceContext { DeviceType* GetEigenDevice() const; virtual void Wait() const {} + + virtual void Finish() const {} }; class CPUDeviceContext : public DeviceContext { @@ -77,6 +79,9 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Wait for all operations completion in the stream. */ void Wait() const override; + /*! \brief Check potential errors for the cuda kernel calls. */ + void Finish() const override; + /*! \brief Return place in the device context. */ Place GetPlace() const override;