trinitycore 魔兽服务器源码分析(二) 网络

时间:2021-04-21 09:19:40

trinitycore 魔兽服务器源码分析(二) 网络

书接上文 继续分析Socket.h SocketMgr.h

template<class T>
class Socket : public std::enable_shared_from_this<T>

根据智能指针的使用规则 类中有使用本类自己的指针 必须继承自enable_shared_from_this<> 防止自引用 不能释放的BUG

class Socket封装了asio中的socket类 获取远端ip 端口等功能, 并且额外提供异步读写的功能

类中的两个原子变量 _closed _closing标记该socket的关闭开启状态

bool Update()函数根据socket是否是同步异步标记进行写入队列的处理。 同步则进行处理 异步则暂缓

void AsyncRead()  void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))

则采取异步读取socket 调用默认函数ReadHandlerInternal() 或者输入函数T::*callback()

由于AsyncReadWithCallback 函数中bind 需要 T类的指针 所以才有开头的继承std::enable_shared_from_this<T>

但是使用比较怪异  std::enable_shared_from_this<>用法一般是继承自己本身

class self :: public std::enable_shared_from_this<self>{
public:

  void test(){

  // only for test 
  std::bind(&self ::test, shared_from_this());

  }
}

异步写write类似  ,由bool AsyncProcessQueue()函数发起

使用asio的async_write_some函数异步读取连接内容 并调用回调函数WriteHandler()或者WriteHandlerWrapper()

不过需要结合MessageBuffer 一起跟进流程

