// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; inline std::vector GetNewDataFromShapeTensor( const Tensor* new_data_tensor) { if (framework::TransToProtoVarType(new_data_tensor->dtype()) == framework::proto::VarType::INT64) { auto* new_data = new_data_tensor->data(); framework::Tensor cpu_starts_tensor; if (platform::is_gpu_place(new_data_tensor->place())) { paddle::framework::TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor); new_data = cpu_starts_tensor.data(); } std::vector vec_new_data(new_data, new_data + new_data_tensor->numel()); return vec_new_data; } else if (framework::TransToProtoVarType(new_data_tensor->dtype()) == framework::proto::VarType::INT32) { auto* new_data = new_data_tensor->data(); std::vector vec_new_data; framework::Tensor cpu_starts_tensor; if (platform::is_gpu_place(new_data_tensor->place())) { paddle::framework::TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor); new_data = cpu_starts_tensor.data(); } for (int i = 0; i < new_data_tensor->numel(); ++i) { vec_new_data.push_back(static_cast(*(new_data + i))); } return vec_new_data; } else { PADDLE_THROW(platform::errors::InvalidArgument( "Expected dtype of ShapeTensor must be int32, int64. But got " "unsupport dtype: %s.", new_data_tensor->dtype())); } } inline std::vector GetNewDataFromShapeTensorList( const std::vector& list_new_shape_tensor) { std::vector vec_new_shape; vec_new_shape.reserve(list_new_shape_tensor.size()); for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { auto tensor = list_new_shape_tensor[i]; PADDLE_ENFORCE_EQ( tensor->dims(), phi::make_ddim({1}), platform::errors::InvalidArgument( "Shape of dim tensor in uniform_random_op should be [1]" "But received tensor's dim=%s.", tensor->dims())); if (framework::TransToProtoVarType(tensor->dtype()) == framework::proto::VarType::INT32) { if (platform::is_gpu_place(tensor->place())) { framework::Tensor temp; paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); vec_new_shape.push_back(static_cast(*temp.data())); } else { vec_new_shape.push_back(static_cast(*tensor->data())); } } else if (framework::TransToProtoVarType(tensor->dtype()) == framework::proto::VarType::INT64) { if (platform::is_gpu_place(tensor->place())) { framework::Tensor temp; paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); vec_new_shape.push_back(*temp.data()); } else { vec_new_shape.push_back(*tensor->data()); } } else { PADDLE_THROW(platform::errors::InvalidArgument( "Expected dtype of ShapeTensorList of %d-th must be int32, int64. " "But got " "unsupport dtype: %s.", i, paddle::framework::DataTypeToString( framework::TransToProtoVarType(tensor->dtype())))); } } return vec_new_shape; } } // namespace operators } // namespace paddle