save_op.cc 7.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14 15 16

#include <stdint.h>
#include <fstream>
D
dzhwinter 已提交
17
#include <memory>
Y
Yu Yang 已提交
18 19
#include <numeric>

Y
Yi Wang 已提交
20
#include "paddle/fluid/framework/data_type.h"
K
Kexin Zhao 已提交
21
#include "paddle/fluid/framework/data_type_transform.h"
Y
Yi Wang 已提交
22 23 24
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
T
tangwei12 已提交
25
#include "paddle/fluid/framework/selected_rows.h"
T
bug fix  
tangwei12 已提交
26
#include "paddle/fluid/framework/variable.h"
Y
Yi Wang 已提交
27
#include "paddle/fluid/platform/device_context.h"
D
dzhwinter 已提交
28
#include "paddle/fluid/platform/port.h"
Y
Yu Yang 已提交
29 30 31 32

namespace paddle {
namespace operators {

T
tangwei12 已提交
33 34
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
T
tangwei12 已提交
35
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
T
tangwei12 已提交
36

Y
Yu Yang 已提交
37 38 39 40 41 42
class SaveOp : public framework::OperatorBase {
 public:
  SaveOp(const std::string &type, const framework::VariableNameMap &inputs,
         const framework::VariableNameMap &outputs,
         const framework::AttributeMap &attrs)
      : OperatorBase(type, inputs, outputs, attrs) {}
43 44 45 46

 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
Y
Yu Yang 已提交
47 48 49 50 51
    auto iname = Input("X");
    auto *var = scope.FindVar(iname);
    PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
                   iname);

T
tangwei12 已提交
52
    if (var->IsType<framework::LoDTensor>()) {
T
bug fux  
tangwei12 已提交
53
      SaveLodTensor(place, var);
T
tangwei12 已提交
54
    } else if (var->IsType<framework::SelectedRows>()) {
T
tangwei12 已提交
55
      SaveSelectedRows(scope, place, var);
T
tangwei12 已提交
56 57 58 59 60 61 62
    } else {
      PADDLE_ENFORCE(
          false,
          "SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
          iname);
    }
  }
Y
Yu Yang 已提交
63

T
tangwei12 已提交
64
  void SaveLodTensor(const platform::Place &place,
T
bug fix  
tangwei12 已提交
65
                     framework::Variable *var) const {
T
bug fux  
tangwei12 已提交
66 67
    auto filename = Attr<std::string>("file_path");
    auto overwrite = Attr<bool>("overwrite");
D
dzhwinter 已提交
68
    auto format = Attr<std::string>("format");
T
bug fux  
tangwei12 已提交
69 70 71 72 73 74 75 76

    if (FileExists(filename) && !overwrite) {
      PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
                   filename, overwrite);
    }

    MkDirRecursively(DirName(filename).c_str());

Y
Yu Yang 已提交
77
    auto &tensor = var->Get<framework::LoDTensor>();
D
dzhwinter 已提交
78 79

    // get device context from pool
Y
Yu Yang 已提交
80 81
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(place);
D
dzhwinter 已提交
82

T
tangwei12 已提交
83 84
    // FIXME(yuyang18): We save variable to local file now, but we should change
    // it to save an output stream.
D
dzhwinter 已提交
85 86 87 88 89 90 91 92
    std::unique_ptr<std::ofstream> fout;
    if (format == "windows") {
      fout.reset(new std::ofstream(filename,
                                   std::ios_base::out | std::ios_base::binary));
    } else {
      fout.reset(new std::ofstream(filename));
    }
    PADDLE_ENFORCE(static_cast<bool>(*fout), "Cannot open %s to write",
T
tangwei12 已提交
93 94
                   filename);

T
bug fix  
tangwei12 已提交
95
    auto save_as_fp16 = Attr<bool>("save_as_fp16");
K
Kexin Zhao 已提交
96 97 98 99 100 101 102 103
    auto in_dtype = framework::ToDataType(tensor.type());
    auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;

    if (in_dtype != out_dtype) {
      auto in_kernel_type = framework::OpKernelType(in_dtype, place);
      auto out_kernel_type = framework::OpKernelType(out_dtype, place);
      framework::LoDTensor out;
      framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
K
Kexin Zhao 已提交
104 105
      // copy LoD info to the new tensor
      out.set_lod(tensor.lod());
D
dzhwinter 已提交
106
      framework::SerializeToStream(*fout, out, dev_ctx);
K
Kexin Zhao 已提交
107
    } else {
D
dzhwinter 已提交
108
      framework::SerializeToStream(*fout, tensor, dev_ctx);
K
Kexin Zhao 已提交
109
    }
T
tangwei12 已提交
110 111
  }

