vc++ socket实现的支持断点续传的下载器

时间:2022-09-04 18:45:09

网上找了一堆代码,有用wininet的,还有用socket的,整理了半天,还是觉得socket靠谱。


只支持内存中断点续传。如果要加上在磁盘上断点续传,原理也差不多,不是本文重点。


注释:

1. CByteBufferVector是一个缓存池,动态分配BYTE形数组空间用的。代码略,可以简单看成BYTE数组。

2. GetStringA是一个CString转CStringA的函数,无需多说。

3. 除了win socket基本没有其它依赖,噢对,ATL::CString除外……


头文件:

class CSocketDownloader;


/**
 * 下载任务
 */
class CDownloadTask
{
    friend class CSocketDownloader;


public:


    CDownloadTask();


    CStringA GetUrlA() const;
    CStringA GetAgnetA() const;
    void ParseUrl();
    
    int Percentage() const;
    DWORD RemainTimeSec(DWORD dwTickElapsed, unsigned int uBytesTransferred) const;


    CString         m_strUrl;           // 下载地址
    CString         m_strAgent;         // 用户agent
    int             m_nMaxTryCount;     // 最多重试次数(重定向不算重试,默认20次)
    int             m_nTimeoutSec;      // socket超时(秒,默认10秒)
    int             m_nPort;            // 端口(默认80)
    HWND            m_hWnd;             // 接收下载进度消息的窗口句柄
    LONG            *m_pTerminate;      // 指向是否中止的标志位,一般由用户界面操作(如点击“取消”按钮)更改此值


protected:


    CStringA        m_strAbsoluteUrlA;
    CStringA        m_strQueryA;
    CStringA        m_strHostA;
    unsigned int    m_uReadBytes;
    unsigned int    m_uTotalBytes;
};


/**
 * socket实现的断点续传下载器
 */
class CSocketDownloader
{
public:
    CSocketDownloader();
    virtual ~CSocketDownloader();


    // 下载到一个buffer
    DWORD DownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec);


    // 下载到一个文件
    DWORD DownloadToFile(CDownloadTask &task, CString strOutputFile);


protected:


    DWORD DoDownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec);
    DWORD ConnectServer(const CDownloadTask &task, SOCKET hSocket);
    DWORD DoDownloadToBufferInner(CDownloadTask &task, CByteBufferVector &bufVec, SOCKET hSocket);
    
    int GetSleepSecCount(int nTryCount) const;
    int GetBufferSize(const CDownloadTask &task) const;
    CStringA GenerateRequest(CDownloadTask &task) const;


};


CPP文件:

#include <math.h>
#include <time.h>


const int BLOCK_SIZE = 1024 * 64;
const int DEFAULT_MAX_TRY = 20;
const int DEFAULT_TIMEOUT = 10;


//////////////////////////////////////////////////////////////////////////
// 下载任务
//////////////////////////////////////////////////////////////////////////


CDownloadTask::CDownloadTask()
    : m_nPort(INTERNET_DEFAULT_HTTP_PORT),
    m_nMaxTryCount(DEFAULT_MAX_TRY),
    m_uReadBytes(0),
    m_uTotalBytes(0),
    m_nTimeoutSec(DEFAULT_TIMEOUT),
    m_hWnd(NULL),
    m_pTerminate(NULL)
{


}


CStringA CDownloadTask::GetUrlA() const
{
    return GetStringA(m_strUrl);
}


CStringA CDownloadTask::GetAgnetA() const
{
    return GetStringA(m_strAgent);
}


