提交 c7bbfb33 编写于 作者: F fengjiayi

Fix a GPU bug

上级 24649a78
......@@ -39,11 +39,16 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
PADDLE_ENFORCE_EQ(
rank, offsets_tensor->dims()[0],
"Offsets size should be equal to dimension size of input tensor.");
const int* offsets_data = offsets_tensor->data<int>();
res.resize(rank);
for (size_t i = 0; i < rank; ++i) {
res[i] = offsets_data[i];
const int* offsets_data;
framework::Tensor cpu_tmp_tensor;
if (platform::is_cpu_place(offsets_tensor->place())) {
offsets_data = offsets_tensor->data<int>();
} else {
framework::TensorCopySync(*offsets_tensor, platform::CPUPlace(),
&cpu_tmp_tensor);
offsets_data = cpu_tmp_tensor.data<int>();
}
res = std::vector<int>(offsets_data, offsets_data + rank);
} else {
res = ctx.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册