提交 58db07b7 编写于 作者: Q qingqing01 提交者: QI JUN

Check errors for the cuda kernel calls. (#5436)

上级 6f43c936
......@@ -440,6 +440,9 @@ void OperatorWithKernel::Run(const Scope& scope,
}
kernel_iter->second->Compute(ctx);
// throws errors if have.
dev_ctx.Finish();
}
} // namespace framework
......
......@@ -244,11 +244,6 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
}
cudaStreamSynchronize(stream);
// TODO(qingqing): Add cuda error check for each kernel.
cudaError_t err = cudaGetLastError();
PADDLE_ENFORCE(err, cudaGetErrorString(err));
}
} // namespace detail
......
......@@ -124,6 +124,11 @@ void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
}
void CUDADeviceContext::Finish() const {
Wait();
PADDLE_ENFORCE(cudaGetLastError());
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
......
......@@ -46,6 +46,8 @@ class DeviceContext {
DeviceType* GetEigenDevice() const;
virtual void Wait() const {}
virtual void Finish() const {}
};
class CPUDeviceContext : public DeviceContext {
......@@ -77,6 +79,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Wait for all operations completion in the stream. */
void Wait() const override;
/*! \brief Check potential errors for the cuda kernel calls. */
void Finish() const override;
/*! \brief Return place in the device context. */
Place GetPlace() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册