提交 499b3c70 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #4338 from reyoung/feature/add_wait_to_device_ctx

Sync computation when Python invoke `run`
......@@ -34,13 +34,14 @@ class DeviceContext {
template <typename DeviceType>
DeviceType* get_eigen_device() const;
virtual void Wait() const {}
};
class CPUDeviceContext : public DeviceContext {
public:
CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace place);
virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const;
......@@ -59,7 +60,7 @@ class CUDADeviceContext : public DeviceContext {
virtual ~CUDADeviceContext();
/*! \brief Wait for all operations completion in the stream. */
void Wait() const;
void Wait() const override;
/*! \brief Return place in the device context. */
Place GetPlace() const override;
......
......@@ -237,7 +237,13 @@ All parameter, weight, gradient are variables in Paddle.
return Backward(forwardOp, no_grad_vars).release();
})
.def("infer_shape", &OperatorBase::InferShape)
.def("run", &OperatorBase::Run)
.def("run",
[](OperatorBase &self,
const Scope &scope,
const platform::DeviceContext &dev_ctx) {
self.Run(scope, dev_ctx);
dev_ctx.Wait();
})
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册