提交 c7bbfb33 编写于 作者: F fengjiayi

Fix a GPU bug

上级 24649a78
...@@ -39,11 +39,16 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) { ...@@ -39,11 +39,16 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rank, offsets_tensor->dims()[0], rank, offsets_tensor->dims()[0],
"Offsets size should be equal to dimension size of input tensor."); "Offsets size should be equal to dimension size of input tensor.");
const int* offsets_data = offsets_tensor->data<int>(); const int* offsets_data;
res.resize(rank); framework::Tensor cpu_tmp_tensor;
for (size_t i = 0; i < rank; ++i) { if (platform::is_cpu_place(offsets_tensor->place())) {
res[i] = offsets_data[i]; offsets_data = offsets_tensor->data<int>();
} else {
framework::TensorCopySync(*offsets_tensor, platform::CPUPlace(),
&cpu_tmp_tensor);
offsets_data = cpu_tmp_tensor.data<int>();
} }
res = std::vector<int>(offsets_data, offsets_data + rank);
} else { } else {
res = ctx.Attr<std::vector<int>>("offsets"); res = ctx.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册