未验证 提交 7c046ae7 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #12323 from reyoung/feature/polish_reshape_and_lod_tensor_blocking_queue

Feature/polish reshape and lod tensor blocking queue
...@@ -38,12 +38,10 @@ class LoDTensorBlockingQueue { ...@@ -38,12 +38,10 @@ class LoDTensorBlockingQueue {
public: public:
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) { bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
CheckDims(lod_tensor_vec);
return queue_.Send(lod_tensor_vec); return queue_.Send(lod_tensor_vec);
} }
bool Push(std::vector<framework::LoDTensor>&& lod_tensor_vec) { bool Push(std::vector<framework::LoDTensor>&& lod_tensor_vec) {
CheckDims(lod_tensor_vec);
return queue_.Send(std::move(lod_tensor_vec)); return queue_.Send(std::move(lod_tensor_vec));
} }
...@@ -65,21 +63,6 @@ class LoDTensorBlockingQueue { ...@@ -65,21 +63,6 @@ class LoDTensorBlockingQueue {
inline bool IsClosed() const { return queue_.IsClosed(); } inline bool IsClosed() const { return queue_.IsClosed(); }
private: private:
void CheckDims(
const std::vector<framework::LoDTensor>& lod_tensor_vec) const {
PADDLE_ENFORCE(dims_.size() == lod_tensor_vec.size(),
"Expect input size is %d but found %s", dims_.size(),
lod_tensor_vec.size());
for (size_t i = 0; i < dims_.size(); ++i) {
const auto& in_dims = framework::slice_ddim(
lod_tensor_vec[i].dims(), 1, lod_tensor_vec[i].dims().size());
const auto& expect_dims =
framework::slice_ddim(dims_[i], 1, dims_[i].size());
PADDLE_ENFORCE(in_dims == expect_dims,
"Dims of the %d-th input tensor do not match", i);
}
}
BlockingQueue<std::vector<framework::LoDTensor>> queue_; BlockingQueue<std::vector<framework::LoDTensor>> queue_;
std::vector<framework::DDim> dims_; std::vector<framework::DDim> dims_;
}; };
......
...@@ -216,7 +216,7 @@ class ReshapeKernel { ...@@ -216,7 +216,7 @@ class ReshapeKernel {
if (shape_tensor) { if (shape_tensor) {
auto *shape_data = shape_tensor->data<int>(); auto *shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor; framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(shape_tensor->place())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>(); shape_data = cpu_shape_tensor.data<int>();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册