From cb28428323a70e82ed84be1f2d9db507ff8f8a30 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 14 Sep 2017 01:03:32 +0800 Subject: [PATCH] Replace LoDTensor in elementwise_mul_op, pad_op and recurrent_op_utils. --- paddle/framework/operator.cc | 31 +++++++++------------- paddle/framework/operator.h | 25 ++++++++++++----- paddle/operators/elementwise_mul_op.cc | 6 ++--- paddle/operators/pad_op.cc | 9 ++++--- paddle/operators/rnn/recurrent_op_utils.cc | 19 ++++++++----- 5 files changed, 51 insertions(+), 39 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 25faeff0d1c..27e77849400 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -189,13 +189,7 @@ void OperatorBase::GenerateTemporaryNames() { template <> const Tensor* InferShapeContext::Input(const std::string& name) const { auto* var = InputVar(name); - if (var == nullptr) return nullptr; - if (var->IsType()) { - return &var->Get(); - } - PADDLE_ENFORCE(var->IsType(), - "The Input(%s) must be LoDTensor or Tensor."); - return &var->Get(); + return var == nullptr ? nullptr : GetTensorFromVar(var); } template <> @@ -204,9 +198,11 @@ const std::vector InferShapeContext::MultiInput( auto names = op().Inputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [&](const std::string& sub_name) { return Input(sub_name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + return var == nullptr ? nullptr : GetTensorFromVar(var); + }); return res; } @@ -214,12 +210,7 @@ template <> Tensor* ExecutionContext::Output(const std::string& name) const { auto* var = OutputVar(name); if (var == nullptr) return nullptr; - if (var->IsType()) { - return const_cast(&var->Get()); - } - PADDLE_ENFORCE(var->IsType(), - "The Input(%s) must be LoDTensor or Tensor."); - return const_cast(&var->Get()); + return GetTensorFromVar(var); } template <> @@ -228,9 +219,11 @@ std::vector ExecutionContext::MultiOutput( auto names = op().Outputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [&](const std::string& sub_name) { return Output(sub_name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [&](const std::string& sub_name) { + auto var = scope().FindVar(sub_name); + return var == nullptr ? nullptr : GetTensorFromVar(var); + }); return res; } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index b2d79084082..bbf9930f0af 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -306,9 +306,11 @@ class InferShapeContext { auto names = op_.Inputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [&](const std::string& sub_name) { return Input(sub_name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + return var == nullptr ? nullptr : &var->Get(); + }); return res; } @@ -317,12 +319,23 @@ class InferShapeContext { auto names = op_.Outputs(name); std::vector res; res.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(res), - [&](const std::string& sub_name) { return Output(sub_name); }); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + return var == nullptr ? nullptr : var->GetMutable(); + }); return res; } + Tensor* GetTensorFromVar(const Variable* var) const { + if (var->IsType()) { + return const_cast(&var->Get()); + } + PADDLE_ENFORCE(var->IsType(), + "The Input(%s) must be LoDTensor or Tensor."); + return const_cast(&var->Get()); + } + private: const OperatorBase& op_; const Scope& scope_; diff --git a/paddle/operators/elementwise_mul_op.cc b/paddle/operators/elementwise_mul_op.cc index 1742925545d..ae88ec1b306 100644 --- a/paddle/operators/elementwise_mul_op.cc +++ b/paddle/operators/elementwise_mul_op.cc @@ -31,7 +31,7 @@ class ElementWiseMulOp : public framework::OperatorWithKernel { auto y_dim = ctx.Input("Y")->dims(); PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), "Rank of first input must >= rank of second input.") - ctx.Output("Out")->Resize(x_dim); + ctx.Output("Out")->Resize(x_dim); } }; @@ -80,8 +80,8 @@ class ElementWiseMulOpGrad : public framework::OperatorWithKernel { auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - auto *y_grad = ctx.Output(framework::GradVarName("Y")); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + auto *y_grad = ctx.Output(framework::GradVarName("Y")); PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), "Rank of first input must >= rank of second input.") diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 7e78b6ec133..6cf7bd6f35b 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -34,7 +34,8 @@ class PadOp : public framework::OperatorWithKernel { for (int i = 0; i < x_dim.size(); ++i) { out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1]; } - ctx.Output("Out")->Resize(framework::make_ddim(out_dims)); + ctx.Output("Out")->Resize( + framework::make_ddim(out_dims)); } }; @@ -95,9 +96,9 @@ class PadOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); auto x_dims = ctx.Input("X")->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - if (x_grad != nullptr) { - x_grad->Resize(x_dims); + auto *x_g = ctx.Output(framework::GradVarName("X")); + if (x_g != nullptr) { + x_g->Resize(x_dims); } } }; diff --git a/paddle/operators/rnn/recurrent_op_utils.cc b/paddle/operators/rnn/recurrent_op_utils.cc index 97872c67ac9..6c082cb1825 100644 --- a/paddle/operators/rnn/recurrent_op_utils.cc +++ b/paddle/operators/rnn/recurrent_op_utils.cc @@ -21,6 +21,7 @@ namespace rnn { namespace f = paddle::framework; using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; void SegmentInputs(const std::vector& step_scopes, const std::vector& inlinks, const size_t seq_len, @@ -31,7 +32,7 @@ void SegmentInputs(const std::vector& step_scopes, PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.", inlinks[i].external); - Tensor* input = input_var->GetMutable(); + LoDTensor* input = input_var->GetMutable(); f::DDim dims = input->dims(); PADDLE_ENFORCE(static_cast(dims[0]) == seq_len, "all the inlinks must have same length"); @@ -40,6 +41,8 @@ void SegmentInputs(const std::vector& step_scopes, Tensor* step_input = step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable(); if (!infer_shape_mode) { + // The input of operators of each step is Tensor here. + // Maybe need to modify Slice function. *step_input = input->Slice(j, j + 1); } step_input->Resize(step_dims); @@ -54,21 +57,23 @@ void ConcatOutputs(const std::vector& step_scopes, auto output_var = step_scopes[0]->FindVar(outlinks[i].external); PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.", outlinks[i].external); - Tensor* output = output_var->GetMutable(); + LoDTensor* output = output_var->GetMutable(); if (infer_shape_mode) { auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal); PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope", outlinks[i].internal); - f::DDim step_dims = step_scope_var->template GetMutable()->dims(); + f::DDim step_dims = + step_scope_var->template GetMutable()->dims(); std::vector dims_vec = vectorize(step_dims); dims_vec.insert(dims_vec.begin(), seq_len); output->Resize(f::make_ddim(dims_vec)); } else { output->mutable_data(platform::CPUPlace()); for (size_t j = 0; j < seq_len; j++) { - Tensor* step_output = - step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable(); + LoDTensor* step_output = step_scopes[j] + ->FindVar(outlinks[i].internal) + ->GetMutable(); // TODO(luotao02) data type and platform::DeviceContext() should set // correctly (output->Slice(j, j + 1)) @@ -94,8 +99,8 @@ void LinkMemories(const std::vector& scopes, auto scope = scopes[step_id]; auto linked_scope = scopes[step_id + offset]; for (auto& attr : memories) { - auto mem = scope->FindVar(attr.pre_var)->GetMutable(); - auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable(); + auto mem = scope->FindVar(attr.pre_var)->GetMutable(); + auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable(); if (infer_shape_mode) { mem->Resize(linked_mem->dims()); } else { -- GitLab