GradientMachine.cpp 1.3 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
#include "PaddleCAPI.h"
#include "PaddleCAPIPrivate.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"

#define cast(v) paddle::capi::cast<paddle::capi::CGradientMachine>(v)

enum GradientMatchineCreateMode {
  CREATE_MODE_NORMAL = 0,
  CREATE_MODE_TESTING = 4
};

namespace paddle {

class MyNeuralNetwork : public NeuralNetwork {
public:
  MyNeuralNetwork(const std::string& name, NeuralNetwork* network)
      : NeuralNetwork(name, network) {}
};

NeuralNetwork* newCustomNerualNetwork(const std::string& name,
                                      NeuralNetwork* network) {
  return new MyNeuralNetwork(name, network);
}
}

extern "C" {
int PDGradientMachineCreateForPredict(PD_GradiemtMachine* machine,
                                      void* modelConfigProtobuf,
                                      int size) {
  if (modelConfigProtobuf == nullptr) return PD_NULLPTR;
  paddle::ModelConfig config;
  if (!config.ParseFromArray(modelConfigProtobuf, size) ||
      !config.IsInitialized()) {
    return PD_PROTOBUF_ERROR;
  }

  auto ptr = new paddle::capi::CGradientMachine();
  ptr->machine.reset(paddle::GradientMachine::create(
      config, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
  *machine = ptr;
  return PD_NO_ERROR;
}

int PDGradientMachineDestroy(PD_GradiemtMachine machine) {
  delete cast(machine);
  return PD_NO_ERROR;
}
}