类代码如下

 /*
* Copyright (C) 2008-2017 TrinityCore <http://www.trinitycore.org/>
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/ #ifndef __SOCKET_H__
#define __SOCKET_H__ #include "MessageBuffer.h"
#include "Log.h"
#include <atomic>
#include <queue>
#include <memory>
#include <functional>
#include <type_traits>
#include <boost/asio/ip/tcp.hpp> using boost::asio::ip::tcp; #define READ_BLOCK_SIZE 4096
#ifdef BOOST_ASIO_HAS_IOCP
#define TC_SOCKET_USE_IOCP
#endif template<class T>
class Socket : public std::enable_shared_from_this<T>
{
public:
explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false)
{
_readBuffer.Resize(READ_BLOCK_SIZE);
} virtual ~Socket()
{
_closed = true;
boost::system::error_code error;
_socket.close(error);
} virtual void Start() = ; virtual bool Update()
{
if (_closed)
return false; #ifndef TC_SOCKET_USE_IOCP
if (_isWritingAsync || (_writeQueue.empty() && !_closing))
return true; for (; HandleQueue();)
;
#endif return true;
} boost::asio::ip::address GetRemoteIpAddress() const
{
return _remoteAddress;
} uint16 GetRemotePort() const
{
return _remotePort;
} void AsyncRead()
{
if (!IsOpen())
return; _readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
} void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))
{
if (!IsOpen())
return; _readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
} void QueuePacket(MessageBuffer&& buffer)
{
_writeQueue.push(std::move(buffer)); #ifdef TC_SOCKET_USE_IOCP
AsyncProcessQueue();
#endif
} bool IsOpen() const { return !_closed && !_closing; } void CloseSocket()
{
if (_closed.exchange(true))
return; boost::system::error_code shutdownError;
_socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError);
if (shutdownError)
TC_LOG_DEBUG("network", "Socket::CloseSocket: %s errored when shutting down socket: %i (%s)", GetRemoteIpAddress().to_string().c_str(),
shutdownError.value(), shutdownError.message().c_str()); OnClose();
} /// Marks the socket for closing after write buffer becomes empty
void DelayedCloseSocket() { _closing = true; } MessageBuffer& GetReadBuffer() { return _readBuffer; } protected:
virtual void OnClose() { } virtual void ReadHandler() = ; bool AsyncProcessQueue()
{
if (_isWritingAsync)
return false; _isWritingAsync = true; #ifdef TC_SOCKET_USE_IOCP
MessageBuffer& buffer = _writeQueue.front();
_socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), std::bind(&Socket<T>::WriteHandler,
this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
#else
_socket.async_write_some(boost::asio::null_buffers(), std::bind(&Socket<T>::WriteHandlerWrapper,
this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
#endif return false;
} void SetNoDelay(bool enable)
{
boost::system::error_code err;
_socket.set_option(tcp::no_delay(enable), err);
if (err)
TC_LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for %s - %d (%s)",
GetRemoteIpAddress().to_string().c_str(), err.value(), err.message().c_str());
} private:
void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes)
{
if (error)
{
CloseSocket();
return;
} _readBuffer.WriteCompleted(transferredBytes);
ReadHandler();
} #ifdef TC_SOCKET_USE_IOCP void WriteHandler(boost::system::error_code error, std::size_t transferedBytes)
{
if (!error)
{
_isWritingAsync = false;
_writeQueue.front().ReadCompleted(transferedBytes);
if (!_writeQueue.front().GetActiveSize())
_writeQueue.pop(); if (!_writeQueue.empty())
AsyncProcessQueue();
else if (_closing)
CloseSocket();
}
else
CloseSocket();
} #else void WriteHandlerWrapper(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/)
{
_isWritingAsync = false;
HandleQueue();
} bool HandleQueue()
{
if (_writeQueue.empty())
return false; MessageBuffer& queuedMessage = _writeQueue.front(); std::size_t bytesToSend = queuedMessage.GetActiveSize(); boost::system::error_code error;
std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error); if (error)
{
if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
return AsyncProcessQueue(); _writeQueue.pop();
if (_closing && _writeQueue.empty())
CloseSocket();
return false;
}
else if (bytesSent == )
{
_writeQueue.pop();
if (_closing && _writeQueue.empty())
CloseSocket();
return false;
}
else if (bytesSent < bytesToSend) // now n > 0
{
queuedMessage.ReadCompleted(bytesSent);
return AsyncProcessQueue();
} _writeQueue.pop();
if (_closing && _writeQueue.empty())
CloseSocket();
return !_writeQueue.empty();
} #endif tcp::socket _socket; boost::asio::ip::address _remoteAddress;
uint16 _remotePort; MessageBuffer _readBuffer;
std::queue<MessageBuffer> _writeQueue; std::atomic<bool> _closed;
std::atomic<bool> _closing; bool _isWritingAsync;
}; #endif // __SOCKET_H__

//======================================================

template<class SocketType>
class SocketMgr

将之前的Socket NetworkThread AsyncAcceptor

整合了起来

virtual bool StartNetwork(boost::asio::io_service& service, std::string const& bindIp, uint16 port, int threadCount)函数

开启threadCount个NetworkThread

创建一个AsyncAcceptor 异步ACCEPT连接

uint32 SelectThreadWithMinConnections() 函数会返回连接数目最少的NetworkThread 的线程索引

std::pair<tcp::socket*, uint32> GetSocketForAccept()则返回连接数目最少的线程索引和 该线程用于异步连接Socket指针

其余的start stop 就没什么了

值得关注的是virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)

当继承SocketMgr的服务器在accept的时候会调用该函数

函数功能是运行accept的Socket的run函数

并且讲Socket加入到NetworkThread 的Socket容器中(AddSocket函数)

整个类的代码如下

 /*
* Copyright (C) 2008-2017 TrinityCore <http://www.trinitycore.org/>
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/ #ifndef SocketMgr_h__
#define SocketMgr_h__ #include "AsyncAcceptor.h"
#include "Errors.h"
#include "NetworkThread.h"
#include <boost/asio/ip/tcp.hpp>
#include <memory> using boost::asio::ip::tcp; template<class SocketType>
class SocketMgr
{
public:
virtual ~SocketMgr()
{
ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
} virtual bool StartNetwork(boost::asio::io_service& service, std::string const& bindIp, uint16 port, int threadCount)
{
ASSERT(threadCount > ); AsyncAcceptor* acceptor = nullptr;
try
{
acceptor = new AsyncAcceptor(service, bindIp, port);
}
catch (boost::system::system_error const& err)
{
TC_LOG_ERROR("network", "Exception caught in SocketMgr.StartNetwork (%s:%u): %s", bindIp.c_str(), port, err.what());
return false;
} if (!acceptor->Bind())
{
TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor");
return false;
} _acceptor = acceptor;
_threadCount = threadCount;
_threads = CreateThreads(); ASSERT(_threads); for (int32 i = ; i < _threadCount; ++i)
_threads[i].Start(); return true;
} virtual void StopNetwork()
{
_acceptor->Close(); if (_threadCount != )
for (int32 i = ; i < _threadCount; ++i)
_threads[i].Stop(); Wait(); delete _acceptor;
_acceptor = nullptr;
delete[] _threads;
_threads = nullptr;
_threadCount = ;
} void Wait()
{
if (_threadCount != )
for (int32 i = ; i < _threadCount; ++i)
_threads[i].Wait();
} virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
{
try
{
std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
newSocket->Start(); _threads[threadIndex].AddSocket(newSocket);
}
catch (boost::system::system_error const& err)
{
TC_LOG_WARN("network", "Failed to retrieve client's remote address %s", err.what());
}
} int32 GetNetworkThreadCount() const { return _threadCount; } uint32 SelectThreadWithMinConnections() const
{
uint32 min = ; for (int32 i = ; i < _threadCount; ++i)
if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
min = i; return min;
} std::pair<tcp::socket*, uint32> GetSocketForAccept()
{
uint32 threadIndex = SelectThreadWithMinConnections();
return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
} protected:
SocketMgr() : _acceptor(nullptr), _threads(nullptr), _threadCount()
{
} virtual NetworkThread<SocketType>* CreateThreads() const = ; AsyncAcceptor* _acceptor;
NetworkThread<SocketType>* _threads;
int32 _threadCount;
}; #endif // SocketMgr_h__