void CDownloadTask::ParseUrl()
{
    m_strAbsoluteUrlA = m_strHostA = m_strQueryA = "";


    CStringA strUrlA = this->GetUrlA();
    const char *pUrl = strUrlA;
    const char *p = pUrl;
    const char *szHttpHead = "http://";


    if (_strnicmp(pUrl, szHttpHead, strlen(szHttpHead)) == 0)
    {
        p = pUrl + strlen(szHttpHead);
    }


    int nHostLen = 0;
    const char *q = strchr(p, '/');
    if (q != NULL)
    {
        nHostLen = q - p;
        int nPathLen = 0;
        const char *r = strchr(q, '?');
        if (r != NULL)
        {
            // 解析query
            r++;
            m_strQueryA = r;
            nPathLen = r - q - 1;
        }
        else
        {
            nPathLen = strlen(q);
        }


        // 解析abs_path
        m_strAbsoluteUrlA.Append(q, nPathLen);
    }
    else
    {
        nHostLen = strlen(p);
    }


    // 解析host
    m_strHostA.Append(p, nHostLen);


    // 解析port
    const char *r = strchr(m_strHostA, ':');
    if (r == 0)
    {
        m_nPort = INTERNET_DEFAULT_HTTP_PORT;
    }
    else
    {
        m_nPort = atoi(r + 1);
    }
}


int CDownloadTask::Percentage() const
{
    return (m_uTotalBytes == 0)
        ? 0
        : (int)((unsigned long long)m_uReadBytes * 100 / (unsigned long long) m_uTotalBytes);
}


DWORD CDownloadTask::RemainTimeSec( DWORD dwTickElapsed, unsigned int uBytesTransferred ) const
{
    unsigned long long uTickElapsed = (unsigned long long)dwTickElapsed;
    unsigned long long uBytes = (unsigned long long)uBytesTransferred;
    unsigned long long uRemain = (unsigned long long)(m_uTotalBytes - m_uReadBytes);
    Log(_T("elapsed=%d, get=%d, remain=%d\n"), dwTickElapsed, uBytesTransferred, m_uTotalBytes - m_uReadBytes);
    return (DWORD)(uTickElapsed * uRemain / (uBytes * CLOCKS_PER_SEC));
}


//////////////////////////////////////////////////////////////////////////
// socket下载器
//////////////////////////////////////////////////////////////////////////


CSocketDownloader::CSocketDownloader()
{


}


CSocketDownloader::~CSocketDownloader()
{


}


DWORD CSocketDownloader::DownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec)
{
    int nTryCount = 0;
    DWORD dwRet = this->DoDownloadToBuffer(task, bufVec);
    if (web::THE_REDIRECT != dwRet)
    {
        nTryCount++;
    }


    while (
        dwRet != web::THE_SUCCEED
        && dwRet != web::THE_USER_CANCELED
        && nTryCount < task.m_nMaxTryCount
        )
    {
        int nTime = this->GetSleepSecCount(nTryCount);
        ::Sleep(nTime);
        dwRet = this->DoDownloadToBuffer(task, bufVec);
        if (web::THE_REDIRECT != dwRet)
        {
            nTryCount++;
        }
    }


    return dwRet;
}


DWORD CSocketDownloader::DownloadToFile( CDownloadTask &task, CString strOutputFile )
{
    CByteBufferVector vec;


    DWORD dwRet = this->DownloadToBuffer(task, vec);
    if (web::THE_SUCCEED != dwRet)
    {
        return dwRet;
    }


    HANDLE hFile = ::CreateFile(strOutputFile, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
    if (hFile == INVALID_HANDLE_VALUE)
    {
        return web::THE_CREATE_FILE;
    }


    BYTE *pBuffer = vec.Ptr(0, task.m_uTotalBytes);
    DWORD dwBytesWritten = 0;
    ::WriteFile(hFile, pBuffer, task.m_uTotalBytes, &dwBytesWritten, NULL);
    ::CloseHandle(hFile);
    return (dwBytesWritten == task.m_uTotalBytes) ? web::THE_SUCCEED : web::THE_WRITE_FILE;
}


DWORD CSocketDownloader::DoDownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec)
{
    task.ParseUrl();


    SOCKET hSocket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
    if (hSocket == INVALID_SOCKET)
    {
        return web::THE_CREATE_SOCKET;
    }


    DWORD dwRet = this->ConnectServer(task, hSocket);
    if (web::THE_SUCCEED != dwRet)
    {
        closesocket(hSocket);
        return dwRet;
    }


    dwRet =  this->DoDownloadToBufferInner(task, bufVec, hSocket);
    closesocket(hSocket);
    return dwRet;
}


