using System;
using System.Net;
using System.Net.Sockets;
using System.Threading.Tasks;
namespace Base
{
///
/// 封装Socket,将回调push到主线程处理
///
public class TSocket: IDisposable
{
private readonly TPoller poller;
private Socket socket;
private readonly SocketAsyncEventArgs innArgs = new SocketAsyncEventArgs();
private readonly SocketAsyncEventArgs outArgs = new SocketAsyncEventArgs();
public Action OnConn;
public Action OnRecv;
public Action OnSend;
public Action OnDisconnect;
public TSocket(TPoller poller)
{
this.poller = poller;
this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
this.innArgs.Completed += this.OnComplete;
this.outArgs.Completed += this.OnComplete;
}
public TSocket(TPoller poller, string host, int port): this(poller)
{
this.Bind(host, port);
this.Listen(100);
}
public string RemoteAddress { get; private set; }
public Socket Socket
{
get
{
return this.socket;
}
}
public void Dispose()
{
if (this.socket == null)
{
return;
}
this.socket.Close();
this.socket = null;
}
private void Bind(string host, int port)
{
this.socket.Bind(new IPEndPoint(IPAddress.Parse(host), port));
}
private void Listen(int backlog)
{
this.socket.Listen(backlog);
}
public Task AcceptAsync(TSocket accpetSocket)
{
var tcs = new TaskCompletionSource();
this.innArgs.UserToken = tcs;
this.innArgs.AcceptSocket = accpetSocket.socket;
if (!this.socket.AcceptAsync(this.innArgs))
{
OnAcceptComplete(this.innArgs);
}
return tcs.Task;
}
private static void OnAcceptComplete(SocketAsyncEventArgs e)
{
var tcs = (TaskCompletionSource)e.UserToken;
e.UserToken = null;
if (e.SocketError != SocketError.Success)
{
tcs.SetException(new Exception($"socket error: {e.SocketError}"));
return;
}
tcs.SetResult(true);
}
private void OnComplete(object sender, SocketAsyncEventArgs e)
{
Action action;
switch (e.LastOperation)
{
case SocketAsyncOperation.Connect:
action = () => OnConnectComplete(e);
break;
case SocketAsyncOperation.Receive:
action = () => OnRecvComplete(e);
break;
case SocketAsyncOperation.Send:
action = () => OnSendComplete(e);
break;
case SocketAsyncOperation.Disconnect:
action = () => OnDisconnectComplete(e);
break;
case SocketAsyncOperation.Accept:
action = () => OnAcceptComplete(e);
break;
default:
throw new Exception($"socket error: {e.LastOperation}");
}
// 回调到主线程处理
this.poller.Add(action);
}
public bool ConnectAsync(string host, int port)
{
this.RemoteAddress = $"{host}:{port}";
this.outArgs.RemoteEndPoint = new IPEndPoint(IPAddress.Parse(host), port);
if (this.socket.ConnectAsync(this.outArgs))
{
return true;
}
OnConnectComplete(this.outArgs);
return false;
}
private void OnConnectComplete(SocketAsyncEventArgs e)
{
if (this.OnConn == null)
{
return;
}
this.OnConn(e.SocketError);
}
public bool RecvAsync(byte[] buffer, int offset, int count)
{
try
{
this.innArgs.SetBuffer(buffer, offset, count);
}
catch (Exception e)
{
throw new Exception($"socket set buffer error: {buffer.Length}, {offset}, {count}", e);
}
if (this.socket.ReceiveAsync(this.innArgs))
{
return true;
}
OnRecvComplete(this.innArgs);
return false;
}
private void OnRecvComplete(SocketAsyncEventArgs e)
{
if (this.OnRecv == null)
{
return;
}
this.OnRecv(e.BytesTransferred, e.SocketError);
}
public bool SendAsync(byte[] buffer, int offset, int count)
{
try
{
this.outArgs.SetBuffer(buffer, offset, count);
}
catch (Exception e)
{
throw new Exception($"socket set buffer error: {buffer.Length}, {offset}, {count}", e);
}
if (this.socket.SendAsync(this.outArgs))
{
return true;
}
OnSendComplete(this.outArgs);
return false;
}
private void OnSendComplete(SocketAsyncEventArgs e)
{
if (this.OnSend == null)
{
return;
}
this.OnSend(e.BytesTransferred, e.SocketError);
}
private void OnDisconnectComplete(SocketAsyncEventArgs e)
{
if (this.OnDisconnect == null)
{
return;
}
this.OnDisconnect(e.SocketError);
}
}
}