diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.cc b/paddle/fluid/distributed/ps/service/communicator/communicator.cc index 767494b4e32f2a64bb7d19acd338bb59340f2f88..0876e1c704860685e2510d0aa2ebf799148130e9 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.cc +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.cc @@ -28,7 +28,6 @@ limitations under the License. */ namespace paddle { namespace distributed { -using LoDTensor = phi::DenseTensor; using phi::SelectedRows; const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; @@ -97,11 +96,11 @@ void Communicator::RpcRecvDense(const std::vector &varnames, regions.reserve(varnames.size()); for (auto &t : varnames) { Variable *var = scope->Var(t); - LoDTensor *tensor = var->GetMutable(); + phi::DenseTensor *tensor = var->GetMutable(); if (platform::is_gpu_place(tensor->place())) { #ifdef PADDLE_WITH_CUDA Variable *temp_var = xpu_temp_scope_->Var(t); - LoDTensor *temp_tensor = temp_var->GetMutable(); + phi::DenseTensor *temp_tensor = temp_var->GetMutable(); temp_tensor->Resize(tensor->dims()); float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); paddle::distributed::Region reg(temp_data, tensor->numel()); @@ -122,7 +121,7 @@ void Communicator::RpcRecvDense(const std::vector &varnames, for (auto &t : varnames) { Variable *var = scope->FindVar(t); - LoDTensor *tensor = var->GetMutable(); + phi::DenseTensor *tensor = var->GetMutable(); VLOG(3) << "Communicator::RecvNoBarrier Var " << t << " On gpu? " << platform::is_gpu_place(tensor->place()); @@ -132,8 +131,8 @@ void Communicator::RpcRecvDense(const std::vector &varnames, << " Temp_data[-1] " << temp_recv_data[tensor->numel() - 1]; if (platform::is_gpu_place(tensor->place())) { #ifdef PADDLE_WITH_CUDA - LoDTensor *temp_tensor = - xpu_temp_scope_->FindVar(t)->GetMutable(); + phi::DenseTensor *temp_tensor = + xpu_temp_scope_->FindVar(t)->GetMutable(); framework::TensorCopy(*temp_tensor, tensor->place(), tensor); float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); VLOG(1) << "Communicator::RpcRecvDense Var " << t << " table_id " @@ -157,11 +156,11 @@ void Communicator::RpcSendDenseParam(const std::vector &varnames, for (auto &t : varnames) { Variable *var = scope.FindVar(t); CHECK(var != nullptr) << "var[" << t << "] not found"; - LoDTensor *tensor = var->GetMutable(); + phi::DenseTensor *tensor = var->GetMutable(); if (platform::is_gpu_place(tensor->place())) { #ifdef PADDLE_WITH_CUDA Variable *temp_var = xpu_temp_scope_->Var(t); - LoDTensor *temp_tensor = temp_var->GetMutable(); + phi::DenseTensor *temp_tensor = temp_var->GetMutable(); temp_tensor->Resize(tensor->dims()); float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); framework::TensorCopy(*tensor, platform::CPUPlace(), temp_tensor); @@ -203,7 +202,8 @@ void Communicator::RpcSendDense(const CommContext &ctx, float *data = dense_data->data(); uint32_t pos = 0; for (size_t i = 0; i < var_names.size(); ++i) { - const LoDTensor tensor = scope.FindVar(var_names[i])->Get(); + const phi::DenseTensor tensor = + scope.FindVar(var_names[i])->Get(); size_t count = static_cast(tensor.numel()); const float *g = tensor.data(); CHECK(pos + count <= dense_data->size()) @@ -472,13 +472,13 @@ void AsyncCommunicator::RecvNoBarrier() { auto var_names = iter.second; for (auto &t : var_names) { Variable *var = recv_scope_->FindVar(t); - LoDTensor *tensor = var->GetMutable(); + phi::DenseTensor *tensor = var->GetMutable(); VLOG(3) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? " << platform::is_gpu_place(tensor->place()); if (platform::is_gpu_place(tensor->place())) { #ifdef PADDLE_WITH_CUDA - LoDTensor *temp_tensor = - xpu_temp_scope_->FindVar(t)->GetMutable(); + phi::DenseTensor *temp_tensor = + xpu_temp_scope_->FindVar(t)->GetMutable(); framework::TensorCopy(*temp_tensor, tensor->place(), tensor); #endif } @@ -591,8 +591,8 @@ void AsyncCommunicator::PullSparseToTensorSync( uint64_t padding_id, platform::Place place, bool is_training, - std::vector *inputs, - std::vector *outputs) { + std::vector *inputs, + std::vector *outputs) { std::vector fea_keys; std::vector pull_result_ptr; fea_keys.reserve(MAX_FEASIGN_NUM / 100); diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index 798d849366b5f4eca63eac8f30baeb7568af8c1c..a6d233ac6dc4efafaf07c38cac16aaaf86444570 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -25,7 +25,6 @@ limitations under the License. */ namespace paddle { namespace distributed { -using LoDTensor = phi::DenseTensor; using framework::ProgramDesc; using framework::VarDesc; using framework::Variable; @@ -232,7 +231,7 @@ std::future FleetWrapper::PullSparseVarsAsync( if (var == nullptr) { continue; } - LoDTensor* tensor = var->GetMutable(); + phi::DenseTensor* tensor = var->GetMutable(); CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; int64_t* ids = tensor->data(); size_t len = tensor->numel(); @@ -279,7 +278,7 @@ void FleetWrapper::PullSparseVarsSync( if (var == nullptr) { continue; } - LoDTensor* tensor = var->GetMutable(); + phi::DenseTensor* tensor = var->GetMutable(); CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; int64_t* ids = tensor->data(); size_t len = tensor->numel(); @@ -327,13 +326,14 @@ void FleetWrapper::PullSparseVarsSync( // is_training is true means training, false means inference, the behavior is // different on pserver -void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, - int fea_dim, - uint64_t padding_id, - platform::Place place, - bool is_training, - std::vector* inputs, - std::vector* outputs) { +void FleetWrapper::PullSparseToTensorSync( + const uint64_t table_id, + int fea_dim, + uint64_t padding_id, + platform::Place place, + bool is_training, + std::vector* inputs, + std::vector* outputs) { std::vector fea_keys; std::vector pull_result_ptr; fea_keys.reserve(MAX_FEASIGN_NUM / 100); @@ -398,7 +398,7 @@ void FleetWrapper::PullDenseVarsAsync( varname = var_names[i] + "pin"; } Variable* var = scope.FindVar(varname); - LoDTensor* tensor = var->GetMutable(); + phi::DenseTensor* tensor = var->GetMutable(); float* w = tensor->data(); paddle::distributed::Region reg(w, tensor->numel()); regions[i] = std::move(reg); @@ -417,7 +417,7 @@ void FleetWrapper::PullDenseVarsSync( regions.reserve(var_names.size()); for (auto& t : var_names) { Variable* var = scope.FindVar(t); - LoDTensor* tensor = var->GetMutable(); + phi::DenseTensor* tensor = var->GetMutable(); if (!platform::is_gpu_place(tensor->place())) { float* w = tensor->data(); paddle::distributed::Region reg(w, tensor->numel()); @@ -437,7 +437,7 @@ void FleetWrapper::PushDenseParamSync( for (auto& t : var_names) { Variable* var = scope.FindVar(t); CHECK(var != nullptr) << "var[" << t << "] not found"; - LoDTensor* tensor = var->GetMutable(); + phi::DenseTensor* tensor = var->GetMutable(); if (!platform::is_gpu_place(tensor->place())) { float* g = tensor->mutable_data(place); paddle::distributed::Region reg(g, tensor->numel()); @@ -468,7 +468,7 @@ void FleetWrapper::PushDenseVarsAsync( for (auto& t : var_names) { Variable* var = scope.FindVar(t); CHECK(var != nullptr) << "var[" << t << "] not found"; - LoDTensor* tensor = var->GetMutable(); + phi::DenseTensor* tensor = var->GetMutable(); int count = tensor->numel(); float* g = tensor->mutable_data(place); // TODO(zhaocaibei123): how to get batch_size in op? @@ -544,8 +544,8 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync( const std::string& click_name, platform::Place place, const std::vector& input_names, - std::vector* inputs, - std::vector* outputs) { + std::vector* inputs, + std::vector* outputs) { // not support return; } @@ -555,11 +555,11 @@ void FleetWrapper::PushSparseFromTensorAsync( int fea_dim, uint64_t padding_id, platform::Place place, - std::vector* inputs, + std::vector* inputs, std::vector& slots, - const LoDTensor* shows, - const LoDTensor* clks, - std::vector* outputs, + const phi::DenseTensor* shows, + const phi::DenseTensor* clks, + std::vector* outputs, bool use_cvm_op) { CHECK(slots.size() == inputs->size()); int batch_size = -1; @@ -777,7 +777,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Variable* var = scope->FindVar(name); CHECK(var != nullptr) << "var[" << name << "] not found"; VLOG(3) << "prepare shrink dense batch_sum"; - LoDTensor* tensor = var->GetMutable(); + phi::DenseTensor* tensor = var->GetMutable(); float* g = tensor->data(); // show_batch_sum += N * log(decay) @@ -787,7 +787,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Variable* var_size = scope->FindVar(size_name); CHECK(var_size != nullptr) << "var[" << size_name << "] not found"; VLOG(3) << "shrink dense batch_sum: " << name << ", " << size_name; - float* g_size = var_size->GetMutable()->data(); + float* g_size = var_size->GetMutable()->data(); for (int k = 0; k < tensor->numel(); k += emb_dim) { g[k] = g[k] + g_size[k] * log(decay); @@ -797,7 +797,7 @@ void FleetWrapper::ShrinkDenseTable(int table_id, } else { Variable* var = scope->FindVar(name); CHECK(var != nullptr) << "var[" << name << "] not found"; - LoDTensor* tensor = var->GetMutable(); + phi::DenseTensor* tensor = var->GetMutable(); float* g = tensor->data(); paddle::distributed::Region reg(g, tensor->numel()); regions.emplace_back(std::move(reg)); diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.h b/paddle/fluid/distributed/ps/wrapper/fleet.h index 438e3ff010d4d594e361ea440184b8055c5997a0..c7aaededa20ba955d57aa1423f9d517f44204bfb 100755 --- a/paddle/fluid/distributed/ps/wrapper/fleet.h +++ b/paddle/fluid/distributed/ps/wrapper/fleet.h @@ -47,7 +47,6 @@ namespace distributed { class PSCore; -using LoDTensor = phi::DenseTensor; using framework::Scope; using framework::Variable; using phi::SelectedRows; @@ -111,13 +110,14 @@ class FleetWrapper { // is_training is true means training, false means inference, the behavior is // different on pserver - void PullSparseToTensorSync(const uint64_t table_id, - int fea_dim, - uint64_t padding_id, - platform::Place place, - bool is_training, - std::vector* inputs, // NOLINT - std::vector* outputs); // NOLINT + void PullSparseToTensorSync( + const uint64_t table_id, + int fea_dim, + uint64_t padding_id, + platform::Place place, + bool is_training, + std::vector* inputs, // NOLINT + std::vector* outputs); // NOLINT // pull dense variables from server in sync mod // Param: scope, table_id, var_names @@ -188,18 +188,18 @@ class FleetWrapper { const std::string& click_name, platform::Place place, const std::vector& input_names, - std::vector* inputs, // NOLINT - std::vector* outputs); // NOLINT + std::vector* inputs, // NOLINT + std::vector* outputs); // NOLINT void PushSparseFromTensorAsync(const uint64_t table_id, int fea_dim, uint64_t padding_id, platform::Place place, - std::vector* inputs, + std::vector* inputs, std::vector& slots, // NOLINT - const LoDTensor* shows, - const LoDTensor* clicks, - std::vector* outputs, + const phi::DenseTensor* shows, + const phi::DenseTensor* clicks, + std::vector* outputs, bool use_cvm_op = false); // Push sparse variables to server in Async mode // Param: scope, table_id, fea_keys, sparse_grad_names diff --git a/paddle/fluid/distributed/ps/wrapper/ps_wrapper.h b/paddle/fluid/distributed/ps/wrapper/ps_wrapper.h index 6de4048e701bd46702c3f394585d2b939f8bdbe7..c4c7c15fdb90b16cacfc1086c851efef5987e901 100644 --- a/paddle/fluid/distributed/ps/wrapper/ps_wrapper.h +++ b/paddle/fluid/distributed/ps/wrapper/ps_wrapper.h @@ -47,7 +47,6 @@ namespace distributed { class PSCore; -using LoDTensor = phi::DenseTensor; using framework::Scope; using framework::Variable; using phi::SelectedRows; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 27bbdc0bbf9bf3f4502a438fa45eccdc3ea1651b..69bb5b7ed8589a94b2b0bb31e3bb984a3da68719 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -35,7 +35,6 @@ namespace paddle { -using LoDTensor = phi::DenseTensor; using framework::Variable; using framework::ir::Graph; using platform::CPUPlace; @@ -48,19 +47,19 @@ using EigenMatrixArray = using ConstEigenMatrixArrayMap = Eigen::Map; using string::PrettyLogH1; using VariableNameMap = std::map>; -static LoDTensor CreateScaleTensor(int64_t channels_num = 1); +static phi::DenseTensor CreateScaleTensor(int64_t channels_num = 1); static void check_var(const Variable* var, const std::string& var_name) { PADDLE_ENFORCE_NOT_NULL( var, platform::errors::PreconditionNotMet("%s is not in the scope", var_name)); PADDLE_ENFORCE_EQ( - var->IsType(), + var->IsType(), true, platform::errors::PreconditionNotMet("Only support lod tensor now.")); } -static void check_tensor(const LoDTensor& tensor) { +static void check_tensor(const phi::DenseTensor& tensor) { PADDLE_ENFORCE_GT( tensor.dims().size(), 0, @@ -78,8 +77,8 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForRNNWeights( auto* wh_var = predictor_.sub_scope_->FindVar(wh_name); check_var(wx_var, wx_name); check_var(wh_var, wh_name); - LoDTensor* wx_tensor = wx_var->GetMutable(); - LoDTensor* wh_tensor = wh_var->GetMutable(); + phi::DenseTensor* wx_tensor = wx_var->GetMutable(); + phi::DenseTensor* wh_tensor = wh_var->GetMutable(); if (gru) { scales_[wx_name] = GetMaxChGRUScalingFactor(*wx_tensor, *wh_tensor); } else { @@ -101,7 +100,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpInputs( if (scales_.find(var_name) != scales_.end()) continue; auto* var = predictor_.sub_scope_->FindVar(var_name); check_var(var, var_name); - LoDTensor* var_tensor = var->GetMutable(); + phi::DenseTensor* var_tensor = var->GetMutable(); // force unsigned type if already know it bool is_unsigned = false; CalculateSingleScale( @@ -118,7 +117,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( if (scales_.find(var_name) != scales_.end()) continue; auto* var = predictor_.sub_scope_->FindVar(var_name); check_var(var, var_name); - LoDTensor* var_tensor = var->GetMutable(); + phi::DenseTensor* var_tensor = var->GetMutable(); // force unsigned type if already know it bool is_unsigned = false; bool compute_scale = true; @@ -183,7 +182,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() { PrettyLogH1("--- Calculating scales for quantization"); - std::map> gathered_data; + std::map> gathered_data; for (const auto* op : predictor_.inference_program_->Block(0).AllOps()) { if (platform::HasOpINT8DataType(op)) { // handle inputs first to let is_unsigned be inferred for the outputs @@ -198,20 +197,20 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateSingleScale( const std::string& op_type_name, const std::string& conn_name, const std::string& var_name, - const LoDTensor& var_tensor, + const phi::DenseTensor& var_tensor, bool is_unsigned) { auto rule = qconfig_->scale_algo(op_type_name, conn_name); if (rule == ScaleAlgo::NONE) return; - PADDLE_ENFORCE_GT( - var_tensor.numel(), - 0, - platform::errors::InvalidArgument( - "MkldnnQuantizer: LoDTensor of variable %s for quantization of op " - "%s of connection %s should not be empty.", - var_name, - op_type_name, - conn_name)); + PADDLE_ENFORCE_GT(var_tensor.numel(), + 0, + platform::errors::InvalidArgument( + "MkldnnQuantizer: phi::DenseTensor of variable %s for " + "quantization of op " + "%s of connection %s should not be empty.", + var_name, + op_type_name, + conn_name)); switch (rule) { case ScaleAlgo::MAX: @@ -236,8 +235,8 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateSingleScale( } } -static LoDTensor CreateScaleTensor(int64_t channels_num) { - LoDTensor scale_tensor; +static phi::DenseTensor CreateScaleTensor(int64_t channels_num) { + phi::DenseTensor scale_tensor; scale_tensor.Resize({channels_num}); scale_tensor.mutable_data(CPUPlace()); return scale_tensor; @@ -272,9 +271,9 @@ std::vector AnalysisPredictor::MkldnnQuantizer::ExpandQuantizedBins( return expanded_quantized_bins; } -std::pair +std::pair AnalysisPredictor::MkldnnQuantizer::GetKLScalingFactor( - const LoDTensor& var_tensor, bool is_unsigned) const { + const phi::DenseTensor& var_tensor, bool is_unsigned) const { ConstEigenVectorArrayMap eigen_tensor{ var_tensor.data(), var_tensor.numel(), 1}; int precision_hist_num_bins = 2048; @@ -381,15 +380,15 @@ AnalysisPredictor::MkldnnQuantizer::GetKLScalingFactor( min_kl_index = starting_iter; } - LoDTensor scale_tensor = CreateScaleTensor(); + phi::DenseTensor scale_tensor = CreateScaleTensor(); scale_tensor.data()[0] = 1.0 / ((min_kl_index + 0.5) * bin_width); return std::make_pair(is_unsigned, scale_tensor); } -std::pair +std::pair AnalysisPredictor::MkldnnQuantizer::GetMaxScalingFactor( - const LoDTensor& var_tensor, bool is_unsigned) const { + const phi::DenseTensor& var_tensor, bool is_unsigned) const { ConstEigenVectorArrayMap eigen_tensor{ var_tensor.data(), var_tensor.numel(), 1}; float max_abs = eigen_tensor.abs().maxCoeff(); @@ -402,15 +401,17 @@ AnalysisPredictor::MkldnnQuantizer::GetMaxScalingFactor( "Tensor is claimed to be unsigned, but its min value (%f) is < 0.0", min_val)); - LoDTensor scale_tensor = CreateScaleTensor(); + phi::DenseTensor scale_tensor = CreateScaleTensor(); scale_tensor.data()[0] = 1.0 / max_abs; return std::make_pair(is_unsigned, scale_tensor); } -std::pair +std::pair AnalysisPredictor::MkldnnQuantizer::GetMaxChScalingFactor( - const LoDTensor& var_tensor, bool is_unsigned, bool is_transposed) const { + const phi::DenseTensor& var_tensor, + bool is_unsigned, + bool is_transposed) const { check_tensor(var_tensor); ConstEigenVectorArrayMap eigen_tensor{ @@ -438,16 +439,17 @@ AnalysisPredictor::MkldnnQuantizer::GetMaxChScalingFactor( } int output_channel_axis = is_transposed; int channels = dims[output_channel_axis]; - LoDTensor scale_tensor = CreateScaleTensor(channels); + phi::DenseTensor scale_tensor = CreateScaleTensor(channels); auto* scale_ptr = scale_tensor.mutable_data(CPUPlace()); std::copy(scales.data(), scales.data() + scales.size(), scale_ptr); return std::make_pair(is_unsigned, scale_tensor); } -std::pair +std::pair AnalysisPredictor::MkldnnQuantizer::GetMaxChGRUScalingFactor( - const LoDTensor& wx_tensor, const LoDTensor& wh_tensor) const { + const phi::DenseTensor& wx_tensor, + const phi::DenseTensor& wh_tensor) const { check_tensor(wx_tensor); check_tensor(wh_tensor); @@ -494,16 +496,17 @@ AnalysisPredictor::MkldnnQuantizer::GetMaxChGRUScalingFactor( transform(scale_ur.begin(), scale_ur.end(), scale_ur.begin(), [](float& c) { return 1 / c; }); - LoDTensor scale_tensor = CreateScaleTensor(scale_ur.size()); + phi::DenseTensor scale_tensor = CreateScaleTensor(scale_ur.size()); auto* scale_ptr = scale_tensor.mutable_data(CPUPlace()); std::copy(scale_ur.begin(), scale_ur.end(), scale_ptr); bool is_unsigned = false; return std::make_pair(is_unsigned, scale_tensor); } -std::pair +std::pair AnalysisPredictor::MkldnnQuantizer::GetMaxChLSTMScalingFactor( - const LoDTensor& wx_tensor, const LoDTensor& wh_tensor) const { + const phi::DenseTensor& wx_tensor, + const phi::DenseTensor& wh_tensor) const { check_tensor(wx_tensor); check_tensor(wh_tensor); @@ -530,7 +533,7 @@ AnalysisPredictor::MkldnnQuantizer::GetMaxChLSTMScalingFactor( transform(scale.begin(), scale.end(), scale.begin(), [](float& c) { return 1 / c; }); - LoDTensor scale_tensor = CreateScaleTensor(scale.size()); + phi::DenseTensor scale_tensor = CreateScaleTensor(scale.size()); auto* scale_ptr = scale_tensor.mutable_data(CPUPlace()); std::copy(scale.begin(), scale.end(), scale_ptr); bool is_unsigned = false; diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.h b/paddle/fluid/inference/tensorrt/convert/io_converter.h index d57ccb19aea9f88117401f620df72b01271763b1..5810768f12a4481743397a6b61b98e1a571e75b7 100644 --- a/paddle/fluid/inference/tensorrt/convert/io_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/io_converter.h @@ -24,8 +24,6 @@ namespace paddle { namespace inference { namespace tensorrt { -using LoDTensor = phi::DenseTensor; - /* * Convert Input from Fluid to TensorRT Engine. * Convert Output from TensorRT Engine to Fluid. @@ -38,13 +36,17 @@ class EngineIOConverter { public: EngineIOConverter() {} - virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {} - virtual void operator()(const void* in, LoDTensor* out, size_t max_size) {} + virtual void operator()(const phi::DenseTensor& in, + void* out, + size_t max_size) {} + virtual void operator()(const void* in, + phi::DenseTensor* out, + size_t max_size) {} void SetStream(cudaStream_t* stream) { stream_ = stream; } static void ConvertInput(const std::string& op_type, - const LoDTensor& in, + const phi::DenseTensor& in, void* out, size_t max_size, cudaStream_t* stream) { @@ -63,7 +65,7 @@ class EngineIOConverter { static void ConvertOutput(const std::string& op_type, const void* in, - LoDTensor* out, + phi::DenseTensor* out, size_t max_size, cudaStream_t* stream) { PADDLE_ENFORCE_NOT_NULL(stream, diff --git a/paddle/fluid/platform/device/ipu/ipu_utils.h b/paddle/fluid/platform/device/ipu/ipu_utils.h index 5e93ce4bf9385c4e23c0eba0732ef2a268afae78..5f50a54cda2972f260d81cf23fd09a05917e5fcb 100644 --- a/paddle/fluid/platform/device/ipu/ipu_utils.h +++ b/paddle/fluid/platform/device/ipu/ipu_utils.h @@ -30,8 +30,6 @@ limitations under the License. */ #include "paddle/fluid/platform/float16.h" using float16 = paddle::platform::float16; -using Tensor = phi::DenseTensor; -using LoDTensor = phi::DenseTensor; using Scope = paddle::framework::Scope; using OpDesc = paddle::framework::OpDesc; using Graph = paddle::framework::ir::Graph; diff --git a/paddle/fluid/platform/device/npu/npu_op_runner.h b/paddle/fluid/platform/device/npu/npu_op_runner.h index 49d1699cca2245ab8104942ee23c506e62e273a0..d5e373036c95b7326e1ce839e7ae97dd28964dc3 100644 --- a/paddle/fluid/platform/device/npu/npu_op_runner.h +++ b/paddle/fluid/platform/device/npu/npu_op_runner.h @@ -28,7 +28,6 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = phi::DenseTensor; using DataLayout = phi::DataLayout; using NPUAttribute = framework::NPUAttribute; using NPUAttributeMap = framework::NPUAttributeMap; @@ -39,8 +38,8 @@ class NpuOpRunner { NpuOpRunner(); explicit NpuOpRunner(const std::string &op_type); NpuOpRunner(const std::string &op_type, - const std::vector &inputs = {}, - const std::vector &outputs = {}, + const std::vector &inputs = {}, + const std::vector &outputs = {}, const NPUAttributeMap &attrs = {}); // NOTE(zhiqiu): why forbid copy and operator= ? @@ -67,12 +66,12 @@ class NpuOpRunner { NpuOpRunner &AddAttrs(const NPUAttributeMap &attrs); - NpuOpRunner &AddInput(const Tensor &tensor); + NpuOpRunner &AddInput(const phi::DenseTensor &tensor); // NOTE(zhiqiu): CANN-5.0.2 support input tensors on host. // Specifically, the tensor of shape, tensor of dims, etc, which are small // vector/list. - NpuOpRunner &AddInput(const Tensor &tensor, aclMemType mem_type); + NpuOpRunner &AddInput(const phi::DenseTensor &tensor, aclMemType mem_type); NpuOpRunner &AddInput(std::vector &&dims); @@ -82,13 +81,13 @@ class NpuOpRunner { NpuOpRunner &AddInput(std::vector &&values); - NpuOpRunner &AddOutput(const Tensor &tensor); + NpuOpRunner &AddOutput(const phi::DenseTensor &tensor); - NpuOpRunner &AddInputs(const std::vector &tensors); + NpuOpRunner &AddInputs(const std::vector &tensors); NpuOpRunner &AddInputNames(const std::vector &names); - NpuOpRunner &AddOutputs(const std::vector &tensors); + NpuOpRunner &AddOutputs(const std::vector &tensors); aclTensorDesc *GetInputDesc(size_t index); @@ -105,21 +104,21 @@ class NpuOpRunner { void Run(aclrtStream stream = nullptr) const; static void TypeAdapter( - const std::vector &inputs, - const std::vector &outputs, + const std::vector &inputs, + const std::vector &outputs, const NPUAttributeMap &attrs, const platform::NPUDeviceContext &dev_ctx, - std::function &, - const std::vector &, + std::function &, + const std::vector &, const NPUAttributeMap &, const platform::NPUDeviceContext &)> op_runner, const std::vector &input_type, const std::vector &output_type); private: - aclTensorDesc *CreateTensorDesc(Tensor tensor, + aclTensorDesc *CreateTensorDesc(phi::DenseTensor tensor, aclMemType mem_type = ACL_MEMTYPE_DEVICE); - aclDataBuffer *CreateDataBuffer(Tensor tensor); + aclDataBuffer *CreateDataBuffer(phi::DenseTensor tensor); private: std::string op_type_; @@ -127,7 +126,7 @@ class NpuOpRunner { std::vector output_buffers_; std::vector input_descs_; std::vector output_descs_; - std::vector host_tensors_; + std::vector host_tensors_; aclopAttr *attr_{nullptr}; }; @@ -136,7 +135,7 @@ aclDataType ConvertToNpuDtype(framework::proto::VarType::Type dtype); aclrtStream GetCurrentNPUStream(int device_id = -1); template -void FillNpuTensorWithConstant(Tensor *tensor, T val) { +void FillNpuTensorWithConstant(phi::DenseTensor *tensor, T val) { PADDLE_ENFORCE_EQ( tensor->IsInitialized(), true, @@ -148,7 +147,7 @@ void FillNpuTensorWithConstant(Tensor *tensor, T val) { int numel = tensor->numel(); if (numel == 1) { - Tensor npu_pinned_tensor(tensor->dtype()); + phi::DenseTensor npu_pinned_tensor(tensor->dtype()); platform::NPUPinnedPlace npu_pinned_place; auto npu_pinned_ptr = npu_pinned_tensor.mutable_data({1}, npu_pinned_place);