提交 b836201b 编写于 作者: E eclipsess

code style

上级 8162ba9a
...@@ -51,38 +51,38 @@ class ConcatFunctor { ...@@ -51,38 +51,38 @@ class ConcatFunctor {
} }
} }
}; };
template <typename T> // template <typename T>
void StridedNumelCopyWithAxis(int64_t axis, T *dst, // void StridedNumelCopyWithAxis(int64_t axis, T *dst,
const framework::DDim &dst_stride_numel, // const framework::DDim &dst_stride_numel,
const T *src, // const T *src,
const framework::DDim &src_stride_numel, // const framework::DDim &src_stride_numel,
int64_t size) { // int64_t size) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; // int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis]; // int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis]; // int64_t dst_after = dst_stride_numel[axis];
//
/// "src and dst tensor should have the same dims size." // /// "src and dst tensor should have the same dims size."
assert(src_stride_numel.size() == dst_stride_numel.size()); // assert(src_stride_numel.size() == dst_stride_numel.size());
//
for (int64_t i = 0; i < axis; ++i) { // for (int64_t i = 0; i < axis; ++i) {
if (i < axis) { // if (i < axis) {
/// src and dst should have the same elements // /// src and dst should have the same elements
/// except the specified axis. // /// except the specified axis.
assert(src_stride_numel[i] / src_stride_numel[axis] == // assert(src_stride_numel[i] / src_stride_numel[axis] ==
dst_stride_numel[i] / dst_stride_numel[axis]); // dst_stride_numel[i] / dst_stride_numel[axis]);
//
} else if (i == axis) { // } else if (i == axis) {
continue; // continue;
} else { // } else {
/// "src and dst should have the same elements " // /// "src and dst should have the same elements "
/// "except the specified axis." // /// "except the specified axis."
assert(src_stride_numel[i] == dst_stride_numel[i]); // assert(src_stride_numel[i] == dst_stride_numel[i]);
} // }
} // }
for (int64_t i = 0; i < before; ++i) { // for (int64_t i = 0; i < before; ++i) {
memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size); // memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size);
} // }
} //}
template <> template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const { void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册