/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include "paddle/phi/backends/all_context.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/utils/data_type.h" namespace phi { namespace funcs { template void BatchTranspose(T* output, const T* input, int64_t batch, int64_t m, int64_t n, const phi::GPUContext* dev_ctx); template struct TransposeNormal { // for dims >= 7 situation void operator()(const DeviceContext& context, const phi::DenseTensor& in, phi::DenseTensor* out, const std::vector& axis); }; template struct Transpose { void operator()(const DeviceContext& context, const phi::DenseTensor& in, phi::DenseTensor* out, const std::vector& axis); }; template struct SetConstant { void operator()(const DeviceContext& context, phi::DenseTensor* tensor, T num); }; #ifdef PADDLE_WITH_XPU template struct SetConstant { void operator()(const phi::XPUContext& context, phi::DenseTensor* tensor, T num); }; #endif template void set_constant_with_place(const phi::DeviceContext& context, phi::DenseTensor* tensor, float value); void set_constant(const phi::DeviceContext& context, phi::DenseTensor* tensor, float value); template struct RowwiseAdd { void operator()(const DeviceContext& context, const phi::DenseTensor& input, const phi::DenseTensor& vec, phi::DenseTensor* output); }; template struct ColwiseSum { void operator()(const DeviceContext& context, const phi::DenseTensor& input, phi::DenseTensor* vec); }; template struct RowwiseSum { void operator()(const DeviceContext& context, const phi::DenseTensor& input, phi::DenseTensor* vec); }; template struct RowwiseMean { void operator()(const DeviceContext& context, const phi::DenseTensor& input, phi::DenseTensor* vec); }; #ifdef PADDLE_WITH_XPU template struct TensorSetConstantXPU { TensorSetConstantXPU(phi::DenseTensor* tensor, U value, phi::Place place) : tensor_(tensor), value_(value), place_(place) {} template void apply() const { auto* begin = tensor_->mutable_data(place_); int numel = tensor_->numel(); std::unique_ptr data_cpu(new T[numel]); std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast(value_)); memory_utils::Copy(place_, begin, phi::CPUPlace(), static_cast(data_cpu.get()), numel * sizeof(T)); } phi::DenseTensor* tensor_; U value_; phi::Place place_; }; #endif template inline void TransCompute(const int dim, const Context& dev_ctx, const DenseTensor& in, DenseTensor* out, const std::vector& axis) { switch (dim) { case 1: Transpose trans1; trans1(dev_ctx, in, out, axis); break; case 2: Transpose trans2; trans2(dev_ctx, in, out, axis); break; case 3: Transpose trans3; trans3(dev_ctx, in, out, axis); break; case 4: Transpose trans4; trans4(dev_ctx, in, out, axis); break; case 5: Transpose trans5; trans5(dev_ctx, in, out, axis); break; case 6: Transpose trans6; trans6(dev_ctx, in, out, axis); break; default: // for dim >= 7 situation TransposeNormal trans_normal; trans_normal(dev_ctx, in, out, axis); } } } // namespace funcs } // namespace phi