chesspi_ai.cpp 5.4 KB
Newer Older
D
dev@dev.com 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#include <cstdio>
#include <memory.h>
#include <vector>
#include <string>
#include <cassert>
#include <algorithm>
#include <set>
#include <unordered_set>
#include <cmath>
#include <omp.h>
#include <atomic>
#include "chesspi.h"

D
dev@dev.com 已提交
14 15
int max_depth = 5;
//棋子代价
D
dev@dev.com 已提交
16
static const unsigned int table_cost[16] = {100000,150,150,150,150,150,150,500,500,150,150,100,100,100,100,100};
D
dev@dev.com 已提交
17 18 19 20 21 22 23 24 25
/*!
 * \brief calc_cost 计算走位后,idx棋子被击杀的代价,并返回
 * \param coordx 当前各个棋子的X坐标
 * \param coordy 当前各个棋子的Y坐标
 * \param alive  当前各个棋子的存活标记
 * \param killed 总击杀数
 * \param idx    被击杀的棋子
 * \return 击杀代价
 */
D
dev@dev.com 已提交
26
float calc_cost(const int coordx[/*32*/], const int coordy[/*32*/],const int alive[/*32*/],const int killed,const int idx)
D
dev@dev.com 已提交
27
{
D
dev@dev.com 已提交
28 29

	//	 * 帅士士相相马马车车炮炮兵兵兵兵兵  將仕仕象象馬馬車車砲砲卒卒卒卒卒
D
dev@dev.com 已提交
30 31 32

	assert(idx >= 16);

D
dev@dev.com 已提交
33 34 35 36 37 38 39 40 41 42 43
	unsigned int rescost = table_cost[idx%16];
	//位置加权
	switch (idx % 16) {
	//相士双全时击杀价值高
	case 1:
	case 2:
	case 3:
	case 4:
		if (alive[(idx-1)/2*2+1] || alive [(idx-1)/2*2+2])
			rescost *= 2;
		break;
D
dev@dev.com 已提交
44
		//马战线挺进,以及后期击杀价值高
D
dev@dev.com 已提交
45 46 47 48 49 50 51 52
	case 5:
	case 6:
		if (idx<16 )
			rescost *= 1+(coordy[idx]/3.0);
		else
			rescost *= 1+((11-coordy[idx])/3.0);
		rescost *= 1 + killed / 4.0;
		break;
D
dev@dev.com 已提交
53
		//车战线击杀高
D
dev@dev.com 已提交
54 55 56 57 58 59 60
	case 7:
	case 8:
		if (idx<16 )
			rescost *= 1+(coordy[idx]/3.0);
		else
			rescost *= 1+((11-coordy[idx])/3.0);
		break;
D
dev@dev.com 已提交
61
		//炮前期击杀高
D
dev@dev.com 已提交
62 63 64 65 66 67 68 69
	case 9:
	case 10:
		if (idx<16 )
			rescost *= 1+(coordy[idx]/3.0);
		else
			rescost *= 1+((11-coordy[idx])/3.0);
		rescost *= 1 + (32 - killed) / 4.0;
		break;
D
dev@dev.com 已提交
70
		//卒过河击杀高
D
dev@dev.com 已提交
71 72 73 74 75 76 77 78 79 80
	case 11:
	case 12:
	case 13:
	case 14:
	case 15:
		if (idx<16 )
			rescost *= coordy[idx]>5?4:1;
		else
			rescost *= coordy[idx]<6?4:1;
		//当头卒价值高
D
dev@dev.com 已提交
81 82
		if (coordx[idx]==5)
			rescost *=5;
D
dev@dev.com 已提交
83 84 85 86 87 88
		break;
	default:
		break;
	}


D
dev@dev.com 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101

	return rescost;
}

