提交 9eb2d7b3 编写于 作者: J jerrywgz

refine code, test=develop

上级 6dfd789b
...@@ -171,15 +171,18 @@ void SliceOneClass(const platform::DeviceContext& ctx, ...@@ -171,15 +171,18 @@ void SliceOneClass(const platform::DeviceContext& ctx,
const T* items_data = items.data<T>(); const T* items_data = items.data<T>();
const int64_t num_item = items.dims()[0]; const int64_t num_item = items.dims()[0];
const int class_num = items.dims()[1]; const int class_num = items.dims()[1];
int item_size = 1;
if (items.dims().size() == 3) { if (items.dims().size() == 3) {
item_size = items.dims()[2]; int item_size = items.dims()[2];
}
for (int i = 0; i < num_item; ++i) { for (int i = 0; i < num_item; ++i) {
std::memcpy(item_data + i * item_size, std::memcpy(item_data + i * item_size,
items_data + i * class_num * item_size + class_id * item_size, items_data + i * class_num * item_size + class_id * item_size,
sizeof(T) * item_size); sizeof(T) * item_size);
} }
} else {
for (int i = 0; i < num_item; ++i) {
item_data[i] = items_data[i * class_num + class_id];
}
}
} }
template <typename T> template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册