optimizer.cc 3.8 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "optimizer.h"
16 17 18
#include <glog/logging.h>
#include <cstdlib>
#include <cstring>
19 20 21
#include <string>

#include "parameter_optimizer.h"
22

D
dzhwinter 已提交
23 24
using paddle::optimizer::ParameterOptimizer;
using paddle::optimizer::Tensor;
25

D
dzhwinter 已提交
26
template <paddle_element_type VALUE>
27 28 29 30 31
struct EnumToType {};

template <class T>
struct TypeToEnum {};

D
dzhwinter 已提交
32 33 34 35 36 37 38 39 40
#define MATCH_ENUM_TYPE(TYPE, ENUM)                 \
  template <>                                       \
  struct TypeToEnum<TYPE> {                         \
    static paddle_element_type v() { return ENUM; } \
    static constexpr TYPE value = ENUM;             \
  };                                                \
  template <>                                       \
  struct EnumToType<ENUM> {                         \
    typedef TYPE Type;                              \
41
  }
42 43 44 45 46 47 48 49

MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32);
MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32);
MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64);
MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64);
MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32);
MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64);

D
dzhwinter 已提交
50 51
struct paddle_optimizer {
  paddle::optimizer::ParameterOptimizer* impl;
52 53 54
};

paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto,
55
                                          const int config_proto_len,
56 57 58 59 60
                                          const paddle_element_type data_type,
                                          void* param_buffer,
                                          int num_bytes,
                                          const char* state,
                                          const int state_len) {
D
dzhwinter 已提交
61
  paddle_optimizer* optimizer = new paddle_optimizer;
62
  std::string config(config_proto, config_proto + config_proto_len);
63 64
  Tensor* parameter = new Tensor(reinterpret_cast<float*>(param_buffer),
                                 num_bytes / sizeof(float));
D
dzhwinter 已提交
65
  optimizer->impl = ParameterOptimizer::Create(config, parameter);
66
  if (state != nullptr) {
67
    std::string s(state, state + state_len);
D
dzhwinter 已提交
68
    optimizer->impl->DeserializeState(s);
69
  }
70 71 72 73 74 75 76 77 78
  return optimizer;
}

int paddle_release_optimizer(paddle_optimizer* o) {
  if (o != nullptr) delete o->impl;
  return PADDLE_SUCCESS;
}

int paddle_update_parameter(paddle_optimizer* o,
79
                            const paddle_element_type data_type,
80 81
                            const void* grad_buffer,
                            int num_bytes) {
D
dzhwinter 已提交
82
  // TOOD(zhihong): datatype not work. need to add the runtime datatype
83
  auto grad_type = reinterpret_cast<const float*>(grad_buffer);
84 85
  Tensor* gradient =
      new Tensor(const_cast<float*>(grad_type), num_bytes / sizeof(float));
D
dzhwinter 已提交
86
  o->impl->Update(gradient);
87 88 89
  return PADDLE_SUCCESS;
}

90 91 92 93
int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer) {
  int param_size = 0;
  *param_buffer = (void*)o->impl->get_weight(&param_size);
  return param_size;
94
}
95

96
int paddle_optimizer_get_state(paddle_optimizer* o, const char** state) {
97 98 99 100 101 102 103 104
  std::string s = o->impl->SerializeState();
  int state_len = s.size();

  if (state_len > 0) {
    *state = (char*)std::malloc(state_len);
    std::memcpy((void*)*state, (const void*)s.c_str(), state_len);
  }

D
dzhwinter 已提交
105
  return state_len;
106
}