提交 b836201b 编写于 作者: E eclipsess

code style

上级 8162ba9a
......@@ -51,38 +51,38 @@ class ConcatFunctor {
}
}
};
template <typename T>
void StridedNumelCopyWithAxis(int64_t axis, T *dst,
const framework::DDim &dst_stride_numel,
const T *src,
const framework::DDim &src_stride_numel,
int64_t size) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis];
/// "src and dst tensor should have the same dims size."
assert(src_stride_numel.size() == dst_stride_numel.size());
for (int64_t i = 0; i < axis; ++i) {
if (i < axis) {
/// src and dst should have the same elements
/// except the specified axis.
assert(src_stride_numel[i] / src_stride_numel[axis] ==
dst_stride_numel[i] / dst_stride_numel[axis]);
} else if (i == axis) {
continue;
} else {
/// "src and dst should have the same elements "
/// "except the specified axis."
assert(src_stride_numel[i] == dst_stride_numel[i]);
}
}
for (int64_t i = 0; i < before; ++i) {
memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size);
}
}
// template <typename T>
// void StridedNumelCopyWithAxis(int64_t axis, T *dst,
// const framework::DDim &dst_stride_numel,
// const T *src,
// const framework::DDim &src_stride_numel,
// int64_t size) {
// int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
// int64_t src_after = src_stride_numel[axis];
// int64_t dst_after = dst_stride_numel[axis];
//
// /// "src and dst tensor should have the same dims size."
// assert(src_stride_numel.size() == dst_stride_numel.size());
//
// for (int64_t i = 0; i < axis; ++i) {
// if (i < axis) {
// /// src and dst should have the same elements
// /// except the specified axis.
// assert(src_stride_numel[i] / src_stride_numel[axis] ==
// dst_stride_numel[i] / dst_stride_numel[axis]);
//
// } else if (i == axis) {
// continue;
// } else {
// /// "src and dst should have the same elements "
// /// "except the specified axis."
// assert(src_stride_numel[i] == dst_stride_numel[i]);
// }
// }
// for (int64_t i = 0; i < before; ++i) {
// memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size);
// }
//}
template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册