std::vector<chess_node> build_tree(const chess_node & root, const int side,const std::vector<chess_node> & history)
{
	std::vector<chess_node> tree;
	std::unordered_set <std::string> dict;
	for (const chess_node & n: history)
		dict.insert(node2hash(n.coords,n.alive));
	tree.push_back(root);
	tree[0].side = side % 2;
	tree[0].depth = 0;
D
dev@dev.com 已提交
102
	int max_nodes = 1000*1000*32;
D
dev@dev.com 已提交
103
	size_t curr_i = 0;
D
dev@dev.com 已提交
104 105 106 107

	//要停留在敌走的偶数步
	const int stop_depth = (max_depth+1)/2 * 2;

D
dev@dev.com 已提交
108 109 110 111 112 113 114 115
	while (tree.size()<=max_nodes && curr_i<tree.size())
	{
		const size_t ts = tree.size();
		const int cores = omp_get_num_procs();
		std::vector<std::vector<chess_node> > vec_appends;
		for (int i=0;i<cores;++i)
			vec_appends.push_back(std::vector<chess_node>());
		std::atomic<int> new_appends (0);
D
dev@dev.com 已提交
116

D
dev@dev.com 已提交
117 118 119 120 121 122
#pragma omp parallel for
		for (int i=curr_i;i<ts;++i)
		{
			if (new_appends + ts >=max_nodes)
				continue;
			const unsigned char clock = tree[i].depth;
D
dev@dev.com 已提交
123 124 125
			if (clock >= stop_depth)
				continue;
			bool onlykill = clock >=max_depth;
D
dev@dev.com 已提交
126 127 128 129 130
			const int tid = omp_get_thread_num();
			if ((tree[i].alive & 0x00010001)==0x00010001)
			{
				const int curr_side = (side + clock) % 2;
				std::vector<chess_node> next_status =
D
dev@dev.com 已提交
131
						expand_node(tree[i],curr_side,onlykill);
D
dev@dev.com 已提交
132 133 134 135 136 137 138 139

				const size_t sz = next_status.size();
				for (size_t j=0;j<sz;++j)
				{
					std::string ha = node2hash(next_status[j].coords,next_status[j].alive);
					bool needI = false;
#pragma omp critical
					{
D
dev@dev.com 已提交
140
						if (dict.find(ha)==dict.end() && clock+1 <= max_depth)
D
dev@dev.com 已提交
141 142 143 144
						{
							needI = true;
							dict.insert(ha);
						}
D
dev@dev.com 已提交
145 146 147 148 149 150 151 152
						else if (dict.find(ha)==dict.end() && clock + 1 <= stop_depth)
						{
							if (next_status[j].jump_cost[0]+next_status[j].jump_cost[1]>0)
							{
								needI = true;
								dict.insert(ha);
							}
						}
D
dev@dev.com 已提交
153 154 155 156 157 158 159 160 161
					}
					if (needI)
					{
						next_status[j].parent = i;
						next_status[j].side = curr_side;
						next_status[j].depth = clock+1;
						vec_appends[tid].push_back(next_status[j]);
						++tree[i].leaves;
						++new_appends;
D
dev@dev.com 已提交
162 163
						if (new_appends%1000==0)
						{
D
dev@dev.com 已提交
164
							printf ("Thinking.%d:%d...  \r",i,int(new_appends+ts));
D
dev@dev.com 已提交
165
						}
D
dev@dev.com 已提交
166 167 168 169 170 171 172 173 174 175 176 177
					}
				}
			}
		}
		for (int i=0;i<cores;++i)
		{
			if (vec_appends[i].size())
				std::move(vec_appends[i].begin(),vec_appends[i].end(),std::back_inserter(tree));
		}
		curr_i += (ts - curr_i);
	}

D
dev@dev.com 已提交
178
	printf ("\nDepth = %d                \n",tree.rbegin()->depth);
D
dev@dev.com 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193

	return tree;
}

size_t judge_tree(std::vector<chess_node> & tree)
{
	const size_t total_nodes = tree.size();
	if (total_nodes<2)
		return 0;
	int side = tree[0].side;
	size_t i = total_nodes - 1;
	while (i > 0)
	{
		if (tree[i].side==0)
		{
D
dev@dev.com 已提交
194
			float ratio = sqrt((tree[i].jump_cost[1]+1) / (tree[i].jump_cost[0]+1)/ (tree[i].jump_cost[0]+1));
D
dev@dev.com 已提交
195 196 197 198
			tree[i].weight = ratio;
		}
		else
		{
D
dev@dev.com 已提交
199
			float ratio = sqrt((tree[i].jump_cost[0]+1) / (tree[i].jump_cost[1]+1)/ (tree[i].jump_cost[1]+1));
D
dev@dev.com 已提交
200 201 202 203 204 205
			tree[i].weight = ratio;
		}
		size_t parent = tree[i].parent;
		tree[parent].jump_cost[0] += tree[i].jump_cost[0] * tree[i].weight/tree[i].depth/tree[i].depth;
		tree[parent].jump_cost[1] += tree[i].jump_cost[1] * tree[i].weight/tree[i].depth/tree[i].depth;
		--i;
D
dev@dev.com 已提交
206 207
		if (i%1000==0)
			printf ("Sorting.%d...  \r",total_nodes - i);
D
dev@dev.com 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
	}

	size_t p = 1;
	float max_v = 0;
	int max_p = 1;
	while (p<total_nodes)
	{
		if (tree[p].parent)
			break;
		//float v = (tree[p].jump_cost[1-side]+1)/(tree[p].jump_cost[side]);
		float v = (tree[p].weight);
		if (v > max_v)
		{
			max_v = v;
			max_p = p;
		}
		++p;
	}
	return max_p;
}