// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/extension.h" int64_t GetRows(std::vector shape, int64_t axis) { int64_t rows = 1; for (int64_t i = 0; i < axis; ++i) { rows *= shape[i]; } return rows; } std::vector GetCols(const std::vector& ins, int64_t rows, int64_t* cols) { std::vector cols_vec(ins.size()); for (size_t i = 0; i < ins.size(); ++i) { int64_t t_cols = ins[i].size() / rows; *cols += t_cols; cols_vec[i] = t_cols; } return cols_vec; } template void ConcatCpuKernel(const std::vector& ins, paddle::Tensor* out, int64_t axis) { size_t num = ins.size(); int64_t out_rows = GetRows(ins[0].shape(), axis); int64_t out_cols = 0; auto ins_cols = GetCols(ins, out_rows, &out_cols); auto* out_data = out->mutable_data(); int64_t col_idx = 0; for (size_t i = 0; i < num; ++i) { int64_t col_len = ins_cols[i]; auto* in_data = ins[i].data(); for (int j = 0; j < out_rows; ++j) { std::memcpy(out_data + j * out_cols + col_idx, in_data + j * col_len, sizeof(data_t) * col_len); } col_idx += col_len; } } template void SplitCpuKernel(const paddle::Tensor& in, const std::vector& ref_ins, std::vector* outs, int64_t axis) { size_t num = outs->size(); int64_t in_rows = GetRows(ref_ins[0].shape(), axis); int64_t in_cols = 0; auto out_cols = GetCols(ref_ins, in_rows, &in_cols); for (size_t i = 0; i < in_rows; ++i) { auto* in_data = in.data() + i * in_cols; int64_t col_idx = 0; for (size_t j = 0; j < num; ++j) { int64_t col_len = out_cols[j]; auto* out_data = outs->at(j).mutable_data() + i * col_len; std::memcpy(out_data, in_data + col_idx, sizeof(data_t) * col_len); col_idx += col_len; } } }