abtest.cpp 3.4 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

G
guru4elephant 已提交
15
#include "core/sdk-cpp/include/abtest.h"
W
sdk-cpp  
wangguibao 已提交
16 17 18 19 20

namespace baidu {
namespace paddle_serving {
namespace sdk_cpp {

W
wangguibao 已提交
21 22 23 24 25
int WeightedRandomRender::initialize(const google::protobuf::Message& conf) {
  srand((unsigned)time(NULL));
  try {
    const configure::WeightedRandomRenderConf& weighted_random_render_conf =
        dynamic_cast<const configure::WeightedRandomRenderConf&>(conf);
W
sdk-cpp  
wangguibao 已提交
26

W
wangguibao 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
    std::string weights = weighted_random_render_conf.variant_weight_list();

    std::vector<std::string> splits;
    if (str_split(weights, WEIGHT_SEPERATOR, &splits) != 0) {
      LOG(ERROR) << "Failed split string:" << weights;
      return -1;
    }

    uint32_t weight_size = splits.size();
    _normalized_sum = 0;
    for (uint32_t wi = 0; wi < weight_size; ++wi) {
      char* end_pos = NULL;
      uint32_t ratio = strtoul(splits[wi].c_str(), &end_pos, 10);
      if (end_pos == splits[wi].c_str()) {
        LOG(ERROR) << "Error ratio(uint32) format:" << splits[wi] << " at "
                   << wi;
W
sdk-cpp  
wangguibao 已提交
43
        return -1;
W
wangguibao 已提交
44 45 46 47
      }

      _variant_weight_list.push_back(ratio);
      _normalized_sum += ratio;
W
sdk-cpp  
wangguibao 已提交
48 49
    }

W
wangguibao 已提交
50 51 52 53 54
    if (_normalized_sum <= 0) {
      LOG(ERROR) << "Zero normalized weight sum";
      return -1;
    }

G
guru4elephant 已提交
55 56 57
    VLOG(2) << "Succ read weights list: " << weights
            << ", count: " << _variant_weight_list.size()
            << ", normalized: " << _normalized_sum;
W
wangguibao 已提交
58 59 60 61 62 63 64 65 66 67 68
  } catch (std::bad_cast& e) {
    LOG(ERROR) << "Failed init WeightedRandomRender"
               << "from configure, err:" << e.what();
    return -1;
  } catch (...) {
    LOG(ERROR) << "Failed init WeightedRandomRender"
               << "from configure, err message is unkown.";
    return -1;
  }

  return 0;
W
sdk-cpp  
wangguibao 已提交
69 70
}

W
wangguibao 已提交
71 72 73
Variant* WeightedRandomRender::route(const VariantList& variants,
                                     const void* params) {
  return route(variants);
W
sdk-cpp  
wangguibao 已提交
74 75
}

W
wangguibao 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89
Variant* WeightedRandomRender::route(const VariantList& variants) {
  if (variants.size() != _variant_weight_list.size()) {
    LOG(ERROR) << "#(Weights) is not equal #(Stubs)"
               << ", size: " << _variant_weight_list.size() << " vs. "
               << variants.size();
    return NULL;
  }

  uint32_t sample = rand() % _normalized_sum;  // NOLINT
  uint32_t cand_size = _variant_weight_list.size();
  uint32_t cur_total = 0;
  for (uint32_t ci = 0; ci < cand_size; ++ci) {
    cur_total += _variant_weight_list[ci];
    if (sample < cur_total) {
G
guru4elephant 已提交
90 91 92
      VLOG(2) << "Sample " << sample << " on " << ci
              << ", _normalized: " << _normalized_sum
              << ", weight: " << _variant_weight_list[ci];
W
wangguibao 已提交
93
      return variants[ci];
W
sdk-cpp  
wangguibao 已提交
94
    }
W
wangguibao 已提交
95
  }
W
sdk-cpp  
wangguibao 已提交
96

W
wangguibao 已提交
97 98
  LOG(ERROR) << "Errors accurs in sampling, sample:" << sample
             << ", total: " << _normalized_sum;
W
sdk-cpp  
wangguibao 已提交
99

W
wangguibao 已提交
100
  return NULL;
W
sdk-cpp  
wangguibao 已提交
101 102
}

W
wangguibao 已提交
103 104 105
}  // namespace sdk_cpp
}  // namespace paddle_serving
}  // namespace baidu