DWORD CSocketDownloader::DoDownloadToBufferInner(CDownloadTask &task, CByteBufferVector &bufVec, SOCKET hSocket)
{
    // 发送请求
    CStringA strRequest = this->GenerateRequest(task);
    int nLen = send(hSocket, strRequest, strRequest.GetLength(), 0);
    if (nLen <= 0)
    {
        return web::THE_SEND_HTTP_HEADER;
    }


    // 接收一部分数据(header部分,以"\r\n\r\n"为止)
    CStringA strRecvBuf;
    char szRecvBuf[MAX_PATH] = { 0 };
    nLen = recv(hSocket, szRecvBuf, MAX_PATH - 1, 0);
    while (nLen > 0)
    {
        szRecvBuf[nLen] = 0;
        strRecvBuf.Append(szRecvBuf);
        if (strstr(szRecvBuf, "\r\n\r\n") != NULL)
        {
            break;
        }
        nLen = recv(hSocket, szRecvBuf, MAX_PATH - 1, 0);
    }
    
    // 找到两个回车换行,即content起始位置。
    const char *pData = strstr(szRecvBuf, "\r\n\r\n");
    if (pData == NULL)
    {
        return web::THE_INVALID_RECV_END;
    }


    pData += 4;


    const char *p = strchr(strRecvBuf, ' ');
    if (p != NULL)
    {
        p++;
        DWORD dwRet = atoi(p);


        if (dwRet == HTTP_STATUS_PARTIAL_CONTENT)        // 206: 断点续传
        {
            const char *q = strstr(strRecvBuf, "\r\nContent-Length:");
            if (q == NULL)
            {
                return web::THE_NO_CONTENT_LENGTH;
            }
            task.m_uTotalBytes = task.m_uReadBytes + atoi(q + 17);
        }
        else if (dwRet == HTTP_STATUS_OK)               // 200: 重新下载(服务器不支持断点续传)
        {
            const char *q = strstr(strRecvBuf, "\r\nContent-Length:");
            if (q == NULL)
            {
                return web::THE_NO_CONTENT_LENGTH;
            }
            task.m_uTotalBytes = task.m_uReadBytes + atoi(q + 17);
            // 清除已经下载的内容
            task.m_uReadBytes = 0;
            bufVec.Reset();
        }
        else if (dwRet == HTTP_STATUS_REDIRECT)         // 302: 重定向
        {
            const char *q = strstr(strRecvBuf, "\r\nLocation:");
            if (q == NULL)
            {
                return web::THE_NO_REDIRECT_LOCATION;
            }
            q += 12;
            const char *r = strstr(q, "\r\n");
            if (r == NULL)
            {
                return web::THE_REDIRECT_INVALID_FORMAT;
            }


            int nUrlLen = r - q;
            CStringA strUrlA;
            strUrlA.Append(q, nUrlLen);
            task.m_strUrl = GetString(strUrlA);
            return web::THE_REDIRECT;
        }
        else
        {
            return web::THE_INVALID_STAUS_CODE;
        }
    }
    
    // 复制已传回来的第一部分content
    int nSize = nLen - (pData - szRecvBuf);
    BYTE *pBuffer = bufVec.Ptr(task.m_uReadBytes, nSize);
    memcpy(pBuffer, pData, nSize);
    task.m_uReadBytes += nSize;


    // 继续接收http content,即下载内容。
    int nBufferSize = this->GetBufferSize(task);
    pBuffer = bufVec.Ptr(task.m_uReadBytes, nBufferSize);


    DWORD dwLastTick = 0;


    // 下载测速
    DWORD dwTickStart = ::GetTickCount();
    unsigned int uReadBytesStart = task.m_uReadBytes;


    while (true)
    {
        if (::InterlockedCompareExchange(task.m_pTerminate, 1, 1))
        {
            // 用户取消。
            return web::THE_USER_CANCELED;
        }


        nLen = recv(hSocket, (char *)(pBuffer), nBufferSize, 0);
        if (nLen < 0)
        {
            return web::THE_RECV_FAIL;
        }
        else if (nLen == 0)
        {
            break;  // 接收完成
        }


        task.m_uReadBytes += nLen;
        if (task.m_uReadBytes == task.m_uTotalBytes)
        {
            break;  // 接收完成
        }


        nBufferSize = this->GetBufferSize(task);
        pBuffer = bufVec.Ptr(task.m_uReadBytes, nBufferSize);


        if (NULL != task.m_hWnd)
        {
            DWORD dwTick = ::GetTickCount();
            if (dwLastTick == 0 || (dwTick - dwLastTick >= 100))        // 每秒最多发10次消息
            {
                // 发送当前下载进度和剩余时间消息
                dwLastTick = dwTick;
                ::PostMessage(task.m_hWnd, WM_FASTINSTALL_PROGRESS_VALUE,
                    static_cast<WPARAM>(task.Percentage()), static_cast<LPARAM>(task.RemainTimeSec(dwTick - dwTickStart, task.m_uReadBytes - uReadBytesStart))
                    );
            }
        }
    }


    DWORD dwTick = ::GetTickCount();
    if (NULL != task.m_hWnd)
    {
        ::PostMessage(
            task.m_hWnd, WM_FASTINSTALL_PROGRESS_VALUE,
            static_cast<WPARAM>(task.Percentage()), -1
            );
    }


    return web::THE_SUCCEED;
}


