未验证 提交 91dd8a2e 编写于 作者: 张春乔 提交者: GitHub

Replace LoDTensor with phi::DenseTensor in fluid\operators (#48417)

* replace LoDTensor with phi::DenseTensor in fluid\operators

* replace LoDTensor with phi::DenseTensor in fluid\operators

* Update split_lod_tensor_op.cc

* Update warpctc_op.cc

* Update broadcast_tensors_op.cc

* Update crf_decoding_op.cc

* Update lstm_op.cc

* Update lstm_op.cc

* Update lod_reset_op.cc

* Update gru_op.cc

* Update linear_chain_crf_op.cc

* resume 2 files for confilct

* Update gru_op.cc

* Update linear_chain_crf_op.cc

* Update lstm_op.cc
上级 30315ac9
......@@ -157,7 +157,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
return table_items[a].index < table_items[b].index;
});
// Build LoDTensor `out`
// Build phi::DenseTensor `out`
framework::LoD *out_lod = out->mutable_lod();
out_lod->clear();
auto prefix_lod = rank_table.coarse_lod();
......@@ -215,16 +215,18 @@ class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X",
"(std::vector<LodTensor>) A vector of tensors that is going to "
"be casted to a big LoDTensor.");
"be casted to a big phi::DenseTensor.");
AddInput("RankTable",
"(LoDRankTable) RankTable provides the coarse lod information to "
"build the output LoDTensor. See "
"build the output phi::DenseTensor. See "
"'paddle/framework/lod_rank_table.h' for more details.");
AddOutput("Out", "(LoDTensor) The LoDTensor formed by input tensor array.");
AddOutput("Out",
"(phi::DenseTensor) The phi::DenseTensor formed by input tensor "
"array.");
AddComment(
R"DOC(This Op build a big LoDTensor from a std::vector<LoDTensor>
R"DOC(This Op build a big phi::DenseTensor from a std::vector<phi::DenseTensor>
and a LoDRankTable. It is supposed to be used in getting dynamic RNN's
outputs back to a normal LoDTensor. The std::vector<LoDTensor>
outputs back to a normal phi::DenseTensor. The std::vector<phi::DenseTensor>
would be the output of RNN Op and the LoDRankTable would be build
with RNN's input.)DOC");
}
......@@ -247,9 +249,9 @@ class ArrayToLoDTensorInferShape : public framework::InferShapeBase {
// detail kernel implementation.
context->SetOutputDim("Out", context->GetInputDim("X"));
// The output LoDTensor's lod_level should be input X's lod_level + 1.
// For compile-time, we call SetLoDLevel to set output's lod_level.
// For runtime, output LoDTensor's lod is determined by input X's lod and
// The output phi::DenseTensor's lod_level should be input X's lod_level
// + 1. For compile-time, we call SetLoDLevel to set output's lod_level. For
// runtime, output phi::DenseTensor's lod is determined by input X's lod and
// the level specified by input RandTable.
// We cannot get X's detail lod and RankTable's level in this function, so
// leave this work to the detail kernel implementation.
......
......@@ -41,8 +41,6 @@ const char kSummarize[] = "summarize";
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
class AssertOp : public framework::OperatorBase {
public:
AssertOp(const std::string &type,
......@@ -58,7 +56,7 @@ class AssertOp : public framework::OperatorBase {
PADDLE_ENFORCE_NOT_NULL(cond_var_ptr,
platform::errors::NotFound(
"Input(Condition) of AssertOp is not found."));
const LoDTensor &cond = cond_var_ptr->Get<LoDTensor>();
const phi::DenseTensor &cond = cond_var_ptr->Get<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
cond.dims(),
phi::make_ddim({1}),
......@@ -78,7 +76,7 @@ class AssertOp : public framework::OperatorBase {
const std::vector<std::string> &x_names = Inputs(kData);
for (const std::string &name : x_names) {
const framework::Variable *x_var_ptr = scope.FindVar(name);
const phi::DenseTensor &x_tensor = x_var_ptr->Get<LoDTensor>();
const phi::DenseTensor &x_tensor = x_var_ptr->Get<phi::DenseTensor>();
formatter.Print(x_tensor, name);
}
......
......@@ -79,16 +79,19 @@ class AssignInferVarType : public framework::VarTypeInference {
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor, SelectedRows or LoDTensorArray) The input variable "
"could be LoDTensor, SelectedRows or LoDTensorArray.")
AddInput(
"X",
"(phi::DenseTensor, SelectedRows or phi::DenseTensorArray) The input "
"variable "
"could be phi::DenseTensor, SelectedRows or phi::DenseTensorArray.")
.AsDispensable();
AddOutput("Out",
"(LoDTensor, SelectedRows or LoDTensorArray) The type of output "
"(phi::DenseTensor, SelectedRows or phi::DenseTensorArray) The "
"type of output "
"is the same as input X.");
AddComment(R"DOC(Assign Operator
Out = X, when type in [LoDTensor/SelectedRows/LoDTensorArray]
Out = X, when type in [phi::DenseTensor/SelectedRows/phi::DenseTensorArray]
raise error if the type is not listed above.
)DOC");
}
......
......@@ -59,13 +59,14 @@ class AssignPosCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
// assign pos decides which tokens should be fetched belong to specially
// counter orderingly.
auto cum_count = context.Input<LoDTensor>(
auto cum_count = context.Input<phi::DenseTensor>(
"cum_count"); // (counter number) int32 | int64
auto numbers =
context.Input<LoDTensor>("X"); // (batch_size * seq_len, topk) int32
auto numbers = context.Input<phi::DenseTensor>(
"X"); // (batch_size * seq_len, topk) int32
auto eff_num_len =
context.Input<LoDTensor>("eff_num_len"); // (sum(cum_count))
auto out = context.Output<LoDTensor>("Out"); // (cum_count) value ranges
context.Input<phi::DenseTensor>("eff_num_len"); // (sum(cum_count))
auto out =
context.Output<phi::DenseTensor>("Out"); // (cum_count) value ranges
// from 0 to batch_size *
// seq_len * topk
auto place = context.GetPlace();
......
......@@ -20,8 +20,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
template <typename T>
class AssignPosOpCPUKernel : public framework::OpKernel<T> {
public:
......
......@@ -205,11 +205,12 @@ framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
}
void AttentionLSTMOpMaker::Make() {
AddInput("X",
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
AddInput(
"X",
"(phi::DenseTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this phi::DenseTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
AddInput("C0",
"(Tensor) LSTM C0"
"This is a tensor with shape (N x D), where N is the batch size, D "
......@@ -247,12 +248,14 @@ void AttentionLSTMOpMaker::Make() {
"Note: we should add the bias of hidden and context accorindg to "
"the same gate: "
"{B_forget, B_input, B_output, B_cell}");
AddOutput("Hidden",
"(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput(
"Hidden",
"(phi::DenseTensor) (same as LSTMOp) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput(
"Cell",
"(phi::DenseTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("AttentionedX",
"(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
" where T is the total time steps in this mini-batch,"
......@@ -339,7 +342,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = phi::CPUContext;
auto* x = ctx.Input<LoDTensor>("X");
auto* x = ctx.Input<phi::DenseTensor>("X");
auto* h0 = ctx.Input<phi::DenseTensor>("H0");
auto* c0 = ctx.Input<phi::DenseTensor>("C0");
auto* atten_w = ctx.Input<phi::DenseTensor>("AttentionWeight");
......@@ -350,8 +353,8 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto* lstm_w = ctx.Input<phi::DenseTensor>("LSTMWeight");
auto* lstm_b = ctx.Input<phi::DenseTensor>("LSTMBias");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
auto* hidden_out = ctx.Output<phi::DenseTensor>("Hidden");
auto* cell_out = ctx.Output<phi::DenseTensor>("Cell");
auto* atted_x = ctx.Output<phi::DenseTensor>("AttentionedX");
auto* fc_out = ctx.Output<phi::DenseTensor>("AttentionFCOut");
auto* lstm_x = ctx.Output<phi::DenseTensor>("LSTMX");
......
......@@ -18,7 +18,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
class AttentionLSTMOp : public framework::OperatorWithKernel {
......
......@@ -383,8 +383,8 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
......@@ -525,8 +525,8 @@ framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType(
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
......
......@@ -28,7 +28,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using DataLayout = phi::DataLayout;
template <typename T>
......
......@@ -23,8 +23,8 @@ namespace operators {
struct BeamSearchDecodeFunctor {
BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor,
LoDTensor* score_tensor,
phi::DenseTensor* id_tensor,
phi::DenseTensor* score_tensor,
size_t beam_size,
int end_id)
: beam_size_(beam_size),
......@@ -119,8 +119,8 @@ struct BeamSearchDecodeFunctor {
const LoDTensorArray& step_scores_origin_;
LoDTensorArray step_ids_ = LoDTensorArray();
LoDTensorArray step_scores_ = LoDTensorArray();
LoDTensor* id_tensor_;
LoDTensor* score_tensor_;
phi::DenseTensor* id_tensor_;
phi::DenseTensor* score_tensor_;
};
template <typename DeviceContext, typename T>
......@@ -164,8 +164,10 @@ class BeamSearchDecodeOpKernel : public framework::OpKernel<T> {
int end_id = context.Attr<int>("end_id");
// prepare output
LoDTensor* sentenceIds = context.Output<LoDTensor>("SentenceIds");
LoDTensor* sentenceScores = context.Output<LoDTensor>("SentenceScores");
phi::DenseTensor* sentenceIds =
context.Output<phi::DenseTensor>("SentenceIds");
phi::DenseTensor* sentenceScores =
context.Output<phi::DenseTensor>("SentenceScores");
BeamSearchDecodeFunctor bs(
*ids, *scores, sentenceIds, sentenceScores, beam_size, end_id);
......
......@@ -23,7 +23,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
using LoDTensorArray = framework::LoDTensorArray;
// all the lod have 2 levels.
......@@ -54,15 +53,15 @@ struct BeamSearchDecoder {
* with word score.
* Param:
* sentence_vector_list: sentence_vector for each source sentence.
* id_tensor: result LoDTensor for sentences of id.
* score_tensor: result LoDTensor for sentences of score.
* id_tensor: result phi::DenseTensor for sentences of id.
* score_tensor: result phi::DenseTensor for sentences of score.
* reverse: whether ids of sentence in sentence_vector_list is reversed
* sort_by_score: whether to sort hypotheses of each sentence by scores.
*/
void ConvertSentenceVectorToLodTensor(
std::vector<SentenceVector<T>> sentence_vector_list,
LoDTensor* id_tensor,
LoDTensor* score_tensor,
phi::DenseTensor* id_tensor,
phi::DenseTensor* score_tensor,
bool reverse = true,
bool sort_by_score = true) const;
......@@ -72,8 +71,8 @@ struct BeamSearchDecoder {
*/
void Backtrace(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor,
LoDTensor* score_tensor) const;
phi::DenseTensor* id_tensor,
phi::DenseTensor* score_tensor) const;
size_t beam_size_;
int end_id_;
......@@ -82,8 +81,8 @@ struct BeamSearchDecoder {
template <typename T>
void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
std::vector<SentenceVector<T>> sentence_vector_list,
LoDTensor* id_tensor,
LoDTensor* score_tensor,
phi::DenseTensor* id_tensor,
phi::DenseTensor* score_tensor,
bool reverse,
bool sort_by_score) const {
size_t src_num = sentence_vector_list.size();
......@@ -158,8 +157,8 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
template <typename T>
void BeamSearchDecoder<T>::Backtrace(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor,
LoDTensor* score_tensor) const {
phi::DenseTensor* id_tensor,
phi::DenseTensor* score_tensor) const {
PADDLE_ENFORCE_NE(
step_ids.empty(),
true,
......
......@@ -18,7 +18,6 @@ limitations under the License. */
using CPUPlace = paddle::platform::CPUPlace;
using LoD = paddle::framework::LoD;
using LoDTensor = phi::DenseTensor;
using LoDTensorArray = paddle::framework::LoDTensorArray;
template <typename T>
......@@ -59,7 +58,7 @@ void GenerateExample(const std::vector<size_t>& level_0,
lod.push_back(level_1);
// Ids
LoDTensor tensor_id;
phi::DenseTensor tensor_id;
tensor_id.set_lod(lod);
tensor_id.Resize({static_cast<int64_t>(data.size())});
// malloc memory
......@@ -69,7 +68,7 @@ void GenerateExample(const std::vector<size_t>& level_0,
}
// Scores
LoDTensor tensor_score;
phi::DenseTensor tensor_score;
tensor_score.set_lod(lod);
tensor_score.Resize({static_cast<int64_t>(data.size())});
// malloc memory
......@@ -124,8 +123,8 @@ void BeamSearchDecodeTestFrame() {
BeamSearchDecoder<T> helper(2, 1); // beam_size = 2, end_id = 1
LoDTensor id_tensor;
LoDTensor score_tensor;
phi::DenseTensor id_tensor;
phi::DenseTensor score_tensor;
helper.Backtrace(ids, scores, &id_tensor, &score_tensor);
LoD lod = id_tensor.lod();
......
......@@ -62,20 +62,21 @@ class BeamSearchDecodeXPUKernel : public framework::OpKernel<T> {
int end_id = context.Attr<int>("end_id");
// prepare output
LoDTensor* sentenceIds = nullptr;
LoDTensor* sentenceScores = nullptr;
phi::DenseTensor* sentenceIds = nullptr;
phi::DenseTensor* sentenceScores = nullptr;
LoDTensor* sentenceIds_temp = context.Output<LoDTensor>("SentenceIds");
LoDTensor* sentenceScores_temp =
context.Output<LoDTensor>("SentenceScores");
phi::DenseTensor* sentenceIds_temp =
context.Output<phi::DenseTensor>("SentenceIds");
phi::DenseTensor* sentenceScores_temp =
context.Output<phi::DenseTensor>("SentenceScores");
if (platform::is_xpu_place(ids->at(0).place())) {
sentenceIds = new LoDTensor();
sentenceIds = new phi::DenseTensor();
sentenceIds->set_lod(sentenceIds_temp->lod());
}
if (platform::is_xpu_place(ids->at(0).place())) {
sentenceScores = new LoDTensor();
sentenceScores = new phi::DenseTensor();
sentenceScores->set_lod(sentenceScores_temp->lod());
}
......
......@@ -18,7 +18,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
int SetMeta(const LoDTensor& srcTensor, LoDTensor* dstTensor) {
int SetMeta(const phi::DenseTensor& srcTensor, phi::DenseTensor* dstTensor) {
if (srcTensor.dtype() == paddle::experimental::DataType::INT32 ||
srcTensor.dtype() == paddle::experimental::DataType::INT64 ||
srcTensor.dtype() == paddle::experimental::DataType::FLOAT32 ||
......@@ -33,8 +33,8 @@ int SetMeta(const LoDTensor& srcTensor, LoDTensor* dstTensor) {
return xpu::Error_t::SUCCESS;
}
template <typename T>
int CopyTensorByXPU(const LoDTensor& srcTensor,
LoDTensor* dstTensor,
int CopyTensorByXPU(const phi::DenseTensor& srcTensor,
phi::DenseTensor* dstTensor,
int flag,
const Place& place) {
const T* srcData = srcTensor.template data<T>();
......@@ -67,8 +67,8 @@ int CopyTensorByXPU(const LoDTensor& srcTensor,
return xpu::Error_t::SUCCESS;
}
const int CopyTensorByType(const LoDTensor& srcTensor,
LoDTensor* dstTensor,
const int CopyTensorByType(const phi::DenseTensor& srcTensor,
phi::DenseTensor* dstTensor,
int flag,
const Place& place) {
int r = 0;
......@@ -97,8 +97,8 @@ const int CopyTensorByType(const LoDTensor& srcTensor,
struct BeamSearchDecodeXPUFunctor {
BeamSearchDecodeXPUFunctor(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor,
LoDTensor* score_tensor,
phi::DenseTensor* id_tensor,
phi::DenseTensor* score_tensor,
size_t beam_size,
int end_id)
: beam_size_(beam_size),
......@@ -164,8 +164,8 @@ struct BeamSearchDecodeXPUFunctor {
// scenarios.
LoDTensorArray step_ids_ = LoDTensorArray();
LoDTensorArray step_scores_ = LoDTensorArray();
LoDTensor* id_tensor_;
LoDTensor* score_tensor_;
phi::DenseTensor* id_tensor_;
phi::DenseTensor* score_tensor_;
};
} // namespace operators
......
......@@ -19,7 +19,6 @@ limitations under the License. */
using CPUPlace = paddle::platform::CPUPlace;
using XPUPlace = paddle::platform::XPUPlace;
using LoD = paddle::framework::LoD;
using LoDTensor = phi::DenseTensor;
using LoDTensorArray = paddle::framework::LoDTensorArray;
template <typename T>
......@@ -67,7 +66,7 @@ void GenerateXPUExample(const std::vector<size_t>& level_0,
lod.push_back(level_1);
// Ids
LoDTensor tensor_id_cpu;
phi::DenseTensor tensor_id_cpu;
tensor_id_cpu.set_lod(lod);
tensor_id_cpu.Resize({static_cast<int64_t>(data.size())});
// malloc memory
......@@ -76,7 +75,7 @@ void GenerateXPUExample(const std::vector<size_t>& level_0,
id_cpu_ptr[i] = static_cast<int64_t>(data.at(i));
}
LoDTensor tensor_id;
phi::DenseTensor tensor_id;
const phi::DenseTensorMeta meta_data_id(paddle::experimental::DataType::INT64,
tensor_id_cpu.dims());
tensor_id.set_meta(meta_data_id);
......@@ -90,7 +89,7 @@ void GenerateXPUExample(const std::vector<size_t>& level_0,
tensor_id_cpu.numel() * sizeof(int64_t));
// Scores
LoDTensor tensor_score_cpu;
phi::DenseTensor tensor_score_cpu;
tensor_score_cpu.set_lod(lod);
tensor_score_cpu.Resize({static_cast<int64_t>(data.size())});
// malloc memory
......@@ -99,7 +98,7 @@ void GenerateXPUExample(const std::vector<size_t>& level_0,
score_cpu_ptr[i] = static_cast<T>(data.at(i));
}
LoDTensor tensor_score;
phi::DenseTensor tensor_score;
if (std::is_same<float, T>::value) {
const phi::DenseTensorMeta meta_data_score(
......@@ -178,8 +177,8 @@ void BeamSearchDecodeTestByXPUFrame() {
ASSERT_EQ(ids.size(), 5UL);
ASSERT_EQ(scores.size(), 5UL);
LoDTensor id_tensor_cpu;
LoDTensor score_tensor_cpu;
phi::DenseTensor id_tensor_cpu;
phi::DenseTensor score_tensor_cpu;
paddle::operators::BeamSearchDecodeXPUFunctor bs_xpu(
ids, scores, &id_tensor_cpu, &score_tensor_cpu, 2, 1);
......
......@@ -27,37 +27,42 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
// inputs and outputs stored in proto
AddInput("pre_ids",
"(LoDTensor) The LoDTensor containing the selected ids at the "
"(phi::DenseTensor) The phi::DenseTensor containing the selected "
"ids at the "
"previous step. It should be a tensor with shape (batch_size, 1) "
"and lod `[[0, 1, ... , batch_size], [0, 1, ..., batch_size]]` at "
"the first step.");
AddInput("pre_scores",
"(LoDTensor) The LoDTensor containing the accumulated "
"scores corresponding to the selected ids at the previous step.");
AddInput(
"pre_scores",
"(phi::DenseTensor) The phi::DenseTensor containing the accumulated "
"scores corresponding to the selected ids at the previous step.");
AddInput("ids",
"(LoDTensor) The LoDTensor containing the candidates ids. Its "
"(phi::DenseTensor) The phi::DenseTensor containing the "
"candidates ids. Its "
"shape should be (batch_size * beam_size, W). If not set, it will "
"be calculated out according to Input(scores) in this operator.")
.AsDispensable();
AddInput("scores",
"(LoDTensor) The LoDTensor containing the current scores "
"corresponding to Input(ids). If Input(ids) is not nullptr, its "
"shape is the same as that of Input(ids)."
"If is_accumulated is true, Input(scores) is accumulated scores "
"and will be used derectedly. Else, each score will be "
"transformed to the log field and accumulate Input(pre_sores) "
"first.");
AddInput(
"scores",
"(phi::DenseTensor) The phi::DenseTensor containing the current scores "
"corresponding to Input(ids). If Input(ids) is not nullptr, its "
"shape is the same as that of Input(ids)."
"If is_accumulated is true, Input(scores) is accumulated scores "
"and will be used derectedly. Else, each score will be "
"transformed to the log field and accumulate Input(pre_sores) "
"first.");
AddOutput("selected_ids",
"A LodTensor that stores the IDs selected by beam search.");
AddOutput("selected_scores",
"A LoDTensor containing the accumulated scores corresponding to "
"Output(selected_ids).");
AddOutput(
"selected_scores",
"A phi::DenseTensor containing the accumulated scores corresponding to "
"Output(selected_ids).");
AddOutput("parent_idx",
"A Tensor preserving the selected_ids' parent index in pre_ids.")
.AsDispensable();
// Attributes stored in AttributeMap
AddAttr<int>("level", "the level of LoDTensor");
AddAttr<int>("level", "the level of phi::DenseTensor");
AddAttr<int>("beam_size", "beam size for beam search");
AddAttr<int>("end_id",
"the token id which indicates the end of a sequence");
......
......@@ -41,12 +41,13 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel {
class BroadcastTensorsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"A Varaible list. The shape and data type of the list elements"
"should be consistent. Variable can be multi-dimensional Tensor"
"or LoDTensor, and data types can be: bool, float16, float32, "
"float64, int32, "
"int64.")
AddInput(
"X",
"A Varaible list. The shape and data type of the list elements"
"should be consistent. Variable can be multi-dimensional Tensor"
"or phi::DenseTensor, and data types can be: bool, float16, float32, "
"float64, int32, "
"int64.")
.AsDuplicable();
AddOutput("Out",
"the sum of input :code:`x`. its shape and data types are "
......@@ -54,7 +55,7 @@ class BroadcastTensorsOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable();
AddComment(
R"DOC(This OP is used to broadcast a vector of inputs
with Tensor or LoDTensor type, following broadcast semantics.)DOC");
with phi::DenseTensor type, following broadcast semantics.)DOC");
}
};
......
......@@ -31,12 +31,13 @@ class CheckMemoryContinueOp : public framework::OperatorWithKernel {
class CheckMemoryContinueOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(vector<LoDTensor>) The input tensors.").AsDuplicable();
AddOutput("Out", "(LoDTensor) The output tensor.").AsDuplicable();
AddOutput(
"XOut",
"(vector<LoDTensor>) The output tensors which are the same as x. It is "
"used to build the graph dependency");
AddInput("X", "(vector<phi::DenseTensor>) The input tensors.")
.AsDuplicable();
AddOutput("Out", "(phi::DenseTensor) The output tensor.").AsDuplicable();
AddOutput("XOut",
"(vector<phi::DenseTensor>) The output tensors which are the "
"same as x. It is "
"used to build the graph dependency");
AddComment(R"DOC(
CheckMemoryContinue Operator.
......
......@@ -23,8 +23,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
template <typename DeviceContext, typename T>
class ChunkEvalKernel : public framework::OpKernel<T> {
public:
......@@ -187,9 +185,9 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
context.Attr<std::vector<int>>("excluded_chunk_types").begin(),
context.Attr<std::vector<int>>("excluded_chunk_types").end());
auto* inference = context.Input<LoDTensor>("Inference");
auto* inference = context.Input<phi::DenseTensor>("Inference");
auto place = inference->place();
auto* label = context.Input<LoDTensor>("Label");
auto* label = context.Input<phi::DenseTensor>("Label");
auto* precision = context.Output<phi::DenseTensor>("Precision");
auto* recall = context.Output<phi::DenseTensor>("Recall");
auto* f1 = context.Output<phi::DenseTensor>("F1-Score");
......
......@@ -120,7 +120,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
in_var_names.size(),
out_var_names.size()));
// Input & Output check: only support LoDTensor
// Input & Output check: only support phi::DenseTensor
bool has_not_init_in_vars = false;
for (size_t i = 0; i < in_tensors.size(); ++i) {
PADDLE_ENFORCE_NOT_NULL(
......@@ -426,17 +426,17 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(vector<LoDTensor>) The input tensors of"
"(vector<phi::DenseTensor>) The input tensors of"
" coalesce_tensor operator.")
.AsDuplicable();
AddOutput("Output",
"(vector<LoDTensor>) The output "
"(vector<phi::DenseTensor>) The output "
"tensors of coalesce_tensor operator. And the address "
"of output tensors are continuous, they are sliced from the "
"tensor of FusedOutput.")
.AsDuplicable();
AddOutput("FusedOutput",
"(LoDTensor) The output tensor "
"(phi::DenseTensor) The output tensor "
"of coalesce_tensor operator. And the tensors of"
" Output is sliced from the tensor of FusedOutput.");
AddAttr<int>("dtype", "The output data type.");
......
......@@ -154,7 +154,7 @@ void BinaryOpBroadcastInferShape(framework::InferShapeContext *ctx) {
ctx->GetInputsVarType(y_name).front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The var type of input %s should be LoDTensor, but got %s.",
"The var type of input %s should be phi::DenseTensor, but got %s.",
ctx->Inputs(y_name).front(),
ctx->GetInputsVarType(y_name).front()));
......
......@@ -30,7 +30,6 @@ class OpBase;
} // namespace imperative
} // namespace paddle
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
namespace paddle {
......@@ -64,7 +63,7 @@ class CopyCrossScopeOp : public framework::OperatorBase {
PADDLE_ENFORCE_NOT_NULL(
id_var,
platform::errors::NotFound("No variable with name %s found.", id_name));
auto id_tensor = id_var->GetMutable<LoDTensor>();
auto id_tensor = id_var->GetMutable<phi::DenseTensor>();
auto it = scope.kids().begin();
phi::DenseTensor cpu_id_tensor;
paddle::framework::TensorCopySync(
......@@ -88,8 +87,8 @@ class CopyCrossScopeOp : public framework::OperatorBase {
platform::errors::NotFound(
"No variable with name %s found in destination scope.",
x_name));
auto dst_tensor = dst_var->GetMutable<LoDTensor>();
auto main_tensor = main_var->GetMutable<LoDTensor>();
auto dst_tensor = dst_var->GetMutable<phi::DenseTensor>();
auto main_tensor = main_var->GetMutable<phi::DenseTensor>();
paddle::framework::TensorCopySync(
*dst_tensor, main_tensor->place(), main_tensor);
}
......@@ -109,8 +108,8 @@ class CopyCrossScopeOp : public framework::OperatorBase {
dst_var,
platform::errors::NotFound(
"No variable with name %s found in destination scope.", x_name));
auto src_tensor = source_var->GetMutable<LoDTensor>();
auto dst_tensor = dst_var->GetMutable<LoDTensor>();
auto src_tensor = source_var->GetMutable<phi::DenseTensor>();
auto dst_tensor = dst_var->GetMutable<phi::DenseTensor>();
paddle::framework::TensorCopySync(
*src_tensor, dst_tensor->place(), dst_tensor);
......@@ -120,7 +119,7 @@ class CopyCrossScopeOp : public framework::OperatorBase {
main_var,
platform::errors::NotFound(
"No variable with name %s found in destination scope.", x_name));
auto main_tensor = main_var->GetMutable<LoDTensor>();
auto main_tensor = main_var->GetMutable<phi::DenseTensor>();
paddle::framework::TensorCopySync(
*dst_tensor, main_tensor->place(), main_tensor);
}
......
......@@ -21,7 +21,8 @@ class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput(
"Emission",
"(Tensor/LoDTensor). For a LoDTensor input, its shape is [N x D] "
"(Tensor/phi::DenseTensor). For a phi::DenseTensor input, its shape is "
"[N x D] "
"where N is the total sequence length of the mini-batch and D is "
"the total tag number. While for a tensor input, its shape is "
"[B X S X D] with B the batch size and S the sequence length of each "
......@@ -39,14 +40,14 @@ class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
"The data type is the same as Input(Emission).");
AddInput(
"Label",
"(Tensor/LoDTensor). The ground truth with shape "
"[N x 1] (for LoDTensor) or [B x S] (for Tensor). This input is "
"(phi::DenseTensor). The ground truth with shape "
"[N x 1] (for phi::DenseTensor) or [B x S] (for Tensor). This input is "
"optional. See more details in the operator's comments. The data type "
"is int64.")
.AsDispensable();
AddOutput(
"ViterbiPath",
"(Tensor/LoDTensor). The decoding results. What to "
"(phi::DenseTensor). The decoding results. What to "
"return changes depending on whether the Input(Label) (the ground "
"truth) is given. See more details in the operator's comment. "
"The data type is int64.");
......
......@@ -24,15 +24,14 @@ namespace paddle {
namespace operators {
using framework::LoD;
using LoDTensor = phi::DenseTensor;
template <typename DeviceContext, typename T>
class CRFDecodingOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* emission_weights = ctx.Input<LoDTensor>("Emission");
auto* emission_weights = ctx.Input<phi::DenseTensor>("Emission");
auto* transition_weights = ctx.Input<phi::DenseTensor>("Transition");
auto* label = ctx.Input<LoDTensor>("Label");
auto* label = ctx.Input<phi::DenseTensor>("Label");
auto* decoded_path = ctx.Output<phi::DenseTensor>("ViterbiPath");
int64_t* path = decoded_path->mutable_data<int64_t>(platform::CPUPlace());
......
......@@ -85,8 +85,8 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"CTCAlign operator CUDA kernel must use CUDAPlace "
"rather than CPUPlace."));
auto* input = ctx.Input<LoDTensor>("Input");
auto* output = ctx.Output<LoDTensor>("Output");
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* output = ctx.Output<phi::DenseTensor>("Output");
const int blank = ctx.Attr<int>("blank");
const int merge_repeated =
static_cast<int>(ctx.Attr<bool>("merge_repeated"));
......@@ -99,9 +99,9 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
auto input_dims = input->dims();
T* output_data = output->mutable_data<T>({input_dims[0], input_dims[1]},
ctx.GetPlace());
auto* input_length = ctx.Input<LoDTensor>("InputLength");
auto* input_length = ctx.Input<phi::DenseTensor>("InputLength");
const T* input_length_data = input_length->data<T>();
auto* output_length = ctx.Output<LoDTensor>("OutputLength");
auto* output_length = ctx.Output<phi::DenseTensor>("OutputLength");
T* output_length_data =
output_length->mutable_data<T>({input_dims[0], 1}, ctx.GetPlace());
PaddingMergeAndDelCudaKernel<T>
......
......@@ -25,14 +25,13 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
template <typename DeviceContext, typename T>
class CTCAlignKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
auto* output = ctx.Output<LoDTensor>("Output");
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* output = ctx.Output<phi::DenseTensor>("Output");
size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
bool merge_repeated = ctx.Attr<bool>("merge_repeated");
T* output_data = output->mutable_data<T>(ctx.GetPlace());
......@@ -43,10 +42,10 @@ class CTCAlignKernel : public framework::OpKernel<T> {
if (input->lod().empty()) {
size_t padding_value =
static_cast<size_t>(ctx.Attr<int>("padding_value"));
auto* input_length = ctx.Input<LoDTensor>("InputLength");
auto* input_length = ctx.Input<phi::DenseTensor>("InputLength");
const T* input_length_data = input_length->data<T>();
auto* output_length = ctx.Output<LoDTensor>("OutputLength");
auto* output_length = ctx.Output<phi::DenseTensor>("OutputLength");
T* output_length_data = output_length->mutable_data<T>(ctx.GetPlace());
for (size_t batch_id = 0; batch_id < (unsigned)input_dims[0];
......
......@@ -26,7 +26,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
template <typename T, typename Type>
......
......@@ -23,7 +23,6 @@ namespace operators {
using phi::PADDLE_CUDA_NUM_THREADS;
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
template <typename T>
__global__ void CvmComputeKernel(const bool use_cvm,
......@@ -87,7 +86,7 @@ template <typename T>
class CVMCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* x = context.Input<LoDTensor>("X");
const auto* x = context.Input<phi::DenseTensor>("X");
const T* x_data = x->data<T>();
auto batch_size = x->dims()[0];
......@@ -95,7 +94,7 @@ class CVMCUDAKernel : public framework::OpKernel<T> {
auto item_size = numel / batch_size;
auto use_cvm = context.Attr<bool>("use_cvm");
auto* y = context.Output<LoDTensor>("Y");
auto* y = context.Output<phi::DenseTensor>("Y");
T* y_data = y->mutable_data<T>(context.GetPlace());
// for Input X do not have Lod Information.
......@@ -128,7 +127,7 @@ template <typename T>
class CVMGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dx = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* dx = context.Output<phi::DenseTensor>(framework::GradVarName("X"));
T* dx_data = dx->mutable_data<T>(context.GetPlace());
const phi::DenseTensor* cvm = context.Input<phi::DenseTensor>("CVM");
......
......@@ -20,7 +20,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
template <typename T>
void CvmComputeKernel(const bool use_cvm,
......@@ -61,14 +60,14 @@ template <typename T>
class CVMOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* x = context.Input<LoDTensor>("X");
const auto* x = context.Input<phi::DenseTensor>("X");
const T* x_data = x->data<T>();
auto batch_size = x->dims()[0];
auto item_size = x->numel() / batch_size;
auto use_cvm = context.Attr<bool>("use_cvm");
auto* y = context.Output<LoDTensor>("Y");
auto* y = context.Output<phi::DenseTensor>("Y");
T* y_data = y->mutable_data<T>(context.GetPlace());
// for Input X do not have Lod Information.
......@@ -102,7 +101,7 @@ template <typename T>
class CVMGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dx = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* dx = context.Output<phi::DenseTensor>(framework::GradVarName("X"));
T* dx_data = dx->mutable_data<T>(context.GetPlace());
const phi::DenseTensor* cvm = context.Input<phi::DenseTensor>("CVM");
......
......@@ -24,7 +24,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using DataLayout = phi::DataLayout;
template <typename T>
......@@ -487,8 +486,8 @@ class DataNormGradOp : public framework::OperatorWithKernel {
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
}
if (t == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
......
......@@ -27,7 +27,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using DataLayout = phi::DataLayout;
using phi::PADDLE_CUDA_NUM_THREADS;
......
......@@ -33,9 +33,9 @@ class DeformablePSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"H is height of the feature, and "
"W is the width of the feature.");
AddInput("ROIs",
"(LoDTensor), "
"(phi::DenseTensor), "
"ROIs (Regions of Interest) to pool over. "
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"ROIs should be a 2-D phi::DenseTensor of shape (num_rois, 4) "
"given as [[x1, y1, x2, y2], ...]. "
"(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates.");
......@@ -149,7 +149,8 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel {
rois_dims.size(),
2,
platform::errors::InvalidArgument(
"Input(ROIs) should be a 2-D LoDTensor of shape (num_rois, 4) "
"Input(ROIs) should be a 2-D phi::DenseTensor of shape (num_rois, "
"4) "
"given as [[ x1, y1, x2, y2], ...]. The rank of Input(ROIs) should "
"be 2, but received ROIs rank is:%d, ROIs shape is:[%s].",
rois_dims.size(),
......
......@@ -40,7 +40,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using phi::PADDLE_CUDA_NUM_THREADS;
static inline int GET_BLOCKS(const int N) {
......@@ -185,7 +184,7 @@ class DeformablePSROIPoolCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const phi::DenseTensor* input = ctx.Input<phi::DenseTensor>("Input");
const LoDTensor* rois = ctx.Input<LoDTensor>("ROIs");
const phi::DenseTensor* rois = ctx.Input<phi::DenseTensor>("ROIs");
const phi::DenseTensor* trans = ctx.Input<phi::DenseTensor>("Trans");
phi::DenseTensor* out = ctx.Output<phi::DenseTensor>("Output");
out->mutable_data<T>(ctx.GetPlace());
......@@ -486,7 +485,7 @@ class DeformablePSROIPoolGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const phi::DenseTensor* input = ctx.Input<phi::DenseTensor>("Input");
const LoDTensor* rois = ctx.Input<LoDTensor>("ROIs");
const phi::DenseTensor* rois = ctx.Input<phi::DenseTensor>("ROIs");
const phi::DenseTensor* trans = ctx.Input<phi::DenseTensor>("Trans");
const phi::DenseTensor* top_count = ctx.Input<phi::DenseTensor>("TopCount");
const phi::DenseTensor* output_grad =
......
......@@ -34,7 +34,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
template <typename T>
T bilinear_interp(
......@@ -80,7 +79,7 @@ void DeformablePSROIPoolForwardCPUKernel(const int count,
T* top_count,
const int batch_size,
int* roi_batch_id_data,
const LoDTensor* rois) {
const phi::DenseTensor* rois) {
for (int ix = 0; ix < count; ix++) {
int pw = ix % pooled_width;
int ph = (ix / pooled_width) % pooled_height;
......@@ -174,7 +173,7 @@ class DeformablePSROIPoolCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* rois = ctx.Input<phi::DenseTensor>("ROIs");
auto* trans = ctx.Input<phi::DenseTensor>("Trans");
auto* out = ctx.Output<phi::DenseTensor>("Output");
out->mutable_data<T>(ctx.GetPlace());
......@@ -316,7 +315,7 @@ void DeformablePSROIPoolBackwardAccCPUKernel(const int count,
const int channels_each_class,
const int batch_size,
int* roi_batch_id_data,
const LoDTensor* rois) {
const phi::DenseTensor* rois) {
for (int index = 0; index < count; index++) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
......@@ -476,7 +475,7 @@ class DeformablePSROIPoolGradCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* rois = ctx.Input<phi::DenseTensor>("ROIs");
auto* trans = ctx.Input<phi::DenseTensor>("Trans");
auto* top_count = ctx.Input<phi::DenseTensor>("TopCount");
auto* output_grad =
......
......@@ -20,7 +20,6 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
using LoDTensor = phi::DenseTensor;
using LoDTensorBlockingQueueHolder =
paddle::operators::reader::LoDTensorBlockingQueueHolder;
......@@ -59,7 +58,7 @@ class DequeueOp : public framework::OperatorBase {
out_var,
platform::errors::NotFound("No variable with name %s found",
out_names[i]));
auto* out_tensor = out_var->GetMutable<LoDTensor>();
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_NOT_NULL(
out_tensor,
platform::errors::InvalidArgument(
......
......@@ -103,7 +103,8 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("DetectRes",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"(phi::DenseTensor) A 2-D phi::DenseTensor with shape [M, 6] "
"represents the "
"detections. Each row has 6 values: "
"[label, confidence, xmin, ymin, xmax, ymax], M is the total "
"number of detect results in this mini-batch. For each instance, "
......@@ -111,7 +112,7 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected data.");
AddInput("Label",
"(LoDTensor) A 2-D LoDTensor represents the"
"(phi::DenseTensor) A 2-D phi::DenseTensor represents the"
"Labeled ground-truth data. Each row has 6 values: "
"[label, xmin, ymin, xmax, ymax, is_difficult] or 5 values: "
"[label, xmin, ymin, xmax, ymax], where N is the total "
......@@ -135,14 +136,16 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
"current mini-batch are calculated.")
.AsDispensable();
AddInput("TruePos",
"(LoDTensor) A 2-D LoDTensor with shape [Ntp, 2], store the "
"(phi::DenseTensor) A 2-D phi::DenseTensor with shape [Ntp, 2], "
"store the "
"input true positive example of each class."
"This input is used to pass the AccumTruePos generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. ")
.AsDispensable();
AddInput("FalsePos",
"(LoDTensor) A 2-D LoDTensor with shape [Nfp, 2], store the "
"(phi::DenseTensor) A 2-D phi::DenseTensor with shape [Nfp, 2], "
"store the "
"input false positive example of each class."
"This input is used to pass the AccumFalsePos generated by the "
"previous mini-batch when the multi mini-batches cumulative "
......@@ -153,16 +156,18 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
"positive example count of each class. It combines the input "
"input(PosCount) and the positive example count computed from "
"input(Detection) and input(Label).");
AddOutput("AccumTruePos",
"(LoDTensor) A LoDTensor with shape [Ntp', 2], store the "
"true positive example of each class. It combines the "
"input(TruePos) and the true positive examples computed from "
"input(Detection) and input(Label).");
AddOutput("AccumFalsePos",
"(LoDTensor) A LoDTensor with shape [Nfp', 2], store the "
"false positive example of each class. It combines the "
"input(FalsePos) and the false positive examples computed from "
"input(Detection) and input(Label).");
AddOutput(
"AccumTruePos",
"(phi::DenseTensor) A phi::DenseTensor with shape [Ntp', 2], store the "
"true positive example of each class. It combines the "
"input(TruePos) and the true positive examples computed from "
"input(Detection) and input(Label).");
AddOutput(
"AccumFalsePos",
"(phi::DenseTensor) A phi::DenseTensor with shape [Nfp', 2], store the "
"false positive example of each class. It combines the "
"input(FalsePos) and the false positive examples computed from "
"input(Detection) and input(Label).");
AddOutput("MAP",
"(Tensor) A tensor with shape [1], store the mAP evaluate "
"result of the detection.");
......
......@@ -35,11 +35,11 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Hyps",
"2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
"2-D Tensor<int64_t>, or 2-D phi::DenseTensor<int64_t> with last "
"dimension being 1. "
"The indices for hypothesis strings.");
AddInput("Refs",
"2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
"2-D Tensor<int64_t>, or 2-D phi::DenseTensor<int64_t> with last "
"dimension being 1. "
"The indices for reference strings.");
AddInput("HypsLength",
......@@ -75,7 +75,7 @@ insertion:
So the edit distance between A and B is 3.
Input(Hyps) is a 2-D Tensor or a 2-D LoDTensor consisting of all the hypothesis strings.
Input(Hyps) is a 2-D Tensor or a 2-D phi::DenseTensor consisting of all the hypothesis strings.
And the `batch_size` reference strings are arranged in order in the same way in the
Input(Refs).
......
......@@ -31,7 +31,6 @@ class OpBase;
} // namespace imperative
} // namespace paddle
using LoDTensor = phi::DenseTensor;
using LoDTensorBlockingQueueHolder =
paddle::operators::reader::LoDTensorBlockingQueueHolder;
......@@ -61,7 +60,7 @@ class EnqueueOp : public framework::OperatorBase {
PADDLE_ENFORCE_NOT_NULL(in_var,
platform::errors::NotFound(
"No variable with name %s found.", var_name));
auto* in_tensor = in_var->GetMutable<LoDTensor>();
auto* in_tensor = in_var->GetMutable<phi::DenseTensor>();
auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
......
......@@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOpMLUKernel : public framework::OpKernel<T> {
auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *in = ctx.Input<phi::DenseTensor>("Input");
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
// set the correct batch size for the LoDTensor.
// set the correct batch size for the phi::DenseTensor.
auto odims = out->dims();
int output_dim_idx = ctx.Attr<int>("output_dim_idx");
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
......
......@@ -35,7 +35,7 @@ class FillConstantBatchSizeLikeOpNPUKernel : public framework::OpKernel<T> {
auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *in = ctx.Input<phi::DenseTensor>("Input");
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
// set the correct batch size for the LoDTensor.
// set the correct batch size for the phi::DenseTensor.
auto odims = out->dims();
int output_dim_idx = ctx.Attr<int>("output_dim_idx");
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
......
......@@ -27,7 +27,7 @@ class FillOpMaker : public framework::OpProtoAndCheckerMaker {
Fill an tensor with `value` and `shape`. The type of the tensor is specify by
`dtype`.
)DOC");
AddOutput("Out", "(LoDTensor) The output tensor.");
AddOutput("Out", "(phi::DenseTensor) The output tensor.");
AddAttr<std::vector<float>>(
"value", "The float values of tensor, which are flatten in row major");
AddAttr<std::vector<int>>("shape", "The shape of output tensor");
......
......@@ -69,16 +69,17 @@ class FilterByInstagOp : public framework::OperatorWithKernel {
class FilterByInstagOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ins", "(LoDTensor) embeded tensor");
AddInput("Ins_tag", "(LoDTensor) ins tag list");
AddInput("Ins", "(phi::DenseTensor) embeded tensor");
AddInput("Ins_tag", "(phi::DenseTensor) ins tag list");
AddInput("Filter_tag", "(1D Tensor) filter tag list");
AddAttr<bool>("is_lod", "is Ins with LoD info or not, default True");
AddAttr<int64_t>("out_val_if_empty",
"if the output after filter is empty, the output value")
.SetDefault(0);
AddOutput("Out", "(LoDTensor) embeded tensor filtered by instag");
AddOutput("Out", "(phi::DenseTensor) embeded tensor filtered by instag");
AddOutput("LossWeight", "(Tensor) loss weight.");
AddOutput("IndexMap", "(LoDTensor) mapping from Out rows to X1 rows");
AddOutput("IndexMap",
"(phi::DenseTensor) mapping from Out rows to X1 rows");
AddComment(R"DOC(
Filter By Instag Op
......
......@@ -45,7 +45,6 @@ namespace operators {
using Tensor = phi::DenseTensor;
using SelectedRows = phi::SelectedRows;
using LoDTensor = phi::DenseTensor;
template <typename T>
using Vector = framework::Vector<T>;
......@@ -341,7 +340,7 @@ class FilterByInstagGPUKernel : public framework::OpKernel<T> {
// context.cuda_device_context().GetMaxThreadsPerBlock();
// X1 is global FC output
// Dim [batch size, embedding size]
const LoDTensor* x1 = context.Input<LoDTensor>("Ins");
const phi::DenseTensor* x1 = context.Input<phi::DenseTensor>("Ins");
bool is_lod = context.Attr<bool>("is_lod");
int is_x1_lod = -1;
......@@ -354,7 +353,7 @@ class FilterByInstagGPUKernel : public framework::OpKernel<T> {
size_t x1_embed_size = x1->dims()[1];
// X2 is ins tag list
// LoD [[0, Sum(ins1), Sum(ins1, ins2), ... ]]
const LoDTensor* x2 = context.Input<LoDTensor>("Ins_tag");
const phi::DenseTensor* x2 = context.Input<phi::DenseTensor>("Ins_tag");
// expected auto = const int64_t
const int64_t* x2_data = x2->data<int64_t>();
......@@ -389,7 +388,7 @@ class FilterByInstagGPUKernel : public framework::OpKernel<T> {
x1_lods.push_back(i + 1);
}
} else {
// x1_lods = context.Input<LoDTensor>("Ins")->lod()[0];
// x1_lods = context.Input<phi::DenseTensor>("Ins")->lod()[0];
// new: lod_level=0 => lod() return {}
if (x1->lod().size() != 0) { // lod_level = 1
x1_lods = x1->lod()[0];
......@@ -412,9 +411,10 @@ class FilterByInstagGPUKernel : public framework::OpKernel<T> {
// for those whose ins been dropout, set 0 for whole lines.
// otherwise, copy whole line
// Dim [local fc count, batch size, embedding size]
LoDTensor* out = context.Output<LoDTensor>("Out");
LoDTensor* map = context.Output<LoDTensor>("IndexMap");
LoDTensor* loss_weight = context.Output<LoDTensor>("LossWeight");
phi::DenseTensor* out = context.Output<phi::DenseTensor>("Out");
phi::DenseTensor* map = context.Output<phi::DenseTensor>("IndexMap");
phi::DenseTensor* loss_weight =
context.Output<phi::DenseTensor>("LossWeight");
int out_first = x1_lods.back();
......@@ -563,13 +563,15 @@ class FilterByInstagGradGPUKernel : public framework::OpKernel<T> {
auto gpu_place = context.GetPlace();
gpuStream_t current_stream = context.cuda_device_context().stream();
auto max_thread_num_per_block = 1024;
auto* output_grad = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x1_grad = context.Output<LoDTensor>(framework::GradVarName("Ins"));
auto* loss_weight = context.Input<LoDTensor>("LossWeight");
auto* mmap = context.Input<LoDTensor>("IndexMap");
auto* x1 = context.Input<LoDTensor>("Ins");
x1_grad->set_lod(context.Input<LoDTensor>("Ins")->lod());
auto* output_grad =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto* x1_grad =
context.Output<phi::DenseTensor>(framework::GradVarName("Ins"));
auto* loss_weight = context.Input<phi::DenseTensor>("LossWeight");
auto* mmap = context.Input<phi::DenseTensor>("IndexMap");
auto* x1 = context.Input<phi::DenseTensor>("Ins");
x1_grad->set_lod(context.Input<phi::DenseTensor>("Ins")->lod());
x1_grad->Resize(x1->dims());
auto* mmap_data = mmap->data<int64_t>();
......
......@@ -31,7 +31,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using SelectedRows = phi::SelectedRows;
using LoDTensor = phi::DenseTensor;
template <typename T>
using Vector = framework::Vector<T>;
......@@ -42,12 +41,12 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
// X1 is global FC output
// Dim [batch size, embedding size]
auto* x1 = context.Input<LoDTensor>("Ins");
auto* x1 = context.Input<phi::DenseTensor>("Ins");
bool is_x1_lod = context.Attr<bool>("is_lod");
int64_t out_val_if_empty = context.Attr<int64_t>("out_val_if_empty");
// X2 is ins tag list
// LoD [[0, Sum(ins1), Sum(ins1, ins2), ... ]]
auto* x2 = context.Input<LoDTensor>("Ins_tag");
auto* x2 = context.Input<phi::DenseTensor>("Ins_tag");
// X3 is local fc tag list
// LoD [[0, Sum(fc1), Sum(fc1, fc2) ...]]
auto* x3 = context.Input<phi::DenseTensor>("Filter_tag");
......@@ -107,9 +106,10 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
// for those whose ins been dropout, set 0 for whole lines.
// otherwise, copy whole line
// Dim [local fc count, batch size, embedding size]
LoDTensor* out = context.Output<LoDTensor>("Out");
LoDTensor* map = context.Output<LoDTensor>("IndexMap");
LoDTensor* loss_weight = context.Output<LoDTensor>("LossWeight");
phi::DenseTensor* out = context.Output<phi::DenseTensor>("Out");
phi::DenseTensor* map = context.Output<phi::DenseTensor>("IndexMap");
phi::DenseTensor* loss_weight =
context.Output<phi::DenseTensor>("LossWeight");
// expected auto = const T
auto* x1_data = x1->data<T>();
// expected auto = T
......@@ -196,12 +196,14 @@ template <typename T>
class FilterByInstagGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* output_grad = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x1_grad = context.Output<LoDTensor>(framework::GradVarName("Ins"));
auto* loss_weight = context.Input<LoDTensor>("LossWeight");
auto* mmap = context.Input<LoDTensor>("IndexMap");
auto* x1 = context.Input<LoDTensor>("Ins");
x1_grad->set_lod(context.Input<LoDTensor>("Ins")->lod());
auto* output_grad =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto* x1_grad =
context.Output<phi::DenseTensor>(framework::GradVarName("Ins"));
auto* loss_weight = context.Input<phi::DenseTensor>("LossWeight");
auto* mmap = context.Input<phi::DenseTensor>("IndexMap");
auto* x1 = context.Input<phi::DenseTensor>("Ins");
x1_grad->set_lod(context.Input<phi::DenseTensor>("Ins")->lod());
x1_grad->Resize(x1->dims());
auto mmap_data = mmap->data<int64_t>();
// expected auto = T
......
......@@ -35,13 +35,14 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel {
"but the received is %s",
ctx->Inputs("X").front(),
ctx->GetInputsVarType("X").front()));
PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The output Out(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx->Outputs("Out").front(),
ctx->GetOutputsVarType("Out").front()));
PADDLE_ENFORCE_EQ(
ctx->GetOutputsVarType("Out").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The output Out(%s)'s type should be phi::DenseTensor, "
"but the received is %s",
ctx->Outputs("Out").front(),
ctx->GetOutputsVarType("Out").front()));
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
......@@ -72,7 +73,7 @@ class GetTensorFromSelectedRowsOpProtoMaker
public:
void Make() override {
AddInput("X", "The input type is SelectedRows.");
AddOutput("Out", "The output type is LoDTensor.");
AddOutput("Out", "The output type is phi::DenseTensor.");
AddComment(
R"DOC(
GetTensorFromSelectedRows Operator
......
......@@ -29,7 +29,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using DataLayout = phi::DataLayout;
class GroupNormOp : public framework::OperatorWithKernel {
......@@ -127,8 +126,8 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
}
PADDLE_ENFORCE_NOT_NULL(
t,
......
......@@ -29,7 +29,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using DataLayout = phi::DataLayout;
template <typename DeviceContext, typename T>
......
......@@ -115,11 +115,12 @@ class GRUOp : public framework::OperatorWithKernel {
class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(LoDTensor) The first input is a LodTensor, which supports "
"variable-time length input sequence. The underlying tensor in "
"this LoDTenosr is a matrix with shape (T X 3D), where, T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput(
"Input",
"(phi::DenseTensor) The first input is a LodTensor, which supports "
"variable-time length input sequence. The underlying tensor in "
"this phi::DenseTensor is a matrix with shape (T X 3D), where, T is "
"the total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) The initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
......@@ -136,35 +137,38 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor, optional) Bias vector with shape (1 x 3D) concating "
"bias of the update gate, reset gate and output candidate.")
.AsDispensable();
AddOutput("BatchGate",
"(LoDTensor) To compute with batches, sequence data will be "
"reorganized into several successive batches each containing "
"data from the same time step. The LoDTensor BatchGate contains "
"the update gate, reset gate and output candidate values "
"organized in batches. The LoD size is 2. The first LoD contains "
"the batch offsets and the second LoD contains the indexes in "
"the raw sequence data.")
AddOutput(
"BatchGate",
"(phi::DenseTensor) To compute with batches, sequence data will be "
"reorganized into several successive batches each containing "
"data from the same time step. The phi::DenseTensor BatchGate contains "
"the update gate, reset gate and output candidate values "
"organized in batches. The LoD size is 2. The first LoD contains "
"the batch offsets and the second LoD contains the indexes in "
"the raw sequence data.")
.AsIntermediate()
.AsExtra();
AddOutput(
"BatchResetHiddenPrev",
"(LoDTensor) The reset hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.")
AddOutput("BatchResetHiddenPrev",
"(phi::DenseTensor) The reset hidden state phi::DenseTensor "
"organized in batches. "
"This phi::DenseTensor is a matrix with shape (T X D) and has "
"the same LoD "
"with `BatchGate`.")
.AsIntermediate()
.AsExtra();
AddOutput(
"BatchHidden",
"(LoDTensor) The hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.")
AddOutput("BatchHidden",
"(phi::DenseTensor) The hidden state phi::DenseTensor organized "
"in batches. "
"This phi::DenseTensor is a matrix with shape (T X D) and has "
"the same LoD "
"with `BatchGate`.")
.AsIntermediate()
.AsExtra();
AddOutput(
"Hidden",
"(LoDTensor) the hidden state LoDTensor organized in sequences. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.");
AddOutput("Hidden",
"(phi::DenseTensor) the hidden state phi::DenseTensor organized "
"in sequences. "
"This phi::DenseTensor is a matrix with shape (T X D) and has "
"the same LoD with `BatchGate`.");
AddAttr<std::string>("activation",
"(string, default tanh) "
"The activation type used for output candidate {h}_t.")
......@@ -314,23 +318,24 @@ class GRUCPUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
using DeviceContext = phi::CPUContext;
using LodTensorPtr = LoDTensor*;
using LodTensorPtr = phi::DenseTensor*;
bool is_test = context.Attr<bool>("is_test");
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input");
auto* input = context.Input<phi::DenseTensor>("Input");
auto* h0 = context.Input<phi::DenseTensor>("H0");
auto* weight = context.Input<phi::DenseTensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<phi::DenseTensor>("Bias");
auto* hidden = context.Output<LoDTensor>("Hidden");
auto* hidden = context.Output<phi::DenseTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
auto input_dims = input->dims();
auto hidden_dims = hidden->dims();
LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp,
batch_hidden_tmp;
if (is_test) {
batch_gate = &batch_gate_tmp;
batch_gate->Resize(input_dims);
......@@ -341,10 +346,10 @@ class GRUCPUKernel : public framework::OpKernel<T> {
batch_hidden = &batch_hidden_tmp;
batch_hidden->Resize(hidden_dims);
} else {
batch_gate = context.Output<LoDTensor>("BatchGate");
batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_gate = context.Output<phi::DenseTensor>("BatchGate");
batch_hidden = context.Output<phi::DenseTensor>("BatchHidden");
batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
context.Output<phi::DenseTensor>("BatchResetHiddenPrev");
}
batch_gate->mutable_data<T>(context.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
......
......@@ -21,23 +21,24 @@ template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
using LodTensorPtr = LoDTensor*;
using LodTensorPtr = phi::DenseTensor*;
bool is_test = context.Attr<bool>("is_test");
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input");
auto* input = context.Input<phi::DenseTensor>("Input");
auto* h0 = context.Input<phi::DenseTensor>("H0");
auto* weight = context.Input<phi::DenseTensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<phi::DenseTensor>("Bias");
auto* hidden = context.Output<LoDTensor>("Hidden");
auto* hidden = context.Output<phi::DenseTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
auto input_dims = input->dims();
auto hidden_dims = hidden->dims();
LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden;
LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp;
phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp,
batch_hidden_tmp;
if (is_test) {
batch_gate = &batch_gate_tmp;
batch_gate->Resize(input_dims);
......@@ -48,10 +49,10 @@ class GRUKernel : public framework::OpKernel<T> {
batch_hidden = &batch_hidden_tmp;
batch_hidden->Resize(hidden_dims);
} else {
batch_gate = context.Output<LoDTensor>("BatchGate");
batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_gate = context.Output<phi::DenseTensor>("BatchGate");
batch_hidden = context.Output<phi::DenseTensor>("BatchHidden");
batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
context.Output<phi::DenseTensor>("BatchResetHiddenPrev");
}
batch_gate->mutable_data<T>(context.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
......
......@@ -25,7 +25,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
template <typename DeviceContext, typename T>
......@@ -47,15 +46,15 @@ class GRUGradKernel : public framework::OpKernel<T> {
auto* h0 = context.Input<phi::DenseTensor>("H0");
auto* weight = context.Input<phi::DenseTensor>("Weight");
const T* weight_data = weight->data<T>();
auto* batch_gate = context.Input<LoDTensor>("BatchGate");
auto* batch_gate = context.Input<phi::DenseTensor>("BatchGate");
auto* batch_reset_hidden_prev =
context.Input<LoDTensor>("BatchResetHiddenPrev");
auto* batch_hidden = context.Input<LoDTensor>("BatchHidden");
auto* hidden = context.Input<LoDTensor>("Hidden");
context.Input<phi::DenseTensor>("BatchResetHiddenPrev");
auto* batch_hidden = context.Input<phi::DenseTensor>("BatchHidden");
auto* hidden = context.Input<phi::DenseTensor>("Hidden");
auto* hidden_grad =
context.Input<LoDTensor>(framework::GradVarName("Hidden"));
context.Input<phi::DenseTensor>(framework::GradVarName("Hidden"));
auto* input_grad =
context.Output<LoDTensor>(framework::GradVarName("Input"));
context.Output<phi::DenseTensor>(framework::GradVarName("Input"));
auto* h0_grad =
context.Output<phi::DenseTensor>(framework::GradVarName("H0"));
auto* weight_grad =
......@@ -68,7 +67,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
int frame_size = hidden_dims[1];
phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad;
phi::DenseTensor batch_hidden_grad, batch_gate_grad,
batch_reset_hidden_prev_grad;
batch_hidden_grad.mutable_data<T>(hidden_dims, context.GetPlace());
batch_gate_grad.mutable_data<T>(gate_dims, context.GetPlace());
batch_reset_hidden_prev_grad.mutable_data<T>(hidden_dims,
......
......@@ -82,43 +82,44 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor, required) The input tensor with shape [N, D], "
"(phi::DenseTensor, required) The input tensor with shape [N, D], "
"where N is the size of mini-batch, and D is the feature size.");
AddInput("W",
"(LoDTensor, required), The parameters of hierarchical "
"(phi::DenseTensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is a 2-D tensor, the shape is"
"[K, D]. Which K is the num of non-leaf node in Path Tree");
AddInput("Label",
"(LoDTensor, required), The labels of training data. It's a"
"(phi::DenseTensor, required), The labels of training data. It's a"
"tensor with shape [N, 1].");
AddInput("PathTable",
"(LoDTensor, optional), The Path Table from root to current word"
"it should have shape like [N, L], L is the length of the Path")
.AsDispensable();
AddInput(
"PathCode",
"(LoDTensor, optional), The Code on each Node of the Path from root "
"to current word"
"PathTable",
"(phi::DenseTensor, optional), The Path Table from root to current word"
"it should have shape like [N, L], L is the length of the Path")
.AsDispensable();
AddInput("PathCode",
"(phi::DenseTensor, optional), The Code on each Node of the Path "
"from root "
"to current word"
"it should have shape like [N, L], L is the length of the Path")
.AsDispensable();
AddInput("Bias",
"(LoDTensor, optional), The bias is a tensor with shape or "
"(phi::DenseTensor, optional), The bias is a tensor with shape or "
"[num_classes, 1]"
"[num_classes - 1, 1].")
.AsDispensable();
AddOutput(
"Out",
"(LoDTensor, required) The output of hierarchical sigmoid operator."
"The shape is [N, 1].");
AddOutput("Out",
"(phi::DenseTensor, required) The output of hierarchical sigmoid "
"operator."
"The shape is [N, 1].");
AddOutput("PreOut",
"(LoDTensor, required) A intermedia 2-D tensor with shape "
"(phi::DenseTensor, required) A intermedia 2-D tensor with shape "
"[batch_size, code_length], where code_length represents the "
"maximum path length from root to leaf nodes.")
.AsIntermediate();
AddOutput(
"W_Out",
"(LoDTensor, optional) using input 'W' as Output to make it mutable"
"When we are using prefetch")
AddOutput("W_Out",
"(phi::DenseTensor, optional) using input 'W' as Output to make "
"it mutable"
"When we are using prefetch")
.AsIntermediate();
AddAttr<AttrType>("num_classes", "(int, optional), The number of classes")
.SetDefault(2);
......@@ -227,7 +228,8 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
auto bias_grad_var_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_var_name)) {
VLOG(3) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor";
<< framework::GradVarName("Bias")
<< " is set to phi::DenseTensor";
ctx->SetOutputType(bias_grad_var_name,
framework::proto::VarType::LOD_TENSOR);
}
......@@ -241,7 +243,7 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
<< " is set to phi::DenseTensor";
ctx->SetOutputType(w_grad_var_name,
framework::proto::VarType::LOD_TENSOR);
}
......
......@@ -27,7 +27,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
inline int Im2SeqOutputSize(
int input_size, int filter_size, int padding_0, int padding_1, int stride) {
......@@ -41,7 +40,7 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const phi::DenseTensor* in = ctx.Input<phi::DenseTensor>("X");
LoDTensor* out = ctx.Output<LoDTensor>("Out");
phi::DenseTensor* out = ctx.Output<phi::DenseTensor>("Out");
auto in_dim = in->dims();
int batch_size = in_dim[0];
int img_channels = in_dim[1];
......
......@@ -23,21 +23,20 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T, typename IndexT = int>
void IndexSelectInner(const framework::ExecutionContext& context,
LoDTensor* input,
const LoDTensor& index,
LoDTensor* output,
phi::DenseTensor* input,
const phi::DenseTensor& index,
phi::DenseTensor* output,
int dim) {
auto input_dim = input->dims();
auto input_dim_size = input_dim.size();
auto output_dim = output->dims();
auto index_size = index.dims()[0];
LoDTensor index_cpu_copy;
phi::DenseTensor index_cpu_copy;
if (!platform::is_cpu_place(index.place())) {
framework::TensorCopySync(index, platform::CPUPlace(), &index_cpu_copy);
}
......@@ -127,9 +126,9 @@ struct IndexSelectAdd<
template <typename DeviceContext, typename T, typename IndexT = int>
void IndexSelectGradInner(const framework::ExecutionContext& context,
const LoDTensor& out_grad,
const LoDTensor& index,
LoDTensor* x_grad,
const phi::DenseTensor& out_grad,
const phi::DenseTensor& index,
phi::DenseTensor* x_grad,
int dim) {
const T* input_data = out_grad.data<T>();
const IndexT* index_data = index.data<IndexT>();
......
......@@ -147,8 +147,8 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
const phi::DenseTensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
......
......@@ -108,8 +108,8 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType(
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
......@@ -129,8 +129,8 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType(
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
......
......@@ -23,7 +23,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using DataLayout = phi::DataLayout;
class InstanceNormOp : public framework::OperatorWithKernel {
......
......@@ -19,15 +19,14 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
template <typename T>
class LabelSmoothMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<LoDTensor>("X");
auto* in_t = ctx.Input<phi::DenseTensor>("X");
auto* dist_t = ctx.Input<phi::DenseTensor>("PriorDist");
auto* out_t = ctx.Output<LoDTensor>("Out");
auto* out_t = ctx.Output<phi::DenseTensor>("Out");
auto epsilon = ctx.Attr<float>("epsilon");
auto epsilon_gt = 1.0f - epsilon;
......
......@@ -19,7 +19,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
template <typename T>
void LabelSmoothMuls(const platform::Place& place,
......@@ -58,8 +57,8 @@ template <typename T>
class LabelSmoothNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out_t = ctx.Output<LoDTensor>("Out");
auto* in_t = ctx.Input<LoDTensor>("X");
auto* out_t = ctx.Output<phi::DenseTensor>("Out");
auto* in_t = ctx.Input<phi::DenseTensor>("X");
auto* dist_t = ctx.Input<phi::DenseTensor>("PriorDist");
auto epsilon = ctx.Attr<float>("epsilon");
......
......@@ -21,7 +21,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using DataLayout = phi::DataLayout;
class LayerNormOp : public framework::OperatorWithKernel {
......@@ -214,8 +213,8 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
}
PADDLE_ENFORCE_NOT_NULL(
t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found."));
......
......@@ -28,7 +28,6 @@
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
template <typename T>
......
......@@ -23,23 +23,24 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Emission",
"(LoDTensor/Tensor<float>). When a LoDTensor input,A 2-D LoDTensor"
"(phi::DenseTensor<float>). When a phi::DenseTensor "
"input,A 2-D phi::DenseTensor"
" with shape [N x D], where N is the size of the "
"mini-batch and D is the total tag number. The unscaled emission "
"weight matrix for the linear chain CRF. When a Tensor input,"
"A Tensor with shape [N x S x D], where N is batch number,"
"S is max length of sequences, D is the total tag number."
"A LoDTensor or Tensor with type float32, float64.");
"A phi::DenseTensor with type float32, float64.");
AddInput("Transition",
"(Tensor, default Tensor<float>) A 2-D Tensor with shape "
"[(D + 2) x D]. The learnable parameter for the linear_chain_crf "
"operator. See more details in the operator's comments.");
AddInput("Label",
"(LoDTensor/Tensor<int64_t>), when a LoDTensor input, "
"(phi::DenseTensor<int64_t>), when a phi::DenseTensor input, "
"[N x 1], where N is the total element number in a mini-batch. "
"when a Tensor input, [N x S], where N is batch number. "
"S is max length of sequences. The ground truth."
"A LoDTensor or Tensor with int64.");
"A phi::DenseTensor with int64.");
AddInput("Length",
"(Tensor, default Tensor<int64_t>) A Tensor with shape "
"[M x 1], where M is the sequence number in a mini-batch."
......@@ -63,7 +64,7 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
"The exponentials of Input(Emission). This is an intermediate "
"computational result in forward computation, and will be reused in "
"backward computation."
"A LoDTensor or Tensor with type float32, float64.")
"A phi::DenseTensor with type float32, float64.")
.AsIntermediate();
AddOutput(
"TransitionExps",
......@@ -71,7 +72,7 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
"[(D + 2) x D]. The exponentials of Input(Transition). This is an "
"intermediate computational result in forward computation, and "
"will be reused in backward computation."
"A LoDTensor or Tensor with type float32, float64.")
"A phi::DenseTensor with type float32, float64.")
.AsIntermediate();
AddOutput(
"LogLikelihood",
......
......@@ -47,7 +47,6 @@ struct ScalarMul {
};
using framework::LoD;
using LoDTensor = phi::DenseTensor;
template <typename DeviceContext, typename T>
class LinearChainCRFOpKernel : public framework::OpKernel<T> {
......@@ -114,7 +113,7 @@ class LinearChainCRFOpKernel : public framework::OpKernel<T> {
phi::funcs::set_constant(ctx.device_context(), emission_exps, 0.0);
phi::funcs::set_constant(ctx.device_context(), alpha, 0.0);
} else {
in_lod = ctx.Input<LoDTensor>("Label")->lod();
in_lod = ctx.Input<phi::DenseTensor>("Label")->lod();
PADDLE_ENFORCE_NE(in_lod.size(),
0,
platform::errors::InvalidArgument(
......@@ -286,7 +285,7 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
emission_exps_tmp.Resize(
{emission_dims[0] * emission_dims[1], emission_dims[2]});
} else {
in_lod = ctx.Input<LoDTensor>("Label")->lod();
in_lod = ctx.Input<phi::DenseTensor>("Label")->lod();
PADDLE_ENFORCE_NE(in_lod.size(),
0,
platform::errors::InvalidArgument(
......
......@@ -62,7 +62,7 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
LoadCombine Operator.
LoadCombine operator loads LoDTensor variables from a file, which could be
LoadCombine operator loads phi::DenseTensor variables from a file, which could be
loaded in memory already. The file should contain one or more LoDTensors
serialized using the SaveCombine operator. The
LoadCombine operator applies a deserialization strategy to appropriately load
......
......@@ -37,7 +37,7 @@ class LoadOp : public framework::OperatorWithKernel {
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded");
AddOutput("Out", "The phi::DenseTensor / SelectedRows need to be loaded");
AddAttr<bool>(
"load_as_fp16",
"If true, the tensor will be first loaded and then "
......@@ -54,7 +54,8 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int64_t>) The shape of the output")
.SetDefault({});
AddComment(
"Load operator will load a LoDTensor / SelectedRows variable from "
"Load operator will load a phi::DenseTensor / SelectedRows variable "
"from "
"disk "
"file.");
}
......
......@@ -54,7 +54,8 @@ class LoadOpKernel : public framework::OpKernel<T> {
LoadSelectedRows(fin, place, out_var);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Load operator only supports loading LoDTensor and SelectedRows "
"Load operator only supports loading phi::DenseTensor and "
"SelectedRows "
"variable, %s has wrong type",
out_var_name));
}
......
......@@ -52,13 +52,14 @@ class LoDRankTableOp : public framework::OperatorBase {
class LoDRankTableOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor) input lod tensor, must contain lod information.");
AddInput(
"X",
"(phi::DenseTensor) input lod tensor, must contain lod information.");
AddOutput("Out", "(LoDRankTable) The rank table of specific level.");
AddAttr<int>("level", "(int) the specific lod level to rank.")
.SetDefault(0)
.EqualGreaterThan(0);
AddComment(R"DOC(Create LoDRanTable by LoDTensor
AddComment(R"DOC(Create LoDRanTable by phi::DenseTensor
LoD Rank Table stores the `level` of `lod` which is ordered by sequence
length in descending order. It is useful when implement dynamic RNN and is
......
......@@ -105,18 +105,20 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, LoDTensor) Input variable of LoDResetOp which "
"could be a Tensor or LoDTensor, where the data of output "
"(Tensor, phi::DenseTensor) Input variable of LoDResetOp which "
"could be a Tensor or phi::DenseTensor, where the data of output "
"variable inherits from.");
AddInput("Y",
"(Tensor, LoDTensor, optional) If provided and Y is LoDTensor, "
"(phi::DenseTensor, optional) If provided and Y is "
"phi::DenseTensor, "
"lod of Input(Y) would be considered as the target lod first, "
"otherwise data of Input(Y) would be considered as the "
"target lod.")
.AsDispensable();
AddOutput("Out",
"(LoDTensor) Output variable of LoDResetOp which should be a "
"LoDTensor.");
AddOutput(
"Out",
"(phi::DenseTensor) Output variable of LoDResetOp which should be a "
"phi::DenseTensor.");
AddAttr<std::vector<int>>("target_lod",
"The target level 0 LoD from Attr().")
.SetDefault(std::vector<int>{});
......@@ -124,7 +126,7 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(LoDReset operator
Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y`
provided and `Y` is a LoDTensor, `Y.lod` would be considered as target LoD
provided and `Y` is a phi::DenseTensor, `Y.lod` would be considered as target LoD
first, otherwise `Y.data` would be considered as target LoD. If `Y` is not
provided, target LoD should be specified by attribute `target_lod`.
If target LoD is specified by `Y.data` or `target_lod`, only one level LoD
......@@ -132,7 +134,7 @@ is supported.
Example 1:
Given a 1-level LoDTensor input(X):
Given a 1-level phi::DenseTensor input(X):
X.lod = [[ 0, 2, 5 6 ]]
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
X.dims = [6, 1]
......@@ -146,7 +148,7 @@ then we get a 1-level LoDTensor:
Example 2:
Given a 1-level LoDTensor input(X):
Given a 1-level phi::DenseTensor input(X):
X.lod = [[ 0, 2, 5 6 ]]
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
X.dims = [6, 1]
......@@ -162,7 +164,7 @@ then we get a 1-level LoDTensor:
Example 3:
Given a 1-level LoDTensor input(X):
Given a 1-level phi::DenseTensor input(X):
X.lod = [[ 0, 2, 5 6 ]]
X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
X.dims = [6, 1]
......
......@@ -125,11 +125,11 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
PADDLE_ENFORCE_LT(
rank_level,
x.lod().size(),
platform::errors::InvalidArgument(
"Input should be a LoDTensor, and its lod_level should be at "
"least %d, but given is %d.",
rank_level + 1,
x.lod().size()));
platform::errors::InvalidArgument("Input should be a phi::DenseTensor, "
"and its lod_level should be at "
"least %d, but given is %d.",
rank_level + 1,
x.lod().size()));
out.resize(max_seq_len);
std::vector<std::vector<CopyRange>> copy_ranges(max_seq_len);
......@@ -189,14 +189,15 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
class LoDTensorToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor), the input lod tensor is a minibatch of sequences, "
"and will be split to a tensor_array according to "
"Input(RankTable).");
AddInput(
"X",
"(phi::DenseTensor), the input lod tensor is a minibatch of sequences, "
"and will be split to a tensor_array according to "
"Input(RankTable).");
AddInput("RankTable", "(LoDRankTable), the rank table.");
AddOutput("Out",
"(LoDTensorArray), the result tensor_array, which is actually a "
"std::vector<LoDTensor>.");
"std::vector<phi::DenseTensor>.");
AddComment(R"DOC(LoDTensorToArray operator.
Input(X) is a minibatch of sequences. Input(RankTable) stores the order of the input sequences.
The lod_tensor_to_array operator will spilt the input sequences to a tensor_array, with each
......@@ -234,9 +235,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
// kernel implementation.
context->SetOutputDim("Out", x_dim);
// The output LoDTensor's lod_level should be input X's lod_level - 1.
// For compile time, we call SetLoDLevel to set output's lod_level.
// For runtime, output LoDTensor's lod is determined by input X's lod and
// The output phi::DenseTensor's lod_level should be input X's lod_level
// - 1. For compile time, we call SetLoDLevel to set output's lod_level. For
// runtime, output phi::DenseTensor's lod is determined by input X's lod and
// the level specified by input RandTable.
// We cannot get X's detail lod and RankTable's level in this function, so
// leave this work to the detail kernel implementation.
......
......@@ -28,7 +28,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using SelectedRows = phi::SelectedRows;
using DDim = framework::DDim;
......@@ -52,8 +51,8 @@ template <typename T>
class LookupTableDequantKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *ids_t = context.Input<phi::DenseTensor>("Ids"); // int tensor
auto *output_t = context.Output<phi::DenseTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
auto id_name = context.InputNames("Ids").front();
......@@ -66,9 +65,9 @@ class LookupTableDequantKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GE(
table_var->Type(),
framework::VarTypeTrait<LoDTensor>::kId,
framework::VarTypeTrait<phi::DenseTensor>::kId,
platform::errors::InvalidArgument("lookup table must be LodTensor"));
auto *table_t = context.Input<LoDTensor>("W");
auto *table_t = context.Input<phi::DenseTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t quant_number = table_t->dims()[1];
int64_t row_width = (quant_number - 2) * 4;
......
......@@ -212,7 +212,7 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
<< " is set to phi::DenseTensor";
ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W"));
......
......@@ -103,9 +103,9 @@ template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *table_t = context.Input<LoDTensor>("W");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *output_t = context.Output<LoDTensor>("Out");
auto *table_t = context.Input<phi::DenseTensor>("W");
auto *ids_t = context.Input<phi::DenseTensor>("Ids");
auto *output_t = context.Output<phi::DenseTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto id_name = context.InputNames("Ids").front();
......@@ -157,9 +157,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto *ids = context.Input<LoDTensor>("Ids");
auto *table = context.Input<LoDTensor>("W");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *ids = context.Input<phi::DenseTensor>("Ids");
auto *table = context.Input<phi::DenseTensor>("W");
auto *d_output =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto *d_table =
context.Output<phi::SelectedRows>(framework::GradVarName("W"));
......@@ -209,9 +210,11 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
stream);
} else {
auto ids_t = context.Input<LoDTensor>("Ids");
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
auto ids_t = context.Input<phi::DenseTensor>("Ids");
auto d_output_t =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto d_table_t =
context.Output<phi::DenseTensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
......
......@@ -27,7 +27,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using SelectedRows = phi::SelectedRows;
using DDim = framework::DDim;
......@@ -37,8 +36,8 @@ template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *ids_t = context.Input<phi::DenseTensor>("Ids"); // int tensor
auto *output_t = context.Output<phi::DenseTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
auto id_name = context.InputNames("Ids").front();
......@@ -51,8 +50,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
if (table_var->IsType<phi::DenseTensor>()) {
auto *table_t = context.Input<phi::DenseTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
......@@ -165,15 +164,15 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &context) const override {
auto *table_var = context.InputVar("W");
DDim table_dim;
if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims();
if (table_var->IsType<phi::DenseTensor>()) {
table_dim = context.Input<phi::DenseTensor>("W")->dims();
} else if (table_var->IsType<phi::SelectedRows>()) {
auto *table_t = context.Input<phi::SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows"));
"must be either phi::DenseTensor or SelectedRows"));
}
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
......@@ -181,8 +180,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *ids = context.Input<phi::DenseTensor>("Ids");
auto *d_output =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto *d_table =
context.Output<phi::SelectedRows>(framework::GradVarName("W"));
......@@ -216,9 +216,11 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
d_output_dims_2d));
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else {
auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
auto *ids = context.Input<phi::DenseTensor>("Ids");
auto *d_output =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto *d_table =
context.Output<phi::DenseTensor>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>();
......
......@@ -156,7 +156,7 @@ class LookupTableV2OpGradVarTypeInference : public framework::VarTypeInference {
framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
<< " is set to phi::DenseTensor";
ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W"));
......
......@@ -28,7 +28,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using SelectedRows = phi::SelectedRows;
using DDim = framework::DDim;
......@@ -57,7 +56,7 @@ struct LookupTableV2CPUFunctor {
template <typename IdT>
void apply() {
auto *output_t = context_.Output<LoDTensor>("Out"); // float tensor
auto *output_t = context_.Output<phi::DenseTensor>("Out"); // float tensor
auto *table_var = context_.InputVar("W");
int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
......@@ -65,8 +64,8 @@ struct LookupTableV2CPUFunctor {
auto ids = CopyIdsToVector<IdT, int64_t>(*ids_t_);
auto ids_numel = static_cast<int64_t>(ids.size());
if (table_var->template IsType<LoDTensor>()) {
const auto &table_t = table_var->template Get<LoDTensor>();
if (table_var->template IsType<phi::DenseTensor>()) {
const auto &table_t = table_var->template Get<phi::DenseTensor>();
int64_t row_number = table_t.dims()[0];
int64_t row_width = table_t.dims()[1];
......@@ -168,15 +167,15 @@ struct LookupTableV2GradCPUFunctor {
void apply() {
auto *table_var = context_.InputVar("W");
DDim table_dim;
if (table_var->template IsType<LoDTensor>()) {
table_dim = context_.Input<LoDTensor>("W")->dims();
if (table_var->template IsType<phi::DenseTensor>()) {
table_dim = context_.Input<phi::DenseTensor>("W")->dims();
} else if (table_var->template IsType<phi::SelectedRows>()) {
auto *table_t = context_.Input<phi::SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The parameter W of a LookupTableV2 "
"must be either LoDTensor or SelectedRows"));
"must be either phi::DenseTensor or SelectedRows"));
}
int64_t padding_idx = context_.Attr<int64_t>("padding_idx");
......@@ -188,7 +187,8 @@ struct LookupTableV2GradCPUFunctor {
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_output =
context_.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto *d_table =
context_.Output<phi::SelectedRows>(framework::GradVarName("W"));
......@@ -219,8 +219,10 @@ struct LookupTableV2GradCPUFunctor {
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else {
auto *d_output = context_.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context_.Output<LoDTensor>(framework::GradVarName("W"));
auto *d_output =
context_.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto *d_table =
context_.Output<phi::DenseTensor>(framework::GradVarName("W"));
auto *ids_data = ids.data();
int64_t N = table_dim[0];
......
......@@ -32,7 +32,7 @@ class LookupTableV2MLUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
table_var->IsType<phi::DenseTensor>(),
true,
platform::errors::InvalidArgument("mlu only accept LoDTensor"));
platform::errors::InvalidArgument("mlu only accept phi::DenseTensor"));
output_t->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc ids_desc(*ids_t);
......@@ -55,11 +55,12 @@ class LookupTableV2GradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *table_var = ctx.InputVar("W");
PADDLE_ENFORCE_EQ(table_var->IsType<phi::DenseTensor>(),
true,
platform::errors::PermissionDenied(
"Unsupported Variable Type , idx in "
"LookupTableV2GradMLUKernel should be LoDTensor."));
PADDLE_ENFORCE_EQ(
table_var->IsType<phi::DenseTensor>(),
true,
platform::errors::PermissionDenied(
"Unsupported Variable Type , idx in "
"LookupTableV2GradMLUKernel should be phi::DenseTensor."));
bool is_sparse = ctx.Attr<bool>("is_sparse");
PADDLE_ENFORCE_EQ(
is_sparse,
......
......@@ -37,7 +37,7 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
table_var->IsType<phi::DenseTensor>(),
true,
platform::errors::InvalidArgument("npu only accept LoDTensor"));
platform::errors::InvalidArgument("npu only accept phi::DenseTensor"));
output_t->mutable_data<T>(ctx.GetPlace());
int64_t padding_idx = ctx.Attr<int64_t>("padding_idx");
......
......@@ -146,11 +146,12 @@ class LSTMOp : public framework::OperatorWithKernel {
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput(
"Input",
"(phi::DenseTensor) the first input is a phi::DenseTensor, which "
"support variable-time length input sequence. The underlying tensor in "
"this phi::DenseTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
......@@ -176,23 +177,26 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
AddOutput("Hidden",
"(LoDTensor) the hidden state of LSTM operator. "
"(phi::DenseTensor) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state of LSTM operator. "
"(phi::DenseTensor) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape as the reorganized input, which "
"is also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence "
"in the raw input.")
AddOutput(
"BatchGate",
"(phi::DenseTensor) This phi::DenseTensor contains input gate, forget "
"gate "
"and output gate after the nonlinear computation. This "
"phi::DenseTensor has the same shape as the reorganized input, which "
"is also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence "
"in the raw input.")
.AsIntermediate()
.AsExtra();
AddOutput("BatchCellPreAct",
"(LoDTensor) This LoDTensor is obtained in the forward and used "
"(phi::DenseTensor) This phi::DenseTensor is obtained in the "
"forward and used "
"in the backward.")
.AsIntermediate()
.AsExtra();
......
......@@ -24,7 +24,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
template <typename DeviceContext, typename T>
......@@ -44,25 +43,25 @@ class LSTMKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
bool is_test = ctx.Attr<bool>("is_test");
auto* input = ctx.Input<LoDTensor>("Input");
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* weight = ctx.Input<phi::DenseTensor>("Weight");
auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto* hidden_t0 = ctx.Input<phi::DenseTensor>("H0");
auto* cell_t0 = ctx.Input<phi::DenseTensor>("C0");
LoDTensor* batch_gate = nullptr;
LoDTensor batch_gate_temp;
phi::DenseTensor* batch_gate = nullptr;
phi::DenseTensor batch_gate_temp;
if (is_test) {
batch_gate = &batch_gate_temp;
batch_gate->Resize(input->dims());
} else {
batch_gate = ctx.Output<LoDTensor>("BatchGate");
batch_gate = ctx.Output<phi::DenseTensor>("BatchGate");
}
batch_gate->mutable_data<T>(ctx.GetPlace());
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* hidden_out = ctx.Output<phi::DenseTensor>("Hidden");
hidden_out->mutable_data<T>(ctx.GetPlace());
auto* cell_out = ctx.Output<LoDTensor>("Cell");
auto* cell_out = ctx.Output<phi::DenseTensor>("Cell");
cell_out->mutable_data<T>(ctx.GetPlace());
bool is_reverse = ctx.Attr<bool>("is_reverse");
......@@ -110,12 +109,12 @@ class LSTMKernel : public framework::OpKernel<T> {
}
// Use the local variable as here.
LoDTensor batch_hidden, batch_cell, batch_cell_pre_act_temp;
LoDTensor* batch_cell_pre_act;
phi::DenseTensor batch_hidden, batch_cell, batch_cell_pre_act_temp;
phi::DenseTensor* batch_cell_pre_act;
if (is_test) {
batch_cell_pre_act = &batch_cell_pre_act_temp;
} else {
batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
batch_cell_pre_act = ctx.Output<phi::DenseTensor>("BatchCellPreAct");
}
batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
......@@ -191,11 +190,11 @@ class LSTMKernel : public framework::OpKernel<T> {
phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden.set_lod(batch_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden
// restore the output hidden in phi::DenseTensor from the batch hidden
to_seq(device_ctx, batch_hidden, hidden_out);
batch_cell.set_lod(batch_gate->lod());
// restore the output cell state in LoDTensor from the batch cell
// restore the output cell state in phi::DenseTensor from the batch cell
to_seq(device_ctx, batch_cell, cell_out);
}
};
......@@ -204,19 +203,20 @@ template <typename DeviceContext, typename T>
class LSTMGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* weight = ctx.Input<phi::DenseTensor>("Weight");
auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto* hidden_out = ctx.Input<LoDTensor>("Hidden");
auto* cell_out = ctx.Input<LoDTensor>("Cell");
auto* hidden_out = ctx.Input<phi::DenseTensor>("Hidden");
auto* cell_out = ctx.Input<phi::DenseTensor>("Cell");
auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");
auto* batch_gate = ctx.Input<phi::DenseTensor>("BatchGate");
auto* batch_cell_pre_act = ctx.Input<phi::DenseTensor>("BatchCellPreAct");
auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden"));
auto* hidden_g =
ctx.Input<phi::DenseTensor>(framework::GradVarName("Hidden"));
auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
auto* in_g = ctx.Output<phi::DenseTensor>(framework::GradVarName("Input"));
auto* weight_g =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Weight"));
auto* bias_g = ctx.Output<phi::DenseTensor>(framework::GradVarName("Bias"));
......@@ -301,12 +301,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
to_batch(ctx, src, &dst, false);
};
LoDTensor batch_hidden, batch_hidden_g, batch_cell;
phi::DenseTensor batch_hidden, batch_hidden_g, batch_cell;
ToBatch(device_ctx, *hidden_out, out_dims, batch_hidden);
ToBatch(device_ctx, *hidden_g, out_dims, batch_hidden_g);
ToBatch(device_ctx, *cell_out, out_dims, batch_cell);
LoDTensor batch_cell_g, batch_gate_g;
phi::DenseTensor batch_cell_g, batch_gate_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
// TODO(qingqing) support the case output cell has gradient.
// to_batch(device_ctx, *cell_g, batch_cell_g, false);
......
......@@ -154,11 +154,12 @@ class LSTMPOp : public framework::OperatorWithKernel {
class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(LoDTensor) the input for sequence data, which supports "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput(
"Input",
"(phi::DenseTensor) the input for sequence data, which supports "
"variable-time length input sequence. The underlying tensor in "
"this phi::DenseTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
......@@ -190,29 +191,34 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
AddOutput("Projection",
"(LoDTensor) the projection of the hidden state of LSTMP "
"(phi::DenseTensor) the projection of the hidden state of LSTMP "
"operator. The shape is (T x P), and LoD is the same with the "
"`Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state of LSTMP operator. "
"(phi::DenseTensor) the cell state of LSTMP operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the activations. This LoDTensor has the "
"same shape as the reorganized input, which is also be called "
"batch input. The LoD size is 2. The first-level LoD is the "
"batch offsets and the second contains the indices, which "
"denotes the position of reorganized sequence in the raw input.")
AddOutput(
"BatchGate",
"(phi::DenseTensor) This phi::DenseTensor contains input gate, forget "
"gate "
"and output gate after the activations. This phi::DenseTensor has the "
"same shape as the reorganized input, which is also be called "
"batch input. The LoD size is 2. The first-level LoD is the "
"batch offsets and the second contains the indices, which "
"denotes the position of reorganized sequence in the raw input.")
.AsIntermediate();
AddOutput("BatchCellPreAct",
"(LoDTensor) the pre-activation cell state reorganized in batch. "
"This LoDTensor is obtained in the forward and used in the "
"backward.")
AddOutput(
"BatchCellPreAct",
"(phi::DenseTensor) the pre-activation cell state reorganized in "
"batch. "
"This phi::DenseTensor is obtained in the forward and used in the "
"backward.")
.AsIntermediate();
AddOutput("BatchHidden",
"(LoDTensor) the hidden state reorganized in batch. "
"This LoDTensor is obtained in the forward and used in the "
"backward.")
AddOutput(
"BatchHidden",
"(phi::DenseTensor) the hidden state reorganized in batch. "
"This phi::DenseTensor is obtained in the forward and used in the "
"backward.")
.AsIntermediate();
AddAttr<bool>("use_peepholes",
"(bool, default: True) "
......
......@@ -29,7 +29,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
using platform::Transform;
......@@ -107,7 +106,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* weight = ctx.Input<phi::DenseTensor>("Weight");
auto* proj_weight = ctx.Input<phi::DenseTensor>("ProjWeight");
auto* bias = ctx.Input<phi::DenseTensor>("Bias");
......@@ -118,11 +117,11 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto proj_clip = static_cast<T>(ctx.Attr<float>("proj_clip"));
auto cell_clip = static_cast<T>(ctx.Attr<float>("cell_clip"));
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
auto* batch_gate = ctx.Output<phi::DenseTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace());
auto* proj_out = ctx.Output<LoDTensor>("Projection");
auto* proj_out = ctx.Output<phi::DenseTensor>("Projection");
proj_out->mutable_data<T>(ctx.GetPlace());
auto* cell_out = ctx.Output<LoDTensor>("Cell");
auto* cell_out = ctx.Output<phi::DenseTensor>("Cell");
cell_out->mutable_data<T>(ctx.GetPlace());
bool is_reverse = ctx.Attr<bool>("is_reverse");
......@@ -172,10 +171,10 @@ class LSTMPKernel : public framework::OpKernel<T> {
}
// Use the local variable as here.
LoDTensor batch_proj, batch_cell;
auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
phi::DenseTensor batch_proj, batch_cell;
auto* batch_cell_pre_act = ctx.Output<phi::DenseTensor>("BatchCellPreAct");
batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
auto* batch_hidden = ctx.Output<LoDTensor>("BatchHidden");
auto* batch_hidden = ctx.Output<phi::DenseTensor>("BatchHidden");
batch_hidden->mutable_data<T>(dims, ctx.GetPlace()); // T x D
batch_proj.mutable_data<T>(proj_dims, ctx.GetPlace()); // T x P
batch_cell.mutable_data<T>(dims, ctx.GetPlace()); // T x D
......@@ -272,11 +271,11 @@ class LSTMPKernel : public framework::OpKernel<T> {
phi::funcs::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_proj.set_lod(batch_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden
// restore the output hidden in phi::DenseTensor from the batch hidden
to_seq(device_ctx, batch_proj, proj_out);
batch_cell.set_lod(batch_gate->lod());
// restore the output cell state in LoDTensor from the batch cell
// restore the output cell state in phi::DenseTensor from the batch cell
to_seq(device_ctx, batch_cell, cell_out);
}
};
......@@ -310,20 +309,20 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto* proj_weight = ctx.Input<phi::DenseTensor>("ProjWeight");
auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto* proj_out = ctx.Input<LoDTensor>("Projection");
auto* cell_out = ctx.Input<LoDTensor>("Cell");
auto* proj_out = ctx.Input<phi::DenseTensor>("Projection");
auto* cell_out = ctx.Input<phi::DenseTensor>("Cell");
auto proj_clip = static_cast<T>(ctx.Attr<float>("proj_clip"));
auto cell_clip = static_cast<T>(ctx.Attr<float>("cell_clip"));
auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");
auto* batch_hidden = ctx.Input<LoDTensor>("BatchHidden");
auto* batch_gate = ctx.Input<phi::DenseTensor>("BatchGate");
auto* batch_cell_pre_act = ctx.Input<phi::DenseTensor>("BatchCellPreAct");
auto* batch_hidden = ctx.Input<phi::DenseTensor>("BatchHidden");
auto* projection_g =
ctx.Input<LoDTensor>(framework::GradVarName("Projection"));
ctx.Input<phi::DenseTensor>(framework::GradVarName("Projection"));
auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
auto* in_g = ctx.Output<phi::DenseTensor>(framework::GradVarName("Input"));
auto* weight_g =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Weight"));
auto* proj_weight_g =
......@@ -415,13 +414,13 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
to_batch(ctx, src, &dst, false);
};
LoDTensor batch_hidden_g, batch_proj, batch_proj_g, batch_cell;
phi::DenseTensor batch_hidden_g, batch_proj, batch_proj_g, batch_cell;
batch_hidden_g.mutable_data<T>(out_dims, ctx.GetPlace());
ToBatch(device_ctx, *proj_out, proj_dims, batch_proj); // T x P
ToBatch(device_ctx, *projection_g, proj_dims, batch_proj_g); // T x P
ToBatch(device_ctx, *cell_out, out_dims, batch_cell); // T x D
LoDTensor batch_cell_g, batch_gate_g;
phi::DenseTensor batch_cell_g, batch_gate_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
// TODO(qingqing) support the case output cell has gradient.
// to_batch(device_ctx, *cell_g, batch_cell_g, false);
......
......@@ -25,7 +25,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using LoD = framework::LoD;
void MatchMatrixTensorOP::InferShape(framework::InferShapeContext* ctx) const {
......@@ -92,7 +91,7 @@ void MatchMatrixTensorOP::InferShape(framework::InferShapeContext* ctx) const {
if (ctx->IsRuntime()) {
framework::Variable* x_var =
PADDLE_GET(framework::Variable*, ctx->GetInputVarPtrs("X")[0]);
const auto& x_lod = x_var->Get<LoDTensor>().lod();
const auto& x_lod = x_var->Get<phi::DenseTensor>().lod();
PADDLE_ENFORCE_EQ(x_lod.empty(),
false,
platform::errors::InvalidArgument(
......@@ -117,7 +116,7 @@ void MatchMatrixTensorOP::InferShape(framework::InferShapeContext* ctx) const {
framework::Variable* y_var =
PADDLE_GET(framework::Variable*, ctx->GetInputVarPtrs("Y")[0]);
const auto& y_lod = y_var->Get<LoDTensor>().lod();
const auto& y_lod = y_var->Get<phi::DenseTensor>().lod();
PADDLE_ENFORCE_EQ(y_lod.empty(),
false,
platform::errors::InvalidArgument(
......@@ -213,18 +212,22 @@ void MatchMatrixTensorOpGrad::InferShape(
void MatchMatrixTensorOpMaker::Make() {
AddInput("X",
"X (LoDTensor, default LoDTensor<float>) Input variable which "
"X (phi::DenseTensor, default phi::DenseTensor<float>) Input "
"variable which "
"should contain lod information.");
AddInput("Y",
"Y (LoDTensor, default LoDTensor<float>) Input variable which "
"Y (phi::DenseTensor, default phi::DenseTensor<float>) Input "
"variable which "
"should contain lod information.");
AddInput("W", "W (Tensor), The weight of X and Y.");
AddAttr<int>("dim_t", "the dim of W").SetDefault(1);
AddOutput("Out",
"(LoDTensor, default LoDTensor<float>) Output variable which "
"(phi::DenseTensor, default phi::DenseTensor<float>) Output "
"variable which "
"is X * W * Y");
AddOutput("Tmp",
"(LoDTensor, default LoDTensor<float>) tmp variable which is "
"(phi::DenseTensor, default phi::DenseTensor<float>) tmp variable "
"which is "
"used for X * W");
AddComment(R"DOC(
Match Matrix Tensor Operator
......@@ -242,11 +245,11 @@ template <typename DeviceContext, typename T>
class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* y = ctx.Input<LoDTensor>("Y");
auto* x = ctx.Input<phi::DenseTensor>("X");
auto* y = ctx.Input<phi::DenseTensor>("Y");
auto* w = ctx.Input<phi::DenseTensor>("W");
auto* out = ctx.Output<LoDTensor>("Out");
auto* tmp = ctx.Output<LoDTensor>("Tmp");
auto* out = ctx.Output<phi::DenseTensor>("Out");
auto* tmp = ctx.Output<phi::DenseTensor>("Tmp");
int dim_t = ctx.Attr<int>("dim_t");
int64_t dim_in = x->dims()[1];
......@@ -322,10 +325,10 @@ template <typename DeviceContext, typename T>
class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* y = ctx.Input<LoDTensor>("Y");
auto* x = ctx.Input<phi::DenseTensor>("X");
auto* y = ctx.Input<phi::DenseTensor>("Y");
auto* w = ctx.Input<phi::DenseTensor>("W");
auto* tmp = ctx.Input<LoDTensor>("Tmp");
auto* tmp = ctx.Input<phi::DenseTensor>("Tmp");
int dim_t = ctx.Attr<int>("dim_t");
int64_t dim_in = x->dims()[1];
......@@ -346,9 +349,9 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
auto* bottom_r_data = y->data<T>();
auto* bottom_l_trans_data = tmp->data<T>();
auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto* d_y = ctx.Output<LoDTensor>(framework::GradVarName("Y"));
auto* d_out = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto* d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
auto* d_y = ctx.Output<phi::DenseTensor>(framework::GradVarName("Y"));
Tensor tmp_grad;
tmp_grad.Resize(tmp->dims());
......
......@@ -83,9 +83,9 @@ class MemcpyD2HKernel {
class MemcpyD2HOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(LoDTensor) The input variable ");
AddInput("X", "(phi::DenseTensor) The input variable ");
AddOutput("Out",
"(LoDTensor) The type of output "
"(phi::DenseTensor) The type of output "
"is the same as input X.");
AddAttr<int>(
"dst_place_type",
......@@ -98,7 +98,7 @@ class MemcpyD2HOpProtoMaker : public framework::OpProtoAndCheckerMaker {
MemcpyD2H Operator.
By now, it ONLY supports the memcopy between NPUPlace/CUDAPlace <-> CUDAPinnedPlace/CPU.
You would have to update it if you want other more capacities.
Out = X, when type in [LoDTensor]
Out = X, when type in [phi::DenseTensor]
raise error if the type is not listed above.
)DOC");
}
......
......@@ -84,9 +84,9 @@ class MemcpyH2DKernel {
class MemcpyH2DOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(LoDTensor) The input variable ");
AddInput("X", "(phi::DenseTensor) The input variable ");
AddOutput("Out",
"(LoDTensor) The type of output "
"(phi::DenseTensor) The type of output "
"is the same as input X.");
AddAttr<int>("dst_place_type",
"Determine the dst place of tensor copy. "
......@@ -100,7 +100,7 @@ class MemcpyH2DOpProtoMaker : public framework::OpProtoAndCheckerMaker {
MemcpyD2H Operator.
By now, it ONLY supports the memcopy between CUDAPinnedPlace/CPU <-> NPUPlace/CUDAPlace.
You would have to update it if you want other more capacities.
Out = X, when type in [LoDTensor]
Out = X, when type in [phi::DenseTensor]
raise error if the type is not listed above.
)DOC");
}
......
......@@ -100,9 +100,9 @@ class MemcpyKernel {
class MemcpyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(LoDTensor) The input variable ");
AddInput("X", "(phi::DenseTensor) The input variable ");
AddOutput("Out",
"(LoDTensor) The type of output "
"(phi::DenseTensor) The type of output "
"is the same as input X.");
AddAttr<int>("dst_place_type",
"Determine the dst place of tensor copy. "
......@@ -122,7 +122,7 @@ class MemcpyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
NPUPlace <-> CPUPlace, and used as an internal op by Recompute-Offload.
You would have to update it if you want other more capacities.
Out = X, when type in [LoDTensor]
Out = X, when type in [phi::DenseTensor]
raise error if the type is not listed above.
)DOC");
}
......
......@@ -104,7 +104,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
out_lod->clear();
size_t out_offset = 0;
// Build LoDTensor `out`
// Build phi::DenseTensor `out`
size_t in_true_idx = 0;
size_t in_false_idx = 0;
......@@ -182,18 +182,18 @@ class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input LoDTensor, contains complete lod information to "
"The input phi::DenseTensor, contains complete lod information to "
"construct the output");
AddInput("Mask", "A bool column vector which mask the input");
AddInput("InTrue", "The True branch to be merged");
AddInput("InFalse", "The False branch to be merged");
AddOutput("Out", "The merged output LoDTensor");
AddOutput("Out", "The merged output phi::DenseTensor");
AddAttr<int>("level", "(int) the specific lod level to rank.")
.SetDefault(0)
.EqualGreaterThan(0);
AddComment(
R"DOC(
Merge True and False branches of LoDTensor into a single Output,
Merge True and False branches of phi::DenseTensor into a single Output,
with a mask at certain lod level. X is used to obtain complete
lod information. Please refer to SplitLoDTensorOp.)DOC");
}
......
......@@ -300,7 +300,7 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference {
ctx->SetOutputType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to LoDTensor";
<< " is set to phi::DenseTensor";
ctx->SetOutputType(weight_grad, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetOutputDataType(weight_grad, ctx->GetInputDataType("Input"));
......
......@@ -32,7 +32,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
using SelectedRows = phi::SelectedRows;
using Sampler = math::Sampler;
using DDim = framework::DDim;
......@@ -395,15 +394,15 @@ class NCEGradKernel : public framework::OpKernel<T> {
auto *table_var = context.InputVar("Weight");
DDim table_dim;
if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("Weight")->dims();
if (table_var->IsType<phi::DenseTensor>()) {
table_dim = context.Input<phi::DenseTensor>("Weight")->dims();
} else if (table_var->IsType<phi::SelectedRows>()) {
auto *table_t = context.Input<phi::SelectedRows>("Weight");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The parameter Weight of a NCE_OP "
"must be either LoDTensor or SelectedRows"));
"must be either phi::DenseTensor or SelectedRows"));
}
auto d_w =
......
......@@ -37,7 +37,6 @@ static inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
template <typename T>
......@@ -86,9 +85,9 @@ template <typename T>
class NumberCountOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto numbers = context.Input<LoDTensor>("numbers");
auto numbers = context.Input<phi::DenseTensor>("numbers");
auto upper_range = context.Attr<int>("upper_range");
auto number_count = context.Output<LoDTensor>("Out");
auto number_count = context.Output<phi::DenseTensor>("Out");
int64_t batch_size = numbers->numel();
auto place = context.GetPlace();
......
......@@ -79,7 +79,8 @@ class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor, LoDTensor<int>) Input variable with rank at least 2. "
"(phi::DenseTensor, phi::DenseTensor<int>) Input variable with "
"rank at least 2. "
"The last dimension of X should be 1. Each value of X is an index "
"to indicate the position.");
AddInput("depth_tensor", "(Tensor, Tensor<int>), Length of one-hot vector")
......
......@@ -60,13 +60,12 @@ struct OneHotOpCUDAFunctor {
}
};
using LoDTensor = phi::DenseTensor;
template <typename DeviceContext, typename T>
class OneHotCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
auto* in = context.Input<phi::DenseTensor>("X");
auto* out = context.Output<phi::DenseTensor>("Out");
int depth = -1;
if (context.HasInput("depth_tensor")) {
......
......@@ -76,14 +76,13 @@ struct OneHotOpFunctor {
}
};
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
template <typename DeviceContext, typename T>
class OneHotKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
auto* in = context.Input<phi::DenseTensor>("X");
auto* out = context.Output<phi::DenseTensor>("Out");
int depth = context.Attr<int>("depth");
bool allow_out_of_range = context.Attr<bool>("allow_out_of_range");
if (context.HasInput("depth_tensor")) {
......
......@@ -25,8 +25,8 @@ class OneHotNPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* in = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
auto* in = ctx.Input<phi::DenseTensor>("X");
auto* out = ctx.Output<phi::DenseTensor>("Out");
int depth = ctx.Attr<int>("depth");
if (ctx.HasInput("depth_tensor")) {
......
......@@ -22,15 +22,14 @@
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
template <typename DeviceContext, typename T>
class OneHotXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
const auto* in = context.Input<phi::DenseTensor>("X");
auto* out = context.Output<phi::DenseTensor>("Out");
// get depth from attr
int depth = context.Attr<int>("depth");
......
......@@ -52,7 +52,8 @@ class OneHotV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor, LoDTensor<int>) Input variable with rank at least 2. "
"(phi::DenseTensor, phi::DenseTensor<int>) Input variable with "
"rank at least 2. "
"The last dimension of X should be 1. Each value of X is an index "
"to indicate the position.");
AddInput("depth_tensor", "(Tensor, Tensor<int>), Length of one-hot vector")
......
......@@ -20,7 +20,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
template <typename T>
class OneHotV2MLUKernel : public framework::OpKernel<T> {
......@@ -28,8 +27,8 @@ class OneHotV2MLUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto* in = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
auto* in = ctx.Input<phi::DenseTensor>("X");
auto* out = ctx.Output<phi::DenseTensor>("Out");
int depth = ctx.Attr<int>("depth");
if (ctx.HasInput("depth_tensor")) {
std::vector<int32_t> depth_data;
......
......@@ -18,7 +18,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
template <typename T>
class OneHotV2NPUKernel : public framework::OpKernel<T> {
......@@ -26,8 +25,8 @@ class OneHotV2NPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* in = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
auto* in = ctx.Input<phi::DenseTensor>("X");
auto* out = ctx.Output<phi::DenseTensor>("Out");
int depth = ctx.Attr<int>("depth");
if (ctx.HasInput("depth_tensor")) {
......
......@@ -23,7 +23,6 @@ namespace operators {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
template <class T>
......@@ -154,8 +153,8 @@ class PartialConcatGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out_grad = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
auto ins = ctx.MultiInput<phi::DenseTensor>("X");
auto outs = ctx.MultiOutput<phi::DenseTensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(ins[0] != nullptr,
true,
......
......@@ -23,7 +23,6 @@ namespace operators {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
template <class T>
......@@ -153,8 +152,8 @@ class PartialSumGradOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
const Tensor *out_grad =
ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
auto ins = ctx.MultiInput<phi::DenseTensor>("X");
auto outs = ctx.MultiOutput<phi::DenseTensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(
ins[0] != nullptr,
......
......@@ -20,7 +20,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
template <typename DeviceContext, typename T>
class PositiveNegativePairKernel : public framework::OpKernel<T> {
......
......@@ -20,7 +20,6 @@ namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
using LoDTensor = phi::DenseTensor;
class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
......@@ -33,9 +32,9 @@ class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"H is the height of the input feature map, and "
"W is the width.");
AddInput("ROIs",
"(LoDTensor), "
"(phi::DenseTensor), "
"ROIs (Regions of Interest) to pool over. "
"should be a 2-D LoDTensor of shape (num_rois, 4) "
"should be a 2-D phi::DenseTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]. "
"where (x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates. "
......@@ -95,13 +94,13 @@ class PRROIPoolOp : public framework::OperatorWithKernel {
rois_dims.size(),
2,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"ROIs should be a 2-D phi::DenseTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
PADDLE_ENFORCE_EQ(
rois_dims[1],
4,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"ROIs should be a 2-D phi::DenseTensor of shape (num_rois, 4) "
"given as [(x1, y1, x2, y2), ...]"));
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册