diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index e4fe35b9b5c5a70f94c1edc17ee7cddc63180ed6..286a8684127a9fcbc42e98b89828d6acb87b859c 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -94,7 +94,7 @@ else() endif() cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim mixed_vector place tensor framework_proto version) -cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) +cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_utils lod_tensor memory) if(WITH_GPU) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index a4b9fff8ecd1530cf4b6569a4226784583b7ea59..ab2e30a15ea15c55c118225fdcaec4c66f95f20e 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -117,7 +117,8 @@ bool CheckLoD(const LoD &in, int tensor_height) { } // check: the lowest level's last offset should equals `tensor_height` if // tensor_height>0. - if (tensor_height > 0 && (size_t)tensor_height != in.back().back()) + if (tensor_height > 0 && + static_cast(tensor_height) != in.back().back()) return false; // check: the higher level's last offset should equals the lower level's @@ -150,7 +151,7 @@ bool CheckAbsLoD(const LoD &in, int tensor_height) { if (level.front() != 0) return false; if (tensor_height < 0) { tensor_height = level.back(); - } else if ((size_t)tensor_height != level.back()) { + } else if (static_cast(tensor_height) != level.back()) { return false; } } @@ -186,27 +187,6 @@ LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx, return LoDAndOffset{sub_lod, {start_idx, end_idx}}; } -void AppendLoD(LoD *lod, const LoD &lod_length) { - PADDLE_ENFORCE( - lod->empty() || lod->size() == lod_length.size(), - platform::errors::InvalidArgument( - "The input LoD length should be equal to the appended LoD size, but " - "received input LoD length is %d, actual LoD size is %d.", - lod_length, lod->size())); - if (lod->empty()) { - for (size_t i = 0; i < lod_length.size(); ++i) { - lod->emplace_back(1, 0); // size = 1, value = 0; - } - *lod = LoD(lod_length.size(), std::vector({0})); - } - for (size_t i = 0; i < lod->size(); ++i) { - auto &level = (*lod)[i]; - for (size_t len : lod_length[i]) { - level.push_back(level.back() + len); - } - } -} - void SerializeToStream(std::ostream &os, const LoDTensor &tensor, const platform::DeviceContext &dev_ctx) { { // the 1st field, uint32_t version for LoDTensor @@ -313,22 +293,6 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, TensorFromStream(is, static_cast(tensor), dev_ctx); } -LoD ConvertToLengthBasedLoD(const LoD &offset_lod) { - LoD length_lod; - length_lod.reserve(offset_lod.size()); - for (size_t lvl = 0; lvl < offset_lod.size(); ++lvl) { - std::vector level; - if (offset_lod[lvl].size() > 0) { - level.reserve(offset_lod[lvl].size() - 1); - } - for (size_t idx = 0; idx < offset_lod[lvl].size() - 1; ++idx) { - level.push_back(offset_lod[lvl][idx + 1] - offset_lod[lvl][idx]); - } - length_lod.push_back(level); - } - return length_lod; -} - LoD ConvertToOffsetBasedLoD(const LoD &length_lod) { LoD offset_lod; offset_lod.reserve(length_lod.size()); diff --git a/paddle/fluid/framework/lod_tensor.h b/paddle/fluid/framework/lod_tensor.h index 14727c190b581b2ab09087d3ca1b2d71f5c83770..63680c008bf66108d20ae9f579995695c1d5d26e 100644 --- a/paddle/fluid/framework/lod_tensor.h +++ b/paddle/fluid/framework/lod_tensor.h @@ -157,8 +157,6 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level, std::pair> GetSubLoDAndAbsoluteOffset( const LoD& lod, size_t start_idx, size_t end_idx, size_t start_level); -void AppendLoD(LoD* lod, const LoD& lod_length); - /* * Serialize/Desiralize LoDTensor to std::ostream * You can pass ofstream or ostringstream to serilize to file @@ -173,18 +171,6 @@ void DeserializeFromStream(std::istream& is, LoDTensor* tensor, const size_t& seek, const std::vector& shape); -/* - * Convert between length-based LoD and offset-based LoD. - * The implementation of LoDTensor class use offset-based LoD. - * However, we want to expose the more user-friendly length-based - * LoD to the Python side instead. - * - * Example: - * If offset_lod = [[0, 2, 3],[0, 3, 5, 9]] - * then length_lod = [[2, 1], [3, 2, 4]] - */ -LoD ConvertToLengthBasedLoD(const LoD& offset_lod); - LoD ConvertToOffsetBasedLoD(const LoD& length_lod); void SerializeToStream(std::ostream& os, const LoDTensor& tensor); diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/paddle/fluid/framework/lod_tensor_test.cc index 917bb7cc096c266fa8df20a79dc2c1dac5b18f12..5e72c2d3d7e94532ee1081892fb280676b8cba48 100644 --- a/paddle/fluid/framework/lod_tensor_test.cc +++ b/paddle/fluid/framework/lod_tensor_test.cc @@ -16,6 +16,7 @@ #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/pten/core/lod_utils.h" namespace paddle { namespace framework { @@ -98,7 +99,7 @@ TEST(LoD, AppendLoD) { origin.push_back(std::vector({0, 1, 6})); origin.push_back(std::vector({0, 2, 5, 7, 10, 12, 15})); - paddle::framework::AppendLoD(&origin, lod_lens); + pten::AppendLoD(&origin, lod_lens); LoD expected; expected.push_back(std::vector({0, 2, 4})); @@ -277,7 +278,7 @@ TEST(LoD, ConvertToLengthBasedLoD) { offset_lod.push_back(std::vector({0, 1, 3})); offset_lod.push_back(std::vector({0, 2, 4, 5})); - LoD length_lod = ConvertToLengthBasedLoD(offset_lod); + LoD length_lod = pten::ConvertToLengthBasedLoD(offset_lod); LoD expected; expected.push_back(std::vector({2})); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e69a6c2e88c6b8e7957a855d0374667106fd22b4..33a4e5d2f390611a3f079bff3232a1bd5f7b3ac0 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1978,6 +1978,10 @@ void OperatorWithKernel::BuildPtenKernelContext( std::type_index(typeid(std::string))) { pt_kernel_context->EmplaceBackAttr( std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(int))) { + pt_kernel_context->EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(int, attr)))); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to Scalar when construct " diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index d28595a6a4c75d1ffa4d522d756cc7f7b8529254..fe60f05e1da431dc7ed7b45acebb8cffecc12941 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -438,6 +438,10 @@ static void BuildDygraphPtenKernelContext( std::type_index(typeid(std::string))) { kernel_ctx->EmplaceBackAttr( std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(int))) { + kernel_ctx->EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(int, attr)))); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to Scalar when construct " diff --git a/paddle/fluid/operators/array_to_lod_tensor_op.cc b/paddle/fluid/operators/array_to_lod_tensor_op.cc index 1680ad528abf9b63c42470ebf77b3457ec318ecf..a959067ddba62acf88d1caf19c58fe90cd8852d5 100644 --- a/paddle/fluid/operators/array_to_lod_tensor_op.cc +++ b/paddle/fluid/operators/array_to_lod_tensor_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/core/lod_utils.h" namespace paddle { namespace framework { @@ -168,7 +169,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { x[x_idx].lod(), idx, idx + 1, 0); auto &lod_length = lod_and_offset.first; - framework::AppendLoD(out_lod, lod_length); + pten::AppendLoD(out_lod, lod_length); size_t start_offset = lod_and_offset.second.first; size_t end_offset = lod_and_offset.second.second; diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index e6b1f6a1c18c38d94d9e3bc4807de7d8b952d60d..9eba127a9b3ceace225e3d3dcf867df518c4477e 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -19,6 +19,8 @@ limitations under the License. */ #include #include +#include "paddle/pten/kernels/funcs/concat_funcs.h" + #ifdef PADDLE_WITH_MKLDNN #include #endif @@ -56,8 +58,8 @@ class ConcatOp : public framework::OperatorWithKernel { size_t axis = ComputeAxis(static_cast(ctx->Attrs().Get("axis")), static_cast(inputs_dims[0].size())); - framework::DDim out_dims = - ComputeAndCheckShape(ctx->IsRuntime(), inputs_dims, axis); + framework::DDim out_dims = pten::funcs::ComputeAndCheckShape( + ctx->IsRuntime(), inputs_dims, axis); if (out_dims[axis] < 0) { out_dims[axis] = -1; } @@ -102,6 +104,15 @@ class ConcatOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext &ctx) const override { + if (ctx.HasInput("AxisTensor")) { + return framework::KernelSignature("concat", {"X"}, {"AxisTensor"}, + {"Out"}); + } + return framework::KernelSignature("concat", {"X"}, {"axis"}, {"Out"}); + } }; class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index bb72174be5ed571dcc8d1467c71ef5980f2fb965..3eaffbdc8bf35be9af8da73d28c92f4d8f00f53b 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -22,54 +22,11 @@ limitations under the License. */ #include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/pten/kernels/concat_kernel.h" +#include "paddle/pten/kernels/funcs/concat_funcs.h" + namespace paddle { namespace operators { -static inline framework::DDim ComputeAndCheckShape( - const bool is_runtime, const std::vector& inputs_dims, - const size_t axis) { - const size_t n = inputs_dims.size(); - auto out_dims = inputs_dims[0]; - size_t in_zero_dims_size = out_dims.size(); - for (size_t i = 1; i < n; i++) { - PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(), - platform::errors::InvalidArgument( - "The shape of input[0] and input[%d] " - "is expected to be equal." - "But received input[0]'s shape = " - "[%s], input[%d]'s shape = [%s].", - i, inputs_dims[0], i, inputs_dims[i])); - for (size_t j = 0; j < in_zero_dims_size; j++) { - if (j == axis) { - if (is_runtime) { - out_dims[axis] += inputs_dims[i][j]; - } else { - if (inputs_dims[i][j] == -1 || out_dims[j] == -1) { - out_dims[axis] = -1; - } else { - out_dims[axis] += inputs_dims[i][j]; - } - } - } else { - bool check_shape = - is_runtime || (inputs_dims[0][j] > 0 && inputs_dims[i][j] > 0); - if (check_shape) { - // check all shape in run time - PADDLE_ENFORCE_EQ(inputs_dims[0][j], inputs_dims[i][j], - platform::errors::InvalidArgument( - "The %d-th dimension of input[0] and input[%d] " - "is expected to be equal." - "But received input[0]'s shape = " - "[%s], input[%d]'s shape = [%s].", - j, i, inputs_dims[0], i, inputs_dims[i])); - } - if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) { - out_dims[j] = inputs_dims[i][j]; - } - } - } - } - return out_dims; -} static inline int64_t ComputeAxis(int64_t axis, int64_t rank) { PADDLE_ENFORCE_EQ( @@ -109,67 +66,21 @@ class ConcatKernel : public framework::OpKernel { ins_dims[i] = ins[i]->dims(); } - framework::DDim out_dims = ComputeAndCheckShape(true, ins_dims, axis); + framework::DDim out_dims = + pten::funcs::ComputeAndCheckShape(true, ins_dims, axis); out->Resize(out_dims); } auto place = ctx.GetPlace(); out->mutable_data(place); - // If axis is 0, the lod of the output is not the same as inputs. - if (axis == 0 && ins[0]->lod().size() > 0) { - size_t lod_size_0 = ins[0]->lod().size(); - size_t lod_size = lod_size_0; - for (size_t i = 1; i < ins.size(); ++i) { - if (ins[i]->lod().size() > 0) { - PADDLE_ENFORCE_EQ( - ins[i]->lod().size(), lod_size_0, - platform::errors::Unimplemented( - "The lod level of all input LoDTensors should be same. " - "Maybe different lod level of input LoDTensors can concat," - "it is not supported currently. The lod level of %dth input " - "is %d and first input is %d.", - i, ins[i]->lod().size(), lod_size_0)); - } else { - lod_size = 0; - break; - } - } - if (lod_size) { - auto* out_lod = out->mutable_lod(); - for (size_t i = 1; i < ins.size(); ++i) { - auto in_lod = ConvertToLengthBasedLoD(ins[i]->lod()); - AppendLoD(out_lod, in_lod); - } - } + // call new kernel + auto& dev_ctx = ctx.device_context(); + std::vector pt_ins; + for (auto& in : ins) { + pt_ins.push_back(*in); } - // Sometimes direct copies will be faster, this maybe need deeply analysis. - if (axis == 0 && ins.size() < 10) { - size_t output_offset = 0; - for (auto* in : ins) { - if (!in || in->numel() == 0UL) { - continue; - } - auto in_stride = framework::stride_numel(in->dims()); - auto out_stride = framework::stride_numel(out->dims()); - StridedNumelCopyWithAxis(ctx.device_context(), axis, - out->data() + output_offset, out_stride, - in->data(), in_stride, in_stride[axis]); - output_offset += in_stride[axis]; - } - } else { - std::vector inputs; - for (size_t j = 0; j < ins.size(); ++j) { - if (ins[j] && ins[j]->numel() > 0) { - inputs.push_back(*ins[j]); - } else { - continue; - } - } - auto& dev_ctx = ctx.template device_context(); - paddle::operators::math::ConcatFunctor concat_functor; - concat_functor(dev_ctx, inputs, static_cast(axis), out); - } + pten::ConcatKernel(dev_ctx, pt_ins, axis, out); } }; diff --git a/paddle/fluid/operators/concat_op_xpu.cc b/paddle/fluid/operators/concat_op_xpu.cc index 0ff11e11165f06628566c65050c90e0ce0f17240..aa10a58738bbd83d0ceae2e5820feeeb0ef3f0a9 100644 --- a/paddle/fluid/operators/concat_op_xpu.cc +++ b/paddle/fluid/operators/concat_op_xpu.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include #include "paddle/fluid/platform/device/xpu/xpu_header.h" +#include "paddle/pten/core/lod_utils.h" + namespace paddle { namespace operators { using Tensor = framework::Tensor; @@ -69,8 +71,8 @@ class ConcatXPUKernel : public framework::OpKernel { if (lod_size) { auto* out_lod = out->mutable_lod(); for (size_t i = 1; i < ins.size(); ++i) { - auto in_lod = ConvertToLengthBasedLoD(ins[i]->lod()); - AppendLoD(out_lod, in_lod); + auto in_lod = pten::ConvertToLengthBasedLoD(ins[i]->lod()); + pten::AppendLoD(out_lod, in_lod); } } } diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index e02972bd753538b7c0119c015990a17a23e8230a..5f39a9afa94bad7c16fce74547e7616a2a685846 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/core/lod_utils.h" namespace paddle { namespace framework { @@ -134,7 +135,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase { auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset( x.lod(), start_idx, start_idx + 1, rank_level + 1); auto &lod_length = lod_and_offset.first; - framework::AppendLoD(&lod, lod_length); + pten::AppendLoD(&lod, lod_length); size_t start_offset = lod_and_offset.second.first; size_t end_offset = lod_and_offset.second.second; copy_ranges[t].emplace_back(CopyRange{start_offset, end_offset}); diff --git a/paddle/fluid/operators/math/concat_and_split.cc b/paddle/fluid/operators/math/concat_and_split.cc index 4f12630d1e02fbffa405b06060aef2b4f2a730a0..a9f2680660bd22788602cae51b3bf1eddeedc545 100644 --- a/paddle/fluid/operators/math/concat_and_split.cc +++ b/paddle/fluid/operators/math/concat_and_split.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/concat_and_split.h" + +#include "paddle/pten/kernels/cpu/concat_and_split.h" #ifdef PADDLE_WITH_ASCEND_CL #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #endif @@ -44,36 +46,9 @@ class ConcatFunctor { void operator()(const platform::CPUDeviceContext& context, const std::vector& input, int axis, framework::Tensor* output) { - // TODO(zcd): Add input data validity checking - size_t num = input.size(); - - int64_t rows = 1; - auto dim_0 = input[0].dims(); - for (int i = 0; i < axis; ++i) { - rows *= dim_0[i]; - } - int64_t out_rows = rows, out_cols = 0; - - std::vector input_cols(input.size()); - for (size_t i = 0; i < num; ++i) { - int64_t t_cols = input[i].numel() / rows; - out_cols += t_cols; - input_cols[i] = t_cols; - } - auto cpu_place = context.GetPlace(); - - // computation - auto output_data = output->data(); - int64_t col_idx = 0; - for (size_t j = 0; j < num; ++j) { - int64_t col_len = input_cols[j]; - auto input_data = input[j].data(); - for (int64_t k = 0; k < out_rows; ++k) { - memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place, - input_data + k * col_len, sizeof(T) * col_len); - } - col_idx += col_len; - } + std::vector pt_input{input.begin(), input.end()}; + pten::ConcatImpl(context, pt_input, axis, + output); } }; @@ -88,46 +63,12 @@ class SplitFunctor { const framework::Tensor& input, const std::vector& ref_inputs, const int axis, std::vector* outputs) { - // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 - // tensors of shape [0,1,4] - if (input.numel() == 0) { - return; - } - - // TODO(zcd): Add input data validity checking - size_t num = outputs->size(); - - int input_rows = 1; - auto dim_0 = ref_inputs[0]->dims(); - for (int i = 0; i < axis; ++i) { - input_rows *= dim_0[i]; - } - - int input_cols = 0; - - std::vector output_cols(outputs->size()); - for (size_t i = 0; i < num; ++i) { - int t_cols = ref_inputs[i]->numel() / input_rows; - input_cols += t_cols; - output_cols[i] = t_cols; - } - auto cpu_place = context.GetPlace(); - - // computation - for (int k = 0; k < input_rows; ++k) { - const T* src_ptr = input.data() + k * input_cols; - int col_idx = 0; - for (size_t j = 0; j < num; ++j) { - int col_len = output_cols[j]; - auto* out_tensor = outputs->at(j); - if (out_tensor != nullptr) { - T* dst_ptr = out_tensor->data() + k * col_len; - memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx, - sizeof(T) * col_len); - } - col_idx += col_len; - } - } + std::vector pt_ref_inputs{ref_inputs.begin(), + ref_inputs.end()}; + std::vector pt_outputs{outputs->begin(), + outputs->end()}; + pten::SplitImpl( + context, input, pt_ref_inputs, axis, &pt_outputs); } }; diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index 5b99a62d78d2ace9c74d282346b6da5130ed18c7..4357a86b7e65dc2c06bba499f22d1bca08703e86 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -12,218 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "gflags/gflags.h" -#include "paddle/fluid/framework/mixed_vector.h" -#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/fluid/platform/float16.h" +#include "paddle/pten/kernels/gpu/concat_and_split.h" namespace paddle { namespace operators { namespace math { -template -__global__ void ConcatKernel(const T** inputs, const int64_t* input_cols, - int col_size, const int64_t output_rows, - const int64_t output_cols, T* output) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int curr_segment = 0; - int curr_offset = input_cols[0]; - for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { - int curr_col_offset = input_cols[curr_segment + 1]; - while (curr_col_offset <= tid_x) { - curr_offset = curr_col_offset; - ++curr_segment; - curr_col_offset = input_cols[curr_segment + 1]; - } - - int local_col = tid_x - curr_offset; - int segment_width = curr_col_offset - curr_offset; - - const T* input_ptr = inputs[curr_segment]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) - output[tid_y * output_cols + tid_x] = - input_ptr[tid_y * segment_width + local_col]; - } -} - -template -__device__ void ConcatKernelDetail(const T** inputs_data, - const int fixed_in_col, const int out_rows, - const int out_cols, T* output_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) { - int split = tid_x * 1.0 / fixed_in_col; - int in_offset = tid_x - split * fixed_in_col; - const T* input_ptr = inputs_data[split]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { - output_data[tid_y * out_cols + tid_x] = - input_ptr[tid_y * fixed_in_col + in_offset]; - } - } -} - -template -__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, - const int64_t fixed_in_col, const int64_t out_rows, - const int64_t out_cols, T* output_data) { - const T* inputs_data[2]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - ConcatKernelDetail(inputs_data, fixed_in_col, out_rows, out_cols, - output_data); -} - -template -__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, - const T* input_addr2, const int64_t fixed_in_col, - const int64_t out_rows, const int64_t out_cols, - T* output_data) { - const T* inputs_data[3]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - inputs_data[2] = input_addr2; - ConcatKernelDetail(inputs_data, fixed_in_col, out_rows, out_cols, - output_data); -} - -template -__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, - const T* input_addr2, const T* input_addr3, - const int64_t fixed_in_col, const int64_t out_rows, - const int64_t out_cols, T* output_data) { - const T* inputs_data[4]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - inputs_data[2] = input_addr2; - inputs_data[3] = input_addr3; - ConcatKernelDetail(inputs_data, fixed_in_col, out_rows, out_cols, - output_data); -} - -template -__global__ void ConcatKernel(const T** inputs_data, const int in_num, - const int64_t fixed_in_col, const int64_t out_rows, - const int64_t out_cols, T* output_data) { - ConcatKernelDetail(inputs_data, fixed_in_col, out_rows, out_cols, - output_data); -} - -template -__global__ void SplitKernel(const T* input_data, const int64_t in_row, - const int64_t in_col, const int64_t* out_cols, - int out_cols_size, T** outputs_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int curr_segment = 0; - int curr_offset = out_cols[0]; - for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { - int curr_col_offset = out_cols[curr_segment + 1]; - while (curr_col_offset <= tid_x) { - curr_offset = curr_col_offset; - ++curr_segment; - curr_col_offset = out_cols[curr_segment + 1]; - } - - int local_col = tid_x - curr_offset; - int segment_width = curr_col_offset - curr_offset; - T* output_ptr = outputs_data[curr_segment]; - if (output_ptr != nullptr) { - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * segment_width + local_col] = - input_data[tid_y * in_col + tid_x]; - } - } -} - -template -__device__ void SplitKernelDetail(const T* input_data, const int in_row, - const int in_col, const int fixed_out_col, - T** outputs_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { - int split = tid_x / fixed_out_col; - int in_offset = tid_x - split * fixed_out_col; - T* output_ptr = outputs_data[split]; - if (output_ptr != nullptr) { - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * fixed_out_col + in_offset] = - input_data[tid_y * in_col + tid_x]; - } - } -} - -template -__global__ void SplitKernel(const T* input_data, const int64_t in_row, - const int64_t in_col, const int64_t fixed_out_col, - T** outputs_data) { - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -template -__global__ void SplitKernel(const T* input_data, const int64_t in_row, - const int64_t in_col, const int64_t fixed_out_col, - T* outputs_addr0, T* outputs_addr1) { - T* outputs_data[2]; - outputs_data[0] = outputs_addr0; - outputs_data[1] = outputs_addr1; - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -template -__global__ void SplitKernel(const T* input_data, const int64_t in_row, - const int64_t in_col, const int64_t fixed_out_col, - T* outputs_addr0, T* outputs_addr1, - T* outputs_addr2) { - T* outputs_data[3]; - outputs_data[0] = outputs_addr0; - outputs_data[1] = outputs_addr1; - outputs_data[2] = outputs_addr2; - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -template -__global__ void SplitKernel(const T* input_data, const int64_t in_row, - const int64_t in_col, const int64_t fixed_out_col, - T* outputs_addr0, T* outputs_addr1, - T* outputs_addr2, T* outputs_addr3) { - T* outputs_data[4]; - outputs_data[0] = outputs_addr0; - outputs_data[1] = outputs_addr1; - outputs_data[2] = outputs_addr2; - outputs_data[3] = outputs_addr3; - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -static inline void GetBlockDims(const platform::CUDADeviceContext& context, - int64_t num_rows, int64_t num_cols, - dim3* block_dims, dim3* grid_dims) { - // Set the thread block and grid according to CurrentDeviceId - const int kThreadsPerBlock = 1024; - int block_cols = kThreadsPerBlock; - if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((num_cols + 31) >> 5) << 5; - } - int block_rows = kThreadsPerBlock / block_cols; - *block_dims = dim3(block_cols, block_rows, 1); - - int max_threads = context.GetMaxPhysicalThreadCount(); - int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1); - - int grid_cols = - std::min((num_cols + block_cols - 1) / block_cols, max_blocks); - int grid_rows = std::min(max_blocks / grid_cols, - std::max(num_rows / block_rows, (int64_t)1)); - *grid_dims = dim3(grid_cols, grid_rows, 1); -} - /* * All tensors' dimension should be the same and the values of * each dimension must be the same, except the axis dimension. @@ -234,112 +29,10 @@ class ConcatFunctor { void operator()(const platform::CUDADeviceContext& context, const std::vector& input, int axis, framework::Tensor* output) { - // TODO(zcd): Add input data validity checking - int in_num = input.size(); - int64_t in_row = 1; - auto dim_0 = input[0].dims(); - for (int i = 0; i < axis; ++i) { - in_row *= dim_0[i]; - } - int64_t in_col = input[0].numel() / in_row; - int64_t out_row = in_row, out_col = 0; - - int inputs_col_num = in_num + 1; - std::vector inputs_data_vec(in_num); - std::vector inputs_col_vec(inputs_col_num); - const T** inputs_data = inputs_data_vec.data(); - int64_t* inputs_col = inputs_col_vec.data(); - -// There are some differences between hip runtime and NV runtime. -// In NV, when the pageable memory data less than 64K is transferred from -// hosttodevice, it will be automatically asynchronous. -// However, only pinned memory in hip can copy asynchronously -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device -// 3.2.6.1. Concurrent Execution between Host and Device -// Memory copies from host to device of a memory block of 64 KB or less -#ifdef PADDLE_WITH_HIP - memory::AllocationPtr data_alloc, col_alloc; - data_alloc = - memory::Alloc(platform::CUDAPinnedPlace(), in_num * sizeof(T*)); - inputs_data = reinterpret_cast(data_alloc->ptr()); - col_alloc = memory::Alloc(platform::CUDAPinnedPlace(), - inputs_col_num * sizeof(int)); - inputs_col = reinterpret_cast(col_alloc->ptr()); -#endif - - inputs_col[0] = 0; - bool has_same_shape = true; - for (int i = 0; i < in_num; ++i) { - int64_t t_cols = input[i].numel() / in_row; - if (has_same_shape) { - if (t_cols != in_col) has_same_shape = false; - } - out_col += t_cols; - inputs_col[i + 1] = out_col; - inputs_data[i] = input[i].data(); - } - - dim3 block_dims; - dim3 grid_dims; - GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims); + std::vector pt_input{input.begin(), input.end()}; - memory::allocation::AllocationPtr tmp_dev_ins_data; - const T** dev_ins_data = nullptr; - if (!has_same_shape || in_num < 2 || in_num > 4) { - tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*)); - auto* restored = - platform::RestoreHostMemIfCapturingCUDAGraph(inputs_data, in_num); - memory::Copy(context.GetPlace(), tmp_dev_ins_data->ptr(), - platform::CPUPlace(), restored, in_num * sizeof(T*), - context.stream()); - dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); - } - - if (has_same_shape) { - if (in_num == 2) { - ConcatKernel<<>>( - inputs_data[0], inputs_data[1], in_col, out_row, out_col, - output->data()); - } else if (in_num == 3) { - ConcatKernel<<>>( - inputs_data[0], inputs_data[1], inputs_data[2], in_col, out_row, - out_col, output->data()); - } else if (in_num == 4) { - ConcatKernel<<>>( - inputs_data[0], inputs_data[1], inputs_data[2], inputs_data[3], - in_col, out_row, out_col, output->data()); - } else { - ConcatKernel<<>>( - dev_ins_data, in_num, in_col, out_row, out_col, output->data()); - } - } else { - auto tmp_dev_ins_col_data = - memory::Alloc(context, inputs_col_num * sizeof(int64_t)); - - auto* restored = platform::RestoreHostMemIfCapturingCUDAGraph( - inputs_col, inputs_col_num); - memory::Copy(context.GetPlace(), tmp_dev_ins_col_data->ptr(), - platform::CPUPlace(), restored, - inputs_col_num * sizeof(int64_t), context.stream()); - int64_t* dev_ins_col_data = - static_cast(tmp_dev_ins_col_data->ptr()); - - ConcatKernel<<>>( - dev_ins_data, dev_ins_col_data, static_cast(inputs_col_num), - out_row, out_col, output->data()); - } - -#ifdef PADDLE_WITH_HIP - // Prevent the pinned memory value from being covered and release the memory - // after the launch kernel of the stream is executed (reapply pinned memory - // next time) - auto* data_alloc_released = data_alloc.release(); - auto* col_alloc_released = col_alloc.release(); - context.AddStreamCallback([data_alloc_released, col_alloc_released] { - memory::allocation::Allocator::AllocationDeleter(data_alloc_released); - memory::allocation::Allocator::AllocationDeleter(col_alloc_released); - }); -#endif + pten::ConcatImpl(context, pt_input, axis, + output); } }; @@ -355,120 +48,12 @@ class SplitFunctor { const framework::Tensor& input, const std::vector& ref_inputs, int axis, std::vector* outputs) { - // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 - // tensors of shape [0,1,4] - if (input.numel() == 0) { - return; - } - - // TODO(zcd): Add input data validity checking - int o_num = outputs->size(); - int64_t out_row = 1; - auto dim_0 = ref_inputs[0]->dims(); - for (int i = 0; i < axis; ++i) { - out_row *= dim_0[i]; - } - - int64_t out0_col = ref_inputs[0]->numel() / out_row; - int64_t in_col = 0, in_row = out_row; - bool has_same_shape = true; - - int outputs_cols_num = o_num + 1; - std::vector outputs_data_vec(o_num); - std::vector outputs_cols_vec(outputs_cols_num); - T** outputs_data = outputs_data_vec.data(); - int64_t* outputs_cols = outputs_cols_vec.data(); - -// There are some differences between hip runtime and NV runtime. -// In NV, when the pageable memory data less than 64K is transferred from -// hosttodevice, it will be automatically asynchronous. -// However, only pinned memory in hip can copy asynchronously -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device -// 3.2.6.1. Concurrent Execution between Host and Device -// Memory copies from host to device of a memory block of 64 KB or less -#ifdef PADDLE_WITH_HIP - memory::AllocationPtr data_alloc, cols_alloc; - data_alloc = memory::Alloc(platform::CUDAPinnedPlace(), o_num * sizeof(T*)); - outputs_data = reinterpret_cast(data_alloc->ptr()); - cols_alloc = memory::Alloc(platform::CUDAPinnedPlace(), - (outputs_cols_num) * sizeof(int64_t)); - outputs_cols = reinterpret_cast(cols_alloc->ptr()); -#endif - - outputs_cols[0] = 0; - for (int i = 0; i < o_num; ++i) { - int64_t t_col = ref_inputs.at(i)->numel() / out_row; - if (has_same_shape) { - if (t_col != out0_col) has_same_shape = false; - } - in_col += t_col; - outputs_cols[i + 1] = in_col; - if (outputs->at(i) != nullptr) { - outputs_data[i] = outputs->at(i)->data(); - } else { - outputs_data[i] = nullptr; - } - } - - dim3 block_dims; - dim3 grid_dims; - GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims); - - memory::allocation::AllocationPtr tmp_dev_outs_data; - T** dev_out_gpu_data = nullptr; - if (!has_same_shape || o_num < 2 || o_num > 4) { - tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*)); - auto* restored = - platform::RestoreHostMemIfCapturingCUDAGraph(outputs_data, o_num); - memory::Copy(context.GetPlace(), tmp_dev_outs_data->ptr(), - platform::CPUPlace(), restored, o_num * sizeof(T*), - context.stream()); - dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); - } - - if (has_same_shape) { - if (o_num == 2) { - SplitKernel<<>>( - input.data(), in_row, in_col, out0_col, outputs_data[0], - outputs_data[1]); - } else if (o_num == 3) { - SplitKernel<<>>( - input.data(), in_row, in_col, out0_col, outputs_data[0], - outputs_data[1], outputs_data[2]); - } else if (o_num == 4) { - SplitKernel<<>>( - input.data(), in_row, in_col, out0_col, outputs_data[0], - outputs_data[1], outputs_data[2], outputs_data[3]); - } else { - SplitKernel<<>>( - input.data(), in_row, in_col, out0_col, dev_out_gpu_data); - } - } else { - auto tmp_dev_ins_col_data = - memory::Alloc(context, outputs_cols_num * sizeof(int64_t)); - auto* restored = platform::RestoreHostMemIfCapturingCUDAGraph( - outputs_cols, outputs_cols_num); - memory::Copy(context.GetPlace(), tmp_dev_ins_col_data->ptr(), - platform::CPUPlace(), restored, - outputs_cols_num * sizeof(int64_t), context.stream()); - int64_t* dev_outs_col_data = - reinterpret_cast(tmp_dev_ins_col_data->ptr()); - - SplitKernel<<>>( - input.data(), in_row, in_col, dev_outs_col_data, - static_cast(outputs_cols_num), dev_out_gpu_data); - } -#ifdef PADDLE_WITH_HIP - // Prevent the pinned memory value from being covered and release the memory - // after the launch kernel of the stream is executed (reapply pinned memory - // next time) - auto* data_alloc_released = data_alloc.release(); - auto* cols_alloc_released = cols_alloc.release(); - context.AddStreamCallback([data_alloc_released, cols_alloc_released] { - memory::allocation::Allocator::AllocationDeleter(data_alloc_released); - memory::allocation::Allocator::AllocationDeleter(cols_alloc_released); - }); -#endif + std::vector pt_ref_inputs{ref_inputs.begin(), + ref_inputs.end()}; + std::vector pt_outputs{outputs->begin(), + outputs->end()}; + pten::SplitImpl( + context, input, pt_ref_inputs, axis, &pt_outputs); } }; diff --git a/paddle/fluid/operators/merge_lod_tensor_op.cc b/paddle/fluid/operators/merge_lod_tensor_op.cc index 653283b604f0723cbb1794dc25f46b617ae137ac..5ebaefcf808c3c26bd7aa78ad720450eda2acf61 100644 --- a/paddle/fluid/operators/merge_lod_tensor_op.cc +++ b/paddle/fluid/operators/merge_lod_tensor_op.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" +#include "paddle/pten/core/lod_utils.h" + namespace pten { class DenseTensor; } // namespace pten @@ -122,7 +124,7 @@ class MergeLoDTensorOp : public framework::OperatorBase { input->lod(), *in_idx, (*in_idx) + 1, 0); auto &lod_length = lod_and_offset.first; - framework::AppendLoD(out_lod, lod_length); + pten::AppendLoD(out_lod, lod_length); size_t start_offset = lod_and_offset.second.first; size_t end_offset = lod_and_offset.second.second; diff --git a/paddle/fluid/operators/shrink_rnn_memory_op.cc b/paddle/fluid/operators/shrink_rnn_memory_op.cc index f39a1c0a39d6eab789906987c9f7dea92a73d2e8..493073fadc2bd19de7044db880aee46a429e5340 100644 --- a/paddle/fluid/operators/shrink_rnn_memory_op.cc +++ b/paddle/fluid/operators/shrink_rnn_memory_op.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/fluid/operators/array_operator.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/pten/core/lod_utils.h" + namespace paddle { namespace framework { class OpDesc; @@ -73,7 +75,7 @@ class ShrinkRNNMemoryOp : public ArrayOp { dst_num_rows, 0); height = lod_offset.second.second; auto out_lod = out_tensor.mutable_lod(); - framework::AppendLoD(out_lod, lod_offset.first); + pten::AppendLoD(out_lod, lod_offset.first); } if (dst_num_rows != 0) { diff --git a/paddle/fluid/operators/split_lod_tensor_op.cc b/paddle/fluid/operators/split_lod_tensor_op.cc index 9c22fa4797219f41978c6e07c2876d9cd9c15677..4cb2a292018f6ac770e385fade61b7936a11b738 100644 --- a/paddle/fluid/operators/split_lod_tensor_op.cc +++ b/paddle/fluid/operators/split_lod_tensor_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/pten/core/lod_utils.h" namespace pten { class DenseTensor; @@ -96,7 +97,7 @@ class SplitLoDTensorOp : public framework::OperatorBase { x_lod, start_idx, start_idx + 1, level); auto &lod_length = lod_and_offset.first; - framework::AppendLoD(lod, lod_length); + pten::AppendLoD(lod, lod_length); size_t start_offset = lod_and_offset.second.first; size_t end_offset = lod_and_offset.second.second; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index cdbfa11abec72df23ed5e0e7077b4910c9600731..454e3b524f5f14f3aa5b780eec2eac2305a1e1ed 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -43,7 +43,6 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/generate_pass.h" #include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/lod_rank_table.h" -#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/new_executor/standalone_executor.h" #include "paddle/fluid/framework/op_info.h" @@ -75,6 +74,7 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/pybind/cuda_streams_py.h" +#include "paddle/pten/core/lod_utils.h" #ifndef PADDLE_ON_INFERENCE #include "paddle/fluid/pybind/eager.h" #endif @@ -1093,7 +1093,7 @@ PYBIND11_MODULE(core_noavx, m) { .def("recursive_sequence_lengths", [](framework::Tensor &self) -> std::vector> { // output the length-based lod info - LoD lod = ConvertToLengthBasedLoD(self.lod()); + LoD lod = pten::ConvertToLengthBasedLoD(self.lod()); std::vector> new_lod; new_lod.reserve(lod.size()); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); diff --git a/paddle/pten/CMakeLists.txt b/paddle/pten/CMakeLists.txt index 9b6e5d70cd8994c4aee86606c57804445dc96d39..cde5e719e316d9eb172b8bdcab0a3f0309149523 100644 --- a/paddle/pten/CMakeLists.txt +++ b/paddle/pten/CMakeLists.txt @@ -18,7 +18,7 @@ add_subdirectory(ops) add_subdirectory(tests) # make an unity target for compile deps -set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta) +set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta lod_utils) get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS) # keep this message for debug, remove it later if needless message(STATUS "All standard pten kernels: ${pten_kernels}") diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index 0b17415a6a98de623be97316b84ac50f6eddea03..e3929d59159c18b6c926fbe23fe336b8ce86b8cc 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -38,6 +38,11 @@ using cast_kernel = void (*)(const DeviceContext&, DataType, DenseTensor*); +using concat_kernel = void (*)(const DeviceContext&, + const std::vector&, + const Scalar&, + DenseTensor*); + using divide_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, diff --git a/paddle/pten/api/lib/utils/tensor_utils.cc b/paddle/pten/api/lib/utils/tensor_utils.cc index 1420810007d1cb202a6a76a79a972e43918d8c07..2e94d508aec7df6ea85973118a272130a23db4d6 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.cc +++ b/paddle/pten/api/lib/utils/tensor_utils.cc @@ -38,6 +38,11 @@ std::unique_ptr MakePtenDenseTensorBase( src.dims(), src.layout(), src.offset()}; + if (!src.IsInitialized()) { + return std::make_unique( + std::move(pten::make_intrusive(src.place())), + std::move(meta)); + } auto shared_storage = pten::make_intrusive(src.Holder()); return std::make_unique(std::move(shared_storage), std::move(meta)); @@ -247,7 +252,9 @@ std::unique_ptr MakePtenTensorBaseFromVar( if (variable.IsType()) { const auto& tensor = variable.Get(); - if (!platform::is_same_place(tensor.place(), expected_place)) { + + if (tensor.IsInitialized() && + !platform::is_same_place(tensor.place(), expected_place)) { framework::LoDTensor tmp_tensor; framework::TensorCopySync(tensor, expected_place, &tmp_tensor); return MakePtenDenseTensor(tmp_tensor); diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index eabc5a19babad95cd2f5f88c46c4c59078d3e156..d89b3c9fefb590c45bb6a3f611113b28da3e51aa 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -12,6 +12,7 @@ cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce) cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce) cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector) +cc_library(lod_utils SRCS lod_utils.cc DEPS enforce mixed_vector) cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base) cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base ) diff --git a/paddle/pten/core/kernel_context.h b/paddle/pten/core/kernel_context.h index 5559b348aa1c967712fae5ee60cdc48b2594227d..5dd2bf367b3b83fbef585239af6a11c552821398 100644 --- a/paddle/pten/core/kernel_context.h +++ b/paddle/pten/core/kernel_context.h @@ -92,7 +92,7 @@ class KernelContext { std::vector MoveInputsBetween(size_t start, size_t end) { std::vector v; for (size_t i = start; i < end; ++i) { - auto t = std::dynamic_pointer_cast(inputs_.at(i)); + auto t = static_cast(inputs_.at(i)); v.emplace_back(*t); inputs_.at(i) = nullptr; } diff --git a/paddle/pten/core/lod_utils.cc b/paddle/pten/core/lod_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..ad5ea6d39d39c9c9ef7dc2b514a3ffdbafc10964 --- /dev/null +++ b/paddle/pten/core/lod_utils.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/lod_utils.h" + +#include "paddle/fluid/platform/enforce.h" + +namespace pten { + +void AppendLoD(LoD *lod, const LoD &lod_length) { + PADDLE_ENFORCE( + lod->empty() || lod->size() == lod_length.size(), + paddle::platform::errors::InvalidArgument( + "The input LoD length should be equal to the appended LoD size, but " + "received input LoD length is %d, actual LoD size is %d.", + lod_length.size(), + lod->size())); + if (lod->empty()) { + for (size_t i = 0; i < lod_length.size(); ++i) { + lod->emplace_back(1, 0); // size = 1, value = 0; + } + *lod = LoD(lod_length.size(), std::vector({0})); + } + for (size_t i = 0; i < lod->size(); ++i) { + auto &level = (*lod)[i]; + for (size_t len : lod_length[i]) { + level.push_back(level.back() + len); + } + } +} + +LoD ConvertToLengthBasedLoD(const LoD &offset_lod) { + LoD length_lod; + length_lod.reserve(offset_lod.size()); + for (size_t lvl = 0; lvl < offset_lod.size(); ++lvl) { + std::vector level; + if (offset_lod[lvl].size() > 0) { + level.reserve(offset_lod[lvl].size() - 1); + } + for (size_t idx = 0; idx < offset_lod[lvl].size() - 1; ++idx) { + level.push_back(offset_lod[lvl][idx + 1] - offset_lod[lvl][idx]); + } + length_lod.push_back(level); + } + return length_lod; +} + +} // namespace pten diff --git a/paddle/pten/core/lod_utils.h b/paddle/pten/core/lod_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..4c2547a43c0270d38c8fb56a96476dfc0eb7b71a --- /dev/null +++ b/paddle/pten/core/lod_utils.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/framework/mixed_vector.h" + +namespace pten { +using LoD = std::vector>; + +void AppendLoD(LoD* lod, const LoD& lod_length); + +/* + * Convert between length-based LoD and offset-based LoD. + * The implementation of LoDTensor class use offset-based LoD. + * However, we want to expose the more user-friendly length-based + * LoD to the Python side instead. + * + * Example: + * If offset_lod = [[0, 2, 3],[0, 3, 5, 9]] + * then length_lod = [[2, 1], [3, 2, 4]] + */ +LoD ConvertToLengthBasedLoD(const LoD& offset_lod); + +} // namespace pten diff --git a/paddle/pten/infermeta/multiary.cc b/paddle/pten/infermeta/multiary.cc index 5dbf3d58a1952576ab9dc5eee5073a7969499029..ecd0396a28688f5473e9ba0144e53a944ec135d7 100644 --- a/paddle/pten/infermeta/multiary.cc +++ b/paddle/pten/infermeta/multiary.cc @@ -14,4 +14,43 @@ limitations under the License. */ #include "paddle/pten/infermeta/multiary.h" -namespace pten {} // namespace pten +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/kernels/funcs/concat_funcs.h" +namespace pten { + +DenseTensorMeta ConcatInferMeta(const std::vector& x_meta, + const Scalar& axis_scalar, + bool is_runtime) { + PADDLE_ENFORCE_GE(x_meta.size(), + 0, + paddle::platform::errors::InvalidArgument( + "The size of input meta vector should be greater" + "than 0.")); + + int axis = axis_scalar.to(); + // 1. calculate axis + int rank = x_meta[0].dims.size(); + PADDLE_ENFORCE_EQ( + axis >= -rank && axis < rank, + true, + paddle::platform::errors::InvalidArgument( + "The axis is expected to be in range of [%d, %d), but got %d", + -rank, + rank, + axis)); + if (axis < 0) { + axis = axis + rank; + } + + // 2. calculate out dims + std::vector x_dims; + for (auto meta : x_meta) { + x_dims.push_back(meta.dims); + } + pten::DDim out_dim = + pten::funcs::ComputeAndCheckShape(is_runtime, x_dims, axis); + + return {x_meta[0].dtype, out_dim, x_meta[0].layout}; +} + +} // namespace pten diff --git a/paddle/pten/infermeta/multiary.h b/paddle/pten/infermeta/multiary.h index 6aa15159630bc7ffee9f822b146ef07e27f1795a..f8d5468e50d47f787bdaba4deaecf8c53503926f 100644 --- a/paddle/pten/infermeta/multiary.h +++ b/paddle/pten/infermeta/multiary.h @@ -14,4 +14,13 @@ limitations under the License. */ #pragma once -namespace pten {} // namespace pten +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/core/tensor_meta.h" +namespace pten { + +// TODO(chentianyu03) use std::vector as InferMeta inputs +DenseTensorMeta ConcatInferMeta(const std::vector& x_meta, + const Scalar& axis_scalar, + bool is_runtime); + +} // namespace pten diff --git a/paddle/pten/kernels/CMakeLists.txt b/paddle/pten/kernels/CMakeLists.txt index 45724e5d22abde9e218fd0a9decf31631979892a..76e112808892d79fa4143aebc715c4ff8e20c0c4 100644 --- a/paddle/pten/kernels/CMakeLists.txt +++ b/paddle/pten/kernels/CMakeLists.txt @@ -24,7 +24,7 @@ endif() # pten depends all pten kernel targets set_property(GLOBAL PROPERTY PTEN_KERNELS "") -set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory arg_map_context convert_utils) +set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils) diff --git a/paddle/pten/kernels/concat_kernel.h b/paddle/pten/kernels/concat_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..310b9ba8c0c4c50521b2dcea2de62171557f83da --- /dev/null +++ b/paddle/pten/kernels/concat_kernel.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/infermeta/multiary.h" +#include "paddle/pten/kernels/empty_kernel.h" +namespace pten { + +template +void ConcatKernel(const Context& dev_ctx, + const std::vector& x, + const Scalar& axis, + DenseTensor* out); + +template +DenseTensor Concat(const Context& dev_ctx, + const std::vector& x, + const Scalar& axis) { + std::vector x_meta; + for (auto t : x) { + x_meta.push_back(t.meta()); + } + + auto out_meta = ConcatInferMeta(x_meta, axis.to(), true); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + ConcatKernel(dev_ctx, x, axis, &dense_out); + return dense_out; +} +} // namespace pten diff --git a/paddle/pten/kernels/cpu/concat_and_split.h b/paddle/pten/kernels/cpu/concat_and_split.h new file mode 100644 index 0000000000000000000000000000000000000000..664ec6f66fc99e34ed5cd9ae903dcbbc3f1ec3a8 --- /dev/null +++ b/paddle/pten/kernels/cpu/concat_and_split.h @@ -0,0 +1,138 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +/* + * \brief Concatenate the input tensors along the dimension axis. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input[0] = [[1,2],[3,4]] + * Input[1] = [[5,6]] + * axis = 0 + * + * Output = [[1,2], + * [3,4], + * [5,6]] + */ + +template +void ConcatImpl(const Context& context, + const std::vector& input, + int axis, + DenseTensor* output) { + // TODO(zcd): Add input data validity checking + size_t num = input.size(); + + int64_t rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int64_t out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (size_t i = 0; i < num; ++i) { + int64_t t_cols = input[i].numel() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + auto cpu_place = context.GetPlace(); + + // computation + auto output_data = output->data(); + int64_t col_idx = 0; + for (size_t j = 0; j < num; ++j) { + int64_t col_len = input_cols[j]; + auto input_data = input[j].data(); + for (int64_t k = 0; k < out_rows; ++k) { + paddle::memory::Copy(cpu_place, + output_data + k * out_cols + col_idx, + cpu_place, + input_data + k * col_len, + sizeof(T) * col_len); + } + col_idx += col_len; + } +} + +/* + * \brief Split the input tensors along the dimension axis into outputs. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input = [[1,2], + * [3,4], + * [5,6]] + * axis = 0 + * + * Output[0] = [[1,2],[3,4]] + * Output[1] = [[5,6]] + */ +template +void SplitImpl(const Context& context, + const DenseTensor& input, + const std::vector& ref_inputs, + const int axis, + std::vector* outputs) { + // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 + // tensors of shape [0,1,4] + if (input.numel() == 0) { + return; + } + + // TODO(zcd): Add input data validity checking + size_t num = outputs->size(); + + int input_rows = 1; + auto dim_0 = ref_inputs[0]->dims(); + for (int i = 0; i < axis; ++i) { + input_rows *= dim_0[i]; + } + + int input_cols = 0; + + std::vector output_cols(outputs->size()); + for (size_t i = 0; i < num; ++i) { + int t_cols = ref_inputs[i]->numel() / input_rows; + input_cols += t_cols; + output_cols[i] = t_cols; + } + auto cpu_place = context.GetPlace(); + + // computation + for (int k = 0; k < input_rows; ++k) { + const T* src_ptr = input.data() + k * input_cols; + int col_idx = 0; + for (size_t j = 0; j < num; ++j) { + int col_len = output_cols[j]; + auto* out_tensor = outputs->at(j); + if (out_tensor != nullptr) { + T* dst_ptr = out_tensor->data() + k * col_len; + paddle::memory::Copy(cpu_place, + dst_ptr, + cpu_place, + src_ptr + col_idx, + sizeof(T) * col_len); + } + col_idx += col_len; + } + } +} + +} // namespace pten diff --git a/paddle/pten/kernels/cpu/concat_kernel.cc b/paddle/pten/kernels/cpu/concat_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb59c9c6005ff7b5d9acd1480c7145225ea07378 --- /dev/null +++ b/paddle/pten/kernels/cpu/concat_kernel.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/concat_kernel.h" + +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/core/lod_utils.h" +#include "paddle/pten/kernels/cpu/concat_and_split.h" +#include "paddle/pten/kernels/funcs/concat_funcs.h" + +namespace pten { + +template +void ConcatKernel(const Context& dev_ctx, + const std::vector& x, + const Scalar& axis_scalar, + DenseTensor* out) { + int64_t axis = axis_scalar.to(); + + axis = pten::funcs::ComputeAxis(axis, x[0].dims().size()); + + std::vector x_dims; + for (size_t i = 0; i < x.size(); ++i) { + x_dims.push_back(x[i].dims()); + } + + pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis); + out->Resize(out_dims); + out->mutable_data(); + + // If axis is 0, the lod of the output is not the same as inputs. + if (axis == 0 && x[0].lod().size() > 0) { + size_t lod_size_0 = x[0].lod().size(); + size_t lod_size = lod_size_0; + for (size_t i = 1; i < x.size(); ++i) { + if (x[i].lod().size() > 0) { + PADDLE_ENFORCE_EQ( + x[i].lod().size(), + lod_size_0, + paddle::platform::errors::Unimplemented( + "The lod level of all input LoDTensors should be same. " + "Maybe different lod level of input LoDTensors can concat," + "it is not supported currently. The lod level of %dth input " + "is %d and first input is %d.", + i, + x[i].lod().size(), + lod_size_0)); + } else { + lod_size = 0; + break; + } + } + if (lod_size) { + auto* out_lod = out->mutable_lod(); + for (size_t i = 1; i < x.size(); ++i) { + auto in_lod = pten::ConvertToLengthBasedLoD(x[i].lod()); + pten::AppendLoD(out_lod, in_lod); + } + } + } + + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && x.size() < 10) { + size_t output_offset = 0; + for (auto& in : x) { + if (in.numel() == 0UL) { + continue; + } + auto in_stride = paddle::framework::stride_numel(in.dims()); + auto out_stride = paddle::framework::stride_numel(out->dims()); + paddle::operators::StridedNumelCopyWithAxis( + dev_ctx, + axis, + out->data() + output_offset, + out_stride, + in.data(), + in_stride, + in_stride[axis]); + output_offset += in_stride[axis]; + } + } else { + std::vector inputs; + for (size_t j = 0; j < x.size(); ++j) { + if (x[j].numel() > 0) { + inputs.push_back(x[j]); + } else { + continue; + } + } + ConcatImpl(dev_ctx, inputs, axis, out); + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(concat, + CPU, + ALL_LAYOUT, + pten::ConcatKernel, + float, + double, + bool, + int64_t, + int, + uint8_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/funcs/concat_funcs.h b/paddle/pten/kernels/funcs/concat_funcs.h new file mode 100644 index 0000000000000000000000000000000000000000..8455b8096922c3c6c78a20ac7cf05895c668f0ca --- /dev/null +++ b/paddle/pten/kernels/funcs/concat_funcs.h @@ -0,0 +1,95 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" +namespace pten { +namespace funcs { + +static inline int64_t ComputeAxis(int64_t axis, int64_t rank) { + PADDLE_ENFORCE_EQ( + axis >= -rank && axis < rank, + true, + paddle::platform::errors::InvalidArgument( + "The axis is expected to be in range of [%d, %d), but got %d", + -rank, + rank, + axis)); + if (axis < 0) { + axis = axis + rank; + } + return axis > 0 ? axis : 0; +} + +static inline pten::DDim ComputeAndCheckShape( + const bool is_runtime, + const std::vector& inputs_dims, + const size_t axis) { + const size_t n = inputs_dims.size(); + auto out_dims = inputs_dims[0]; + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; i++) { + PADDLE_ENFORCE_EQ(inputs_dims[i].size(), + out_dims.size(), + paddle::platform::errors::InvalidArgument( + "The shape of input[0] and input[%d] " + "is expected to be equal." + "But received input[0]'s shape = " + "[%s], input[%d]'s shape = [%s].", + i, + inputs_dims[0], + i, + inputs_dims[i])); + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + if (is_runtime) { + out_dims[axis] += inputs_dims[i][j]; + } else { + if (inputs_dims[i][j] == -1 || out_dims[j] == -1) { + out_dims[axis] = -1; + } else { + out_dims[axis] += inputs_dims[i][j]; + } + } + } else { + bool check_shape = + is_runtime || (inputs_dims[0][j] > 0 && inputs_dims[i][j] > 0); + if (check_shape) { + // check all shape in run time + PADDLE_ENFORCE_EQ(inputs_dims[0][j], + inputs_dims[i][j], + paddle::platform::errors::InvalidArgument( + "The %d-th dimension of input[0] and input[%d] " + "is expected to be equal." + "But received input[0]'s shape = " + "[%s], input[%d]'s shape = [%s].", + j, + i, + inputs_dims[0], + i, + inputs_dims[i])); + } + if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) { + out_dims[j] = inputs_dims[i][j]; + } + } + } + } + return out_dims; +} + +} // namespace funcs +} // namespace pten diff --git a/paddle/pten/kernels/gpu/concat_and_split.h b/paddle/pten/kernels/gpu/concat_and_split.h new file mode 100644 index 0000000000000000000000000000000000000000..66b21b5f5135166a24d1194dce1086417ba7bfc0 --- /dev/null +++ b/paddle/pten/kernels/gpu/concat_and_split.h @@ -0,0 +1,569 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "gflags/gflags.h" +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +#include "paddle/pten/backends/gpu/gpu_context.h" + +namespace pten { + +template +__global__ void ConcatKernel_(const T** inputs, + const int64_t* input_cols, + int col_size, + const int64_t output_rows, + const int64_t output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int curr_segment = 0; + int curr_offset = input_cols[0]; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + int curr_col_offset = input_cols[curr_segment + 1]; + while (curr_col_offset <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + curr_col_offset = input_cols[curr_segment + 1]; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + + const T* input_ptr = inputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * segment_width + local_col]; + } +} + +template +__device__ void ConcatKernelDetail(const T** inputs_data, + const int fixed_in_col, + const int out_rows, + const int out_cols, + T* output_data) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * 1.0 / fixed_in_col; + int in_offset = tid_x - split * fixed_in_col; + const T* input_ptr = inputs_data[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { + output_data[tid_y * out_cols + tid_x] = + input_ptr[tid_y * fixed_in_col + in_offset]; + } + } +} + +template +__global__ void ConcatKernel_(const T* input_addr0, + const T* input_addr1, + const int64_t fixed_in_col, + const int64_t out_rows, + const int64_t out_cols, + T* output_data) { + const T* inputs_data[2]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void ConcatKernel_(const T* input_addr0, + const T* input_addr1, + const T* input_addr2, + const int64_t fixed_in_col, + const int64_t out_rows, + const int64_t out_cols, + T* output_data) { + const T* inputs_data[3]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + inputs_data[2] = input_addr2; + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void ConcatKernel_(const T* input_addr0, + const T* input_addr1, + const T* input_addr2, + const T* input_addr3, + const int64_t fixed_in_col, + const int64_t out_rows, + const int64_t out_cols, + T* output_data) { + const T* inputs_data[4]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + inputs_data[2] = input_addr2; + inputs_data[3] = input_addr3; + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void ConcatKernel_(const T** inputs_data, + const int in_num, + const int64_t fixed_in_col, + const int64_t out_rows, + const int64_t out_cols, + T* output_data) { + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void SplitKernel(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t* out_cols, + int out_cols_size, + T** outputs_data) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int curr_segment = 0; + int curr_offset = out_cols[0]; + for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { + int curr_col_offset = out_cols[curr_segment + 1]; + while (curr_col_offset <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + curr_col_offset = out_cols[curr_segment + 1]; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* output_ptr = outputs_data[curr_segment]; + if (output_ptr != nullptr) { + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input_data[tid_y * in_col + tid_x]; + } + } +} + +template +__device__ void SplitKernelDetail(const T* input_data, + const int in_row, + const int in_col, + const int fixed_out_col, + T** outputs_data) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { + int split = tid_x / fixed_out_col; + int in_offset = tid_x - split * fixed_out_col; + T* output_ptr = outputs_data[split]; + if (output_ptr != nullptr) { + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * fixed_out_col + in_offset] = + input_data[tid_y * in_col + tid_x]; + } + } +} + +template +__global__ void SplitKernel(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T** outputs_data) { + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +template +__global__ void SplitKernel(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1) { + T* outputs_data[2]; + outputs_data[0] = outputs_addr0; + outputs_data[1] = outputs_addr1; + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +template +__global__ void SplitKernel(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1, + T* outputs_addr2) { + T* outputs_data[3]; + outputs_data[0] = outputs_addr0; + outputs_data[1] = outputs_addr1; + outputs_data[2] = outputs_addr2; + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +template +__global__ void SplitKernel(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1, + T* outputs_addr2, + T* outputs_addr3) { + T* outputs_data[4]; + outputs_data[0] = outputs_addr0; + outputs_data[1] = outputs_addr1; + outputs_data[2] = outputs_addr2; + outputs_data[3] = outputs_addr3; + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +static inline void GetBlockDims( + const paddle::platform::CUDADeviceContext& context, + int64_t num_rows, + int64_t num_cols, + dim3* block_dims, + dim3* grid_dims) { + // Set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((num_cols + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + *block_dims = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((num_cols + block_cols - 1) / block_cols, max_blocks); + int grid_rows = std::min(max_blocks / grid_cols, + std::max(num_rows / block_rows, (int64_t)1)); + *grid_dims = dim3(grid_cols, grid_rows, 1); +} + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +void ConcatImpl(const Context& context, + const std::vector& input, + int axis, + pten::DenseTensor* output) { + // TODO(zcd): Add input data validity checking + int in_num = input.size(); + int64_t in_row = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + in_row *= dim_0[i]; + } + int64_t in_col = input[0].numel() / in_row; + int64_t out_row = in_row, out_col = 0; + + int inputs_col_num = in_num + 1; + std::vector inputs_data_vec(in_num); + std::vector inputs_col_vec(inputs_col_num); + const T** inputs_data = inputs_data_vec.data(); + int64_t* inputs_col = inputs_col_vec.data(); + +// There are some differences between hip runtime and NV runtime. +// In NV, when the pageable memory data less than 64K is transferred from +// hosttodevice, it will be automatically asynchronous. +// However, only pinned memory in hip can copy asynchronously +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device +// 3.2.6.1. Concurrent Execution between Host and Device +// Memory copies from host to device of a memory block of 64 KB or less +#ifdef PADDLE_WITH_HIP + paddle::memory::AllocationPtr data_alloc, col_alloc; + // TODO(chentianyu03): try to find a method to remove the Alloc function + data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), + in_num * sizeof(T*)); + inputs_data = reinterpret_cast(data_alloc->ptr()); + // TODO(chentianyu03): try to find a method to remove the Alloc function + col_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), + inputs_col_num * sizeof(int)); + inputs_col = reinterpret_cast(col_alloc->ptr()); +#endif + + inputs_col[0] = 0; + bool has_same_shape = true; + for (int i = 0; i < in_num; ++i) { + int64_t t_cols = input[i].numel() / in_row; + if (has_same_shape) { + if (t_cols != in_col) has_same_shape = false; + } + out_col += t_cols; + inputs_col[i + 1] = out_col; + inputs_data[i] = input[i].data(); + } + + dim3 block_dims; + dim3 grid_dims; + GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims); + + paddle::memory::allocation::AllocationPtr tmp_dev_ins_data; + const T** dev_ins_data = nullptr; + if (!has_same_shape || in_num < 2 || in_num > 4) { + tmp_dev_ins_data = paddle::memory::Alloc(context, in_num * sizeof(T*)); + auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + inputs_data, in_num); + paddle::memory::Copy(context.GetPlace(), + tmp_dev_ins_data->ptr(), + paddle::platform::CPUPlace(), + restored, + in_num * sizeof(T*), + context.stream()); + dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); + } + + if (has_same_shape) { + if (in_num == 2) { + ConcatKernel_<<>>( + inputs_data[0], + inputs_data[1], + in_col, + out_row, + out_col, + output->data()); + } else if (in_num == 3) { + ConcatKernel_<<>>( + inputs_data[0], + inputs_data[1], + inputs_data[2], + in_col, + out_row, + out_col, + output->data()); + } else if (in_num == 4) { + ConcatKernel_<<>>( + inputs_data[0], + inputs_data[1], + inputs_data[2], + inputs_data[3], + in_col, + out_row, + out_col, + output->data()); + } else { + ConcatKernel_<<>>( + dev_ins_data, in_num, in_col, out_row, out_col, output->data()); + } + } else { + auto tmp_dev_ins_col_data = + paddle::memory::Alloc(context, inputs_col_num * sizeof(int64_t)); + + auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + inputs_col, inputs_col_num); + paddle::memory::Copy(context.GetPlace(), + tmp_dev_ins_col_data->ptr(), + paddle::platform::CPUPlace(), + restored, + inputs_col_num * sizeof(int64_t), + context.stream()); + int64_t* dev_ins_col_data = + static_cast(tmp_dev_ins_col_data->ptr()); + + ConcatKernel_<<>>( + dev_ins_data, + dev_ins_col_data, + static_cast(inputs_col_num), + out_row, + out_col, + output->data()); + } + +#ifdef PADDLE_WITH_HIP + // Prevent the pinned memory value from being covered and release the memory + // after the launch kernel of the stream is executed (reapply pinned memory + // next time) + auto* data_alloc_released = data_alloc.release(); + auto* col_alloc_released = col_alloc.release(); + context.AddStreamCallback([data_alloc_released, col_alloc_released] { + paddle::memory::allocation::Allocator::AllocationDeleter( + data_alloc_released); + paddle::memory::allocation::Allocator::AllocationDeleter( + col_alloc_released); + }); +#endif +} + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +void SplitImpl(const Context& context, + const pten::DenseTensor& input, + const std::vector& ref_inputs, + int axis, + std::vector* outputs) { + // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 + // tensors of shape [0,1,4] + if (input.numel() == 0) { + return; + } + + // TODO(zcd): Add input data validity checking + int o_num = outputs->size(); + int64_t out_row = 1; + auto dim_0 = ref_inputs[0]->dims(); + for (int i = 0; i < axis; ++i) { + out_row *= dim_0[i]; + } + + int64_t out0_col = ref_inputs[0]->numel() / out_row; + int64_t in_col = 0, in_row = out_row; + bool has_same_shape = true; + + int outputs_cols_num = o_num + 1; + std::vector outputs_data_vec(o_num); + std::vector outputs_cols_vec(outputs_cols_num); + T** outputs_data = outputs_data_vec.data(); + int64_t* outputs_cols = outputs_cols_vec.data(); + +// There are some differences between hip runtime and NV runtime. +// In NV, when the pageable memory data less than 64K is transferred from +// hosttodevice, it will be automatically asynchronous. +// However, only pinned memory in hip can copy asynchronously +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device +// 3.2.6.1. Concurrent Execution between Host and Device +// Memory copies from host to device of a memory block of 64 KB or less +#ifdef PADDLE_WITH_HIP + paddle::memory::AllocationPtr data_alloc, cols_alloc; + // TODO(chentianyu03): try to find a method to remove the Alloc function + data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), + o_num * sizeof(T*)); + outputs_data = reinterpret_cast(data_alloc->ptr()); + // TODO(chentianyu03): try to find a method to remove the Alloc function + cols_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), + (outputs_cols_num) * sizeof(int64_t)); + outputs_cols = reinterpret_cast(cols_alloc->ptr()); +#endif + + outputs_cols[0] = 0; + for (int i = 0; i < o_num; ++i) { + int64_t t_col = ref_inputs.at(i)->numel() / out_row; + if (has_same_shape) { + if (t_col != out0_col) has_same_shape = false; + } + in_col += t_col; + outputs_cols[i + 1] = in_col; + if (outputs->at(i) != nullptr) { + outputs_data[i] = outputs->at(i)->data(); + } else { + outputs_data[i] = nullptr; + } + } + + dim3 block_dims; + dim3 grid_dims; + GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims); + + paddle::memory::allocation::AllocationPtr tmp_dev_outs_data; + T** dev_out_gpu_data = nullptr; + if (!has_same_shape || o_num < 2 || o_num > 4) { + // TODO(chentianyu03): try to find a method to remove the Alloc function + tmp_dev_outs_data = paddle::memory::Alloc(context, o_num * sizeof(T*)); + auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + outputs_data, o_num); + paddle::memory::Copy(context.GetPlace(), + tmp_dev_outs_data->ptr(), + paddle::platform::CPUPlace(), + restored, + o_num * sizeof(T*), + context.stream()); + dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); + } + + if (has_same_shape) { + if (o_num == 2) { + SplitKernel<<>>( + input.data(), + in_row, + in_col, + out0_col, + outputs_data[0], + outputs_data[1]); + } else if (o_num == 3) { + SplitKernel<<>>( + input.data(), + in_row, + in_col, + out0_col, + outputs_data[0], + outputs_data[1], + outputs_data[2]); + } else if (o_num == 4) { + SplitKernel<<>>( + input.data(), + in_row, + in_col, + out0_col, + outputs_data[0], + outputs_data[1], + outputs_data[2], + outputs_data[3]); + } else { + SplitKernel<<>>( + input.data(), in_row, in_col, out0_col, dev_out_gpu_data); + } + } else { + auto tmp_dev_ins_col_data = + // TODO(chentianyu03): try to find a method to remove the Alloc function + paddle::memory::Alloc(context, outputs_cols_num * sizeof(int64_t)); + auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + outputs_cols, outputs_cols_num); + paddle::memory::Copy(context.GetPlace(), + tmp_dev_ins_col_data->ptr(), + paddle::platform::CPUPlace(), + restored, + outputs_cols_num * sizeof(int64_t), + context.stream()); + int64_t* dev_outs_col_data = + reinterpret_cast(tmp_dev_ins_col_data->ptr()); + + SplitKernel<<>>( + input.data(), + in_row, + in_col, + dev_outs_col_data, + static_cast(outputs_cols_num), + dev_out_gpu_data); + } +#ifdef PADDLE_WITH_HIP + // Prevent the pinned memory value from being covered and release the memory + // after the launch kernel of the stream is executed (reapply pinned memory + // next time) + auto* data_alloc_released = data_alloc.release(); + auto* cols_alloc_released = cols_alloc.release(); + context.AddStreamCallback([data_alloc_released, cols_alloc_released] { + paddle::memory::allocation::Allocator::AllocationDeleter( + data_alloc_released); + paddle::memory::allocation::Allocator::AllocationDeleter( + cols_alloc_released); + }); +#endif +} + +} // namespace pten diff --git a/paddle/pten/kernels/gpu/concat_kernel.cu b/paddle/pten/kernels/gpu/concat_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..6ddfef460fc6cf2945903fbb70574272e4e18e55 --- /dev/null +++ b/paddle/pten/kernels/gpu/concat_kernel.cu @@ -0,0 +1,125 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/concat_kernel.h" + +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/core/lod_utils.h" +#include "paddle/pten/kernels/funcs/concat_funcs.h" +#include "paddle/pten/kernels/gpu/concat_and_split.h" + +namespace pten { + +template +void ConcatKernel(const Context& dev_ctx, + const std::vector& x, + const Scalar& axis_scalar, + DenseTensor* out) { + int64_t axis = axis_scalar.to(); + + axis = pten::funcs::ComputeAxis(axis, x[0].dims().size()); + + std::vector x_dims; + for (size_t i = 0; i < x.size(); ++i) { + x_dims.push_back(x[i].dims()); + } + + pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis); + out->Resize(out_dims); + out->mutable_data(); + + // If axis is 0, the lod of the output is not the same as inputs. + if (axis == 0 && x[0].lod().size() > 0) { + size_t lod_size_0 = x[0].lod().size(); + size_t lod_size = lod_size_0; + for (size_t i = 1; i < x.size(); ++i) { + if (x[i].lod().size() > 0) { + PADDLE_ENFORCE_EQ( + x[i].lod().size(), + lod_size_0, + paddle::platform::errors::Unimplemented( + "The lod level of all input LoDTensors should be same. " + "Maybe different lod level of input LoDTensors can concat," + "it is not supported currently. The lod level of %dth input " + "is %d and first input is %d.", + i, + x[i].lod().size(), + lod_size_0)); + } else { + lod_size = 0; + break; + } + } + if (lod_size) { + auto* out_lod = out->mutable_lod(); + for (size_t i = 1; i < x.size(); ++i) { + auto in_lod = pten::ConvertToLengthBasedLoD(x[i].lod()); + pten::AppendLoD(out_lod, in_lod); + } + } + } + + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && x.size() < 10) { + size_t output_offset = 0; + for (auto& in : x) { + if (in.numel() == 0UL) { + continue; + } + auto in_stride = paddle::framework::stride_numel(in.dims()); + auto out_stride = paddle::framework::stride_numel(out->dims()); + paddle::operators::StridedNumelCopyWithAxis( + dev_ctx, + axis, + out->data() + output_offset, + out_stride, + in.data(), + in_stride, + in_stride[axis]); + output_offset += in_stride[axis]; + } + } else { + std::vector inputs; + for (size_t j = 0; j < x.size(); ++j) { + if (x[j].numel() > 0) { + inputs.push_back(x[j]); + } else { + continue; + } + } + ConcatImpl(dev_ctx, inputs, axis, out); + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(concat, + GPU, + ALL_LAYOUT, + pten::ConcatKernel, + float, + double, + bool, + int64_t, + int, + uint8_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index 79d9a3d82e69ed8c403fadc797e6397bca3dc30f..e9faa22c4eb7b109d4aa62e83b58f2be7e5279fe 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -21,3 +21,4 @@ cc_test(test_sum_api SRCS test_sum_api.cc DEPS pten_tensor pten_api pten_api_uti cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_conj_api SRCS test_conj_api.cc DEPS pten_tensor pten_api pten_api_utils) +cc_test(test_concat_api SRCS test_concat_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_concat_api.cc b/paddle/pten/tests/api/test_concat_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..e84aee0aaaf4ff6151327fc556595f13eb7efc1f --- /dev/null +++ b/paddle/pten/tests/api/test_concat_api.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/api/include/api.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace paddle { +namespace tests { + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +// TODO(chentianyu03): Remove this test after the API is used in the dygraph +TEST(API, concat) { + // 1. create tensor + const auto alloc = std::make_unique( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc.get(), + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 10}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x->mutable_data(); + + auto dense_y = std::make_shared( + alloc.get(), + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 10}), + pten::DataLayout::NCHW)); + auto* dense_y_data = dense_y->mutable_data(); + + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0; + dense_y_data[i * 10 + j] = (i * 10 + j) * 1.0; + } + } + + paddle::experimental::Tensor x(dense_x); + paddle::experimental::Tensor y(dense_y); + + std::vector inputs{x, y}; + + // 2. test API + auto out = paddle::experimental::concat(inputs, 0); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 6); + ASSERT_EQ(out.dims()[1], 10); + ASSERT_EQ(out.numel(), 60); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto out_data = dense_out->data(); + for (size_t i = 0; i < 60; ++i) { + if (i < 30) { + ASSERT_NEAR(dense_x_data[i], out_data[i], 1e-6f); + } else { + ASSERT_NEAR(dense_y_data[i - 30], out_data[i], 1e-6f); + } + } +} + +} // namespace tests +} // namespace paddle diff --git a/paddle/pten/tests/kernels/CMakeLists.txt b/paddle/pten/tests/kernels/CMakeLists.txt index 6f70f2ca2c895a042c1aad0a43b4ff70966f256a..407e5c097aec44d9f70d0d774b04c49f283bdd0e 100644 --- a/paddle/pten/tests/kernels/CMakeLists.txt +++ b/paddle/pten/tests/kernels/CMakeLists.txt @@ -10,3 +10,4 @@ cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_utils) cc_test(test_sum_dev_api SRCS test_sum_dev_api.cc DEPS pten pten_api_utils) cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils) +cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS pten pten_api_utils) diff --git a/paddle/pten/tests/kernels/test_concat_dev_api.cc b/paddle/pten/tests/kernels/test_concat_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..c5d979ad908fff52e4d8c95db5e01acc0c50a2f6 --- /dev/null +++ b/paddle/pten/tests/kernels/test_concat_dev_api.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/kernels/concat_kernel.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { +namespace tests { + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +TEST(DEV_API, concat) { + // 1. create tensor + const auto alloc = std::make_unique( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc.get(), + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 10}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + + pten::DenseTensor dense_y(alloc.get(), + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 10}), + pten::DataLayout::NCHW)); + auto* dense_y_data = dense_y.mutable_data(); + + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0; + dense_y_data[i * 10 + j] = (i * 10 + j) * 1.0; + } + } + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + std::vector inputs = {dense_x, dense_y}; + + // 2. test API + auto out = pten::Concat( + *(static_cast(dev_ctx)), inputs, 0); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 6); + ASSERT_EQ(out.dims()[1], 10); + ASSERT_EQ(out.meta().dtype, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + auto out_data = out.data(); + + for (size_t i = 0; i < 60; ++i) { + if (i < 30) { + ASSERT_NEAR(dense_x_data[i], out_data[i], 1e-6f); + } else { + ASSERT_NEAR(dense_y_data[i - 30], out_data[i], 1e-6f); + } + } +} + +} // namespace tests +} // namespace pten diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 562a726aa29f27bb6b7017c72ba86dc9f33372c3..1bf5344e8374677b87c7ae677a1db95e53b619e1 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -18,6 +18,16 @@ param : [x, out_dtype] data_type : x + +- api : concat + args : (const std::vector& x, const Scalar& axis) + output : Tensor + infer_meta : + func : ConcatInferMeta + param : [x, axis, true] + kernel : + func : concat + - api : conj args : (const Tensor& x) output : Tensor diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index e8539b11d14550dfa76eb9affd7da2ce2b7aed4c..c99473158524637de112289e58182cd14bea60fc 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -58,7 +58,10 @@ class API: f"Args declaration should start with '(' and end with ')', please check the args of {self.api} in api.yaml." args_str = args_str[1:-1] args_list = args_str.split(',') - input_types = ['const Tensor&', 'const Tensor &'] + input_types = [ + 'const Tensor&', 'const Tensor &', 'const std::vector&', + 'const std::vector &' + ] attr_types = ['const Scalar&', 'const Scalar &', 'const ScalarArray&', 'const ScalarArray &', \ 'int', 'int32_t', 'int64_t', 'size_t', 'float', 'double', 'bool', \ 'const std::vector&', 'Backend', 'DataLayout', 'DataType'] @@ -247,7 +250,7 @@ PADDLE_API {self.output} {self.api}({self.args['args_declare']}); param_code = "" for param in infer_meta_params: if param in input_names: - param_code = param_code + self.prefix_tensor_name + param + "->meta(), " + param_code = param_code + "GetDenseTensorMeta(" + self.prefix_tensor_name + param + "), " elif param in attr_names: param_code = param_code + param + ", " elif isinstance(param, str): @@ -267,7 +270,7 @@ PADDLE_API {self.output} {self.api}({self.args['args_declare']}); for input_name in input_names: # set input code input_tensor_code = input_tensor_code + f""" - auto {self.prefix_tensor_name}{input_name} = std::dynamic_pointer_cast({input_name}.impl());""" + auto {self.prefix_tensor_name}{input_name} = TensorToDenseTensor({input_name});""" attr_names = attrs['names'] if kernel_param is None: @@ -374,6 +377,35 @@ namespace experimental { """) +def tensor_to_densetensor(): + return """ + std::shared_ptr TensorToDenseTensor(const Tensor& tensor) { + return std::dynamic_pointer_cast(tensor.impl()); + } + + std::shared_ptr> TensorToDenseTensor(const std::vector& tensors) { + std::vector pt_tensors; + + for(auto & t : tensors) { + pt_tensors.push_back(*std::dynamic_pointer_cast(t.impl())); + } + return std::make_shared>(pt_tensors); + } + + const pten::DenseTensorMeta GetDenseTensorMeta(const std::shared_ptr & x) { + return x->meta(); + } + + const std::vector GetDenseTensorMeta(const std::shared_ptr>& x) { + std::vector metas; + for(auto& t : *x) { + metas.push_back(t.meta()); + } + return metas; + } +""" + + def generate_api(api_yaml_path, header_file_path, source_file_path): with open(api_yaml_path, 'r') as f: @@ -390,6 +422,7 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): include_header_file = "paddle/pten/api/include/api.h" source_file.write(source_include(include_header_file)) source_file.write(namespace[0]) + source_file.write(tensor_to_densetensor()) for api in apis: api_code = API(api)