网上找了一堆代码,有用wininet的,还有用socket的,整理了半天,还是觉得socket靠谱。
只支持内存中断点续传。如果要加上在磁盘上断点续传,原理也差不多,不是本文重点。
注释:
1. CByteBufferVector是一个缓存池,动态分配BYTE形数组空间用的。代码略,可以简单看成BYTE数组。
2. GetStringA是一个CString转CStringA的函数,无需多说。
3. 除了win socket基本没有其它依赖,噢对,ATL::CString除外……
头文件:
class CSocketDownloader;
/**
* 下载任务
*/
class CDownloadTask
{
friend class CSocketDownloader;
public:
CDownloadTask();
CStringA GetUrlA() const;
CStringA GetAgnetA() const;
void ParseUrl();
int Percentage() const;
DWORD RemainTimeSec(DWORD dwTickElapsed, unsigned int uBytesTransferred) const;
CString m_strUrl; // 下载地址
CString m_strAgent; // 用户agent
int m_nMaxTryCount; // 最多重试次数(重定向不算重试,默认20次)
int m_nTimeoutSec; // socket超时(秒,默认10秒)
int m_nPort; // 端口(默认80)
HWND m_hWnd; // 接收下载进度消息的窗口句柄
LONG *m_pTerminate; // 指向是否中止的标志位,一般由用户界面操作(如点击“取消”按钮)更改此值
protected:
CStringA m_strAbsoluteUrlA;
CStringA m_strQueryA;
CStringA m_strHostA;
unsigned int m_uReadBytes;
unsigned int m_uTotalBytes;
};
/**
* socket实现的断点续传下载器
*/
class CSocketDownloader
{
public:
CSocketDownloader();
virtual ~CSocketDownloader();
// 下载到一个buffer
DWORD DownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec);
// 下载到一个文件
DWORD DownloadToFile(CDownloadTask &task, CString strOutputFile);
protected:
DWORD DoDownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec);
DWORD ConnectServer(const CDownloadTask &task, SOCKET hSocket);
DWORD DoDownloadToBufferInner(CDownloadTask &task, CByteBufferVector &bufVec, SOCKET hSocket);
int GetSleepSecCount(int nTryCount) const;
int GetBufferSize(const CDownloadTask &task) const;
CStringA GenerateRequest(CDownloadTask &task) const;
};
CPP文件:
#include <math.h>
#include <time.h>
const int BLOCK_SIZE = 1024 * 64;
const int DEFAULT_MAX_TRY = 20;
const int DEFAULT_TIMEOUT = 10;
//////////////////////////////////////////////////////////////////////////
// 下载任务
//////////////////////////////////////////////////////////////////////////
CDownloadTask::CDownloadTask()
: m_nPort(INTERNET_DEFAULT_HTTP_PORT),
m_nMaxTryCount(DEFAULT_MAX_TRY),
m_uReadBytes(0),
m_uTotalBytes(0),
m_nTimeoutSec(DEFAULT_TIMEOUT),
m_hWnd(NULL),
m_pTerminate(NULL)
{
}
CStringA CDownloadTask::GetUrlA() const
{
return GetStringA(m_strUrl);
}
CStringA CDownloadTask::GetAgnetA() const
{
return GetStringA(m_strAgent);
}
void CDownloadTask::ParseUrl()
{
m_strAbsoluteUrlA = m_strHostA = m_strQueryA = "";
CStringA strUrlA = this->GetUrlA();
const char *pUrl = strUrlA;
const char *p = pUrl;
const char *szHttpHead = "http://";
if (_strnicmp(pUrl, szHttpHead, strlen(szHttpHead)) == 0)
{
p = pUrl + strlen(szHttpHead);
}
int nHostLen = 0;
const char *q = strchr(p, '/');
if (q != NULL)
{
nHostLen = q - p;
int nPathLen = 0;
const char *r = strchr(q, '?');
if (r != NULL)
{
// 解析query
r++;
m_strQueryA = r;
nPathLen = r - q - 1;
}
else
{
nPathLen = strlen(q);
}
// 解析abs_path
m_strAbsoluteUrlA.Append(q, nPathLen);
}
else
{
nHostLen = strlen(p);
}
// 解析host
m_strHostA.Append(p, nHostLen);
// 解析port
const char *r = strchr(m_strHostA, ':');
if (r == 0)
{
m_nPort = INTERNET_DEFAULT_HTTP_PORT;
}
else
{
m_nPort = atoi(r + 1);
}
}
int CDownloadTask::Percentage() const
{
return (m_uTotalBytes == 0)
? 0
: (int)((unsigned long long)m_uReadBytes * 100 / (unsigned long long) m_uTotalBytes);
}
DWORD CDownloadTask::RemainTimeSec( DWORD dwTickElapsed, unsigned int uBytesTransferred ) const
{
unsigned long long uTickElapsed = (unsigned long long)dwTickElapsed;
unsigned long long uBytes = (unsigned long long)uBytesTransferred;
unsigned long long uRemain = (unsigned long long)(m_uTotalBytes - m_uReadBytes);
Log(_T("elapsed=%d, get=%d, remain=%d\n"), dwTickElapsed, uBytesTransferred, m_uTotalBytes - m_uReadBytes);
return (DWORD)(uTickElapsed * uRemain / (uBytes * CLOCKS_PER_SEC));
}
//////////////////////////////////////////////////////////////////////////
// socket下载器
//////////////////////////////////////////////////////////////////////////
CSocketDownloader::CSocketDownloader()
{
}
CSocketDownloader::~CSocketDownloader()
{
}
DWORD CSocketDownloader::DownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec)
{
int nTryCount = 0;
DWORD dwRet = this->DoDownloadToBuffer(task, bufVec);
if (web::THE_REDIRECT != dwRet)
{
nTryCount++;
}
while (
dwRet != web::THE_SUCCEED
&& dwRet != web::THE_USER_CANCELED
&& nTryCount < task.m_nMaxTryCount
)
{
int nTime = this->GetSleepSecCount(nTryCount);
::Sleep(nTime);
dwRet = this->DoDownloadToBuffer(task, bufVec);
if (web::THE_REDIRECT != dwRet)
{
nTryCount++;
}
}
return dwRet;
}
DWORD CSocketDownloader::DownloadToFile( CDownloadTask &task, CString strOutputFile )
{
CByteBufferVector vec;
DWORD dwRet = this->DownloadToBuffer(task, vec);
if (web::THE_SUCCEED != dwRet)
{
return dwRet;
}
HANDLE hFile = ::CreateFile(strOutputFile, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
if (hFile == INVALID_HANDLE_VALUE)
{
return web::THE_CREATE_FILE;
}
BYTE *pBuffer = vec.Ptr(0, task.m_uTotalBytes);
DWORD dwBytesWritten = 0;
::WriteFile(hFile, pBuffer, task.m_uTotalBytes, &dwBytesWritten, NULL);
::CloseHandle(hFile);
return (dwBytesWritten == task.m_uTotalBytes) ? web::THE_SUCCEED : web::THE_WRITE_FILE;
}
DWORD CSocketDownloader::DoDownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec)
{
task.ParseUrl();
SOCKET hSocket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
if (hSocket == INVALID_SOCKET)
{
return web::THE_CREATE_SOCKET;
}
DWORD dwRet = this->ConnectServer(task, hSocket);
if (web::THE_SUCCEED != dwRet)
{
closesocket(hSocket);
return dwRet;
}
dwRet = this->DoDownloadToBufferInner(task, bufVec, hSocket);
closesocket(hSocket);
return dwRet;
}
DWORD CSocketDownloader::DoDownloadToBufferInner(CDownloadTask &task, CByteBufferVector &bufVec, SOCKET hSocket)
{
// 发送请求
CStringA strRequest = this->GenerateRequest(task);
int nLen = send(hSocket, strRequest, strRequest.GetLength(), 0);
if (nLen <= 0)
{
return web::THE_SEND_HTTP_HEADER;
}
// 接收一部分数据(header部分,以"\r\n\r\n"为止)
CStringA strRecvBuf;
char szRecvBuf[MAX_PATH] = { 0 };
nLen = recv(hSocket, szRecvBuf, MAX_PATH - 1, 0);
while (nLen > 0)
{
szRecvBuf[nLen] = 0;
strRecvBuf.Append(szRecvBuf);
if (strstr(szRecvBuf, "\r\n\r\n") != NULL)
{
break;
}
nLen = recv(hSocket, szRecvBuf, MAX_PATH - 1, 0);
}
// 找到两个回车换行,即content起始位置。
const char *pData = strstr(szRecvBuf, "\r\n\r\n");
if (pData == NULL)
{
return web::THE_INVALID_RECV_END;
}
pData += 4;
const char *p = strchr(strRecvBuf, ' ');
if (p != NULL)
{
p++;
DWORD dwRet = atoi(p);
if (dwRet == HTTP_STATUS_PARTIAL_CONTENT) // 206: 断点续传
{
const char *q = strstr(strRecvBuf, "\r\nContent-Length:");
if (q == NULL)
{
return web::THE_NO_CONTENT_LENGTH;
}
task.m_uTotalBytes = task.m_uReadBytes + atoi(q + 17);
}
else if (dwRet == HTTP_STATUS_OK) // 200: 重新下载(服务器不支持断点续传)
{
const char *q = strstr(strRecvBuf, "\r\nContent-Length:");
if (q == NULL)
{
return web::THE_NO_CONTENT_LENGTH;
}
task.m_uTotalBytes = task.m_uReadBytes + atoi(q + 17);
// 清除已经下载的内容
task.m_uReadBytes = 0;
bufVec.Reset();
}
else if (dwRet == HTTP_STATUS_REDIRECT) // 302: 重定向
{
const char *q = strstr(strRecvBuf, "\r\nLocation:");
if (q == NULL)
{
return web::THE_NO_REDIRECT_LOCATION;
}
q += 12;
const char *r = strstr(q, "\r\n");
if (r == NULL)
{
return web::THE_REDIRECT_INVALID_FORMAT;
}
int nUrlLen = r - q;
CStringA strUrlA;
strUrlA.Append(q, nUrlLen);
task.m_strUrl = GetString(strUrlA);
return web::THE_REDIRECT;
}
else
{
return web::THE_INVALID_STAUS_CODE;
}
}
// 复制已传回来的第一部分content
int nSize = nLen - (pData - szRecvBuf);
BYTE *pBuffer = bufVec.Ptr(task.m_uReadBytes, nSize);
memcpy(pBuffer, pData, nSize);
task.m_uReadBytes += nSize;
// 继续接收http content,即下载内容。
int nBufferSize = this->GetBufferSize(task);
pBuffer = bufVec.Ptr(task.m_uReadBytes, nBufferSize);
DWORD dwLastTick = 0;
// 下载测速
DWORD dwTickStart = ::GetTickCount();
unsigned int uReadBytesStart = task.m_uReadBytes;
while (true)
{
if (::InterlockedCompareExchange(task.m_pTerminate, 1, 1))
{
// 用户取消。
return web::THE_USER_CANCELED;
}
nLen = recv(hSocket, (char *)(pBuffer), nBufferSize, 0);
if (nLen < 0)
{
return web::THE_RECV_FAIL;
}
else if (nLen == 0)
{
break; // 接收完成
}
task.m_uReadBytes += nLen;
if (task.m_uReadBytes == task.m_uTotalBytes)
{
break; // 接收完成
}
nBufferSize = this->GetBufferSize(task);
pBuffer = bufVec.Ptr(task.m_uReadBytes, nBufferSize);
if (NULL != task.m_hWnd)
{
DWORD dwTick = ::GetTickCount();
if (dwLastTick == 0 || (dwTick - dwLastTick >= 100)) // 每秒最多发10次消息
{
// 发送当前下载进度和剩余时间消息
dwLastTick = dwTick;
::PostMessage(task.m_hWnd, WM_FASTINSTALL_PROGRESS_VALUE,
static_cast<WPARAM>(task.Percentage()), static_cast<LPARAM>(task.RemainTimeSec(dwTick - dwTickStart, task.m_uReadBytes - uReadBytesStart))
);
}
}
}
DWORD dwTick = ::GetTickCount();
if (NULL != task.m_hWnd)
{
::PostMessage(
task.m_hWnd, WM_FASTINSTALL_PROGRESS_VALUE,
static_cast<WPARAM>(task.Percentage()), -1
);
}
return web::THE_SUCCEED;
}
DWORD CSocketDownloader::ConnectServer(const CDownloadTask &task, SOCKET hSocket)
{
PHOSTENT pHostent = gethostbyname(task.m_strHostA);
if (pHostent == NULL)
{
return web::THE_GET_HOST_BY_NAME;
}
sockaddr_in addrSvr;
addrSvr.sin_port = htons((u_short)task.m_nPort);
addrSvr.sin_family = AF_INET;
addrSvr.sin_addr.s_addr = *(ULONG*)pHostent->h_addr_list[0];
if (SOCKET_ERROR == connect(hSocket, (sockaddr*)&addrSvr, sizeof(addrSvr)))
{
return web::THE_CONNECT_SOCKET;
}
int opt = task.m_nTimeoutSec * 1000;
if (0 != setsockopt(hSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&opt, sizeof(opt)))
{
return web::THE_SET_SOCK_OPT1;
}
BOOL bKeepAlive = TRUE;
int len = sizeof(bKeepAlive);
getsockopt(hSocket, SOL_SOCKET, SO_KEEPALIVE, (char*)&bKeepAlive, &len);
bKeepAlive = TRUE;
if (0 != setsockopt(hSocket, SOL_SOCKET, SO_KEEPALIVE, (char *)&bKeepAlive, sizeof(BOOL)))
{
return web::THE_SET_SOCK_OPT2;
}
return web::THE_SUCCEED;
}
int CSocketDownloader::GetSleepSecCount( int nTryCount ) const
{
return (nTryCount + 1) * 1000;
}
int CSocketDownloader::GetBufferSize( const CDownloadTask &task ) const
{
return std::min<int>(BLOCK_SIZE, task.m_uTotalBytes - task.m_uReadBytes);
}
CStringA CSocketDownloader::GenerateRequest( CDownloadTask &task ) const
{
CStringA strRequest;
CStringA strTemp;
if (task.m_strQueryA.IsEmpty())
{
strTemp.Format(
"GET %s HTTP/1.1\r\nHOST: %s\r\n",
task.m_strAbsoluteUrlA.GetString(), task.m_strHostA.GetString()
);
}
else
{
strTemp.Format(
"GET %s?%s HTTP/1.1\r\nHOST: %s\r\n",
task.m_strAbsoluteUrlA.GetString(), task.m_strQueryA.GetString(), task.m_strHostA.GetString()
);
}
strRequest.Append(strTemp);
strTemp.Format("Range: bytes=%d-\r\n", task.m_uReadBytes);
strRequest.Append(strTemp);
strTemp.Format("User-Agent: %s\r\n", task.GetAgnetA().GetString());
strRequest.Append(strTemp);
strRequest.Append("Accept: */*\r\n");
strRequest.Append("Accept-Encoding: gzip, deflate\r\n");
strRequest.Append("Connection: Keep-Alive\r\n\r\n");
return strRequest;
}