提交 ea284cf6 编写于 作者: M mxdg

first commit

上级 86eee60d
#!/usr/bin/env python3
# coding=utf-8
from __future__ import print_function, unicode_literals, division, absolute_import
import sys
import time
import binascii
import struct
import collections
import logging
import socket
import select
import threading
import traceback
import functools
import server_pool
try:
# for pycharm type hinting
from typing import Union, Callable
except:
pass
# socket recv buffer, 16384 bytes
RECV_BUFFER_SIZE = 2 ** 14
# default secretkey, use -k/--secretkey to change
#SECRET_KEY = "shootback"
# how long a SPARE slaver would keep
# once slaver received an heart-beat package from master,
# the TTL would be reset. And heart-beat delay is less than TTL,
# so, theoretically, spare slaver never timeout,
# except network failure
# notice: working slaver would NEVER timeout
SPARE_SLAVER_TTL = 300
# internal program version, appears in CtrlPkg
INTERNAL_VERSION = 0x000D
# version for human readable
__version__ = (2, 2, 8, INTERNAL_VERSION)
# just a logger
log = logging.getLogger(__name__)
def version_info():
"""get program version for human. eg: "2.1.0-r2" """
return "{}.{}.{}-r{}".format(*__version__)
def configure_logging(level):
logging.basicConfig(
level=level,
format='[%(levelname)s %(asctime)s] %(message)s',
)
def fmt_addr(socket):
"""(host, int(port)) --> "host:port" """
return "{}:{}".format(*socket)
def split_host(x):
""" "host:port" --> (host, int(port))"""
try:
host, port = x.split(":")
port = int(port)
except:
raise ValueError(
"wrong syntax, format host:port is "
"required, not {}".format(x))
else:
return host, port
def try_close(closable):
"""try close something
same as
try:
connection.close()
except:
pass
"""
try:
closable.close()
except:
pass
def select_recv(conn, buff_size, timeout=None):
"""add timeout for socket.recv()
:type conn: socket.SocketType
:type buff_size: int
:type timeout: float
:rtype: Union[bytes, None]
"""
rlist, _, _ = select.select([conn], [], [], timeout)
if not rlist:
# timeout
raise RuntimeError("recv timeout")
buff = conn.recv(buff_size)
if not buff:
raise RuntimeError("received zero bytes, socket was closed")
return buff
class SocketBridge:
"""
transfer data between sockets
"""
def __init__(self):
self.conn_rd = set() # record readable-sockets
self.map = {} # record sockets pairs
self.callbacks = {} # record callbacks
self.tmp_thread = None
def add_conn_pair(self, conn1, conn2,tmp=None, callback=None):
"""
transfer anything between two sockets
:type conn1: socket.SocketType
:type conn2: socket.SocketType
:param callback: callback in connection finish
:type callback: Callable
"""
# mark as readable
self.conn_rd.add(conn1)
self.conn_rd.add(conn2)
# record sockets pairs
self.map[conn1] = conn2
self.map[conn2] = conn1
# record callback
if callback is not None:
self.callbacks[conn1] = callback
if tmp is not None:
conn2.send(tmp)
logging.info("tmp send:{}".format(len(tmp)))
def get_thread(self):
return self.tmp_thread
def start_as_daemon(self):
t = threading.Thread(target=self.start)
t.daemon = True
t.start()
log.info("SocketBridge daemon started")
self.tmp_thread = t;
# return t
def start(self):
server_pool.ServerPool.bridgeAdd += 1
while True:
try:
self._start()
except:
log.error("FATAL ERROR! SocketBridge failed {}".format(
traceback.format_exc()
))
def _start(self):
# memoryview act as an recv buffer
# refer https://docs.python.org/3/library/stdtypes.html#memoryview
buff = memoryview(bytearray(RECV_BUFFER_SIZE))
while True:
if not self.conn_rd:
# sleep if there is no connections
time.sleep(0.06)
continue
# blocks until there is socket(s) ready for .recv
# notice: sockets which were closed by remote,
# are also regarded as read-ready by select()
r, w, e = select.select(self.conn_rd, [], [], 0.5)
for s in r: # iter every read-ready or closed sockets
try:
# here, we use .recv_into() instead of .recv()
# recv data directly into the pre-allocated buffer
# to avoid many unnecessary malloc()
# see https://docs.python.org/3/library/socket.html#socket.socket.recv_into
rec_len = s.recv_into(buff, RECV_BUFFER_SIZE)
# agre = "http"
# url = agre + '://' + heads['Host']
# heads = httphead(buff.tobytes().decode('utf-8'))
# logging.info("recv head:{}".format(heads))
except Exception as e:
# unable to read, in most cases, it's due to socket close
self._rd_shutdown(s)
continue
if not rec_len:
# read zero size, closed or shutdowned socket
self._rd_shutdown(s)
continue
try:
# send data, we use `buff[:rec_len]` slice because
# only the front of buff is filled
self.map[s].send(buff[:rec_len])
except Exception as e:
# unable to send, close connection
self._rd_shutdown(s)
continue
def _rd_shutdown(self, conn, once=False):
"""action when connection should be read-shutdown
:type conn: socket.SocketType
"""
if conn in self.conn_rd:
self.conn_rd.remove(conn)
try:
conn.shutdown(socket.SHUT_RD)
except:
pass
if not once and conn in self.map: # use the `once` param to avoid infinite loop
# if a socket is rd_shutdowned, then it's
# pair should be wr_shutdown.
self._wr_shutdown(self.map[conn], True)
if self.map.get(conn) not in self.conn_rd:
# if both two connection pair was rd-shutdowned,
# this pair sockets are regarded to be completed
# so we gonna close them
self._terminate(conn)
def _wr_shutdown(self, conn, once=False):
"""action when connection should be write-shutdown
:type conn: socket.SocketType
"""
try:
conn.shutdown(socket.SHUT_WR)
except:
pass
if not once and conn in self.map: # use the `once` param to avoid infinite loop
# pair should be rd_shutdown.
# if a socket is wr_shutdowned, then it's
self._rd_shutdown(self.map[conn], True)
def _terminate(self, conn):
"""terminate a sockets pair (two socket)
:type conn: socket.SocketType
:param conn: any one of the sockets pair
"""
try_close(conn) # close the first socket
server_pool.ServerPool.bridgeRemove += 1
# ------ close and clean the mapped socket, if exist ------
if conn in self.map:
_mapped_conn = self.map[conn]
try_close(_mapped_conn)
if _mapped_conn in self.map:
del self.map[_mapped_conn]
del self.map[conn] # clean the first socket
else:
_mapped_conn = None # just a fallback
# ------ callback --------
# because we are not sure which socket are assigned to callback,
# so we should try both
if conn in self.callbacks:
try:
self.callbacks[conn]()
except Exception as e:
log.error("traceback error: {}".format(e))
log.debug(traceback.format_exc())
del self.callbacks[conn]
elif _mapped_conn and _mapped_conn in self.callbacks:
try:
self.callbacks[_mapped_conn]()
except Exception as e:
log.error("traceback error: {}".format(e))
log.debug(traceback.format_exc())
del self.callbacks[_mapped_conn]
class CtrlPkg:
PACKAGE_SIZE = 2 ** 6 # 64 bytes
CTRL_PKG_TIMEOUT = 5 # CtrlPkg recv timeout, in second
# CRC32 for SECRET_KEY and Reversed(SECRET_KEY)
SECRET_KEY_CRC32 = 0# = binascii.crc32(SECRET_KEY.encode('utf-8')) & 0xffffffff
SECRET_KEY_REVERSED_CRC32 = 0# = binascii.crc32(SECRET_KEY[::-1].encode('utf-8')) & 0xffffffff
# Package Type
PTYPE_HS_S2M = -1 # handshake pkg, slaver to master
PTYPE_HEART_BEAT = 0 # heart beat pkg
PTYPE_HS_M2S = +1 # handshake pkg, Master to Slaver
TYPE_NAME_MAP = {
PTYPE_HS_S2M: "PTYPE_HS_S2M",
PTYPE_HEART_BEAT: "PTYPE_HEART_BEAT",
PTYPE_HS_M2S: "PTYPE_HS_M2S",
}
# formats
# see https://docs.python.org/3/library/struct.html#format-characters
# for format syntax
FORMAT_PKG = "!b b H 20x 40s"
FORMATS_DATA = {
PTYPE_HS_S2M: "!I 36x",
PTYPE_HEART_BEAT: "!40x",
PTYPE_HS_M2S: "!I 36x",
}
def __init__(self, pkg_ver=0x01, pkg_type=0,
prgm_ver=INTERNAL_VERSION, data=(),
raw=None,SECRET_KEY_CRC32=0,SECRET_KEY_REVERSED_CRC32=0
):
"""do not call this directly, use `CtrlPkg.pbuild_*` instead"""
self._cache_prebuilt_pkg = {} # cache
self.pkg_ver = pkg_ver
self.pkg_type = pkg_type
self.prgm_ver = prgm_ver
self.data = data
self.SECRET_KEY_CRC32 = SECRET_KEY_CRC32
self.SECRET_KEY_REVERSED_CRC32 = SECRET_KEY_REVERSED_CRC32
if raw:
self.raw = raw
else:
self._build_bytes()
@property
def type_name(self):
"""返回人类可读的包类型"""
return self.TYPE_NAME_MAP.get(self.pkg_type, "TypeUnknown")
def __str__(self):
return """pkg_ver: {} pkg_type:{} prgm_ver:{} data:{}""".format(
self.pkg_ver,
self.type_name,
self.prgm_ver,
self.data,
)
def __repr__(self):
return self.__str__()
def _build_bytes(self):
self.raw = struct.pack(
self.FORMAT_PKG,
self.pkg_ver,
self.pkg_type,
self.prgm_ver,
self.data_encode(self.pkg_type, self.data),
)
def _prebuilt_pkg(cls, pkg_type, fallback):
"""act as lru_cache"""
if pkg_type not in cls._cache_prebuilt_pkg:
pkg = fallback(force_rebuilt=True)
cls._cache_prebuilt_pkg[pkg_type] = pkg
logging.info("_prebuilt_pkg,id:{}".format(id(cls._cache_prebuilt_pkg)))
return cls._cache_prebuilt_pkg[pkg_type]
def recalc_crc32(cls,skey):
cls.skey = skey
cls.SECRET_KEY_CRC32 = binascii.crc32(skey.encode('utf-8')) & 0xffffffff
cls.SECRET_KEY_REVERSED_CRC32 = binascii.crc32(skey[::-1].encode('utf-8')) & 0xffffffff
logging.info("main key:{},id:{},{},{}".format(cls.skey,id(cls),cls.SECRET_KEY_CRC32,cls.SECRET_KEY_REVERSED_CRC32))
def clean_crc32(self):
self.skey = ""
self.SECRET_KEY_CRC32 = "closed hahaha"
self.SECRET_KEY_REVERSED_CRC32 = "closed hahaha"
def data_decode(cls, ptype, data_raw):
return struct.unpack(cls.FORMATS_DATA[ptype], data_raw)
def data_encode(cls, ptype, data):
return struct.pack(cls.FORMATS_DATA[ptype], *data)
def verify(self, pkg_type=None):
logging.info("verify 响应包 {},{},{}".format(self.data, self.SECRET_KEY_CRC32,self.SECRET_KEY_REVERSED_CRC32))
try:
if pkg_type is not None and self.pkg_type != pkg_type:
return False
elif self.pkg_type == self.PTYPE_HS_S2M:
# Slaver-->Master 的握手响应包
logging.info("Slaver-->Master 的握手响应包 {},{}".format(self.data[0],self.SECRET_KEY_REVERSED_CRC32))
return self.data[0] == self.SECRET_KEY_REVERSED_CRC32
elif self.pkg_type == self.PTYPE_HEART_BEAT:
# 心跳
return True
elif self.pkg_type == self.PTYPE_HS_M2S:
# Master-->Slaver 的握手包
logging.info("Master-->Slaver 的握手包".format(self.data[0], self.SECRET_KEY_CRC32))
return self.data[0] == self.SECRET_KEY_CRC32
else:
return True
except:
return False
def decode_only(cls, raw):
"""
decode raw bytes to CtrlPkg instance, no verify
use .decode_verify() if you also want verify
:param raw: raw bytes content of package
:type raw: bytes
:rtype: CtrlPkg
"""
if not raw or len(raw) != cls.PACKAGE_SIZE:
raise ValueError("content size should be {}, but {}".format(
cls.PACKAGE_SIZE, len(raw)
))
pkg_ver, pkg_type, prgm_ver, data_raw = struct.unpack(cls.FORMAT_PKG, raw)
logging.info("CtrlPkg,decode_only,,,,pkg_ver:{}, pkg_type:{}, prgm_ver:{}".format(pkg_ver, pkg_type, prgm_ver))
data = cls.data_decode(pkg_type, data_raw)
logging.info("CtrlPkg,decode_only,data:{}".format(data))
return CtrlPkg(
pkg_ver=pkg_ver, pkg_type=pkg_type,
prgm_ver=prgm_ver,
data=data,
raw=raw,
SECRET_KEY_CRC32=cls.SECRET_KEY_CRC32, SECRET_KEY_REVERSED_CRC32=cls.SECRET_KEY_REVERSED_CRC32
)
def decode_verify(cls, raw, pkg_type=None):
"""decode and verify a package
:param raw: raw bytes content of package
:type raw: bytes
:param pkg_type: assert this package's type,
if type not match, would be marked as wrong
:type pkg_type: int
:rtype: CtrlPkg, bool
:return: tuple(CtrlPkg, is_it_a_valid_package)
"""
try:
pkg = cls.decode_only(raw)
except:
return None, False
else:
return pkg, pkg.verify(pkg_type=pkg_type)
def pbuild_hs_m2s(cls, force_rebuilt=False):
"""pkg build: Handshake Master to Slaver"""
# because py27 do not have functools.lru_cache, so we must write our own
if force_rebuilt:
return CtrlPkg(
pkg_type=cls.PTYPE_HS_M2S,
data=(cls.SECRET_KEY_CRC32,),
SECRET_KEY_CRC32=cls.SECRET_KEY_CRC32, SECRET_KEY_REVERSED_CRC32=cls.SECRET_KEY_REVERSED_CRC32
)
else:
return cls._prebuilt_pkg(cls.PTYPE_HS_M2S, cls.pbuild_hs_m2s)
def pbuild_hs_s2m(cls, force_rebuilt=False):
"""pkg build: Handshake Slaver to Master"""
if force_rebuilt:
return CtrlPkg(
pkg_type=cls.PTYPE_HS_S2M,
data=(cls.SECRET_KEY_REVERSED_CRC32,),
SECRET_KEY_CRC32=cls.SECRET_KEY_CRC32, SECRET_KEY_REVERSED_CRC32=cls.SECRET_KEY_REVERSED_CRC32
)
else:
return cls._prebuilt_pkg(cls.PTYPE_HS_S2M, cls.pbuild_hs_s2m)
def pbuild_heart_beat(cls, force_rebuilt=False):
"""pkg build: Heart Beat Package"""
if force_rebuilt:
return CtrlPkg(
pkg_type=cls.PTYPE_HEART_BEAT,
SECRET_KEY_CRC32=cls.SECRET_KEY_CRC32, SECRET_KEY_REVERSED_CRC32=cls.SECRET_KEY_REVERSED_CRC32
)
else:
return cls._prebuilt_pkg(cls.PTYPE_HEART_BEAT, cls.pbuild_heart_beat)
def recv(cls, sock, timeout=CTRL_PKG_TIMEOUT, expect_ptype=None):
"""just a shortcut function
:param sock: which socket to recv CtrlPkg from
:type sock: socket.SocketType
:rtype: CtrlPkg,bool
"""
logging.info("CtrlPkg,recv,sock:{},expect_ptype:{}".format(sock,expect_ptype))
buff = select_recv(sock, cls.PACKAGE_SIZE, timeout)
pkg, verify = cls.decode_verify(buff, pkg_type=expect_ptype) # type: CtrlPkg,bool
return pkg, verify
def httphead(request):
header = request.split('\r\n\r\n', 1)[0]
headers = dict()
for line in header.split('\r\n')[1:]:
key, val = line.split(': ', 1)
headers[key] = val
return headers
\ No newline at end of file
{
"http":{
"to_master":"0.0.0.0:10013",
"customer":"0.0.0.0:80",
"host":[
{
"domain":"pwd.ngrokhk.linkbus.xyz",
"auth":{
"username":"cn",
"password":"1234"
}
},
{
"domain":"nopwd.ngrokhk.linkbus.xyz"
}
]
},
"tcp":[
{
"master":"0.0.0.0:10013",
"customer":"0.0.0.0:10125",
"secretkey":"pwd001"
},
{
"master":"0.0.0.0:10014",
"customer":"0.0.0.0:10126",
"secretkey":"pwd002"
}
]
}
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
from server_pool import ServerPool
class Dbv3Transfer(object):
@staticmethod
def thread_db(obj):
ServerPool.get_instance()
@staticmethod
def thread_db_stop():
ServerPool.get_instance().stop()
\ No newline at end of file
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
import time
import select
import json
from master import *
import server_pool
from master2 import *
import socket
class EventLoop(object):
def __init__(self):
self._stopping = False
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.conn_rd =[sock]
self.portdict = {}
self.socketbridge = SocketBridge()
self.socketbridge.start_as_daemon()
@property
def getPortdict(self):
return self.portdict
def thread_stop(self):
self._stopping = True
exists = self.portdict.keys()
for cur in exists:
logging.info("loop dispose port {}".format(cur))
self.portdict[cur].dispose()
def run(self):
while not self._stopping:
# logging.debug('hh:{}'.format(r['hh']))
try:
f = open('config.json', 'r')
configs = json.load(f)
http = configs['http']
tcp = configs['tcp']
check = {}
if http:
host, port = http['customer'].split(":")
if self.portdict.get(port) is None:
self.portdict[port] = http_service(http)
self.portdict[port].updateconfig(http)
check[port] = http
for c in tcp:
host, port = c['master'].split(":")
check[port] = c
exists = self.portdict.keys()
removelist = []
for cur in exists:
if not check.get(cur):
logging.info("loop dispose port {}".format(cur))
self.portdict[cur].dispose()
removelist.append(cur)
for re in removelist:
self.portdict.pop(re)
for (port,c) in check.items():
if not self.portdict.get(port):
logging.info("run init {}".format(c['master']))
self.portdict[port] = Mastar_line(self.socketbridge)
self.portdict[port].main_master(c)
except Exception as e:
logging.info("fail config.json e:{}".format(e))
finally:
if f:
f.close()
# logging.debug('using event model: 123')
logging.info("bridgeAdd:{},bridgeRemove:{}".format(server_pool.ServerPool.bridgeAdd,server_pool.ServerPool.bridgeRemove))
time.sleep(10)
#!/usr/bin/env python3
# coding=utf-8
from common_func import *
import queue
import threading
#_listening_sockets = [] # for close at exit
# __author__ = "Aploium <i@z.codes>"
# __website__ = "https://github.com/aploium/shootback"
local = threading.local()
class Master:
def __init__(self, customer_listen_addr, communicate_addr=None,
slaver_pool=None,pkg=None, socketbridge=None,_listening_sockets=None):
"""
:param customer_listen_addr: equals to the -c/--customer param
:param communicate_addr: equals to the -m/--master param
"""
self.pkg = pkg
self._stopped = {"stop":False}
logging.info("Master__init__,{},{}".format(id(self),id(self.pkg)))
self._listening_sockets = []
self.thread_pool = {}
self.thread_pool["spare_slaver"] = {}
self.thread_pool["working_slaver"] = {}
self.working_pool = {}
self.socket_bridge = socketbridge
# a queue for customers who have connected to us,
# but not assigned a slaver yet
self.pending_customers = queue.Queue()
self.communicate_addr = communicate_addr
_fmt_communicate_addr = fmt_addr(self.communicate_addr)
if slaver_pool:
# 若使用外部slaver_pool, 就不再初始化listen
# 这是以后待添加的功能
self.external_slaver = True
self.thread_pool["listen_slaver"] = None
else:
# 自己listen来获取slaver
self.external_slaver = False
self.slaver_pool = collections.deque()
# prepare Thread obj, not activated yet
self.thread_pool["listen_slaver"] = threading.Thread(
target=self._listen_slaver,
name="listen_slaver-{}".format(_fmt_communicate_addr),
daemon=True,
)
logging.info("master init self.slaver_pool:{},{}".format(id(self.slaver_pool),self.slaver_pool))
# prepare Thread obj, not activated yet
self.customer_listen_addr = customer_listen_addr
self.thread_pool["listen_customer"] = threading.Thread(
target=self._listen_customer,
name="listen_customer-{}".format(_fmt_communicate_addr),
daemon=True,
)
# prepare Thread obj, not activated yet
self.thread_pool["heart_beat_daemon"] = threading.Thread(
target=self._heart_beat_daemon,
name="heart_beat_daemon-{}".format(_fmt_communicate_addr),
daemon=True,
)
# prepare assign_slaver_daemon
self.thread_pool["assign_slaver_daemon"] = threading.Thread(
target=self._assign_slaver_daemon,
name="assign_slaver_daemon-{}".format(_fmt_communicate_addr),
daemon=True,
)
def dispose(self):
self._stopped['stop'] = True
logging.info("master dispose {}".format(self._stopped))
while len(self.slaver_pool):
slaver = self.slaver_pool.pop()
try:
slaver['addr_slaver'].shutdown(socket.SHUT_WR)
except Exception as e:
pass
try:
slaver['conn_slaver'].shutdown(socket.SHUT_WR)
except Exception as e:
pass
try:
slaver['addr_slaver'].close()
except Exception as e:
pass
try:
slaver['conn_slaver'].close()
except Exception as e:
pass
self.working_pool = None
for sock in self._listening_sockets:
try:
sock.shutdown(socket.SHUT_RDWR)
except Exception as e:
pass
try:
sock.close()
except Exception as e:
pass
self.thread_pool["socket_bridge"] = None
self.pending_customers = None
self.pkg.clean_crc32()
def serve_forever(self):
if not self.external_slaver:
self.thread_pool["listen_slaver"].start()
self.thread_pool["heart_beat_daemon"].start()
self.thread_pool["listen_customer"].start()
self.thread_pool["assign_slaver_daemon"].start()
self.thread_pool["socket_bridge"] = self.socket_bridge.get_thread()
# while True:
# time.sleep(10)
def try_bind_port(self,sock, addr):
while not self._stopped['stop']:
try:
sock.bind(addr)
except Exception as e:
log.error((
"unable to bind {}, {}. If this port was used by the recently-closed shootback itself\n"
"then don't worry, it would be available in several seconds\n"
"we'll keep trying....").format(addr, e))
log.debug(traceback.format_exc())
time.sleep(3)
else:
break
def _transfer_complete(self, addr_customer):
"""a callback for SocketBridge, do some cleanup jobs"""
log.info("customer complete: {}".format(addr_customer))
del self.working_pool[addr_customer]
def _serve_customer(self, conn_customer, conn_slaver,tmp):
"""put customer and slaver sockets into SocketBridge, let them exchange data"""
self.socket_bridge.add_conn_pair(
conn_customer, conn_slaver,tmp,
functools.partial( # it's a callback
# 这个回调用来在传输完成后删除工作池中对应记录
self._transfer_complete,
conn_customer.getpeername()
)
)
def _send_heartbeat(self,conn_slaver):
"""send and verify heartbeat pkg"""
conn_slaver.send(self.pkg.pbuild_heart_beat().raw)
pkg, verify = self.pkg.recv(
conn_slaver, expect_ptype=CtrlPkg.PTYPE_HEART_BEAT) # type: CtrlPkg,bool
if not verify:
return False
if pkg.prgm_ver < 0x000B:
# shootback before 2.2.5-r10 use two-way heartbeat
# so there is no third pkg to send
pass
else:
# newer version use TCP-like 3-way heartbeat
# the older 2-way heartbeat can't only ensure the
# master --> slaver pathway is OK, but the reverse
# communicate may down. So we need a TCP-like 3-way
# heartbeat
conn_slaver.send(self.pkg.pbuild_heart_beat().raw)
return verify
def _heart_beat_daemon(self):
"""
每次取出slaver队列头部的一个, 测试心跳, 并把它放回尾部.
slaver若超过 SPARE_SLAVER_TTL 秒未收到心跳, 则会自动重连
所以睡眠间隔(delay)满足 delay * slaver总数 < TTL
使得一轮循环的时间小于TTL,
保证每个slaver都在过期前能被心跳保活
"""
default_delay = 5 + SPARE_SLAVER_TTL // 12
delay = default_delay
log.info("heart beat daemon start, delay: {}s".format(delay))
while not self._stopped['stop']:
time.sleep(delay)
# log.debug("heart_beat_daemon: hello! im weak")
# ---------------------- preparation -----------------------
slaver_count = len(self.slaver_pool)
# logging.info("_heart_beat_daemon test {},{}".format(id(self),self._stopped))
if not slaver_count:
log.warning("heart_beat_daemon: sorry, no slaver available, keep sleeping")
# restore default delay if there is no slaver
delay = default_delay
continue
else:
# notice this `slaver_count*2 + 1`
# slaver will expire and re-connect if didn't receive
# heartbeat pkg after SPARE_SLAVER_TTL seconds.
# set delay to be short enough to let every slaver receive heartbeat
# before expire
delay = 1 + SPARE_SLAVER_TTL // max(slaver_count * 2 + 1, 12)
# pop the oldest slaver
# heartbeat it and then put it to the end of queue
slaver = self.slaver_pool.popleft()
addr_slaver = slaver["addr_slaver"]
# ------------------ real heartbeat begin --------------------
start_time = time.perf_counter()
try:
hb_result = self._send_heartbeat(slaver["conn_slaver"])
except Exception as e:
log.warning("error during heartbeat to {}: {}".format(
fmt_addr(addr_slaver), e))
log.debug(traceback.format_exc())
hb_result = False
finally:
time_used = round((time.perf_counter() - start_time) * 1000.0, 2)
# ------------------ real heartbeat end ----------------------
if not hb_result:
log.warning("heart beat failed: {}, time: {}ms".format(
fmt_addr(addr_slaver), time_used))
try_close(slaver["conn_slaver"])
del slaver["conn_slaver"]
# if heartbeat failed, start the next heartbeat immediately
# because in most cases, all 5 slaver connection will
# fall and re-connect in the same time
delay = 0
else:
log.debug("heartbeat success: {}, time: {}ms".format(
fmt_addr(addr_slaver), time_used))
self.slaver_pool.append(slaver)
def _handshake(self,conn_slaver):
"""
handshake before real data transfer
it ensures:
1. client is alive and ready for transmission
2. client is shootback_slaver, not mistakenly connected other program
3. verify the SECRET_KEY
4. tell slaver it's time to connect target
handshake procedure:
1. master hello --> slaver
2. slaver verify master's hello
3. slaver hello --> master
4. (immediately after 3) slaver connect to target
4. master verify slaver
5. enter real data transfer
"""
conn_slaver.send(self.pkg.pbuild_hs_m2s().raw)
log.debug("CtrlPkg key{},{}".format(self.pkg.SECRET_KEY_CRC32,self.pkg.SECRET_KEY_REVERSED_CRC32))
buff = select_recv(conn_slaver, CtrlPkg.PACKAGE_SIZE, 2)
if buff is None:
return False
pkg, verify = self.pkg.decode_verify(buff, CtrlPkg.PTYPE_HS_S2M) # type: CtrlPkg,bool
log.debug("CtrlPkg from slaver {}: {}".format(conn_slaver.getpeername(), pkg))
return verify
def _get_an_active_slaver(self):
"""get and activate an slaver for data transfer"""
try_count = 10
while not self._stopped['stop']:
try:
logging.info("master _get_an_active_slaver self.slaver_pool:{},{}".format(id(self.slaver_pool), self.slaver_pool))
dict_slaver = self.slaver_pool.popleft()
except:
if try_count:
time.sleep(0.02)
try_count -= 1
if try_count % 10 == 0:
log.error("!!NO SLAVER AVAILABLE!! trying {}".format(try_count))
continue
return None
conn_slaver = dict_slaver["conn_slaver"]
try:
hs = self._handshake(conn_slaver)
except Exception as e:
log.warning("Handshake failed: {},key:{},{},{},{}".format(e,id(self),self.pkg.skey,self.pkg.SECRET_KEY_CRC32,self.pkg.SECRET_KEY_REVERSED_CRC32))
log.debug(traceback.format_exc())
hs = False
if hs:
return conn_slaver
else:
log.warning("slaver handshake failed: {}".format(dict_slaver["addr_slaver"]))
try_close(conn_slaver)
time.sleep(0.02)
def _assign_slaver_daemon(self):
"""assign slaver for customer"""
while not self._stopped['stop']:
# get a newly connected customer
conn_customer, addr_customer,tmp = self.pending_customers.get()
conn_slaver = self._get_an_active_slaver()
if conn_slaver is None:
log.warning("Closing customer[{}] because no available slaver found".format(
addr_customer))
try_close(conn_customer)
continue
else:
log.debug("Using slaver: {} for {}".format(conn_slaver.getpeername(), addr_customer))
self.working_pool[addr_customer] = {
"addr_customer": addr_customer,
"conn_customer": conn_customer,
"conn_slaver": conn_slaver,
"tmp":tmp
}
try:
self._serve_customer(conn_customer, conn_slaver,tmp)
except Exception as e:
try:
logging.info("_serve_customer fail e:{},{}".format(e,conn_customer))
try_close(conn_customer)
except:
pass
continue
def _listen_slaver(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.try_bind_port(sock, self.communicate_addr)
sock.listen(10)
self._listening_sockets.append(sock)
log.info("Listening for slavers: {}".format(
fmt_addr(self.communicate_addr)))
while not self._stopped['stop']:
logging.info("_listen_slaver stop:{},{}".format(id(self._stopped),self._stopped))
conn, addr = sock.accept()
self.slaver_pool.append({
"addr_slaver": addr,
"conn_slaver": conn,
})
log.info("{} Got slaver {} Total: {}".format(
self.communicate_addr,fmt_addr(addr), len(self.slaver_pool)
))
def _listen_customer(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.try_bind_port(sock, self.customer_listen_addr)
sock.listen(20)
self._listening_sockets.append(sock)
log.info("Listening for customers: {}".format(
fmt_addr(self.customer_listen_addr)))
while not self._stopped['stop']:
conn_customer, addr_customer = sock.accept()
log.info("Serving customer: {} Total customers: {}".format(
addr_customer, self.pending_customers.qsize() + 1
))
# just put it into the queue,
# let _assign_slaver_daemon() do the else
# don't block this loop
self.pending_customers.put((conn_customer, addr_customer,None))
def add_http_customer(self,conn,addr,tmp):
log.info("Serving customer: {} Total customers: {}".format(
conn, self.pending_customers.qsize() + 1
))
self.pending_customers.put((conn, addr,tmp))
class Mastar_line:
def __init__(self,socketbridge):
self.__website__ = "https://github.com/aploium/shootback"
self._listening_sockets = []
self.SPARE_SLAVER_TTL = 0
self.SECRET_KEY = ""
self.socketbridge = socketbridge
self.master = None
def dispose(self):
if self.master:
self.master.dispose()
self.SECRET_KEY = ""
self.socketbridge = None
def run_master(self,communicate_addr, customer_listen_addr,pkg):
log.info("shootback {} running as master".format(version_info()))
# log.info("author: {} site: {}".format(__author__, __website__))
log.info("slaver from: {} customer from: {}".format(
fmt_addr(communicate_addr), fmt_addr(customer_listen_addr)))
self.master = Master(customer_listen_addr, communicate_addr,self._listening_sockets,pkg,self.socketbridge)
self.master.serve_forever()
def main_master(self,args):
communicate_addr = split_host(args['master'])
customer_listen_addr = split_host(args['customer'])
self.SECRET_KEY = args['secretkey']
self.pkg = CtrlPkg()
self.pkg.recalc_crc32(self.SECRET_KEY)
logging.info("main_master,{},{},id:{},self.pkg:{}".format(args['master'],self.SECRET_KEY,id(self),id(self.pkg)))
local.SPARE_SLAVER_TTL = SPARE_SLAVER_TTL
# if args.quiet < 2:
# if args.verbose:
# level = logging.DEBUG
# elif args.quiet:
# level = logging.WARNING
# else:
# level = logging.INFO
configure_logging(logging.INFO)
self.run_master(communicate_addr, customer_listen_addr,self.pkg)
此差异已折叠。
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import threading
import time
import logging
import db_transfer
class MainThread(threading.Thread):
def __init__(self, obj):
threading.Thread.__init__(self)
self.obj = obj
def run(self):
self.obj.thread_db(self.obj)
def stop(self):
self.obj.thread_db_stop()
def main():
logging.basicConfig(
level=logging.DEBUG,
format='[%(levelname)s %(asctime)s] %(message)s',
)
thread = MainThread(db_transfer.Dbv3Transfer)
thread.start()
try:
while thread.is_alive():
time.sleep(10)
except (KeyboardInterrupt, IOError, OSError) as e:
import traceback
traceback.print_exc()
thread.stop()
if __name__ == '__main__':
main()
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
import threading
import eventloop
class MainThread(threading.Thread):
def __init__(self, loop):
threading.Thread.__init__(self)
self.loop = loop
def run(self):
ServerPool._loop(self.loop)
def stop(self):
self.loop.thread_stop()
class ServerPool(object):
instance = None
bridgeAdd = 0
bridgeRemove = 0
def __init__(self):
self.loop = eventloop.EventLoop()
self.thread = MainThread(self.loop)
self.thread.start()
@property
def getloop(self):
return self.loop
@staticmethod
def get_instance():
if ServerPool.instance is None:
ServerPool.instance = ServerPool()
return ServerPool.instance
@staticmethod
def _loop(loop):
try:
loop.run()
except (KeyboardInterrupt, IOError, OSError) as e:
logging.error(e)
import traceback
traceback.print_exc()
exit(0)
except Exception as e:
logging.error(e)
import traceback
traceback.print_exc()
def stop(self):
self.thread.stop()
\ No newline at end of file
#!/usr/bin/env python
# coding=utf-8
from __future__ import print_function, unicode_literals, division, absolute_import
import sys
import time
import binascii
import struct
import collections
import logging
import socket
import select
import threading
import traceback
import functools
try:
# for pycharm type hinting
from typing import Union, Callable
except:
pass
# socket recv buffer, 16384 bytes
RECV_BUFFER_SIZE = 2 ** 14
# default secretkey, use -k/--secretkey to change
SECRET_KEY = "0"
# how long a SPARE slaver would keep
# once slaver received an heart-beat package from master,
# the TTL would be reset. And heart-beat delay is less than TTL,
# so, theoretically, spare slaver never timeout,
# except network failure
# notice: working slaver would NEVER timeout
SPARE_SLAVER_TTL = 300
# internal program version, appears in CtrlPkg
INTERNAL_VERSION = 0x000D
# version for human readable
__version__ = (2, 2, 8, INTERNAL_VERSION)
# just a logger
log = logging.getLogger(__name__)
def version_info():
"""get program version for human. eg: "2.1.0-r2" """
return "{}.{}.{}-r{}".format(*__version__)
def configure_logging(level):
logging.basicConfig(
level=level,
format='[%(levelname)s %(asctime)s] %(message)s',
)
def fmt_addr(socket):
"""(host, int(port)) --> "host:port" """
return "{}:{}".format(*socket)
def split_host(x):
""" "host:port" --> (host, int(port))"""
try:
host, port = x.split(":")
port = int(port)
except:
raise ValueError(
"wrong syntax, format host:port is "
"required, not {}".format(x))
else:
return host, port
def try_close(closable):
"""try close something
same as
try:
connection.close()
except:
pass
"""
try:
closable.close()
except:
pass
def select_recv(conn, buff_size, timeout=None):
"""add timeout for socket.recv()
:type conn: socket.SocketType
:type buff_size: int
:type timeout: float
:rtype: Union[bytes, None]
"""
rlist, _, _ = select.select([conn], [], [], timeout)
if not rlist:
# timeout
raise RuntimeError("recv timeout")
buff = conn.recv(buff_size)
if not buff:
raise RuntimeError("received zero bytes, socket was closed")
return buff
class SocketBridge:
"""
transfer data between sockets
"""
def __init__(self):
self.conn_rd = set() # record readable-sockets
self.map = {} # record sockets pairs
self.callbacks = {} # record callbacks
def add_conn_pair(self, conn1, conn2, callback=None):
"""
transfer anything between two sockets
:type conn1: socket.SocketType
:type conn2: socket.SocketType
:param callback: callback in connection finish
:type callback: Callable
"""
# mark as readable
self.conn_rd.add(conn1)
self.conn_rd.add(conn2)
# record sockets pairs
self.map[conn1] = conn2
self.map[conn2] = conn1
# record callback
if callback is not None:
self.callbacks[conn1] = callback
def start_as_daemon(self):
t = threading.Thread(target=self.start)
t.daemon = True
t.start()
log.info("SocketBridge daemon started")
return t
def start(self):
while True:
try:
self._start()
except:
log.error("FATAL ERROR! SocketBridge failed {}".format(
traceback.format_exc()
))
def _start(self):
# memoryview act as an recv buffer
# refer https://docs.python.org/3/library/stdtypes.html#memoryview
buff = memoryview(bytearray(RECV_BUFFER_SIZE))
while True:
if not self.conn_rd:
# sleep if there is no connections
time.sleep(0.06)
continue
# blocks until there is socket(s) ready for .recv
# notice: sockets which were closed by remote,
# are also regarded as read-ready by select()
r, w, e = select.select(self.conn_rd, [], [], 0.5)
for s in r: # iter every read-ready or closed sockets
try:
# here, we use .recv_into() instead of .recv()
# recv data directly into the pre-allocated buffer
# to avoid many unnecessary malloc()
# see https://docs.python.org/3/library/socket.html#socket.socket.recv_into
rec_len = s.recv_into(buff, RECV_BUFFER_SIZE)
except:
# unable to read, in most cases, it's due to socket close
self._rd_shutdown(s)
continue
if not rec_len:
# read zero size, closed or shutdowned socket
self._rd_shutdown(s)
continue
try:
# send data, we use `buff[:rec_len]` slice because
# only the front of buff is filled
self.map[s].send(buff[:rec_len])
except:
# unable to send, close connection
self._rd_shutdown(s)
continue
def _rd_shutdown(self, conn, once=False):
"""action when connection should be read-shutdown
:type conn: socket.SocketType
"""
if conn in self.conn_rd:
self.conn_rd.remove(conn)
try:
conn.shutdown(socket.SHUT_RD)
except:
pass
if not once and conn in self.map: # use the `once` param to avoid infinite loop
# if a socket is rd_shutdowned, then it's
# pair should be wr_shutdown.
self._wr_shutdown(self.map[conn], True)
if self.map.get(conn) not in self.conn_rd:
# if both two connection pair was rd-shutdowned,
# this pair sockets are regarded to be completed
# so we gonna close them
self._terminate(conn)
def _wr_shutdown(self, conn, once=False):
"""action when connection should be write-shutdown
:type conn: socket.SocketType
"""
try:
conn.shutdown(socket.SHUT_WR)
except:
pass
if not once and conn in self.map: # use the `once` param to avoid infinite loop
# pair should be rd_shutdown.
# if a socket is wr_shutdowned, then it's
self._rd_shutdown(self.map[conn], True)
def _terminate(self, conn):
"""terminate a sockets pair (two socket)
:type conn: socket.SocketType
:param conn: any one of the sockets pair
"""
try_close(conn) # close the first socket
# ------ close and clean the mapped socket, if exist ------
if conn in self.map:
_mapped_conn = self.map[conn]
try_close(_mapped_conn)
if _mapped_conn in self.map:
del self.map[_mapped_conn]
del self.map[conn] # clean the first socket
else:
_mapped_conn = None # just a fallback
# ------ callback --------
# because we are not sure which socket are assigned to callback,
# so we should try both
if conn in self.callbacks:
try:
self.callbacks[conn]()
except Exception as e:
log.error("traceback error: {}".format(e))
log.debug(traceback.format_exc())
del self.callbacks[conn]
elif _mapped_conn and _mapped_conn in self.callbacks:
try:
self.callbacks[_mapped_conn]()
except Exception as e:
log.error("traceback error: {}".format(e))
log.debug(traceback.format_exc())
del self.callbacks[_mapped_conn]
class CtrlPkg:
"""
Control Packages of shootback, not completed yet
current we have: handshake and heartbeat
NOTICE: If you are non-Chinese reader,
please contact me for the following Chinese comment's translation
http://github.com/aploium
控制包结构 总长64bytes CtrlPkg.FORMAT_PKG
使用 big-endian
体积 名称 数据类型 描述
1 pkg_ver unsigned char 包版本 *1
1 pkg_type signed char 包类型 *2
2 prgm_ver unsigned short 程序版本 *3
20 N/A N/A 预留
40 data bytes 数据区 *4
*1: 包版本. 包整体结构的定义版本, 目前只有 0x01
*2: 包类型. 除心跳外, 所有负数包代表由Slaver发出, 正数包由Master发出
-1: Slaver-->Master 的握手响应包 PTYPE_HS_S2M
0: 心跳包 PTYPE_HEART_BEAT
+1: Master-->Slaver 的握手包 PTYPE_HS_M2S
*3: 默认即为 INTERNAL_VERSION
*4: 数据区中的内容由各个类型的包自身定义
-------------- 数据区定义 ------------------
包类型: -1 (Slaver-->Master 的握手响应包)
体积 名称 数据类型 描述
4 crc32_s2m unsigned int 简单鉴权用 CRC32(Reversed(SECRET_KEY))
其余为空
*注意: -1握手包是把 SECRET_KEY 字符串翻转后取CRC32, +1握手包不预先反转
包类型: 0 (心跳)
数据区为空
包理性: +1 (Master-->Slaver 的握手包)
体积 名称 数据类型 描述
4 crc32_m2s unsigned int 简单鉴权用 CRC32(SECRET_KEY)
其余为空
"""
PACKAGE_SIZE = 2 ** 6 # 64 bytes
CTRL_PKG_TIMEOUT = 5 # CtrlPkg recv timeout, in second
# CRC32 for SECRET_KEY and Reversed(SECRET_KEY)
SECRET_KEY_CRC32 = binascii.crc32(SECRET_KEY.encode('utf-8')) & 0xffffffff
SECRET_KEY_REVERSED_CRC32 = binascii.crc32(SECRET_KEY[::-1].encode('utf-8')) & 0xffffffff
# Package Type
PTYPE_HS_S2M = -1 # handshake pkg, slaver to master
PTYPE_HEART_BEAT = 0 # heart beat pkg
PTYPE_HS_M2S = +1 # handshake pkg, Master to Slaver
TYPE_NAME_MAP = {
PTYPE_HS_S2M: "PTYPE_HS_S2M",
PTYPE_HEART_BEAT: "PTYPE_HEART_BEAT",
PTYPE_HS_M2S: "PTYPE_HS_M2S",
}
# formats
# see https://docs.python.org/3/library/struct.html#format-characters
# for format syntax
FORMAT_PKG = "!b b H 20x 40s"
FORMATS_DATA = {
PTYPE_HS_S2M: "!I 36x",
PTYPE_HEART_BEAT: "!40x",
PTYPE_HS_M2S: "!I 36x",
}
_cache_prebuilt_pkg = {} # cache
def __init__(self, pkg_ver=0x01, pkg_type=0,
prgm_ver=INTERNAL_VERSION, data=(),
raw=None,
):
"""do not call this directly, use `CtrlPkg.pbuild_*` instead"""
self.pkg_ver = pkg_ver
self.pkg_type = pkg_type
self.prgm_ver = prgm_ver
self.data = data
if raw:
self.raw = raw
else:
self._build_bytes()
@property
def type_name(self):
"""返回人类可读的包类型"""
return self.TYPE_NAME_MAP.get(self.pkg_type, "TypeUnknown")
def __str__(self):
return """pkg_ver: {} pkg_type:{} prgm_ver:{} data:{}""".format(
self.pkg_ver,
self.type_name,
self.prgm_ver,
self.data,
)
def __repr__(self):
return self.__str__()
def _build_bytes(self):
self.raw = struct.pack(
self.FORMAT_PKG,
self.pkg_ver,
self.pkg_type,
self.prgm_ver,
self.data_encode(self.pkg_type, self.data),
)
@classmethod
def _prebuilt_pkg(cls, pkg_type, fallback):
"""act as lru_cache"""
if pkg_type not in cls._cache_prebuilt_pkg:
pkg = fallback(force_rebuilt=True)
cls._cache_prebuilt_pkg[pkg_type] = pkg
return cls._cache_prebuilt_pkg[pkg_type]
@classmethod
def recalc_crc32(cls):
cls.SECRET_KEY_CRC32 = binascii.crc32(SECRET_KEY.encode('utf-8')) & 0xffffffff
cls.SECRET_KEY_REVERSED_CRC32 = binascii.crc32(SECRET_KEY[::-1].encode('utf-8')) & 0xffffffff
#logging.info("main key:{},{}".format(cls.SECRET_KEY_CRC32, cls.SECRET_KEY_REVERSED_CRC32))
@classmethod
def data_decode(cls, ptype, data_raw):
return struct.unpack(cls.FORMATS_DATA[ptype], data_raw)
@classmethod
def data_encode(cls, ptype, data):
return struct.pack(cls.FORMATS_DATA[ptype], *data)
def verify(self, pkg_type=None):
try:
if pkg_type is not None and self.pkg_type != pkg_type:
return False
elif self.pkg_type == self.PTYPE_HS_S2M:
# Slaver-->Master 的握手响应包
return self.data[0] == self.SECRET_KEY_REVERSED_CRC32
elif self.pkg_type == self.PTYPE_HEART_BEAT:
# 心跳
return True
elif self.pkg_type == self.PTYPE_HS_M2S:
# Master-->Slaver 的握手包
return self.data[0] == self.SECRET_KEY_CRC32
else:
return True
except:
return False
@classmethod
def decode_only(cls, raw):
"""
decode raw bytes to CtrlPkg instance, no verify
use .decode_verify() if you also want verify
:param raw: raw bytes content of package
:type raw: bytes
:rtype: CtrlPkg
"""
if not raw or len(raw) != cls.PACKAGE_SIZE:
raise ValueError("content size should be {}, but {}".format(
cls.PACKAGE_SIZE, len(raw)
))
pkg_ver, pkg_type, prgm_ver, data_raw = struct.unpack(cls.FORMAT_PKG, raw)
logging.info("CtrlPkg,decode_only,,,,pkg_ver:{}, pkg_type:{}, prgm_ver:{}".format(pkg_ver, pkg_type,prgm_ver))
data = cls.data_decode(pkg_type, data_raw)
return cls(
pkg_ver=pkg_ver, pkg_type=pkg_type,
prgm_ver=prgm_ver,
data=data,
raw=raw,
)
@classmethod
def decode_verify(cls, raw, pkg_type=None):
"""decode and verify a package
:param raw: raw bytes content of package
:type raw: bytes
:param pkg_type: assert this package's type,
if type not match, would be marked as wrong
:type pkg_type: int
:rtype: CtrlPkg, bool
:return: tuple(CtrlPkg, is_it_a_valid_package)
"""
try:
pkg = cls.decode_only(raw)
except:
return None, False
else:
return pkg, pkg.verify(pkg_type=pkg_type)
@classmethod
def pbuild_hs_m2s(cls, force_rebuilt=False):
"""pkg build: Handshake Master to Slaver"""
# because py27 do not have functools.lru_cache, so we must write our own
if force_rebuilt:
return cls(
pkg_type=cls.PTYPE_HS_M2S,
data=(cls.SECRET_KEY_CRC32,),
)
else:
return cls._prebuilt_pkg(cls.PTYPE_HS_M2S, cls.pbuild_hs_m2s)
@classmethod
def pbuild_hs_s2m(cls, force_rebuilt=False):
"""pkg build: Handshake Slaver to Master"""
if force_rebuilt:
return cls(
pkg_type=cls.PTYPE_HS_S2M,
data=(cls.SECRET_KEY_REVERSED_CRC32,),
)
else:
return cls._prebuilt_pkg(cls.PTYPE_HS_S2M, cls.pbuild_hs_s2m)
@classmethod
def pbuild_heart_beat(cls, force_rebuilt=False):
"""pkg build: Heart Beat Package"""
if force_rebuilt:
return cls(
pkg_type=cls.PTYPE_HEART_BEAT,
)
else:
return cls._prebuilt_pkg(cls.PTYPE_HEART_BEAT, cls.pbuild_heart_beat)
@classmethod
def recv(cls, sock, timeout=CTRL_PKG_TIMEOUT, expect_ptype=None):
"""just a shortcut function
:param sock: which socket to recv CtrlPkg from
:type sock: socket.SocketType
:rtype: CtrlPkg,bool
"""
logging.info("CtrlPkg,recv,sock:{},expect_ptype:{}".format(sock, expect_ptype))
buff = select_recv(sock, cls.PACKAGE_SIZE, timeout)
pkg, verify = CtrlPkg.decode_verify(buff, pkg_type=expect_ptype) # type: CtrlPkg,bool
return pkg, verify
#!/usr/bin/env python
# coding=utf-8
from __future__ import print_function, unicode_literals, division, absolute_import
from common_func import *
__author__ = "Aploium <i@z.codes>"
__website__ = "https://github.com/aploium/shootback"
class Slaver:
"""
slaver socket阶段
连接master->等待->心跳(重复)--->握手-->正式传输数据->退出
"""
def __init__(self, communicate_addr, target_addr, max_spare_count=5):
self.communicate_addr = communicate_addr
self.target_addr = target_addr
self.max_spare_count = max_spare_count
self.spare_slaver_pool = {}
self.working_pool = {}
self.socket_bridge = SocketBridge()
def _connect_master(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(self.communicate_addr)
self.spare_slaver_pool[sock.getsockname()] = {
"conn_slaver": sock,
}
return sock
def _connect_target(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(self.target_addr)
log.debug("connected to target[{}] at: {}".format(
sock.getpeername(),
sock.getsockname(),
))
return sock
def _response_heartbeat(self, conn_slaver, hb_from_master):
# assert isinstance(hb_from_master, CtrlPkg)
# assert isinstance(conn_slaver, socket.SocketType)
if hb_from_master.prgm_ver < 0x000B:
# shootback before 2.2.5-r10 use two-way heartbeat
# so just send a heart_beat pkg back
conn_slaver.send(CtrlPkg.pbuild_heart_beat().raw)
return True
else:
# newer version use TCP-like 3-way heartbeat
# the older 2-way heartbeat can't only ensure the
# master --> slaver pathway is OK, but the reverse
# communicate may down. So we need a TCP-like 3-way
# heartbeat
conn_slaver.send(CtrlPkg.pbuild_heart_beat().raw)
pkg, verify = CtrlPkg.recv(
conn_slaver,
expect_ptype=CtrlPkg.PTYPE_HEART_BEAT) # type: CtrlPkg,bool
if verify:
log.debug("heartbeat success {}".format(
fmt_addr(conn_slaver.getsockname())))
return True
else:
log.warning(
"received a wrong pkg[{}] during heartbeat, {}".format(
pkg, conn_slaver.getsockname()
))
return False
def _stage_ctrlpkg(self, conn_slaver):
"""
handling CtrlPkg until handshake
well, there is only one CtrlPkg: heartbeat, yet
it ensures:
1. network is ok, master is alive
2. master is shootback_master, not bad guy
3. verify the SECRET_KEY
4. tell slaver it's time to connect target
handshake procedure:
1. master hello --> slaver
2. slaver verify master's hello
3. slaver hello --> master
4. (immediately after 3) slaver connect to target
4. master verify slaver
5. enter real data transfer
"""
while True: # 可能会有一段时间的心跳包
# recv master --> slaver
# timeout is set to `SPARE_SLAVER_TTL`
# which means if not receive pkg from master in SPARE_SLAVER_TTL seconds,
# this connection would expire and re-connect
pkg, verify = CtrlPkg.recv(conn_slaver, SPARE_SLAVER_TTL) # type: CtrlPkg,bool
if not verify:
return False
log.debug("CtrlPkg from {}: {}".format(conn_slaver.getpeername(), pkg))
if pkg.pkg_type == CtrlPkg.PTYPE_HEART_BEAT:
# if the pkg is heartbeat pkg, enter handshake procedure
if not self._response_heartbeat(conn_slaver, pkg):
return False
elif pkg.pkg_type == CtrlPkg.PTYPE_HS_M2S:
# 拿到了开始传输的握手包, 进入工作阶段
break
# send slaver hello --> master
conn_slaver.send(CtrlPkg.pbuild_hs_s2m().raw)
return True
def _transfer_complete(self, addr_slaver):
"""a callback for SocketBridge, do some cleanup jobs"""
del self.working_pool[addr_slaver]
log.info("slaver complete: {}".format(addr_slaver))
def _slaver_working(self, conn_slaver):
addr_slaver = conn_slaver.getsockname()
addr_master = conn_slaver.getpeername()
# --------- handling CtrlPkg until handshake -------------
try:
hs = self._stage_ctrlpkg(conn_slaver)
except Exception as e:
log.warning("slaver{} waiting handshake failed {}".format(
fmt_addr(addr_slaver), e))
log.debug(traceback.print_exc())
hs = False
else:
if not hs:
log.warning("bad handshake or timeout between: {} and {}".format(
fmt_addr(addr_master), fmt_addr(addr_slaver)))
if not hs:
# handshake failed or timeout
del self.spare_slaver_pool[addr_slaver]
try_close(conn_slaver)
log.warning("a slaver[{}] abort due to handshake error or timeout".format(
fmt_addr(addr_slaver)))
return
else:
log.info("Success master handshake from: {} to {}".format(
fmt_addr(addr_master), fmt_addr(addr_slaver)))
# ----------- slaver activated! ------------
# move self from spare_slaver_pool to working_pool
self.working_pool[addr_slaver] = self.spare_slaver_pool.pop(addr_slaver)
# ----------- connecting to target ----------
try:
conn_target = self._connect_target()
except:
log.error("unable to connect target")
try_close(conn_slaver)
del self.working_pool[addr_slaver]
return
self.working_pool[addr_slaver]["conn_target"] = conn_target
# ----------- all preparation finished -----------
# pass two sockets to SocketBridge, and let it do the
# real data exchange task
self.socket_bridge.add_conn_pair(
conn_slaver, conn_target,
functools.partial(
# 这个回调用来在传输完成后删除工作池中对应记录
self._transfer_complete, addr_slaver
)
)
# this slaver thread exits here
return
def serve_forever(self):
self.socket_bridge.start_as_daemon() # hi, don't ignore me
# sleep between two retries if exception occurs
# eg: master down or network temporary failed
# err_delay would increase if err occurs repeatedly
# until `max_err_delay`
# would immediately decrease to 0 after a success connection
err_delay = 0
max_err_delay = 15
# spare_delay is sleep cycle if we are full of spare slaver
# would immediately decrease to 0 after a slaver lack
spare_delay = 0.08
default_spare_delay = 0.08
while True:
if len(self.spare_slaver_pool) >= self.max_spare_count:
time.sleep(spare_delay)
spare_delay = (spare_delay + default_spare_delay) / 2.0
continue
else:
spare_delay = 0.0
try:
conn_slaver = self._connect_master()
except Exception as e:
log.warning("unable to connect master {}".format(e))
log.debug(traceback.format_exc())
time.sleep(err_delay)
if err_delay < max_err_delay:
err_delay += 1
continue
try:
t = threading.Thread(target=self._slaver_working,
args=(conn_slaver,)
)
t.daemon = True
t.start()
log.info("connected to master[{}] at {} total: {}".format(
fmt_addr(conn_slaver.getpeername()),
fmt_addr(conn_slaver.getsockname()),
len(self.spare_slaver_pool),
))
except Exception as e:
log.error("unable create Thread: {}".format(e))
log.debug(traceback.format_exc())
time.sleep(err_delay)
if err_delay < max_err_delay:
err_delay += 1
continue
# set err_delay if everything is ok
err_delay = 0
def run_slaver(communicate_addr, target_addr, max_spare_count=5):
log.info("running as slaver, master addr: {} target: {}".format(
fmt_addr(communicate_addr), fmt_addr(target_addr)
))
Slaver(communicate_addr, target_addr, max_spare_count=max_spare_count).serve_forever()
def argparse_slaver():
import argparse
parser = argparse.ArgumentParser(
description="""shootback {ver}-slaver
A fast and reliable reverse TCP tunnel (this is slaver)
Help access local-network service from Internet.
https://github.com/aploium/shootback""".format(ver=version_info()),
epilog="""
Example1:
tunnel local ssh to public internet, assume master's ip is 1.2.3.4
Master(another public server): master.py -m 0.0.0.0:10000 -c 0.0.0.0:10022
Slaver(this pc): slaver.py -m 1.2.3.4:10000 -t 127.0.0.1:22
Customer(any internet user): ssh 1.2.3.4 -p 10022
the actual traffic is: customer <--> master(1.2.3.4) <--> slaver(this pc) <--> ssh(this pc)
Example2:
Tunneling for www.example.com
Master(this pc): master.py -m 127.0.0.1:10000 -c 127.0.0.1:10080
Slaver(this pc): slaver.py -m 127.0.0.1:10000 -t example.com:80
Customer(this pc): curl -v -H "host: example.com" 127.0.0.1:10080
Tips: ANY service using TCP is shootback-able. HTTP/FTP/Proxy/SSH/VNC/...
""",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("-m", "--master", required=True,
metavar="host:port",
help="master address, usually an Public-IP. eg: 2.3.3.3:5500")
parser.add_argument("-t", "--target", required=True,
metavar="host:port",
help="where the traffic from master should be tunneled to, usually not public. eg: 10.1.2.3:80")
parser.add_argument("-k", "--secretkey", default="shootback",
help="secretkey to identity master and slaver, should be set to the same value in both side")
parser.add_argument("-v", "--verbose", action="count", default=0,
help="verbose output")
parser.add_argument("-q", "--quiet", action="count", default=0,
help="quiet output, only display warning and errors, use two to disable output")
parser.add_argument("-V", "--version", action="version", version="shootback {}-slaver".format(version_info()))
parser.add_argument("--ttl", default=300, type=int, dest="SPARE_SLAVER_TTL",
help="standing-by slaver's TTL, default is 300. "
"this value is optimized for most cases")
parser.add_argument("--max-standby", default=5, type=int, dest="max_spare_count",
help="max standby slaver TCP connections count, default is 5. "
"which is enough for more than 800 concurrency. "
"while working connections are always unlimited")
return parser.parse_args()
def main_slaver():
global SPARE_SLAVER_TTL
global SECRET_KEY
global SECRET_KEY_CRC32
global SECRET_KEY_REVERSED_CRC32
args = argparse_slaver()
if args.verbose and args.quiet:
print("-v and -q should not appear together")
exit(1)
communicate_addr = split_host(args.master)
target_addr = split_host(args.target)
SECRET_KEY = args.secretkey
CtrlPkg.recalc_crc32()
CtrlPkg.SECRET_KEY_CRC32 = binascii.crc32(SECRET_KEY.encode('utf-8')) & 0xffffffff
CtrlPkg.SECRET_KEY_REVERSED_CRC32 = binascii.crc32(SECRET_KEY[::-1].encode('utf-8')) & 0xffffffff
SPARE_SLAVER_TTL = args.SPARE_SLAVER_TTL
max_spare_count = args.max_spare_count
if args.quiet < 2:
if args.verbose:
level = logging.DEBUG
elif args.quiet:
level = logging.WARNING
else:
level = logging.INFO
configure_logging(level)
log.info("shootback {} slaver running".format(version_info()))
log.info("author: {} site: {}".format(__author__, __website__))
log.info("Master: {}".format(fmt_addr(communicate_addr)))
log.info("Target: {}".format(fmt_addr(target_addr)))
# communicate_addr = ("localhost", 12345)
# target_addr = ("93.184.216.34", 80) # www.example.com
run_slaver(communicate_addr, target_addr, max_spare_count=max_spare_count)
if __name__ == '__main__':
main_slaver()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册