save_op.cc 6.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14 15 16 17 18

#include <stdint.h>
#include <fstream>
#include <numeric>

Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/data_type.h"
K
Kexin Zhao 已提交
20
#include "paddle/fluid/framework/data_type_transform.h"
Y
Yi Wang 已提交
21 22 23
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
T
tangwei12 已提交
24
#include "paddle/fluid/framework/selected_rows.h"
T
bug fix  
tangwei12 已提交
25
#include "paddle/fluid/framework/variable.h"
Y
Yi Wang 已提交
26
#include "paddle/fluid/platform/device_context.h"
D
dzhwinter 已提交
27
#include "paddle/fluid/platform/port.h"
Y
Yu Yang 已提交
28 29 30 31

namespace paddle {
namespace operators {

T
tangwei12 已提交
32 33
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
T
tangwei12 已提交
34
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
T
tangwei12 已提交
35

Y
Yu Yang 已提交
36 37 38 39 40 41
class SaveOp : public framework::OperatorBase {
 public:
  SaveOp(const std::string &type, const framework::VariableNameMap &inputs,
         const framework::VariableNameMap &outputs,
         const framework::AttributeMap &attrs)
      : OperatorBase(type, inputs, outputs, attrs) {}
42 43 44 45

 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
Y
Yu Yang 已提交
46 47 48 49 50
    auto iname = Input("X");
    auto *var = scope.FindVar(iname);
    PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
                   iname);

T
tangwei12 已提交
51
    if (var->IsType<framework::LoDTensor>()) {
T
bug fux  
tangwei12 已提交
52
      SaveLodTensor(place, var);
T
tangwei12 已提交
53
    } else if (var->IsType<framework::SelectedRows>()) {
T
tangwei12 已提交
54
      SaveSelectedRows(scope, place, var);
T
tangwei12 已提交
55 56 57 58 59 60 61
    } else {
      PADDLE_ENFORCE(
          false,
          "SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
          iname);
    }
  }
Y
Yu Yang 已提交
62

T
tangwei12 已提交
63
  void SaveLodTensor(const platform::Place &place,
T
bug fix  
tangwei12 已提交
64
                     framework::Variable *var) const {
T
bug fux  
tangwei12 已提交
65 66 67 68 69 70 71 72 73 74
    auto filename = Attr<std::string>("file_path");
    auto overwrite = Attr<bool>("overwrite");

    if (FileExists(filename) && !overwrite) {
      PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
                   filename, overwrite);
    }

    MkDirRecursively(DirName(filename).c_str());

Y
Yu Yang 已提交
75
    auto &tensor = var->Get<framework::LoDTensor>();
D
dzhwinter 已提交
76 77

    // get device context from pool
Y
Yu Yang 已提交
78 79
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(place);
D
dzhwinter 已提交
80

T
tangwei12 已提交
81 82 83 84 85 86
    // FIXME(yuyang18): We save variable to local file now, but we should change
    // it to save an output stream.
    std::ofstream fout(filename);
    PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
                   filename);

T
bug fix  
tangwei12 已提交
87
    auto save_as_fp16 = Attr<bool>("save_as_fp16");
K
Kexin Zhao 已提交
88 89 90 91 92 93 94 95
    auto in_dtype = framework::ToDataType(tensor.type());
    auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;

    if (in_dtype != out_dtype) {
      auto in_kernel_type = framework::OpKernelType(in_dtype, place);
      auto out_kernel_type = framework::OpKernelType(out_dtype, place);
      framework::LoDTensor out;
      framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
K
Kexin Zhao 已提交
96 97
      // copy LoD info to the new tensor
      out.set_lod(tensor.lod());
K
Kexin Zhao 已提交
98 99 100 101
      framework::SerializeToStream(fout, out, dev_ctx);
    } else {
      framework::SerializeToStream(fout, tensor, dev_ctx);
    }
T
bug fix  
tangwei12 已提交
102
    fout.close();
T
tangwei12 已提交
103 104
  }

T
tangwei12 已提交
105
  void SaveSelectedRows(const framework::Scope &scope,
T
bug fix  
tangwei12 已提交
106 107
                        const platform::Place &place,
                        framework::Variable *var) const {
T
tangwei12 已提交
108
    auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
T
tangwei12 已提交
109 110
    PADDLE_ENFORCE(
        lt_var != nullptr,
T
tangwei12 已提交
111
        "Can not find variable kLookupTablePath for SaveSelectedRows");
T
tangwei12 已提交
112
    std::string filename = lt_var->data();
113
    VLOG(40) << "SaveSelectedRows get File name: " << filename;
T
tangwei12 已提交
114

T
bug fix  
tangwei12 已提交
115 116
    MkDirRecursively(DirName(filename).c_str());

T
tangwei12 已提交
117 118 119 120 121 122 123 124 125 126 127 128
    auto &selectedRows = var->Get<framework::SelectedRows>();

    // get device context from pool
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(place);

    // FIXME(yuyang18): We save variable to local file now, but we should change
    // it to save an output stream.
    std::ofstream fout(filename);
    PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
                   filename);
    framework::SerializeToStream(fout, selectedRows, dev_ctx);
T
bug fix  
tangwei12 已提交
129
    fout.close();
Y
Yu Yang 已提交
130 131 132 133 134
  }
};

class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
135
  void Make() override {
T
tangwei12 已提交
136
    AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved");
137 138 139
    AddComment(R"DOC(
Save operator

T
tangwei12 已提交
140
This operator will serialize and write LoDTensor / SelectedRows variable to file on disk.
Y
Yu Yang 已提交
141
)DOC");
142 143 144
    AddAttr<bool>("overwrite",
                  "(boolean, default true)"
                  "Overwrite the output file if exist")
Y
Yu Yang 已提交
145
        .SetDefault(true);
K
Kexin Zhao 已提交
146 147 148 149 150 151
    AddAttr<bool>("save_as_fp16",
                  "(boolean, default false)"
                  "If true, the tensor will be converted to float16 data "
                  "type and then saved. Otherwise, the tensor will be "
                  "directly saved without data type conversion.")
        .SetDefault(false);
Y
Yu Yang 已提交
152
    AddAttr<std::string>("file_path",
153 154
                         "(string)"
                         "The \"file_path\" where the variable will be saved.")
Y
Yu Yang 已提交
155 156 157 158 159
        .AddCustomChecker(
            [](const std::string &path) { return !path.empty(); });
  }
};

T
tangwei12 已提交
160 161 162 163
class SaveOpVarTypeInference : public framework::VarTypeInference {
 public:
  void operator()(const framework::OpDesc &op_desc,
                  framework::BlockDesc *block) const override {
T
tangwei12 已提交
164
    auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front();
T
tangwei12 已提交
165 166 167 168 169 170 171 172 173 174
    auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
    auto var_type = framework::proto::VarType::RAW;
    out_var.SetType(var_type);
  }
};

class SaveOpShapeInference : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *ctx) const override {}
};
T
tangwei12 已提交
175 176
}  // namespace operators
}  // namespace paddle
Y
Yu Yang 已提交
177 178 179

namespace ops = paddle::operators;

T
tangwei12 已提交
180 181
REGISTER_OPERATOR(save, ops::SaveOp, paddle::framework::EmptyGradOpMaker,
                  ops::SaveOpProtoMaker, ops::SaveOpVarTypeInference,
T
tangwei12 已提交
182
                  ops::SaveOpShapeInference);