提交 58db07b7 编写于 作者: Q qingqing01 提交者: QI JUN

Check errors for the cuda kernel calls. (#5436)

上级 6f43c936
...@@ -440,6 +440,9 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -440,6 +440,9 @@ void OperatorWithKernel::Run(const Scope& scope,
} }
kernel_iter->second->Compute(ctx); kernel_iter->second->Compute(ctx);
// throws errors if have.
dev_ctx.Finish();
} }
} // namespace framework } // namespace framework
......
...@@ -244,11 +244,6 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, ...@@ -244,11 +244,6 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
op, value, grad, frameSize, batchSize, active_node, active_gate, op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state); active_state);
} }
cudaStreamSynchronize(stream);
// TODO(qingqing): Add cuda error check for each kernel.
cudaError_t err = cudaGetLastError();
PADDLE_ENFORCE(err, cudaGetErrorString(err));
} }
} // namespace detail } // namespace detail
......
...@@ -124,6 +124,11 @@ void CUDADeviceContext::Wait() const { ...@@ -124,6 +124,11 @@ void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
} }
void CUDADeviceContext::Finish() const {
Wait();
PADDLE_ENFORCE(cudaGetLastError());
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get(); return eigen_device_.get();
} }
......
...@@ -46,6 +46,8 @@ class DeviceContext { ...@@ -46,6 +46,8 @@ class DeviceContext {
DeviceType* GetEigenDevice() const; DeviceType* GetEigenDevice() const;
virtual void Wait() const {} virtual void Wait() const {}
virtual void Finish() const {}
}; };
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
...@@ -77,6 +79,9 @@ class CUDADeviceContext : public DeviceContext { ...@@ -77,6 +79,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Wait for all operations completion in the stream. */ /*! \brief Wait for all operations completion in the stream. */
void Wait() const override; void Wait() const override;
/*! \brief Check potential errors for the cuda kernel calls. */
void Finish() const override;
/*! \brief Return place in the device context. */ /*! \brief Return place in the device context. */
Place GetPlace() const override; Place GetPlace() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册