T
tangwei12 已提交
112
  void SaveSelectedRows(const framework::Scope &scope,
T
bug fix  
tangwei12 已提交
113 114
                        const platform::Place &place,
                        framework::Variable *var) const {
T
tangwei12 已提交
115
    auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
T
tangwei12 已提交
116 117
    PADDLE_ENFORCE(
        lt_var != nullptr,
T
tangwei12 已提交
118
        "Can not find variable kLookupTablePath for SaveSelectedRows");
T
tangwei12 已提交
119
    std::string filename = lt_var->data();
D
dzhwinter 已提交
120
    auto format = Attr<std::string>("format");
T
tangwei12 已提交
121 122
    VLOG(4) << "SaveSelectedRows get File name: " << filename;

T
bug fix  
tangwei12 已提交
123 124
    MkDirRecursively(DirName(filename).c_str());

T
tangwei12 已提交
125 126 127 128 129 130 131 132
    auto &selectedRows = var->Get<framework::SelectedRows>();

    // get device context from pool
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(place);

    // FIXME(yuyang18): We save variable to local file now, but we should change
    // it to save an output stream.
D
dzhwinter 已提交
133 134 135 136 137 138 139 140
    std::unique_ptr<std::ofstream> fout;
    if (format == "windows") {
      fout.reset(new std::ofstream(filename,
                                   std::ios_base::out | std::ios_base::binary));
    } else {
      fout.reset(new std::ofstream(filename));
    }
    PADDLE_ENFORCE(static_cast<bool>(*fout), "Cannot open %s to write",
T
tangwei12 已提交
141
                   filename);
D
dzhwinter 已提交
142
    framework::SerializeToStream(*fout, selectedRows, dev_ctx);
Y
Yu Yang 已提交
143 144 145 146 147
  }
};

class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
148
  void Make() override {
T
tangwei12 已提交
149
    AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved");
150 151 152
    AddComment(R"DOC(
Save operator

T
tangwei12 已提交
153
This operator will serialize and write LoDTensor / SelectedRows variable to file on disk.
Y
Yu Yang 已提交
154
)DOC");
155 156 157
    AddAttr<bool>("overwrite",
                  "(boolean, default true)"
                  "Overwrite the output file if exist")
Y
Yu Yang 已提交
158
        .SetDefault(true);
K
Kexin Zhao 已提交
159 160 161 162 163 164
    AddAttr<bool>("save_as_fp16",
                  "(boolean, default false)"
                  "If true, the tensor will be converted to float16 data "
                  "type and then saved. Otherwise, the tensor will be "
                  "directly saved without data type conversion.")
        .SetDefault(false);
Y
Yu Yang 已提交
165
    AddAttr<std::string>("file_path",
166 167
                         "(string)"
                         "The \"file_path\" where the variable will be saved.")
Y
Yu Yang 已提交
168 169
        .AddCustomChecker(
            [](const std::string &path) { return !path.empty(); });
D
dzhwinter 已提交
170 171 172 173 174 175 176 177 178 179 180 181
    AddAttr<std::string>("format",
                         R"DOC((windows|linux)" "saved model file format
                         windows and linux file newline symbol is
different. windows(newline is \n\r) or linux(newline is \r)
So if you set attribute format to windows, then we saved model file in binary.
It can be used both linux and windows. If you set format to linux,
it will save file in normal file, newline symbol is \r. Need to note
that these two format is not inter-compatible.)DOC")
        .SetDefault("linux")
        .AddCustomChecker([](const std::string &s) {
          return s == "windows" || s == "linux";
        });
Y
Yu Yang 已提交
182 183 184
  }
};

T
tangwei12 已提交
185 186 187 188
class SaveOpVarTypeInference : public framework::VarTypeInference {
 public:
  void operator()(const framework::OpDesc &op_desc,
                  framework::BlockDesc *block) const override {
T
tangwei12 已提交
189
    auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front();
T
tangwei12 已提交
190 191 192 193 194 195 196 197 198 199
    auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
    auto var_type = framework::proto::VarType::RAW;
    out_var.SetType(var_type);
  }
};

class SaveOpShapeInference : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *ctx) const override {}
};
T
tangwei12 已提交
200 201
}  // namespace operators
}  // namespace paddle
Y
Yu Yang 已提交
202 203 204

namespace ops = paddle::operators;

T
tangwei12 已提交
205 206
REGISTER_OPERATOR(save, ops::SaveOp, paddle::framework::EmptyGradOpMaker,
                  ops::SaveOpProtoMaker, ops::SaveOpVarTypeInference,
T
tangwei12 已提交
207
                  ops::SaveOpShapeInference);