DWORD CSocketDownloader::ConnectServer(const CDownloadTask &task, SOCKET hSocket)
{
    PHOSTENT pHostent = gethostbyname(task.m_strHostA);
    if (pHostent == NULL)
    {
        return web::THE_GET_HOST_BY_NAME;
    }


    sockaddr_in addrSvr;
    addrSvr.sin_port = htons((u_short)task.m_nPort);
    addrSvr.sin_family = AF_INET;
    addrSvr.sin_addr.s_addr = *(ULONG*)pHostent->h_addr_list[0];
    if (SOCKET_ERROR == connect(hSocket, (sockaddr*)&addrSvr, sizeof(addrSvr)))
    {
        return web::THE_CONNECT_SOCKET;
    }


    int opt = task.m_nTimeoutSec * 1000;
    if (0 != setsockopt(hSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&opt, sizeof(opt)))
    {
        return web::THE_SET_SOCK_OPT1;
    }


    BOOL bKeepAlive = TRUE;  
    int len = sizeof(bKeepAlive);
    getsockopt(hSocket, SOL_SOCKET, SO_KEEPALIVE, (char*)&bKeepAlive, &len);


    bKeepAlive = TRUE;
    if (0 != setsockopt(hSocket, SOL_SOCKET, SO_KEEPALIVE, (char *)&bKeepAlive, sizeof(BOOL)))
    {
        return web::THE_SET_SOCK_OPT2;
    }


    return web::THE_SUCCEED;
}


int CSocketDownloader::GetSleepSecCount( int nTryCount ) const
{
    return (nTryCount + 1) * 1000;
}


int CSocketDownloader::GetBufferSize( const CDownloadTask &task ) const
{
    return std::min<int>(BLOCK_SIZE, task.m_uTotalBytes - task.m_uReadBytes);
}


CStringA CSocketDownloader::GenerateRequest( CDownloadTask &task ) const
{
    CStringA strRequest;


    CStringA strTemp;
    if (task.m_strQueryA.IsEmpty())
    {
        strTemp.Format(
            "GET %s HTTP/1.1\r\nHOST: %s\r\n",
            task.m_strAbsoluteUrlA.GetString(), task.m_strHostA.GetString()
            );
    }
    else
    {
        strTemp.Format(
            "GET %s?%s HTTP/1.1\r\nHOST: %s\r\n",
            task.m_strAbsoluteUrlA.GetString(), task.m_strQueryA.GetString(), task.m_strHostA.GetString()
            );
    }
    strRequest.Append(strTemp);


    strTemp.Format("Range: bytes=%d-\r\n", task.m_uReadBytes);
    strRequest.Append(strTemp);


    strTemp.Format("User-Agent: %s\r\n", task.GetAgnetA().GetString());
    strRequest.Append(strTemp);


    strRequest.Append("Accept: */*\r\n");
    strRequest.Append("Accept-Encoding: gzip, deflate\r\n");
    strRequest.Append("Connection: Keep-Alive\r\n\r\n");
    
    return strRequest;
}