/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/extension/include/ext_tensor.h" #include #include "paddle/fluid/framework/custom_tensor_utils.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/transform.h" namespace paddle { template struct CastDataTypeFunctor { HOSTDEVICE inline OutType operator()(InType in) const { return static_cast(in); } }; template struct CastDataType { CastDataType(const framework::Tensor &in, framework::Tensor *out, const platform::DeviceContext *ctx) : in_(in), out_(out), ctx_(ctx) {} const framework::Tensor in_; framework::Tensor *out_; const platform::DeviceContext *ctx_; template void apply() { auto *in_begin = in_.data(); auto *in_end = in_begin + in_.numel(); auto *out_begin = out_->mutable_data(in_.place()); if (platform::is_cpu_place(in_.place())) { platform::Transform trans; auto *context = static_cast(ctx_); trans(*context, in_begin, in_end, out_begin, CastDataTypeFunctor()); #ifdef __NVCC__ } else if (platform::is_gpu_place(in_.place())) { platform::Transform trans; auto *context = static_cast(ctx_); trans(*context, in_begin, in_end, out_begin, CastDataTypeFunctor()); context->Wait(); #endif } else { PADDLE_THROW(platform::errors::Unimplemented( "Place type is not supported when casting data type.")); } } }; template void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc, int64_t ele_size) { #ifdef PADDLE_WITH_CUDA platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); int device_num = paddle::platform::GetCurrentDeviceId(); platform::CUDAPlace gpu_place(device_num); auto *dev_ctx = static_cast(pool.Get(gpu_place)); if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kCPU)) { memory::Copy(platform::CPUPlace(), static_cast(dst), gpu_place, src, ele_size, dev_ctx->stream()); } else if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kGPU)) { memory::Copy(gpu_place, static_cast(dst), gpu_place, src, ele_size, dev_ctx->stream()); } else if ((src_plc == PlaceType::kCPU) && (dst_plc == PlaceType::kGPU)) { memory::Copy(gpu_place, static_cast(dst), platform::CPUPlace(), src, ele_size, dev_ctx->stream()); } else { PADDLE_THROW(platform::errors::Unavailable( "Only GPU related Copy can reach this func.")); } cudaStreamSynchronize(dev_ctx->stream()); #endif } #define GET_CASTED_TENSOR \ if (!tensor_) { \ tensor_ = std::make_shared(); \ } \ auto *tensor = static_cast(tensor_.get()); void Tensor::reshape(const std::vector &shape) { GET_CASTED_TENSOR auto new_dim = framework::make_ddim(shape); if (tensor->numel() != framework::product(new_dim)) { LOG(WARNING) << "Custom Op: Calling reshape to a new shape which is bigger " "or smaller" << "than original shape will not change your tensor's memory " "Please call" << "paddle::Tensor::mutable_data() after to reallocate " "your tensor's size." << std::endl; } tensor->Resize(new_dim); } Tensor::Tensor(const PlaceType &place) : tensor_(std::make_shared()), place_(place), stream_(StreamWrapper()) {} Tensor::Tensor(const PlaceType &place, const std::vector &shape) : tensor_(std::make_shared()), place_(place), stream_(StreamWrapper()) { GET_CASTED_TENSOR tensor->Resize(framework::make_ddim(shape)); } template T *Tensor::mutable_data(const PlaceType &place) { place_ = place; return mutable_data(); } template T *Tensor::mutable_data() { GET_CASTED_TENSOR PADDLE_ENFORCE_GT( tensor->numel(), 0, platform::errors::PreconditionNotMet( "You should call Tensor::Reshape(const std::vector " "&shape)" "function before retrieving mutable_data from input tensor.")); switch (static_cast(place_)) { case static_cast(PlaceType::kCPU): { return tensor->mutable_data(platform::CPUPlace()); } #ifdef PADDLE_WITH_CUDA case static_cast(PlaceType::kGPU): { int device_num = platform::GetCurrentDeviceId(); return tensor->mutable_data(platform::CUDAPlace(device_num)); } #endif default: PADDLE_THROW(platform::errors::Unavailable( "Custom operator unsupported place id(%d)", static_cast(place_))); } } template T *Tensor::data() const { GET_CASTED_TENSOR; auto *res = tensor->data(); return res; } DataType Tensor::type() const { GET_CASTED_TENSOR; auto type = tensor->type(); if (type == framework::proto::VarType::FP32) { return DataType::FLOAT32; } else if (type == framework::proto::VarType::INT64) { return DataType::INT64; } else if (type == framework::proto::VarType::INT32) { return DataType::INT32; } else if (type == framework::proto::VarType::INT16) { return DataType::INT16; } else if (type == framework::proto::VarType::INT8) { return DataType::INT8; } else if (type == framework::proto::VarType::UINT8) { return DataType::UINT8; } else if (type == framework::proto::VarType::FP64) { return DataType::FLOAT64; } else if (type == framework::proto::VarType::BOOL) { return DataType::BOOL; } else if (type == framework::proto::VarType::COMPLEX64) { return DataType::COMPLEX64; } else if (type == framework::proto::VarType::COMPLEX128) { return DataType::COMPLEX128; } else if (type == framework::proto::VarType::FP16) { return DataType::FLOAT16; } // TODO(JiabinYang) Support more dtype here return DataType::FLOAT32; } template Tensor Tensor::copy_to(const PlaceType &target_place) const { GET_CASTED_TENSOR; PADDLE_ENFORCE_GE(tensor->numel(), 0, platform::errors::PreconditionNotMet( "You should call Tensor::Reshape(const " "std::vector &shape)" "function before copying data from cpu.")); size_t ele_size = tensor->numel() * sizeof(T); auto *p_src_data = tensor->data(); auto src_place = place(); Tensor target = Tensor(target_place); target.reshape(shape()); auto *p_target_data = target.template mutable_data(); if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) { std::memcpy(static_cast(p_target_data), p_src_data, ele_size); } else if ((src_place == PlaceType::kGPU) && (target_place == PlaceType::kCPU)) { GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); } else if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kGPU)) { GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); } else if ((src_place == PlaceType::kGPU) && (target_place == PlaceType::kGPU)) { GpuCopy(p_src_data, p_target_data, src_place, target_place, ele_size); } else { PADDLE_THROW(platform::errors::Unavailable( "Not supported place transform of place: %d to place: %d", static_cast(src_place), static_cast(target_place))); } return target; } template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to( const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to( const PlaceType &target_place) const; template PD_DLL_DECL Tensor Tensor::copy_to(const PlaceType &target_place) const; template PD_DLL_DECL float *Tensor::data() const; template PD_DLL_DECL double *Tensor::data() const; template PD_DLL_DECL int64_t *Tensor::data() const; template PD_DLL_DECL int32_t *Tensor::data() const; template PD_DLL_DECL uint8_t *Tensor::data() const; template PD_DLL_DECL int8_t *Tensor::data() const; template PD_DLL_DECL int16_t *Tensor::data() const; template PD_DLL_DECL bool *Tensor::data() const; template PD_DLL_DECL paddle::platform::complex64 * Tensor::data() const; template PD_DLL_DECL paddle::platform::complex128 * Tensor::data() const; template PD_DLL_DECL paddle::platform::float16 * Tensor::data() const; template PD_DLL_DECL float *Tensor::mutable_data(); template PD_DLL_DECL double *Tensor::mutable_data(); template PD_DLL_DECL int64_t *Tensor::mutable_data(); template PD_DLL_DECL int32_t *Tensor::mutable_data(); template PD_DLL_DECL uint8_t *Tensor::mutable_data(); template PD_DLL_DECL int8_t *Tensor::mutable_data(); template PD_DLL_DECL int16_t *Tensor::mutable_data(); template PD_DLL_DECL bool *Tensor::mutable_data(); template PD_DLL_DECL paddle::platform::complex64 * Tensor::mutable_data(); template PD_DLL_DECL paddle::platform::complex128 * Tensor::mutable_data(); template PD_DLL_DECL paddle::platform::float16 * Tensor::mutable_data(); template PD_DLL_DECL float *Tensor::mutable_data(const PlaceType &place); template PD_DLL_DECL double *Tensor::mutable_data( const PlaceType &place); template PD_DLL_DECL int64_t *Tensor::mutable_data( const PlaceType &place); template PD_DLL_DECL int32_t *Tensor::mutable_data( const PlaceType &place); template PD_DLL_DECL uint8_t *Tensor::mutable_data( const PlaceType &place); template PD_DLL_DECL int8_t *Tensor::mutable_data( const PlaceType &place); template PD_DLL_DECL int16_t *Tensor::mutable_data( const PlaceType &place); template PD_DLL_DECL bool *Tensor::mutable_data(const PlaceType &place); template PD_DLL_DECL paddle::platform::complex64 * Tensor::mutable_data(const PlaceType &place); template PD_DLL_DECL paddle::platform::complex128 * Tensor::mutable_data(const PlaceType &place); template PD_DLL_DECL paddle::platform::float16 * Tensor::mutable_data(const PlaceType &place); std::vector Tensor::shape() const { GET_CASTED_TENSOR return framework::vectorize(tensor->dims()); } const PlaceType &Tensor::place() const { GET_CASTED_TENSOR; if (platform::is_cpu_place(tensor->place())) { place_ = PlaceType::kCPU; } else if (platform::is_gpu_place(tensor->place())) { place_ = PlaceType::kGPU; } else { PADDLE_THROW(platform::errors::Unimplemented( "Current Tensor hold unsupported Place Type, Please Init it" "using Tensor::mutable_data(PaddlePlace) which T is" "either Place::kCPU or Place::kGPU")); } return place_; } Tensor Tensor::cast(const DataType &target_type) const { GET_CASTED_TENSOR; Tensor rlt = Tensor(place()); rlt.reshape(this->shape()); auto rlt_tensor_ = static_cast(rlt.tensor_.get()); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto ctx = pool.Get(tensor->place()); auto src_type = tensor->type(); auto dst_type = framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type); switch (src_type) { case framework::proto::VarType::FP32: framework::VisitDataType(dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::FP64: framework::VisitDataType(dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::INT32: framework::VisitDataType(dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::INT64: framework::VisitDataType( dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::BOOL: framework::VisitDataType(dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::INT16: framework::VisitDataType( dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::UINT8: framework::VisitDataType( dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::COMPLEX64: framework::VisitDataType( dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::COMPLEX128: framework::VisitDataType(dst_type, CastDataType( *tensor, rlt_tensor_, ctx)); break; case framework::proto::VarType::FP16: framework::VisitDataType( dst_type, CastDataType(*tensor, rlt_tensor_, ctx)); break; // TODO(JiabinYang) Support more dtype here default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when casting data type.", framework::DataTypeToString(src_type))); } return rlt; } int64_t Tensor::size() const { GET_CASTED_TENSOR; return tensor->numel(); } #ifdef PADDLE_WITH_CUDA cudaStream_t Tensor::stream() const { if (!stream_.IsStreamSet()) { PADDLE_THROW(platform::errors::PreconditionNotMet( "Stream is not Set, only input tensor will have " "stream which is set by framework ")); } else { return reinterpret_cast(stream_.GetStream()); } } #endif namespace framework { void CustomTensorUtils::ShareDataTo(const paddle::Tensor &src, void *dst) { static_cast(dst)->ShareDataWith( *static_cast(src.tensor_.get())); } void CustomTensorUtils::ShareDataFrom(const void *src, const paddle::Tensor &dst) { if (!dst.tensor_) { dst.tensor_ = std::make_shared(); } auto *tensor = static_cast(dst.tensor_.get()); tensor->ShareDataWith(*static_cast(src)); } } // namespace framework } // namespace paddle