recurrent_op_utils.cc 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/operators/rnn/recurrent_op_utils.h"

namespace paddle {
namespace operators {
namespace rnn {

D
dongzhihong 已提交
21 22 23
namespace f = paddle::framework;

using Tensor = framework::Tensor;
24
using LoDTensor = framework::LoDTensor;
25 26

void SegmentInputs(const std::vector<Scope*>& step_scopes,
S
superjom 已提交
27
                   const std::vector<std::string>& inlinks,
Q
qiaolongfei 已提交
28
                   const size_t seq_len) {
29 30
  PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
  for (size_t i = 0; i < inlinks.size(); ++i) {
S
superjom 已提交
31 32 33 34
    // global inputs
    auto input_var = step_scopes[0]->parent().FindVar(inlinks[i]);
    PADDLE_ENFORCE_NOT_NULL(input_var, "input link [%s] is not in scope.",
                            inlinks[i]);
35

36
    LoDTensor* input = input_var->GetMutable<LoDTensor>();
D
dongzhihong 已提交
37
    f::DDim dims = input->dims();
S
superjom 已提交
38 39
    PADDLE_ENFORCE_EQ(static_cast<size_t>(dims[0]), seq_len,
                      "all the inlinks be the same length");
D
dongzhihong 已提交
40
    f::DDim step_dims = slice_ddim(dims, 1, dims.size());
41 42
    for (size_t j = 0; j < seq_len; j++) {
      Tensor* step_input =
S
superjom 已提交
43
          step_scopes[j]->NewVar(inlinks[i])->GetMutable<Tensor>();
Q
qiaolongfei 已提交
44 45 46
      // The input of operators of each step is Tensor here.
      // Maybe need to modify Slice function.
      *step_input = input->Slice<float>(j, j + 1);
47 48 49 50 51 52
      step_input->Resize(step_dims);
    }
  }
}

void ConcatOutputs(const std::vector<Scope*>& step_scopes,
S
superjom 已提交
53
                   const std::vector<std::string>& outlinks,
Q
qiaolongfei 已提交
54
                   const size_t seq_len) {
55
  for (size_t i = 0; i < outlinks.size(); i++) {
S
superjom 已提交
56 57 58
    auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
    PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
                            outlinks[i]);
59
    LoDTensor* output = output_var->GetMutable<LoDTensor>();
Y
Yan Chunwei 已提交
60

Q
qiaolongfei 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
    PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
    f::DDim step_dims =
        step_scope_var->template GetMutable<LoDTensor>()->dims();
    std::vector<int64_t> dims_vec = vectorize(step_dims);
    dims_vec.insert(dims_vec.begin(), seq_len);
    output->Resize(f::make_ddim(dims_vec));
    output->mutable_data<float>(platform::CPUPlace());
    for (size_t j = 0; j < seq_len; j++) {
      LoDTensor* step_output =
          step_scopes[j]->FindVar(outlinks[i])->GetMutable<LoDTensor>();
      // TODO(luotao02) data type and platform::DeviceContext() should set
      // correctly
      (output->Slice<float>(j, j + 1))
          .CopyFrom<float>(*step_output, platform::CPUPlace());
76 77 78 79 80 81
    }
  }
}

void LinkMemories(const std::vector<Scope*>& scopes,
                  const std::vector<rnn::MemoryAttr>& memories,
Q
qiaolongfei 已提交
82
                  const size_t step_id, const int offset) {
Y
Yan Chunwei 已提交
83 84 85 86 87 88 89 90 91
  PADDLE_ENFORCE_LT(step_id, scopes.size(),
                    "step [%d] is out of range of step scopes' size [%d]",
                    step_id, scopes.size());
  PADDLE_ENFORCE_GE(static_cast<int>(step_id) + offset, 0,
                    "offset [%d] must be large than -[%d]", offset, step_id);
  PADDLE_ENFORCE_LT(
      step_id + offset, scopes.size(),
      "offset [%d] is out of range, it must be less than (%d - %d)", offset,
      scopes.size(), step_id);
92 93 94
  auto scope = scopes[step_id];
  auto linked_scope = scopes[step_id + offset];
  for (auto& attr : memories) {
95 96
    auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
    auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
Q
qiaolongfei 已提交
97 98
    mem->Resize(linked_mem->dims());
    mem->ShareDataWith<float>(*linked_mem);
99 100 101 102
  }
}

void InitArgument(const ArgumentName& name, Argument* arg,
S
superjom 已提交
103 104 105
                  const framework::OperatorBase& op, bool is_grad) {
  arg->step_scopes =
      is_grad ? op.Input(name.step_scopes) : op.Output(name.step_scopes);
S
superjom 已提交
106 107
  arg->inlinks = op.Inputs(name.inlinks);
  arg->outlinks = op.Outputs(name.outlinks);
108

S
superjom 已提交
109 110
  auto boot_memories =
      is_grad ? op.Outputs(name.boot_memories) : op.Inputs(name.boot_memories);
111
  // attributes
Y
Yu Yang 已提交
112 113
  auto memories = op.Attr<std::vector<std::string>>(name.memories);
  auto pre_memories = op.Attr<std::vector<std::string>>(name.pre_memories);
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

  PADDLE_ENFORCE(memories.size() == boot_memories.size(),
                 "the size of memories, boot_memories don't match:%d,%d",
                 memories.size(), boot_memories.size());
  PADDLE_ENFORCE(pre_memories.size() == boot_memories.size(),
                 "the size of pre_memories, boot_memories don't match:%d,%d",
                 pre_memories.size(), boot_memories.size());
  PADDLE_ENFORCE(memories.size() > 0, "more than 1 memories should be set");

  for (size_t i = 0; i < memories.size(); ++i) {
    rnn::MemoryAttr mem_attr;
    mem_attr.var = memories[i];
    mem_attr.pre_var = pre_memories[i];
    mem_attr.boot_var = boot_memories[i];
    (arg->memories).push_back(mem_attr);
  }
}

}  // namespace rnn
}  // namespace operators
}  // namespace paddle