using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Sockets; using System.Threading; namespace Sodao.FastSocket.SocketBase { /// /// base host /// public abstract class BaseHost : IHost { #region Members private static long s_connectionID = 1000L; private long m_connectionID; private readonly ConnectionCollection _listConnections = new ConnectionCollection(); private readonly SocketAsyncEventArgsPool _saePool = null; #endregion #region Constructors /// /// new /// /// /// /// socketBufferSize /// messageBufferSize protected BaseHost(int socketBufferSize, int messageBufferSize) { if (socketBufferSize < 1) throw new ArgumentOutOfRangeException("socketBufferSize"); if (messageBufferSize < 1) throw new ArgumentOutOfRangeException("messageBufferSize"); this.SocketBufferSize = socketBufferSize; this.MessageBufferSize = messageBufferSize; this._saePool = new SocketAsyncEventArgsPool(messageBufferSize); } #endregion #region IHost Members /// /// get socket buffer size /// public int SocketBufferSize { get; private set; } /// /// get message buffer size /// public int MessageBufferSize { get; private set; } /// /// create new /// /// /// /// socket is null public virtual IConnection NewConnection(Socket socket) { if (socket == null) throw new ArgumentNullException("socket"); socket.NoDelay = true; socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.DontLinger, true); socket.ReceiveBufferSize = this.SocketBufferSize; socket.SendBufferSize = this.SocketBufferSize; return new DefaultConnection(this.NextConnectionID(), socket, this); } /// /// get by connectionID /// /// /// public IConnection GetConnectionByID(long connectionID) { return this._listConnections.Get(connectionID); } /// /// list all /// /// public IConnection[] ListAllConnection() { return this._listConnections.ToArray(); } /// /// get connection count. /// /// public int CountConnection() { return this._listConnections.Count(); } /// /// 启动 /// public virtual void Start() { } /// /// 停止 /// public virtual void Stop() { this._listConnections.DisconnectAll(); } #endregion #region Protected Methods /// /// 生成下一个连接ID /// /// protected long NextConnectionID() { this.m_connectionID = Interlocked.Increment(ref BaseHost.s_connectionID); return this.m_connectionID; } /// /// register connection /// /// /// connection is null protected void RegisterConnection(IConnection connection) { if (connection == null) throw new ArgumentNullException("connection"); if (connection.Active) { this._listConnections.Add(connection); this.OnConnected(connection); } } /// /// OnConnected /// /// protected virtual void OnConnected(IConnection connection) { Log.Trace.Debug(string.Concat("socket connected, id:", connection.ConnectionID.ToString(), ", remot endPoint:", connection.RemoteEndPoint == null ? "unknow" : connection.RemoteEndPoint.ToString(), ", local endPoint:", connection.LocalEndPoint == null ? "unknow" : connection.LocalEndPoint.ToString())); } /// /// OnStartSending /// /// /// protected virtual void OnStartSending(IConnection connection, Packet packet) { } /// /// OnSendCallback /// /// /// /// protected virtual void OnSendCallback(IConnection connection, Packet packet, bool isSuccess) { } /// /// OnMessageReceived /// /// /// protected virtual void OnMessageReceived(IConnection connection, MessageReceivedEventArgs e) { } /// /// OnDisconnected /// /// /// /// connection is null protected virtual void OnDisconnected(IConnection connection, Exception ex) { this._listConnections.Remove(connection.ConnectionID); Log.Trace.Debug(string.Concat("socket disconnected, id:", connection.ConnectionID.ToString(), ", remot endPoint:", connection.RemoteEndPoint == null ? "unknow" : connection.RemoteEndPoint.ToString(), ", local endPoint:", connection.LocalEndPoint == null ? "unknow" : connection.LocalEndPoint.ToString(), ex == null ? string.Empty : string.Concat(", reason is: ", ex.ToString()))); } /// /// OnError /// /// /// protected virtual void OnConnectionError(IConnection connection, Exception ex) { Log.Trace.Error(ex.Message, ex); } #endregion /// /// pool /// private class SocketAsyncEventArgsPool { #region Private Members private readonly int _messageBufferSize; private readonly ConcurrentStack _pool = new ConcurrentStack(); #endregion #region Constructors /// /// new /// /// public SocketAsyncEventArgsPool(int messageBufferSize) { this._messageBufferSize = messageBufferSize; } #endregion #region Public Methods /// /// acquire /// /// public SocketAsyncEventArgs Acquire() { SocketAsyncEventArgs e = null; if (this._pool.TryPop(out e)) return e; e = new SocketAsyncEventArgs(); e.SetBuffer(new byte[this._messageBufferSize], 0, this._messageBufferSize); return e; } /// /// release /// /// public void Release(SocketAsyncEventArgs e) { if (this._pool.Count < 10000) { this._pool.Push(e); return; } e.Dispose(); } #endregion } #region DefaultConnection /// /// default socket connection /// private class DefaultConnection : IConnection { #region Private Members private int _active = 1; private DateTime _latestActiveTime = Utils.Date.UtcNow; private readonly int _messageBufferSize; private readonly BaseHost _host = null; private readonly Socket _socket = null; private SocketAsyncEventArgs _saeSend = null; private Packet _currSendingPacket = null; private readonly PacketQueue _packetQueue = null; private SocketAsyncEventArgs _saeReceive = null; private MemoryStream _tsStream = null; private int _isReceiving = 0; private DateTime _lastAliveTime; #endregion #region Constructors /// /// new /// /// /// /// /// socket is null /// host is null public DefaultConnection(long connectionID, Socket socket, BaseHost host) { if (socket == null) throw new ArgumentNullException("socket"); if (host == null) throw new ArgumentNullException("host"); this.ConnectionID = connectionID; this._socket = socket; this._messageBufferSize = host.MessageBufferSize; this._host = host; this._lastAliveTime = DateTime.Now; try { this.LocalEndPoint = (IPEndPoint)socket.LocalEndPoint; this.RemoteEndPoint = (IPEndPoint)socket.RemoteEndPoint; } catch (Exception ex) { Log.Trace.Error("get socket endPoint error.", ex); } //init send this._saeSend = host._saePool.Acquire(); this._saeSend.Completed += this.SendAsyncCompleted; this._packetQueue = new PacketQueue(this.SendPacketInternal); //init receive this._saeReceive = host._saePool.Acquire(); this._saeReceive.Completed += this.ReceiveAsyncCompleted; } #endregion #region IConnection Members /// /// 连接断开事件 /// public event DisconnectedHandler Disconnected; /// /// return the connection is active. /// public bool Active { get { return Thread.VolatileRead(ref this._active) == 1; } } /// /// get the connection latest active time. /// public DateTime LatestActiveTime { get { return this._latestActiveTime; } } /// /// get the connection id. /// public long ConnectionID { get; private set; } /// /// 获取本地IP地址 /// public IPEndPoint LocalEndPoint { get; private set; } /// /// 获取远程IP地址 /// public IPEndPoint RemoteEndPoint { get; private set; } /// /// 获取或设置与用户数据 /// public object UserData { get; set; } /// /// 异步发送数据 /// /// public void BeginSend(Packet packet) { if (!this._packetQueue.TrySend(packet)) this.OnSendCallback(packet, false); } /// /// 异步接收数据 /// public void BeginReceive() { if (Interlocked.CompareExchange(ref this._isReceiving, 1, 0) == 0) this.ReceiveInternal(); } /// /// 异步断开连接 /// /// public void BeginDisconnect(Exception ex = null) { if (Interlocked.CompareExchange(ref this._active, 0, 1) == 1) this.DisconnectInternal(ex); } #endregion #region Private Methods #region Free /// /// free send queue /// private void FreeSendQueue() { var result = this._packetQueue.Close(); if (result.BeforeState == PacketQueue.CLOSED) return; if (result.Packets != null) foreach (var p in result.Packets) this.OnSendCallback(p, false); if (result.BeforeState == PacketQueue.IDLE) this.FreeSend(); } /// /// free for send. /// private void FreeSend() { this._currSendingPacket = null; this._saeSend.Completed -= this.SendAsyncCompleted; this._host._saePool.Release(this._saeSend); this._saeSend = null; } /// /// free fo receive. /// private void FreeReceive() { this._saeReceive.Completed -= this.ReceiveAsyncCompleted; this._host._saePool.Release(this._saeReceive); this._saeReceive = null; if (this._tsStream != null) { this._tsStream.Close(); this._tsStream = null; } } #endregion #region Fire Events /// /// fire StartSending /// /// private void OnStartSending(Packet packet) { this._host.OnStartSending(this, packet); } /// /// fire SendCallback /// /// /// private void OnSendCallback(Packet packet, bool isSuccess) { if (isSuccess) this._latestActiveTime = Utils.Date.UtcNow; else packet.SentSize = 0; this._host.OnSendCallback(this, packet, isSuccess); } /// /// fire MessageReceived /// /// private void OnMessageReceived(MessageReceivedEventArgs e) { this._latestActiveTime = Utils.Date.UtcNow; try { this._host.OnMessageReceived(this, e); } catch { } } /// /// fire Disconnected /// private void OnDisconnected(Exception ex) { if (this.Disconnected != null) this.Disconnected(this, ex); this._host.OnDisconnected(this, ex); } /// /// fire Error /// /// private void OnError(Exception ex) { this._host.OnConnectionError(this, ex); } #endregion #region Send /// /// internal send packet. /// /// /// packet is null private void SendPacketInternal(Packet packet) { this._currSendingPacket = packet; this.OnStartSending(packet); this.SendPacketInternal(this._saeSend); } /// /// internal send packet. /// /// private void SendPacketInternal(SocketAsyncEventArgs e) { var packet = this._currSendingPacket; //按messageBufferSize大小分块传输 var length = Math.Min(packet.Payload.Length - packet.SentSize, this._messageBufferSize); var completedAsync = true; try { //copy data to send buffer Buffer.BlockCopy(packet.Payload, packet.SentSize, e.Buffer, 0, length); e.SetBuffer(0, length); completedAsync = this._socket.SendAsync(e); } catch (Exception ex) { this.BeginDisconnect(ex); this.FreeSend(); this.OnSendCallback(packet, false); this.OnError(ex); } if (!completedAsync) ThreadPool.QueueUserWorkItem(_ => this.SendAsyncCompleted(this, e)); } /// /// async send callback /// /// /// private void SendAsyncCompleted(object sender, SocketAsyncEventArgs e) { var packet = this._currSendingPacket; //send error! if (e.SocketError != SocketError.Success) { this.BeginDisconnect(new SocketException((int)e.SocketError)); this.FreeSend(); this.OnSendCallback(packet, false); return; } packet.SentSize += e.BytesTransferred; if (e.Offset + e.BytesTransferred < e.Count) { //continue to send until all bytes are sent! var completedAsync = true; try { e.SetBuffer(e.Offset + e.BytesTransferred, e.Count - e.BytesTransferred - e.Offset); completedAsync = this._socket.SendAsync(e); } catch (Exception ex) { this.BeginDisconnect(ex); this.FreeSend(); this.OnSendCallback(packet, false); this.OnError(ex); } if (!completedAsync) ThreadPool.QueueUserWorkItem(_ => this.SendAsyncCompleted(sender, e)); } else { if (packet.IsSent()) { this._currSendingPacket = null; this.OnSendCallback(packet, true); //try send next packet if (!this._packetQueue.TrySendNext()) this.FreeSend(); } else this.SendPacketInternal(e);//continue send this packet } } #endregion #region Receive /// /// receive /// private void ReceiveInternal() { bool completed = true; try { completed = this._socket.ReceiveAsync(this._saeReceive); } catch (Exception ex) { this.BeginDisconnect(ex); this.FreeReceive(); this.OnError(ex); } if (!completed) ThreadPool.QueueUserWorkItem(_ => this.ReceiveAsyncCompleted(this, this._saeReceive)); } /// /// async receive callback /// /// /// private void ReceiveAsyncCompleted(object sender, SocketAsyncEventArgs e) { if (e.SocketError != SocketError.Success) { this.BeginDisconnect(new SocketException((int)e.SocketError)); this.FreeReceive(); return; } if (e.BytesTransferred < 1) { this.BeginDisconnect(); this.FreeReceive(); return; } ArraySegment buffer; var ts = this._tsStream; if (ts == null || ts.Length == 0) buffer = new ArraySegment(e.Buffer, 0, e.BytesTransferred); else { ts.Write(e.Buffer, 0, e.BytesTransferred); buffer = new ArraySegment(ts.GetBuffer(), 0, (int)ts.Length); } this.OnMessageReceived(new MessageReceivedEventArgs(buffer, this.MessageProcessCallback)); } /// /// message process callback /// /// /// /// readlength less than 0 or greater than payload.Count. private void MessageProcessCallback(ArraySegment payload, int readlength) { if (readlength < 0 || readlength > payload.Count) throw new ArgumentOutOfRangeException("readlength", "readlength less than 0 or greater than payload.Count."); var ts = this._tsStream; if (readlength == 0) { if (ts == null) this._tsStream = ts = new MemoryStream(this._messageBufferSize); else ts.SetLength(0); ts.Write(payload.Array, payload.Offset, payload.Count); this.ReceiveInternal(); return; } if (readlength == payload.Count) { if (ts != null) ts.SetLength(0); this.ReceiveInternal(); return; } //粘包处理 this.OnMessageReceived(new MessageReceivedEventArgs( new ArraySegment(payload.Array, payload.Offset + readlength, payload.Count - readlength), this.MessageProcessCallback)); } #endregion #region Disconnect /// /// disconnect /// /// private void DisconnectInternal(Exception reason) { var e = new SocketAsyncEventArgs(); e.Completed += this.DisconnectAsyncCompleted; e.UserToken = reason; var completedAsync = true; try { this._socket.Shutdown(SocketShutdown.Both); completedAsync = this._socket.DisconnectAsync(e); } catch (Exception ex) { Log.Trace.Error(ex.Message, ex); ThreadPool.QueueUserWorkItem(_ => this.DisconnectAsyncCompleted(this, e)); return; } if (!completedAsync) ThreadPool.QueueUserWorkItem(_ => this.DisconnectAsyncCompleted(this, e)); } /// /// async disconnect callback /// /// /// private void DisconnectAsyncCompleted(object sender, SocketAsyncEventArgs e) { //dispose socket try { this._socket.Close(); } catch (Exception ex) { Log.Trace.Error(ex.Message, ex); } //dispose socketAsyncEventArgs var reason = e.UserToken as Exception; e.Completed -= this.DisconnectAsyncCompleted; e.Dispose(); //fire disconnected this.OnDisconnected(reason); //close send queue this.FreeSendQueue(); } #endregion #endregion #region PacketQueue /// /// packet queue /// private class PacketQueue { #region Private Members public const int IDLE = 1; //空闲状态 public const int SENDING = 2; //发送中 public const int ENQUEUE = 3; //入列状态 public const int DEQUEUE = 4; //出列状态 public const int CLOSED = 5; //已关闭 private int _state = IDLE; //当前状态 private Queue _queue = new Queue(); private Action _sendAction = null; #endregion #region Constructors /// /// new /// /// /// sendAction is null. public PacketQueue(Action sendAction) { if (sendAction == null) throw new ArgumentNullException("sendAction"); this._sendAction = sendAction; } #endregion #region Public Methods /// /// try send packet /// /// /// if CLOSED return false. public bool TrySend(Packet packet) { var spin = true; while (spin) { switch (this._state) { case IDLE: if (Interlocked.CompareExchange(ref this._state, SENDING, IDLE) == IDLE) spin = false; break; case SENDING: if (Interlocked.CompareExchange(ref this._state, ENQUEUE, SENDING) == SENDING) { this._queue.Enqueue(packet); this._state = SENDING; return true; } break; case ENQUEUE: case DEQUEUE: Thread.Yield(); break; case CLOSED: return false; } } this._sendAction(packet); return true; } /// /// close /// /// public CloseResult Close() { var spin = true; int beforeState = -1; while (spin) { switch (this._state) { case IDLE: if (Interlocked.CompareExchange(ref this._state, CLOSED, IDLE) == IDLE) { spin = false; beforeState = IDLE; } break; case SENDING: if (Interlocked.CompareExchange(ref this._state, CLOSED, SENDING) == SENDING) { spin = false; beforeState = SENDING; } break; case ENQUEUE: case DEQUEUE: Thread.Yield(); break; case CLOSED: return new CloseResult(CLOSED, null); } } var arrPackets = this._queue.ToArray(); this._queue.Clear(); this._queue = null; this._sendAction = null; return new CloseResult(beforeState, arrPackets); } /// /// try send next packet /// /// if CLOSED return false. public bool TrySendNext() { var spin = true; Packet packet = null; while (spin) { switch (this._state) { case SENDING: if (Interlocked.CompareExchange(ref this._state, DEQUEUE, SENDING) == SENDING) { if (this._queue.Count == 0) { this._state = IDLE; return true; } packet = this._queue.Dequeue(); this._state = SENDING; spin = false; } break; case ENQUEUE: Thread.Yield(); break; case CLOSED: return false; } } this._sendAction(packet); return true; } #endregion #region CloseResult /// /// close queue result /// public sealed class CloseResult { /// /// before close state /// public readonly int BeforeState; /// /// wait sending packet array /// public readonly Packet[] Packets; /// /// new /// /// /// public CloseResult(int beforeState, Packet[] packets) { this.BeforeState = beforeState; this.Packets = packets; } } #endregion } #endregion } #endregion } }