save_op.cc 7.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14 15 16 17

#include <stdint.h>
#include <fstream>
#include <numeric>
18
#include <vector>
Y
Yu Yang 已提交
19

Y
Yi Wang 已提交
20
#include "paddle/fluid/framework/data_type.h"
K
Kexin Zhao 已提交
21
#include "paddle/fluid/framework/data_type_transform.h"
Y
Yi Wang 已提交
22 23 24
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
T
tangwei12 已提交
25
#include "paddle/fluid/framework/selected_rows.h"
T
bug fix  
tangwei12 已提交
26
#include "paddle/fluid/framework/variable.h"
Y
Yi Wang 已提交
27
#include "paddle/fluid/platform/device_context.h"
D
dzhwinter 已提交
28
#include "paddle/fluid/platform/port.h"
Y
Yu Yang 已提交
29 30 31 32

namespace paddle {
namespace operators {

33 34 35
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

T
tangwei12 已提交
36 37
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
T
tangwei12 已提交
38
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
T
tangwei12 已提交
39

40
class SaveOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
41
 public:
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {}

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
                                   ctx.GetPlace());
  }
};

class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved");
    AddComment(R"DOC(
Save operator

This operator will serialize and write LoDTensor / SelectedRows variable to file on disk.
)DOC");
    AddAttr<bool>("overwrite",
                  "(boolean, default true)"
                  "Overwrite the output file if exist")
        .SetDefault(true);
    AddAttr<bool>("save_as_fp16",
                  "(boolean, default false)"
                  "If true, the tensor will be converted to float16 data "
                  "type and then saved. Otherwise, the tensor will be "
                  "directly saved without data type conversion.")
        .SetDefault(false);
    AddAttr<std::string>("file_path",
                         "(string)"
                         "The \"file_path\" where the variable will be saved.")
        .AddCustomChecker(
            [](const std::string &path) { return !path.empty(); });
    AddOutput(LOOKUP_TABLE_PATH,
              "(string)"
              "for pserver: The \"kLookupTablePath\" where checkpoint notify "
              "to save lookup table variables"
              " to directory specified.")
        .AsDispensable();
  }
};

class SaveOpVarTypeInference : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    auto var_type = framework::proto::VarType::RAW;
    ctx->SetType(LOOKUP_TABLE_PATH, var_type);
  }
};

class SaveOpShapeInference : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *ctx) const override {}
};

template <typename DeviceContext, typename T>
class SaveOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto place = ctx.GetPlace();

    auto *input_var = ctx.InputVar("X");
    auto iname = ctx.Inputs("X").data();
    PADDLE_ENFORCE(input_var != nullptr, "Cannot find variable %s for save_op",
Y
Yu Yang 已提交
109 110
                   iname);

111 112 113 114
    if (input_var->IsType<framework::LoDTensor>()) {
      SaveLodTensor(ctx, place, input_var);
    } else if (input_var->IsType<framework::SelectedRows>()) {
      SaveSelectedRows(ctx, place, input_var);
T
tangwei12 已提交
115 116 117 118 119 120 121
    } else {
      PADDLE_ENFORCE(
          false,
          "SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
          iname);
    }
  }
Y
Yu Yang 已提交
122

123 124 125 126 127
  void SaveLodTensor(const framework::ExecutionContext &ctx,
                     const platform::Place &place,
                     const framework::Variable *var) const {
    auto filename = ctx.Attr<std::string>("file_path");
    auto overwrite = ctx.Attr<bool>("overwrite");
T
bug fux  
tangwei12 已提交
128 129 130 131 132 133 134 135

    if (FileExists(filename) && !overwrite) {
      PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
                   filename, overwrite);
    }

    MkDirRecursively(DirName(filename).c_str());

Y
Yu Yang 已提交
136
    auto &tensor = var->Get<framework::LoDTensor>();
D
dzhwinter 已提交
137 138

    // get device context from pool
Y
Yu Yang 已提交
139 140
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(place);
D
dzhwinter 已提交
141

T
tangwei12 已提交
142 143
    // FIXME(yuyang18): We save variable to local file now, but we should change
    // it to save an output stream.
144
    std::ofstream fout(filename, std::ios::binary);
145
    PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
T
tangwei12 已提交
146 147
                   filename);

148
    auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
Y
Yu Yang 已提交
149
    auto in_dtype = tensor.type();
K
Kexin Zhao 已提交
150 151 152 153 154 155 156
    auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;

    if (in_dtype != out_dtype) {
      auto in_kernel_type = framework::OpKernelType(in_dtype, place);
      auto out_kernel_type = framework::OpKernelType(out_dtype, place);
      framework::LoDTensor out;
      framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
K
Kexin Zhao 已提交
157 158
      // copy LoD info to the new tensor
      out.set_lod(tensor.lod());
159
      framework::SerializeToStream(fout, out, dev_ctx);
K
Kexin Zhao 已提交
160
    } else {
161
      framework::SerializeToStream(fout, tensor, dev_ctx);
K
Kexin Zhao 已提交
162
    }
163
    fout.close();
T
tangwei12 已提交
164 165
  }

166
  void SaveSelectedRows(const framework::ExecutionContext &ctx,
T
bug fix  
tangwei12 已提交
167
                        const platform::Place &place,
168 169
                        const framework::Variable *var) const {
    framework::Variable *out_put_var = ctx.OutputVar(LOOKUP_TABLE_PATH);
T
tangwei12 已提交
170
    PADDLE_ENFORCE(
171
        out_put_var != nullptr,
T
tangwei12 已提交
172
        "Can not find variable kLookupTablePath for SaveSelectedRows");
173 174
    auto *lt_var = out_put_var->GetMutable<std::string>();

T
tangwei12 已提交
175
    std::string filename = lt_var->data();
M
minqiyang 已提交
176
    VLOG(4) << "SaveSelectedRows get File name: " << filename;
T
tangwei12 已提交
177

T
bug fix  
tangwei12 已提交
178 179
    MkDirRecursively(DirName(filename).c_str());

T
tangwei12 已提交
180 181 182 183 184 185 186 187
    auto &selectedRows = var->Get<framework::SelectedRows>();

    // get device context from pool
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(place);

    // FIXME(yuyang18): We save variable to local file now, but we should change
    // it to save an output stream.
188
    std::ofstream fout(filename, std::ios::binary);
189
    PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
T
tangwei12 已提交
190
                   filename);
191 192
    framework::SerializeToStream(fout, selectedRows, dev_ctx);
    fout.close();
Y
Yu Yang 已提交
193 194 195
  }
};

T
tangwei12 已提交
196 197
}  // namespace operators
}  // namespace paddle
Y
Yu Yang 已提交
198 199 200

namespace ops = paddle::operators;

201 202 203 204 205 206 207 208 209
REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker,
                  ops::SaveOpVarTypeInference, ops::SaveOpShapeInference);

REGISTER_OP_CPU_KERNEL(
    save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
    ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);