提交 ba1f5b5c 编写于 作者: Y Yu Yang

Sync computation when Python invoke `run`

* Since GPU is an async device by default. We should sync computation
  when Python invoke `run`. So Python can get the correct computation
  result
上级 7d33447d
...@@ -34,13 +34,14 @@ class DeviceContext { ...@@ -34,13 +34,14 @@ class DeviceContext {
template <typename DeviceType> template <typename DeviceType>
DeviceType* get_eigen_device() const; DeviceType* get_eigen_device() const;
virtual void Wait() const {}
}; };
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
CPUDeviceContext(); CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace place); explicit CPUDeviceContext(CPUPlace place);
virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const; Eigen::DefaultDevice* eigen_device() const;
...@@ -59,7 +60,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -59,7 +60,7 @@ class CUDADeviceContext : public DeviceContext {
virtual ~CUDADeviceContext(); virtual ~CUDADeviceContext();
/*! \brief Wait for all operations completion in the stream. */ /*! \brief Wait for all operations completion in the stream. */
void Wait() const; void Wait() const override;
/*! \brief Return place in the device context. */ /*! \brief Return place in the device context. */
Place GetPlace() const override; Place GetPlace() const override;
......
...@@ -238,7 +238,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -238,7 +238,13 @@ All parameter, weight, gradient are variables in Paddle.
return Backward(forwardOp, no_grad_vars).release(); return Backward(forwardOp, no_grad_vars).release();
}) })
.def("infer_shape", &OperatorBase::InferShape) .def("infer_shape", &OperatorBase::InferShape)
.def("run", &OperatorBase::Run) .def("run",
[](OperatorBase &self,
const Scope &scope,
const platform::DeviceContext &dev_ctx) {
self.Run(scope, dev_ctx);
dev_ctx.Wait();
})
.def("type", .def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); }) [](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs", .def("outputs",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册