diff --git a/src/frame/video_frame.cpp b/src/frame/video_frame.cpp index dc872dc..eb0426e 100644 --- a/src/frame/video_frame.cpp +++ b/src/frame/video_frame.cpp @@ -12,7 +12,7 @@ VideoFrame::VideoFrame(size_t size) { height_ = 0; } -VideoFrame::VideoFrame(size_t size, size_t width, size_t height) { +VideoFrame::VideoFrame(size_t size, uint32_t width, uint32_t height) { buffer_ = new uint8_t[size]; size_ = size; width_ = width; @@ -27,8 +27,8 @@ VideoFrame::VideoFrame(const uint8_t *buffer, size_t size) { height_ = 0; } -VideoFrame::VideoFrame(const uint8_t *buffer, size_t size, size_t width, - size_t height) { +VideoFrame::VideoFrame(const uint8_t *buffer, size_t size, uint32_t width, + uint32_t height) { buffer_ = new uint8_t[size]; memcpy(buffer_, buffer, size); size_ = size; diff --git a/src/frame/video_frame.h b/src/frame/video_frame.h index f57d01c..01c4c03 100644 --- a/src/frame/video_frame.h +++ b/src/frame/video_frame.h @@ -14,9 +14,10 @@ class VideoFrame { public: VideoFrame(); VideoFrame(size_t size); - VideoFrame(size_t size, size_t width, size_t height); + VideoFrame(size_t size, uint32_t width, uint32_t height); VideoFrame(const uint8_t *buffer, size_t size); - VideoFrame(const uint8_t *buffer, size_t size, size_t width, size_t height); + VideoFrame(const uint8_t *buffer, size_t size, uint32_t width, + uint32_t height); VideoFrame(const VideoFrame &video_frame); VideoFrame(VideoFrame &&video_frame); VideoFrame &operator=(const VideoFrame &video_frame); @@ -27,18 +28,18 @@ class VideoFrame { public: const uint8_t *Buffer() { return buffer_; } size_t Size() { return size_; } - size_t Width() { return width_; } - size_t Height() { return height_; } + uint32_t Width() { return width_; } + uint32_t Height() { return height_; } void SetSize(size_t size) { size_ = size; } - void SetWidth(size_t width) { width_ = width; } - void SetHeight(size_t height) { height_ = height; } + void SetWidth(uint32_t width) { width_ = width; } + void SetHeight(uint32_t height) { height_ = height; } private: uint8_t *buffer_ = nullptr; size_t size_ = 0; - size_t width_ = 0; - size_t height_ = 0; + uint32_t width_ = 0; + uint32_t height_ = 0; }; #endif \ No newline at end of file diff --git a/src/ice/ice_agent.cpp b/src/ice/ice_agent.cpp index 75a56c4..e2a24a7 100644 --- a/src/ice/ice_agent.cpp +++ b/src/ice/ice_agent.cpp @@ -190,11 +190,10 @@ int IceAgent::CreateIceAgent(nice_cb_state_changed_t on_state_changed, return 0; } -void cb_closed(GObject *src, GAsyncResult *res, gpointer data) { - NiceAgent *agent = NICE_AGENT(src); - g_debug("test-turn:%s: %p", G_STRFUNC, agent); - - *((gboolean *)data) = TRUE; +void cb_closed(GObject *src, [[maybe_unused]] GAsyncResult *res, + [[maybe_unused]] gpointer data) { + [[maybe_unused]] NiceAgent *agent = NICE_AGENT(src); + LOG_INFO("Nice agent closed"); } int IceAgent::DestroyIceAgent() { @@ -384,11 +383,11 @@ int IceAgent::Send(const char *data, size_t size) { // return -1; // } - int ret = nice_agent_send(agent_, stream_id_, 1, size, data); + bool ret = nice_agent_send(agent_, stream_id_, 1, (guint)size, data); #ifdef SAVE_IO_STREAM fwrite(data, 1, size, file_out_); #endif - return 0; + return ret ? 0 : -1; } \ No newline at end of file diff --git a/src/log/log.cpp b/src/log/log.cpp index 363e650..d690207 100644 --- a/src/log/log.cpp +++ b/src/log/log.cpp @@ -6,12 +6,20 @@ std::shared_ptr get_logger() { } auto now = std::chrono::system_clock::now() + std::chrono::hours(8); - auto timet = std::chrono::system_clock::to_time_t(now); - auto localTime = *std::gmtime(&timet); + auto now_time = std::chrono::system_clock::to_time_t(now); + + std::tm tm_info; + +#ifdef _WIN32 + gmtime_s(&tm_info, &now_time); +#else + std::gmtime_r(&now_time, &tm_info); +#endif + std::stringstream ss; std::string filename; ss << LOGGER_NAME; - ss << std::put_time(&localTime, "-%Y%m%d-%H%M%S.log"); + ss << std::put_time(&tm_info, "-%Y%m%d-%H%M%S.log"); ss >> filename; std::string path = "logs/" + filename; diff --git a/src/media/audio/decode/audio_decoder.cpp b/src/media/audio/decode/audio_decoder.cpp index 767f25c..eb281a7 100644 --- a/src/media/audio/decode/audio_decoder.cpp +++ b/src/media/audio/decode/audio_decoder.cpp @@ -36,11 +36,11 @@ int AudioDecoder::Init() { } int AudioDecoder::Decode( - const uint8_t* data, int size, + const uint8_t* data, size_t size, std::function on_receive_decoded_frame) { // LOG_ERROR("input opus size = {}", size); - auto frame_size = - opus_decode(opus_decoder_, data, size, out_data, MAX_FRAME_SIZE, 0); + auto frame_size = opus_decode(opus_decoder_, data, (opus_int32)size, out_data, + MAX_FRAME_SIZE, 0); if (frame_size < 0) { LOG_ERROR("Decode opus frame failed"); diff --git a/src/media/audio/decode/audio_decoder.h b/src/media/audio/decode/audio_decoder.h index 9f6e21d..4cd694f 100644 --- a/src/media/audio/decode/audio_decoder.h +++ b/src/media/audio/decode/audio_decoder.h @@ -26,7 +26,7 @@ class AudioDecoder { public: int Init(); - int Decode(const uint8_t *data, int size, + int Decode(const uint8_t *data, size_t size, std::function on_receive_decoded_frame); std::string GetDecoderName() { return "Opus"; } diff --git a/src/media/audio/encode/audio_encoder.cpp b/src/media/audio/encode/audio_encoder.cpp index 162af6c..6489990 100644 --- a/src/media/audio/encode/audio_encoder.cpp +++ b/src/media/audio/encode/audio_encoder.cpp @@ -52,7 +52,7 @@ int AudioEncoder::Init() { } int AudioEncoder::Encode( - const uint8_t *data, int size, + const uint8_t *data, size_t size, std::function on_encoded_audio_buffer) { if (!on_encoded_audio_buffer_) { @@ -67,7 +67,7 @@ int AudioEncoder::Encode( // printf("1 Time cost: %d size: %d\n", now_ts - last_ts, size); // last_ts = now_ts; - auto ret = opus_encode(opus_encoder_, (opus_int16 *)data, size, out_data, + auto ret = opus_encode(opus_encoder_, (opus_int16 *)data, (int)size, out_data, MAX_PACKET_SIZE); if (ret < 0) { printf("opus decode failed, %d\n", ret); @@ -76,15 +76,7 @@ int AudioEncoder::Encode( if (on_encoded_audio_buffer_) { on_encoded_audio_buffer_((char *)out_data, ret); - } else { - OnEncodedAudioBuffer((char *)out_data, ret); } return 0; } - -int AudioEncoder::OnEncodedAudioBuffer(char *encoded_audio_buffer, - size_t size) { - LOG_INFO("OnEncodedAudioBuffer not implemented"); - return 0; -} diff --git a/src/media/audio/encode/audio_encoder.h b/src/media/audio/encode/audio_encoder.h index 718074b..2d0d589 100644 --- a/src/media/audio/encode/audio_encoder.h +++ b/src/media/audio/encode/audio_encoder.h @@ -23,12 +23,10 @@ class AudioEncoder { public: int Init(); - int Encode(const uint8_t* data, int size, + int Encode(const uint8_t* data, size_t size, std::function on_encoded_audio_buffer); - int OnEncodedAudioBuffer(char* encoded_audio_buffer, size_t size); - std::string GetEncoderName() { return "Opus"; } private: diff --git a/src/media/nvcodec/Logger.h b/src/media/nvcodec/Logger.h index c6ae4c4..4d7b227 100644 --- a/src/media/nvcodec/Logger.h +++ b/src/media/nvcodec/Logger.h @@ -27,230 +27,232 @@ #pragma once -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include + #ifdef _WIN32 -#include #include +#include #pragma comment(lib, "ws2_32.lib") #undef ERROR #else -#include -#include -#include #include +#include +#include +#include + #define SOCKET int #define INVALID_SOCKET -1 #endif -enum LogLevel { - TRACE, - INFO, - WARNING, - ERROR, - FATAL -}; +enum LogLevel { TRACE, INFO, WARNING, ERROR, FATAL }; -namespace simplelogger{ +namespace simplelogger { class Logger { -public: - Logger(LogLevel level, bool bPrintTimeStamp) : level(level), bPrintTimeStamp(bPrintTimeStamp) {} - virtual ~Logger() {} - virtual std::ostream& GetStream() = 0; - virtual void FlushStream() {} - bool ShouldLogFor(LogLevel l) { - return l >= level; + public: + Logger(LogLevel level, bool bPrintTimeStamp) + : level(level), bPrintTimeStamp(bPrintTimeStamp) {} + virtual ~Logger() {} + virtual std::ostream &GetStream() = 0; + virtual void FlushStream() {} + bool ShouldLogFor(LogLevel l) { return l >= level; } + char *GetLead(LogLevel l, [[maybe_unused]] const char *szFile, + [[maybe_unused]] int nLine, + [[maybe_unused]] const char *szFunc) { + if (l < TRACE || l > FATAL) { + sprintf(szLead, "[?????] "); + return szLead; } - char* GetLead(LogLevel l, const char *szFile, int nLine, const char *szFunc) { - if (l < TRACE || l > FATAL) { - sprintf(szLead, "[?????] "); - return szLead; - } - const char *szLevels[] = {"TRACE", "INFO", "WARN", "ERROR", "FATAL"}; - if (bPrintTimeStamp) { - time_t t = time(NULL); - struct tm *ptm = localtime(&t); - sprintf(szLead, "[%-5s][%02d:%02d:%02d] ", - szLevels[l], ptm->tm_hour, ptm->tm_min, ptm->tm_sec); - } else { - sprintf(szLead, "[%-5s] ", szLevels[l]); - } - return szLead; + + const char *szLevels[] = {"TRACE", "INFO", "WARN", "ERROR", "FATAL"}; + if (bPrintTimeStamp) { + time_t t = time(NULL); + struct tm *ptm = localtime(&t); + sprintf(szLead, "[%-5s][%02d:%02d:%02d] ", szLevels[l], ptm->tm_hour, + ptm->tm_min, ptm->tm_sec); + } else { + sprintf(szLead, "[%-5s] ", szLevels[l]); } - void EnterCriticalSection() { - mtx.lock(); - } - void LeaveCriticalSection() { - mtx.unlock(); - } -private: - LogLevel level; - char szLead[80]; - bool bPrintTimeStamp; - std::mutex mtx; + return szLead; + } + void EnterCriticalSection() { mtx.lock(); } + void LeaveCriticalSection() { mtx.unlock(); } + + private: + LogLevel level; + char szLead[80]; + bool bPrintTimeStamp; + std::mutex mtx; }; class LoggerFactory { -public: - static Logger* CreateFileLogger(std::string strFilePath, - LogLevel level = INFO, bool bPrintTimeStamp = true) { - return new FileLogger(strFilePath, level, bPrintTimeStamp); - } - static Logger* CreateConsoleLogger(LogLevel level = INFO, - bool bPrintTimeStamp = true) { - return new ConsoleLogger(level, bPrintTimeStamp); - } - static Logger* CreateUdpLogger(char *szHost, unsigned uPort, LogLevel level = INFO, - bool bPrintTimeStamp = true) { - return new UdpLogger(szHost, uPort, level, bPrintTimeStamp); - } -private: - LoggerFactory() {} + public: + static Logger *CreateFileLogger(std::string strFilePath, + LogLevel level = INFO, + bool bPrintTimeStamp = true) { + return new FileLogger(strFilePath, level, bPrintTimeStamp); + } + static Logger *CreateConsoleLogger(LogLevel level = INFO, + bool bPrintTimeStamp = true) { + return new ConsoleLogger(level, bPrintTimeStamp); + } + static Logger *CreateUdpLogger(char *szHost, unsigned uPort, + LogLevel level = INFO, + bool bPrintTimeStamp = true) { + return new UdpLogger(szHost, uPort, level, bPrintTimeStamp); + } - class FileLogger : public Logger { - public: - FileLogger(std::string strFilePath, LogLevel level, bool bPrintTimeStamp) + private: + LoggerFactory() {} + + class FileLogger : public Logger { + public: + FileLogger(std::string strFilePath, LogLevel level, bool bPrintTimeStamp) : Logger(level, bPrintTimeStamp) { - pFileOut = new std::ofstream(); - pFileOut->open(strFilePath.c_str()); - } - ~FileLogger() { - pFileOut->close(); - } - std::ostream& GetStream() { - return *pFileOut; - } - private: - std::ofstream *pFileOut; - }; + pFileOut = new std::ofstream(); + pFileOut->open(strFilePath.c_str()); + } + ~FileLogger() { pFileOut->close(); } + std::ostream &GetStream() { return *pFileOut; } - class ConsoleLogger : public Logger { - public: - ConsoleLogger(LogLevel level, bool bPrintTimeStamp) + private: + std::ofstream *pFileOut; + }; + + class ConsoleLogger : public Logger { + public: + ConsoleLogger(LogLevel level, bool bPrintTimeStamp) : Logger(level, bPrintTimeStamp) {} - std::ostream& GetStream() { - return std::cout; + std::ostream &GetStream() { return std::cout; } + }; + + class UdpLogger : public Logger { + private: + class UdpOstream : public std::ostream { + public: + UdpOstream(char *szHost, unsigned short uPort) + : std::ostream(&sb), socket(INVALID_SOCKET) { +#ifdef _WIN32 + WSADATA w; + if (WSAStartup(0x0101, &w) != 0) { + fprintf(stderr, "WSAStartup() failed.\n"); + return; } +#endif + socket = ::socket(AF_INET, SOCK_DGRAM, 0); + if (socket == INVALID_SOCKET) { +#ifdef _WIN32 + WSACleanup(); +#endif + fprintf(stderr, "socket() failed.\n"); + return; + } +#ifdef _WIN32 + unsigned int b1, b2, b3, b4; + sscanf(szHost, "%u.%u.%u.%u", &b1, &b2, &b3, &b4); + struct in_addr addr = {(unsigned char)b1, (unsigned char)b2, + (unsigned char)b3, (unsigned char)b4}; +#else + struct in_addr addr = {inet_addr(szHost)}; +#endif + struct sockaddr_in s = {AF_INET, htons(uPort), addr}; + server = s; + } + ~UdpOstream() throw() { + if (socket == INVALID_SOCKET) { + return; + } +#ifdef _WIN32 + closesocket(socket); + WSACleanup(); +#else + close(socket); +#endif + } + void Flush() { + if (sendto(socket, sb.str().c_str(), (int)sb.str().length() + 1, 0, + (struct sockaddr *)&server, + (int)sizeof(sockaddr_in)) == -1) { + fprintf(stderr, "sendto() failed.\n"); + } + sb.str(""); + } + + private: + std::stringbuf sb; + SOCKET socket; + struct sockaddr_in server; }; - class UdpLogger : public Logger { - private: - class UdpOstream : public std::ostream { - public: - UdpOstream(char *szHost, unsigned short uPort) : std::ostream(&sb), socket(INVALID_SOCKET){ -#ifdef _WIN32 - WSADATA w; - if (WSAStartup(0x0101, &w) != 0) { - fprintf(stderr, "WSAStartup() failed.\n"); - return; - } -#endif - socket = ::socket(AF_INET, SOCK_DGRAM, 0); - if (socket == INVALID_SOCKET) { -#ifdef _WIN32 - WSACleanup(); -#endif - fprintf(stderr, "socket() failed.\n"); - return; - } -#ifdef _WIN32 - unsigned int b1, b2, b3, b4; - sscanf(szHost, "%u.%u.%u.%u", &b1, &b2, &b3, &b4); - struct in_addr addr = {(unsigned char)b1, (unsigned char)b2, (unsigned char)b3, (unsigned char)b4}; -#else - struct in_addr addr = {inet_addr(szHost)}; -#endif - struct sockaddr_in s = {AF_INET, htons(uPort), addr}; - server = s; - } - ~UdpOstream() throw() { - if (socket == INVALID_SOCKET) { - return; - } -#ifdef _WIN32 - closesocket(socket); - WSACleanup(); -#else - close(socket); -#endif - } - void Flush() { - if (sendto(socket, sb.str().c_str(), (int)sb.str().length() + 1, - 0, (struct sockaddr *)&server, (int)sizeof(sockaddr_in)) == -1) { - fprintf(stderr, "sendto() failed.\n"); - } - sb.str(""); - } + public: + UdpLogger(char *szHost, unsigned uPort, LogLevel level, + bool bPrintTimeStamp) + : Logger(level, bPrintTimeStamp), + udpOut(szHost, (unsigned short)uPort) {} + UdpOstream &GetStream() { return udpOut; } + virtual void FlushStream() { udpOut.Flush(); } - private: - std::stringbuf sb; - SOCKET socket; - struct sockaddr_in server; - }; - public: - UdpLogger(char *szHost, unsigned uPort, LogLevel level, bool bPrintTimeStamp) - : Logger(level, bPrintTimeStamp), udpOut(szHost, (unsigned short)uPort) {} - UdpOstream& GetStream() { - return udpOut; - } - virtual void FlushStream() { - udpOut.Flush(); - } - private: - UdpOstream udpOut; - }; + private: + UdpOstream udpOut; + }; }; class LogTransaction { -public: - LogTransaction(Logger *pLogger, LogLevel level, const char *szFile, const int nLine, const char *szFunc) : pLogger(pLogger), level(level) { - if (!pLogger) { - std::cout << "[-----] "; - return; - } - if (!pLogger->ShouldLogFor(level)) { - return; - } - pLogger->EnterCriticalSection(); - pLogger->GetStream() << pLogger->GetLead(level, szFile, nLine, szFunc); + public: + LogTransaction(Logger *pLogger, LogLevel level, const char *szFile, + const int nLine, const char *szFunc) + : pLogger(pLogger), level(level) { + if (!pLogger) { + std::cout << "[-----] "; + return; } - ~LogTransaction() { - if (!pLogger) { - std::cout << std::endl; - return; - } - if (!pLogger->ShouldLogFor(level)) { - return; - } - pLogger->GetStream() << std::endl; - pLogger->FlushStream(); - pLogger->LeaveCriticalSection(); - if (level == FATAL) { - exit(1); - } + if (!pLogger->ShouldLogFor(level)) { + return; } - std::ostream& GetStream() { - if (!pLogger) { - return std::cout; - } - if (!pLogger->ShouldLogFor(level)) { - return ossNull; - } - return pLogger->GetStream(); + pLogger->EnterCriticalSection(); + pLogger->GetStream() << pLogger->GetLead(level, szFile, nLine, szFunc); + } + ~LogTransaction() { + if (!pLogger) { + std::cout << std::endl; + return; } -private: - Logger *pLogger; - LogLevel level; - std::ostringstream ossNull; + if (!pLogger->ShouldLogFor(level)) { + return; + } + pLogger->GetStream() << std::endl; + pLogger->FlushStream(); + pLogger->LeaveCriticalSection(); + if (level == FATAL) { + exit(1); + } + } + std::ostream &GetStream() { + if (!pLogger) { + return std::cout; + } + if (!pLogger->ShouldLogFor(level)) { + return ossNull; + } + return pLogger->GetStream(); + } + + private: + Logger *pLogger; + LogLevel level; + std::ostringstream ossNull; }; -} +} // namespace simplelogger extern simplelogger::Logger *logger; -#define LOG(level) simplelogger::LogTransaction(logger, level, __FILE__, __LINE__, __FUNCTION__).GetStream() +#define LOG(level) \ + simplelogger::LogTransaction(logger, level, __FILE__, __LINE__, \ + __FUNCTION__) \ + .GetStream() diff --git a/src/media/nvcodec/NvDecoder.cpp b/src/media/nvcodec/NvDecoder.cpp index e1aa575..5c0109d 100644 --- a/src/media/nvcodec/NvDecoder.cpp +++ b/src/media/nvcodec/NvDecoder.cpp @@ -24,6 +24,8 @@ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR * OTHER DEALINGS IN THE SOFTWARE. */ +#pragma warning(push) +#pragma warning(disable : 4244) #include "NvDecoder.h" @@ -222,7 +224,7 @@ int NvDecoder::HandleVideoSequence(CUVIDEOFORMAT *pVideoFormat) { if (!decodecaps.bIsSupported) { NVDEC_THROW_ERROR("Codec not supported on this GPU", CUDA_ERROR_NOT_SUPPORTED); - return nDecodeSurface; + // return nDecodeSurface; } if ((pVideoFormat->coded_width > decodecaps.nMaxWidth) || @@ -237,7 +239,7 @@ int NvDecoder::HandleVideoSequence(CUVIDEOFORMAT *pVideoFormat) { const std::string cErr = errorString.str(); NVDEC_THROW_ERROR(cErr, CUDA_ERROR_NOT_SUPPORTED); - return nDecodeSurface; + // return nDecodeSurface; } if ((pVideoFormat->coded_width >> 4) * (pVideoFormat->coded_height >> 4) > @@ -254,7 +256,7 @@ int NvDecoder::HandleVideoSequence(CUVIDEOFORMAT *pVideoFormat) { const std::string cErr = errorString.str(); NVDEC_THROW_ERROR(cErr, CUDA_ERROR_NOT_SUPPORTED); - return nDecodeSurface; + // return nDecodeSurface; } if (m_nWidth && m_nLumaHeight && m_nChromaHeight) { @@ -571,7 +573,7 @@ int NvDecoder::setReconfigParams(const Rect *pCropRect, const Dim *pResizeDim) { int NvDecoder::HandlePictureDecode(CUVIDPICPARAMS *pPicParams) { if (!m_hDecoder) { NVDEC_THROW_ERROR("Decoder not initialized.", CUDA_ERROR_NOT_INITIALIZED); - return false; + // return false; } m_nPicNumInDecodeOrder[pPicParams->CurrPicIdx] = m_nDecodePicCnt++; CUDA_DRVAPI_CALL(cuCtxPushCurrent(m_cuContext)); @@ -921,3 +923,4 @@ void NvDecoder::UnlockFrame(uint8_t **pFrame) { uint64_t timestamp[2] = {0}; m_vTimestamp.insert(m_vTimestamp.end(), ×tamp[0], ×tamp[1]); } +#pragma warning(pop) \ No newline at end of file diff --git a/src/media/nvcodec/NvEncoder.cpp b/src/media/nvcodec/NvEncoder.cpp index 1074726..5217d9a 100644 --- a/src/media/nvcodec/NvEncoder.cpp +++ b/src/media/nvcodec/NvEncoder.cpp @@ -769,12 +769,11 @@ uint32_t NvEncoder::GetWidthInBytes(const NV_ENC_BUFFER_FORMAT bufferFormat, return width * 4; default: NVENC_THROW_ERROR("Invalid Buffer format", NV_ENC_ERR_INVALID_PARAM); - return 0; + // return 0; } } -uint32_t NvEncoder::GetNumChromaPlanes( - const NV_ENC_BUFFER_FORMAT bufferFormat) { +int32_t NvEncoder::GetNumChromaPlanes(const NV_ENC_BUFFER_FORMAT bufferFormat) { switch (bufferFormat) { case NV_ENC_BUFFER_FORMAT_NV12: case NV_ENC_BUFFER_FORMAT_YUV420_10BIT: @@ -792,12 +791,12 @@ uint32_t NvEncoder::GetNumChromaPlanes( return 0; default: NVENC_THROW_ERROR("Invalid Buffer format", NV_ENC_ERR_INVALID_PARAM); - return -1; + // return -1; } } -uint32_t NvEncoder::GetChromaPitch(const NV_ENC_BUFFER_FORMAT bufferFormat, - const uint32_t lumaPitch) { +int32_t NvEncoder::GetChromaPitch(const NV_ENC_BUFFER_FORMAT bufferFormat, + const uint32_t lumaPitch) { switch (bufferFormat) { case NV_ENC_BUFFER_FORMAT_NV12: case NV_ENC_BUFFER_FORMAT_YUV420_10BIT: @@ -815,7 +814,7 @@ uint32_t NvEncoder::GetChromaPitch(const NV_ENC_BUFFER_FORMAT bufferFormat, return 0; default: NVENC_THROW_ERROR("Invalid Buffer format", NV_ENC_ERR_INVALID_PARAM); - return -1; + // return -1; } } @@ -871,7 +870,7 @@ uint32_t NvEncoder::GetChromaHeight(const NV_ENC_BUFFER_FORMAT bufferFormat, return 0; default: NVENC_THROW_ERROR("Invalid Buffer format", NV_ENC_ERR_INVALID_PARAM); - return 0; + // return 0; } } @@ -897,7 +896,7 @@ uint32_t NvEncoder::GetChromaWidthInBytes( return 0; default: NVENC_THROW_ERROR("Invalid Buffer format", NV_ENC_ERR_INVALID_PARAM); - return 0; + // return 0; } } @@ -934,7 +933,7 @@ int NvEncoder::GetFrameSize() const { return 4 * GetEncodeWidth() * GetEncodeHeight(); default: NVENC_THROW_ERROR("Invalid Buffer format", NV_ENC_ERR_INVALID_PARAM); - return 0; + // return 0; } } diff --git a/src/media/nvcodec/NvEncoder.h b/src/media/nvcodec/NvEncoder.h index ed847a1..a81b0ab 100644 --- a/src/media/nvcodec/NvEncoder.h +++ b/src/media/nvcodec/NvEncoder.h @@ -317,14 +317,14 @@ class NvEncoder { * @brief This a static function to get the chroma plane pitch for YUV planar * formats. */ - static uint32_t GetChromaPitch(const NV_ENC_BUFFER_FORMAT bufferFormat, - const uint32_t lumaPitch); + static int32_t GetChromaPitch(const NV_ENC_BUFFER_FORMAT bufferFormat, + const uint32_t lumaPitch); /** * @brief This a static function to get the number of chroma planes for YUV * planar formats. */ - static uint32_t GetNumChromaPlanes(const NV_ENC_BUFFER_FORMAT bufferFormat); + static int32_t GetNumChromaPlanes(const NV_ENC_BUFFER_FORMAT bufferFormat); /** * @brief This a static function to get the chroma plane width in bytes for diff --git a/src/media/nvcodec/NvEncoderCLIOptions.h b/src/media/nvcodec/NvEncoderCLIOptions.h index de6518c..12a10f1 100644 --- a/src/media/nvcodec/NvEncoderCLIOptions.h +++ b/src/media/nvcodec/NvEncoderCLIOptions.h @@ -26,14 +26,15 @@ */ #pragma once -#include -#include #include -#include -#include -#include #include #include +#include +#include +#include +#include +#include + #include "Logger.h" #include "nvEncodeAPI.h" @@ -41,11 +42,11 @@ extern simplelogger::Logger *logger; #ifndef _WIN32 inline bool operator==(const GUID &guid1, const GUID &guid2) { - return !memcmp(&guid1, &guid2, sizeof(GUID)); + return !memcmp(&guid1, &guid2, sizeof(GUID)); } inline bool operator!=(const GUID &guid1, const GUID &guid2) { - return !(guid1 == guid2); + return !(guid1 == guid2); } #endif @@ -56,785 +57,1140 @@ inline bool operator!=(const GUID &guid1, const GUID &guid2) { * initialization parameters. */ class NvEncoderInitParam { -public: - NvEncoderInitParam(const char *szParam = "", - std::function *pfuncInit = NULL, bool _bLowLatency = false) - : strParam(szParam), bLowLatency(_bLowLatency) - { - if (pfuncInit) { - funcInit = *pfuncInit; - } - - std::transform(strParam.begin(), strParam.end(), strParam.begin(), tolower); - std::istringstream ss(strParam); - tokens = std::vector { - std::istream_iterator(ss), - std::istream_iterator() - }; - - for (unsigned i = 0; i < tokens.size(); i++) - { - if (tokens[i] == "-codec" && ++i != tokens.size()) - { - ParseString("-codec", tokens[i], vCodec, szCodecNames, &guidCodec); - continue; - } - if (tokens[i] == "-preset" && ++i != tokens.size()) { - ParseString("-preset", tokens[i], vPreset, szPresetNames, &guidPreset); - continue; - } - if (tokens[i] == "-tuninginfo" && ++i != tokens.size()) - { - ParseString("-tuninginfo", tokens[i], vTuningInfo, szTuningInfoNames, &m_TuningInfo); - continue; - } - } - } - virtual ~NvEncoderInitParam() {} - virtual bool IsCodecH264() { - return GetEncodeGUID() == NV_ENC_CODEC_H264_GUID; + public: + NvEncoderInitParam( + const char *szParam = "", + std::function *pfuncInit = NULL, + bool _bLowLatency = false) + : strParam(szParam), bLowLatency(_bLowLatency) { + if (pfuncInit) { + funcInit = *pfuncInit; } - virtual bool IsCodecHEVC() { - return GetEncodeGUID() == NV_ENC_CODEC_HEVC_GUID; + std::transform(strParam.begin(), strParam.end(), strParam.begin(), tolower); + std::istringstream ss(strParam); + tokens = std::vector{std::istream_iterator(ss), + std::istream_iterator()}; + + for (unsigned i = 0; i < tokens.size(); i++) { + if (tokens[i] == "-codec" && ++i != tokens.size()) { + ParseString("-codec", tokens[i], vCodec, szCodecNames, &guidCodec); + continue; + } + if (tokens[i] == "-preset" && ++i != tokens.size()) { + ParseString("-preset", tokens[i], vPreset, szPresetNames, &guidPreset); + continue; + } + if (tokens[i] == "-tuninginfo" && ++i != tokens.size()) { + ParseString("-tuninginfo", tokens[i], vTuningInfo, szTuningInfoNames, + &m_TuningInfo); + continue; + } + } + } + virtual ~NvEncoderInitParam() {} + virtual bool IsCodecH264() { + return GetEncodeGUID() == NV_ENC_CODEC_H264_GUID; + } + + virtual bool IsCodecHEVC() { + return GetEncodeGUID() == NV_ENC_CODEC_HEVC_GUID; + } + + virtual bool IsCodecAV1() { return GetEncodeGUID() == NV_ENC_CODEC_AV1_GUID; } + + std::string GetHelpMessage(bool bMeOnly = false, bool bUnbuffered = false, + bool bHide444 = false, + bool bOutputInVidMem = false) { + std::ostringstream oss; + + if (bOutputInVidMem && bMeOnly) { + oss << "-codec Codec: " + << "h264" << std::endl; + } else { + oss << "-codec Codec: " << szCodecNames << std::endl; } - virtual bool IsCodecAV1() { - return GetEncodeGUID() == NV_ENC_CODEC_AV1_GUID; + oss << "-preset Preset: " << szPresetNames << std::endl + << "-profile H264: " << szH264ProfileNames; + + if (bOutputInVidMem && bMeOnly) { + oss << std::endl; + } else { + oss << "; HEVC: " << szHevcProfileNames; + oss << "; AV1: " << szAV1ProfileNames << std::endl; } - std::string GetHelpMessage(bool bMeOnly = false, bool bUnbuffered = false, bool bHide444 = false, bool bOutputInVidMem = false) - { - std::ostringstream oss; - - if (bOutputInVidMem && bMeOnly) - { - oss << "-codec Codec: " << "h264" << std::endl; - } - else - { - oss << "-codec Codec: " << szCodecNames << std::endl; - } - - oss << "-preset Preset: " << szPresetNames << std::endl - << "-profile H264: " << szH264ProfileNames; - - if (bOutputInVidMem && bMeOnly) - { - oss << std::endl; - } - else - { - oss << "; HEVC: " << szHevcProfileNames; - oss << "; AV1: " << szAV1ProfileNames << std::endl; - } - - if (!bMeOnly) - { - if (bLowLatency == false) - oss << "-tuninginfo TuningInfo: " << szTuningInfoNames << std::endl; - else - oss << "-tuninginfo TuningInfo: " << szLowLatencyTuningInfoNames << std::endl; - oss << "-multipass Multipass: " << szMultipass << std::endl; - } - - if (!bHide444 && !bLowLatency) - { - oss << "-444 (Only for RGB input) YUV444 encode. Not valid for AV1 Codec" << std::endl; - } - if (bMeOnly) return oss.str(); - oss << "-fps Frame rate" << std::endl; - - if (!bUnbuffered && !bLowLatency) - { - oss << "-bf Number of consecutive B-frames" << std::endl; - } - - if (!bLowLatency) - { - oss << "-rc Rate control mode: " << szRcModeNames << std::endl - << "-gop Length of GOP (Group of Pictures)" << std::endl - << "-bitrate Average bit rate, can be in unit of 1, K, M" << std::endl - << "Note: Fps or Average bit rate values for each session can be specified in the form of v1,v1,v3 (no space) for AppTransOneToN" << std::endl - << " If the number of 'bitrate' or 'fps' values specified are less than the number of sessions, then the last specified value will be considered for the remaining sessions" << std::endl - << "-maxbitrate Max bit rate, can be in unit of 1, K, M" << std::endl - << "-vbvbufsize VBV buffer size in bits, can be in unit of 1, K, M" << std::endl - << "-vbvinit VBV initial delay in bits, can be in unit of 1, K, M" << std::endl - << "-aq Enable spatial AQ and set its stength (range 1-15, 0-auto)" << std::endl - << "-temporalaq (No value) Enable temporal AQ" << std::endl - << "-cq Target constant quality level for VBR mode (range 1-51, 0-auto)" << std::endl; - } - if (!bUnbuffered && !bLowLatency) - { - oss << "-lookahead Maximum depth of lookahead (range 0-(31 - number of B frames))" << std::endl; - } - oss << "-qmin Min QP value" << std::endl - << "-qmax Max QP value" << std::endl - << "-initqp Initial QP value" << std::endl; - if (!bLowLatency) - { - oss << "-constqp QP value for constqp rate control mode" << std::endl - << "Note: QP value can be in the form of qp_of_P_B_I or qp_P,qp_B,qp_I (no space)" << std::endl; - } - if (bUnbuffered && !bLowLatency) - { - oss << "Note: Options -bf and -lookahead are unavailable for this app" << std::endl; - } - return oss.str(); + if (!bMeOnly) { + if (bLowLatency == false) + oss << "-tuninginfo TuningInfo: " << szTuningInfoNames << std::endl; + else + oss << "-tuninginfo TuningInfo: " << szLowLatencyTuningInfoNames + << std::endl; + oss << "-multipass Multipass: " << szMultipass << std::endl; } - /** - * @brief Generate and return a string describing the values of the main/common - * encoder initialization parameters - */ - std::string MainParamToString(const NV_ENC_INITIALIZE_PARAMS *pParams) { - std::ostringstream os; - os - << "Encoding Parameters:" - << std::endl << "\tcodec : " << ConvertValueToString(vCodec, szCodecNames, pParams->encodeGUID) - << std::endl << "\tpreset : " << ConvertValueToString(vPreset, szPresetNames, pParams->presetGUID); - if (pParams->tuningInfo) - { - os << std::endl << "\ttuningInfo : " << ConvertValueToString(vTuningInfo, szTuningInfoNames, pParams->tuningInfo); - } - os - << std::endl << "\tprofile : " << ConvertValueToString(vProfile, szProfileNames, pParams->encodeConfig->profileGUID) - << std::endl << "\tchroma : " << ConvertValueToString(vChroma, szChromaNames, (pParams->encodeGUID == NV_ENC_CODEC_H264_GUID) ? pParams->encodeConfig->encodeCodecConfig.h264Config.chromaFormatIDC : - (pParams->encodeGUID == NV_ENC_CODEC_HEVC_GUID) ? pParams->encodeConfig->encodeCodecConfig.hevcConfig.chromaFormatIDC : - pParams->encodeConfig->encodeCodecConfig.av1Config.chromaFormatIDC) - << std::endl << "\tbitdepth : " << ((pParams->encodeGUID == NV_ENC_CODEC_H264_GUID) ? pParams->encodeConfig->encodeCodecConfig.h264Config.inputBitDepth : (pParams->encodeGUID == NV_ENC_CODEC_HEVC_GUID) ? - pParams->encodeConfig->encodeCodecConfig.hevcConfig.inputBitDepth : pParams->encodeConfig->encodeCodecConfig.av1Config.inputBitDepth) - << std::endl << "\trc : " << ConvertValueToString(vRcMode, szRcModeNames, pParams->encodeConfig->rcParams.rateControlMode) - ; - if (pParams->encodeConfig->rcParams.rateControlMode == NV_ENC_PARAMS_RC_CONSTQP) { - os << " (P,B,I=" << pParams->encodeConfig->rcParams.constQP.qpInterP << "," << pParams->encodeConfig->rcParams.constQP.qpInterB << "," << pParams->encodeConfig->rcParams.constQP.qpIntra << ")"; - } - os - << std::endl << "\tfps : " << pParams->frameRateNum << "/" << pParams->frameRateDen - << std::endl << "\tgop : " << (pParams->encodeConfig->gopLength == NVENC_INFINITE_GOPLENGTH ? "INF" : std::to_string(pParams->encodeConfig->gopLength)) - << std::endl << "\tbf : " << pParams->encodeConfig->frameIntervalP - 1 - << std::endl << "\tmultipass : " << pParams->encodeConfig->rcParams.multiPass - << std::endl << "\tsize : " << pParams->encodeWidth << "x" << pParams->encodeHeight - << std::endl << "\tbitrate : " << pParams->encodeConfig->rcParams.averageBitRate - << std::endl << "\tmaxbitrate : " << pParams->encodeConfig->rcParams.maxBitRate - << std::endl << "\tvbvbufsize : " << pParams->encodeConfig->rcParams.vbvBufferSize - << std::endl << "\tvbvinit : " << pParams->encodeConfig->rcParams.vbvInitialDelay - << std::endl << "\taq : " << (pParams->encodeConfig->rcParams.enableAQ ? (pParams->encodeConfig->rcParams.aqStrength ? std::to_string(pParams->encodeConfig->rcParams.aqStrength) : "auto") : "disabled") - << std::endl << "\ttemporalaq : " << (pParams->encodeConfig->rcParams.enableTemporalAQ ? "enabled" : "disabled") - << std::endl << "\tlookahead : " << (pParams->encodeConfig->rcParams.enableLookahead ? std::to_string(pParams->encodeConfig->rcParams.lookaheadDepth) : "disabled") - << std::endl << "\tcq : " << (unsigned int)pParams->encodeConfig->rcParams.targetQuality - << std::endl << "\tqmin : P,B,I=" << (int)pParams->encodeConfig->rcParams.minQP.qpInterP << "," << (int)pParams->encodeConfig->rcParams.minQP.qpInterB << "," << (int)pParams->encodeConfig->rcParams.minQP.qpIntra - << std::endl << "\tqmax : P,B,I=" << (int)pParams->encodeConfig->rcParams.maxQP.qpInterP << "," << (int)pParams->encodeConfig->rcParams.maxQP.qpInterB << "," << (int)pParams->encodeConfig->rcParams.maxQP.qpIntra - << std::endl << "\tinitqp : P,B,I=" << (int)pParams->encodeConfig->rcParams.initialRCQP.qpInterP << "," << (int)pParams->encodeConfig->rcParams.initialRCQP.qpInterB << "," << (int)pParams->encodeConfig->rcParams.initialRCQP.qpIntra - ; - return os.str(); + if (!bHide444 && !bLowLatency) { + oss << "-444 (Only for RGB input) YUV444 encode. Not valid for " + "AV1 Codec" + << std::endl; + } + if (bMeOnly) return oss.str(); + oss << "-fps Frame rate" << std::endl; + + if (!bUnbuffered && !bLowLatency) { + oss << "-bf Number of consecutive B-frames" << std::endl; } -public: - virtual GUID GetEncodeGUID() { return guidCodec; } - virtual GUID GetPresetGUID() { return guidPreset; } - virtual NV_ENC_TUNING_INFO GetTuningInfo() { return m_TuningInfo; } + if (!bLowLatency) { + oss << "-rc Rate control mode: " << szRcModeNames << std::endl + << "-gop Length of GOP (Group of Pictures)" << std::endl + << "-bitrate Average bit rate, can be in unit of 1, K, M" + << std::endl + << "Note: Fps or Average bit rate values for each session can " + "be specified in the form of v1,v1,v3 (no space) for " + "AppTransOneToN" + << std::endl + << " If the number of 'bitrate' or 'fps' values " + "specified are less than the number of sessions, then the last " + "specified value will be considered for the remaining sessions" + << std::endl + << "-maxbitrate Max bit rate, can be in unit of 1, K, M" << std::endl + << "-vbvbufsize VBV buffer size in bits, can be in unit of 1, K, M" + << std::endl + << "-vbvinit VBV initial delay in bits, can be in unit of 1, K, M" + << std::endl + << "-aq Enable spatial AQ and set its stength (range 1-15, " + "0-auto)" + << std::endl + << "-temporalaq (No value) Enable temporal AQ" << std::endl + << "-cq Target constant quality level for VBR mode (range " + "1-51, 0-auto)" + << std::endl; + } + if (!bUnbuffered && !bLowLatency) { + oss << "-lookahead Maximum depth of lookahead (range 0-(31 - number of " + "B frames))" + << std::endl; + } + oss << "-qmin Min QP value" << std::endl + << "-qmax Max QP value" << std::endl + << "-initqp Initial QP value" << std::endl; + if (!bLowLatency) { + oss << "-constqp QP value for constqp rate control mode" << std::endl + << "Note: QP value can be in the form of qp_of_P_B_I or " + "qp_P,qp_B,qp_I (no space)" + << std::endl; + } + if (bUnbuffered && !bLowLatency) { + oss << "Note: Options -bf and -lookahead are unavailable for this app" + << std::endl; + } + return oss.str(); + } - /* - * @brief Set encoder initialization parameters based on input options - * This method parses the tokens formed from the command line options - * provided to the application and sets the fields from NV_ENC_INITIALIZE_PARAMS - * based on the supplied values. - */ + /** + * @brief Generate and return a string describing the values of the + * main/common encoder initialization parameters + */ + std::string MainParamToString(const NV_ENC_INITIALIZE_PARAMS *pParams) { + std::ostringstream os; + os << "Encoding Parameters:" << std::endl + << "\tcodec : " + << ConvertValueToString(vCodec, szCodecNames, pParams->encodeGUID) + << std::endl + << "\tpreset : " + << ConvertValueToString(vPreset, szPresetNames, pParams->presetGUID); + if (pParams->tuningInfo) { + os << std::endl + << "\ttuningInfo : " + << ConvertValueToString(vTuningInfo, szTuningInfoNames, + pParams->tuningInfo); + } + os << std::endl + << "\tprofile : " + << ConvertValueToString(vProfile, szProfileNames, + pParams->encodeConfig->profileGUID) + << std::endl + << "\tchroma : " + << ConvertValueToString(vChroma, szChromaNames, + (pParams->encodeGUID == NV_ENC_CODEC_H264_GUID) + ? pParams->encodeConfig->encodeCodecConfig + .h264Config.chromaFormatIDC + : (pParams->encodeGUID == NV_ENC_CODEC_HEVC_GUID) + ? pParams->encodeConfig->encodeCodecConfig + .hevcConfig.chromaFormatIDC + : pParams->encodeConfig->encodeCodecConfig + .av1Config.chromaFormatIDC) + << std::endl + << "\tbitdepth : " + << ((pParams->encodeGUID == NV_ENC_CODEC_H264_GUID) + ? pParams->encodeConfig->encodeCodecConfig.h264Config + .inputBitDepth + : (pParams->encodeGUID == NV_ENC_CODEC_HEVC_GUID) + ? pParams->encodeConfig->encodeCodecConfig.hevcConfig + .inputBitDepth + : pParams->encodeConfig->encodeCodecConfig.av1Config + .inputBitDepth) + << std::endl + << "\trc : " + << ConvertValueToString(vRcMode, szRcModeNames, + pParams->encodeConfig->rcParams.rateControlMode); + if (pParams->encodeConfig->rcParams.rateControlMode == + NV_ENC_PARAMS_RC_CONSTQP) { + os << " (P,B,I=" << pParams->encodeConfig->rcParams.constQP.qpInterP + << "," << pParams->encodeConfig->rcParams.constQP.qpInterB << "," + << pParams->encodeConfig->rcParams.constQP.qpIntra << ")"; + } + os << std::endl + << "\tfps : " << pParams->frameRateNum << "/" + << pParams->frameRateDen << std::endl + << "\tgop : " + << (pParams->encodeConfig->gopLength == NVENC_INFINITE_GOPLENGTH + ? "INF" + : std::to_string(pParams->encodeConfig->gopLength)) + << std::endl + << "\tbf : " << pParams->encodeConfig->frameIntervalP - 1 + << std::endl + << "\tmultipass : " << pParams->encodeConfig->rcParams.multiPass + << std::endl + << "\tsize : " << pParams->encodeWidth << "x" + << pParams->encodeHeight << std::endl + << "\tbitrate : " << pParams->encodeConfig->rcParams.averageBitRate + << std::endl + << "\tmaxbitrate : " << pParams->encodeConfig->rcParams.maxBitRate + << std::endl + << "\tvbvbufsize : " << pParams->encodeConfig->rcParams.vbvBufferSize + << std::endl + << "\tvbvinit : " << pParams->encodeConfig->rcParams.vbvInitialDelay + << std::endl + << "\taq : " + << (pParams->encodeConfig->rcParams.enableAQ + ? (pParams->encodeConfig->rcParams.aqStrength + ? std::to_string( + pParams->encodeConfig->rcParams.aqStrength) + : "auto") + : "disabled") + << std::endl + << "\ttemporalaq : " + << (pParams->encodeConfig->rcParams.enableTemporalAQ ? "enabled" + : "disabled") + << std::endl + << "\tlookahead : " + << (pParams->encodeConfig->rcParams.enableLookahead + ? std::to_string(pParams->encodeConfig->rcParams.lookaheadDepth) + : "disabled") + << std::endl + << "\tcq : " + << (unsigned int)pParams->encodeConfig->rcParams.targetQuality + << std::endl + << "\tqmin : P,B,I=" + << (int)pParams->encodeConfig->rcParams.minQP.qpInterP << "," + << (int)pParams->encodeConfig->rcParams.minQP.qpInterB << "," + << (int)pParams->encodeConfig->rcParams.minQP.qpIntra << std::endl + << "\tqmax : P,B,I=" + << (int)pParams->encodeConfig->rcParams.maxQP.qpInterP << "," + << (int)pParams->encodeConfig->rcParams.maxQP.qpInterB << "," + << (int)pParams->encodeConfig->rcParams.maxQP.qpIntra << std::endl + << "\tinitqp : P,B,I=" + << (int)pParams->encodeConfig->rcParams.initialRCQP.qpInterP << "," + << (int)pParams->encodeConfig->rcParams.initialRCQP.qpInterB << "," + << (int)pParams->encodeConfig->rcParams.initialRCQP.qpIntra; + return os.str(); + } - virtual void setTransOneToN(bool isTransOneToN) - { - bTransOneToN = isTransOneToN; + public: + virtual GUID GetEncodeGUID() { return guidCodec; } + virtual GUID GetPresetGUID() { return guidPreset; } + virtual NV_ENC_TUNING_INFO GetTuningInfo() { return m_TuningInfo; } + + /* + * @brief Set encoder initialization parameters based on input options + * This method parses the tokens formed from the command line options + * provided to the application and sets the fields from + * NV_ENC_INITIALIZE_PARAMS based on the supplied values. + */ + + virtual void setTransOneToN(bool isTransOneToN) { + bTransOneToN = isTransOneToN; + } + + virtual void SetInitParams(NV_ENC_INITIALIZE_PARAMS *pParams, + NV_ENC_BUFFER_FORMAT eBufferFormat) { + NV_ENC_CONFIG &config = *pParams->encodeConfig; + int nGOPOption = 0, nBFramesOption = 0; + for (unsigned i = 0; i < tokens.size(); i++) { + if (tokens[i] == "-codec" && ++i || tokens[i] == "-preset" && ++i || + tokens[i] == "-tuninginfo" && ++i || + tokens[i] == "-multipass" && ++i != tokens.size() && + ParseString("-multipass", tokens[i], vMultiPass, szMultipass, + &config.rcParams.multiPass) || + tokens[i] == "-profile" && ++i != tokens.size() && + (IsCodecH264() + ? ParseString("-profile", tokens[i], vH264Profile, + szH264ProfileNames, &config.profileGUID) + : IsCodecHEVC() + ? ParseString("-profile", tokens[i], vHevcProfile, + szHevcProfileNames, &config.profileGUID) + : ParseString("-profile", tokens[i], vAV1Profile, + szAV1ProfileNames, &config.profileGUID)) || + tokens[i] == "-rc" && ++i != tokens.size() && + ParseString("-rc", tokens[i], vRcMode, szRcModeNames, + &config.rcParams.rateControlMode) || + tokens[i] == "-fps" && ++i != tokens.size() && + ParseInt("-fps", tokens[i], &pParams->frameRateNum) || + tokens[i] == "-bf" && ++i != tokens.size() && + ParseInt("-bf", tokens[i], &config.frameIntervalP) && + ++config.frameIntervalP && ++nBFramesOption || + tokens[i] == "-bitrate" && ++i != tokens.size() && + ParseBitRate("-bitrate", tokens[i], + &config.rcParams.averageBitRate) || + tokens[i] == "-maxbitrate" && ++i != tokens.size() && + ParseBitRate("-maxbitrate", tokens[i], + &config.rcParams.maxBitRate) || + tokens[i] == "-vbvbufsize" && ++i != tokens.size() && + ParseBitRate("-vbvbufsize", tokens[i], + &config.rcParams.vbvBufferSize) || + tokens[i] == "-vbvinit" && ++i != tokens.size() && + ParseBitRate("-vbvinit", tokens[i], + &config.rcParams.vbvInitialDelay) || + tokens[i] == "-cq" && ++i != tokens.size() && + ParseInt("-cq", tokens[i], &config.rcParams.targetQuality) || + tokens[i] == "-initqp" && ++i != tokens.size() && + ParseQp("-initqp", tokens[i], &config.rcParams.initialRCQP) && + (config.rcParams.enableInitialRCQP = true) || + tokens[i] == "-qmin" && ++i != tokens.size() && + ParseQp("-qmin", tokens[i], &config.rcParams.minQP) && + (config.rcParams.enableMinQP = true) || + tokens[i] == "-qmax" && ++i != tokens.size() && + ParseQp("-qmax", tokens[i], &config.rcParams.maxQP) && + (config.rcParams.enableMaxQP = true) || + tokens[i] == "-constqp" && ++i != tokens.size() && + ParseQp("-constqp", tokens[i], &config.rcParams.constQP) || + tokens[i] == "-temporalaq" && + (config.rcParams.enableTemporalAQ = true)) { + continue; + } + if (tokens[i] == "-lookahead" && ++i != tokens.size() && + ParseInt("-lookahead", tokens[i], &config.rcParams.lookaheadDepth)) { + config.rcParams.enableLookahead = config.rcParams.lookaheadDepth > 0; + continue; + } + int aqStrength = 0; + if (tokens[i] == "-aq" && ++i != tokens.size() && + ParseInt("-aq", tokens[i], &aqStrength)) { + config.rcParams.enableAQ = true; + config.rcParams.aqStrength = aqStrength; + continue; + } + + if (tokens[i] == "-gop" && ++i != tokens.size() && + ParseInt("-gop", tokens[i], &config.gopLength)) { + nGOPOption = 1; + if (IsCodecH264()) { + config.encodeCodecConfig.h264Config.idrPeriod = config.gopLength; + } else if (IsCodecHEVC()) { + config.encodeCodecConfig.hevcConfig.idrPeriod = config.gopLength; + } else { + config.encodeCodecConfig.av1Config.idrPeriod = config.gopLength; + } + continue; + } + + if (tokens[i] == "-444") { + if (IsCodecH264()) { + config.encodeCodecConfig.h264Config.chromaFormatIDC = 3; + } else if (IsCodecHEVC()) { + config.encodeCodecConfig.hevcConfig.chromaFormatIDC = 3; + } else { + std::ostringstream errmessage; + errmessage << "Incorrect Parameter: YUV444 Input not supported with " + "AV1 Codec" + << std::endl; + throw std::invalid_argument(errmessage.str()); + } + continue; + } + + std::ostringstream errmessage; + errmessage << "Incorrect parameter: " << tokens[i] << std::endl; + errmessage << "Re-run the application with the -h option to get a list " + "of the supported options."; + errmessage << std::endl; + + throw std::invalid_argument(errmessage.str()); } - virtual void SetInitParams(NV_ENC_INITIALIZE_PARAMS *pParams, NV_ENC_BUFFER_FORMAT eBufferFormat) - { - NV_ENC_CONFIG &config = *pParams->encodeConfig; - int nGOPOption = 0, nBFramesOption = 0; - for (unsigned i = 0; i < tokens.size(); i++) - { - if ( - tokens[i] == "-codec" && ++i || - tokens[i] == "-preset" && ++i || - tokens[i] == "-tuninginfo" && ++i || - tokens[i] == "-multipass" && ++i != tokens.size() && ParseString("-multipass", tokens[i], vMultiPass, szMultipass, &config.rcParams.multiPass) || - tokens[i] == "-profile" && ++i != tokens.size() && (IsCodecH264() ? - ParseString("-profile", tokens[i], vH264Profile, szH264ProfileNames, &config.profileGUID) : IsCodecHEVC() ? - ParseString("-profile", tokens[i], vHevcProfile, szHevcProfileNames, &config.profileGUID) : - ParseString("-profile", tokens[i], vAV1Profile, szAV1ProfileNames, &config.profileGUID)) || - tokens[i] == "-rc" && ++i != tokens.size() && ParseString("-rc", tokens[i], vRcMode, szRcModeNames, &config.rcParams.rateControlMode) || - tokens[i] == "-fps" && ++i != tokens.size() && ParseInt("-fps", tokens[i], &pParams->frameRateNum) || - tokens[i] == "-bf" && ++i != tokens.size() && ParseInt("-bf", tokens[i], &config.frameIntervalP) && ++config.frameIntervalP && ++nBFramesOption || - tokens[i] == "-bitrate" && ++i != tokens.size() && ParseBitRate("-bitrate", tokens[i], &config.rcParams.averageBitRate) || - tokens[i] == "-maxbitrate" && ++i != tokens.size() && ParseBitRate("-maxbitrate", tokens[i], &config.rcParams.maxBitRate) || - tokens[i] == "-vbvbufsize" && ++i != tokens.size() && ParseBitRate("-vbvbufsize", tokens[i], &config.rcParams.vbvBufferSize) || - tokens[i] == "-vbvinit" && ++i != tokens.size() && ParseBitRate("-vbvinit", tokens[i], &config.rcParams.vbvInitialDelay) || - tokens[i] == "-cq" && ++i != tokens.size() && ParseInt("-cq", tokens[i], &config.rcParams.targetQuality) || - tokens[i] == "-initqp" && ++i != tokens.size() && ParseQp("-initqp", tokens[i], &config.rcParams.initialRCQP) && (config.rcParams.enableInitialRCQP = true) || - tokens[i] == "-qmin" && ++i != tokens.size() && ParseQp("-qmin", tokens[i], &config.rcParams.minQP) && (config.rcParams.enableMinQP = true) || - tokens[i] == "-qmax" && ++i != tokens.size() && ParseQp("-qmax", tokens[i], &config.rcParams.maxQP) && (config.rcParams.enableMaxQP = true) || - tokens[i] == "-constqp" && ++i != tokens.size() && ParseQp("-constqp", tokens[i], &config.rcParams.constQP) || - tokens[i] == "-temporalaq" && (config.rcParams.enableTemporalAQ = true) - ) - { - continue; - } - if (tokens[i] == "-lookahead" && ++i != tokens.size() && ParseInt("-lookahead", tokens[i], &config.rcParams.lookaheadDepth)) - { - config.rcParams.enableLookahead = config.rcParams.lookaheadDepth > 0; - continue; - } - int aqStrength; - if (tokens[i] == "-aq" && ++i != tokens.size() && ParseInt("-aq", tokens[i], &aqStrength)) { - config.rcParams.enableAQ = true; - config.rcParams.aqStrength = aqStrength; - continue; - } - - if (tokens[i] == "-gop" && ++i != tokens.size() && ParseInt("-gop", tokens[i], &config.gopLength)) - { - nGOPOption = 1; - if (IsCodecH264()) - { - config.encodeCodecConfig.h264Config.idrPeriod = config.gopLength; - } - else if (IsCodecHEVC()) - { - config.encodeCodecConfig.hevcConfig.idrPeriod = config.gopLength; - } - else - { - config.encodeCodecConfig.av1Config.idrPeriod = config.gopLength; - } - continue; - } - - if (tokens[i] == "-444") - { - if (IsCodecH264()) - { - config.encodeCodecConfig.h264Config.chromaFormatIDC = 3; - } - else if (IsCodecHEVC()) - { - config.encodeCodecConfig.hevcConfig.chromaFormatIDC = 3; - } - else - { - std::ostringstream errmessage; - errmessage << "Incorrect Parameter: YUV444 Input not supported with AV1 Codec" << std::endl; - throw std::invalid_argument(errmessage.str()); - } - continue; - } - - std::ostringstream errmessage; - errmessage << "Incorrect parameter: " << tokens[i] << std::endl; - errmessage << "Re-run the application with the -h option to get a list of the supported options."; - errmessage << std::endl; - - throw std::invalid_argument(errmessage.str()); - } - - if (IsCodecHEVC()) - { - if (eBufferFormat == NV_ENC_BUFFER_FORMAT_YUV420_10BIT || eBufferFormat == NV_ENC_BUFFER_FORMAT_YUV444_10BIT) - { - config.encodeCodecConfig.hevcConfig.inputBitDepth = NV_ENC_BIT_DEPTH_10; - config.encodeCodecConfig.hevcConfig.outputBitDepth = NV_ENC_BIT_DEPTH_10; - } - } - - if (IsCodecAV1()) - { - if (eBufferFormat == NV_ENC_BUFFER_FORMAT_YUV420_10BIT) - { - config.encodeCodecConfig.av1Config.inputBitDepth = NV_ENC_BIT_DEPTH_10; - config.encodeCodecConfig.av1Config.outputBitDepth = NV_ENC_BIT_DEPTH_10; - } - } - - if (nGOPOption && nBFramesOption && (config.gopLength < ((uint32_t)config.frameIntervalP))) - { - std::ostringstream errmessage; - errmessage << "gopLength (" << config.gopLength << ") must be greater or equal to frameIntervalP (number of B frames + 1) (" << config.frameIntervalP << ")\n"; - throw std::invalid_argument(errmessage.str()); - } - - funcInit(pParams); - LOG(INFO) << NvEncoderInitParam().MainParamToString(pParams); - LOG(TRACE) << NvEncoderInitParam().FullParamToString(pParams); + if (IsCodecHEVC()) { + if (eBufferFormat == NV_ENC_BUFFER_FORMAT_YUV420_10BIT || + eBufferFormat == NV_ENC_BUFFER_FORMAT_YUV444_10BIT) { + config.encodeCodecConfig.hevcConfig.inputBitDepth = NV_ENC_BIT_DEPTH_10; + config.encodeCodecConfig.hevcConfig.outputBitDepth = + NV_ENC_BIT_DEPTH_10; + } } -private: - /* - * Helper methods for parsing tokens (generated by splitting the command line) - * and performing conversions to the appropriate target type/value. - */ - template - bool ParseString(const std::string &strName, const std::string &strValue, const std::vector &vValue, const std::string &strValueNames, T *pValue) { - std::vector vstrValueName = split(strValueNames, ' '); - auto it = std::find(vstrValueName.begin(), vstrValueName.end(), strValue); - if (it == vstrValueName.end()) { - LOG(ERROR) << strName << " options: " << strValueNames; - return false; - } - *pValue = vValue[it - vstrValueName.begin()]; - return true; - } - template - std::string ConvertValueToString(const std::vector &vValue, const std::string &strValueNames, T value) { - auto it = std::find(vValue.begin(), vValue.end(), value); - if (it == vValue.end()) { - LOG(ERROR) << "Invalid value. Can't convert to one of " << strValueNames; - return std::string(); - } - return split(strValueNames, ' ')[it - vValue.begin()]; - } - bool ParseBitRate(const std::string &strName, const std::string &strValue, unsigned *pBitRate) { - if(bTransOneToN) - { - std::vector oneToNBitrate = split(strValue, ','); - std::string currBitrate; - if ((bitrateCnt + 1) > oneToNBitrate.size()) - { - currBitrate = oneToNBitrate[oneToNBitrate.size() - 1]; - } - else - { - currBitrate = oneToNBitrate[bitrateCnt]; - bitrateCnt++; - } - - try { - size_t l; - double r = std::stod(currBitrate, &l); - char c = currBitrate[l]; - if (c != 0 && c != 'k' && c != 'm') { - LOG(ERROR) << strName << " units: 1, K, M (lower case also allowed)"; - } - *pBitRate = (unsigned)((c == 'm' ? 1000000 : (c == 'k' ? 1000 : 1)) * r); - } - catch (std::invalid_argument) { - return false; - } - return true; - } - - else - { - try { - size_t l; - double r = std::stod(strValue, &l); - char c = strValue[l]; - if (c != 0 && c != 'k' && c != 'm') { - LOG(ERROR) << strName << " units: 1, K, M (lower case also allowed)"; - } - *pBitRate = (unsigned)((c == 'm' ? 1000000 : (c == 'k' ? 1000 : 1)) * r); - } - catch (std::invalid_argument) { - return false; - } - return true; - } - } - template - bool ParseInt(const std::string &strName, const std::string &strValue, T *pInt) { - if (bTransOneToN) - { - std::vector oneToNFps = split(strValue, ','); - std::string currFps; - if ((fpsCnt + 1) > oneToNFps.size()) - { - currFps = oneToNFps[oneToNFps.size() - 1]; - } - else - { - currFps = oneToNFps[fpsCnt]; - fpsCnt++; - } - - try { - *pInt = std::stoi(currFps); - } - catch (std::invalid_argument) { - LOG(ERROR) << strName << " need a value of positive number"; - return false; - } - return true; - } - else - { - try { - *pInt = std::stoi(strValue); - } - catch (std::invalid_argument) { - LOG(ERROR) << strName << " need a value of positive number"; - return false; - } - return true; - } - } - bool ParseQp(const std::string &strName, const std::string &strValue, NV_ENC_QP *pQp) { - std::vector vQp = split(strValue, ','); - try { - if (vQp.size() == 1) { - unsigned qp = (unsigned)std::stoi(vQp[0]); - *pQp = {qp, qp, qp}; - } else if (vQp.size() == 3) { - *pQp = {(unsigned)std::stoi(vQp[0]), (unsigned)std::stoi(vQp[1]), (unsigned)std::stoi(vQp[2])}; - } else { - LOG(ERROR) << strName << " qp_for_P_B_I or qp_P,qp_B,qp_I (no space is allowed)"; - return false; - } - } catch (std::invalid_argument) { - return false; - } - return true; - } - std::vector split(const std::string &s, char delim) { - std::stringstream ss(s); - std::string token; - std::vector tokens; - while (getline(ss, token, delim)) { - tokens.push_back(token); - } - return tokens; + if (IsCodecAV1()) { + if (eBufferFormat == NV_ENC_BUFFER_FORMAT_YUV420_10BIT) { + config.encodeCodecConfig.av1Config.inputBitDepth = NV_ENC_BIT_DEPTH_10; + config.encodeCodecConfig.av1Config.outputBitDepth = NV_ENC_BIT_DEPTH_10; + } } -private: - std::string strParam; - std::function funcInit = [](NV_ENC_INITIALIZE_PARAMS *pParams){}; - std::vector tokens; - GUID guidCodec = NV_ENC_CODEC_H264_GUID; - GUID guidPreset = NV_ENC_PRESET_P3_GUID; - NV_ENC_TUNING_INFO m_TuningInfo = NV_ENC_TUNING_INFO_HIGH_QUALITY; - bool bLowLatency = false; - uint32_t bitrateCnt = 0; - uint32_t fpsCnt = 0; - bool bTransOneToN = 0; - - const char *szCodecNames = "h264 hevc av1"; - std::vector vCodec = std::vector { - NV_ENC_CODEC_H264_GUID, - NV_ENC_CODEC_HEVC_GUID, - NV_ENC_CODEC_AV1_GUID - }; - - const char *szChromaNames = "yuv420 yuv444"; - std::vector vChroma = std::vector - { - 1, 3 - }; - - const char *szPresetNames = "p1 p2 p3 p4 p5 p6 p7"; - std::vector vPreset = std::vector { - NV_ENC_PRESET_P1_GUID, - NV_ENC_PRESET_P2_GUID, - NV_ENC_PRESET_P3_GUID, - NV_ENC_PRESET_P4_GUID, - NV_ENC_PRESET_P5_GUID, - NV_ENC_PRESET_P6_GUID, - NV_ENC_PRESET_P7_GUID, - }; - - const char *szH264ProfileNames = "baseline main high high444"; - std::vector vH264Profile = std::vector { - NV_ENC_H264_PROFILE_BASELINE_GUID, - NV_ENC_H264_PROFILE_MAIN_GUID, - NV_ENC_H264_PROFILE_HIGH_GUID, - NV_ENC_H264_PROFILE_HIGH_444_GUID, - }; - const char *szHevcProfileNames = "main main10 frext"; - std::vector vHevcProfile = std::vector { - NV_ENC_HEVC_PROFILE_MAIN_GUID, - NV_ENC_HEVC_PROFILE_MAIN10_GUID, - NV_ENC_HEVC_PROFILE_FREXT_GUID, - }; - const char *szAV1ProfileNames = "main"; - std::vector vAV1Profile = std::vector{ - NV_ENC_AV1_PROFILE_MAIN_GUID, - }; - - const char *szProfileNames = "(default) auto baseline(h264) main(h264) high(h264) high444(h264)" - " stereo(h264) progressiv_high(h264) constrained_high(h264)" - " main(hevc) main10(hevc) frext(hevc)" - " main(av1) high(av1)"; - std::vector vProfile = std::vector { - GUID{}, - NV_ENC_CODEC_PROFILE_AUTOSELECT_GUID, - NV_ENC_H264_PROFILE_BASELINE_GUID, - NV_ENC_H264_PROFILE_MAIN_GUID, - NV_ENC_H264_PROFILE_HIGH_GUID, - NV_ENC_H264_PROFILE_HIGH_444_GUID, - NV_ENC_H264_PROFILE_STEREO_GUID, - NV_ENC_H264_PROFILE_PROGRESSIVE_HIGH_GUID, - NV_ENC_H264_PROFILE_CONSTRAINED_HIGH_GUID, - NV_ENC_HEVC_PROFILE_MAIN_GUID, - NV_ENC_HEVC_PROFILE_MAIN10_GUID, - NV_ENC_HEVC_PROFILE_FREXT_GUID, - NV_ENC_AV1_PROFILE_MAIN_GUID, - }; - - const char *szLowLatencyTuningInfoNames = "lowlatency ultralowlatency"; - const char *szTuningInfoNames = "hq lowlatency ultralowlatency lossless uhq"; - std::vector vTuningInfo = std::vector{ - NV_ENC_TUNING_INFO_HIGH_QUALITY, - NV_ENC_TUNING_INFO_LOW_LATENCY, - NV_ENC_TUNING_INFO_ULTRA_LOW_LATENCY, - NV_ENC_TUNING_INFO_LOSSLESS, - NV_ENC_TUNING_INFO_ULTRA_HIGH_QUALITY - }; - - const char *szRcModeNames = "constqp vbr cbr"; - std::vector vRcMode = std::vector { - NV_ENC_PARAMS_RC_CONSTQP, - NV_ENC_PARAMS_RC_VBR, - NV_ENC_PARAMS_RC_CBR, - }; - - const char *szMultipass = "disabled qres fullres"; - std::vector vMultiPass = std::vector{ - NV_ENC_MULTI_PASS_DISABLED, - NV_ENC_TWO_PASS_QUARTER_RESOLUTION, - NV_ENC_TWO_PASS_FULL_RESOLUTION, - }; - - const char *szQpMapModeNames = "disabled emphasis_level_map delta_qp_map qp_map"; - std::vector vQpMapMode = std::vector { - NV_ENC_QP_MAP_DISABLED, - NV_ENC_QP_MAP_EMPHASIS, - NV_ENC_QP_MAP_DELTA, - NV_ENC_QP_MAP, - }; - - -public: - /* - * Generates and returns a string describing the values for each field in - * the NV_ENC_INITIALIZE_PARAMS structure (i.e. a description of the entire - * set of initialization parameters supplied to the API). - */ - std::string FullParamToString(const NV_ENC_INITIALIZE_PARAMS *pInitializeParams) { - std::ostringstream os; - os << "NV_ENC_INITIALIZE_PARAMS:" << std::endl - << "encodeGUID: " << ConvertValueToString(vCodec, szCodecNames, pInitializeParams->encodeGUID) << std::endl - << "presetGUID: " << ConvertValueToString(vPreset, szPresetNames, pInitializeParams->presetGUID) << std::endl; - if (pInitializeParams->tuningInfo) - { - os << "tuningInfo: " << ConvertValueToString(vTuningInfo, szTuningInfoNames, pInitializeParams->tuningInfo) << std::endl; - } - os - << "encodeWidth: " << pInitializeParams->encodeWidth << std::endl - << "encodeHeight: " << pInitializeParams->encodeHeight << std::endl - << "darWidth: " << pInitializeParams->darWidth << std::endl - << "darHeight: " << pInitializeParams->darHeight << std::endl - << "frameRateNum: " << pInitializeParams->frameRateNum << std::endl - << "frameRateDen: " << pInitializeParams->frameRateDen << std::endl - << "enableEncodeAsync: " << pInitializeParams->enableEncodeAsync << std::endl - << "reportSliceOffsets: " << pInitializeParams->reportSliceOffsets << std::endl - << "enableSubFrameWrite: " << pInitializeParams->enableSubFrameWrite << std::endl - << "enableExternalMEHints: " << pInitializeParams->enableExternalMEHints << std::endl - << "enableMEOnlyMode: " << pInitializeParams->enableMEOnlyMode << std::endl - << "enableWeightedPrediction: " << pInitializeParams->enableWeightedPrediction << std::endl - << "maxEncodeWidth: " << pInitializeParams->maxEncodeWidth << std::endl - << "maxEncodeHeight: " << pInitializeParams->maxEncodeHeight << std::endl - << "maxMEHintCountsPerBlock: " << pInitializeParams->maxMEHintCountsPerBlock << std::endl - ; - NV_ENC_CONFIG *pConfig = pInitializeParams->encodeConfig; - os << "NV_ENC_CONFIG:" << std::endl - << "profile: " << ConvertValueToString(vProfile, szProfileNames, pConfig->profileGUID) << std::endl - << "gopLength: " << pConfig->gopLength << std::endl - << "frameIntervalP: " << pConfig->frameIntervalP << std::endl - << "monoChromeEncoding: " << pConfig->monoChromeEncoding << std::endl - << "frameFieldMode: " << pConfig->frameFieldMode << std::endl - << "mvPrecision: " << pConfig->mvPrecision << std::endl - << "NV_ENC_RC_PARAMS:" << std::endl - << " rateControlMode: 0x" << std::hex << pConfig->rcParams.rateControlMode << std::dec << std::endl - << " constQP: " << pConfig->rcParams.constQP.qpInterP << ", " << pConfig->rcParams.constQP.qpInterB << ", " << pConfig->rcParams.constQP.qpIntra << std::endl - << " averageBitRate: " << pConfig->rcParams.averageBitRate << std::endl - << " maxBitRate: " << pConfig->rcParams.maxBitRate << std::endl - << " vbvBufferSize: " << pConfig->rcParams.vbvBufferSize << std::endl - << " vbvInitialDelay: " << pConfig->rcParams.vbvInitialDelay << std::endl - << " enableMinQP: " << pConfig->rcParams.enableMinQP << std::endl - << " enableMaxQP: " << pConfig->rcParams.enableMaxQP << std::endl - << " enableInitialRCQP: " << pConfig->rcParams.enableInitialRCQP << std::endl - << " enableAQ: " << pConfig->rcParams.enableAQ << std::endl - << " qpMapMode: " << ConvertValueToString(vQpMapMode, szQpMapModeNames, pConfig->rcParams.qpMapMode) << std::endl - << " multipass: " << ConvertValueToString(vMultiPass, szMultipass, pConfig->rcParams.multiPass) << std::endl - << " enableLookahead: " << pConfig->rcParams.enableLookahead << std::endl - << " disableIadapt: " << pConfig->rcParams.disableIadapt << std::endl - << " disableBadapt: " << pConfig->rcParams.disableBadapt << std::endl - << " enableTemporalAQ: " << pConfig->rcParams.enableTemporalAQ << std::endl - << " zeroReorderDelay: " << pConfig->rcParams.zeroReorderDelay << std::endl - << " enableNonRefP: " << pConfig->rcParams.enableNonRefP << std::endl - << " strictGOPTarget: " << pConfig->rcParams.strictGOPTarget << std::endl - << " aqStrength: " << pConfig->rcParams.aqStrength << std::endl - << " minQP: " << pConfig->rcParams.minQP.qpInterP << ", " << pConfig->rcParams.minQP.qpInterB << ", " << pConfig->rcParams.minQP.qpIntra << std::endl - << " maxQP: " << pConfig->rcParams.maxQP.qpInterP << ", " << pConfig->rcParams.maxQP.qpInterB << ", " << pConfig->rcParams.maxQP.qpIntra << std::endl - << " initialRCQP: " << pConfig->rcParams.initialRCQP.qpInterP << ", " << pConfig->rcParams.initialRCQP.qpInterB << ", " << pConfig->rcParams.initialRCQP.qpIntra << std::endl - << " temporallayerIdxMask: " << pConfig->rcParams.temporallayerIdxMask << std::endl - << " temporalLayerQP: " << (int)pConfig->rcParams.temporalLayerQP[0] << ", " << (int)pConfig->rcParams.temporalLayerQP[1] << ", " << (int)pConfig->rcParams.temporalLayerQP[2] << ", " << (int)pConfig->rcParams.temporalLayerQP[3] << ", " << (int)pConfig->rcParams.temporalLayerQP[4] << ", " << (int)pConfig->rcParams.temporalLayerQP[5] << ", " << (int)pConfig->rcParams.temporalLayerQP[6] << ", " << (int)pConfig->rcParams.temporalLayerQP[7] << std::endl - << " targetQuality: " << pConfig->rcParams.targetQuality << std::endl - << " lookaheadDepth: " << pConfig->rcParams.lookaheadDepth << std::endl; - if (pInitializeParams->encodeGUID == NV_ENC_CODEC_H264_GUID) { - os - << "NV_ENC_CODEC_CONFIG (H264):" << std::endl - << " enableStereoMVC: " << pConfig->encodeCodecConfig.h264Config.enableStereoMVC << std::endl - << " hierarchicalPFrames: " << pConfig->encodeCodecConfig.h264Config.hierarchicalPFrames << std::endl - << " hierarchicalBFrames: " << pConfig->encodeCodecConfig.h264Config.hierarchicalBFrames << std::endl - << " outputBufferingPeriodSEI: " << pConfig->encodeCodecConfig.h264Config.outputBufferingPeriodSEI << std::endl - << " outputPictureTimingSEI: " << pConfig->encodeCodecConfig.h264Config.outputPictureTimingSEI << std::endl - << " outputAUD: " << pConfig->encodeCodecConfig.h264Config.outputAUD << std::endl - << " disableSPSPPS: " << pConfig->encodeCodecConfig.h264Config.disableSPSPPS << std::endl - << " outputFramePackingSEI: " << pConfig->encodeCodecConfig.h264Config.outputFramePackingSEI << std::endl - << " outputRecoveryPointSEI: " << pConfig->encodeCodecConfig.h264Config.outputRecoveryPointSEI << std::endl - << " enableIntraRefresh: " << pConfig->encodeCodecConfig.h264Config.enableIntraRefresh << std::endl - << " enableConstrainedEncoding: " << pConfig->encodeCodecConfig.h264Config.enableConstrainedEncoding << std::endl - << " repeatSPSPPS: " << pConfig->encodeCodecConfig.h264Config.repeatSPSPPS << std::endl - << " enableVFR: " << pConfig->encodeCodecConfig.h264Config.enableVFR << std::endl - << " enableLTR: " << pConfig->encodeCodecConfig.h264Config.enableLTR << std::endl - << " qpPrimeYZeroTransformBypassFlag: " << pConfig->encodeCodecConfig.h264Config.qpPrimeYZeroTransformBypassFlag << std::endl - << " useConstrainedIntraPred: " << pConfig->encodeCodecConfig.h264Config.useConstrainedIntraPred << std::endl - << " level: " << pConfig->encodeCodecConfig.h264Config.level << std::endl - << " idrPeriod: " << pConfig->encodeCodecConfig.h264Config.idrPeriod << std::endl - << " separateColourPlaneFlag: " << pConfig->encodeCodecConfig.h264Config.separateColourPlaneFlag << std::endl - << " disableDeblockingFilterIDC: " << pConfig->encodeCodecConfig.h264Config.disableDeblockingFilterIDC << std::endl - << " numTemporalLayers: " << pConfig->encodeCodecConfig.h264Config.numTemporalLayers << std::endl - << " spsId: " << pConfig->encodeCodecConfig.h264Config.spsId << std::endl - << " ppsId: " << pConfig->encodeCodecConfig.h264Config.ppsId << std::endl - << " adaptiveTransformMode: " << pConfig->encodeCodecConfig.h264Config.adaptiveTransformMode << std::endl - << " fmoMode: " << pConfig->encodeCodecConfig.h264Config.fmoMode << std::endl - << " bdirectMode: " << pConfig->encodeCodecConfig.h264Config.bdirectMode << std::endl - << " entropyCodingMode: " << pConfig->encodeCodecConfig.h264Config.entropyCodingMode << std::endl - << " stereoMode: " << pConfig->encodeCodecConfig.h264Config.stereoMode << std::endl - << " intraRefreshPeriod: " << pConfig->encodeCodecConfig.h264Config.intraRefreshPeriod << std::endl - << " intraRefreshCnt: " << pConfig->encodeCodecConfig.h264Config.intraRefreshCnt << std::endl - << " maxNumRefFrames: " << pConfig->encodeCodecConfig.h264Config.maxNumRefFrames << std::endl - << " sliceMode: " << pConfig->encodeCodecConfig.h264Config.sliceMode << std::endl - << " sliceModeData: " << pConfig->encodeCodecConfig.h264Config.sliceModeData << std::endl - << " NV_ENC_CONFIG_H264_VUI_PARAMETERS:" << std::endl - << " overscanInfoPresentFlag: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.overscanInfoPresentFlag << std::endl - << " overscanInfo: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.overscanInfo << std::endl - << " videoSignalTypePresentFlag: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.videoSignalTypePresentFlag << std::endl - << " videoFormat: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.videoFormat << std::endl - << " videoFullRangeFlag: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.videoFullRangeFlag << std::endl - << " colourDescriptionPresentFlag: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.colourDescriptionPresentFlag << std::endl - << " colourPrimaries: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.colourPrimaries << std::endl - << " transferCharacteristics: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.transferCharacteristics << std::endl - << " colourMatrix: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.colourMatrix << std::endl - << " chromaSampleLocationFlag: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.chromaSampleLocationFlag << std::endl - << " chromaSampleLocationTop: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.chromaSampleLocationTop << std::endl - << " chromaSampleLocationBot: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.chromaSampleLocationBot << std::endl - << " bitstreamRestrictionFlag: " << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.bitstreamRestrictionFlag << std::endl - << " ltrNumFrames: " << pConfig->encodeCodecConfig.h264Config.ltrNumFrames << std::endl - << " ltrTrustMode: " << pConfig->encodeCodecConfig.h264Config.ltrTrustMode << std::endl - << " chromaFormatIDC: " << pConfig->encodeCodecConfig.h264Config.chromaFormatIDC << std::endl - << " maxTemporalLayers: " << pConfig->encodeCodecConfig.h264Config.maxTemporalLayers << std::endl; - } else if (pInitializeParams->encodeGUID == NV_ENC_CODEC_HEVC_GUID) { - os - << "NV_ENC_CODEC_CONFIG (HEVC):" << std::endl - << " level: " << pConfig->encodeCodecConfig.hevcConfig.level << std::endl - << " tier: " << pConfig->encodeCodecConfig.hevcConfig.tier << std::endl - << " minCUSize: " << pConfig->encodeCodecConfig.hevcConfig.minCUSize << std::endl - << " maxCUSize: " << pConfig->encodeCodecConfig.hevcConfig.maxCUSize << std::endl - << " useConstrainedIntraPred: " << pConfig->encodeCodecConfig.hevcConfig.useConstrainedIntraPred << std::endl - << " disableDeblockAcrossSliceBoundary: " << pConfig->encodeCodecConfig.hevcConfig.disableDeblockAcrossSliceBoundary << std::endl - << " outputBufferingPeriodSEI: " << pConfig->encodeCodecConfig.hevcConfig.outputBufferingPeriodSEI << std::endl - << " outputPictureTimingSEI: " << pConfig->encodeCodecConfig.hevcConfig.outputPictureTimingSEI << std::endl - << " outputAUD: " << pConfig->encodeCodecConfig.hevcConfig.outputAUD << std::endl - << " enableLTR: " << pConfig->encodeCodecConfig.hevcConfig.enableLTR << std::endl - << " disableSPSPPS: " << pConfig->encodeCodecConfig.hevcConfig.disableSPSPPS << std::endl - << " repeatSPSPPS: " << pConfig->encodeCodecConfig.hevcConfig.repeatSPSPPS << std::endl - << " enableIntraRefresh: " << pConfig->encodeCodecConfig.hevcConfig.enableIntraRefresh << std::endl - << " chromaFormatIDC: " << pConfig->encodeCodecConfig.hevcConfig.chromaFormatIDC << std::endl - << " inputBitDepth: " << pConfig->encodeCodecConfig.hevcConfig.inputBitDepth << std::endl - << " outputBitDepth: " << pConfig->encodeCodecConfig.hevcConfig.outputBitDepth << std::endl - << " idrPeriod: " << pConfig->encodeCodecConfig.hevcConfig.idrPeriod << std::endl - << " intraRefreshPeriod: " << pConfig->encodeCodecConfig.hevcConfig.intraRefreshPeriod << std::endl - << " intraRefreshCnt: " << pConfig->encodeCodecConfig.hevcConfig.intraRefreshCnt << std::endl - << " maxNumRefFramesInDPB: " << pConfig->encodeCodecConfig.hevcConfig.maxNumRefFramesInDPB << std::endl - << " ltrNumFrames: " << pConfig->encodeCodecConfig.hevcConfig.ltrNumFrames << std::endl - << " vpsId: " << pConfig->encodeCodecConfig.hevcConfig.vpsId << std::endl - << " spsId: " << pConfig->encodeCodecConfig.hevcConfig.spsId << std::endl - << " ppsId: " << pConfig->encodeCodecConfig.hevcConfig.ppsId << std::endl - << " sliceMode: " << pConfig->encodeCodecConfig.hevcConfig.sliceMode << std::endl - << " sliceModeData: " << pConfig->encodeCodecConfig.hevcConfig.sliceModeData << std::endl - << " maxTemporalLayersMinus1: " << pConfig->encodeCodecConfig.hevcConfig.maxTemporalLayersMinus1 << std::endl - << " NV_ENC_CONFIG_HEVC_VUI_PARAMETERS:" << std::endl - << " overscanInfoPresentFlag: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.overscanInfoPresentFlag << std::endl - << " overscanInfo: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.overscanInfo << std::endl - << " videoSignalTypePresentFlag: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.videoSignalTypePresentFlag << std::endl - << " videoFormat: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.videoFormat << std::endl - << " videoFullRangeFlag: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.videoFullRangeFlag << std::endl - << " colourDescriptionPresentFlag: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.colourDescriptionPresentFlag << std::endl - << " colourPrimaries: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.colourPrimaries << std::endl - << " transferCharacteristics: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.transferCharacteristics << std::endl - << " colourMatrix: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.colourMatrix << std::endl - << " chromaSampleLocationFlag: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.chromaSampleLocationFlag << std::endl - << " chromaSampleLocationTop: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.chromaSampleLocationTop << std::endl - << " chromaSampleLocationBot: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.chromaSampleLocationBot << std::endl - << " bitstreamRestrictionFlag: " << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.bitstreamRestrictionFlag << std::endl - << " ltrTrustMode: " << pConfig->encodeCodecConfig.hevcConfig.ltrTrustMode << std::endl; - } else if (pInitializeParams->encodeGUID == NV_ENC_CODEC_AV1_GUID) { - os - << "NV_ENC_CODEC_CONFIG (AV1):" << std::endl - << " level: " << pConfig->encodeCodecConfig.av1Config.level << std::endl - << " tier: " << pConfig->encodeCodecConfig.av1Config.tier << std::endl - << " minPartSize: " << pConfig->encodeCodecConfig.av1Config.minPartSize << std::endl - << " maxPartSize: " << pConfig->encodeCodecConfig.av1Config.maxPartSize << std::endl - << " outputAnnexBFormat: " << pConfig->encodeCodecConfig.av1Config.outputAnnexBFormat << std::endl - << " enableTimingInfo: " << pConfig->encodeCodecConfig.av1Config.enableTimingInfo << std::endl - << " enableDecoderModelInfo: " << pConfig->encodeCodecConfig.av1Config.enableDecoderModelInfo << std::endl - << " enableFrameIdNumbers: " << pConfig->encodeCodecConfig.av1Config.enableFrameIdNumbers << std::endl - << " disableSeqHdr: " << pConfig->encodeCodecConfig.av1Config.disableSeqHdr << std::endl - << " repeatSeqHdr: " << pConfig->encodeCodecConfig.av1Config.repeatSeqHdr << std::endl - << " enableIntraRefresh: " << pConfig->encodeCodecConfig.av1Config.enableIntraRefresh << std::endl - << " chromaFormatIDC: " << pConfig->encodeCodecConfig.av1Config.chromaFormatIDC << std::endl - << " enableBitstreamPadding: " << pConfig->encodeCodecConfig.av1Config.enableBitstreamPadding << std::endl - << " enableCustomTileConfig: " << pConfig->encodeCodecConfig.av1Config.enableCustomTileConfig << std::endl - << " enableFilmGrainParams: " << pConfig->encodeCodecConfig.av1Config.enableFilmGrainParams << std::endl - << " inputBitDepth: " << pConfig->encodeCodecConfig.av1Config.inputBitDepth << std::endl - << " outputBitDepth: " << pConfig->encodeCodecConfig.av1Config.outputBitDepth << std::endl - << " idrPeriod: " << pConfig->encodeCodecConfig.av1Config.idrPeriod << std::endl - << " intraRefreshPeriod: " << pConfig->encodeCodecConfig.av1Config.intraRefreshPeriod << std::endl - << " intraRefreshCnt: " << pConfig->encodeCodecConfig.av1Config.intraRefreshCnt << std::endl - << " maxNumRefFramesInDPB: " << pConfig->encodeCodecConfig.av1Config.maxNumRefFramesInDPB << std::endl - << " numTileColumns: " << pConfig->encodeCodecConfig.av1Config.numTileColumns << std::endl - << " numTileRows: " << pConfig->encodeCodecConfig.av1Config.numTileRows << std::endl - << " maxTemporalLayersMinus1: " << pConfig->encodeCodecConfig.av1Config.maxTemporalLayersMinus1 << std::endl - << " colorPrimaries: " << pConfig->encodeCodecConfig.av1Config.colorPrimaries << std::endl - << " transferCharacteristics: " << pConfig->encodeCodecConfig.av1Config.transferCharacteristics << std::endl - << " matrixCoefficients: " << pConfig->encodeCodecConfig.av1Config.matrixCoefficients << std::endl - << " colorRange: " << pConfig->encodeCodecConfig.av1Config.colorRange << std::endl - << " chromaSamplePosition: " << pConfig->encodeCodecConfig.av1Config.chromaSamplePosition << std::endl - << " useBFramesAsRef: " << pConfig->encodeCodecConfig.av1Config.useBFramesAsRef << std::endl - << " numFwdRefs: " << pConfig->encodeCodecConfig.av1Config.numFwdRefs << std::endl - << " numBwdRefs: " << pConfig->encodeCodecConfig.av1Config.numBwdRefs << std::endl; - if (pConfig->encodeCodecConfig.av1Config.filmGrainParams != NULL) - { - os - << " NV_ENC_FILM_GRAIN_PARAMS_AV1:" << std::endl - << " applyGrain: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->applyGrain << std::endl - << " chromaScalingFromLuma: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->chromaScalingFromLuma << std::endl - << " overlapFlag: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->overlapFlag << std::endl - << " clipToRestrictedRange: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->clipToRestrictedRange << std::endl - << " grainScalingMinus8: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->grainScalingMinus8 << std::endl - << " arCoeffLag: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->arCoeffLag << std::endl - << " numYPoints: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->numYPoints << std::endl - << " numCbPoints: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->numCbPoints << std::endl - << " numCrPoints: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->numCrPoints << std::endl - << " arCoeffShiftMinus6: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->arCoeffShiftMinus6 << std::endl - << " grainScaleShift: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->grainScaleShift << std::endl - << " cbMult: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->cbMult << std::endl - << " cbLumaMult: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->cbLumaMult << std::endl - << " cbOffset: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->cbOffset << std::endl - << " crMult: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->crMult << std::endl - << " crLumaMult: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->crLumaMult << std::endl - << " crOffset: " << pConfig->encodeCodecConfig.av1Config.filmGrainParams->crOffset << std::endl; - } - } - - return os.str(); + if (nGOPOption && nBFramesOption && + (config.gopLength < ((uint32_t)config.frameIntervalP))) { + std::ostringstream errmessage; + errmessage << "gopLength (" << config.gopLength + << ") must be greater or equal to frameIntervalP (number of B " + "frames + 1) (" + << config.frameIntervalP << ")\n"; + throw std::invalid_argument(errmessage.str()); } + + funcInit(pParams); + LOG(INFO) << NvEncoderInitParam().MainParamToString(pParams); + LOG(TRACE) << NvEncoderInitParam().FullParamToString(pParams); + } + + private: + /* + * Helper methods for parsing tokens (generated by splitting the command line) + * and performing conversions to the appropriate target type/value. + */ + template + bool ParseString(const std::string &strName, const std::string &strValue, + const std::vector &vValue, + const std::string &strValueNames, T *pValue) { + std::vector vstrValueName = split(strValueNames, ' '); + auto it = std::find(vstrValueName.begin(), vstrValueName.end(), strValue); + if (it == vstrValueName.end()) { + LOG(ERROR) << strName << " options: " << strValueNames; + return false; + } + *pValue = vValue[it - vstrValueName.begin()]; + return true; + } + template + std::string ConvertValueToString(const std::vector &vValue, + const std::string &strValueNames, T value) { + auto it = std::find(vValue.begin(), vValue.end(), value); + if (it == vValue.end()) { + LOG(ERROR) << "Invalid value. Can't convert to one of " << strValueNames; + return std::string(); + } + return split(strValueNames, ' ')[it - vValue.begin()]; + } + bool ParseBitRate(const std::string &strName, const std::string &strValue, + unsigned *pBitRate) { + if (bTransOneToN) { + std::vector oneToNBitrate = split(strValue, ','); + std::string currBitrate; + if ((bitrateCnt + 1) > oneToNBitrate.size()) { + currBitrate = oneToNBitrate[oneToNBitrate.size() - 1]; + } else { + currBitrate = oneToNBitrate[bitrateCnt]; + bitrateCnt++; + } + + try { + size_t l; + double r = std::stod(currBitrate, &l); + char c = currBitrate[l]; + if (c != 0 && c != 'k' && c != 'm') { + LOG(ERROR) << strName << " units: 1, K, M (lower case also allowed)"; + } + *pBitRate = + (unsigned)((c == 'm' ? 1000000 : (c == 'k' ? 1000 : 1)) * r); + } catch (std::invalid_argument) { + return false; + } + return true; + } + + else { + try { + size_t l; + double r = std::stod(strValue, &l); + char c = strValue[l]; + if (c != 0 && c != 'k' && c != 'm') { + LOG(ERROR) << strName << " units: 1, K, M (lower case also allowed)"; + } + *pBitRate = + (unsigned)((c == 'm' ? 1000000 : (c == 'k' ? 1000 : 1)) * r); + } catch (std::invalid_argument) { + return false; + } + return true; + } + } + template + bool ParseInt(const std::string &strName, const std::string &strValue, + T *pInt) { + if (bTransOneToN) { + std::vector oneToNFps = split(strValue, ','); + std::string currFps; + if ((fpsCnt + 1) > oneToNFps.size()) { + currFps = oneToNFps[oneToNFps.size() - 1]; + } else { + currFps = oneToNFps[fpsCnt]; + fpsCnt++; + } + + try { + *pInt = (T)std::stoi(currFps); + } catch (std::invalid_argument) { + LOG(ERROR) << strName << " need a value of positive number"; + return false; + } + return true; + } else { + try { + *pInt = (T)std::stoi(strValue); + } catch (std::invalid_argument) { + LOG(ERROR) << strName << " need a value of positive number"; + return false; + } + return true; + } + } + bool ParseQp(const std::string &strName, const std::string &strValue, + NV_ENC_QP *pQp) { + std::vector vQp = split(strValue, ','); + try { + if (vQp.size() == 1) { + unsigned qp = (unsigned)std::stoi(vQp[0]); + *pQp = {qp, qp, qp}; + } else if (vQp.size() == 3) { + *pQp = {(unsigned)std::stoi(vQp[0]), (unsigned)std::stoi(vQp[1]), + (unsigned)std::stoi(vQp[2])}; + } else { + LOG(ERROR) << strName + << " qp_for_P_B_I or qp_P,qp_B,qp_I (no space is allowed)"; + return false; + } + } catch (std::invalid_argument) { + return false; + } + return true; + } + std::vector split(const std::string &s, char delim) { + std::stringstream ss(s); + std::string token; + std::vector tks; + while (getline(ss, token, delim)) { + tks.push_back(token); + } + return tks; + } + + private: + std::string strParam; + std::function funcInit = + []([[maybe_unused]] NV_ENC_INITIALIZE_PARAMS *pParams) {}; + std::vector tokens; + GUID guidCodec = NV_ENC_CODEC_H264_GUID; + GUID guidPreset = NV_ENC_PRESET_P3_GUID; + NV_ENC_TUNING_INFO m_TuningInfo = NV_ENC_TUNING_INFO_HIGH_QUALITY; + bool bLowLatency = false; + uint32_t bitrateCnt = 0; + uint32_t fpsCnt = 0; + bool bTransOneToN = 0; + + const char *szCodecNames = "h264 hevc av1"; + std::vector vCodec = std::vector{ + NV_ENC_CODEC_H264_GUID, NV_ENC_CODEC_HEVC_GUID, NV_ENC_CODEC_AV1_GUID}; + + const char *szChromaNames = "yuv420 yuv444"; + std::vector vChroma = std::vector{1, 3}; + + const char *szPresetNames = "p1 p2 p3 p4 p5 p6 p7"; + std::vector vPreset = std::vector{ + NV_ENC_PRESET_P1_GUID, NV_ENC_PRESET_P2_GUID, NV_ENC_PRESET_P3_GUID, + NV_ENC_PRESET_P4_GUID, NV_ENC_PRESET_P5_GUID, NV_ENC_PRESET_P6_GUID, + NV_ENC_PRESET_P7_GUID, + }; + + const char *szH264ProfileNames = "baseline main high high444"; + std::vector vH264Profile = std::vector{ + NV_ENC_H264_PROFILE_BASELINE_GUID, + NV_ENC_H264_PROFILE_MAIN_GUID, + NV_ENC_H264_PROFILE_HIGH_GUID, + NV_ENC_H264_PROFILE_HIGH_444_GUID, + }; + const char *szHevcProfileNames = "main main10 frext"; + std::vector vHevcProfile = std::vector{ + NV_ENC_HEVC_PROFILE_MAIN_GUID, + NV_ENC_HEVC_PROFILE_MAIN10_GUID, + NV_ENC_HEVC_PROFILE_FREXT_GUID, + }; + const char *szAV1ProfileNames = "main"; + std::vector vAV1Profile = std::vector{ + NV_ENC_AV1_PROFILE_MAIN_GUID, + }; + + const char *szProfileNames = + "(default) auto baseline(h264) main(h264) high(h264) high444(h264)" + " stereo(h264) progressiv_high(h264) constrained_high(h264)" + " main(hevc) main10(hevc) frext(hevc)" + " main(av1) high(av1)"; + std::vector vProfile = std::vector{ + GUID{}, + NV_ENC_CODEC_PROFILE_AUTOSELECT_GUID, + NV_ENC_H264_PROFILE_BASELINE_GUID, + NV_ENC_H264_PROFILE_MAIN_GUID, + NV_ENC_H264_PROFILE_HIGH_GUID, + NV_ENC_H264_PROFILE_HIGH_444_GUID, + NV_ENC_H264_PROFILE_STEREO_GUID, + NV_ENC_H264_PROFILE_PROGRESSIVE_HIGH_GUID, + NV_ENC_H264_PROFILE_CONSTRAINED_HIGH_GUID, + NV_ENC_HEVC_PROFILE_MAIN_GUID, + NV_ENC_HEVC_PROFILE_MAIN10_GUID, + NV_ENC_HEVC_PROFILE_FREXT_GUID, + NV_ENC_AV1_PROFILE_MAIN_GUID, + }; + + const char *szLowLatencyTuningInfoNames = "lowlatency ultralowlatency"; + const char *szTuningInfoNames = "hq lowlatency ultralowlatency lossless uhq"; + std::vector vTuningInfo = std::vector{ + NV_ENC_TUNING_INFO_HIGH_QUALITY, NV_ENC_TUNING_INFO_LOW_LATENCY, + NV_ENC_TUNING_INFO_ULTRA_LOW_LATENCY, NV_ENC_TUNING_INFO_LOSSLESS, + NV_ENC_TUNING_INFO_ULTRA_HIGH_QUALITY}; + + const char *szRcModeNames = "constqp vbr cbr"; + std::vector vRcMode = + std::vector{ + NV_ENC_PARAMS_RC_CONSTQP, + NV_ENC_PARAMS_RC_VBR, + NV_ENC_PARAMS_RC_CBR, + }; + + const char *szMultipass = "disabled qres fullres"; + std::vector vMultiPass = std::vector{ + NV_ENC_MULTI_PASS_DISABLED, + NV_ENC_TWO_PASS_QUARTER_RESOLUTION, + NV_ENC_TWO_PASS_FULL_RESOLUTION, + }; + + const char *szQpMapModeNames = + "disabled emphasis_level_map delta_qp_map qp_map"; + std::vector vQpMapMode = std::vector{ + NV_ENC_QP_MAP_DISABLED, + NV_ENC_QP_MAP_EMPHASIS, + NV_ENC_QP_MAP_DELTA, + NV_ENC_QP_MAP, + }; + + public: + /* + * Generates and returns a string describing the values for each field in + * the NV_ENC_INITIALIZE_PARAMS structure (i.e. a description of the entire + * set of initialization parameters supplied to the API). + */ + std::string FullParamToString( + const NV_ENC_INITIALIZE_PARAMS *pInitializeParams) { + std::ostringstream os; + os << "NV_ENC_INITIALIZE_PARAMS:" << std::endl + << "encodeGUID: " + << ConvertValueToString(vCodec, szCodecNames, + pInitializeParams->encodeGUID) + << std::endl + << "presetGUID: " + << ConvertValueToString(vPreset, szPresetNames, + pInitializeParams->presetGUID) + << std::endl; + if (pInitializeParams->tuningInfo) { + os << "tuningInfo: " + << ConvertValueToString(vTuningInfo, szTuningInfoNames, + pInitializeParams->tuningInfo) + << std::endl; + } + os << "encodeWidth: " << pInitializeParams->encodeWidth << std::endl + << "encodeHeight: " << pInitializeParams->encodeHeight << std::endl + << "darWidth: " << pInitializeParams->darWidth << std::endl + << "darHeight: " << pInitializeParams->darHeight << std::endl + << "frameRateNum: " << pInitializeParams->frameRateNum << std::endl + << "frameRateDen: " << pInitializeParams->frameRateDen << std::endl + << "enableEncodeAsync: " << pInitializeParams->enableEncodeAsync + << std::endl + << "reportSliceOffsets: " << pInitializeParams->reportSliceOffsets + << std::endl + << "enableSubFrameWrite: " << pInitializeParams->enableSubFrameWrite + << std::endl + << "enableExternalMEHints: " << pInitializeParams->enableExternalMEHints + << std::endl + << "enableMEOnlyMode: " << pInitializeParams->enableMEOnlyMode + << std::endl + << "enableWeightedPrediction: " + << pInitializeParams->enableWeightedPrediction << std::endl + << "maxEncodeWidth: " << pInitializeParams->maxEncodeWidth << std::endl + << "maxEncodeHeight: " << pInitializeParams->maxEncodeHeight << std::endl + << "maxMEHintCountsPerBlock: " + << pInitializeParams->maxMEHintCountsPerBlock << std::endl; + NV_ENC_CONFIG *pConfig = pInitializeParams->encodeConfig; + os << "NV_ENC_CONFIG:" << std::endl + << "profile: " + << ConvertValueToString(vProfile, szProfileNames, pConfig->profileGUID) + << std::endl + << "gopLength: " << pConfig->gopLength << std::endl + << "frameIntervalP: " << pConfig->frameIntervalP << std::endl + << "monoChromeEncoding: " << pConfig->monoChromeEncoding << std::endl + << "frameFieldMode: " << pConfig->frameFieldMode << std::endl + << "mvPrecision: " << pConfig->mvPrecision << std::endl + << "NV_ENC_RC_PARAMS:" << std::endl + << " rateControlMode: 0x" << std::hex + << pConfig->rcParams.rateControlMode << std::dec << std::endl + << " constQP: " << pConfig->rcParams.constQP.qpInterP << ", " + << pConfig->rcParams.constQP.qpInterB << ", " + << pConfig->rcParams.constQP.qpIntra << std::endl + << " averageBitRate: " << pConfig->rcParams.averageBitRate + << std::endl + << " maxBitRate: " << pConfig->rcParams.maxBitRate << std::endl + << " vbvBufferSize: " << pConfig->rcParams.vbvBufferSize + << std::endl + << " vbvInitialDelay: " << pConfig->rcParams.vbvInitialDelay + << std::endl + << " enableMinQP: " << pConfig->rcParams.enableMinQP << std::endl + << " enableMaxQP: " << pConfig->rcParams.enableMaxQP << std::endl + << " enableInitialRCQP: " << pConfig->rcParams.enableInitialRCQP + << std::endl + << " enableAQ: " << pConfig->rcParams.enableAQ << std::endl + << " qpMapMode: " + << ConvertValueToString(vQpMapMode, szQpMapModeNames, + pConfig->rcParams.qpMapMode) + << std::endl + << " multipass: " + << ConvertValueToString(vMultiPass, szMultipass, + pConfig->rcParams.multiPass) + << std::endl + << " enableLookahead: " << pConfig->rcParams.enableLookahead + << std::endl + << " disableIadapt: " << pConfig->rcParams.disableIadapt << std::endl + << " disableBadapt: " << pConfig->rcParams.disableBadapt << std::endl + << " enableTemporalAQ: " << pConfig->rcParams.enableTemporalAQ + << std::endl + << " zeroReorderDelay: " << pConfig->rcParams.zeroReorderDelay + << std::endl + << " enableNonRefP: " << pConfig->rcParams.enableNonRefP << std::endl + << " strictGOPTarget: " << pConfig->rcParams.strictGOPTarget + << std::endl + << " aqStrength: " << pConfig->rcParams.aqStrength << std::endl + << " minQP: " << pConfig->rcParams.minQP.qpInterP << ", " + << pConfig->rcParams.minQP.qpInterB << ", " + << pConfig->rcParams.minQP.qpIntra << std::endl + << " maxQP: " << pConfig->rcParams.maxQP.qpInterP << ", " + << pConfig->rcParams.maxQP.qpInterB << ", " + << pConfig->rcParams.maxQP.qpIntra << std::endl + << " initialRCQP: " << pConfig->rcParams.initialRCQP.qpInterP << ", " + << pConfig->rcParams.initialRCQP.qpInterB << ", " + << pConfig->rcParams.initialRCQP.qpIntra << std::endl + << " temporallayerIdxMask: " << pConfig->rcParams.temporallayerIdxMask + << std::endl + << " temporalLayerQP: " << (int)pConfig->rcParams.temporalLayerQP[0] + << ", " << (int)pConfig->rcParams.temporalLayerQP[1] << ", " + << (int)pConfig->rcParams.temporalLayerQP[2] << ", " + << (int)pConfig->rcParams.temporalLayerQP[3] << ", " + << (int)pConfig->rcParams.temporalLayerQP[4] << ", " + << (int)pConfig->rcParams.temporalLayerQP[5] << ", " + << (int)pConfig->rcParams.temporalLayerQP[6] << ", " + << (int)pConfig->rcParams.temporalLayerQP[7] << std::endl + << " targetQuality: " << pConfig->rcParams.targetQuality << std::endl + << " lookaheadDepth: " << pConfig->rcParams.lookaheadDepth + << std::endl; + if (pInitializeParams->encodeGUID == NV_ENC_CODEC_H264_GUID) { + os << "NV_ENC_CODEC_CONFIG (H264):" << std::endl + << " enableStereoMVC: " + << pConfig->encodeCodecConfig.h264Config.enableStereoMVC << std::endl + << " hierarchicalPFrames: " + << pConfig->encodeCodecConfig.h264Config.hierarchicalPFrames + << std::endl + << " hierarchicalBFrames: " + << pConfig->encodeCodecConfig.h264Config.hierarchicalBFrames + << std::endl + << " outputBufferingPeriodSEI: " + << pConfig->encodeCodecConfig.h264Config.outputBufferingPeriodSEI + << std::endl + << " outputPictureTimingSEI: " + << pConfig->encodeCodecConfig.h264Config.outputPictureTimingSEI + << std::endl + << " outputAUD: " << pConfig->encodeCodecConfig.h264Config.outputAUD + << std::endl + << " disableSPSPPS: " + << pConfig->encodeCodecConfig.h264Config.disableSPSPPS << std::endl + << " outputFramePackingSEI: " + << pConfig->encodeCodecConfig.h264Config.outputFramePackingSEI + << std::endl + << " outputRecoveryPointSEI: " + << pConfig->encodeCodecConfig.h264Config.outputRecoveryPointSEI + << std::endl + << " enableIntraRefresh: " + << pConfig->encodeCodecConfig.h264Config.enableIntraRefresh + << std::endl + << " enableConstrainedEncoding: " + << pConfig->encodeCodecConfig.h264Config.enableConstrainedEncoding + << std::endl + << " repeatSPSPPS: " + << pConfig->encodeCodecConfig.h264Config.repeatSPSPPS << std::endl + << " enableVFR: " << pConfig->encodeCodecConfig.h264Config.enableVFR + << std::endl + << " enableLTR: " << pConfig->encodeCodecConfig.h264Config.enableLTR + << std::endl + << " qpPrimeYZeroTransformBypassFlag: " + << pConfig->encodeCodecConfig.h264Config + .qpPrimeYZeroTransformBypassFlag + << std::endl + << " useConstrainedIntraPred: " + << pConfig->encodeCodecConfig.h264Config.useConstrainedIntraPred + << std::endl + << " level: " << pConfig->encodeCodecConfig.h264Config.level + << std::endl + << " idrPeriod: " << pConfig->encodeCodecConfig.h264Config.idrPeriod + << std::endl + << " separateColourPlaneFlag: " + << pConfig->encodeCodecConfig.h264Config.separateColourPlaneFlag + << std::endl + << " disableDeblockingFilterIDC: " + << pConfig->encodeCodecConfig.h264Config.disableDeblockingFilterIDC + << std::endl + << " numTemporalLayers: " + << pConfig->encodeCodecConfig.h264Config.numTemporalLayers << std::endl + << " spsId: " << pConfig->encodeCodecConfig.h264Config.spsId + << std::endl + << " ppsId: " << pConfig->encodeCodecConfig.h264Config.ppsId + << std::endl + << " adaptiveTransformMode: " + << pConfig->encodeCodecConfig.h264Config.adaptiveTransformMode + << std::endl + << " fmoMode: " << pConfig->encodeCodecConfig.h264Config.fmoMode + << std::endl + << " bdirectMode: " + << pConfig->encodeCodecConfig.h264Config.bdirectMode << std::endl + << " entropyCodingMode: " + << pConfig->encodeCodecConfig.h264Config.entropyCodingMode << std::endl + << " stereoMode: " + << pConfig->encodeCodecConfig.h264Config.stereoMode << std::endl + << " intraRefreshPeriod: " + << pConfig->encodeCodecConfig.h264Config.intraRefreshPeriod + << std::endl + << " intraRefreshCnt: " + << pConfig->encodeCodecConfig.h264Config.intraRefreshCnt << std::endl + << " maxNumRefFrames: " + << pConfig->encodeCodecConfig.h264Config.maxNumRefFrames << std::endl + << " sliceMode: " << pConfig->encodeCodecConfig.h264Config.sliceMode + << std::endl + << " sliceModeData: " + << pConfig->encodeCodecConfig.h264Config.sliceModeData << std::endl + << " NV_ENC_CONFIG_H264_VUI_PARAMETERS:" << std::endl + << " overscanInfoPresentFlag: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters + .overscanInfoPresentFlag + << std::endl + << " overscanInfo: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.overscanInfo + << std::endl + << " videoSignalTypePresentFlag: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters + .videoSignalTypePresentFlag + << std::endl + << " videoFormat: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.videoFormat + << std::endl + << " videoFullRangeFlag: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters + .videoFullRangeFlag + << std::endl + << " colourDescriptionPresentFlag: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters + .colourDescriptionPresentFlag + << std::endl + << " colourPrimaries: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters + .colourPrimaries + << std::endl + << " transferCharacteristics: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters + .transferCharacteristics + << std::endl + << " colourMatrix: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters.colourMatrix + << std::endl + << " chromaSampleLocationFlag: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters + .chromaSampleLocationFlag + << std::endl + << " chromaSampleLocationTop: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters + .chromaSampleLocationTop + << std::endl + << " chromaSampleLocationBot: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters + .chromaSampleLocationBot + << std::endl + << " bitstreamRestrictionFlag: " + << pConfig->encodeCodecConfig.h264Config.h264VUIParameters + .bitstreamRestrictionFlag + << std::endl + << " ltrNumFrames: " + << pConfig->encodeCodecConfig.h264Config.ltrNumFrames << std::endl + << " ltrTrustMode: " + << pConfig->encodeCodecConfig.h264Config.ltrTrustMode << std::endl + << " chromaFormatIDC: " + << pConfig->encodeCodecConfig.h264Config.chromaFormatIDC << std::endl + << " maxTemporalLayers: " + << pConfig->encodeCodecConfig.h264Config.maxTemporalLayers + << std::endl; + } else if (pInitializeParams->encodeGUID == NV_ENC_CODEC_HEVC_GUID) { + os << "NV_ENC_CODEC_CONFIG (HEVC):" << std::endl + << " level: " << pConfig->encodeCodecConfig.hevcConfig.level + << std::endl + << " tier: " << pConfig->encodeCodecConfig.hevcConfig.tier + << std::endl + << " minCUSize: " << pConfig->encodeCodecConfig.hevcConfig.minCUSize + << std::endl + << " maxCUSize: " << pConfig->encodeCodecConfig.hevcConfig.maxCUSize + << std::endl + << " useConstrainedIntraPred: " + << pConfig->encodeCodecConfig.hevcConfig.useConstrainedIntraPred + << std::endl + << " disableDeblockAcrossSliceBoundary: " + << pConfig->encodeCodecConfig.hevcConfig + .disableDeblockAcrossSliceBoundary + << std::endl + << " outputBufferingPeriodSEI: " + << pConfig->encodeCodecConfig.hevcConfig.outputBufferingPeriodSEI + << std::endl + << " outputPictureTimingSEI: " + << pConfig->encodeCodecConfig.hevcConfig.outputPictureTimingSEI + << std::endl + << " outputAUD: " << pConfig->encodeCodecConfig.hevcConfig.outputAUD + << std::endl + << " enableLTR: " << pConfig->encodeCodecConfig.hevcConfig.enableLTR + << std::endl + << " disableSPSPPS: " + << pConfig->encodeCodecConfig.hevcConfig.disableSPSPPS << std::endl + << " repeatSPSPPS: " + << pConfig->encodeCodecConfig.hevcConfig.repeatSPSPPS << std::endl + << " enableIntraRefresh: " + << pConfig->encodeCodecConfig.hevcConfig.enableIntraRefresh + << std::endl + << " chromaFormatIDC: " + << pConfig->encodeCodecConfig.hevcConfig.chromaFormatIDC << std::endl + << " inputBitDepth: " + << pConfig->encodeCodecConfig.hevcConfig.inputBitDepth << std::endl + << " outputBitDepth: " + << pConfig->encodeCodecConfig.hevcConfig.outputBitDepth << std::endl + << " idrPeriod: " << pConfig->encodeCodecConfig.hevcConfig.idrPeriod + << std::endl + << " intraRefreshPeriod: " + << pConfig->encodeCodecConfig.hevcConfig.intraRefreshPeriod + << std::endl + << " intraRefreshCnt: " + << pConfig->encodeCodecConfig.hevcConfig.intraRefreshCnt << std::endl + << " maxNumRefFramesInDPB: " + << pConfig->encodeCodecConfig.hevcConfig.maxNumRefFramesInDPB + << std::endl + << " ltrNumFrames: " + << pConfig->encodeCodecConfig.hevcConfig.ltrNumFrames << std::endl + << " vpsId: " << pConfig->encodeCodecConfig.hevcConfig.vpsId + << std::endl + << " spsId: " << pConfig->encodeCodecConfig.hevcConfig.spsId + << std::endl + << " ppsId: " << pConfig->encodeCodecConfig.hevcConfig.ppsId + << std::endl + << " sliceMode: " << pConfig->encodeCodecConfig.hevcConfig.sliceMode + << std::endl + << " sliceModeData: " + << pConfig->encodeCodecConfig.hevcConfig.sliceModeData << std::endl + << " maxTemporalLayersMinus1: " + << pConfig->encodeCodecConfig.hevcConfig.maxTemporalLayersMinus1 + << std::endl + << " NV_ENC_CONFIG_HEVC_VUI_PARAMETERS:" << std::endl + << " overscanInfoPresentFlag: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters + .overscanInfoPresentFlag + << std::endl + << " overscanInfo: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.overscanInfo + << std::endl + << " videoSignalTypePresentFlag: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters + .videoSignalTypePresentFlag + << std::endl + << " videoFormat: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.videoFormat + << std::endl + << " videoFullRangeFlag: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters + .videoFullRangeFlag + << std::endl + << " colourDescriptionPresentFlag: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters + .colourDescriptionPresentFlag + << std::endl + << " colourPrimaries: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters + .colourPrimaries + << std::endl + << " transferCharacteristics: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters + .transferCharacteristics + << std::endl + << " colourMatrix: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters.colourMatrix + << std::endl + << " chromaSampleLocationFlag: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters + .chromaSampleLocationFlag + << std::endl + << " chromaSampleLocationTop: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters + .chromaSampleLocationTop + << std::endl + << " chromaSampleLocationBot: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters + .chromaSampleLocationBot + << std::endl + << " bitstreamRestrictionFlag: " + << pConfig->encodeCodecConfig.hevcConfig.hevcVUIParameters + .bitstreamRestrictionFlag + << std::endl + << " ltrTrustMode: " + << pConfig->encodeCodecConfig.hevcConfig.ltrTrustMode << std::endl; + } else if (pInitializeParams->encodeGUID == NV_ENC_CODEC_AV1_GUID) { + os << "NV_ENC_CODEC_CONFIG (AV1):" << std::endl + << " level: " << pConfig->encodeCodecConfig.av1Config.level + << std::endl + << " tier: " << pConfig->encodeCodecConfig.av1Config.tier + << std::endl + << " minPartSize: " + << pConfig->encodeCodecConfig.av1Config.minPartSize << std::endl + << " maxPartSize: " + << pConfig->encodeCodecConfig.av1Config.maxPartSize << std::endl + << " outputAnnexBFormat: " + << pConfig->encodeCodecConfig.av1Config.outputAnnexBFormat << std::endl + << " enableTimingInfo: " + << pConfig->encodeCodecConfig.av1Config.enableTimingInfo << std::endl + << " enableDecoderModelInfo: " + << pConfig->encodeCodecConfig.av1Config.enableDecoderModelInfo + << std::endl + << " enableFrameIdNumbers: " + << pConfig->encodeCodecConfig.av1Config.enableFrameIdNumbers + << std::endl + << " disableSeqHdr: " + << pConfig->encodeCodecConfig.av1Config.disableSeqHdr << std::endl + << " repeatSeqHdr: " + << pConfig->encodeCodecConfig.av1Config.repeatSeqHdr << std::endl + << " enableIntraRefresh: " + << pConfig->encodeCodecConfig.av1Config.enableIntraRefresh << std::endl + << " chromaFormatIDC: " + << pConfig->encodeCodecConfig.av1Config.chromaFormatIDC << std::endl + << " enableBitstreamPadding: " + << pConfig->encodeCodecConfig.av1Config.enableBitstreamPadding + << std::endl + << " enableCustomTileConfig: " + << pConfig->encodeCodecConfig.av1Config.enableCustomTileConfig + << std::endl + << " enableFilmGrainParams: " + << pConfig->encodeCodecConfig.av1Config.enableFilmGrainParams + << std::endl + << " inputBitDepth: " + << pConfig->encodeCodecConfig.av1Config.inputBitDepth << std::endl + << " outputBitDepth: " + << pConfig->encodeCodecConfig.av1Config.outputBitDepth << std::endl + << " idrPeriod: " << pConfig->encodeCodecConfig.av1Config.idrPeriod + << std::endl + << " intraRefreshPeriod: " + << pConfig->encodeCodecConfig.av1Config.intraRefreshPeriod << std::endl + << " intraRefreshCnt: " + << pConfig->encodeCodecConfig.av1Config.intraRefreshCnt << std::endl + << " maxNumRefFramesInDPB: " + << pConfig->encodeCodecConfig.av1Config.maxNumRefFramesInDPB + << std::endl + << " numTileColumns: " + << pConfig->encodeCodecConfig.av1Config.numTileColumns << std::endl + << " numTileRows: " + << pConfig->encodeCodecConfig.av1Config.numTileRows << std::endl + << " maxTemporalLayersMinus1: " + << pConfig->encodeCodecConfig.av1Config.maxTemporalLayersMinus1 + << std::endl + << " colorPrimaries: " + << pConfig->encodeCodecConfig.av1Config.colorPrimaries << std::endl + << " transferCharacteristics: " + << pConfig->encodeCodecConfig.av1Config.transferCharacteristics + << std::endl + << " matrixCoefficients: " + << pConfig->encodeCodecConfig.av1Config.matrixCoefficients << std::endl + << " colorRange: " + << pConfig->encodeCodecConfig.av1Config.colorRange << std::endl + << " chromaSamplePosition: " + << pConfig->encodeCodecConfig.av1Config.chromaSamplePosition + << std::endl + << " useBFramesAsRef: " + << pConfig->encodeCodecConfig.av1Config.useBFramesAsRef << std::endl + << " numFwdRefs: " + << pConfig->encodeCodecConfig.av1Config.numFwdRefs << std::endl + << " numBwdRefs: " + << pConfig->encodeCodecConfig.av1Config.numBwdRefs << std::endl; + if (pConfig->encodeCodecConfig.av1Config.filmGrainParams != NULL) { + os << " NV_ENC_FILM_GRAIN_PARAMS_AV1:" << std::endl + << " applyGrain: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->applyGrain + << std::endl + << " chromaScalingFromLuma: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams + ->chromaScalingFromLuma + << std::endl + << " overlapFlag: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->overlapFlag + << std::endl + << " clipToRestrictedRange: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams + ->clipToRestrictedRange + << std::endl + << " grainScalingMinus8: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams + ->grainScalingMinus8 + << std::endl + << " arCoeffLag: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->arCoeffLag + << std::endl + << " numYPoints: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->numYPoints + << std::endl + << " numCbPoints: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->numCbPoints + << std::endl + << " numCrPoints: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->numCrPoints + << std::endl + << " arCoeffShiftMinus6: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams + ->arCoeffShiftMinus6 + << std::endl + << " grainScaleShift: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams + ->grainScaleShift + << std::endl + << " cbMult: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->cbMult + << std::endl + << " cbLumaMult: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->cbLumaMult + << std::endl + << " cbOffset: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->cbOffset + << std::endl + << " crMult: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->crMult + << std::endl + << " crLumaMult: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->crLumaMult + << std::endl + << " crOffset: " + << pConfig->encodeCodecConfig.av1Config.filmGrainParams->crOffset + << std::endl; + } + } + + return os.str(); + } }; diff --git a/src/media/video/decode/aom/aom_av1_decoder.cpp b/src/media/video/decode/aom/aom_av1_decoder.cpp index fb3acd7..9e9ba83 100644 --- a/src/media/video/decode/aom/aom_av1_decoder.cpp +++ b/src/media/video/decode/aom/aom_av1_decoder.cpp @@ -2,23 +2,27 @@ #include "log.h" -#define SAVE_RECEIVED_AV1_STREAM 0 -#define SAVE_DECODED_NV12_STREAM 0 +// #define SAVE_DECODED_NV12_STREAM +// #define SAVE_RECEIVED_AV1_STREAM AomAv1Decoder::AomAv1Decoder() {} AomAv1Decoder::~AomAv1Decoder() { - if (SAVE_RECEIVED_AV1_STREAM && file_av1_) { - fflush(file_av1_); - fclose(file_av1_); - file_av1_ = nullptr; - } - - if (SAVE_DECODED_NV12_STREAM && file_nv12_) { +#ifdef SAVE_DECODED_NV12_STREAM + if (file_nv12_) { fflush(file_nv12_); fclose(file_nv12_); file_nv12_ = nullptr; } +#endif + +#ifdef SAVE_RECEIVED_AV1_STREAM + if (file_av1_) { + fflush(file_av1_); + fclose(file_av1_); + file_av1_ = nullptr; + } +#endif if (nv12_frame_) { delete nv12_frame_; @@ -43,29 +47,29 @@ int AomAv1Decoder::Init() { aom_codec_control(&aom_av1_decoder_ctx_, AV1D_GET_IMG_FORMAT, AOM_IMG_FMT_NV12); - if (SAVE_RECEIVED_AV1_STREAM) { - file_av1_ = fopen("received_av1_stream.ivf", "w+b"); - if (!file_av1_) { - LOG_WARN("Fail to open received_av1_stream.ivf"); - } +#ifdef SAVE_DECODED_NV12_STREAM + file_nv12_ = fopen("decoded_nv12_stream.yuv", "w+b"); + if (!file_nv12_) { + LOG_WARN("Fail to open decoded_nv12_stream.yuv"); } +#endif - if (SAVE_DECODED_NV12_STREAM) { - file_nv12_ = fopen("decoded_nv12_stream.yuv", "w+b"); - if (!file_nv12_) { - LOG_WARN("Fail to open decoded_nv12_stream.yuv"); - } +#ifdef SAVE_RECEIVED_AV1_STREAM + file_av1_ = fopen("received_av1_stream.ivf", "w+b"); + if (!file_av1_) { + LOG_WARN("Fail to open received_av1_stream.ivf"); } +#endif return 0; } int AomAv1Decoder::Decode( - const uint8_t *data, int size, + const uint8_t *data, size_t size, std::function on_receive_decoded_frame) { - if (SAVE_RECEIVED_AV1_STREAM) { - fwrite((unsigned char *)data, 1, size, file_av1_); - } +#ifdef SAVE_RECEIVED_AV1_STREAM + fwrite((unsigned char *)data, 1, size, file_av1_); +#endif aom_codec_iter_t iter = nullptr; aom_codec_err_t ret = @@ -105,8 +109,8 @@ int AomAv1Decoder::Decode( } } int corrupted = 0; - int ret = aom_codec_control(&aom_av1_decoder_ctx_, AOMD_GET_FRAME_CORRUPTED, - &corrupted); + ret = aom_codec_control(&aom_av1_decoder_ctx_, AOMD_GET_FRAME_CORRUPTED, + &corrupted); if (ret != AOM_CODEC_OK) { LOG_ERROR("Failed to get frame corrupted"); return -1; @@ -140,10 +144,10 @@ int AomAv1Decoder::Decode( on_receive_decoded_frame(*nv12_frame_); - if (SAVE_DECODED_NV12_STREAM) { - fwrite((unsigned char *)nv12_frame_->Buffer(), 1, nv12_frame_->Size(), - file_nv12_); - } +#ifdef SAVE_DECODED_NV12_STREAM + fwrite((unsigned char *)nv12_frame_->Buffer(), 1, nv12_frame_->Size(), + file_nv12_); +#endif return 0; } diff --git a/src/media/video/decode/aom/aom_av1_decoder.h b/src/media/video/decode/aom/aom_av1_decoder.h index bcca059..4901582 100644 --- a/src/media/video/decode/aom/aom_av1_decoder.h +++ b/src/media/video/decode/aom/aom_av1_decoder.h @@ -22,7 +22,7 @@ class AomAv1Decoder : public VideoDecoder { public: int Init(); - int Decode(const uint8_t *data, int size, + int Decode(const uint8_t *data, size_t size, std::function on_receive_decoded_frame); std::string GetDecoderName() { return "AomAv1"; } @@ -32,8 +32,8 @@ class AomAv1Decoder : public VideoDecoder { int nv12_frame_capacity_ = 0; int nv12_frame_size_ = 0; - int frame_width_ = 0; - int frame_height_ = 0; + uint32_t frame_width_ = 0; + uint32_t frame_height_ = 0; FILE *file_av1_ = nullptr; FILE *file_nv12_ = nullptr; diff --git a/src/media/video/decode/dav1d/dav1d_av1_decoder.cpp b/src/media/video/decode/dav1d/dav1d_av1_decoder.cpp index 3f73f42..5dd815f 100644 --- a/src/media/video/decode/dav1d/dav1d_av1_decoder.cpp +++ b/src/media/video/decode/dav1d/dav1d_av1_decoder.cpp @@ -2,8 +2,8 @@ #include "log.h" -#define SAVE_RECEIVED_AV1_STREAM 0 -#define SAVE_DECODED_NV12_STREAM 0 +// #define SAVE_DECODED_NV12_STREAM +// #define SAVE_RECEIVED_AV1_STREAM #include "libyuv.h" @@ -28,7 +28,8 @@ class ScopedDav1dData { }; // Calling `dav1d_data_wrap` requires a `free_callback` to be registered. -void NullFreeCallback(const uint8_t *buffer, void *opaque) {} +void NullFreeCallback([[maybe_unused]] const uint8_t *buffer, + [[maybe_unused]] void *opaque) {} void Yuv420pToNv12(unsigned char *SrcY, unsigned char *SrcU, unsigned char *SrcV, int y_stride, int uv_stride, @@ -49,17 +50,21 @@ void Yuv420pToNv12(unsigned char *SrcY, unsigned char *SrcU, Dav1dAv1Decoder::Dav1dAv1Decoder() {} Dav1dAv1Decoder::~Dav1dAv1Decoder() { - if (SAVE_RECEIVED_AV1_STREAM && file_av1_) { - fflush(file_av1_); - fclose(file_av1_); - file_av1_ = nullptr; - } - - if (SAVE_DECODED_NV12_STREAM && file_nv12_) { +#ifdef SAVE_DECODED_NV12_STREAM + if (file_nv12_) { fflush(file_nv12_); fclose(file_nv12_); file_nv12_ = nullptr; } +#endif + +#ifdef SAVE_RECEIVED_AV1_STREAM + if (file_av1_) { + fflush(file_av1_); + fclose(file_av1_); + file_av1_ = nullptr; + } +#endif if (nv12_frame_) { delete nv12_frame_; @@ -83,29 +88,29 @@ int Dav1dAv1Decoder::Init() { LOG_ERROR("Dav1d AV1 decoder open failed"); } - if (SAVE_RECEIVED_AV1_STREAM) { - file_av1_ = fopen("received_av1_stream.ivf", "w+b"); - if (!file_av1_) { - LOG_WARN("Fail to open received_av1_stream.ivf"); - } +#ifdef SAVE_DECODED_NV12_STREAM + file_nv12_ = fopen("decoded_nv12_stream.yuv", "w+b"); + if (!file_nv12_) { + LOG_WARN("Fail to open decoded_nv12_stream.yuv"); } +#endif - if (SAVE_DECODED_NV12_STREAM) { - file_nv12_ = fopen("decoded_nv12_stream.yuv", "w+b"); - if (!file_nv12_) { - LOG_WARN("Fail to open decoded_nv12_stream.yuv"); - } +#ifdef SAVE_RECEIVED_AV1_STREAM + file_av1_ = fopen("received_av1_stream.ivf", "w+b"); + if (!file_av1_) { + LOG_WARN("Fail to open received_av1_stream.ivf"); } +#endif return 0; } int Dav1dAv1Decoder::Decode( - const uint8_t *data, int size, + const uint8_t *data, size_t size, std::function on_receive_decoded_frame) { - if (SAVE_RECEIVED_AV1_STREAM) { - fwrite((unsigned char *)data, 1, size, file_av1_); - } +#ifdef SAVE_RECEIVED_AV1_STREAM + fwrite((unsigned char *)data, 1, size, file_av1_); +#endif ScopedDav1dData scoped_dav1d_data; Dav1dData &dav1d_data = scoped_dav1d_data.Data(); @@ -176,14 +181,14 @@ int Dav1dAv1Decoder::Decode( Yuv420pToNv12((unsigned char *)dav1d_picture.data[0], (unsigned char *)dav1d_picture.data[1], (unsigned char *)dav1d_picture.data[2], - dav1d_picture.stride[0], dav1d_picture.stride[1], + (int)dav1d_picture.stride[0], (int)dav1d_picture.stride[1], (unsigned char *)nv12_frame_->Buffer(), frame_width_, frame_height_); } else { libyuv::I420ToNV12( - (const uint8_t *)dav1d_picture.data[0], dav1d_picture.stride[0], - (const uint8_t *)dav1d_picture.data[1], dav1d_picture.stride[1], - (const uint8_t *)dav1d_picture.data[2], dav1d_picture.stride[1], + (const uint8_t *)dav1d_picture.data[0], (int)dav1d_picture.stride[0], + (const uint8_t *)dav1d_picture.data[1], (int)dav1d_picture.stride[1], + (const uint8_t *)dav1d_picture.data[2], (int)dav1d_picture.stride[1], (uint8_t *)nv12_frame_->Buffer(), frame_width_, (uint8_t *)nv12_frame_->Buffer() + frame_width_ * frame_height_, frame_width_, frame_width_, frame_height_); @@ -191,10 +196,10 @@ int Dav1dAv1Decoder::Decode( on_receive_decoded_frame(*nv12_frame_); - if (SAVE_DECODED_NV12_STREAM) { - fwrite((unsigned char *)nv12_frame_->Buffer(), 1, nv12_frame_->Size(), - file_nv12_); - } +#ifdef SAVE_DECODED_NV12_STREAM + fwrite((unsigned char *)nv12_frame_->Buffer(), 1, nv12_frame_->Size(), + file_nv12_); +#endif return 0; } \ No newline at end of file diff --git a/src/media/video/decode/dav1d/dav1d_av1_decoder.h b/src/media/video/decode/dav1d/dav1d_av1_decoder.h index e32a4cc..b0dcf0b 100644 --- a/src/media/video/decode/dav1d/dav1d_av1_decoder.h +++ b/src/media/video/decode/dav1d/dav1d_av1_decoder.h @@ -20,18 +20,18 @@ class Dav1dAv1Decoder : public VideoDecoder { public: int Init(); - int Decode(const uint8_t *data, int size, + int Decode(const uint8_t *data, size_t size, std::function on_receive_decoded_frame); std::string GetDecoderName() { return "Dav1dAv1"; } private: VideoFrame *nv12_frame_ = 0; - int nv12_frame_capacity_ = 0; - int nv12_frame_size_ = 0; + size_t nv12_frame_capacity_ = 0; + size_t nv12_frame_size_ = 0; - int frame_width_ = 0; - int frame_height_ = 0; + uint32_t frame_width_ = 0; + uint32_t frame_height_ = 0; FILE *file_av1_ = nullptr; FILE *file_nv12_ = nullptr; diff --git a/src/media/video/decode/nvcodec/nvidia_video_decoder.cpp b/src/media/video/decode/nvcodec/nvidia_video_decoder.cpp index 1ac7303..3735505 100644 --- a/src/media/video/decode/nvcodec/nvidia_video_decoder.cpp +++ b/src/media/video/decode/nvcodec/nvidia_video_decoder.cpp @@ -3,22 +3,26 @@ #include "log.h" #include "nvcodec_api.h" -#define SAVE_RECEIVED_H264_STREAM 0 -#define SAVE_DECODED_NV12_STREAM 0 +// #define SAVE_DECODED_NV12_STREAM +// #define SAVE_RECEIVED_H264_STREAM NvidiaVideoDecoder::NvidiaVideoDecoder() {} NvidiaVideoDecoder::~NvidiaVideoDecoder() { - if (SAVE_RECEIVED_H264_STREAM && file_h264_) { - fflush(file_h264_); - fclose(file_h264_); - file_h264_ = nullptr; - } - - if (SAVE_DECODED_NV12_STREAM && file_nv12_) { +#ifdef SAVE_DECODED_NV12_STREAM + if (file_nv12_) { fflush(file_nv12_); fclose(file_nv12_); file_nv12_ = nullptr; } +#endif + +#ifdef SAVE_RECEIVED_H264_STREAM + if (file_h264_) { + fflush(file_h264_); + fclose(file_h264_); + file_h264_ = nullptr; + } +#endif } int NvidiaVideoDecoder::Init() { @@ -42,55 +46,55 @@ int NvidiaVideoDecoder::Init() { decoder = new NvDecoder(cuContext, false, cudaVideoCodec_H264, true); - if (SAVE_RECEIVED_H264_STREAM) { - file_h264_ = fopen("received_h264_stream.h264", "w+b"); - if (!file_h264_) { - LOG_WARN("Fail to open received_h264_stream.h264"); - } +#ifdef SAVE_DECODED_NV12_STREAM + file_nv12_ = fopen("decoded_nv12_stream.yuv", "w+b"); + if (!file_nv12_) { + LOG_WARN("Fail to open decoded_nv12_stream.yuv"); } +#endif - if (SAVE_DECODED_NV12_STREAM) { - file_nv12_ = fopen("decoded_nv12_stream.yuv", "w+b"); - if (!file_nv12_) { - LOG_WARN("Fail to open decoded_nv12_stream.yuv"); - } +#ifdef SAVE_RECEIVED_H264_STREAM + file_h264_ = fopen("received_h264_stream.h264", "w+b"); + if (!file_h264_) { + LOG_WARN("Fail to open received_h264_stream.h264"); } +#endif return 0; } int NvidiaVideoDecoder::Decode( - const uint8_t *data, int size, + const uint8_t *data, size_t size, std::function on_receive_decoded_frame) { if (!decoder) { return -1; } - if (SAVE_RECEIVED_H264_STREAM) { - fwrite((unsigned char *)data, 1, size, file_h264_); - } +#ifdef SAVE_RECEIVED_H264_STREAM + fwrite((unsigned char *)data, 1, size, file_h264_); +#endif if ((*(data + 4) & 0x1f) == 0x07) { // LOG_WARN("Receive key frame"); } - int num_frame_returned = decoder->Decode(data, size); - + int num_frame_returned = decoder->Decode(data, (int)size); for (size_t i = 0; i < num_frame_returned; ++i) { cudaVideoSurfaceFormat format = decoder->GetOutputFormat(); if (format == cudaVideoSurfaceFormat_NV12) { - uint8_t *data = nullptr; - data = decoder->GetFrame(); - if (data) { + uint8_t *decoded_frame_buffer = nullptr; + decoded_frame_buffer = decoder->GetFrame(); + if (decoded_frame_buffer) { if (on_receive_decoded_frame) { VideoFrame decoded_frame( - data, decoder->GetWidth() * decoder->GetHeight() * 3 / 2, + decoded_frame_buffer, + decoder->GetWidth() * decoder->GetHeight() * 3 / 2, decoder->GetWidth(), decoder->GetHeight()); on_receive_decoded_frame(decoded_frame); - if (SAVE_DECODED_NV12_STREAM) { - fwrite((unsigned char *)decoded_frame.Buffer(), 1, - decoded_frame.Size(), file_nv12_); - } +#ifdef SAVE_DECODED_NV12_STREAM + fwrite((unsigned char *)decoded_frame.Buffer(), 1, + decoded_frame.Size(), file_nv12_); +#endif } } } diff --git a/src/media/video/decode/nvcodec/nvidia_video_decoder.h b/src/media/video/decode/nvcodec/nvidia_video_decoder.h index a99affe..82556cf 100644 --- a/src/media/video/decode/nvcodec/nvidia_video_decoder.h +++ b/src/media/video/decode/nvcodec/nvidia_video_decoder.h @@ -14,7 +14,7 @@ class NvidiaVideoDecoder : public VideoDecoder { public: int Init(); - int Decode(const uint8_t* data, int size, + int Decode(const uint8_t* data, size_t size, std::function on_receive_decoded_frame); std::string GetDecoderName() { return "NvidiaH264"; } diff --git a/src/media/video/decode/openh264/openh264_decoder.cpp b/src/media/video/decode/openh264/openh264_decoder.cpp index 998a569..e1879dc 100644 --- a/src/media/video/decode/openh264/openh264_decoder.cpp +++ b/src/media/video/decode/openh264/openh264_decoder.cpp @@ -5,8 +5,8 @@ #include "libyuv.h" #include "log.h" -#define SAVE_NV12_STREAM 0 -#define SAVE_H264_STREAM 0 +// #define SAVE_DECODED_NV12_STREAM +// #define SAVE_RECEIVED_H264_STREAM void CopyYuvWithStride(uint8_t *src_y, uint8_t *src_u, uint8_t *src_v, int width, int height, int stride_y, int stride_u, @@ -65,31 +65,35 @@ OpenH264Decoder::~OpenH264Decoder() { delete[] yuv420p_frame_; } - if (SAVE_H264_STREAM && h264_stream_) { - fflush(h264_stream_); - h264_stream_ = nullptr; - } - - if (SAVE_NV12_STREAM && nv12_stream_) { +#ifdef SAVE_DECODED_NV12_STREAM + if (nv12_stream_) { fflush(nv12_stream_); nv12_stream_ = nullptr; } +#endif + +#ifdef SAVE_RECEIVED_H264_STREAM + if (h264_stream_) { + fflush(h264_stream_); + h264_stream_ = nullptr; + } +#endif } int OpenH264Decoder::Init() { - if (SAVE_NV12_STREAM) { - nv12_stream_ = fopen("nv12_receive_.yuv", "w+b"); - if (!nv12_stream_) { - LOG_WARN("Fail to open nv12_receive_.yuv"); - } +#ifdef SAVE_DECODED_NV12_STREAM + nv12_stream_ = fopen("nv12_receive_.yuv", "w+b"); + if (!nv12_stream_) { + LOG_WARN("Fail to open nv12_receive_.yuv"); } +#endif - if (SAVE_NV12_STREAM) { - h264_stream_ = fopen("h264_receive.h264", "w+b"); - if (!h264_stream_) { - LOG_WARN("Fail to open h264_receive.h264"); - } +#ifdef SAVE_RECEIVED_H264_STREAM + h264_stream_ = fopen("h264_receive.h264", "w+b"); + if (!h264_stream_) { + LOG_WARN("Fail to open h264_receive.h264"); } +#endif frame_width_ = 1280; frame_height_ = 720; @@ -115,15 +119,15 @@ int OpenH264Decoder::Init() { } int OpenH264Decoder::Decode( - const uint8_t *data, int size, + const uint8_t *data, size_t size, std::function on_receive_decoded_frame) { if (!openh264_decoder_) { return -1; } - if (SAVE_H264_STREAM) { - fwrite((unsigned char *)data, 1, size, h264_stream_); - } +#ifdef SAVE_RECEIVED_H264_STREAM + fwrite((unsigned char *)data, 1, size, h264_stream_); +#endif if ((*(data + 4) & 0x1f) == 0x07) { // LOG_WARN("Receive key frame"); @@ -132,7 +136,7 @@ int OpenH264Decoder::Decode( SBufferInfo sDstBufInfo; memset(&sDstBufInfo, 0, sizeof(SBufferInfo)); - openh264_decoder_->DecodeFrameNoDelay(data, size, yuv420p_planes_, + openh264_decoder_->DecodeFrameNoDelay(data, (int)size, yuv420p_planes_, &sDstBufInfo); frame_width_ = sDstBufInfo.UsrData.sSystemBuffer.iWidth; @@ -200,10 +204,10 @@ int OpenH264Decoder::Decode( on_receive_decoded_frame(*nv12_frame_); - if (SAVE_NV12_STREAM) { - fwrite((unsigned char *)nv12_frame_->Buffer(), 1, nv12_frame_->Size(), - nv12_stream_); - } +#ifdef SAVE_DECODED_NV12_STREAM + fwrite((unsigned char *)nv12_frame_->Buffer(), 1, nv12_frame_->Size(), + nv12_stream_); +#endif } } diff --git a/src/media/video/decode/openh264/openh264_decoder.h b/src/media/video/decode/openh264/openh264_decoder.h index a979113..4eaf94d 100644 --- a/src/media/video/decode/openh264/openh264_decoder.h +++ b/src/media/video/decode/openh264/openh264_decoder.h @@ -24,7 +24,7 @@ class OpenH264Decoder : public VideoDecoder { public: int Init(); - int Decode(const uint8_t* data, int size, + int Decode(const uint8_t* data, size_t size, std::function on_receive_decoded_frame); std::string GetDecoderName() { return "OpenH264"; } @@ -37,8 +37,8 @@ class OpenH264Decoder : public VideoDecoder { FILE* h264_stream_ = nullptr; uint8_t* decoded_frame_ = nullptr; int decoded_frame_size_ = 0; - int frame_width_ = 1280; - int frame_height_ = 720; + uint32_t frame_width_ = 1280; + uint32_t frame_height_ = 720; unsigned char* yuv420p_planes_[3] = {nullptr, nullptr, nullptr}; unsigned char* yuv420p_frame_ = nullptr; diff --git a/src/media/video/decode/video_decoder.h b/src/media/video/decode/video_decoder.h index 7b5fce6..1a0165a 100644 --- a/src/media/video/decode/video_decoder.h +++ b/src/media/video/decode/video_decoder.h @@ -20,7 +20,7 @@ class VideoDecoder { virtual int Init() = 0; virtual int Decode( - const uint8_t *data, int size, + const uint8_t *data, size_t size, std::function on_receive_decoded_frame) = 0; virtual std::string GetDecoderName() = 0; diff --git a/src/media/video/encode/aom/aom_av1_encoder.cpp b/src/media/video/encode/aom/aom_av1_encoder.cpp index c990fef..42cea10 100644 --- a/src/media/video/encode/aom/aom_av1_encoder.cpp +++ b/src/media/video/encode/aom/aom_av1_encoder.cpp @@ -5,8 +5,8 @@ #include "log.h" -#define SAVE_RECEIVED_NV12_STREAM 0 -#define SAVE_ENCODED_AV1_STREAM 0 +// #define SAVE_RECEIVED_NV12_STREAM +// #define SAVE_ENCODED_AV1_STREAM #define SET_ENCODER_PARAM_OR_RETURN_ERROR(param_id, param_value) \ do { \ @@ -104,17 +104,21 @@ int AomAv1Encoder::ResetEncodeResolution(unsigned int width, AomAv1Encoder::AomAv1Encoder() {} AomAv1Encoder::~AomAv1Encoder() { - if (SAVE_RECEIVED_NV12_STREAM && file_nv12_) { +#ifdef SAVE_RECEIVED_NV12_STREAM + if (file_nv12_) { fflush(file_nv12_); fclose(file_nv12_); file_nv12_ = nullptr; } +#endif - if (SAVE_ENCODED_AV1_STREAM && file_av1_) { +#ifdef SAVE_ENCODED_AV1_STREAM + if (file_av1_) { fflush(file_av1_); fclose(file_av1_); file_av1_ = nullptr; } +#endif delete[] encoded_frame_; encoded_frame_ = nullptr; @@ -245,19 +249,19 @@ int AomAv1Encoder::Init() { frame_for_encode_ = aom_img_wrap(nullptr, AOM_IMG_FMT_NV12, frame_width_, frame_height_, 1, nullptr); - if (SAVE_RECEIVED_NV12_STREAM) { - file_nv12_ = fopen("received_nv12_stream.yuv", "w+b"); - if (!file_nv12_) { - LOG_ERROR("Fail to open received_nv12_stream.yuv"); - } +#ifdef SAVE_RECEIVED_NV12_STREAM + file_nv12_ = fopen("received_nv12_stream.yuv", "w+b"); + if (!file_nv12_) { + LOG_ERROR("Fail to open received_nv12_stream.yuv"); } +#endif - if (SAVE_ENCODED_AV1_STREAM) { - file_av1_ = fopen("encoded_av1_stream.ivf", "w+b"); - if (!file_av1_) { - LOG_ERROR("Fail to open encoded_av1_stream.ivf"); - } +#ifdef SAVE_ENCODED_AV1_STREAM + file_av1_ = fopen("encoded_av1_stream.ivf", "w+b"); + if (!file_av1_) { + LOG_ERROR("Fail to open encoded_av1_stream.ivf"); } +#endif return 0; } @@ -266,9 +270,9 @@ int AomAv1Encoder::Encode(const XVideoFrame *video_frame, std::function on_encoded_image) { - if (SAVE_RECEIVED_NV12_STREAM) { - fwrite(video_frame->data, 1, video_frame->size, file_nv12_); - } +#ifdef SAVE_RECEIVED_NV12_STREAM + fwrite(video_frame->data, 1, video_frame->size, file_nv12_); +#endif aom_codec_err_t ret = AOM_CODEC_OK; @@ -293,7 +297,7 @@ int AomAv1Encoder::Encode(const XVideoFrame *video_frame, } const uint32_t duration = - kRtpTicksPerSecond / static_cast(max_frame_rate_); + (uint32_t)(kRtpTicksPerSecond / static_cast(max_frame_rate_)); timestamp_ += duration; frame_for_encode_->planes[AOM_PLANE_Y] = (unsigned char *)(video_frame->data); @@ -327,7 +331,6 @@ int AomAv1Encoder::Encode(const XVideoFrame *video_frame, } aom_codec_iter_t iter = nullptr; - int data_pkt_count = 0; while (const aom_codec_cx_pkt_t *pkt = aom_codec_get_cx_data(&aom_av1_encoder_ctx_, &iter)) { if (pkt->kind == AOM_CODEC_CX_FRAME_PKT && pkt->data.frame.sz > 0) { @@ -341,11 +344,9 @@ int AomAv1Encoder::Encode(const XVideoFrame *video_frame, if (on_encoded_image) { on_encoded_image((char *)encoded_frame_, encoded_frame_size_, frame_type); - if (SAVE_ENCODED_AV1_STREAM) { - fwrite(encoded_frame_, 1, encoded_frame_size_, file_av1_); - } - } else { - OnEncodedImage((char *)encoded_frame_, encoded_frame_size_); +#ifdef SAVE_ENCODED_AV1_STREAM + fwrite(encoded_frame_, 1, encoded_frame_size_, file_av1_); +#endif } } } @@ -353,11 +354,6 @@ int AomAv1Encoder::Encode(const XVideoFrame *video_frame, return 0; } -int AomAv1Encoder::OnEncodedImage(char *encoded_packets, size_t size) { - LOG_INFO("OnEncodedImage not implemented"); - return 0; -} - int AomAv1Encoder::ForceIdr() { force_i_frame_flags_ = AOM_EFLAG_FORCE_KF; return 0; diff --git a/src/media/video/encode/aom/aom_av1_encoder.h b/src/media/video/encode/aom/aom_av1_encoder.h index 226114d..191d320 100644 --- a/src/media/video/encode/aom/aom_av1_encoder.h +++ b/src/media/video/encode/aom/aom_av1_encoder.h @@ -36,20 +36,12 @@ class AomAv1Encoder : public VideoEncoder { public: int Init(); - int Encode(const uint8_t* pData, int nSize, - std::function - on_encoded_image) { - return 0; - } int Encode(const XVideoFrame* video_frame, std::function on_encoded_image); - int OnEncodedImage(char* encoded_packets, size_t size); - int ForceIdr(); std::string GetEncoderName() { return "AomAV1"; } @@ -65,8 +57,8 @@ class AomAv1Encoder : public VideoEncoder { int Release(); private: - int frame_width_ = 1280; - int frame_height_ = 720; + uint32_t frame_width_ = 1280; + uint32_t frame_height_ = 720; int key_frame_interval_ = 300; int target_bitrate_ = 1000; int max_bitrate_ = 2500000; @@ -91,7 +83,7 @@ class AomAv1Encoder : public VideoEncoder { aom_enc_frame_flags_t force_i_frame_flags_ = 0; uint8_t* encoded_frame_ = nullptr; size_t encoded_frame_capacity_ = 0; - int encoded_frame_size_ = 0; + size_t encoded_frame_size_ = 0; }; #endif \ No newline at end of file diff --git a/src/media/video/encode/nvcodec/nvidia_video_encoder.cpp b/src/media/video/encode/nvcodec/nvidia_video_encoder.cpp index 3887051..987e351 100644 --- a/src/media/video/encode/nvcodec/nvidia_video_encoder.cpp +++ b/src/media/video/encode/nvcodec/nvidia_video_encoder.cpp @@ -6,22 +6,26 @@ #include "nvcodec_api.h" #include "nvcodec_common.h" -#define SAVE_RECEIVED_NV12_STREAM 0 -#define SAVE_ENCODED_H264_STREAM 0 +// #define SAVE_RECEIVED_NV12_STREAM +// #define SAVE_ENCODED_H264_STREAM NvidiaVideoEncoder::NvidiaVideoEncoder() {} NvidiaVideoEncoder::~NvidiaVideoEncoder() { - if (SAVE_RECEIVED_NV12_STREAM && file_nv12_) { +#ifdef SAVE_RECEIVED_NV12_STREAM + if (file_nv12_) { fflush(file_nv12_); fclose(file_nv12_); file_nv12_ = nullptr; } +#endif - if (SAVE_ENCODED_H264_STREAM && file_h264_) { +#ifdef SAVE_ENCODED_H264_STREAM + if (file_h264_) { fflush(file_h264_); fclose(file_h264_); file_h264_ = nullptr; } +#endif if (nv12_data_) { free(nv12_data_); @@ -106,19 +110,20 @@ int NvidiaVideoEncoder::Init() { encoder_->CreateEncoder(&init_params); - if (SAVE_RECEIVED_NV12_STREAM) { - file_nv12_ = fopen("received_nv12_stream.yuv", "w+b"); - if (!file_nv12_) { - LOG_WARN("Fail to open received_nv12_stream.yuv"); - } +#ifdef SAVE_RECEIVED_NV12_STREAM + file_nv12_ = fopen("received_nv12_stream.yuv", "w+b"); + if (!file_nv12_) { + LOG_WARN("Fail to open received_nv12_stream.yuv"); } - if (SAVE_ENCODED_H264_STREAM) { - file_h264_ = fopen("encoded_h264_stream.h264", "w+b"); - if (!file_h264_) { - LOG_WARN("Fail to open encoded_h264_stream.h264"); - } +#endif + +#ifdef SAVE_ENCODED_H264_STREAM + file_h264_ = fopen("encoded_h264_stream.h264", "w+b"); + if (!file_h264_) { + LOG_WARN("Fail to open encoded_h264_stream.h264"); } +#endif return 0; } @@ -133,9 +138,9 @@ int NvidiaVideoEncoder::Encode( return -1; } - if (SAVE_RECEIVED_NV12_STREAM) { - fwrite(video_frame->data, 1, video_frame->size, file_nv12_); - } +#ifdef SAVE_RECEIVED_NV12_STREAM + fwrite(video_frame->data, 1, video_frame->size, file_nv12_); +#endif if (video_frame->width != frame_width_ || video_frame->height != frame_height_) { @@ -178,11 +183,9 @@ int NvidiaVideoEncoder::Encode( for (const auto &packet : encoded_packets_) { if (on_encoded_image) { on_encoded_image((char *)packet.data(), packet.size(), frame_type); - if (SAVE_ENCODED_H264_STREAM) { - fwrite((unsigned char *)packet.data(), 1, packet.size(), file_h264_); - } - } else { - OnEncodedImage((char *)packet.data(), packet.size()); +#ifdef SAVE_ENCODED_H264_STREAM + fwrite((unsigned char *)packet.data(), 1, packet.size(), file_h264_); +#endif } } @@ -196,11 +199,6 @@ int NvidiaVideoEncoder::Encode( return 0; } -int NvidiaVideoEncoder::OnEncodedImage(char *encoded_packets, size_t size) { - LOG_INFO("OnEncodedImage not implemented"); - return 0; -} - int NvidiaVideoEncoder::ForceIdr() { if (!encoder_) { return -1; diff --git a/src/media/video/encode/nvcodec/nvidia_video_encoder.h b/src/media/video/encode/nvcodec/nvidia_video_encoder.h index cc9e82f..8827d01 100644 --- a/src/media/video/encode/nvcodec/nvidia_video_encoder.h +++ b/src/media/video/encode/nvcodec/nvidia_video_encoder.h @@ -12,20 +12,12 @@ class NvidiaVideoEncoder : public VideoEncoder { virtual ~NvidiaVideoEncoder(); int Init(); - int Encode(const uint8_t* pData, int nSize, - std::function - on_encoded_image) { - return 0; - } int Encode(const XVideoFrame* video_frame, std::function on_encoded_image); - virtual int OnEncodedImage(char* encoded_packets, size_t size); - int ForceIdr(); std::string GetEncoderName() { return "NvidiaH264"; } diff --git a/src/media/video/encode/openh264/openh264_encoder.cpp b/src/media/video/encode/openh264/openh264_encoder.cpp index b34b708..eb8f425 100644 --- a/src/media/video/encode/openh264/openh264_encoder.cpp +++ b/src/media/video/encode/openh264/openh264_encoder.cpp @@ -5,17 +5,15 @@ #include "libyuv.h" #include "log.h" -#define SAVE_RECEIVED_NV12_STREAM 0 -#define SAVE_ENCODED_H264_STREAM 0 +// #define SAVE_RECEIVED_NV12_STREAM +// #define SAVE_ENCODED_H264_STREAM void Nv12ToI420(unsigned char *Src_data, int src_width, int src_height, unsigned char *Dst_data) { - // NV12 video size - int NV12_Size = src_width * src_height * 3 / 2; + // NV12 int NV12_Y_Size = src_width * src_height; - // YUV420 video size - int I420_Size = src_width * src_height * 3 / 2; + // YUV420 int I420_Y_Size = src_width * src_height; int I420_U_Size = (src_width >> 1) * (src_height >> 1); int I420_V_Size = I420_U_Size; @@ -29,7 +27,7 @@ void Nv12ToI420(unsigned char *Src_data, int src_width, int src_height, // dst: buffer address of Y channel、U channel and V channel unsigned char *Y_data_Dst = Dst_data; unsigned char *U_data_Dst = Dst_data + I420_Y_Size; - unsigned char *V_data_Dst = Dst_data + I420_Y_Size + I420_U_Size; + unsigned char *V_data_Dst = Dst_data + I420_Y_Size + I420_V_Size; int Dst_Stride_Y = src_width; int Dst_Stride_U = src_width >> 1; int Dst_Stride_V = Dst_Stride_U; @@ -43,17 +41,21 @@ void Nv12ToI420(unsigned char *Src_data, int src_width, int src_height, OpenH264Encoder::OpenH264Encoder() {} OpenH264Encoder::~OpenH264Encoder() { - if (SAVE_RECEIVED_NV12_STREAM && file_nv12_) { +#ifdef SAVE_RECEIVED_NV12_STREAM + if (file_nv12_) { fflush(file_nv12_); fclose(file_nv12_); file_nv12_ = nullptr; } +#endif - if (SAVE_ENCODED_H264_STREAM && file_h264_) { +#ifdef SAVE_ENCODED_H264_STREAM + if (file_h264_) { fflush(file_h264_); fclose(file_h264_); file_h264_ = nullptr; } +#endif if (yuv420p_frame_) { delete[] yuv420p_frame_; @@ -160,19 +162,19 @@ int OpenH264Encoder::Init() { video_format_ = EVideoFormatType::videoFormatI420; openh264_encoder_->SetOption(ENCODER_OPTION_DATAFORMAT, &video_format_); - if (SAVE_RECEIVED_NV12_STREAM) { - file_nv12_ = fopen("received_nv12_stream.yuv", "w+b"); - if (!file_nv12_) { - LOG_WARN("Fail to open received_nv12_stream.yuv"); - } +#ifdef SAVE_RECEIVED_NV12_STREAM + file_nv12_ = fopen("received_nv12_stream.yuv", "w+b"); + if (!file_nv12_) { + LOG_WARN("Fail to open received_nv12_stream.yuv"); } +#endif - if (SAVE_ENCODED_H264_STREAM) { - file_h264_ = fopen("encoded_h264_stream.h264", "w+b"); - if (!file_h264_) { - LOG_WARN("Fail to open encoded_h264_stream.h264"); - } +#ifdef SAVE_ENCODED_H264_STREAM + file_h264_ = fopen("encoded_h264_stream.h264", "w+b"); + if (!file_h264_) { + LOG_WARN("Fail to open encoded_h264_stream.h264"); } +#endif return 0; } @@ -187,9 +189,9 @@ int OpenH264Encoder::Encode( return -1; } - if (SAVE_RECEIVED_NV12_STREAM) { - fwrite(video_frame->data, 1, video_frame->size, file_nv12_); - } +#ifdef SAVE_RECEIVED_NV12_STREAM + fwrite(video_frame->data, 1, video_frame->size, file_nv12_); +#endif if (!yuv420p_frame_) { yuv420p_frame_capacity_ = video_frame->size; @@ -267,7 +269,7 @@ int OpenH264Encoder::Encode( } size_t frag = 0; - int encoded_frame_size = 0; + size_t encoded_frame_size = 0; for (int layer = 0; layer < info.iLayerNum; ++layer) { const SLayerBSInfo &layerInfo = info.sLayerInfo[layer]; size_t layer_len = 0; @@ -281,11 +283,9 @@ int OpenH264Encoder::Encode( if (on_encoded_image) { on_encoded_image((char *)encoded_frame_, encoded_frame_size_, frame_type); - if (SAVE_ENCODED_H264_STREAM) { - fwrite(encoded_frame_, 1, encoded_frame_size_, file_h264_); - } - } else { - OnEncodedImage((char *)encoded_frame_, encoded_frame_size_); +#ifdef SAVE_ENCODED_H264_STREAM + fwrite(encoded_frame_, 1, encoded_frame_size_, file_h264_); +#endif } #else if (info.eFrameType == videoFrameTypeInvalid) { @@ -327,11 +327,9 @@ int OpenH264Encoder::Encode( if (on_encoded_image) { on_encoded_image((char *)encoded_frame_, frame_type); - if (SAVE_ENCODED_H264_STREAM) { - fwrite(encoded_frame_, 1, encoded_frame_size_, file_h264_); - } - } else { - OnEncodedImage((char *)encoded_frame_, encoded_frame_size_); +#ifdef SAVE_ENCODED_H264_STREAM + fwrite(encoded_frame_, 1, encoded_frame_size_, file_h264_); +#endif } EVideoFrameType ft_temp = info.eFrameType; @@ -353,11 +351,6 @@ int OpenH264Encoder::Encode( return 0; } -int OpenH264Encoder::OnEncodedImage(char *encoded_packets, size_t size) { - LOG_INFO("OnEncodedImage not implemented"); - return 0; -} - int OpenH264Encoder::ForceIdr() { if (openh264_encoder_) { return openh264_encoder_->ForceIntraFrame(true); diff --git a/src/media/video/encode/openh264/openh264_encoder.h b/src/media/video/encode/openh264/openh264_encoder.h index 72a52d4..9e61b73 100644 --- a/src/media/video/encode/openh264/openh264_encoder.h +++ b/src/media/video/encode/openh264/openh264_encoder.h @@ -23,20 +23,12 @@ class OpenH264Encoder : public VideoEncoder { virtual ~OpenH264Encoder(); int Init(); - int Encode(const uint8_t* pData, int nSize, - std::function - on_encoded_image) { - return 0; - } int Encode(const XVideoFrame* video_frame, std::function on_encoded_image); - int OnEncodedImage(char* encoded_packets, size_t size); - int ForceIdr(); std::string GetEncoderName() { return "OpenH264"; } @@ -48,8 +40,8 @@ class OpenH264Encoder : public VideoEncoder { int Release(); private: - int frame_width_ = 1280; - int frame_height_ = 720; + uint32_t frame_width_ = 1280; + uint32_t frame_height_ = 720; int key_frame_interval_ = 300; int target_bitrate_ = 10000000; int max_bitrate_ = 10000000; @@ -68,10 +60,10 @@ class OpenH264Encoder : public VideoEncoder { int video_format_; SSourcePicture raw_frame_; unsigned char* yuv420p_frame_ = nullptr; - int yuv420p_frame_capacity_ = 0; + size_t yuv420p_frame_capacity_ = 0; uint8_t* encoded_frame_ = nullptr; - int encoded_frame_capacity_ = 0; - int encoded_frame_size_ = 0; + size_t encoded_frame_capacity_ = 0; + size_t encoded_frame_size_ = 0; bool got_output = false; bool is_keyframe = false; int temporal_ = 1; diff --git a/src/media/video/encode/video_encoder.h b/src/media/video/encode/video_encoder.h index 8d451ff..6041b93 100644 --- a/src/media/video/encode/video_encoder.h +++ b/src/media/video/encode/video_encoder.h @@ -20,18 +20,11 @@ class VideoEncoder { public: virtual int Init() = 0; - virtual int Encode(const uint8_t* pData, int nSize, - std::function - on_encoded_image) = 0; - virtual int Encode(const XVideoFrame* video_frame, std::function on_encoded_image) = 0; - virtual int OnEncodedImage(char* encoded_packets, size_t size) = 0; - virtual int ForceIdr() = 0; virtual std::string GetEncoderName() = 0; diff --git a/src/qos/kcp/ikcp.c b/src/qos/kcp/ikcp.c deleted file mode 100644 index 6e14bc9..0000000 --- a/src/qos/kcp/ikcp.c +++ /dev/null @@ -1,1306 +0,0 @@ -//===================================================================== -// -// KCP - A Better ARQ Protocol Implementation -// skywind3000 (at) gmail.com, 2010-2011 -// -// Features: -// + Average RTT reduce 30% - 40% vs traditional ARQ like tcp. -// + Maximum RTT reduce three times vs tcp. -// + Lightweight, distributed as a single source file. -// -//===================================================================== -#include "ikcp.h" - -#include -#include -#include -#include -#include - - - -//===================================================================== -// KCP BASIC -//===================================================================== -const IUINT32 IKCP_RTO_NDL = 30; // no delay min rto -const IUINT32 IKCP_RTO_MIN = 100; // normal min rto -const IUINT32 IKCP_RTO_DEF = 200; -const IUINT32 IKCP_RTO_MAX = 60000; -const IUINT32 IKCP_CMD_PUSH = 81; // cmd: push data -const IUINT32 IKCP_CMD_ACK = 82; // cmd: ack -const IUINT32 IKCP_CMD_WASK = 83; // cmd: window probe (ask) -const IUINT32 IKCP_CMD_WINS = 84; // cmd: window size (tell) -const IUINT32 IKCP_ASK_SEND = 1; // need to send IKCP_CMD_WASK -const IUINT32 IKCP_ASK_TELL = 2; // need to send IKCP_CMD_WINS -const IUINT32 IKCP_WND_SND = 32; -const IUINT32 IKCP_WND_RCV = 128; // must >= max fragment size -const IUINT32 IKCP_MTU_DEF = 1400; -const IUINT32 IKCP_ACK_FAST = 3; -const IUINT32 IKCP_INTERVAL = 100; -const IUINT32 IKCP_OVERHEAD = 24; -const IUINT32 IKCP_DEADLINK = 20; -const IUINT32 IKCP_THRESH_INIT = 2; -const IUINT32 IKCP_THRESH_MIN = 2; -const IUINT32 IKCP_PROBE_INIT = 7000; // 7 secs to probe window size -const IUINT32 IKCP_PROBE_LIMIT = 120000; // up to 120 secs to probe window -const IUINT32 IKCP_FASTACK_LIMIT = 5; // max times to trigger fastack - - -//--------------------------------------------------------------------- -// encode / decode -//--------------------------------------------------------------------- - -/* encode 8 bits unsigned int */ -static inline char *ikcp_encode8u(char *p, unsigned char c) -{ - *(unsigned char*)p++ = c; - return p; -} - -/* decode 8 bits unsigned int */ -static inline const char *ikcp_decode8u(const char *p, unsigned char *c) -{ - *c = *(unsigned char*)p++; - return p; -} - -/* encode 16 bits unsigned int (lsb) */ -static inline char *ikcp_encode16u(char *p, unsigned short w) -{ -#if IWORDS_BIG_ENDIAN || IWORDS_MUST_ALIGN - *(unsigned char*)(p + 0) = (w & 255); - *(unsigned char*)(p + 1) = (w >> 8); -#else - memcpy(p, &w, 2); -#endif - p += 2; - return p; -} - -/* decode 16 bits unsigned int (lsb) */ -static inline const char *ikcp_decode16u(const char *p, unsigned short *w) -{ -#if IWORDS_BIG_ENDIAN || IWORDS_MUST_ALIGN - *w = *(const unsigned char*)(p + 1); - *w = *(const unsigned char*)(p + 0) + (*w << 8); -#else - memcpy(w, p, 2); -#endif - p += 2; - return p; -} - -/* encode 32 bits unsigned int (lsb) */ -static inline char *ikcp_encode32u(char *p, IUINT32 l) -{ -#if IWORDS_BIG_ENDIAN || IWORDS_MUST_ALIGN - *(unsigned char*)(p + 0) = (unsigned char)((l >> 0) & 0xff); - *(unsigned char*)(p + 1) = (unsigned char)((l >> 8) & 0xff); - *(unsigned char*)(p + 2) = (unsigned char)((l >> 16) & 0xff); - *(unsigned char*)(p + 3) = (unsigned char)((l >> 24) & 0xff); -#else - memcpy(p, &l, 4); -#endif - p += 4; - return p; -} - -/* decode 32 bits unsigned int (lsb) */ -static inline const char *ikcp_decode32u(const char *p, IUINT32 *l) -{ -#if IWORDS_BIG_ENDIAN || IWORDS_MUST_ALIGN - *l = *(const unsigned char*)(p + 3); - *l = *(const unsigned char*)(p + 2) + (*l << 8); - *l = *(const unsigned char*)(p + 1) + (*l << 8); - *l = *(const unsigned char*)(p + 0) + (*l << 8); -#else - memcpy(l, p, 4); -#endif - p += 4; - return p; -} - -static inline IUINT32 _imin_(IUINT32 a, IUINT32 b) { - return a <= b ? a : b; -} - -static inline IUINT32 _imax_(IUINT32 a, IUINT32 b) { - return a >= b ? a : b; -} - -static inline IUINT32 _ibound_(IUINT32 lower, IUINT32 middle, IUINT32 upper) -{ - return _imin_(_imax_(lower, middle), upper); -} - -static inline long _itimediff(IUINT32 later, IUINT32 earlier) -{ - return ((IINT32)(later - earlier)); -} - -//--------------------------------------------------------------------- -// manage segment -//--------------------------------------------------------------------- -typedef struct IKCPSEG IKCPSEG; - -static void* (*ikcp_malloc_hook)(size_t) = NULL; -static void (*ikcp_free_hook)(void *) = NULL; - -// internal malloc -static void* ikcp_malloc(size_t size) { - if (ikcp_malloc_hook) - return ikcp_malloc_hook(size); - return malloc(size); -} - -// internal free -static void ikcp_free(void *ptr) { - if (ikcp_free_hook) { - ikcp_free_hook(ptr); - } else { - free(ptr); - } -} - -// redefine allocator -void ikcp_allocator(void* (*new_malloc)(size_t), void (*new_free)(void*)) -{ - ikcp_malloc_hook = new_malloc; - ikcp_free_hook = new_free; -} - -// allocate a new kcp segment -static IKCPSEG* ikcp_segment_new(ikcpcb *kcp, int size) -{ - return (IKCPSEG*)ikcp_malloc(sizeof(IKCPSEG) + size); -} - -// delete a segment -static void ikcp_segment_delete(ikcpcb *kcp, IKCPSEG *seg) -{ - ikcp_free(seg); -} - -// write log -void ikcp_log(ikcpcb *kcp, int mask, const char *fmt, ...) -{ - char buffer[1024]; - va_list argptr; - if ((mask & kcp->logmask) == 0 || kcp->writelog == 0) return; - va_start(argptr, fmt); - vsprintf(buffer, fmt, argptr); - va_end(argptr); - kcp->writelog(buffer, kcp, kcp->user); -} - -// check log mask -static int ikcp_canlog(const ikcpcb *kcp, int mask) -{ - if ((mask & kcp->logmask) == 0 || kcp->writelog == NULL) return 0; - return 1; -} - -// output segment -static int ikcp_output(ikcpcb *kcp, const void *data, int size) -{ - assert(kcp); - assert(kcp->output); - if (ikcp_canlog(kcp, IKCP_LOG_OUTPUT)) { - ikcp_log(kcp, IKCP_LOG_OUTPUT, "[RO] %ld bytes", (long)size); - } - if (size == 0) return 0; - return kcp->output((const char*)data, size, kcp, kcp->user); -} - -// output queue -void ikcp_qprint(const char *name, const struct IQUEUEHEAD *head) -{ -#if 0 - const struct IQUEUEHEAD *p; - printf("<%s>: [", name); - for (p = head->next; p != head; p = p->next) { - const IKCPSEG *seg = iqueue_entry(p, const IKCPSEG, node); - printf("(%lu %d)", (unsigned long)seg->sn, (int)(seg->ts % 10000)); - if (p->next != head) printf(","); - } - printf("]\n"); -#endif -} - - -//--------------------------------------------------------------------- -// create a new kcpcb -//--------------------------------------------------------------------- -ikcpcb* ikcp_create(IUINT32 conv, void *user) -{ - ikcpcb *kcp = (ikcpcb*)ikcp_malloc(sizeof(struct IKCPCB)); - if (kcp == NULL) return NULL; - kcp->conv = conv; - kcp->user = user; - kcp->snd_una = 0; - kcp->snd_nxt = 0; - kcp->rcv_nxt = 0; - kcp->ts_recent = 0; - kcp->ts_lastack = 0; - kcp->ts_probe = 0; - kcp->probe_wait = 0; - kcp->snd_wnd = IKCP_WND_SND; - kcp->rcv_wnd = IKCP_WND_RCV; - kcp->rmt_wnd = IKCP_WND_RCV; - kcp->cwnd = 0; - kcp->incr = 0; - kcp->probe = 0; - kcp->mtu = IKCP_MTU_DEF; - kcp->mss = kcp->mtu - IKCP_OVERHEAD; - kcp->stream = 0; - - kcp->buffer = (char*)ikcp_malloc((kcp->mtu + IKCP_OVERHEAD) * 3); - if (kcp->buffer == NULL) { - ikcp_free(kcp); - return NULL; - } - - iqueue_init(&kcp->snd_queue); - iqueue_init(&kcp->rcv_queue); - iqueue_init(&kcp->snd_buf); - iqueue_init(&kcp->rcv_buf); - kcp->nrcv_buf = 0; - kcp->nsnd_buf = 0; - kcp->nrcv_que = 0; - kcp->nsnd_que = 0; - kcp->state = 0; - kcp->acklist = NULL; - kcp->ackblock = 0; - kcp->ackcount = 0; - kcp->rx_srtt = 0; - kcp->rx_rttval = 0; - kcp->rx_rto = IKCP_RTO_DEF; - kcp->rx_minrto = IKCP_RTO_MIN; - kcp->current = 0; - kcp->interval = IKCP_INTERVAL; - kcp->ts_flush = IKCP_INTERVAL; - kcp->nodelay = 0; - kcp->updated = 0; - kcp->logmask = 0; - kcp->ssthresh = IKCP_THRESH_INIT; - kcp->fastresend = 0; - kcp->fastlimit = IKCP_FASTACK_LIMIT; - kcp->nocwnd = 0; - kcp->xmit = 0; - kcp->dead_link = IKCP_DEADLINK; - kcp->output = NULL; - kcp->writelog = NULL; - - return kcp; -} - - -//--------------------------------------------------------------------- -// release a new kcpcb -//--------------------------------------------------------------------- -void ikcp_release(ikcpcb *kcp) -{ - assert(kcp); - if (kcp) { - IKCPSEG *seg; - while (!iqueue_is_empty(&kcp->snd_buf)) { - seg = iqueue_entry(kcp->snd_buf.next, IKCPSEG, node); - iqueue_del(&seg->node); - ikcp_segment_delete(kcp, seg); - } - while (!iqueue_is_empty(&kcp->rcv_buf)) { - seg = iqueue_entry(kcp->rcv_buf.next, IKCPSEG, node); - iqueue_del(&seg->node); - ikcp_segment_delete(kcp, seg); - } - while (!iqueue_is_empty(&kcp->snd_queue)) { - seg = iqueue_entry(kcp->snd_queue.next, IKCPSEG, node); - iqueue_del(&seg->node); - ikcp_segment_delete(kcp, seg); - } - while (!iqueue_is_empty(&kcp->rcv_queue)) { - seg = iqueue_entry(kcp->rcv_queue.next, IKCPSEG, node); - iqueue_del(&seg->node); - ikcp_segment_delete(kcp, seg); - } - if (kcp->buffer) { - ikcp_free(kcp->buffer); - } - if (kcp->acklist) { - ikcp_free(kcp->acklist); - } - - kcp->nrcv_buf = 0; - kcp->nsnd_buf = 0; - kcp->nrcv_que = 0; - kcp->nsnd_que = 0; - kcp->ackcount = 0; - kcp->buffer = NULL; - kcp->acklist = NULL; - ikcp_free(kcp); - } -} - - -//--------------------------------------------------------------------- -// set output callback, which will be invoked by kcp -//--------------------------------------------------------------------- -void ikcp_setoutput(ikcpcb *kcp, int (*output)(const char *buf, int len, - ikcpcb *kcp, void *user)) -{ - kcp->output = output; -} - - -//--------------------------------------------------------------------- -// user/upper level recv: returns size, returns below zero for EAGAIN -//--------------------------------------------------------------------- -int ikcp_recv(ikcpcb *kcp, char *buffer, int len) -{ - struct IQUEUEHEAD *p; - int ispeek = (len < 0)? 1 : 0; - int peeksize; - int recover = 0; - IKCPSEG *seg; - assert(kcp); - - if (iqueue_is_empty(&kcp->rcv_queue)) - return -1; - - if (len < 0) len = -len; - - peeksize = ikcp_peeksize(kcp); - - if (peeksize < 0) - return -2; - - if (peeksize > len) - return -3; - - if (kcp->nrcv_que >= kcp->rcv_wnd) - recover = 1; - - // merge fragment - for (len = 0, p = kcp->rcv_queue.next; p != &kcp->rcv_queue; ) { - int fragment; - seg = iqueue_entry(p, IKCPSEG, node); - p = p->next; - - if (buffer) { - memcpy(buffer, seg->data, seg->len); - buffer += seg->len; - } - - len += seg->len; - fragment = seg->frg; - - if (ikcp_canlog(kcp, IKCP_LOG_RECV)) { - ikcp_log(kcp, IKCP_LOG_RECV, "recv sn=%lu", (unsigned long)seg->sn); - } - - if (ispeek == 0) { - iqueue_del(&seg->node); - ikcp_segment_delete(kcp, seg); - kcp->nrcv_que--; - } - - if (fragment == 0) - break; - } - - assert(len == peeksize); - - // move available data from rcv_buf -> rcv_queue - while (! iqueue_is_empty(&kcp->rcv_buf)) { - seg = iqueue_entry(kcp->rcv_buf.next, IKCPSEG, node); - if (seg->sn == kcp->rcv_nxt && kcp->nrcv_que < kcp->rcv_wnd) { - iqueue_del(&seg->node); - kcp->nrcv_buf--; - iqueue_add_tail(&seg->node, &kcp->rcv_queue); - kcp->nrcv_que++; - kcp->rcv_nxt++; - } else { - break; - } - } - - // fast recover - if (kcp->nrcv_que < kcp->rcv_wnd && recover) { - // ready to send back IKCP_CMD_WINS in ikcp_flush - // tell remote my window size - kcp->probe |= IKCP_ASK_TELL; - } - - return len; -} - - -//--------------------------------------------------------------------- -// peek data size -//--------------------------------------------------------------------- -int ikcp_peeksize(const ikcpcb *kcp) -{ - struct IQUEUEHEAD *p; - IKCPSEG *seg; - int length = 0; - - assert(kcp); - - if (iqueue_is_empty(&kcp->rcv_queue)) return -1; - - seg = iqueue_entry(kcp->rcv_queue.next, IKCPSEG, node); - if (seg->frg == 0) return seg->len; - - if (kcp->nrcv_que < seg->frg + 1) return -1; - - for (p = kcp->rcv_queue.next; p != &kcp->rcv_queue; p = p->next) { - seg = iqueue_entry(p, IKCPSEG, node); - length += seg->len; - if (seg->frg == 0) break; - } - - return length; -} - - -//--------------------------------------------------------------------- -// user/upper level send, returns below zero for error -//--------------------------------------------------------------------- -int ikcp_send(ikcpcb *kcp, const char *buffer, int len) -{ - IKCPSEG *seg; - int count, i; - int sent = 0; - - assert(kcp->mss > 0); - if (len < 0) return -1; - - // append to previous segment in streaming mode (if possible) - if (kcp->stream != 0) { - if (!iqueue_is_empty(&kcp->snd_queue)) { - IKCPSEG *old = iqueue_entry(kcp->snd_queue.prev, IKCPSEG, node); - if (old->len < kcp->mss) { - int capacity = kcp->mss - old->len; - int extend = (len < capacity)? len : capacity; - seg = ikcp_segment_new(kcp, old->len + extend); - assert(seg); - if (seg == NULL) { - return -2; - } - iqueue_add_tail(&seg->node, &kcp->snd_queue); - memcpy(seg->data, old->data, old->len); - if (buffer) { - memcpy(seg->data + old->len, buffer, extend); - buffer += extend; - } - seg->len = old->len + extend; - seg->frg = 0; - len -= extend; - iqueue_del_init(&old->node); - ikcp_segment_delete(kcp, old); - sent = extend; - } - } - if (len <= 0) { - return sent; - } - } - - if (len <= (int)kcp->mss) count = 1; - else count = (len + kcp->mss - 1) / kcp->mss; - - if (count >= (int)IKCP_WND_RCV) { - if (kcp->stream != 0 && sent > 0) - return sent; - return -2; - } - - if (count == 0) count = 1; - - // fragment - for (i = 0; i < count; i++) { - int size = len > (int)kcp->mss ? (int)kcp->mss : len; - seg = ikcp_segment_new(kcp, size); - assert(seg); - if (seg == NULL) { - return -2; - } - if (buffer && len > 0) { - memcpy(seg->data, buffer, size); - } - seg->len = size; - seg->frg = (kcp->stream == 0)? (count - i - 1) : 0; - iqueue_init(&seg->node); - iqueue_add_tail(&seg->node, &kcp->snd_queue); - kcp->nsnd_que++; - if (buffer) { - buffer += size; - } - len -= size; - sent += size; - } - - return sent; -} - - -//--------------------------------------------------------------------- -// parse ack -//--------------------------------------------------------------------- -static void ikcp_update_ack(ikcpcb *kcp, IINT32 rtt) -{ - IINT32 rto = 0; - if (kcp->rx_srtt == 0) { - kcp->rx_srtt = rtt; - kcp->rx_rttval = rtt / 2; - } else { - long delta = rtt - kcp->rx_srtt; - if (delta < 0) delta = -delta; - kcp->rx_rttval = (3 * kcp->rx_rttval + delta) / 4; - kcp->rx_srtt = (7 * kcp->rx_srtt + rtt) / 8; - if (kcp->rx_srtt < 1) kcp->rx_srtt = 1; - } - rto = kcp->rx_srtt + _imax_(kcp->interval, 4 * kcp->rx_rttval); - kcp->rx_rto = _ibound_(kcp->rx_minrto, rto, IKCP_RTO_MAX); -} - -static void ikcp_shrink_buf(ikcpcb *kcp) -{ - struct IQUEUEHEAD *p = kcp->snd_buf.next; - if (p != &kcp->snd_buf) { - IKCPSEG *seg = iqueue_entry(p, IKCPSEG, node); - kcp->snd_una = seg->sn; - } else { - kcp->snd_una = kcp->snd_nxt; - } -} - -static void ikcp_parse_ack(ikcpcb *kcp, IUINT32 sn) -{ - struct IQUEUEHEAD *p, *next; - - if (_itimediff(sn, kcp->snd_una) < 0 || _itimediff(sn, kcp->snd_nxt) >= 0) - return; - - for (p = kcp->snd_buf.next; p != &kcp->snd_buf; p = next) { - IKCPSEG *seg = iqueue_entry(p, IKCPSEG, node); - next = p->next; - if (sn == seg->sn) { - iqueue_del(p); - ikcp_segment_delete(kcp, seg); - kcp->nsnd_buf--; - break; - } - if (_itimediff(sn, seg->sn) < 0) { - break; - } - } -} - -static void ikcp_parse_una(ikcpcb *kcp, IUINT32 una) -{ - struct IQUEUEHEAD *p, *next; - for (p = kcp->snd_buf.next; p != &kcp->snd_buf; p = next) { - IKCPSEG *seg = iqueue_entry(p, IKCPSEG, node); - next = p->next; - if (_itimediff(una, seg->sn) > 0) { - iqueue_del(p); - ikcp_segment_delete(kcp, seg); - kcp->nsnd_buf--; - } else { - break; - } - } -} - -static void ikcp_parse_fastack(ikcpcb *kcp, IUINT32 sn, IUINT32 ts) -{ - struct IQUEUEHEAD *p, *next; - - if (_itimediff(sn, kcp->snd_una) < 0 || _itimediff(sn, kcp->snd_nxt) >= 0) - return; - - for (p = kcp->snd_buf.next; p != &kcp->snd_buf; p = next) { - IKCPSEG *seg = iqueue_entry(p, IKCPSEG, node); - next = p->next; - if (_itimediff(sn, seg->sn) < 0) { - break; - } - else if (sn != seg->sn) { - #ifndef IKCP_FASTACK_CONSERVE - seg->fastack++; - #else - if (_itimediff(ts, seg->ts) >= 0) - seg->fastack++; - #endif - } - } -} - - -//--------------------------------------------------------------------- -// ack append -//--------------------------------------------------------------------- -static void ikcp_ack_push(ikcpcb *kcp, IUINT32 sn, IUINT32 ts) -{ - IUINT32 newsize = kcp->ackcount + 1; - IUINT32 *ptr; - - if (newsize > kcp->ackblock) { - IUINT32 *acklist; - IUINT32 newblock; - - for (newblock = 8; newblock < newsize; newblock <<= 1); - acklist = (IUINT32*)ikcp_malloc(newblock * sizeof(IUINT32) * 2); - - if (acklist == NULL) { - assert(acklist != NULL); - abort(); - } - - if (kcp->acklist != NULL) { - IUINT32 x; - for (x = 0; x < kcp->ackcount; x++) { - acklist[x * 2 + 0] = kcp->acklist[x * 2 + 0]; - acklist[x * 2 + 1] = kcp->acklist[x * 2 + 1]; - } - ikcp_free(kcp->acklist); - } - - kcp->acklist = acklist; - kcp->ackblock = newblock; - } - - ptr = &kcp->acklist[kcp->ackcount * 2]; - ptr[0] = sn; - ptr[1] = ts; - kcp->ackcount++; -} - -static void ikcp_ack_get(const ikcpcb *kcp, int p, IUINT32 *sn, IUINT32 *ts) -{ - if (sn) sn[0] = kcp->acklist[p * 2 + 0]; - if (ts) ts[0] = kcp->acklist[p * 2 + 1]; -} - - -//--------------------------------------------------------------------- -// parse data -//--------------------------------------------------------------------- -void ikcp_parse_data(ikcpcb *kcp, IKCPSEG *newseg) -{ - struct IQUEUEHEAD *p, *prev; - IUINT32 sn = newseg->sn; - int repeat = 0; - - if (_itimediff(sn, kcp->rcv_nxt + kcp->rcv_wnd) >= 0 || - _itimediff(sn, kcp->rcv_nxt) < 0) { - ikcp_segment_delete(kcp, newseg); - return; - } - - for (p = kcp->rcv_buf.prev; p != &kcp->rcv_buf; p = prev) { - IKCPSEG *seg = iqueue_entry(p, IKCPSEG, node); - prev = p->prev; - if (seg->sn == sn) { - repeat = 1; - break; - } - if (_itimediff(sn, seg->sn) > 0) { - break; - } - } - - if (repeat == 0) { - iqueue_init(&newseg->node); - iqueue_add(&newseg->node, p); - kcp->nrcv_buf++; - } else { - ikcp_segment_delete(kcp, newseg); - } - -#if 0 - ikcp_qprint("rcvbuf", &kcp->rcv_buf); - printf("rcv_nxt=%lu\n", kcp->rcv_nxt); -#endif - - // move available data from rcv_buf -> rcv_queue - while (! iqueue_is_empty(&kcp->rcv_buf)) { - IKCPSEG *seg = iqueue_entry(kcp->rcv_buf.next, IKCPSEG, node); - if (seg->sn == kcp->rcv_nxt && kcp->nrcv_que < kcp->rcv_wnd) { - iqueue_del(&seg->node); - kcp->nrcv_buf--; - iqueue_add_tail(&seg->node, &kcp->rcv_queue); - kcp->nrcv_que++; - kcp->rcv_nxt++; - } else { - break; - } - } - -#if 0 - ikcp_qprint("queue", &kcp->rcv_queue); - printf("rcv_nxt=%lu\n", kcp->rcv_nxt); -#endif - -#if 1 -// printf("snd(buf=%d, queue=%d)\n", kcp->nsnd_buf, kcp->nsnd_que); -// printf("rcv(buf=%d, queue=%d)\n", kcp->nrcv_buf, kcp->nrcv_que); -#endif -} - - -//--------------------------------------------------------------------- -// input data -//--------------------------------------------------------------------- -int ikcp_input(ikcpcb *kcp, const char *data, long size) -{ - IUINT32 prev_una = kcp->snd_una; - IUINT32 maxack = 0, latest_ts = 0; - int flag = 0; - - if (ikcp_canlog(kcp, IKCP_LOG_INPUT)) { - ikcp_log(kcp, IKCP_LOG_INPUT, "[RI] %d bytes", (int)size); - } - - if (data == NULL || (int)size < (int)IKCP_OVERHEAD) return -1; - - while (1) { - IUINT32 ts, sn, len, una, conv; - IUINT16 wnd; - IUINT8 cmd, frg; - IKCPSEG *seg; - - if (size < (int)IKCP_OVERHEAD) break; - - data = ikcp_decode32u(data, &conv); - if (conv != kcp->conv) return -1; - - data = ikcp_decode8u(data, &cmd); - data = ikcp_decode8u(data, &frg); - data = ikcp_decode16u(data, &wnd); - data = ikcp_decode32u(data, &ts); - data = ikcp_decode32u(data, &sn); - data = ikcp_decode32u(data, &una); - data = ikcp_decode32u(data, &len); - - size -= IKCP_OVERHEAD; - - if ((long)size < (long)len || (int)len < 0) return -2; - - if (cmd != IKCP_CMD_PUSH && cmd != IKCP_CMD_ACK && - cmd != IKCP_CMD_WASK && cmd != IKCP_CMD_WINS) - return -3; - - kcp->rmt_wnd = wnd; - ikcp_parse_una(kcp, una); - ikcp_shrink_buf(kcp); - - if (cmd == IKCP_CMD_ACK) { - if (_itimediff(kcp->current, ts) >= 0) { - ikcp_update_ack(kcp, _itimediff(kcp->current, ts)); - } - ikcp_parse_ack(kcp, sn); - ikcp_shrink_buf(kcp); - if (flag == 0) { - flag = 1; - maxack = sn; - latest_ts = ts; - } else { - if (_itimediff(sn, maxack) > 0) { - #ifndef IKCP_FASTACK_CONSERVE - maxack = sn; - latest_ts = ts; - #else - if (_itimediff(ts, latest_ts) > 0) { - maxack = sn; - latest_ts = ts; - } - #endif - } - } - if (ikcp_canlog(kcp, IKCP_LOG_IN_ACK)) { - ikcp_log(kcp, IKCP_LOG_IN_ACK, - "input ack: sn=%lu rtt=%ld rto=%ld", (unsigned long)sn, - (long)_itimediff(kcp->current, ts), - (long)kcp->rx_rto); - } - } - else if (cmd == IKCP_CMD_PUSH) { - if (ikcp_canlog(kcp, IKCP_LOG_IN_DATA)) { - ikcp_log(kcp, IKCP_LOG_IN_DATA, - "input psh: sn=%lu ts=%lu", (unsigned long)sn, (unsigned long)ts); - } - if (_itimediff(sn, kcp->rcv_nxt + kcp->rcv_wnd) < 0) { - ikcp_ack_push(kcp, sn, ts); - if (_itimediff(sn, kcp->rcv_nxt) >= 0) { - seg = ikcp_segment_new(kcp, len); - seg->conv = conv; - seg->cmd = cmd; - seg->frg = frg; - seg->wnd = wnd; - seg->ts = ts; - seg->sn = sn; - seg->una = una; - seg->len = len; - - if (len > 0) { - memcpy(seg->data, data, len); - } - - ikcp_parse_data(kcp, seg); - } - } - } - else if (cmd == IKCP_CMD_WASK) { - // ready to send back IKCP_CMD_WINS in ikcp_flush - // tell remote my window size - kcp->probe |= IKCP_ASK_TELL; - if (ikcp_canlog(kcp, IKCP_LOG_IN_PROBE)) { - ikcp_log(kcp, IKCP_LOG_IN_PROBE, "input probe"); - } - } - else if (cmd == IKCP_CMD_WINS) { - // do nothing - if (ikcp_canlog(kcp, IKCP_LOG_IN_WINS)) { - ikcp_log(kcp, IKCP_LOG_IN_WINS, - "input wins: %lu", (unsigned long)(wnd)); - } - } - else { - return -3; - } - - data += len; - size -= len; - } - - if (flag != 0) { - ikcp_parse_fastack(kcp, maxack, latest_ts); - } - - if (_itimediff(kcp->snd_una, prev_una) > 0) { - if (kcp->cwnd < kcp->rmt_wnd) { - IUINT32 mss = kcp->mss; - if (kcp->cwnd < kcp->ssthresh) { - kcp->cwnd++; - kcp->incr += mss; - } else { - if (kcp->incr < mss) kcp->incr = mss; - kcp->incr += (mss * mss) / kcp->incr + (mss / 16); - if ((kcp->cwnd + 1) * mss <= kcp->incr) { - #if 1 - kcp->cwnd = (kcp->incr + mss - 1) / ((mss > 0)? mss : 1); - #else - kcp->cwnd++; - #endif - } - } - if (kcp->cwnd > kcp->rmt_wnd) { - kcp->cwnd = kcp->rmt_wnd; - kcp->incr = kcp->rmt_wnd * mss; - } - } - } - - return 0; -} - - -//--------------------------------------------------------------------- -// ikcp_encode_seg -//--------------------------------------------------------------------- -static char *ikcp_encode_seg(char *ptr, const IKCPSEG *seg) -{ - ptr = ikcp_encode32u(ptr, seg->conv); - ptr = ikcp_encode8u(ptr, (IUINT8)seg->cmd); - ptr = ikcp_encode8u(ptr, (IUINT8)seg->frg); - ptr = ikcp_encode16u(ptr, (IUINT16)seg->wnd); - ptr = ikcp_encode32u(ptr, seg->ts); - ptr = ikcp_encode32u(ptr, seg->sn); - ptr = ikcp_encode32u(ptr, seg->una); - ptr = ikcp_encode32u(ptr, seg->len); - return ptr; -} - -static int ikcp_wnd_unused(const ikcpcb *kcp) -{ - if (kcp->nrcv_que < kcp->rcv_wnd) { - return kcp->rcv_wnd - kcp->nrcv_que; - } - return 0; -} - - -//--------------------------------------------------------------------- -// ikcp_flush -//--------------------------------------------------------------------- -void ikcp_flush(ikcpcb *kcp) -{ - IUINT32 current = kcp->current; - char *buffer = kcp->buffer; - char *ptr = buffer; - int count, size, i; - IUINT32 resent, cwnd; - IUINT32 rtomin; - struct IQUEUEHEAD *p; - int change = 0; - int lost = 0; - IKCPSEG seg; - - // 'ikcp_update' haven't been called. - if (kcp->updated == 0) return; - - seg.conv = kcp->conv; - seg.cmd = IKCP_CMD_ACK; - seg.frg = 0; - seg.wnd = ikcp_wnd_unused(kcp); - seg.una = kcp->rcv_nxt; - seg.len = 0; - seg.sn = 0; - seg.ts = 0; - - // flush acknowledges - count = kcp->ackcount; - for (i = 0; i < count; i++) { - size = (int)(ptr - buffer); - if (size + (int)IKCP_OVERHEAD > (int)kcp->mtu) { - ikcp_output(kcp, buffer, size); - ptr = buffer; - } - ikcp_ack_get(kcp, i, &seg.sn, &seg.ts); - ptr = ikcp_encode_seg(ptr, &seg); - } - - kcp->ackcount = 0; - - // probe window size (if remote window size equals zero) - if (kcp->rmt_wnd == 0) { - if (kcp->probe_wait == 0) { - kcp->probe_wait = IKCP_PROBE_INIT; - kcp->ts_probe = kcp->current + kcp->probe_wait; - } - else { - if (_itimediff(kcp->current, kcp->ts_probe) >= 0) { - if (kcp->probe_wait < IKCP_PROBE_INIT) - kcp->probe_wait = IKCP_PROBE_INIT; - kcp->probe_wait += kcp->probe_wait / 2; - if (kcp->probe_wait > IKCP_PROBE_LIMIT) - kcp->probe_wait = IKCP_PROBE_LIMIT; - kcp->ts_probe = kcp->current + kcp->probe_wait; - kcp->probe |= IKCP_ASK_SEND; - } - } - } else { - kcp->ts_probe = 0; - kcp->probe_wait = 0; - } - - // flush window probing commands - if (kcp->probe & IKCP_ASK_SEND) { - seg.cmd = IKCP_CMD_WASK; - size = (int)(ptr - buffer); - if (size + (int)IKCP_OVERHEAD > (int)kcp->mtu) { - ikcp_output(kcp, buffer, size); - ptr = buffer; - } - ptr = ikcp_encode_seg(ptr, &seg); - } - - // flush window probing commands - if (kcp->probe & IKCP_ASK_TELL) { - seg.cmd = IKCP_CMD_WINS; - size = (int)(ptr - buffer); - if (size + (int)IKCP_OVERHEAD > (int)kcp->mtu) { - ikcp_output(kcp, buffer, size); - ptr = buffer; - } - ptr = ikcp_encode_seg(ptr, &seg); - } - - kcp->probe = 0; - - // calculate window size - cwnd = _imin_(kcp->snd_wnd, kcp->rmt_wnd); - if (kcp->nocwnd == 0) cwnd = _imin_(kcp->cwnd, cwnd); - - // move data from snd_queue to snd_buf - while (_itimediff(kcp->snd_nxt, kcp->snd_una + cwnd) < 0) { - IKCPSEG *newseg; - if (iqueue_is_empty(&kcp->snd_queue)) break; - - newseg = iqueue_entry(kcp->snd_queue.next, IKCPSEG, node); - - iqueue_del(&newseg->node); - iqueue_add_tail(&newseg->node, &kcp->snd_buf); - kcp->nsnd_que--; - kcp->nsnd_buf++; - - newseg->conv = kcp->conv; - newseg->cmd = IKCP_CMD_PUSH; - newseg->wnd = seg.wnd; - newseg->ts = current; - newseg->sn = kcp->snd_nxt++; - newseg->una = kcp->rcv_nxt; - newseg->resendts = current; - newseg->rto = kcp->rx_rto; - newseg->fastack = 0; - newseg->xmit = 0; - } - - // calculate resent - resent = (kcp->fastresend > 0)? (IUINT32)kcp->fastresend : 0xffffffff; - rtomin = (kcp->nodelay == 0)? (kcp->rx_rto >> 3) : 0; - - // flush data segments - for (p = kcp->snd_buf.next; p != &kcp->snd_buf; p = p->next) { - IKCPSEG *segment = iqueue_entry(p, IKCPSEG, node); - int needsend = 0; - if (segment->xmit == 0) { - needsend = 1; - segment->xmit++; - segment->rto = kcp->rx_rto; - segment->resendts = current + segment->rto + rtomin; - } - else if (_itimediff(current, segment->resendts) >= 0) { - needsend = 1; - segment->xmit++; - kcp->xmit++; - if (kcp->nodelay == 0) { - segment->rto += _imax_(segment->rto, (IUINT32)kcp->rx_rto); - } else { - IINT32 step = (kcp->nodelay < 2)? - ((IINT32)(segment->rto)) : kcp->rx_rto; - segment->rto += step / 2; - } - segment->resendts = current + segment->rto; - lost = 1; - } - else if (segment->fastack >= resent) { - if ((int)segment->xmit <= kcp->fastlimit || - kcp->fastlimit <= 0) { - needsend = 1; - segment->xmit++; - segment->fastack = 0; - segment->resendts = current + segment->rto; - change++; - } - } - - if (needsend) { - int need; - segment->ts = current; - segment->wnd = seg.wnd; - segment->una = kcp->rcv_nxt; - - size = (int)(ptr - buffer); - need = IKCP_OVERHEAD + segment->len; - - if (size + need > (int)kcp->mtu) { - ikcp_output(kcp, buffer, size); - ptr = buffer; - } - - ptr = ikcp_encode_seg(ptr, segment); - - if (segment->len > 0) { - memcpy(ptr, segment->data, segment->len); - ptr += segment->len; - } - - if (segment->xmit >= kcp->dead_link) { - kcp->state = (IUINT32)-1; - } - } - } - - // flash remain segments - size = (int)(ptr - buffer); - if (size > 0) { - ikcp_output(kcp, buffer, size); - } - - // update ssthresh - if (change) { - IUINT32 inflight = kcp->snd_nxt - kcp->snd_una; - kcp->ssthresh = inflight / 2; - if (kcp->ssthresh < IKCP_THRESH_MIN) - kcp->ssthresh = IKCP_THRESH_MIN; - kcp->cwnd = kcp->ssthresh + resent; - kcp->incr = kcp->cwnd * kcp->mss; - } - - if (lost) { - kcp->ssthresh = cwnd / 2; - if (kcp->ssthresh < IKCP_THRESH_MIN) - kcp->ssthresh = IKCP_THRESH_MIN; - kcp->cwnd = 1; - kcp->incr = kcp->mss; - } - - if (kcp->cwnd < 1) { - kcp->cwnd = 1; - kcp->incr = kcp->mss; - } -} - - -//--------------------------------------------------------------------- -// update state (call it repeatedly, every 10ms-100ms), or you can ask -// ikcp_check when to call it again (without ikcp_input/_send calling). -// 'current' - current timestamp in millisec. -//--------------------------------------------------------------------- -void ikcp_update(ikcpcb *kcp, IUINT32 current) -{ - IINT32 slap; - - kcp->current = current; - - if (kcp->updated == 0) { - kcp->updated = 1; - kcp->ts_flush = kcp->current; - } - - slap = _itimediff(kcp->current, kcp->ts_flush); - - if (slap >= 10000 || slap < -10000) { - kcp->ts_flush = kcp->current; - slap = 0; - } - - if (slap >= 0) { - kcp->ts_flush += kcp->interval; - if (_itimediff(kcp->current, kcp->ts_flush) >= 0) { - kcp->ts_flush = kcp->current + kcp->interval; - } - ikcp_flush(kcp); - } -} - - -//--------------------------------------------------------------------- -// Determine when should you invoke ikcp_update: -// returns when you should invoke ikcp_update in millisec, if there -// is no ikcp_input/_send calling. you can call ikcp_update in that -// time, instead of call update repeatly. -// Important to reduce unnacessary ikcp_update invoking. use it to -// schedule ikcp_update (eg. implementing an epoll-like mechanism, -// or optimize ikcp_update when handling massive kcp connections) -//--------------------------------------------------------------------- -IUINT32 ikcp_check(const ikcpcb *kcp, IUINT32 current) -{ - IUINT32 ts_flush = kcp->ts_flush; - IINT32 tm_flush = 0x7fffffff; - IINT32 tm_packet = 0x7fffffff; - IUINT32 minimal = 0; - struct IQUEUEHEAD *p; - - if (kcp->updated == 0) { - return current; - } - - if (_itimediff(current, ts_flush) >= 10000 || - _itimediff(current, ts_flush) < -10000) { - ts_flush = current; - } - - if (_itimediff(current, ts_flush) >= 0) { - return current; - } - - tm_flush = _itimediff(ts_flush, current); - - for (p = kcp->snd_buf.next; p != &kcp->snd_buf; p = p->next) { - const IKCPSEG *seg = iqueue_entry(p, const IKCPSEG, node); - IINT32 diff = _itimediff(seg->resendts, current); - if (diff <= 0) { - return current; - } - if (diff < tm_packet) tm_packet = diff; - } - - minimal = (IUINT32)(tm_packet < tm_flush ? tm_packet : tm_flush); - if (minimal >= kcp->interval) minimal = kcp->interval; - - return current + minimal; -} - - - -int ikcp_setmtu(ikcpcb *kcp, int mtu) -{ - char *buffer; - if (mtu < 50 || mtu < (int)IKCP_OVERHEAD) - return -1; - buffer = (char*)ikcp_malloc((mtu + IKCP_OVERHEAD) * 3); - if (buffer == NULL) - return -2; - kcp->mtu = mtu; - kcp->mss = kcp->mtu - IKCP_OVERHEAD; - ikcp_free(kcp->buffer); - kcp->buffer = buffer; - return 0; -} - -int ikcp_interval(ikcpcb *kcp, int interval) -{ - if (interval > 5000) interval = 5000; - else if (interval < 10) interval = 10; - kcp->interval = interval; - return 0; -} - -int ikcp_nodelay(ikcpcb *kcp, int nodelay, int interval, int resend, int nc) -{ - if (nodelay >= 0) { - kcp->nodelay = nodelay; - if (nodelay) { - kcp->rx_minrto = IKCP_RTO_NDL; - } - else { - kcp->rx_minrto = IKCP_RTO_MIN; - } - } - if (interval >= 0) { - if (interval > 5000) interval = 5000; - else if (interval < 10) interval = 10; - kcp->interval = interval; - } - if (resend >= 0) { - kcp->fastresend = resend; - } - if (nc >= 0) { - kcp->nocwnd = nc; - } - return 0; -} - - -int ikcp_wndsize(ikcpcb *kcp, int sndwnd, int rcvwnd) -{ - if (kcp) { - if (sndwnd > 0) { - kcp->snd_wnd = sndwnd; - } - if (rcvwnd > 0) { // must >= max fragment size - kcp->rcv_wnd = _imax_(rcvwnd, IKCP_WND_RCV); - } - } - return 0; -} - -int ikcp_waitsnd(const ikcpcb *kcp) -{ - return kcp->nsnd_buf + kcp->nsnd_que; -} - - -// read conv -IUINT32 ikcp_getconv(const void *ptr) -{ - IUINT32 conv; - ikcp_decode32u((const char*)ptr, &conv); - return conv; -} - - diff --git a/src/qos/kcp/ikcp.h b/src/qos/kcp/ikcp.h deleted file mode 100644 index e525105..0000000 --- a/src/qos/kcp/ikcp.h +++ /dev/null @@ -1,416 +0,0 @@ -//===================================================================== -// -// KCP - A Better ARQ Protocol Implementation -// skywind3000 (at) gmail.com, 2010-2011 -// -// Features: -// + Average RTT reduce 30% - 40% vs traditional ARQ like tcp. -// + Maximum RTT reduce three times vs tcp. -// + Lightweight, distributed as a single source file. -// -//===================================================================== -#ifndef __IKCP_H__ -#define __IKCP_H__ - -#include -#include -#include - - -//===================================================================== -// 32BIT INTEGER DEFINITION -//===================================================================== -#ifndef __INTEGER_32_BITS__ -#define __INTEGER_32_BITS__ -#if defined(_WIN64) || defined(WIN64) || defined(__amd64__) || \ - defined(__x86_64) || defined(__x86_64__) || defined(_M_IA64) || \ - defined(_M_AMD64) - typedef unsigned int ISTDUINT32; - typedef int ISTDINT32; -#elif defined(_WIN32) || defined(WIN32) || defined(__i386__) || \ - defined(__i386) || defined(_M_X86) - typedef unsigned long ISTDUINT32; - typedef long ISTDINT32; -#elif defined(__MACOS__) - typedef UInt32 ISTDUINT32; - typedef SInt32 ISTDINT32; -#elif defined(__APPLE__) && defined(__MACH__) - #include - typedef u_int32_t ISTDUINT32; - typedef int32_t ISTDINT32; -#elif defined(__BEOS__) - #include - typedef u_int32_t ISTDUINT32; - typedef int32_t ISTDINT32; -#elif (defined(_MSC_VER) || defined(__BORLANDC__)) && (!defined(__MSDOS__)) - typedef unsigned __int32 ISTDUINT32; - typedef __int32 ISTDINT32; -#elif defined(__GNUC__) - #include - typedef uint32_t ISTDUINT32; - typedef int32_t ISTDINT32; -#else - typedef unsigned long ISTDUINT32; - typedef long ISTDINT32; -#endif -#endif - - -//===================================================================== -// Integer Definition -//===================================================================== -#ifndef __IINT8_DEFINED -#define __IINT8_DEFINED -typedef char IINT8; -#endif - -#ifndef __IUINT8_DEFINED -#define __IUINT8_DEFINED -typedef unsigned char IUINT8; -#endif - -#ifndef __IUINT16_DEFINED -#define __IUINT16_DEFINED -typedef unsigned short IUINT16; -#endif - -#ifndef __IINT16_DEFINED -#define __IINT16_DEFINED -typedef short IINT16; -#endif - -#ifndef __IINT32_DEFINED -#define __IINT32_DEFINED -typedef ISTDINT32 IINT32; -#endif - -#ifndef __IUINT32_DEFINED -#define __IUINT32_DEFINED -typedef ISTDUINT32 IUINT32; -#endif - -#ifndef __IINT64_DEFINED -#define __IINT64_DEFINED -#if defined(_MSC_VER) || defined(__BORLANDC__) -typedef __int64 IINT64; -#else -typedef long long IINT64; -#endif -#endif - -#ifndef __IUINT64_DEFINED -#define __IUINT64_DEFINED -#if defined(_MSC_VER) || defined(__BORLANDC__) -typedef unsigned __int64 IUINT64; -#else -typedef unsigned long long IUINT64; -#endif -#endif - -#ifndef INLINE -#if defined(__GNUC__) - -#if (__GNUC__ > 3) || ((__GNUC__ == 3) && (__GNUC_MINOR__ >= 1)) -#define INLINE __inline__ __attribute__((always_inline)) -#else -#define INLINE __inline__ -#endif - -#elif (defined(_MSC_VER) || defined(__BORLANDC__) || defined(__WATCOMC__)) -#define INLINE __inline -#else -#define INLINE -#endif -#endif - -#if (!defined(__cplusplus)) && (!defined(inline)) -#define inline INLINE -#endif - - -//===================================================================== -// QUEUE DEFINITION -//===================================================================== -#ifndef __IQUEUE_DEF__ -#define __IQUEUE_DEF__ - -struct IQUEUEHEAD { - struct IQUEUEHEAD *next, *prev; -}; - -typedef struct IQUEUEHEAD iqueue_head; - - -//--------------------------------------------------------------------- -// queue init -//--------------------------------------------------------------------- -#define IQUEUE_HEAD_INIT(name) { &(name), &(name) } -#define IQUEUE_HEAD(name) \ - struct IQUEUEHEAD name = IQUEUE_HEAD_INIT(name) - -#define IQUEUE_INIT(ptr) ( \ - (ptr)->next = (ptr), (ptr)->prev = (ptr)) - -#define IOFFSETOF(TYPE, MEMBER) ((size_t) &((TYPE *)0)->MEMBER) - -#define ICONTAINEROF(ptr, type, member) ( \ - (type*)( ((char*)((type*)ptr)) - IOFFSETOF(type, member)) ) - -#define IQUEUE_ENTRY(ptr, type, member) ICONTAINEROF(ptr, type, member) - - -//--------------------------------------------------------------------- -// queue operation -//--------------------------------------------------------------------- -#define IQUEUE_ADD(node, head) ( \ - (node)->prev = (head), (node)->next = (head)->next, \ - (head)->next->prev = (node), (head)->next = (node)) - -#define IQUEUE_ADD_TAIL(node, head) ( \ - (node)->prev = (head)->prev, (node)->next = (head), \ - (head)->prev->next = (node), (head)->prev = (node)) - -#define IQUEUE_DEL_BETWEEN(p, n) ((n)->prev = (p), (p)->next = (n)) - -#define IQUEUE_DEL(entry) (\ - (entry)->next->prev = (entry)->prev, \ - (entry)->prev->next = (entry)->next, \ - (entry)->next = 0, (entry)->prev = 0) - -#define IQUEUE_DEL_INIT(entry) do { \ - IQUEUE_DEL(entry); IQUEUE_INIT(entry); } while (0) - -#define IQUEUE_IS_EMPTY(entry) ((entry) == (entry)->next) - -#define iqueue_init IQUEUE_INIT -#define iqueue_entry IQUEUE_ENTRY -#define iqueue_add IQUEUE_ADD -#define iqueue_add_tail IQUEUE_ADD_TAIL -#define iqueue_del IQUEUE_DEL -#define iqueue_del_init IQUEUE_DEL_INIT -#define iqueue_is_empty IQUEUE_IS_EMPTY - -#define IQUEUE_FOREACH(iterator, head, TYPE, MEMBER) \ - for ((iterator) = iqueue_entry((head)->next, TYPE, MEMBER); \ - &((iterator)->MEMBER) != (head); \ - (iterator) = iqueue_entry((iterator)->MEMBER.next, TYPE, MEMBER)) - -#define iqueue_foreach(iterator, head, TYPE, MEMBER) \ - IQUEUE_FOREACH(iterator, head, TYPE, MEMBER) - -#define iqueue_foreach_entry(pos, head) \ - for( (pos) = (head)->next; (pos) != (head) ; (pos) = (pos)->next ) - - -#define __iqueue_splice(list, head) do { \ - iqueue_head *first = (list)->next, *last = (list)->prev; \ - iqueue_head *at = (head)->next; \ - (first)->prev = (head), (head)->next = (first); \ - (last)->next = (at), (at)->prev = (last); } while (0) - -#define iqueue_splice(list, head) do { \ - if (!iqueue_is_empty(list)) __iqueue_splice(list, head); } while (0) - -#define iqueue_splice_init(list, head) do { \ - iqueue_splice(list, head); iqueue_init(list); } while (0) - - -#ifdef _MSC_VER -#pragma warning(disable:4311) -#pragma warning(disable:4312) -#pragma warning(disable:4996) -#endif - -#endif - - -//--------------------------------------------------------------------- -// BYTE ORDER & ALIGNMENT -//--------------------------------------------------------------------- -#ifndef IWORDS_BIG_ENDIAN - #ifdef _BIG_ENDIAN_ - #if _BIG_ENDIAN_ - #define IWORDS_BIG_ENDIAN 1 - #endif - #endif - #ifndef IWORDS_BIG_ENDIAN - #if defined(__hppa__) || \ - defined(__m68k__) || defined(mc68000) || defined(_M_M68K) || \ - (defined(__MIPS__) && defined(__MIPSEB__)) || \ - defined(__ppc__) || defined(__POWERPC__) || defined(_M_PPC) || \ - defined(__sparc__) || defined(__powerpc__) || \ - defined(__mc68000__) || defined(__s390x__) || defined(__s390__) - #define IWORDS_BIG_ENDIAN 1 - #endif - #endif - #ifndef IWORDS_BIG_ENDIAN - #define IWORDS_BIG_ENDIAN 0 - #endif -#endif - -#ifndef IWORDS_MUST_ALIGN - #if defined(__i386__) || defined(__i386) || defined(_i386_) - #define IWORDS_MUST_ALIGN 0 - #elif defined(_M_IX86) || defined(_X86_) || defined(__x86_64__) - #define IWORDS_MUST_ALIGN 0 - #elif defined(__amd64) || defined(__amd64__) - #define IWORDS_MUST_ALIGN 0 - #else - #define IWORDS_MUST_ALIGN 1 - #endif -#endif - - -//===================================================================== -// SEGMENT -//===================================================================== -struct IKCPSEG -{ - struct IQUEUEHEAD node; - IUINT32 conv; - IUINT32 cmd; - IUINT32 frg; - IUINT32 wnd; - IUINT32 ts; - IUINT32 sn; - IUINT32 una; - IUINT32 len; - IUINT32 resendts; - IUINT32 rto; - IUINT32 fastack; - IUINT32 xmit; - char data[1]; -}; - - -//--------------------------------------------------------------------- -// IKCPCB -//--------------------------------------------------------------------- -struct IKCPCB -{ - IUINT32 conv, mtu, mss, state; - IUINT32 snd_una, snd_nxt, rcv_nxt; - IUINT32 ts_recent, ts_lastack, ssthresh; - IINT32 rx_rttval, rx_srtt, rx_rto, rx_minrto; - IUINT32 snd_wnd, rcv_wnd, rmt_wnd, cwnd, probe; - IUINT32 current, interval, ts_flush, xmit; - IUINT32 nrcv_buf, nsnd_buf; - IUINT32 nrcv_que, nsnd_que; - IUINT32 nodelay, updated; - IUINT32 ts_probe, probe_wait; - IUINT32 dead_link, incr; - struct IQUEUEHEAD snd_queue; - struct IQUEUEHEAD rcv_queue; - struct IQUEUEHEAD snd_buf; - struct IQUEUEHEAD rcv_buf; - IUINT32 *acklist; - IUINT32 ackcount; - IUINT32 ackblock; - void *user; - char *buffer; - int fastresend; - int fastlimit; - int nocwnd, stream; - int logmask; - int (*output)(const char *buf, int len, struct IKCPCB *kcp, void *user); - void (*writelog)(const char *log, struct IKCPCB *kcp, void *user); -}; - - -typedef struct IKCPCB ikcpcb; - -#define IKCP_LOG_OUTPUT 1 -#define IKCP_LOG_INPUT 2 -#define IKCP_LOG_SEND 4 -#define IKCP_LOG_RECV 8 -#define IKCP_LOG_IN_DATA 16 -#define IKCP_LOG_IN_ACK 32 -#define IKCP_LOG_IN_PROBE 64 -#define IKCP_LOG_IN_WINS 128 -#define IKCP_LOG_OUT_DATA 256 -#define IKCP_LOG_OUT_ACK 512 -#define IKCP_LOG_OUT_PROBE 1024 -#define IKCP_LOG_OUT_WINS 2048 - -#ifdef __cplusplus -extern "C" { -#endif - -//--------------------------------------------------------------------- -// interface -//--------------------------------------------------------------------- - -// create a new kcp control object, 'conv' must equal in two endpoint -// from the same connection. 'user' will be passed to the output callback -// output callback can be setup like this: 'kcp->output = my_udp_output' -ikcpcb* ikcp_create(IUINT32 conv, void *user); - -// release kcp control object -void ikcp_release(ikcpcb *kcp); - -// set output callback, which will be invoked by kcp -void ikcp_setoutput(ikcpcb *kcp, int (*output)(const char *buf, int len, - ikcpcb *kcp, void *user)); - -// user/upper level recv: returns size, returns below zero for EAGAIN -int ikcp_recv(ikcpcb *kcp, char *buffer, int len); - -// user/upper level send, returns below zero for error -int ikcp_send(ikcpcb *kcp, const char *buffer, int len); - -// update state (call it repeatedly, every 10ms-100ms), or you can ask -// ikcp_check when to call it again (without ikcp_input/_send calling). -// 'current' - current timestamp in millisec. -void ikcp_update(ikcpcb *kcp, IUINT32 current); - -// Determine when should you invoke ikcp_update: -// returns when you should invoke ikcp_update in millisec, if there -// is no ikcp_input/_send calling. you can call ikcp_update in that -// time, instead of call update repeatly. -// Important to reduce unnacessary ikcp_update invoking. use it to -// schedule ikcp_update (eg. implementing an epoll-like mechanism, -// or optimize ikcp_update when handling massive kcp connections) -IUINT32 ikcp_check(const ikcpcb *kcp, IUINT32 current); - -// when you received a low level packet (eg. UDP packet), call it -int ikcp_input(ikcpcb *kcp, const char *data, long size); - -// flush pending data -void ikcp_flush(ikcpcb *kcp); - -// check the size of next message in the recv queue -int ikcp_peeksize(const ikcpcb *kcp); - -// change MTU size, default is 1400 -int ikcp_setmtu(ikcpcb *kcp, int mtu); - -// set maximum window size: sndwnd=32, rcvwnd=32 by default -int ikcp_wndsize(ikcpcb *kcp, int sndwnd, int rcvwnd); - -// get how many packet is waiting to be sent -int ikcp_waitsnd(const ikcpcb *kcp); - -// fastest: ikcp_nodelay(kcp, 1, 20, 2, 1) -// nodelay: 0:disable(default), 1:enable -// interval: internal update timer interval in millisec, default is 100ms -// resend: 0:disable fast resend(default), 1:enable fast resend -// nc: 0:normal congestion control(default), 1:disable congestion control -int ikcp_nodelay(ikcpcb *kcp, int nodelay, int interval, int resend, int nc); - - -void ikcp_log(ikcpcb *kcp, int mask, const char *fmt, ...); - -// setup allocator -void ikcp_allocator(void* (*new_malloc)(size_t), void (*new_free)(void*)); - -// read conv -IUINT32 ikcp_getconv(const void *ptr); - - -#ifdef __cplusplus -} -#endif - -#endif - - diff --git a/src/transmission/ice_transmission.cpp b/src/transmission/ice_transmission.cpp index 8bfa972..9187ab3 100644 --- a/src/transmission/ice_transmission.cpp +++ b/src/transmission/ice_transmission.cpp @@ -6,7 +6,6 @@ #include #include "common.h" -#include "ikcp.h" #include "log.h" #if __APPLE__ #else @@ -49,7 +48,7 @@ int IceTransmission::SetLocalCapabilities( hardware_acceleration_ = hardware_acceleration; use_trickle_ice_ = use_trickle_ice; use_reliable_ice_ = use_reliable_ice; - enable_turn_ = force_turn; + enable_turn_ = enable_turn; force_turn_ = force_turn; support_video_payload_types_ = video_payload_types; support_audio_payload_types_ = audio_payload_types; @@ -105,10 +104,9 @@ int IceTransmission::InitIceTransmission( }); rtp_video_receiver_->SetOnReceiveCompleteFrame( [this](VideoFrame &video_frame) -> void { - // LOG_ERROR("OnReceiveCompleteFrame {}", video_frame.Size()); - ice_io_statistics_->UpdateVideoInboundBytes(video_frame.Size()); - - int num_frame_returned = video_decoder_->Decode( + ice_io_statistics_->UpdateVideoInboundBytes( + (uint32_t)video_frame.Size()); + [[maybe_unused]] int num_frame_returned = video_decoder_->Decode( (uint8_t *)video_frame.Buffer(), video_frame.Size(), [this](VideoFrame video_frame) { if (on_receive_video_) { @@ -140,7 +138,7 @@ int IceTransmission::InitIceTransmission( return -2; } - ice_io_statistics_->UpdateVideoOutboundBytes(size); + ice_io_statistics_->UpdateVideoOutboundBytes((uint32_t)size); return ice_agent_->Send(data, size); }); @@ -166,9 +164,9 @@ int IceTransmission::InitIceTransmission( }); rtp_audio_receiver_->SetOnReceiveData([this](const char *data, size_t size) -> void { - ice_io_statistics_->UpdateAudioInboundBytes(size); + ice_io_statistics_->UpdateAudioInboundBytes((uint32_t)size); - int num_frame_returned = audio_decoder_->Decode( + [[maybe_unused]] int num_frame_returned = audio_decoder_->Decode( (uint8_t *)data, size, [this](uint8_t *data, int size) { if (on_receive_audio_) { on_receive_audio_((const char *)data, size, remote_user_id_.data(), @@ -192,7 +190,7 @@ int IceTransmission::InitIceTransmission( return -2; } - ice_io_statistics_->UpdateAudioOutboundBytes(size); + ice_io_statistics_->UpdateAudioOutboundBytes((uint32_t)size); return ice_agent_->Send(data, size); }); @@ -218,7 +216,7 @@ int IceTransmission::InitIceTransmission( }); rtp_data_receiver_->SetOnReceiveData( [this](const char *data, size_t size) -> void { - ice_io_statistics_->UpdateDataInboundBytes(size); + ice_io_statistics_->UpdateDataInboundBytes((uint32_t)size); if (on_receive_data_) { on_receive_data_(data, size, remote_user_id_.data(), @@ -241,7 +239,7 @@ int IceTransmission::InitIceTransmission( return -2; } - ice_io_statistics_->UpdateDataOutboundBytes(size); + ice_io_statistics_->UpdateDataOutboundBytes((uint32_t)size); return ice_agent_->Send(data, size); }); @@ -253,8 +251,9 @@ int IceTransmission::InitIceTransmission( turn_password); ice_agent_->CreateIceAgent( - [](NiceAgent *agent, guint stream_id, guint component_id, - NiceComponentState state, gpointer user_ptr) { + []([[maybe_unused]] NiceAgent *agent, [[maybe_unused]] guint stream_id, + [[maybe_unused]] guint component_id, NiceComponentState state, + gpointer user_ptr) { if (user_ptr) { IceTransmission *ice_transmission_obj = static_cast(user_ptr); @@ -313,7 +312,8 @@ int IceTransmission::InitIceTransmission( } } }, - [](NiceAgent *agent, guint stream_id, gpointer user_ptr) { + []([[maybe_unused]] NiceAgent *agent, [[maybe_unused]] guint stream_id, + gpointer user_ptr) { // non-trickle if (user_ptr) { IceTransmission *ice_transmission_obj = @@ -365,8 +365,9 @@ int IceTransmission::InitIceTransmission( &net_traffic_stats, ice_transmission_obj->user_data_); } }, - [](NiceAgent *agent, guint stream_id, guint component_id, guint size, - gchar *buffer, gpointer user_ptr) { + []([[maybe_unused]] NiceAgent *agent, [[maybe_unused]] guint stream_id, + [[maybe_unused]] guint component_id, guint size, gchar *buffer, + gpointer user_ptr) { if (user_ptr) { IceTransmission *ice_transmission_obj = static_cast(user_ptr); @@ -977,7 +978,7 @@ int IceTransmission::SendVideoFrame(const XVideoFrame *video_frame) { if (video_rtp_codec_) { video_rtp_codec_->Encode( static_cast(frame_type), - (uint8_t *)encoded_frame, size, packets); + (uint8_t *)encoded_frame, (uint32_t)size, packets); } rtp_video_sender_->Enqueue(packets); } @@ -1007,15 +1008,15 @@ int IceTransmission::SendAudioFrame(const char *data, size_t size) { if (rtp_audio_sender_) { if (audio_rtp_codec_) { std::vector packets; - audio_rtp_codec_->Encode((uint8_t *)encoded_audio_buffer, size, - packets); + audio_rtp_codec_->Encode((uint8_t *)encoded_audio_buffer, + (uint32_t)size, packets); rtp_audio_sender_->Enqueue(packets); } } return 0; }); - return 0; + return ret; } int IceTransmission::SendDataFrame(const char *data, size_t size) { @@ -1030,7 +1031,7 @@ int IceTransmission::SendDataFrame(const char *data, size_t size) { if (rtp_data_sender_) { if (data_rtp_codec_) { - data_rtp_codec_->Encode((uint8_t *)data, size, packets); + data_rtp_codec_->Encode((uint8_t *)data, (uint32_t)size, packets); rtp_data_sender_->Enqueue(packets); } } diff --git a/src/ws/ws_client.cpp b/src/ws/ws_client.cpp index e3e5968..e6ab56f 100644 --- a/src/ws/ws_client.cpp +++ b/src/ws/ws_client.cpp @@ -136,7 +136,8 @@ void WsClient::Ping(websocketpp::connection_hdl hdl) { WsStatus WsClient::GetStatus() { return ws_status_; } -void WsClient::OnOpen(client *c, websocketpp::connection_hdl hdl) { +void WsClient::OnOpen([[maybe_unused]] client *c, + websocketpp::connection_hdl hdl) { ws_status_ = WsStatus::WsOpened; on_ws_status_(WsStatus::WsOpened); @@ -155,13 +156,15 @@ void WsClient::OnOpen(client *c, websocketpp::connection_hdl hdl) { } } -void WsClient::OnFail(client *c, websocketpp::connection_hdl hdl) { +void WsClient::OnFail([[maybe_unused]] client *c, + websocketpp::connection_hdl hdl) { ws_status_ = WsStatus::WsFailed; on_ws_status_(WsStatus::WsFailed); Connect(uri_); } -void WsClient::OnClose(client *c, websocketpp::connection_hdl hdl) { +void WsClient::OnClose([[maybe_unused]] client *c, + websocketpp::connection_hdl hdl) { ws_status_ = WsStatus::WsServerClosed; on_ws_status_(WsStatus::WsServerClosed); diff --git a/thirdparty/websocketpp/include/websocketpp/common/md5.hpp b/thirdparty/websocketpp/include/websocketpp/common/md5.hpp index 279725f..0dfc43e 100644 --- a/thirdparty/websocketpp/include/websocketpp/common/md5.hpp +++ b/thirdparty/websocketpp/include/websocketpp/common/md5.hpp @@ -53,6 +53,8 @@ Purschke . 1999-05-03 lpd Original version. */ +#pragma warning(push) +#pragma warning(disable : 4267) #ifndef WEBSOCKETPP_COMMON_MD5_HPP #define WEBSOCKETPP_COMMON_MD5_HPP @@ -68,120 +70,121 @@ */ #include -#include + #include +#include namespace websocketpp { /// Provides MD5 hashing functionality namespace md5 { typedef unsigned char md5_byte_t; /* 8-bit byte */ -typedef unsigned int md5_word_t; /* 32-bit word */ +typedef unsigned int md5_word_t; /* 32-bit word */ /* Define the state of the MD5 Algorithm. */ typedef struct md5_state_s { - md5_word_t count[2]; /* message length in bits, lsw first */ - md5_word_t abcd[4]; /* digest buffer */ - md5_byte_t buf[64]; /* accumulate block */ + md5_word_t count[2]; /* message length in bits, lsw first */ + md5_word_t abcd[4]; /* digest buffer */ + md5_byte_t buf[64]; /* accumulate block */ } md5_state_t; /* Initialize the algorithm. */ inline void md5_init(md5_state_t *pms); /* Append a string to the message. */ -inline void md5_append(md5_state_t *pms, md5_byte_t const * data, size_t nbytes); +inline void md5_append(md5_state_t *pms, md5_byte_t const *data, size_t nbytes); /* Finish the message and return the digest. */ inline void md5_finish(md5_state_t *pms, md5_byte_t digest[16]); -#undef ZSW_MD5_BYTE_ORDER /* 1 = big-endian, -1 = little-endian, 0 = unknown */ +#undef ZSW_MD5_BYTE_ORDER /* 1 = big-endian, -1 = little-endian, 0 = unknown \ + */ #ifdef ARCH_IS_BIG_ENDIAN -# define ZSW_MD5_BYTE_ORDER (ARCH_IS_BIG_ENDIAN ? 1 : -1) +#define ZSW_MD5_BYTE_ORDER (ARCH_IS_BIG_ENDIAN ? 1 : -1) #else -# define ZSW_MD5_BYTE_ORDER 0 +#define ZSW_MD5_BYTE_ORDER 0 #endif #define ZSW_MD5_T_MASK ((md5_word_t)~0) #define ZSW_MD5_T1 /* 0xd76aa478 */ (ZSW_MD5_T_MASK ^ 0x28955b87) #define ZSW_MD5_T2 /* 0xe8c7b756 */ (ZSW_MD5_T_MASK ^ 0x173848a9) -#define ZSW_MD5_T3 0x242070db +#define ZSW_MD5_T3 0x242070db #define ZSW_MD5_T4 /* 0xc1bdceee */ (ZSW_MD5_T_MASK ^ 0x3e423111) #define ZSW_MD5_T5 /* 0xf57c0faf */ (ZSW_MD5_T_MASK ^ 0x0a83f050) -#define ZSW_MD5_T6 0x4787c62a +#define ZSW_MD5_T6 0x4787c62a #define ZSW_MD5_T7 /* 0xa8304613 */ (ZSW_MD5_T_MASK ^ 0x57cfb9ec) #define ZSW_MD5_T8 /* 0xfd469501 */ (ZSW_MD5_T_MASK ^ 0x02b96afe) -#define ZSW_MD5_T9 0x698098d8 +#define ZSW_MD5_T9 0x698098d8 #define ZSW_MD5_T10 /* 0x8b44f7af */ (ZSW_MD5_T_MASK ^ 0x74bb0850) #define ZSW_MD5_T11 /* 0xffff5bb1 */ (ZSW_MD5_T_MASK ^ 0x0000a44e) #define ZSW_MD5_T12 /* 0x895cd7be */ (ZSW_MD5_T_MASK ^ 0x76a32841) -#define ZSW_MD5_T13 0x6b901122 +#define ZSW_MD5_T13 0x6b901122 #define ZSW_MD5_T14 /* 0xfd987193 */ (ZSW_MD5_T_MASK ^ 0x02678e6c) #define ZSW_MD5_T15 /* 0xa679438e */ (ZSW_MD5_T_MASK ^ 0x5986bc71) -#define ZSW_MD5_T16 0x49b40821 +#define ZSW_MD5_T16 0x49b40821 #define ZSW_MD5_T17 /* 0xf61e2562 */ (ZSW_MD5_T_MASK ^ 0x09e1da9d) #define ZSW_MD5_T18 /* 0xc040b340 */ (ZSW_MD5_T_MASK ^ 0x3fbf4cbf) -#define ZSW_MD5_T19 0x265e5a51 +#define ZSW_MD5_T19 0x265e5a51 #define ZSW_MD5_T20 /* 0xe9b6c7aa */ (ZSW_MD5_T_MASK ^ 0x16493855) #define ZSW_MD5_T21 /* 0xd62f105d */ (ZSW_MD5_T_MASK ^ 0x29d0efa2) -#define ZSW_MD5_T22 0x02441453 +#define ZSW_MD5_T22 0x02441453 #define ZSW_MD5_T23 /* 0xd8a1e681 */ (ZSW_MD5_T_MASK ^ 0x275e197e) #define ZSW_MD5_T24 /* 0xe7d3fbc8 */ (ZSW_MD5_T_MASK ^ 0x182c0437) -#define ZSW_MD5_T25 0x21e1cde6 +#define ZSW_MD5_T25 0x21e1cde6 #define ZSW_MD5_T26 /* 0xc33707d6 */ (ZSW_MD5_T_MASK ^ 0x3cc8f829) #define ZSW_MD5_T27 /* 0xf4d50d87 */ (ZSW_MD5_T_MASK ^ 0x0b2af278) -#define ZSW_MD5_T28 0x455a14ed +#define ZSW_MD5_T28 0x455a14ed #define ZSW_MD5_T29 /* 0xa9e3e905 */ (ZSW_MD5_T_MASK ^ 0x561c16fa) #define ZSW_MD5_T30 /* 0xfcefa3f8 */ (ZSW_MD5_T_MASK ^ 0x03105c07) -#define ZSW_MD5_T31 0x676f02d9 +#define ZSW_MD5_T31 0x676f02d9 #define ZSW_MD5_T32 /* 0x8d2a4c8a */ (ZSW_MD5_T_MASK ^ 0x72d5b375) #define ZSW_MD5_T33 /* 0xfffa3942 */ (ZSW_MD5_T_MASK ^ 0x0005c6bd) #define ZSW_MD5_T34 /* 0x8771f681 */ (ZSW_MD5_T_MASK ^ 0x788e097e) -#define ZSW_MD5_T35 0x6d9d6122 +#define ZSW_MD5_T35 0x6d9d6122 #define ZSW_MD5_T36 /* 0xfde5380c */ (ZSW_MD5_T_MASK ^ 0x021ac7f3) #define ZSW_MD5_T37 /* 0xa4beea44 */ (ZSW_MD5_T_MASK ^ 0x5b4115bb) -#define ZSW_MD5_T38 0x4bdecfa9 +#define ZSW_MD5_T38 0x4bdecfa9 #define ZSW_MD5_T39 /* 0xf6bb4b60 */ (ZSW_MD5_T_MASK ^ 0x0944b49f) #define ZSW_MD5_T40 /* 0xbebfbc70 */ (ZSW_MD5_T_MASK ^ 0x4140438f) -#define ZSW_MD5_T41 0x289b7ec6 +#define ZSW_MD5_T41 0x289b7ec6 #define ZSW_MD5_T42 /* 0xeaa127fa */ (ZSW_MD5_T_MASK ^ 0x155ed805) #define ZSW_MD5_T43 /* 0xd4ef3085 */ (ZSW_MD5_T_MASK ^ 0x2b10cf7a) -#define ZSW_MD5_T44 0x04881d05 +#define ZSW_MD5_T44 0x04881d05 #define ZSW_MD5_T45 /* 0xd9d4d039 */ (ZSW_MD5_T_MASK ^ 0x262b2fc6) #define ZSW_MD5_T46 /* 0xe6db99e5 */ (ZSW_MD5_T_MASK ^ 0x1924661a) -#define ZSW_MD5_T47 0x1fa27cf8 +#define ZSW_MD5_T47 0x1fa27cf8 #define ZSW_MD5_T48 /* 0xc4ac5665 */ (ZSW_MD5_T_MASK ^ 0x3b53a99a) #define ZSW_MD5_T49 /* 0xf4292244 */ (ZSW_MD5_T_MASK ^ 0x0bd6ddbb) -#define ZSW_MD5_T50 0x432aff97 +#define ZSW_MD5_T50 0x432aff97 #define ZSW_MD5_T51 /* 0xab9423a7 */ (ZSW_MD5_T_MASK ^ 0x546bdc58) #define ZSW_MD5_T52 /* 0xfc93a039 */ (ZSW_MD5_T_MASK ^ 0x036c5fc6) -#define ZSW_MD5_T53 0x655b59c3 +#define ZSW_MD5_T53 0x655b59c3 #define ZSW_MD5_T54 /* 0x8f0ccc92 */ (ZSW_MD5_T_MASK ^ 0x70f3336d) #define ZSW_MD5_T55 /* 0xffeff47d */ (ZSW_MD5_T_MASK ^ 0x00100b82) #define ZSW_MD5_T56 /* 0x85845dd1 */ (ZSW_MD5_T_MASK ^ 0x7a7ba22e) -#define ZSW_MD5_T57 0x6fa87e4f +#define ZSW_MD5_T57 0x6fa87e4f #define ZSW_MD5_T58 /* 0xfe2ce6e0 */ (ZSW_MD5_T_MASK ^ 0x01d3191f) #define ZSW_MD5_T59 /* 0xa3014314 */ (ZSW_MD5_T_MASK ^ 0x5cfebceb) -#define ZSW_MD5_T60 0x4e0811a1 +#define ZSW_MD5_T60 0x4e0811a1 #define ZSW_MD5_T61 /* 0xf7537e82 */ (ZSW_MD5_T_MASK ^ 0x08ac817d) #define ZSW_MD5_T62 /* 0xbd3af235 */ (ZSW_MD5_T_MASK ^ 0x42c50dca) -#define ZSW_MD5_T63 0x2ad7d2bb +#define ZSW_MD5_T63 0x2ad7d2bb #define ZSW_MD5_T64 /* 0xeb86d391 */ (ZSW_MD5_T_MASK ^ 0x14792c6e) -static void md5_process(md5_state_t *pms, md5_byte_t const * data /*[64]*/) { - md5_word_t - a = pms->abcd[0], b = pms->abcd[1], - c = pms->abcd[2], d = pms->abcd[3]; - md5_word_t t; +static void md5_process(md5_state_t *pms, md5_byte_t const *data /*[64]*/) { + md5_word_t a = pms->abcd[0], b = pms->abcd[1], c = pms->abcd[2], + d = pms->abcd[3]; + md5_word_t t; #if ZSW_MD5_BYTE_ORDER > 0 - /* Define storage only for big-endian CPUs. */ - md5_word_t X[16]; + /* Define storage only for big-endian CPUs. */ + md5_word_t X[16]; #else - /* Define storage for little-endian or both types of CPUs. */ - md5_word_t xbuf[16]; - md5_word_t const * X; + /* Define storage for little-endian or both types of CPUs. */ + md5_word_t xbuf[16]; + md5_word_t const *X; #endif - { + { #if ZSW_MD5_BYTE_ORDER == 0 /* * Determine dynamically whether this is a big-endian or @@ -192,257 +195,252 @@ static void md5_process(md5_state_t *pms, md5_byte_t const * data /*[64]*/) { if (*((md5_byte_t const *)&w)) /* dynamic little-endian */ #endif -#if ZSW_MD5_BYTE_ORDER <= 0 /* little-endian */ +#if ZSW_MD5_BYTE_ORDER <= 0 /* little-endian */ { - /* - * On little-endian machines, we can process properly aligned - * data without copying it. - */ - if (!((data - (md5_byte_t const *)0) & 3)) { + /* + * On little-endian machines, we can process properly aligned + * data without copying it. + */ + if (!((data - (md5_byte_t const *)0) & 3)) { /* data are properly aligned */ X = (md5_word_t const *)data; - } else { + } else { /* not aligned */ std::memcpy(xbuf, data, 64); X = xbuf; - } + } } #endif #if ZSW_MD5_BYTE_ORDER == 0 - else /* dynamic big-endian */ + else /* dynamic big-endian */ #endif -#if ZSW_MD5_BYTE_ORDER >= 0 /* big-endian */ +#if ZSW_MD5_BYTE_ORDER >= 0 /* big-endian */ { - /* - * On big-endian machines, we must arrange the bytes in the - * right order. - */ - const md5_byte_t *xp = data; - int i; + /* + * On big-endian machines, we must arrange the bytes in the + * right order. + */ + const md5_byte_t *xp = data; + int i; -# if ZSW_MD5_BYTE_ORDER == 0 - X = xbuf; /* (dynamic only) */ -# else -# define xbuf X /* (static only) */ -# endif - for (i = 0; i < 16; ++i, xp += 4) +#if ZSW_MD5_BYTE_ORDER == 0 + X = xbuf; /* (dynamic only) */ +#else +#define xbuf X /* (static only) */ +#endif + for (i = 0; i < 16; ++i, xp += 4) xbuf[i] = xp[0] + (xp[1] << 8) + (xp[2] << 16) + (xp[3] << 24); } #endif - } + } #define ZSW_MD5_ROTATE_LEFT(x, n) (((x) << (n)) | ((x) >> (32 - (n)))) - /* Round 1. */ - /* Let [abcd k s i] denote the operation - a = b + ((a + F(b,c,d) + X[k] + T[i]) <<< s). */ + /* Round 1. */ + /* Let [abcd k s i] denote the operation + a = b + ((a + F(b,c,d) + X[k] + T[i]) <<< s). */ #define ZSW_MD5_F(x, y, z) (((x) & (y)) | (~(x) & (z))) -#define SET(a, b, c, d, k, s, Ti)\ - t = a + ZSW_MD5_F(b,c,d) + X[k] + Ti;\ +#define SET(a, b, c, d, k, s, Ti) \ + t = a + ZSW_MD5_F(b, c, d) + X[k] + Ti; \ a = ZSW_MD5_ROTATE_LEFT(t, s) + b - /* Do the following 16 operations. */ - SET(a, b, c, d, 0, 7, ZSW_MD5_T1); - SET(d, a, b, c, 1, 12, ZSW_MD5_T2); - SET(c, d, a, b, 2, 17, ZSW_MD5_T3); - SET(b, c, d, a, 3, 22, ZSW_MD5_T4); - SET(a, b, c, d, 4, 7, ZSW_MD5_T5); - SET(d, a, b, c, 5, 12, ZSW_MD5_T6); - SET(c, d, a, b, 6, 17, ZSW_MD5_T7); - SET(b, c, d, a, 7, 22, ZSW_MD5_T8); - SET(a, b, c, d, 8, 7, ZSW_MD5_T9); - SET(d, a, b, c, 9, 12, ZSW_MD5_T10); - SET(c, d, a, b, 10, 17, ZSW_MD5_T11); - SET(b, c, d, a, 11, 22, ZSW_MD5_T12); - SET(a, b, c, d, 12, 7, ZSW_MD5_T13); - SET(d, a, b, c, 13, 12, ZSW_MD5_T14); - SET(c, d, a, b, 14, 17, ZSW_MD5_T15); - SET(b, c, d, a, 15, 22, ZSW_MD5_T16); + /* Do the following 16 operations. */ + SET(a, b, c, d, 0, 7, ZSW_MD5_T1); + SET(d, a, b, c, 1, 12, ZSW_MD5_T2); + SET(c, d, a, b, 2, 17, ZSW_MD5_T3); + SET(b, c, d, a, 3, 22, ZSW_MD5_T4); + SET(a, b, c, d, 4, 7, ZSW_MD5_T5); + SET(d, a, b, c, 5, 12, ZSW_MD5_T6); + SET(c, d, a, b, 6, 17, ZSW_MD5_T7); + SET(b, c, d, a, 7, 22, ZSW_MD5_T8); + SET(a, b, c, d, 8, 7, ZSW_MD5_T9); + SET(d, a, b, c, 9, 12, ZSW_MD5_T10); + SET(c, d, a, b, 10, 17, ZSW_MD5_T11); + SET(b, c, d, a, 11, 22, ZSW_MD5_T12); + SET(a, b, c, d, 12, 7, ZSW_MD5_T13); + SET(d, a, b, c, 13, 12, ZSW_MD5_T14); + SET(c, d, a, b, 14, 17, ZSW_MD5_T15); + SET(b, c, d, a, 15, 22, ZSW_MD5_T16); #undef SET - /* Round 2. */ - /* Let [abcd k s i] denote the operation - a = b + ((a + G(b,c,d) + X[k] + T[i]) <<< s). */ + /* Round 2. */ + /* Let [abcd k s i] denote the operation + a = b + ((a + G(b,c,d) + X[k] + T[i]) <<< s). */ #define ZSW_MD5_G(x, y, z) (((x) & (z)) | ((y) & ~(z))) -#define SET(a, b, c, d, k, s, Ti)\ - t = a + ZSW_MD5_G(b,c,d) + X[k] + Ti;\ +#define SET(a, b, c, d, k, s, Ti) \ + t = a + ZSW_MD5_G(b, c, d) + X[k] + Ti; \ a = ZSW_MD5_ROTATE_LEFT(t, s) + b - /* Do the following 16 operations. */ - SET(a, b, c, d, 1, 5, ZSW_MD5_T17); - SET(d, a, b, c, 6, 9, ZSW_MD5_T18); - SET(c, d, a, b, 11, 14, ZSW_MD5_T19); - SET(b, c, d, a, 0, 20, ZSW_MD5_T20); - SET(a, b, c, d, 5, 5, ZSW_MD5_T21); - SET(d, a, b, c, 10, 9, ZSW_MD5_T22); - SET(c, d, a, b, 15, 14, ZSW_MD5_T23); - SET(b, c, d, a, 4, 20, ZSW_MD5_T24); - SET(a, b, c, d, 9, 5, ZSW_MD5_T25); - SET(d, a, b, c, 14, 9, ZSW_MD5_T26); - SET(c, d, a, b, 3, 14, ZSW_MD5_T27); - SET(b, c, d, a, 8, 20, ZSW_MD5_T28); - SET(a, b, c, d, 13, 5, ZSW_MD5_T29); - SET(d, a, b, c, 2, 9, ZSW_MD5_T30); - SET(c, d, a, b, 7, 14, ZSW_MD5_T31); - SET(b, c, d, a, 12, 20, ZSW_MD5_T32); + /* Do the following 16 operations. */ + SET(a, b, c, d, 1, 5, ZSW_MD5_T17); + SET(d, a, b, c, 6, 9, ZSW_MD5_T18); + SET(c, d, a, b, 11, 14, ZSW_MD5_T19); + SET(b, c, d, a, 0, 20, ZSW_MD5_T20); + SET(a, b, c, d, 5, 5, ZSW_MD5_T21); + SET(d, a, b, c, 10, 9, ZSW_MD5_T22); + SET(c, d, a, b, 15, 14, ZSW_MD5_T23); + SET(b, c, d, a, 4, 20, ZSW_MD5_T24); + SET(a, b, c, d, 9, 5, ZSW_MD5_T25); + SET(d, a, b, c, 14, 9, ZSW_MD5_T26); + SET(c, d, a, b, 3, 14, ZSW_MD5_T27); + SET(b, c, d, a, 8, 20, ZSW_MD5_T28); + SET(a, b, c, d, 13, 5, ZSW_MD5_T29); + SET(d, a, b, c, 2, 9, ZSW_MD5_T30); + SET(c, d, a, b, 7, 14, ZSW_MD5_T31); + SET(b, c, d, a, 12, 20, ZSW_MD5_T32); #undef SET - /* Round 3. */ - /* Let [abcd k s t] denote the operation - a = b + ((a + H(b,c,d) + X[k] + T[i]) <<< s). */ + /* Round 3. */ + /* Let [abcd k s t] denote the operation + a = b + ((a + H(b,c,d) + X[k] + T[i]) <<< s). */ #define ZSW_MD5_H(x, y, z) ((x) ^ (y) ^ (z)) -#define SET(a, b, c, d, k, s, Ti)\ - t = a + ZSW_MD5_H(b,c,d) + X[k] + Ti;\ +#define SET(a, b, c, d, k, s, Ti) \ + t = a + ZSW_MD5_H(b, c, d) + X[k] + Ti; \ a = ZSW_MD5_ROTATE_LEFT(t, s) + b - /* Do the following 16 operations. */ - SET(a, b, c, d, 5, 4, ZSW_MD5_T33); - SET(d, a, b, c, 8, 11, ZSW_MD5_T34); - SET(c, d, a, b, 11, 16, ZSW_MD5_T35); - SET(b, c, d, a, 14, 23, ZSW_MD5_T36); - SET(a, b, c, d, 1, 4, ZSW_MD5_T37); - SET(d, a, b, c, 4, 11, ZSW_MD5_T38); - SET(c, d, a, b, 7, 16, ZSW_MD5_T39); - SET(b, c, d, a, 10, 23, ZSW_MD5_T40); - SET(a, b, c, d, 13, 4, ZSW_MD5_T41); - SET(d, a, b, c, 0, 11, ZSW_MD5_T42); - SET(c, d, a, b, 3, 16, ZSW_MD5_T43); - SET(b, c, d, a, 6, 23, ZSW_MD5_T44); - SET(a, b, c, d, 9, 4, ZSW_MD5_T45); - SET(d, a, b, c, 12, 11, ZSW_MD5_T46); - SET(c, d, a, b, 15, 16, ZSW_MD5_T47); - SET(b, c, d, a, 2, 23, ZSW_MD5_T48); + /* Do the following 16 operations. */ + SET(a, b, c, d, 5, 4, ZSW_MD5_T33); + SET(d, a, b, c, 8, 11, ZSW_MD5_T34); + SET(c, d, a, b, 11, 16, ZSW_MD5_T35); + SET(b, c, d, a, 14, 23, ZSW_MD5_T36); + SET(a, b, c, d, 1, 4, ZSW_MD5_T37); + SET(d, a, b, c, 4, 11, ZSW_MD5_T38); + SET(c, d, a, b, 7, 16, ZSW_MD5_T39); + SET(b, c, d, a, 10, 23, ZSW_MD5_T40); + SET(a, b, c, d, 13, 4, ZSW_MD5_T41); + SET(d, a, b, c, 0, 11, ZSW_MD5_T42); + SET(c, d, a, b, 3, 16, ZSW_MD5_T43); + SET(b, c, d, a, 6, 23, ZSW_MD5_T44); + SET(a, b, c, d, 9, 4, ZSW_MD5_T45); + SET(d, a, b, c, 12, 11, ZSW_MD5_T46); + SET(c, d, a, b, 15, 16, ZSW_MD5_T47); + SET(b, c, d, a, 2, 23, ZSW_MD5_T48); #undef SET - /* Round 4. */ - /* Let [abcd k s t] denote the operation - a = b + ((a + I(b,c,d) + X[k] + T[i]) <<< s). */ + /* Round 4. */ + /* Let [abcd k s t] denote the operation + a = b + ((a + I(b,c,d) + X[k] + T[i]) <<< s). */ #define ZSW_MD5_I(x, y, z) ((y) ^ ((x) | ~(z))) -#define SET(a, b, c, d, k, s, Ti)\ - t = a + ZSW_MD5_I(b,c,d) + X[k] + Ti;\ +#define SET(a, b, c, d, k, s, Ti) \ + t = a + ZSW_MD5_I(b, c, d) + X[k] + Ti; \ a = ZSW_MD5_ROTATE_LEFT(t, s) + b - /* Do the following 16 operations. */ - SET(a, b, c, d, 0, 6, ZSW_MD5_T49); - SET(d, a, b, c, 7, 10, ZSW_MD5_T50); - SET(c, d, a, b, 14, 15, ZSW_MD5_T51); - SET(b, c, d, a, 5, 21, ZSW_MD5_T52); - SET(a, b, c, d, 12, 6, ZSW_MD5_T53); - SET(d, a, b, c, 3, 10, ZSW_MD5_T54); - SET(c, d, a, b, 10, 15, ZSW_MD5_T55); - SET(b, c, d, a, 1, 21, ZSW_MD5_T56); - SET(a, b, c, d, 8, 6, ZSW_MD5_T57); - SET(d, a, b, c, 15, 10, ZSW_MD5_T58); - SET(c, d, a, b, 6, 15, ZSW_MD5_T59); - SET(b, c, d, a, 13, 21, ZSW_MD5_T60); - SET(a, b, c, d, 4, 6, ZSW_MD5_T61); - SET(d, a, b, c, 11, 10, ZSW_MD5_T62); - SET(c, d, a, b, 2, 15, ZSW_MD5_T63); - SET(b, c, d, a, 9, 21, ZSW_MD5_T64); + /* Do the following 16 operations. */ + SET(a, b, c, d, 0, 6, ZSW_MD5_T49); + SET(d, a, b, c, 7, 10, ZSW_MD5_T50); + SET(c, d, a, b, 14, 15, ZSW_MD5_T51); + SET(b, c, d, a, 5, 21, ZSW_MD5_T52); + SET(a, b, c, d, 12, 6, ZSW_MD5_T53); + SET(d, a, b, c, 3, 10, ZSW_MD5_T54); + SET(c, d, a, b, 10, 15, ZSW_MD5_T55); + SET(b, c, d, a, 1, 21, ZSW_MD5_T56); + SET(a, b, c, d, 8, 6, ZSW_MD5_T57); + SET(d, a, b, c, 15, 10, ZSW_MD5_T58); + SET(c, d, a, b, 6, 15, ZSW_MD5_T59); + SET(b, c, d, a, 13, 21, ZSW_MD5_T60); + SET(a, b, c, d, 4, 6, ZSW_MD5_T61); + SET(d, a, b, c, 11, 10, ZSW_MD5_T62); + SET(c, d, a, b, 2, 15, ZSW_MD5_T63); + SET(b, c, d, a, 9, 21, ZSW_MD5_T64); #undef SET - /* Then perform the following additions. (That is increment each - of the four registers by the value it had before this block - was started.) */ - pms->abcd[0] += a; - pms->abcd[1] += b; - pms->abcd[2] += c; - pms->abcd[3] += d; + /* Then perform the following additions. (That is increment each + of the four registers by the value it had before this block + was started.) */ + pms->abcd[0] += a; + pms->abcd[1] += b; + pms->abcd[2] += c; + pms->abcd[3] += d; } void md5_init(md5_state_t *pms) { - pms->count[0] = pms->count[1] = 0; - pms->abcd[0] = 0x67452301; - pms->abcd[1] = /*0xefcdab89*/ ZSW_MD5_T_MASK ^ 0x10325476; - pms->abcd[2] = /*0x98badcfe*/ ZSW_MD5_T_MASK ^ 0x67452301; - pms->abcd[3] = 0x10325476; + pms->count[0] = pms->count[1] = 0; + pms->abcd[0] = 0x67452301; + pms->abcd[1] = /*0xefcdab89*/ ZSW_MD5_T_MASK ^ 0x10325476; + pms->abcd[2] = /*0x98badcfe*/ ZSW_MD5_T_MASK ^ 0x67452301; + pms->abcd[3] = 0x10325476; } -void md5_append(md5_state_t *pms, md5_byte_t const * data, size_t nbytes) { - md5_byte_t const * p = data; - size_t left = nbytes; - int offset = (pms->count[0] >> 3) & 63; - md5_word_t nbits = (md5_word_t)(nbytes << 3); +void md5_append(md5_state_t *pms, md5_byte_t const *data, size_t nbytes) { + md5_byte_t const *p = data; + size_t left = nbytes; + int offset = (pms->count[0] >> 3) & 63; + md5_word_t nbits = (md5_word_t)(nbytes << 3); - if (nbytes <= 0) - return; + if (nbytes <= 0) return; - /* Update the message length. */ - pms->count[1] += nbytes >> 29; - pms->count[0] += nbits; - if (pms->count[0] < nbits) - pms->count[1]++; + /* Update the message length. */ + pms->count[1] += nbytes >> 29; + pms->count[0] += nbits; + if (pms->count[0] < nbits) pms->count[1]++; - /* Process an initial partial block. */ - if (offset) { + /* Process an initial partial block. */ + if (offset) { int copy = (offset + nbytes > 64 ? 64 - offset : static_cast(nbytes)); std::memcpy(pms->buf + offset, p, copy); - if (offset + copy < 64) - return; + if (offset + copy < 64) return; p += copy; left -= copy; md5_process(pms, pms->buf); - } + } - /* Process full blocks. */ - for (; left >= 64; p += 64, left -= 64) - md5_process(pms, p); + /* Process full blocks. */ + for (; left >= 64; p += 64, left -= 64) md5_process(pms, p); - /* Process a final partial block. */ - if (left) - std::memcpy(pms->buf, p, left); + /* Process a final partial block. */ + if (left) std::memcpy(pms->buf, p, left); } void md5_finish(md5_state_t *pms, md5_byte_t digest[16]) { - static md5_byte_t const pad[64] = { - 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - }; - md5_byte_t data[8]; - int i; + static md5_byte_t const pad[64] = { + 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + md5_byte_t data[8]; + int i; - /* Save the length before padding. */ - for (i = 0; i < 8; ++i) + /* Save the length before padding. */ + for (i = 0; i < 8; ++i) data[i] = (md5_byte_t)(pms->count[i >> 2] >> ((i & 3) << 3)); - /* Pad to 56 bytes mod 64. */ - md5_append(pms, pad, ((55 - (pms->count[0] >> 3)) & 63) + 1); - /* Append the length. */ - md5_append(pms, data, 8); - for (i = 0; i < 16; ++i) + /* Pad to 56 bytes mod 64. */ + md5_append(pms, pad, ((55 - (pms->count[0] >> 3)) & 63) + 1); + /* Append the length. */ + md5_append(pms, data, 8); + for (i = 0; i < 16; ++i) digest[i] = (md5_byte_t)(pms->abcd[i >> 2] >> ((i & 3) << 3)); } // some convenience c++ functions -inline std::string md5_hash_string(std::string const & s) { - char digest[16]; +inline std::string md5_hash_string(std::string const &s) { + char digest[16]; - md5_state_t state; + md5_state_t state; - md5_init(&state); - md5_append(&state, (md5_byte_t const *)s.c_str(), s.size()); - md5_finish(&state, (md5_byte_t *)digest); + md5_init(&state); + md5_append(&state, (md5_byte_t const *)s.c_str(), s.size()); + md5_finish(&state, (md5_byte_t *)digest); - std::string ret; - ret.resize(16); - std::copy(digest,digest+16,ret.begin()); + std::string ret; + ret.resize(16); + std::copy(digest, digest + 16, ret.begin()); - return ret; + return ret; } -const char hexval[16] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; +const char hexval[16] = {'0', '1', '2', '3', '4', '5', '6', '7', + '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; -inline std::string md5_hash_hex(std::string const & input) { - std::string hash = md5_hash_string(input); - std::string hex; +inline std::string md5_hash_hex(std::string const &input) { + std::string hash = md5_hash_string(input); + std::string hex; - for (size_t i = 0; i < hash.size(); i++) { - hex.push_back(hexval[((hash[i] >> 4) & 0xF)]); - hex.push_back(hexval[(hash[i]) & 0x0F]); - } + for (size_t i = 0; i < hash.size(); i++) { + hex.push_back(hexval[((hash[i] >> 4) & 0xF)]); + hex.push_back(hexval[(hash[i]) & 0x0F]); + } - return hex; + return hex; } -} // md5 -} // websocketpp +} // namespace md5 +} // namespace websocketpp -#endif // WEBSOCKETPP_COMMON_MD5_HPP +#endif // WEBSOCKETPP_COMMON_MD5_HPP +#pragma warning(pop) \ No newline at end of file diff --git a/thirdparty/websocketpp/include/websocketpp/frame.hpp b/thirdparty/websocketpp/include/websocketpp/frame.hpp index 18a990b..2b1e560 100644 --- a/thirdparty/websocketpp/include/websocketpp/frame.hpp +++ b/thirdparty/websocketpp/include/websocketpp/frame.hpp @@ -24,16 +24,17 @@ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * */ +#pragma warning(push) +#pragma warning(disable : 4127) +#pragma warning(disable : 4267) #ifndef WEBSOCKETPP_FRAME_HPP #define WEBSOCKETPP_FRAME_HPP #include #include - -#include #include - +#include #include namespace websocketpp { @@ -53,20 +54,20 @@ static unsigned int const MAX_EXTENDED_HEADER_LENGTH = 12; /// Two byte conversion union union uint16_converter { - uint16_t i; - uint8_t c[2]; + uint16_t i; + uint8_t c[2]; }; /// Four byte conversion union union uint32_converter { - uint32_t i; - uint8_t c[4]; + uint32_t i; + uint8_t c[4]; }; /// Eight byte conversion union union uint64_converter { - uint64_t i; - uint8_t c[8]; + uint64_t i; + uint8_t c[8]; }; /// Constants and utility functions related to WebSocket opcodes @@ -74,101 +75,95 @@ union uint64_converter { * WebSocket Opcodes are 4 bits. See RFC6455 section 5.2. */ namespace opcode { - enum value { - continuation = 0x0, - text = 0x1, - binary = 0x2, - rsv3 = 0x3, - rsv4 = 0x4, - rsv5 = 0x5, - rsv6 = 0x6, - rsv7 = 0x7, - close = 0x8, - ping = 0x9, - pong = 0xA, - control_rsvb = 0xB, - control_rsvc = 0xC, - control_rsvd = 0xD, - control_rsve = 0xE, - control_rsvf = 0xF, +enum value { + continuation = 0x0, + text = 0x1, + binary = 0x2, + rsv3 = 0x3, + rsv4 = 0x4, + rsv5 = 0x5, + rsv6 = 0x6, + rsv7 = 0x7, + close = 0x8, + ping = 0x9, + pong = 0xA, + control_rsvb = 0xB, + control_rsvc = 0xC, + control_rsvd = 0xD, + control_rsve = 0xE, + control_rsvf = 0xF, - CONTINUATION = 0x0, - TEXT = 0x1, - BINARY = 0x2, - RSV3 = 0x3, - RSV4 = 0x4, - RSV5 = 0x5, - RSV6 = 0x6, - RSV7 = 0x7, - CLOSE = 0x8, - PING = 0x9, - PONG = 0xA, - CONTROL_RSVB = 0xB, - CONTROL_RSVC = 0xC, - CONTROL_RSVD = 0xD, - CONTROL_RSVE = 0xE, - CONTROL_RSVF = 0xF - }; + CONTINUATION = 0x0, + TEXT = 0x1, + BINARY = 0x2, + RSV3 = 0x3, + RSV4 = 0x4, + RSV5 = 0x5, + RSV6 = 0x6, + RSV7 = 0x7, + CLOSE = 0x8, + PING = 0x9, + PONG = 0xA, + CONTROL_RSVB = 0xB, + CONTROL_RSVC = 0xC, + CONTROL_RSVD = 0xD, + CONTROL_RSVE = 0xE, + CONTROL_RSVF = 0xF +}; - /// Check if an opcode is reserved - /** - * @param v The opcode to test. - * @return Whether or not the opcode is reserved. - */ - inline bool reserved(value v) { - return (v >= rsv3 && v <= rsv7) || - (v >= control_rsvb && v <= control_rsvf); - } - - /// Check if an opcode is invalid - /** - * Invalid opcodes are negative or require greater than 4 bits to store. - * - * @param v The opcode to test. - * @return Whether or not the opcode is invalid. - */ - inline bool invalid(value v) { - return (v > 0xF || v < 0); - } - - /// Check if an opcode is for a control frame - /** - * @param v The opcode to test. - * @return Whether or not the opcode is a control opcode. - */ - inline bool is_control(value v) { - return v >= 0x8; - } +/// Check if an opcode is reserved +/** + * @param v The opcode to test. + * @return Whether or not the opcode is reserved. + */ +inline bool reserved(value v) { + return (v >= rsv3 && v <= rsv7) || (v >= control_rsvb && v <= control_rsvf); } +/// Check if an opcode is invalid +/** + * Invalid opcodes are negative or require greater than 4 bits to store. + * + * @param v The opcode to test. + * @return Whether or not the opcode is invalid. + */ +inline bool invalid(value v) { return (v > 0xF || v < 0); } + +/// Check if an opcode is for a control frame +/** + * @param v The opcode to test. + * @return Whether or not the opcode is a control opcode. + */ +inline bool is_control(value v) { return v >= 0x8; } +} // namespace opcode + /// Constants related to frame and payload limits namespace limits { - /// Minimum length of a WebSocket frame header. - static unsigned int const basic_header_length = 2; +/// Minimum length of a WebSocket frame header. +static unsigned int const basic_header_length = 2; - /// Maximum length of a WebSocket header - static unsigned int const max_header_length = 14; +/// Maximum length of a WebSocket header +static unsigned int const max_header_length = 14; - /// Maximum length of the variable portion of the WebSocket header - static unsigned int const max_extended_header_length = 12; +/// Maximum length of the variable portion of the WebSocket header +static unsigned int const max_extended_header_length = 12; - /// Maximum size of a basic WebSocket payload - static uint8_t const payload_size_basic = 125; +/// Maximum size of a basic WebSocket payload +static uint8_t const payload_size_basic = 125; - /// Maximum size of an extended WebSocket payload (basic payload = 126) - static uint16_t const payload_size_extended = 0xFFFF; // 2^16, 65535 +/// Maximum size of an extended WebSocket payload (basic payload = 126) +static uint16_t const payload_size_extended = 0xFFFF; // 2^16, 65535 - /// Maximum size of a jumbo WebSocket payload (basic payload = 127) - static uint64_t const payload_size_jumbo = 0x7FFFFFFFFFFFFFFFLL;//2^63 - - /// Maximum size of close frame reason - /** - * This is payload_size_basic - 2 bytes (as first two bytes are used for - * the close code - */ - static uint8_t const close_reason_size = 123; -} +/// Maximum size of a jumbo WebSocket payload (basic payload = 127) +static uint64_t const payload_size_jumbo = 0x7FFFFFFFFFFFFFFFLL; // 2^63 +/// Maximum size of close frame reason +/** + * This is payload_size_basic - 2 bytes (as first two bytes are used for + * the close code + */ +static uint8_t const close_reason_size = 123; +} // namespace limits // masks for fields in the basic header static uint8_t const BHB0_OPCODE = 0x0F; @@ -180,98 +175,97 @@ static uint8_t const BHB0_FIN = 0x80; static uint8_t const BHB1_PAYLOAD = 0x7F; static uint8_t const BHB1_MASK = 0x80; -static uint8_t const payload_size_code_16bit = 0x7E; // 126 -static uint8_t const payload_size_code_64bit = 0x7F; // 127 +static uint8_t const payload_size_code_16bit = 0x7E; // 126 +static uint8_t const payload_size_code_64bit = 0x7F; // 127 typedef uint32_converter masking_key_type; /// The constant size component of a WebSocket frame header struct basic_header { - basic_header() : b0(0x00),b1(0x00) {} + basic_header() : b0(0x00), b1(0x00) {} - basic_header(uint8_t p0, uint8_t p1) : b0(p0), b1(p1) {} + basic_header(uint8_t p0, uint8_t p1) : b0(p0), b1(p1) {} - basic_header(opcode::value op, uint64_t size, bool fin, bool mask, - bool rsv1 = false, bool rsv2 = false, bool rsv3 = false) : b0(0x00), - b1(0x00) - { - if (fin) { - b0 |= BHB0_FIN; - } - if (rsv1) { - b0 |= BHB0_RSV1; - } - if (rsv2) { - b0 |= BHB0_RSV2; - } - if (rsv3) { - b0 |= BHB0_RSV3; - } - b0 |= (op & BHB0_OPCODE); + basic_header(opcode::value op, uint64_t size, bool fin, bool mask, + bool rsv1 = false, bool rsv2 = false, bool rsv3 = false) + : b0(0x00), b1(0x00) { + if (fin) { + b0 |= BHB0_FIN; + } + if (rsv1) { + b0 |= BHB0_RSV1; + } + if (rsv2) { + b0 |= BHB0_RSV2; + } + if (rsv3) { + b0 |= BHB0_RSV3; + } + b0 |= (op & BHB0_OPCODE); - if (mask) { - b1 |= BHB1_MASK; - } - - uint8_t basic_value; - - if (size <= limits::payload_size_basic) { - basic_value = static_cast(size); - } else if (size <= limits::payload_size_extended) { - basic_value = payload_size_code_16bit; - } else { - basic_value = payload_size_code_64bit; - } - - - b1 |= basic_value; + if (mask) { + b1 |= BHB1_MASK; } - uint8_t b0; - uint8_t b1; + uint8_t basic_value; + + if (size <= limits::payload_size_basic) { + basic_value = static_cast(size); + } else if (size <= limits::payload_size_extended) { + basic_value = payload_size_code_16bit; + } else { + basic_value = payload_size_code_64bit; + } + + b1 |= basic_value; + } + + uint8_t b0; + uint8_t b1; }; /// The variable size component of a WebSocket frame header struct extended_header { - extended_header() { - std::fill_n(this->bytes,MAX_EXTENDED_HEADER_LENGTH,0x00); + extended_header() { + std::fill_n(this->bytes, MAX_EXTENDED_HEADER_LENGTH, 0x00); + } + + extended_header(uint64_t payload_size) { + std::fill_n(this->bytes, MAX_EXTENDED_HEADER_LENGTH, 0x00); + + copy_payload(payload_size); + } + + extended_header(uint64_t payload_size, uint32_t masking_key) { + std::fill_n(this->bytes, MAX_EXTENDED_HEADER_LENGTH, 0x00); + + // Copy payload size + int offset = copy_payload(payload_size); + + // Copy Masking Key + uint32_converter temp32; + temp32.i = masking_key; + std::copy(temp32.c, temp32.c + 4, bytes + offset); + } + + uint8_t bytes[MAX_EXTENDED_HEADER_LENGTH]; + + private: + int copy_payload(uint64_t payload_size) { + int payload_offset = 0; + + if (payload_size <= limits::payload_size_basic) { + payload_offset = 8; + } else if (payload_size <= limits::payload_size_extended) { + payload_offset = 6; } - extended_header(uint64_t payload_size) { - std::fill_n(this->bytes,MAX_EXTENDED_HEADER_LENGTH,0x00); + uint64_converter temp64; + temp64.i = lib::net::_htonll(payload_size); + std::copy(temp64.c + payload_offset, temp64.c + 8, bytes); - copy_payload(payload_size); - } - - extended_header(uint64_t payload_size, uint32_t masking_key) { - std::fill_n(this->bytes,MAX_EXTENDED_HEADER_LENGTH,0x00); - - // Copy payload size - int offset = copy_payload(payload_size); - - // Copy Masking Key - uint32_converter temp32; - temp32.i = masking_key; - std::copy(temp32.c,temp32.c+4,bytes+offset); - } - - uint8_t bytes[MAX_EXTENDED_HEADER_LENGTH]; -private: - int copy_payload(uint64_t payload_size) { - int payload_offset = 0; - - if (payload_size <= limits::payload_size_basic) { - payload_offset = 8; - } else if (payload_size <= limits::payload_size_extended) { - payload_offset = 6; - } - - uint64_converter temp64; - temp64.i = lib::net::_htonll(payload_size); - std::copy(temp64.c+payload_offset,temp64.c+8,bytes); - - return 8-payload_offset; - } + return 8 - payload_offset; + } }; bool get_fin(basic_header const &h); @@ -295,31 +289,30 @@ uint16_t get_extended_size(extended_header const &); uint64_t get_jumbo_size(extended_header const &); uint64_t get_payload_size(basic_header const &, extended_header const &); -size_t prepare_masking_key(masking_key_type const & key); +size_t prepare_masking_key(masking_key_type const &key); size_t circshift_prepared_key(size_t prepared_key, size_t offset); // Functions for performing xor based masking and unmasking template -void byte_mask(input_iter b, input_iter e, output_iter o, masking_key_type - const & key, size_t key_offset = 0); +void byte_mask(input_iter b, input_iter e, output_iter o, + masking_key_type const &key, size_t key_offset = 0); template -void byte_mask(iter_type b, iter_type e, masking_key_type const & key, - size_t key_offset = 0); -void word_mask_exact(uint8_t * input, uint8_t * output, size_t length, - masking_key_type const & key); -void word_mask_exact(uint8_t * data, size_t length, masking_key_type const & - key); -size_t word_mask_circ(uint8_t * input, uint8_t * output, size_t length, - size_t prepared_key); -size_t word_mask_circ(uint8_t * data, size_t length, size_t prepared_key); +void byte_mask(iter_type b, iter_type e, masking_key_type const &key, + size_t key_offset = 0); +void word_mask_exact(uint8_t *input, uint8_t *output, size_t length, + masking_key_type const &key); +void word_mask_exact(uint8_t *data, size_t length, masking_key_type const &key); +size_t word_mask_circ(uint8_t *input, uint8_t *output, size_t length, + size_t prepared_key); +size_t word_mask_circ(uint8_t *data, size_t length, size_t prepared_key); /// Check whether the frame's FIN bit is set. /** * @param [in] h The basic header to extract from. * @return True if the header's fin bit is set. */ -inline bool get_fin(basic_header const & h) { - return ((h.b0 & BHB0_FIN) == BHB0_FIN); +inline bool get_fin(basic_header const &h) { + return ((h.b0 & BHB0_FIN) == BHB0_FIN); } /// Set the frame's FIN bit @@ -327,8 +320,8 @@ inline bool get_fin(basic_header const & h) { * @param [out] h Header to set. * @param [in] value Value to set it to. */ -inline void set_fin(basic_header & h, bool value) { - h.b0 = (value ? h.b0 | BHB0_FIN : h.b0 & ~BHB0_FIN); +inline void set_fin(basic_header &h, bool value) { + h.b0 = (value ? h.b0 | BHB0_FIN : h.b0 & ~BHB0_FIN); } /// check whether the frame's RSV1 bit is set @@ -337,7 +330,7 @@ inline void set_fin(basic_header & h, bool value) { * @return True if the header's RSV1 bit is set. */ inline bool get_rsv1(const basic_header &h) { - return ((h.b0 & BHB0_RSV1) == BHB0_RSV1); + return ((h.b0 & BHB0_RSV1) == BHB0_RSV1); } /// Set the frame's RSV1 bit @@ -346,7 +339,7 @@ inline bool get_rsv1(const basic_header &h) { * @param [in] value Value to set it to. */ inline void set_rsv1(basic_header &h, bool value) { - h.b0 = (value ? h.b0 | BHB0_RSV1 : h.b0 & ~BHB0_RSV1); + h.b0 = (value ? h.b0 | BHB0_RSV1 : h.b0 & ~BHB0_RSV1); } /// check whether the frame's RSV2 bit is set @@ -355,7 +348,7 @@ inline void set_rsv1(basic_header &h, bool value) { * @return True if the header's RSV2 bit is set. */ inline bool get_rsv2(const basic_header &h) { - return ((h.b0 & BHB0_RSV2) == BHB0_RSV2); + return ((h.b0 & BHB0_RSV2) == BHB0_RSV2); } /// Set the frame's RSV2 bit @@ -364,7 +357,7 @@ inline bool get_rsv2(const basic_header &h) { * @param [in] value Value to set it to. */ inline void set_rsv2(basic_header &h, bool value) { - h.b0 = (value ? h.b0 | BHB0_RSV2 : h.b0 & ~BHB0_RSV2); + h.b0 = (value ? h.b0 | BHB0_RSV2 : h.b0 & ~BHB0_RSV2); } /// check whether the frame's RSV3 bit is set @@ -373,7 +366,7 @@ inline void set_rsv2(basic_header &h, bool value) { * @return True if the header's RSV3 bit is set. */ inline bool get_rsv3(const basic_header &h) { - return ((h.b0 & BHB0_RSV3) == BHB0_RSV3); + return ((h.b0 & BHB0_RSV3) == BHB0_RSV3); } /// Set the frame's RSV3 bit @@ -382,7 +375,7 @@ inline bool get_rsv3(const basic_header &h) { * @param [in] value Value to set it to. */ inline void set_rsv3(basic_header &h, bool value) { - h.b0 = (value ? h.b0 | BHB0_RSV3 : h.b0 & ~BHB0_RSV3); + h.b0 = (value ? h.b0 | BHB0_RSV3 : h.b0 & ~BHB0_RSV3); } /// Extract opcode from basic header @@ -391,7 +384,7 @@ inline void set_rsv3(basic_header &h, bool value) { * @return The opcode value of the header. */ inline opcode::value get_opcode(const basic_header &h) { - return opcode::value(h.b0 & BHB0_OPCODE); + return opcode::value(h.b0 & BHB0_OPCODE); } /// check whether the frame is masked @@ -399,8 +392,8 @@ inline opcode::value get_opcode(const basic_header &h) { * @param [in] h The basic header to extract from. * @return True if the header mask bit is set. */ -inline bool get_masked(basic_header const & h) { - return ((h.b1 & BHB1_MASK) == BHB1_MASK); +inline bool get_masked(basic_header const &h) { + return ((h.b1 & BHB1_MASK) == BHB1_MASK); } /// Set the frame's MASK bit @@ -408,8 +401,8 @@ inline bool get_masked(basic_header const & h) { * @param [out] h Header to set. * @param value Value to set it to. */ -inline void set_masked(basic_header & h, bool value) { - h.b1 = (value ? h.b1 | BHB1_MASK : h.b1 & ~BHB1_MASK); +inline void set_masked(basic_header &h, bool value) { + h.b1 = (value ? h.b1 | BHB1_MASK : h.b1 & ~BHB1_MASK); } /// Extracts the raw payload length specified in the basic header @@ -429,7 +422,7 @@ inline void set_masked(basic_header & h, bool value) { * @return The exact size encoded in h. */ inline uint8_t get_basic_size(const basic_header &h) { - return h.b1 & BHB1_PAYLOAD; + return h.b1 & BHB1_PAYLOAD; } /// Calculates the full length of the header based on the first bytes. @@ -442,19 +435,19 @@ inline uint8_t get_basic_size(const basic_header &h) { * @param h Basic frame header to extract size from. * @return Full length of the extended header. */ -inline size_t get_header_len(basic_header const & h) { - // TODO: check extensions? +inline size_t get_header_len(basic_header const &h) { + // TODO: check extensions? - // masking key offset represents the space used for the extended length - // fields - size_t size = BASIC_HEADER_LENGTH + get_masking_key_offset(h); + // masking key offset represents the space used for the extended length + // fields + size_t size = BASIC_HEADER_LENGTH + get_masking_key_offset(h); - // If the header is masked there is a 4 byte masking key - if (get_masked(h)) { - size += 4; - } + // If the header is masked there is a 4 byte masking key + if (get_masked(h)) { + size += 4; + } - return size; + return size; } /// Calculate the offset location of the masking key within the extended header @@ -467,13 +460,13 @@ inline size_t get_header_len(basic_header const & h) { * @return byte offset of the first byte of the masking key */ inline unsigned int get_masking_key_offset(const basic_header &h) { - if (get_basic_size(h) == payload_size_code_16bit) { - return 2; - } else if (get_basic_size(h) == payload_size_code_64bit) { - return 8; - } else { - return 0; - } + if (get_basic_size(h) == payload_size_code_16bit) { + return 2; + } else if (get_basic_size(h) == payload_size_code_64bit) { + return 8; + } else { + return 0; + } } /// Generate a properly sized contiguous string that encodes a full frame header @@ -486,19 +479,16 @@ inline unsigned int get_masking_key_offset(const basic_header &h) { * * @return A contiguous string containing h and e */ -inline std::string prepare_header(const basic_header &h, const - extended_header &e) -{ - std::string ret; +inline std::string prepare_header(const basic_header &h, + const extended_header &e) { + std::string ret; - ret.push_back(char(h.b0)); - ret.push_back(char(h.b1)); - ret.append( - reinterpret_cast(e.bytes), - get_header_len(h)-BASIC_HEADER_LENGTH - ); + ret.push_back(char(h.b0)); + ret.push_back(char(h.b1)); + ret.append(reinterpret_cast(e.bytes), + get_header_len(h) - BASIC_HEADER_LENGTH); - return ret; + return ret; } /// Extract the masking key from a frame header @@ -513,19 +503,18 @@ inline std::string prepare_header(const basic_header &h, const * * @return The masking key as an integer. */ -inline masking_key_type get_masking_key(const basic_header &h, const - extended_header &e) -{ - masking_key_type temp32; +inline masking_key_type get_masking_key(const basic_header &h, + const extended_header &e) { + masking_key_type temp32; - if (!get_masked(h)) { - temp32.i = 0; - } else { - unsigned int offset = get_masking_key_offset(h); - std::copy(e.bytes+offset,e.bytes+offset+4,temp32.c); - } + if (!get_masked(h)) { + temp32.i = 0; + } else { + unsigned int offset = get_masking_key_offset(h); + std::copy(e.bytes + offset, e.bytes + offset + 4, temp32.c); + } - return temp32; + return temp32; } /// Extract the extended size field from an extended header @@ -538,9 +527,9 @@ inline masking_key_type get_masking_key(const basic_header &h, const * @return The size encoded in the extended header in host byte order */ inline uint16_t get_extended_size(const extended_header &e) { - uint16_converter temp16; - std::copy(e.bytes,e.bytes+2,temp16.c); - return ntohs(temp16.i); + uint16_converter temp16; + std::copy(e.bytes, e.bytes + 2, temp16.c); + return ntohs(temp16.i); } /// Extract the jumbo size field from an extended header @@ -553,9 +542,9 @@ inline uint16_t get_extended_size(const extended_header &e) { * @return The size encoded in the extended header in host byte order */ inline uint64_t get_jumbo_size(const extended_header &e) { - uint64_converter temp64; - std::copy(e.bytes,e.bytes+8,temp64.c); - return lib::net::_ntohll(temp64.i); + uint64_converter temp64; + std::copy(e.bytes, e.bytes + 8, temp64.c); + return lib::net::_ntohll(temp64.i); } /// Extract the full payload size field from a WebSocket header @@ -570,18 +559,17 @@ inline uint64_t get_jumbo_size(const extended_header &e) { * * @return The size encoded in the combined header in host byte order. */ -inline uint64_t get_payload_size(const basic_header &h, const - extended_header &e) -{ - uint8_t val = get_basic_size(h); +inline uint64_t get_payload_size(const basic_header &h, + const extended_header &e) { + uint8_t val = get_basic_size(h); - if (val <= limits::payload_size_basic) { - return val; - } else if (val == payload_size_code_16bit) { - return get_extended_size(e); - } else { - return get_jumbo_size(e); - } + if (val <= limits::payload_size_basic) { + return val; + } else if (val == payload_size_code_16bit) { + return get_extended_size(e); + } else { + return get_jumbo_size(e); + } } /// Extract a masking key into a value the size of a machine word. @@ -592,15 +580,15 @@ inline uint64_t get_payload_size(const basic_header &h, const * * @return prepared key as a machine word */ -inline size_t prepare_masking_key(const masking_key_type& key) { - size_t low_bits = static_cast(key.i); +inline size_t prepare_masking_key(const masking_key_type &key) { + size_t low_bits = static_cast(key.i); - if (sizeof(size_t) == 8) { - uint64_t high_bits = static_cast(key.i); - return static_cast((high_bits << 32) | low_bits); - } else { - return low_bits; - } + if (sizeof(size_t) == 8) { + uint64_t high_bits = static_cast(key.i); + return static_cast((high_bits << 32) | low_bits); + } else { + return low_bits; + } } /// circularly shifts the supplied prepared masking key by offset bytes @@ -610,16 +598,16 @@ inline size_t prepare_masking_key(const masking_key_type& key) { * to zero and less than sizeof(size_t). */ inline size_t circshift_prepared_key(size_t prepared_key, size_t offset) { - if (offset == 0) { - return prepared_key; - } - if (lib::net::is_little_endian()) { - size_t temp = prepared_key << (sizeof(size_t)-offset)*8; - return (prepared_key >> offset*8) | temp; - } else { - size_t temp = prepared_key >> (sizeof(size_t)-offset)*8; - return (prepared_key << offset*8) | temp; - } + if (offset == 0) { + return prepared_key; + } + if (lib::net::is_little_endian()) { + size_t temp = prepared_key << (sizeof(size_t) - offset) * 8; + return (prepared_key >> offset * 8) | temp; + } else { + size_t temp = prepared_key >> (sizeof(size_t) - offset) * 8; + return (prepared_key << offset * 8) | temp; + } } /// Byte by byte mask/unmask @@ -643,15 +631,14 @@ inline size_t circshift_prepared_key(size_t prepared_key, size_t offset) { */ template void byte_mask(input_iter first, input_iter last, output_iter result, - masking_key_type const & key, size_t key_offset) -{ - size_t key_index = key_offset%4; - while (first != last) { - *result = *first ^ key.c[key_index++]; - key_index %= 4; - ++result; - ++first; - } + masking_key_type const &key, size_t key_offset) { + size_t key_index = key_offset % 4; + while (first != last) { + *result = *first ^ key.c[key_index++]; + key_index %= 4; + ++result; + ++first; + } } /// Byte by byte mask/unmask (in place) @@ -672,10 +659,9 @@ void byte_mask(input_iter first, input_iter last, output_iter result, * @param key_offset offset value to start masking at. */ template -void byte_mask(iter_type b, iter_type e, masking_key_type const & key, - size_t key_offset) -{ - byte_mask(b,e,b,key,key_offset); +void byte_mask(iter_type b, iter_type e, masking_key_type const &key, + size_t key_offset) { + byte_mask(b, e, b, key, key_offset); } /// Exact word aligned mask/unmask @@ -699,21 +685,20 @@ void byte_mask(iter_type b, iter_type e, masking_key_type const & key, * * @param key Masking key to use */ -inline void word_mask_exact(uint8_t* input, uint8_t* output, size_t length, - const masking_key_type& key) -{ - size_t prepared_key = prepare_masking_key(key); - size_t n = length/sizeof(size_t); - size_t* input_word = reinterpret_cast(input); - size_t* output_word = reinterpret_cast(output); +inline void word_mask_exact(uint8_t *input, uint8_t *output, size_t length, + const masking_key_type &key) { + size_t prepared_key = prepare_masking_key(key); + size_t n = length / sizeof(size_t); + size_t *input_word = reinterpret_cast(input); + size_t *output_word = reinterpret_cast(output); - for (size_t i = 0; i < n; i++) { - output_word[i] = input_word[i] ^ prepared_key; - } + for (size_t i = 0; i < n; i++) { + output_word[i] = input_word[i] ^ prepared_key; + } - for (size_t i = n*sizeof(size_t); i < length; i++) { - output[i] = input[i] ^ key.c[i%4]; - } + for (size_t i = n * sizeof(size_t); i < length; i++) { + output[i] = input[i] ^ key.c[i % 4]; + } } /// Exact word aligned mask/unmask (in place) @@ -728,10 +713,9 @@ inline void word_mask_exact(uint8_t* input, uint8_t* output, size_t length, * * @param key Masking key to use */ -inline void word_mask_exact(uint8_t* data, size_t length, const - masking_key_type& key) -{ - word_mask_exact(data,data,length,key); +inline void word_mask_exact(uint8_t *data, size_t length, + const masking_key_type &key) { + word_mask_exact(data, data, length, key); } /// Circular word aligned mask/unmask @@ -765,27 +749,26 @@ inline void word_mask_exact(uint8_t* data, size_t length, const * * @return the prepared_key shifted to account for the input length */ -inline size_t word_mask_circ(uint8_t * input, uint8_t * output, size_t length, - size_t prepared_key) -{ - size_t n = length / sizeof(size_t); // whole words - size_t l = length - (n * sizeof(size_t)); // remaining bytes - size_t * input_word = reinterpret_cast(input); - size_t * output_word = reinterpret_cast(output); +inline size_t word_mask_circ(uint8_t *input, uint8_t *output, size_t length, + size_t prepared_key) { + size_t n = length / sizeof(size_t); // whole words + size_t l = length - (n * sizeof(size_t)); // remaining bytes + size_t *input_word = reinterpret_cast(input); + size_t *output_word = reinterpret_cast(output); - // mask word by word - for (size_t i = 0; i < n; i++) { - output_word[i] = input_word[i] ^ prepared_key; - } + // mask word by word + for (size_t i = 0; i < n; i++) { + output_word[i] = input_word[i] ^ prepared_key; + } - // mask partial word at the end - size_t start = length - l; - uint8_t * byte_key = reinterpret_cast(&prepared_key); - for (size_t i = 0; i < l; ++i) { - output[start+i] = input[start+i] ^ byte_key[i]; - } + // mask partial word at the end + size_t start = length - l; + uint8_t *byte_key = reinterpret_cast(&prepared_key); + for (size_t i = 0; i < l; ++i) { + output[start + i] = input[start + i] ^ byte_key[i]; + } - return circshift_prepared_key(prepared_key,l); + return circshift_prepared_key(prepared_key, l); } /// Circular word aligned mask/unmask (in place) @@ -802,8 +785,9 @@ inline size_t word_mask_circ(uint8_t * input, uint8_t * output, size_t length, * * @return the prepared_key shifted to account for the input length */ -inline size_t word_mask_circ(uint8_t* data, size_t length, size_t prepared_key){ - return word_mask_circ(data,data,length,prepared_key); +inline size_t word_mask_circ(uint8_t *data, size_t length, + size_t prepared_key) { + return word_mask_circ(data, data, length, prepared_key); } /// Circular byte aligned mask/unmask @@ -827,17 +811,16 @@ inline size_t word_mask_circ(uint8_t* data, size_t length, size_t prepared_key){ * * @return the prepared_key shifted to account for the input length */ -inline size_t byte_mask_circ(uint8_t * input, uint8_t * output, size_t length, - size_t prepared_key) -{ - uint32_converter key; - key.i = prepared_key; +inline size_t byte_mask_circ(uint8_t *input, uint8_t *output, size_t length, + size_t prepared_key) { + uint32_converter key; + key.i = prepared_key; - for (size_t i = 0; i < length; ++i) { - output[i] = input[i] ^ key.c[i % 4]; - } + for (size_t i = 0; i < length; ++i) { + output[i] = input[i] ^ key.c[i % 4]; + } - return circshift_prepared_key(prepared_key,length % 4); + return circshift_prepared_key(prepared_key, length % 4); } /// Circular byte aligned mask/unmask (in place) @@ -854,11 +837,13 @@ inline size_t byte_mask_circ(uint8_t * input, uint8_t * output, size_t length, * * @return the prepared_key shifted to account for the input length */ -inline size_t byte_mask_circ(uint8_t* data, size_t length, size_t prepared_key){ - return byte_mask_circ(data,data,length,prepared_key); +inline size_t byte_mask_circ(uint8_t *data, size_t length, + size_t prepared_key) { + return byte_mask_circ(data, data, length, prepared_key); } -} // namespace frame -} // namespace websocketpp +} // namespace frame +} // namespace websocketpp -#endif //WEBSOCKETPP_FRAME_HPP +#endif // WEBSOCKETPP_FRAME_HPP +#pragma warning(pop) \ No newline at end of file diff --git a/thirdparty/websocketpp/include/websocketpp/processors/hybi13.hpp b/thirdparty/websocketpp/include/websocketpp/processors/hybi13.hpp index ca12439..3c8d87b 100644 --- a/thirdparty/websocketpp/include/websocketpp/processors/hybi13.hpp +++ b/thirdparty/websocketpp/include/websocketpp/processors/hybi13.hpp @@ -24,27 +24,25 @@ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * */ +#pragma warning(push) +#pragma warning(disable : 4127) #ifndef WEBSOCKETPP_PROCESSOR_HYBI13_HPP #define WEBSOCKETPP_PROCESSOR_HYBI13_HPP -#include - -#include -#include - -#include -#include -#include - -#include -#include - #include #include #include -#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace websocketpp { namespace processor { @@ -52,1027 +50,1002 @@ namespace processor { /// Processor for Hybi version 13 (RFC6455) template class hybi13 : public processor { -public: - typedef processor base; + public: + typedef processor base; - typedef typename config::request_type request_type; - typedef typename config::response_type response_type; + typedef typename config::request_type request_type; + typedef typename config::response_type response_type; - typedef typename config::message_type message_type; - typedef typename message_type::ptr message_ptr; + typedef typename config::message_type message_type; + typedef typename message_type::ptr message_ptr; - typedef typename config::con_msg_manager_type msg_manager_type; - typedef typename msg_manager_type::ptr msg_manager_ptr; - typedef typename config::rng_type rng_type; + typedef typename config::con_msg_manager_type msg_manager_type; + typedef typename msg_manager_type::ptr msg_manager_ptr; + typedef typename config::rng_type rng_type; - typedef typename config::permessage_deflate_type permessage_deflate_type; + typedef typename config::permessage_deflate_type permessage_deflate_type; - typedef std::pair err_str_pair; + typedef std::pair err_str_pair; - explicit hybi13(bool secure, bool p_is_server, msg_manager_ptr manager, rng_type& rng) - : processor(secure, p_is_server) - , m_msg_manager(manager) - , m_rng(rng) - { - reset_headers(); + explicit hybi13(bool secure, bool p_is_server, msg_manager_ptr manager, + rng_type& rng) + : processor(secure, p_is_server), + m_msg_manager(manager), + m_rng(rng) { + reset_headers(); + } + + int get_version() const { return 13; } + + bool has_permessage_deflate() const { + return m_permessage_deflate.is_implemented(); + } + + err_str_pair negotiate_extensions(request_type const& request) { + return negotiate_extensions_helper(request); + } + + err_str_pair negotiate_extensions(response_type const& response) { + return negotiate_extensions_helper(response); + } + + /// Extension negotiation helper function + /** + * This exists mostly because the code for requests and responses is + * identical and I can't have virtual template methods. + */ + template + err_str_pair negotiate_extensions_helper(header_type const& header) { + err_str_pair ret; + + // Respect blanket disabling of all extensions and don't even parse + // the extension header + if (!config::enable_extensions) { + ret.first = make_error_code(error::extensions_disabled); + return ret; } - int get_version() const { - return 13; + http::parameter_list p; + + bool error = header.get_header_as_plist("Sec-WebSocket-Extensions", p); + + if (error) { + ret.first = make_error_code(error::extension_parse_error); + return ret; } - bool has_permessage_deflate() const { - return m_permessage_deflate.is_implemented(); + // If there are no extensions parsed then we are done! + if (p.size() == 0) { + return ret; } - err_str_pair negotiate_extensions(request_type const & request) { - return negotiate_extensions_helper(request); - } - - err_str_pair negotiate_extensions(response_type const & response) { - return negotiate_extensions_helper(response); - } - - /// Extension negotiation helper function - /** - * This exists mostly because the code for requests and responses is - * identical and I can't have virtual template methods. - */ - template - err_str_pair negotiate_extensions_helper(header_type const & header) { - err_str_pair ret; + http::parameter_list::const_iterator it; - // Respect blanket disabling of all extensions and don't even parse - // the extension header - if (!config::enable_extensions) { - ret.first = make_error_code(error::extensions_disabled); - return ret; + // look through the list of extension requests to find the first + // one that we can accept. + if (m_permessage_deflate.is_implemented()) { + err_str_pair neg_ret; + for (it = p.begin(); it != p.end(); ++it) { + // not a permessage-deflate extension request, ignore + if (it->first != "permessage-deflate") { + continue; } - http::parameter_list p; - - bool error = header.get_header_as_plist("Sec-WebSocket-Extensions",p); - - if (error) { - ret.first = make_error_code(error::extension_parse_error); - return ret; + // if we have already successfully negotiated this extension + // then skip any other requests to negotiate the same one + // with different parameters + if (m_permessage_deflate.is_enabled()) { + continue; } - // If there are no extensions parsed then we are done! - if (p.size() == 0) { - return ret; + // attempt to negotiate this offer + neg_ret = m_permessage_deflate.negotiate(it->second); + + if (neg_ret.first) { + // negotiation offer failed. Do nothing. We will continue + // searching for a permessage-deflate config that succeeds + continue; } - http::parameter_list::const_iterator it; + // Negotiation tentatively succeeded - // look through the list of extension requests to find the first - // one that we can accept. - if (m_permessage_deflate.is_implemented()) { - err_str_pair neg_ret; - for (it = p.begin(); it != p.end(); ++it) { - // not a permessage-deflate extension request, ignore - if (it->first != "permessage-deflate") { - continue; - } - - // if we have already successfully negotiated this extension - // then skip any other requests to negotiate the same one - // with different parameters - if (m_permessage_deflate.is_enabled()) { - continue; - } - - // attempt to negotiate this offer - neg_ret = m_permessage_deflate.negotiate(it->second); - - if (neg_ret.first) { - // negotiation offer failed. Do nothing. We will continue - // searching for a permessage-deflate config that succeeds - continue; - } - - // Negotiation tentatively succeeded - - // Actually try to initialize the extension before we - // deem negotiation complete - lib::error_code ec = m_permessage_deflate.init(base::m_server); - - if (ec) { - // Negotiation succeeded but initialization failed this is - // an error that should stop negotiation of permessage - // deflate. Return the reason for the init failure - - ret.first = ec; - break; - } else { - // Successfully initialized, push the negotiated response into - // the reply and stop looking for additional permessage-deflate - // extensions - ret.second += neg_ret.second; - break; - } - } - } - - // support for future extensions would go here. Should check the value of - // ret.first before continuing. Might need to consider whether failure of - // negotiation of an earlier extension should stop negotiation of subsequent - // ones - - return ret; - } - - lib::error_code validate_handshake(request_type const & r) const { - if (r.get_method() != "GET") { - return make_error_code(error::invalid_http_method); - } - - if (r.get_version() != "HTTP/1.1") { - return make_error_code(error::invalid_http_version); - } - - // required headers - // Host is required by HTTP/1.1 - // Connection is required by is_websocket_handshake - // Upgrade is required by is_websocket_handshake - if (r.get_header("Sec-WebSocket-Key").empty()) { - return make_error_code(error::missing_required_header); - } - - return lib::error_code(); - } - - /* TODO: the 'subprotocol' parameter may need to be expanded into a more - * generic struct if other user input parameters to the processed handshake - * are found. - */ - lib::error_code process_handshake(request_type const & request, - std::string const & subprotocol, response_type & response) const - { - std::string server_key = request.get_header("Sec-WebSocket-Key"); - - lib::error_code ec = process_handshake_key(server_key); + // Actually try to initialize the extension before we + // deem negotiation complete + lib::error_code ec = m_permessage_deflate.init(base::m_server); if (ec) { - return ec; - } + // Negotiation succeeded but initialization failed this is + // an error that should stop negotiation of permessage + // deflate. Return the reason for the init failure - response.replace_header("Sec-WebSocket-Accept",server_key); - response.append_header("Upgrade",constants::upgrade_token); - response.append_header("Connection",constants::connection_token); - - if (!subprotocol.empty()) { - response.replace_header("Sec-WebSocket-Protocol",subprotocol); - } - - return lib::error_code(); - } - - /// Fill in a set of request headers for a client connection request - /** - * @param [out] req Set of headers to fill in - * @param [in] uri The uri being connected to - * @param [in] subprotocols The list of subprotocols to request - */ - lib::error_code client_handshake_request(request_type & req, uri_ptr - uri, std::vector const & subprotocols) const - { - req.set_method("GET"); - req.set_uri(uri->get_resource()); - req.set_version("HTTP/1.1"); - - req.append_header("Upgrade","websocket"); - req.append_header("Connection","Upgrade"); - req.replace_header("Sec-WebSocket-Version","13"); - req.replace_header("Host",uri->get_host_port()); - - if (!subprotocols.empty()) { - std::ostringstream result; - std::vector::const_iterator it = subprotocols.begin(); - result << *it++; - while (it != subprotocols.end()) { - result << ", " << *it++; - } - - req.replace_header("Sec-WebSocket-Protocol",result.str()); - } - - // Generate handshake key - frame::uint32_converter conv; - unsigned char raw_key[16]; - - for (int i = 0; i < 4; i++) { - conv.i = m_rng(); - std::copy(conv.c,conv.c+4,&raw_key[i*4]); - } - - req.replace_header("Sec-WebSocket-Key",base64_encode(raw_key, 16)); - - if (m_permessage_deflate.is_implemented()) { - std::string offer = m_permessage_deflate.generate_offer(); - if (!offer.empty()) { - req.replace_header("Sec-WebSocket-Extensions",offer); - } - } - - return lib::error_code(); - } - - /// Validate the server's response to an outgoing handshake request - /** - * @param req The original request sent - * @param res The reponse to generate - * @return An error code, 0 on success, non-zero for other errors - */ - lib::error_code validate_server_handshake_response(request_type const & req, - response_type& res) const - { - // A valid response has an HTTP 101 switching protocols code - if (res.get_status_code() != http::status_code::switching_protocols) { - return error::make_error_code(error::invalid_http_status); - } - - // And the upgrade token in an upgrade header - std::string const & upgrade_header = res.get_header("Upgrade"); - if (utility::ci_find_substr(upgrade_header, constants::upgrade_token, - sizeof(constants::upgrade_token)-1) == upgrade_header.end()) - { - return error::make_error_code(error::missing_required_header); - } - - // And the websocket token in the connection header - std::string const & con_header = res.get_header("Connection"); - if (utility::ci_find_substr(con_header, constants::connection_token, - sizeof(constants::connection_token)-1) == con_header.end()) - { - return error::make_error_code(error::missing_required_header); - } - - // And has a valid Sec-WebSocket-Accept value - std::string key = req.get_header("Sec-WebSocket-Key"); - lib::error_code ec = process_handshake_key(key); - - if (ec || key != res.get_header("Sec-WebSocket-Accept")) { - return error::make_error_code(error::missing_required_header); - } - - // check extensions - - return lib::error_code(); - } - - std::string get_raw(response_type const & res) const { - return res.raw(); - } - - std::string const & get_origin(request_type const & r) const { - return r.get_header("Origin"); - } - - lib::error_code extract_subprotocols(request_type const & req, - std::vector & subprotocol_list) - { - if (!req.get_header("Sec-WebSocket-Protocol").empty()) { - http::parameter_list p; - - if (!req.get_header_as_plist("Sec-WebSocket-Protocol",p)) { - http::parameter_list::const_iterator it; - - for (it = p.begin(); it != p.end(); ++it) { - subprotocol_list.push_back(it->first); - } - } else { - return error::make_error_code(error::subprotocol_parse_error); - } - } - return lib::error_code(); - } - - uri_ptr get_uri(request_type const & request) const { - return get_uri_from_host(request,(base::m_secure ? "wss" : "ws")); - } - - /// Process new websocket connection bytes - /** - * - * Hybi 13 data streams represent a series of variable length frames. Each - * frame is made up of a series of fixed length fields. The lengths of later - * fields are contained in earlier fields. The first field length is fixed - * by the spec. - * - * This processor represents a state machine that keeps track of what field - * is presently being read and how many more bytes are needed to complete it - * - * - * - * - * Read two header bytes - * Extract full frame length. - * Read extra header bytes - * Validate frame header (including extension validate) - * Read extension data into extension message state object - * Read payload data into payload - * - * @param buf Input buffer - * - * @param len Length of input buffer - * - * @return Number of bytes processed or zero on error - */ - size_t consume(uint8_t * buf, size_t len, lib::error_code & ec) { - size_t p = 0; - - ec = lib::error_code(); - - //std::cout << "consume: " << utility::to_hex(buf,len) << std::endl; - - // Loop while we don't have a message ready and we still have bytes - // left to process. - while (m_state != READY && m_state != FATAL_ERROR && - (p < len || m_bytes_needed == 0)) - { - if (m_state == HEADER_BASIC) { - p += this->copy_basic_header_bytes(buf+p,len-p); - - if (m_bytes_needed > 0) { - continue; - } - - ec = this->validate_incoming_basic_header( - m_basic_header, base::m_server, !m_data_msg.msg_ptr - ); - if (ec) {break;} - - // extract full header size and adjust consume state accordingly - m_state = HEADER_EXTENDED; - m_cursor = 0; - m_bytes_needed = frame::get_header_len(m_basic_header) - - frame::BASIC_HEADER_LENGTH; - } else if (m_state == HEADER_EXTENDED) { - p += this->copy_extended_header_bytes(buf+p,len-p); - - if (m_bytes_needed > 0) { - continue; - } - - ec = validate_incoming_extended_header(m_basic_header,m_extended_header); - if (ec){break;} - - m_state = APPLICATION; - m_bytes_needed = static_cast(get_payload_size(m_basic_header,m_extended_header)); - - // check if this frame is the start of a new message and set up - // the appropriate message metadata. - frame::opcode::value op = frame::get_opcode(m_basic_header); - - // TODO: get_message failure conditions - - if (frame::opcode::is_control(op)) { - m_control_msg = msg_metadata( - m_msg_manager->get_message(op,m_bytes_needed), - frame::get_masking_key(m_basic_header,m_extended_header) - ); - - m_current_msg = &m_control_msg; - } else { - if (!m_data_msg.msg_ptr) { - if (m_bytes_needed > base::m_max_message_size) { - ec = make_error_code(error::message_too_big); - break; - } - - m_data_msg = msg_metadata( - m_msg_manager->get_message(op,m_bytes_needed), - frame::get_masking_key(m_basic_header,m_extended_header) - ); - - if (m_permessage_deflate.is_enabled()) { - m_data_msg.msg_ptr->set_compressed(frame::get_rsv1(m_basic_header)); - } - } else { - // Fetch the underlying payload buffer from the data message we - // are writing into. - std::string & out = m_data_msg.msg_ptr->get_raw_payload(); - - if (out.size() + m_bytes_needed > base::m_max_message_size) { - ec = make_error_code(error::message_too_big); - break; - } - - // Each frame starts a new masking key. All other state - // remains between frames. - m_data_msg.prepared_key = prepare_masking_key( - frame::get_masking_key( - m_basic_header, - m_extended_header - ) - ); - - out.reserve(out.size() + m_bytes_needed); - } - m_current_msg = &m_data_msg; - } - } else if (m_state == EXTENSION) { - m_state = APPLICATION; - } else if (m_state == APPLICATION) { - size_t bytes_to_process = (std::min)(m_bytes_needed,len-p); - - if (bytes_to_process > 0) { - p += this->process_payload_bytes(buf+p,bytes_to_process,ec); - - if (ec) {break;} - } - - if (m_bytes_needed > 0) { - continue; - } - - // If this was the last frame in the message set the ready flag. - // Otherwise, reset processor state to read additional frames. - if (frame::get_fin(m_basic_header)) { - ec = finalize_message(); - if (ec) { - break; - } - } else { - this->reset_headers(); - } - } else { - // shouldn't be here - ec = make_error_code(error::general); - return 0; - } - } - - return p; - } - - /// Perform any finalization actions on an incoming message - /** - * Called after the full message is received. Provides the opportunity for - * extensions to complete any data post processing as well as final UTF8 - * validation checks for text messages. - * - * @return A code indicating errors, if any - */ - lib::error_code finalize_message() { - std::string & out = m_current_msg->msg_ptr->get_raw_payload(); - - // if the frame is compressed, append the compression - // trailer and flush the compression buffer. - if (m_permessage_deflate.is_enabled() - && m_current_msg->msg_ptr->get_compressed()) - { - uint8_t trailer[4] = {0x00, 0x00, 0xff, 0xff}; - - // Decompress current buffer into the message buffer - lib::error_code ec; - ec = m_permessage_deflate.decompress(trailer,4,out); - if (ec) { - return ec; - } - } - - // ensure that text messages end on a valid UTF8 code point - if (frame::get_opcode(m_basic_header) == frame::opcode::TEXT) { - if (!m_current_msg->validator.complete()) { - return make_error_code(error::invalid_utf8); - } - } - - m_state = READY; - - return lib::error_code(); - } - - void reset_headers() { - m_state = HEADER_BASIC; - m_bytes_needed = frame::BASIC_HEADER_LENGTH; - - m_basic_header.b0 = 0x00; - m_basic_header.b1 = 0x00; - - std::fill_n( - m_extended_header.bytes, - frame::MAX_EXTENDED_HEADER_LENGTH, - 0x00 - ); - } - - /// Test whether or not the processor has a message ready - bool ready() const { - return (m_state == READY); - } - - message_ptr get_message() { - if (!ready()) { - return message_ptr(); - } - message_ptr ret = m_current_msg->msg_ptr; - m_current_msg->msg_ptr.reset(); - - if (frame::opcode::is_control(ret->get_opcode())) { - m_control_msg.msg_ptr.reset(); + ret.first = ec; + break; } else { - m_data_msg.msg_ptr.reset(); + // Successfully initialized, push the negotiated response into + // the reply and stop looking for additional permessage-deflate + // extensions + ret.second += neg_ret.second; + break; + } + } + } + + // support for future extensions would go here. Should check the value of + // ret.first before continuing. Might need to consider whether failure of + // negotiation of an earlier extension should stop negotiation of subsequent + // ones + + return ret; + } + + lib::error_code validate_handshake(request_type const& r) const { + if (r.get_method() != "GET") { + return make_error_code(error::invalid_http_method); + } + + if (r.get_version() != "HTTP/1.1") { + return make_error_code(error::invalid_http_version); + } + + // required headers + // Host is required by HTTP/1.1 + // Connection is required by is_websocket_handshake + // Upgrade is required by is_websocket_handshake + if (r.get_header("Sec-WebSocket-Key").empty()) { + return make_error_code(error::missing_required_header); + } + + return lib::error_code(); + } + + /* TODO: the 'subprotocol' parameter may need to be expanded into a more + * generic struct if other user input parameters to the processed handshake + * are found. + */ + lib::error_code process_handshake(request_type const& request, + std::string const& subprotocol, + response_type& response) const { + std::string server_key = request.get_header("Sec-WebSocket-Key"); + + lib::error_code ec = process_handshake_key(server_key); + + if (ec) { + return ec; + } + + response.replace_header("Sec-WebSocket-Accept", server_key); + response.append_header("Upgrade", constants::upgrade_token); + response.append_header("Connection", constants::connection_token); + + if (!subprotocol.empty()) { + response.replace_header("Sec-WebSocket-Protocol", subprotocol); + } + + return lib::error_code(); + } + + /// Fill in a set of request headers for a client connection request + /** + * @param [out] req Set of headers to fill in + * @param [in] uri The uri being connected to + * @param [in] subprotocols The list of subprotocols to request + */ + lib::error_code client_handshake_request( + request_type& req, uri_ptr uri, + std::vector const& subprotocols) const { + req.set_method("GET"); + req.set_uri(uri->get_resource()); + req.set_version("HTTP/1.1"); + + req.append_header("Upgrade", "websocket"); + req.append_header("Connection", "Upgrade"); + req.replace_header("Sec-WebSocket-Version", "13"); + req.replace_header("Host", uri->get_host_port()); + + if (!subprotocols.empty()) { + std::ostringstream result; + std::vector::const_iterator it = subprotocols.begin(); + result << *it++; + while (it != subprotocols.end()) { + result << ", " << *it++; + } + + req.replace_header("Sec-WebSocket-Protocol", result.str()); + } + + // Generate handshake key + frame::uint32_converter conv; + unsigned char raw_key[16]; + + for (int i = 0; i < 4; i++) { + conv.i = m_rng(); + std::copy(conv.c, conv.c + 4, &raw_key[i * 4]); + } + + req.replace_header("Sec-WebSocket-Key", base64_encode(raw_key, 16)); + + if (m_permessage_deflate.is_implemented()) { + std::string offer = m_permessage_deflate.generate_offer(); + if (!offer.empty()) { + req.replace_header("Sec-WebSocket-Extensions", offer); + } + } + + return lib::error_code(); + } + + /// Validate the server's response to an outgoing handshake request + /** + * @param req The original request sent + * @param res The reponse to generate + * @return An error code, 0 on success, non-zero for other errors + */ + lib::error_code validate_server_handshake_response(request_type const& req, + response_type& res) const { + // A valid response has an HTTP 101 switching protocols code + if (res.get_status_code() != http::status_code::switching_protocols) { + return error::make_error_code(error::invalid_http_status); + } + + // And the upgrade token in an upgrade header + std::string const& upgrade_header = res.get_header("Upgrade"); + if (utility::ci_find_substr(upgrade_header, constants::upgrade_token, + sizeof(constants::upgrade_token) - 1) == + upgrade_header.end()) { + return error::make_error_code(error::missing_required_header); + } + + // And the websocket token in the connection header + std::string const& con_header = res.get_header("Connection"); + if (utility::ci_find_substr(con_header, constants::connection_token, + sizeof(constants::connection_token) - 1) == + con_header.end()) { + return error::make_error_code(error::missing_required_header); + } + + // And has a valid Sec-WebSocket-Accept value + std::string key = req.get_header("Sec-WebSocket-Key"); + lib::error_code ec = process_handshake_key(key); + + if (ec || key != res.get_header("Sec-WebSocket-Accept")) { + return error::make_error_code(error::missing_required_header); + } + + // check extensions + + return lib::error_code(); + } + + std::string get_raw(response_type const& res) const { return res.raw(); } + + std::string const& get_origin(request_type const& r) const { + return r.get_header("Origin"); + } + + lib::error_code extract_subprotocols( + request_type const& req, std::vector& subprotocol_list) { + if (!req.get_header("Sec-WebSocket-Protocol").empty()) { + http::parameter_list p; + + if (!req.get_header_as_plist("Sec-WebSocket-Protocol", p)) { + http::parameter_list::const_iterator it; + + for (it = p.begin(); it != p.end(); ++it) { + subprotocol_list.push_back(it->first); + } + } else { + return error::make_error_code(error::subprotocol_parse_error); + } + } + return lib::error_code(); + } + + uri_ptr get_uri(request_type const& request) const { + return get_uri_from_host(request, (base::m_secure ? "wss" : "ws")); + } + + /// Process new websocket connection bytes + /** + * + * Hybi 13 data streams represent a series of variable length frames. Each + * frame is made up of a series of fixed length fields. The lengths of later + * fields are contained in earlier fields. The first field length is fixed + * by the spec. + * + * This processor represents a state machine that keeps track of what field + * is presently being read and how many more bytes are needed to complete it + * + * + * + * + * Read two header bytes + * Extract full frame length. + * Read extra header bytes + * Validate frame header (including extension validate) + * Read extension data into extension message state object + * Read payload data into payload + * + * @param buf Input buffer + * + * @param len Length of input buffer + * + * @return Number of bytes processed or zero on error + */ + size_t consume(uint8_t* buf, size_t len, lib::error_code& ec) { + size_t p = 0; + + ec = lib::error_code(); + + // std::cout << "consume: " << utility::to_hex(buf,len) << std::endl; + + // Loop while we don't have a message ready and we still have bytes + // left to process. + while (m_state != READY && m_state != FATAL_ERROR && + (p < len || m_bytes_needed == 0)) { + if (m_state == HEADER_BASIC) { + p += this->copy_basic_header_bytes(buf + p, len - p); + + if (m_bytes_needed > 0) { + continue; } - this->reset_headers(); - - return ret; - } - - /// Test whether or not the processor is in a fatal error state. - bool get_error() const { - return m_state == FATAL_ERROR; - } - - size_t get_bytes_needed() const { - return m_bytes_needed; - } - - /// Prepare a user data message for writing - /** - * Performs validation, masking, compression, etc. will return an error if - * there was an error, otherwise msg will be ready to be written - * - * TODO: tests - * - * @param in An unprepared message to prepare - * @param out A message to be overwritten with the prepared message - * @return error code - */ - virtual lib::error_code prepare_data_frame(message_ptr in, message_ptr out) - { - if (!in || !out) { - return make_error_code(error::invalid_arguments); + ec = this->validate_incoming_basic_header( + m_basic_header, base::m_server, !m_data_msg.msg_ptr); + if (ec) { + break; } - frame::opcode::value op = in->get_opcode(); + // extract full header size and adjust consume state accordingly + m_state = HEADER_EXTENDED; + m_cursor = 0; + m_bytes_needed = + frame::get_header_len(m_basic_header) - frame::BASIC_HEADER_LENGTH; + } else if (m_state == HEADER_EXTENDED) { + p += this->copy_extended_header_bytes(buf + p, len - p); + + if (m_bytes_needed > 0) { + continue; + } + + ec = validate_incoming_extended_header(m_basic_header, + m_extended_header); + if (ec) { + break; + } + + m_state = APPLICATION; + m_bytes_needed = static_cast( + get_payload_size(m_basic_header, m_extended_header)); + + // check if this frame is the start of a new message and set up + // the appropriate message metadata. + frame::opcode::value op = frame::get_opcode(m_basic_header); + + // TODO: get_message failure conditions - // validate opcode: only regular data frames if (frame::opcode::is_control(op)) { - return make_error_code(error::invalid_opcode); - } + m_control_msg = msg_metadata( + m_msg_manager->get_message(op, m_bytes_needed), + frame::get_masking_key(m_basic_header, m_extended_header)); - std::string& i = in->get_raw_payload(); - std::string& o = out->get_raw_payload(); - - // validate payload utf8 - if (op == frame::opcode::TEXT && !utf8_validator::validate(i)) { - return make_error_code(error::invalid_payload); - } - - frame::masking_key_type key; - bool masked = !base::m_server; - bool compressed = m_permessage_deflate.is_enabled() - && in->get_compressed(); - bool fin = in->get_fin(); - - if (masked) { - // Generate masking key. - key.i = m_rng(); + m_current_msg = &m_control_msg; } else { - key.i = 0; + if (!m_data_msg.msg_ptr) { + if (m_bytes_needed > base::m_max_message_size) { + ec = make_error_code(error::message_too_big); + break; + } + + m_data_msg = msg_metadata( + m_msg_manager->get_message(op, m_bytes_needed), + frame::get_masking_key(m_basic_header, m_extended_header)); + + if (m_permessage_deflate.is_enabled()) { + m_data_msg.msg_ptr->set_compressed( + frame::get_rsv1(m_basic_header)); + } + } else { + // Fetch the underlying payload buffer from the data message we + // are writing into. + std::string& out = m_data_msg.msg_ptr->get_raw_payload(); + + if (out.size() + m_bytes_needed > base::m_max_message_size) { + ec = make_error_code(error::message_too_big); + break; + } + + // Each frame starts a new masking key. All other state + // remains between frames. + m_data_msg.prepared_key = prepare_masking_key( + frame::get_masking_key(m_basic_header, m_extended_header)); + + out.reserve(out.size() + m_bytes_needed); + } + m_current_msg = &m_data_msg; + } + } else if (m_state == EXTENSION) { + m_state = APPLICATION; + } else if (m_state == APPLICATION) { + size_t bytes_to_process = (std::min)(m_bytes_needed, len - p); + + if (bytes_to_process > 0) { + p += this->process_payload_bytes(buf + p, bytes_to_process, ec); + + if (ec) { + break; + } } - // prepare payload - if (compressed) { - // compress and store in o after header. - m_permessage_deflate.compress(i,o); + if (m_bytes_needed > 0) { + continue; + } - if (o.size() < 4) { - return make_error_code(error::general); - } - - // Strip trailing 4 0x00 0x00 0xff 0xff bytes before writing to the - // wire - o.resize(o.size()-4); - - // mask in place if necessary - if (masked) { - this->masked_copy(o,o,key); - } + // If this was the last frame in the message set the ready flag. + // Otherwise, reset processor state to read additional frames. + if (frame::get_fin(m_basic_header)) { + ec = finalize_message(); + if (ec) { + break; + } } else { - // no compression, just copy data into the output buffer - o.resize(i.size()); - - // if we are masked, have the masking function write to the output - // buffer directly to avoid another copy. If not masked, copy - // directly without masking. - if (masked) { - this->masked_copy(i,o,key); - } else { - std::copy(i.begin(),i.end(),o.begin()); - } + this->reset_headers(); } - - // generate header - frame::basic_header h(op,o.size(),fin,masked,compressed); - - if (masked) { - frame::extended_header e(o.size(),key.i); - out->set_header(frame::prepare_header(h,e)); - } else { - frame::extended_header e(o.size()); - out->set_header(frame::prepare_header(h,e)); - } - - out->set_prepared(true); - out->set_opcode(op); - - return lib::error_code(); + } else { + // shouldn't be here + ec = make_error_code(error::general); + return 0; + } } - /// Get URI - lib::error_code prepare_ping(std::string const & in, message_ptr out) const { - return this->prepare_control(frame::opcode::PING,in,out); + return p; + } + + /// Perform any finalization actions on an incoming message + /** + * Called after the full message is received. Provides the opportunity for + * extensions to complete any data post processing as well as final UTF8 + * validation checks for text messages. + * + * @return A code indicating errors, if any + */ + lib::error_code finalize_message() { + std::string& out = m_current_msg->msg_ptr->get_raw_payload(); + + // if the frame is compressed, append the compression + // trailer and flush the compression buffer. + if (m_permessage_deflate.is_enabled() && + m_current_msg->msg_ptr->get_compressed()) { + uint8_t trailer[4] = {0x00, 0x00, 0xff, 0xff}; + + // Decompress current buffer into the message buffer + lib::error_code ec; + ec = m_permessage_deflate.decompress(trailer, 4, out); + if (ec) { + return ec; + } } - lib::error_code prepare_pong(std::string const & in, message_ptr out) const { - return this->prepare_control(frame::opcode::PONG,in,out); + // ensure that text messages end on a valid UTF8 code point + if (frame::get_opcode(m_basic_header) == frame::opcode::TEXT) { + if (!m_current_msg->validator.complete()) { + return make_error_code(error::invalid_utf8); + } } - virtual lib::error_code prepare_close(close::status::value code, - std::string const & reason, message_ptr out) const - { - if (close::status::reserved(code)) { - return make_error_code(error::reserved_close_code); - } + m_state = READY; - if (close::status::invalid(code) && code != close::status::no_status) { - return make_error_code(error::invalid_close_code); - } + return lib::error_code(); + } - if (code == close::status::no_status && reason.size() > 0) { - return make_error_code(error::reason_requires_code); - } + void reset_headers() { + m_state = HEADER_BASIC; + m_bytes_needed = frame::BASIC_HEADER_LENGTH; - if (reason.size() > frame:: limits::payload_size_basic-2) { - return make_error_code(error::control_too_big); - } + m_basic_header.b0 = 0x00; + m_basic_header.b1 = 0x00; - std::string payload; + std::fill_n(m_extended_header.bytes, frame::MAX_EXTENDED_HEADER_LENGTH, + 0x00); + } - if (code != close::status::no_status) { - close::code_converter val; - val.i = htons(code); + /// Test whether or not the processor has a message ready + bool ready() const { return (m_state == READY); } - payload.resize(reason.size()+2); - - payload[0] = val.c[0]; - payload[1] = val.c[1]; - - std::copy(reason.begin(),reason.end(),payload.begin()+2); - } - - return this->prepare_control(frame::opcode::CLOSE,payload,out); + message_ptr get_message() { + if (!ready()) { + return message_ptr(); } -protected: - /// Convert a client handshake key into a server response key in place - lib::error_code process_handshake_key(std::string & key) const { - key.append(constants::handshake_guid); + message_ptr ret = m_current_msg->msg_ptr; + m_current_msg->msg_ptr.reset(); - unsigned char message_digest[20]; - sha1::calc(key.c_str(),key.length(),message_digest); - key = base64_encode(message_digest,20); - - return lib::error_code(); + if (frame::opcode::is_control(ret->get_opcode())) { + m_control_msg.msg_ptr.reset(); + } else { + m_data_msg.msg_ptr.reset(); } - /// Reads bytes from buf into m_basic_header - size_t copy_basic_header_bytes(uint8_t const * buf, size_t len) { - if (len == 0 || m_bytes_needed == 0) { - return 0; - } + this->reset_headers(); - if (len > 1) { - // have at least two bytes - if (m_bytes_needed == 2) { - m_basic_header.b0 = buf[0]; - m_basic_header.b1 = buf[1]; - m_bytes_needed -= 2; - return 2; - } else { - m_basic_header.b1 = buf[0]; - m_bytes_needed--; - return 1; - } - } else { - // have exactly one byte - if (m_bytes_needed == 2) { - m_basic_header.b0 = buf[0]; - m_bytes_needed--; - return 1; - } else { - m_basic_header.b1 = buf[0]; - m_bytes_needed--; - return 1; - } - } + return ret; + } + + /// Test whether or not the processor is in a fatal error state. + bool get_error() const { return m_state == FATAL_ERROR; } + + size_t get_bytes_needed() const { return m_bytes_needed; } + + /// Prepare a user data message for writing + /** + * Performs validation, masking, compression, etc. will return an error if + * there was an error, otherwise msg will be ready to be written + * + * TODO: tests + * + * @param in An unprepared message to prepare + * @param out A message to be overwritten with the prepared message + * @return error code + */ + virtual lib::error_code prepare_data_frame(message_ptr in, message_ptr out) { + if (!in || !out) { + return make_error_code(error::invalid_arguments); } - /// Reads bytes from buf into m_extended_header - size_t copy_extended_header_bytes(uint8_t const * buf, size_t len) { - size_t bytes_to_read = (std::min)(m_bytes_needed,len); + frame::opcode::value op = in->get_opcode(); - std::copy(buf,buf+bytes_to_read,m_extended_header.bytes+m_cursor); - m_cursor += bytes_to_read; - m_bytes_needed -= bytes_to_read; - - return bytes_to_read; + // validate opcode: only regular data frames + if (frame::opcode::is_control(op)) { + return make_error_code(error::invalid_opcode); } - /// Reads bytes from buf into message payload - /** - * This function performs unmasking and uncompression, validates the - * decoded bytes, and writes them to the appropriate message buffer. - * - * This member function will use the input buffer as stratch space for its - * work. The raw input bytes will not be preserved. This applies only to the - * bytes actually needed. At most min(m_bytes_needed,len) will be processed. - * - * @param buf Input/working buffer - * @param len Length of buf - * @return Number of bytes processed or zero in case of an error - */ - size_t process_payload_bytes(uint8_t * buf, size_t len, lib::error_code& ec) - { - // unmask if masked - if (frame::get_masked(m_basic_header)) { - m_current_msg->prepared_key = frame::byte_mask_circ( - buf, len, m_current_msg->prepared_key); - // TODO: SIMD masking - } + std::string& i = in->get_raw_payload(); + std::string& o = out->get_raw_payload(); - std::string & out = m_current_msg->msg_ptr->get_raw_payload(); - size_t offset = out.size(); - - // decompress message if needed. - if (m_permessage_deflate.is_enabled() - && m_current_msg->msg_ptr->get_compressed()) - { - // Decompress current buffer into the message buffer - ec = m_permessage_deflate.decompress(buf,len,out); - if (ec) { - return 0; - } - } else { - // No compression, straight copy - out.append(reinterpret_cast(buf),len); - } - - // validate unmasked, decompressed values - if (m_current_msg->msg_ptr->get_opcode() == frame::opcode::TEXT) { - if (!m_current_msg->validator.decode(out.begin()+offset,out.end())) { - ec = make_error_code(error::invalid_utf8); - return 0; - } - } - - m_bytes_needed -= len; - - return len; + // validate payload utf8 + if (op == frame::opcode::TEXT && !utf8_validator::validate(i)) { + return make_error_code(error::invalid_payload); } - /// Validate an incoming basic header - /** - * Validates an incoming hybi13 basic header. - * - * @param h The basic header to validate - * @param is_server Whether or not the endpoint that received this frame - * is a server. - * @param new_msg Whether or not this is the first frame of the message - * @return 0 on success or a non-zero error code on failure - */ - lib::error_code validate_incoming_basic_header(frame::basic_header const & h, - bool is_server, bool new_msg) const - { - frame::opcode::value op = frame::get_opcode(h); + frame::masking_key_type key; + bool masked = !base::m_server; + bool compressed = m_permessage_deflate.is_enabled() && in->get_compressed(); + bool fin = in->get_fin(); - // Check control frame size limit - if (frame::opcode::is_control(op) && - frame::get_basic_size(h) > frame::limits::payload_size_basic) - { - return make_error_code(error::control_too_big); - } - - // Check that RSV bits are clear - // The only RSV bits allowed are rsv1 if the permessage_compress - // extension is enabled for this connection and the message is not - // a control message. - // - // TODO: unit tests for this - if (frame::get_rsv1(h) && (!m_permessage_deflate.is_enabled() - || frame::opcode::is_control(op))) - { - return make_error_code(error::invalid_rsv_bit); - } - - if (frame::get_rsv2(h) || frame::get_rsv3(h)) { - return make_error_code(error::invalid_rsv_bit); - } - - // Check for reserved opcodes - if (frame::opcode::reserved(op)) { - return make_error_code(error::invalid_opcode); - } - - // Check for invalid opcodes - // TODO: unit tests for this? - if (frame::opcode::invalid(op)) { - return make_error_code(error::invalid_opcode); - } - - // Check for fragmented control message - if (frame::opcode::is_control(op) && !frame::get_fin(h)) { - return make_error_code(error::fragmented_control); - } - - // Check for continuation without an active message - if (new_msg && op == frame::opcode::CONTINUATION) { - return make_error_code(error::invalid_continuation); - } - - // Check for new data frame when expecting continuation - if (!new_msg && !frame::opcode::is_control(op) && - op != frame::opcode::CONTINUATION) - { - return make_error_code(error::invalid_continuation); - } - - // Servers should reject any unmasked frames from clients. - // Clients should reject any masked frames from servers. - if (is_server && !frame::get_masked(h)) { - return make_error_code(error::masking_required); - } else if (!is_server && frame::get_masked(h)) { - return make_error_code(error::masking_forbidden); - } - - return lib::error_code(); + if (masked) { + // Generate masking key. + key.i = m_rng(); + } else { + key.i = 0; } - /// Validate an incoming extended header - /** - * Validates an incoming hybi13 full header. - * - * @todo unit test for the >32 bit frames on 32 bit systems case - * - * @param h The basic header to validate - * @param e The extended header to validate - * @return An error_code, non-zero values indicate why the validation - * failed - */ - lib::error_code validate_incoming_extended_header(frame::basic_header h, - frame::extended_header e) const - { - uint8_t basic_size = frame::get_basic_size(h); - uint64_t payload_size = frame::get_payload_size(h,e); + // prepare payload + if (compressed) { + // compress and store in o after header. + m_permessage_deflate.compress(i, o); - // Check for non-minimally encoded payloads - if (basic_size == frame::payload_size_code_16bit && - payload_size <= frame::limits::payload_size_basic) - { - return make_error_code(error::non_minimal_encoding); - } + if (o.size() < 4) { + return make_error_code(error::general); + } - if (basic_size == frame::payload_size_code_64bit && - payload_size <= frame::limits::payload_size_extended) - { - return make_error_code(error::non_minimal_encoding); - } + // Strip trailing 4 0x00 0x00 0xff 0xff bytes before writing to the + // wire + o.resize(o.size() - 4); - // Check for >32bit frames on 32 bit systems - if (sizeof(size_t) == 4 && (payload_size >> 32)) { - return make_error_code(error::requires_64bit); - } + // mask in place if necessary + if (masked) { + this->masked_copy(o, o, key); + } + } else { + // no compression, just copy data into the output buffer + o.resize(i.size()); - return lib::error_code(); + // if we are masked, have the masking function write to the output + // buffer directly to avoid another copy. If not masked, copy + // directly without masking. + if (masked) { + this->masked_copy(i, o, key); + } else { + std::copy(i.begin(), i.end(), o.begin()); + } } - /// Copy and mask/unmask in one operation - /** - * Reads input from one string and writes unmasked output to another. - * - * @param [in] i The input string. - * @param [out] o The output string. - * @param [in] key The masking key to use for masking/unmasking - */ - void masked_copy (std::string const & i, std::string & o, - frame::masking_key_type key) const - { - frame::byte_mask(i.begin(),i.end(),o.begin(),key); - // TODO: SIMD masking + // generate header + frame::basic_header h(op, o.size(), fin, masked, compressed); + + if (masked) { + frame::extended_header e(o.size(), key.i); + out->set_header(frame::prepare_header(h, e)); + } else { + frame::extended_header e(o.size()); + out->set_header(frame::prepare_header(h, e)); } - /// Generic prepare control frame with opcode and payload. - /** - * Internal control frame building method. Handles validation, masking, etc - * - * @param op The control opcode to use - * @param payload The payload to use - * @param out The message buffer to store the prepared frame in - * @return Status code, zero on success, non-zero on error - */ - lib::error_code prepare_control(frame::opcode::value op, - std::string const & payload, message_ptr out) const - { - if (!out) { - return make_error_code(error::invalid_arguments); - } + out->set_prepared(true); + out->set_opcode(op); - if (!frame::opcode::is_control(op)) { - return make_error_code(error::invalid_opcode); - } + return lib::error_code(); + } - if (payload.size() > frame::limits::payload_size_basic) { - return make_error_code(error::control_too_big); - } + /// Get URI + lib::error_code prepare_ping(std::string const& in, message_ptr out) const { + return this->prepare_control(frame::opcode::PING, in, out); + } - frame::masking_key_type key; - bool masked = !base::m_server; + lib::error_code prepare_pong(std::string const& in, message_ptr out) const { + return this->prepare_control(frame::opcode::PONG, in, out); + } - frame::basic_header h(op,payload.size(),true,masked); - - std::string & o = out->get_raw_payload(); - o.resize(payload.size()); - - if (masked) { - // Generate masking key. - key.i = m_rng(); - - frame::extended_header e(payload.size(),key.i); - out->set_header(frame::prepare_header(h,e)); - this->masked_copy(payload,o,key); - } else { - frame::extended_header e(payload.size()); - out->set_header(frame::prepare_header(h,e)); - std::copy(payload.begin(),payload.end(),o.begin()); - } - - out->set_opcode(op); - out->set_prepared(true); - - return lib::error_code(); + virtual lib::error_code prepare_close(close::status::value code, + std::string const& reason, + message_ptr out) const { + if (close::status::reserved(code)) { + return make_error_code(error::reserved_close_code); } - enum state { - HEADER_BASIC = 0, - HEADER_EXTENDED = 1, - EXTENSION = 2, - APPLICATION = 3, - READY = 4, - FATAL_ERROR = 5 - }; + if (close::status::invalid(code) && code != close::status::no_status) { + return make_error_code(error::invalid_close_code); + } - /// This data structure holds data related to processing a message, such as - /// the buffer it is being written to, its masking key, its UTF8 validation - /// state, and sometimes its compression state. - struct msg_metadata { - msg_metadata() {} - msg_metadata(message_ptr m, size_t p) : msg_ptr(m),prepared_key(p) {} - msg_metadata(message_ptr m, frame::masking_key_type p) - : msg_ptr(m) - , prepared_key(prepare_masking_key(p)) {} + if (code == close::status::no_status && reason.size() > 0) { + return make_error_code(error::reason_requires_code); + } - message_ptr msg_ptr; // pointer to the message data buffer - size_t prepared_key; // prepared masking key - utf8_validator::validator validator; // utf8 validation state - }; + if (reason.size() > frame::limits::payload_size_basic - 2) { + return make_error_code(error::control_too_big); + } - // Basic header of the frame being read - frame::basic_header m_basic_header; + std::string payload; - // Pointer to a manager that can create message buffers for us. - msg_manager_ptr m_msg_manager; + if (code != close::status::no_status) { + close::code_converter val; + val.i = htons(code); - // Number of bytes needed to complete the current operation - size_t m_bytes_needed; + payload.resize(reason.size() + 2); - // Number of extended header bytes read - size_t m_cursor; + payload[0] = val.c[0]; + payload[1] = val.c[1]; - // Metadata for the current data msg - msg_metadata m_data_msg; - // Metadata for the current control msg - msg_metadata m_control_msg; + std::copy(reason.begin(), reason.end(), payload.begin() + 2); + } - // Pointer to the metadata associated with the frame being read - msg_metadata * m_current_msg; + return this->prepare_control(frame::opcode::CLOSE, payload, out); + } - // Extended header of current frame - frame::extended_header m_extended_header; + protected: + /// Convert a client handshake key into a server response key in place + lib::error_code process_handshake_key(std::string& key) const { + key.append(constants::handshake_guid); - rng_type & m_rng; + unsigned char message_digest[20]; + sha1::calc(key.c_str(), key.length(), message_digest); + key = base64_encode(message_digest, 20); - // Overall state of the processor - state m_state; + return lib::error_code(); + } - // Extensions - permessage_deflate_type m_permessage_deflate; + /// Reads bytes from buf into m_basic_header + size_t copy_basic_header_bytes(uint8_t const* buf, size_t len) { + if (len == 0 || m_bytes_needed == 0) { + return 0; + } + + if (len > 1) { + // have at least two bytes + if (m_bytes_needed == 2) { + m_basic_header.b0 = buf[0]; + m_basic_header.b1 = buf[1]; + m_bytes_needed -= 2; + return 2; + } else { + m_basic_header.b1 = buf[0]; + m_bytes_needed--; + return 1; + } + } else { + // have exactly one byte + if (m_bytes_needed == 2) { + m_basic_header.b0 = buf[0]; + m_bytes_needed--; + return 1; + } else { + m_basic_header.b1 = buf[0]; + m_bytes_needed--; + return 1; + } + } + } + + /// Reads bytes from buf into m_extended_header + size_t copy_extended_header_bytes(uint8_t const* buf, size_t len) { + size_t bytes_to_read = (std::min)(m_bytes_needed, len); + + std::copy(buf, buf + bytes_to_read, m_extended_header.bytes + m_cursor); + m_cursor += bytes_to_read; + m_bytes_needed -= bytes_to_read; + + return bytes_to_read; + } + + /// Reads bytes from buf into message payload + /** + * This function performs unmasking and uncompression, validates the + * decoded bytes, and writes them to the appropriate message buffer. + * + * This member function will use the input buffer as stratch space for its + * work. The raw input bytes will not be preserved. This applies only to the + * bytes actually needed. At most min(m_bytes_needed,len) will be processed. + * + * @param buf Input/working buffer + * @param len Length of buf + * @return Number of bytes processed or zero in case of an error + */ + size_t process_payload_bytes(uint8_t* buf, size_t len, lib::error_code& ec) { + // unmask if masked + if (frame::get_masked(m_basic_header)) { + m_current_msg->prepared_key = + frame::byte_mask_circ(buf, len, m_current_msg->prepared_key); + // TODO: SIMD masking + } + + std::string& out = m_current_msg->msg_ptr->get_raw_payload(); + size_t offset = out.size(); + + // decompress message if needed. + if (m_permessage_deflate.is_enabled() && + m_current_msg->msg_ptr->get_compressed()) { + // Decompress current buffer into the message buffer + ec = m_permessage_deflate.decompress(buf, len, out); + if (ec) { + return 0; + } + } else { + // No compression, straight copy + out.append(reinterpret_cast(buf), len); + } + + // validate unmasked, decompressed values + if (m_current_msg->msg_ptr->get_opcode() == frame::opcode::TEXT) { + if (!m_current_msg->validator.decode(out.begin() + offset, out.end())) { + ec = make_error_code(error::invalid_utf8); + return 0; + } + } + + m_bytes_needed -= len; + + return len; + } + + /// Validate an incoming basic header + /** + * Validates an incoming hybi13 basic header. + * + * @param h The basic header to validate + * @param is_server Whether or not the endpoint that received this frame + * is a server. + * @param new_msg Whether or not this is the first frame of the message + * @return 0 on success or a non-zero error code on failure + */ + lib::error_code validate_incoming_basic_header(frame::basic_header const& h, + bool is_server, + bool new_msg) const { + frame::opcode::value op = frame::get_opcode(h); + + // Check control frame size limit + if (frame::opcode::is_control(op) && + frame::get_basic_size(h) > frame::limits::payload_size_basic) { + return make_error_code(error::control_too_big); + } + + // Check that RSV bits are clear + // The only RSV bits allowed are rsv1 if the permessage_compress + // extension is enabled for this connection and the message is not + // a control message. + // + // TODO: unit tests for this + if (frame::get_rsv1(h) && + (!m_permessage_deflate.is_enabled() || frame::opcode::is_control(op))) { + return make_error_code(error::invalid_rsv_bit); + } + + if (frame::get_rsv2(h) || frame::get_rsv3(h)) { + return make_error_code(error::invalid_rsv_bit); + } + + // Check for reserved opcodes + if (frame::opcode::reserved(op)) { + return make_error_code(error::invalid_opcode); + } + + // Check for invalid opcodes + // TODO: unit tests for this? + if (frame::opcode::invalid(op)) { + return make_error_code(error::invalid_opcode); + } + + // Check for fragmented control message + if (frame::opcode::is_control(op) && !frame::get_fin(h)) { + return make_error_code(error::fragmented_control); + } + + // Check for continuation without an active message + if (new_msg && op == frame::opcode::CONTINUATION) { + return make_error_code(error::invalid_continuation); + } + + // Check for new data frame when expecting continuation + if (!new_msg && !frame::opcode::is_control(op) && + op != frame::opcode::CONTINUATION) { + return make_error_code(error::invalid_continuation); + } + + // Servers should reject any unmasked frames from clients. + // Clients should reject any masked frames from servers. + if (is_server && !frame::get_masked(h)) { + return make_error_code(error::masking_required); + } else if (!is_server && frame::get_masked(h)) { + return make_error_code(error::masking_forbidden); + } + + return lib::error_code(); + } + + /// Validate an incoming extended header + /** + * Validates an incoming hybi13 full header. + * + * @todo unit test for the >32 bit frames on 32 bit systems case + * + * @param h The basic header to validate + * @param e The extended header to validate + * @return An error_code, non-zero values indicate why the validation + * failed + */ + lib::error_code validate_incoming_extended_header( + frame::basic_header h, frame::extended_header e) const { + uint8_t basic_size = frame::get_basic_size(h); + uint64_t payload_size = frame::get_payload_size(h, e); + + // Check for non-minimally encoded payloads + if (basic_size == frame::payload_size_code_16bit && + payload_size <= frame::limits::payload_size_basic) { + return make_error_code(error::non_minimal_encoding); + } + + if (basic_size == frame::payload_size_code_64bit && + payload_size <= frame::limits::payload_size_extended) { + return make_error_code(error::non_minimal_encoding); + } + + // Check for >32bit frames on 32 bit systems + if (sizeof(size_t) == 4 && (payload_size >> 32)) { + return make_error_code(error::requires_64bit); + } + + return lib::error_code(); + } + + /// Copy and mask/unmask in one operation + /** + * Reads input from one string and writes unmasked output to another. + * + * @param [in] i The input string. + * @param [out] o The output string. + * @param [in] key The masking key to use for masking/unmasking + */ + void masked_copy(std::string const& i, std::string& o, + frame::masking_key_type key) const { + frame::byte_mask(i.begin(), i.end(), o.begin(), key); + // TODO: SIMD masking + } + + /// Generic prepare control frame with opcode and payload. + /** + * Internal control frame building method. Handles validation, masking, etc + * + * @param op The control opcode to use + * @param payload The payload to use + * @param out The message buffer to store the prepared frame in + * @return Status code, zero on success, non-zero on error + */ + lib::error_code prepare_control(frame::opcode::value op, + std::string const& payload, + message_ptr out) const { + if (!out) { + return make_error_code(error::invalid_arguments); + } + + if (!frame::opcode::is_control(op)) { + return make_error_code(error::invalid_opcode); + } + + if (payload.size() > frame::limits::payload_size_basic) { + return make_error_code(error::control_too_big); + } + + frame::masking_key_type key; + bool masked = !base::m_server; + + frame::basic_header h(op, payload.size(), true, masked); + + std::string& o = out->get_raw_payload(); + o.resize(payload.size()); + + if (masked) { + // Generate masking key. + key.i = m_rng(); + + frame::extended_header e(payload.size(), key.i); + out->set_header(frame::prepare_header(h, e)); + this->masked_copy(payload, o, key); + } else { + frame::extended_header e(payload.size()); + out->set_header(frame::prepare_header(h, e)); + std::copy(payload.begin(), payload.end(), o.begin()); + } + + out->set_opcode(op); + out->set_prepared(true); + + return lib::error_code(); + } + + enum state { + HEADER_BASIC = 0, + HEADER_EXTENDED = 1, + EXTENSION = 2, + APPLICATION = 3, + READY = 4, + FATAL_ERROR = 5 + }; + + /// This data structure holds data related to processing a message, such as + /// the buffer it is being written to, its masking key, its UTF8 validation + /// state, and sometimes its compression state. + struct msg_metadata { + msg_metadata() {} + msg_metadata(message_ptr m, size_t p) : msg_ptr(m), prepared_key(p) {} + msg_metadata(message_ptr m, frame::masking_key_type p) + : msg_ptr(m), prepared_key(prepare_masking_key(p)) {} + + message_ptr msg_ptr; // pointer to the message data buffer + size_t prepared_key; // prepared masking key + utf8_validator::validator validator; // utf8 validation state + }; + + // Basic header of the frame being read + frame::basic_header m_basic_header; + + // Pointer to a manager that can create message buffers for us. + msg_manager_ptr m_msg_manager; + + // Number of bytes needed to complete the current operation + size_t m_bytes_needed; + + // Number of extended header bytes read + size_t m_cursor; + + // Metadata for the current data msg + msg_metadata m_data_msg; + // Metadata for the current control msg + msg_metadata m_control_msg; + + // Pointer to the metadata associated with the frame being read + msg_metadata* m_current_msg; + + // Extended header of current frame + frame::extended_header m_extended_header; + + rng_type& m_rng; + + // Overall state of the processor + state m_state; + + // Extensions + permessage_deflate_type m_permessage_deflate; }; -} // namespace processor -} // namespace websocketpp +} // namespace processor +} // namespace websocketpp -#endif //WEBSOCKETPP_PROCESSOR_HYBI13_HPP +#endif // WEBSOCKETPP_PROCESSOR_HYBI13_HPP +#pragma warning(pop) \ No newline at end of file diff --git a/thirdparty/websocketpp/include/websocketpp/sha1/sha1.hpp b/thirdparty/websocketpp/include/websocketpp/sha1/sha1.hpp index 6b48d95..c486ac7 100644 --- a/thirdparty/websocketpp/include/websocketpp/sha1/sha1.hpp +++ b/thirdparty/websocketpp/include/websocketpp/sha1/sha1.hpp @@ -32,6 +32,8 @@ under the same license as the original, which is listed below. (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +#pragma warning(push) +#pragma warning(disable : 4267) #ifndef SHA1_DEFINED #define SHA1_DEFINED @@ -39,83 +41,74 @@ under the same license as the original, which is listed below. namespace websocketpp { namespace sha1 { -namespace { // local +namespace { // local // Rotate an integer value to left. inline unsigned int rol(unsigned int value, unsigned int steps) { - return ((value << steps) | (value >> (32 - steps))); + return ((value << steps) | (value >> (32 - steps))); } // Sets the first 16 integers in the buffert to zero. // Used for clearing the W buffert. -inline void clearWBuffert(unsigned int * buffert) -{ - for (int pos = 16; --pos >= 0;) - { - buffert[pos] = 0; - } +inline void clearWBuffert(unsigned int *buffert) { + for (int pos = 16; --pos >= 0;) { + buffert[pos] = 0; + } } -inline void innerHash(unsigned int * result, unsigned int * w) -{ - unsigned int a = result[0]; - unsigned int b = result[1]; - unsigned int c = result[2]; - unsigned int d = result[3]; - unsigned int e = result[4]; +inline void innerHash(unsigned int *result, unsigned int *w) { + unsigned int a = result[0]; + unsigned int b = result[1]; + unsigned int c = result[2]; + unsigned int d = result[3]; + unsigned int e = result[4]; - int round = 0; + int round = 0; - #define sha1macro(func,val) \ - { \ - const unsigned int t = rol(a, 5) + (func) + e + val + w[round]; \ - e = d; \ - d = c; \ - c = rol(b, 30); \ - b = a; \ - a = t; \ - } +#define sha1macro(func, val) \ + { \ + const unsigned int t = rol(a, 5) + (func) + e + val + w[round]; \ + e = d; \ + d = c; \ + c = rol(b, 30); \ + b = a; \ + a = t; \ + } - while (round < 16) - { - sha1macro((b & c) | (~b & d), 0x5a827999) - ++round; - } - while (round < 20) - { - w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); - sha1macro((b & c) | (~b & d), 0x5a827999) - ++round; - } - while (round < 40) - { - w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); - sha1macro(b ^ c ^ d, 0x6ed9eba1) - ++round; - } - while (round < 60) - { - w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); - sha1macro((b & c) | (b & d) | (c & d), 0x8f1bbcdc) - ++round; - } - while (round < 80) - { - w[round] = rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); - sha1macro(b ^ c ^ d, 0xca62c1d6) - ++round; - } + while (round < 16) { + sha1macro((b & c) | (~b & d), 0x5a827999)++ round; + } + while (round < 20) { + w[round] = + rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); + sha1macro((b & c) | (~b & d), 0x5a827999)++ round; + } + while (round < 40) { + w[round] = + rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); + sha1macro(b ^ c ^ d, 0x6ed9eba1)++ round; + } + while (round < 60) { + w[round] = + rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); + sha1macro((b & c) | (b & d) | (c & d), 0x8f1bbcdc)++ round; + } + while (round < 80) { + w[round] = + rol((w[round - 3] ^ w[round - 8] ^ w[round - 14] ^ w[round - 16]), 1); + sha1macro(b ^ c ^ d, 0xca62c1d6)++ round; + } - #undef sha1macro +#undef sha1macro - result[0] += a; - result[1] += b; - result[2] += c; - result[3] += d; - result[4] += e; + result[0] += a; + result[1] += b; + result[2] += c; + result[3] += d; + result[4] += e; } -} // namespace +} // namespace /// Calculate a SHA1 hash /** @@ -124,66 +117,70 @@ inline void innerHash(unsigned int * result, unsigned int * w) * @param hash should point to a buffer of at least 20 bytes of size for storing * the sha1 result in. */ -inline void calc(void const * src, size_t bytelength, unsigned char * hash) { - // Init the result array. - unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, - 0x10325476, 0xc3d2e1f0 }; +inline void calc(void const *src, size_t bytelength, unsigned char *hash) { + // Init the result array. + unsigned int result[5] = {0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, + 0xc3d2e1f0}; - // Cast the void src pointer to be the byte array we can work with. - unsigned char const * sarray = (unsigned char const *) src; + // Cast the void src pointer to be the byte array we can work with. + unsigned char const *sarray = (unsigned char const *)src; - // The reusable round buffer - unsigned int w[80]; + // The reusable round buffer + unsigned int w[80]; - // Loop through all complete 64byte blocks. + // Loop through all complete 64byte blocks. - size_t endCurrentBlock; - size_t currentBlock = 0; + size_t endCurrentBlock; + size_t currentBlock = 0; - if (bytelength >= 64) { - size_t const endOfFullBlocks = bytelength - 64; + if (bytelength >= 64) { + size_t const endOfFullBlocks = bytelength - 64; - while (currentBlock <= endOfFullBlocks) { - endCurrentBlock = currentBlock + 64; + while (currentBlock <= endOfFullBlocks) { + endCurrentBlock = currentBlock + 64; - // Init the round buffer with the 64 byte block data. - for (int roundPos = 0; currentBlock < endCurrentBlock; currentBlock += 4) - { - // This line will swap endian on big endian and keep endian on - // little endian. - w[roundPos++] = (unsigned int) sarray[currentBlock + 3] - | (((unsigned int) sarray[currentBlock + 2]) << 8) - | (((unsigned int) sarray[currentBlock + 1]) << 16) - | (((unsigned int) sarray[currentBlock]) << 24); - } - innerHash(result, w); - } + // Init the round buffer with the 64 byte block data. + for (int roundPos = 0; currentBlock < endCurrentBlock; + currentBlock += 4) { + // This line will swap endian on big endian and keep endian on + // little endian. + w[roundPos++] = (unsigned int)sarray[currentBlock + 3] | + (((unsigned int)sarray[currentBlock + 2]) << 8) | + (((unsigned int)sarray[currentBlock + 1]) << 16) | + (((unsigned int)sarray[currentBlock]) << 24); + } + innerHash(result, w); } + } - // Handle the last and not full 64 byte block if existing. - endCurrentBlock = bytelength - currentBlock; - clearWBuffert(w); - size_t lastBlockBytes = 0; - for (;lastBlockBytes < endCurrentBlock; ++lastBlockBytes) { - w[lastBlockBytes >> 2] |= (unsigned int) sarray[lastBlockBytes + currentBlock] << ((3 - (lastBlockBytes & 3)) << 3); - } + // Handle the last and not full 64 byte block if existing. + endCurrentBlock = bytelength - currentBlock; + clearWBuffert(w); + size_t lastBlockBytes = 0; + for (; lastBlockBytes < endCurrentBlock; ++lastBlockBytes) { + w[lastBlockBytes >> 2] |= + (unsigned int)sarray[lastBlockBytes + currentBlock] + << ((3 - (lastBlockBytes & 3)) << 3); + } - w[lastBlockBytes >> 2] |= 0x80 << ((3 - (lastBlockBytes & 3)) << 3); - if (endCurrentBlock >= 56) { - innerHash(result, w); - clearWBuffert(w); - } - w[15] = bytelength << 3; + w[lastBlockBytes >> 2] |= 0x80 << ((3 - (lastBlockBytes & 3)) << 3); + if (endCurrentBlock >= 56) { innerHash(result, w); + clearWBuffert(w); + } + w[15] = bytelength << 3; + innerHash(result, w); - // Store hash in result pointer, and make sure we get in in the correct - // order on both endian models. - for (int hashByte = 20; --hashByte >= 0;) { - hash[hashByte] = (result[hashByte >> 2] >> (((3 - hashByte) & 0x3) << 3)) & 0xff; - } + // Store hash in result pointer, and make sure we get in in the correct + // order on both endian models. + for (int hashByte = 20; --hashByte >= 0;) { + hash[hashByte] = + (result[hashByte >> 2] >> (((3 - hashByte) & 0x3) << 3)) & 0xff; + } } -} // namespace sha1 -} // namespace websocketpp +} // namespace sha1 +} // namespace websocketpp -#endif // SHA1_DEFINED +#endif // SHA1_DEFINED +#pragma warning(pop) \ No newline at end of file diff --git a/thirdparty/websocketpp/include/websocketpp/transport/asio/connection.hpp b/thirdparty/websocketpp/include/websocketpp/transport/asio/connection.hpp index 57dda74..15f9d6b 100644 --- a/thirdparty/websocketpp/include/websocketpp/transport/asio/connection.hpp +++ b/thirdparty/websocketpp/include/websocketpp/transport/asio/connection.hpp @@ -24,32 +24,29 @@ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * */ +#pragma warning(push) +#pragma warning(disable : 4127) #ifndef WEBSOCKETPP_TRANSPORT_ASIO_CON_HPP #define WEBSOCKETPP_TRANSPORT_ASIO_CON_HPP -#include - -#include - -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - #include #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace websocketpp { namespace transport { @@ -65,1133 +62,1001 @@ typedef lib::function tcp_init_handler; */ template class connection : public config::socket_type::socket_con_type { -public: - /// Type of this connection transport component - typedef connection type; - /// Type of a shared pointer to this connection transport component - typedef lib::shared_ptr ptr; + public: + /// Type of this connection transport component + typedef connection type; + /// Type of a shared pointer to this connection transport component + typedef lib::shared_ptr ptr; - /// Type of the socket connection component - typedef typename config::socket_type::socket_con_type socket_con_type; - /// Type of a shared pointer to the socket connection component - typedef typename socket_con_type::ptr socket_con_ptr; - /// Type of this transport's access logging policy - typedef typename config::alog_type alog_type; - /// Type of this transport's error logging policy - typedef typename config::elog_type elog_type; + /// Type of the socket connection component + typedef typename config::socket_type::socket_con_type socket_con_type; + /// Type of a shared pointer to the socket connection component + typedef typename socket_con_type::ptr socket_con_ptr; + /// Type of this transport's access logging policy + typedef typename config::alog_type alog_type; + /// Type of this transport's error logging policy + typedef typename config::elog_type elog_type; - typedef typename config::request_type request_type; - typedef typename request_type::ptr request_ptr; - typedef typename config::response_type response_type; - typedef typename response_type::ptr response_ptr; + typedef typename config::request_type request_type; + typedef typename request_type::ptr request_ptr; + typedef typename config::response_type response_type; + typedef typename response_type::ptr response_ptr; - /// Type of a pointer to the Asio io_service being used - typedef lib::asio::io_service * io_service_ptr; - /// Type of a pointer to the Asio io_service::strand being used - typedef lib::shared_ptr strand_ptr; - /// Type of a pointer to the Asio timer class - typedef lib::shared_ptr timer_ptr; + /// Type of a pointer to the Asio io_service being used + typedef lib::asio::io_service* io_service_ptr; + /// Type of a pointer to the Asio io_service::strand being used + typedef lib::shared_ptr strand_ptr; + /// Type of a pointer to the Asio timer class + typedef lib::shared_ptr timer_ptr; - // connection is friends with its associated endpoint to allow the endpoint - // to call private/protected utility methods that we don't want to expose - // to the public api. - friend class endpoint; + // connection is friends with its associated endpoint to allow the endpoint + // to call private/protected utility methods that we don't want to expose + // to the public api. + friend class endpoint; - // generate and manage our own io_service - explicit connection(bool is_server, const lib::shared_ptr & alog, const lib::shared_ptr & elog) - : m_is_server(is_server) - , m_alog(alog) - , m_elog(elog) - { - m_alog->write(log::alevel::devel,"asio con transport constructor"); + // generate and manage our own io_service + explicit connection(bool is_server, const lib::shared_ptr& alog, + const lib::shared_ptr& elog) + : m_is_server(is_server), m_alog(alog), m_elog(elog) { + m_alog->write(log::alevel::devel, "asio con transport constructor"); + } + + /// Get a shared pointer to this component + ptr get_shared() { + return lib::static_pointer_cast(socket_con_type::get_shared()); + } + + bool is_secure() const { return socket_con_type::is_secure(); } + + /// Set uri hook + /** + * Called by the endpoint as a connection is being established to provide + * the uri being connected to to the transport layer. + * + * This transport policy doesn't use the uri except to forward it to the + * socket layer. + * + * @since 0.6.0 + * + * @param u The uri to set + */ + void set_uri(uri_ptr u) { socket_con_type::set_uri(u); } + + /// Sets the tcp pre init handler + /** + * The tcp pre init handler is called after the raw tcp connection has been + * established but before any additional wrappers (proxy connects, TLS + * handshakes, etc) have been performed. + * + * @since 0.3.0 + * + * @param h The handler to call on tcp pre init. + */ + void set_tcp_pre_init_handler(tcp_init_handler h) { + m_tcp_pre_init_handler = h; + } + + /// Sets the tcp pre init handler (deprecated) + /** + * The tcp pre init handler is called after the raw tcp connection has been + * established but before any additional wrappers (proxy connects, TLS + * handshakes, etc) have been performed. + * + * @deprecated Use set_tcp_pre_init_handler instead + * + * @param h The handler to call on tcp pre init. + */ + void set_tcp_init_handler(tcp_init_handler h) { set_tcp_pre_init_handler(h); } + + /// Sets the tcp post init handler + /** + * The tcp post init handler is called after the tcp connection has been + * established and all additional wrappers (proxy connects, TLS handshakes, + * etc have been performed. This is fired before any bytes are read or any + * WebSocket specific handshake logic has been performed. + * + * @since 0.3.0 + * + * @param h The handler to call on tcp post init. + */ + void set_tcp_post_init_handler(tcp_init_handler h) { + m_tcp_post_init_handler = h; + } + + /// Set the proxy to connect through (exception free) + /** + * The URI passed should be a complete URI including scheme. For example: + * http://proxy.example.com:8080/ + * + * The proxy must be set up as an explicit (CONNECT) proxy allowed to + * connect to the port you specify. Traffic to the proxy is not encrypted. + * + * @param uri The full URI of the proxy to connect to. + * + * @param ec A status value + */ + void set_proxy(std::string const& uri, lib::error_code& ec) { + // TODO: return errors for illegal URIs here? + // TODO: should https urls be illegal for the moment? + m_proxy = uri; + m_proxy_data = lib::make_shared(); + ec = lib::error_code(); + } + + /// Set the proxy to connect through (exception) + void set_proxy(std::string const& uri) { + lib::error_code ec; + set_proxy(uri, ec); + if (ec) { + throw exception(ec); + } + } + + /// Set the basic auth credentials to use (exception free) + /** + * The URI passed should be a complete URI including scheme. For example: + * http://proxy.example.com:8080/ + * + * The proxy must be set up as an explicit proxy + * + * @param username The username to send + * + * @param password The password to send + * + * @param ec A status value + */ + void set_proxy_basic_auth(std::string const& username, + std::string const& password, lib::error_code& ec) { + if (!m_proxy_data) { + ec = make_error_code(websocketpp::error::invalid_state); + return; } - /// Get a shared pointer to this component - ptr get_shared() { - return lib::static_pointer_cast(socket_con_type::get_shared()); + // TODO: username can't contain ':' + std::string val = "Basic " + base64_encode(username + ":" + password); + m_proxy_data->req.replace_header("Proxy-Authorization", val); + ec = lib::error_code(); + } + + /// Set the basic auth credentials to use (exception) + void set_proxy_basic_auth(std::string const& username, + std::string const& password) { + lib::error_code ec; + set_proxy_basic_auth(username, password, ec); + if (ec) { + throw exception(ec); + } + } + + /// Set the proxy timeout duration (exception free) + /** + * Duration is in milliseconds. Default value is based on the transport + * config + * + * @param duration The number of milliseconds to wait before aborting the + * proxy connection. + * + * @param ec A status value + */ + void set_proxy_timeout(long duration, lib::error_code& ec) { + if (!m_proxy_data) { + ec = make_error_code(websocketpp::error::invalid_state); + return; } - bool is_secure() const { - return socket_con_type::is_secure(); + m_proxy_data->timeout_proxy = duration; + ec = lib::error_code(); + } + + /// Set the proxy timeout duration (exception) + void set_proxy_timeout(long duration) { + lib::error_code ec; + set_proxy_timeout(duration, ec); + if (ec) { + throw exception(ec); + } + } + + std::string const& get_proxy() const { return m_proxy; } + + /// Get the remote endpoint address + /** + * The iostream transport has no information about the ultimate remote + * endpoint. It will return the string "iostream transport". To indicate + * this. + * + * TODO: allow user settable remote endpoint addresses if this seems useful + * + * @return A string identifying the address of the remote endpoint + */ + std::string get_remote_endpoint() const { + lib::error_code ec; + + std::string ret = socket_con_type::get_remote_endpoint(ec); + + if (ec) { + m_elog->write(log::elevel::info, ret); + return "Unknown"; + } else { + return ret; + } + } + + /// Get the connection handle + connection_hdl get_handle() const { return m_connection_hdl; } + + /// Call back a function after a period of time. + /** + * Sets a timer that calls back a function after the specified period of + * milliseconds. Returns a handle that can be used to cancel the timer. + * A cancelled timer will return the error code error::operation_aborted + * A timer that expired will return no error. + * + * @param duration Length of time to wait in milliseconds + * + * @param callback The function to call back when the timer has expired + * + * @return A handle that can be used to cancel the timer if it is no longer + * needed. + */ + timer_ptr set_timer(long duration, timer_handler callback) { + timer_ptr new_timer(new lib::asio::steady_timer( + *m_io_service, lib::asio::milliseconds(duration))); + + if (config::enable_multithreading) { + new_timer->async_wait( + m_strand->wrap(lib::bind(&type::handle_timer, get_shared(), new_timer, + callback, lib::placeholders::_1))); + } else { + new_timer->async_wait(lib::bind(&type::handle_timer, get_shared(), + new_timer, callback, + lib::placeholders::_1)); } - /// Set uri hook - /** - * Called by the endpoint as a connection is being established to provide - * the uri being connected to to the transport layer. - * - * This transport policy doesn't use the uri except to forward it to the - * socket layer. - * - * @since 0.6.0 - * - * @param u The uri to set - */ - void set_uri(uri_ptr u) { - socket_con_type::set_uri(u); + return new_timer; + } + + /// Timer callback + /** + * The timer pointer is included to ensure the timer isn't destroyed until + * after it has expired. + * + * TODO: candidate for protected status + * + * @param post_timer Pointer to the timer in question + * @param callback The function to call back + * @param ec The status code + */ + void handle_timer(timer_ptr, timer_handler callback, + lib::asio::error_code const& ec) { + if (ec) { + if (ec == lib::asio::error::operation_aborted) { + callback(make_error_code(transport::error::operation_aborted)); + } else { + log_err(log::elevel::info, "asio handle_timer", ec); + callback(make_error_code(error::pass_through)); + } + } else { + callback(lib::error_code()); + } + } + + /// Get a pointer to this connection's strand + strand_ptr get_strand() { return m_strand; } + + /// Get the internal transport error code for a closed/failed connection + /** + * Retrieves a machine readable detailed error code indicating the reason + * that the connection was closed or failed. Valid only after the close or + * fail handler is called. + * + * Primarily used if you are using mismatched asio / system_error + * implementations such as `boost::asio` with `std::system_error`. In these + * cases the transport error type is different than the library error type + * and some WebSocket++ functions that return transport errors via the + * library error code type will be coerced into a catch all `pass_through` + * or `tls_error` error. This method will return the original machine + * readable transport error in the native type. + * + * @since 0.7.0 + * + * @return Error code indicating the reason the connection was closed or + * failed + */ + lib::asio::error_code get_transport_ec() const { return m_tec; } + + /// Initialize transport for reading + /** + * init_asio is called once immediately after construction to initialize + * Asio components to the io_service + * + * The transport initialization sequence consists of the following steps: + * - Pre-init: the underlying socket is initialized to the point where + * bytes may be written. No bytes are actually written in this stage + * - Proxy negotiation: if a proxy is set, a request is made to it to start + * a tunnel to the final destination. This stage ends when the proxy is + * ready to forward the + * next byte to the remote endpoint. + * - Post-init: Perform any i/o with the remote endpoint, such as setting up + * tunnels for encryption. This stage ends when the connection is ready to + * read or write the WebSocket handshakes. At this point the original + * callback function is called. + */ + protected: + void init(init_handler callback) { + if (m_alog->static_test(log::alevel::devel)) { + m_alog->write(log::alevel::devel, "asio connection init"); } - /// Sets the tcp pre init handler - /** - * The tcp pre init handler is called after the raw tcp connection has been - * established but before any additional wrappers (proxy connects, TLS - * handshakes, etc) have been performed. - * - * @since 0.3.0 - * - * @param h The handler to call on tcp pre init. - */ - void set_tcp_pre_init_handler(tcp_init_handler h) { - m_tcp_pre_init_handler = h; + // TODO: pre-init timeout. Right now no implemented socket policies + // actually have an asyncronous pre-init + + socket_con_type::pre_init(lib::bind(&type::handle_pre_init, get_shared(), + callback, lib::placeholders::_1)); + } + + /// initialize the proxy buffers and http parsers + /** + * + * @param authority The address of the server we want the proxy to tunnel to + * in the format of a URI authority (host:port) + * + * @return Status code indicating what errors occurred, if any + */ + lib::error_code proxy_init(std::string const& authority) { + if (!m_proxy_data) { + return websocketpp::error::make_error_code( + websocketpp::error::invalid_state); + } + m_proxy_data->req.set_version("HTTP/1.1"); + m_proxy_data->req.set_method("CONNECT"); + + m_proxy_data->req.set_uri(authority); + m_proxy_data->req.replace_header("Host", authority); + + return lib::error_code(); + } + + /// Finish constructing the transport + /** + * init_asio is called once immediately after construction to initialize + * Asio components to the io_service. + * + * @param io_service A pointer to the io_service to register with this + * connection + * + * @return Status code for the success or failure of the initialization + */ + lib::error_code init_asio(io_service_ptr io_service) { + m_io_service = io_service; + + if (config::enable_multithreading) { + m_strand.reset(new lib::asio::io_service::strand(*io_service)); } - /// Sets the tcp pre init handler (deprecated) - /** - * The tcp pre init handler is called after the raw tcp connection has been - * established but before any additional wrappers (proxy connects, TLS - * handshakes, etc) have been performed. - * - * @deprecated Use set_tcp_pre_init_handler instead - * - * @param h The handler to call on tcp pre init. - */ - void set_tcp_init_handler(tcp_init_handler h) { - set_tcp_pre_init_handler(h); + lib::error_code ec = + socket_con_type::init_asio(io_service, m_strand, m_is_server); + + return ec; + } + + void handle_pre_init(init_handler callback, lib::error_code const& ec) { + if (m_alog->static_test(log::alevel::devel)) { + m_alog->write(log::alevel::devel, "asio connection handle pre_init"); } - /// Sets the tcp post init handler - /** - * The tcp post init handler is called after the tcp connection has been - * established and all additional wrappers (proxy connects, TLS handshakes, - * etc have been performed. This is fired before any bytes are read or any - * WebSocket specific handshake logic has been performed. - * - * @since 0.3.0 - * - * @param h The handler to call on tcp post init. - */ - void set_tcp_post_init_handler(tcp_init_handler h) { - m_tcp_post_init_handler = h; + if (m_tcp_pre_init_handler) { + m_tcp_pre_init_handler(m_connection_hdl); } - /// Set the proxy to connect through (exception free) - /** - * The URI passed should be a complete URI including scheme. For example: - * http://proxy.example.com:8080/ - * - * The proxy must be set up as an explicit (CONNECT) proxy allowed to - * connect to the port you specify. Traffic to the proxy is not encrypted. - * - * @param uri The full URI of the proxy to connect to. - * - * @param ec A status value - */ - void set_proxy(std::string const & uri, lib::error_code & ec) { - // TODO: return errors for illegal URIs here? - // TODO: should https urls be illegal for the moment? - m_proxy = uri; - m_proxy_data = lib::make_shared(); - ec = lib::error_code(); + if (ec) { + callback(ec); } - /// Set the proxy to connect through (exception) - void set_proxy(std::string const & uri) { - lib::error_code ec; - set_proxy(uri,ec); - if (ec) { throw exception(ec); } + // If we have a proxy set issue a proxy connect, otherwise skip to + // post_init + if (!m_proxy.empty()) { + proxy_write(callback); + } else { + post_init(callback); + } + } + + void post_init(init_handler callback) { + if (m_alog->static_test(log::alevel::devel)) { + m_alog->write(log::alevel::devel, "asio connection post_init"); } - /// Set the basic auth credentials to use (exception free) - /** - * The URI passed should be a complete URI including scheme. For example: - * http://proxy.example.com:8080/ - * - * The proxy must be set up as an explicit proxy - * - * @param username The username to send - * - * @param password The password to send - * - * @param ec A status value - */ - void set_proxy_basic_auth(std::string const & username, std::string const & - password, lib::error_code & ec) - { - if (!m_proxy_data) { - ec = make_error_code(websocketpp::error::invalid_state); - return; - } + timer_ptr post_timer; - // TODO: username can't contain ':' - std::string val = "Basic "+base64_encode(username + ":" + password); - m_proxy_data->req.replace_header("Proxy-Authorization",val); - ec = lib::error_code(); + if (config::timeout_socket_post_init > 0) { + post_timer = + set_timer(config::timeout_socket_post_init, + lib::bind(&type::handle_post_init_timeout, get_shared(), + post_timer, callback, lib::placeholders::_1)); } - /// Set the basic auth credentials to use (exception) - void set_proxy_basic_auth(std::string const & username, std::string const & - password) - { - lib::error_code ec; - set_proxy_basic_auth(username,password,ec); - if (ec) { throw exception(ec); } + socket_con_type::post_init(lib::bind(&type::handle_post_init, get_shared(), + post_timer, callback, + lib::placeholders::_1)); + } + + /// Post init timeout callback + /** + * The timer pointer is included to ensure the timer isn't destroyed until + * after it has expired. + * + * @param post_timer Pointer to the timer in question + * @param callback The function to call back + * @param ec The status code + */ + void handle_post_init_timeout(timer_ptr, init_handler callback, + lib::error_code const& ec) { + lib::error_code ret_ec; + + if (ec) { + if (ec == transport::error::operation_aborted) { + m_alog->write(log::alevel::devel, "asio post init timer cancelled"); + return; + } + + log_err(log::elevel::devel, "asio handle_post_init_timeout", ec); + ret_ec = ec; + } else { + if (socket_con_type::get_ec()) { + ret_ec = socket_con_type::get_ec(); + } else { + ret_ec = make_error_code(transport::error::timeout); + } } - /// Set the proxy timeout duration (exception free) - /** - * Duration is in milliseconds. Default value is based on the transport - * config - * - * @param duration The number of milliseconds to wait before aborting the - * proxy connection. - * - * @param ec A status value - */ - void set_proxy_timeout(long duration, lib::error_code & ec) { - if (!m_proxy_data) { - ec = make_error_code(websocketpp::error::invalid_state); - return; - } + m_alog->write(log::alevel::devel, "Asio transport post-init timed out"); + cancel_socket_checked(); + callback(ret_ec); + } - m_proxy_data->timeout_proxy = duration; - ec = lib::error_code(); + /// Post init timeout callback + /** + * The timer pointer is included to ensure the timer isn't destroyed until + * after it has expired. + * + * @param post_timer Pointer to the timer in question + * @param callback The function to call back + * @param ec The status code + */ + void handle_post_init(timer_ptr post_timer, init_handler callback, + lib::error_code const& ec) { + if (ec == transport::error::operation_aborted || + (post_timer && lib::asio::is_neg(post_timer->expires_from_now()))) { + m_alog->write(log::alevel::devel, "post_init cancelled"); + return; } - /// Set the proxy timeout duration (exception) - void set_proxy_timeout(long duration) { - lib::error_code ec; - set_proxy_timeout(duration,ec); - if (ec) { throw exception(ec); } + if (post_timer) { + post_timer->cancel(); } - std::string const & get_proxy() const { - return m_proxy; + if (m_alog->static_test(log::alevel::devel)) { + m_alog->write(log::alevel::devel, "asio connection handle_post_init"); } - /// Get the remote endpoint address - /** - * The iostream transport has no information about the ultimate remote - * endpoint. It will return the string "iostream transport". To indicate - * this. - * - * TODO: allow user settable remote endpoint addresses if this seems useful - * - * @return A string identifying the address of the remote endpoint - */ - std::string get_remote_endpoint() const { - lib::error_code ec; - - std::string ret = socket_con_type::get_remote_endpoint(ec); - - if (ec) { - m_elog->write(log::elevel::info,ret); - return "Unknown"; - } else { - return ret; - } + if (m_tcp_post_init_handler) { + m_tcp_post_init_handler(m_connection_hdl); } - /// Get the connection handle - connection_hdl get_handle() const { - return m_connection_hdl; + callback(ec); + } + + void proxy_write(init_handler callback) { + if (m_alog->static_test(log::alevel::devel)) { + m_alog->write(log::alevel::devel, "asio connection proxy_write"); } - /// Call back a function after a period of time. - /** - * Sets a timer that calls back a function after the specified period of - * milliseconds. Returns a handle that can be used to cancel the timer. - * A cancelled timer will return the error code error::operation_aborted - * A timer that expired will return no error. - * - * @param duration Length of time to wait in milliseconds - * - * @param callback The function to call back when the timer has expired - * - * @return A handle that can be used to cancel the timer if it is no longer - * needed. - */ - timer_ptr set_timer(long duration, timer_handler callback) { - timer_ptr new_timer( - new lib::asio::steady_timer( - *m_io_service, - lib::asio::milliseconds(duration)) - ); - - if (config::enable_multithreading) { - new_timer->async_wait(m_strand->wrap(lib::bind( - &type::handle_timer, get_shared(), - new_timer, - callback, - lib::placeholders::_1 - ))); - } else { - new_timer->async_wait(lib::bind( - &type::handle_timer, get_shared(), - new_timer, - callback, - lib::placeholders::_1 - )); - } - - return new_timer; + if (!m_proxy_data) { + m_elog->write( + log::elevel::library, + "assertion failed: !m_proxy_data in asio::connection::proxy_write"); + callback(make_error_code(error::general)); + return; } - /// Timer callback - /** - * The timer pointer is included to ensure the timer isn't destroyed until - * after it has expired. - * - * TODO: candidate for protected status - * - * @param post_timer Pointer to the timer in question - * @param callback The function to call back - * @param ec The status code - */ - void handle_timer(timer_ptr, timer_handler callback, - lib::asio::error_code const & ec) - { - if (ec) { - if (ec == lib::asio::error::operation_aborted) { - callback(make_error_code(transport::error::operation_aborted)); - } else { - log_err(log::elevel::info,"asio handle_timer",ec); - callback(make_error_code(error::pass_through)); - } - } else { - callback(lib::error_code()); - } + m_proxy_data->write_buf = m_proxy_data->req.raw(); + + m_bufs.push_back(lib::asio::buffer(m_proxy_data->write_buf.data(), + m_proxy_data->write_buf.size())); + + m_alog->write(log::alevel::devel, m_proxy_data->write_buf); + + // Set a timer so we don't wait forever for the proxy to respond + m_proxy_data->timer = + this->set_timer(m_proxy_data->timeout_proxy, + lib::bind(&type::handle_proxy_timeout, get_shared(), + callback, lib::placeholders::_1)); + + // Send proxy request + if (config::enable_multithreading) { + lib::asio::async_write( + socket_con_type::get_next_layer(), m_bufs, + m_strand->wrap(lib::bind(&type::handle_proxy_write, get_shared(), + callback, lib::placeholders::_1))); + } else { + lib::asio::async_write(socket_con_type::get_next_layer(), m_bufs, + lib::bind(&type::handle_proxy_write, get_shared(), + callback, lib::placeholders::_1)); + } + } + + void handle_proxy_timeout(init_handler callback, lib::error_code const& ec) { + if (ec == transport::error::operation_aborted) { + m_alog->write(log::alevel::devel, + "asio handle_proxy_write timer cancelled"); + return; + } else if (ec) { + log_err(log::elevel::devel, "asio handle_proxy_write", ec); + callback(ec); + } else { + m_alog->write(log::alevel::devel, + "asio handle_proxy_write timer expired"); + cancel_socket_checked(); + callback(make_error_code(transport::error::timeout)); + } + } + + void handle_proxy_write(init_handler callback, + lib::asio::error_code const& ec) { + if (m_alog->static_test(log::alevel::devel)) { + m_alog->write(log::alevel::devel, "asio connection handle_proxy_write"); } - /// Get a pointer to this connection's strand - strand_ptr get_strand() { - return m_strand; + m_bufs.clear(); + + // Timer expired or the operation was aborted for some reason. + // Whatever aborted it will be issuing the callback so we are safe to + // return + if (ec == lib::asio::error::operation_aborted || + lib::asio::is_neg(m_proxy_data->timer->expires_from_now())) { + m_elog->write(log::elevel::devel, "write operation aborted"); + return; } - /// Get the internal transport error code for a closed/failed connection - /** - * Retrieves a machine readable detailed error code indicating the reason - * that the connection was closed or failed. Valid only after the close or - * fail handler is called. - * - * Primarily used if you are using mismatched asio / system_error - * implementations such as `boost::asio` with `std::system_error`. In these - * cases the transport error type is different than the library error type - * and some WebSocket++ functions that return transport errors via the - * library error code type will be coerced into a catch all `pass_through` - * or `tls_error` error. This method will return the original machine - * readable transport error in the native type. - * - * @since 0.7.0 - * - * @return Error code indicating the reason the connection was closed or - * failed - */ - lib::asio::error_code get_transport_ec() const { - return m_tec; + if (ec) { + log_err(log::elevel::info, "asio handle_proxy_write", ec); + m_proxy_data->timer->cancel(); + callback(make_error_code(error::pass_through)); + return; } - /// Initialize transport for reading - /** - * init_asio is called once immediately after construction to initialize - * Asio components to the io_service - * - * The transport initialization sequence consists of the following steps: - * - Pre-init: the underlying socket is initialized to the point where - * bytes may be written. No bytes are actually written in this stage - * - Proxy negotiation: if a proxy is set, a request is made to it to start - * a tunnel to the final destination. This stage ends when the proxy is - * ready to forward the - * next byte to the remote endpoint. - * - Post-init: Perform any i/o with the remote endpoint, such as setting up - * tunnels for encryption. This stage ends when the connection is ready to - * read or write the WebSocket handshakes. At this point the original - * callback function is called. - */ -protected: - void init(init_handler callback) { - if (m_alog->static_test(log::alevel::devel)) { - m_alog->write(log::alevel::devel,"asio connection init"); - } + proxy_read(callback); + } - // TODO: pre-init timeout. Right now no implemented socket policies - // actually have an asyncronous pre-init - - socket_con_type::pre_init( - lib::bind( - &type::handle_pre_init, - get_shared(), - callback, - lib::placeholders::_1 - ) - ); + void proxy_read(init_handler callback) { + if (m_alog->static_test(log::alevel::devel)) { + m_alog->write(log::alevel::devel, "asio connection proxy_read"); } - /// initialize the proxy buffers and http parsers - /** - * - * @param authority The address of the server we want the proxy to tunnel to - * in the format of a URI authority (host:port) - * - * @return Status code indicating what errors occurred, if any - */ - lib::error_code proxy_init(std::string const & authority) { - if (!m_proxy_data) { - return websocketpp::error::make_error_code( - websocketpp::error::invalid_state); - } - m_proxy_data->req.set_version("HTTP/1.1"); - m_proxy_data->req.set_method("CONNECT"); - - m_proxy_data->req.set_uri(authority); - m_proxy_data->req.replace_header("Host",authority); - - return lib::error_code(); + if (!m_proxy_data) { + m_elog->write( + log::elevel::library, + "assertion failed: !m_proxy_data in asio::connection::proxy_read"); + m_proxy_data->timer->cancel(); + callback(make_error_code(error::general)); + return; } - /// Finish constructing the transport - /** - * init_asio is called once immediately after construction to initialize - * Asio components to the io_service. - * - * @param io_service A pointer to the io_service to register with this - * connection - * - * @return Status code for the success or failure of the initialization - */ - lib::error_code init_asio (io_service_ptr io_service) { - m_io_service = io_service; + if (config::enable_multithreading) { + lib::asio::async_read_until( + socket_con_type::get_next_layer(), m_proxy_data->read_buf, "\r\n\r\n", + m_strand->wrap(lib::bind(&type::handle_proxy_read, get_shared(), + callback, lib::placeholders::_1, + lib::placeholders::_2))); + } else { + lib::asio::async_read_until( + socket_con_type::get_next_layer(), m_proxy_data->read_buf, "\r\n\r\n", + lib::bind(&type::handle_proxy_read, get_shared(), callback, + lib::placeholders::_1, lib::placeholders::_2)); + } + } - if (config::enable_multithreading) { - m_strand.reset(new lib::asio::io_service::strand(*io_service)); - } - - lib::error_code ec = socket_con_type::init_asio(io_service, m_strand, - m_is_server); - - return ec; + /// Proxy read callback + /** + * @param init_handler The function to call back + * @param ec The status code + * @param bytes_transferred The number of bytes read + */ + void handle_proxy_read(init_handler callback, lib::asio::error_code const& ec, + size_t) { + if (m_alog->static_test(log::alevel::devel)) { + m_alog->write(log::alevel::devel, "asio connection handle_proxy_read"); } - void handle_pre_init(init_handler callback, lib::error_code const & ec) { - if (m_alog->static_test(log::alevel::devel)) { - m_alog->write(log::alevel::devel,"asio connection handle pre_init"); - } - - if (m_tcp_pre_init_handler) { - m_tcp_pre_init_handler(m_connection_hdl); - } - - if (ec) { - callback(ec); - } - - // If we have a proxy set issue a proxy connect, otherwise skip to - // post_init - if (!m_proxy.empty()) { - proxy_write(callback); - } else { - post_init(callback); - } + // Timer expired or the operation was aborted for some reason. + // Whatever aborted it will be issuing the callback so we are safe to + // return + if (ec == lib::asio::error::operation_aborted || + lib::asio::is_neg(m_proxy_data->timer->expires_from_now())) { + m_elog->write(log::elevel::devel, "read operation aborted"); + return; } - void post_init(init_handler callback) { - if (m_alog->static_test(log::alevel::devel)) { - m_alog->write(log::alevel::devel,"asio connection post_init"); - } + // At this point there is no need to wait for the timer anymore + m_proxy_data->timer->cancel(); - timer_ptr post_timer; - - if (config::timeout_socket_post_init > 0) { - post_timer = set_timer( - config::timeout_socket_post_init, - lib::bind( - &type::handle_post_init_timeout, - get_shared(), - post_timer, - callback, - lib::placeholders::_1 - ) - ); - } + if (ec) { + m_elog->write(log::elevel::info, + "asio handle_proxy_read error: " + ec.message()); + callback(make_error_code(error::pass_through)); + } else { + if (!m_proxy_data) { + m_elog->write(log::elevel::library, + "assertion failed: !m_proxy_data in " + "asio::connection::handle_proxy_read"); + callback(make_error_code(error::general)); + return; + } - socket_con_type::post_init( - lib::bind( - &type::handle_post_init, - get_shared(), - post_timer, - callback, - lib::placeholders::_1 - ) - ); + std::istream input(&m_proxy_data->read_buf); + + m_proxy_data->res.consume(input); + + if (!m_proxy_data->res.headers_ready()) { + // we read until the headers were done in theory but apparently + // they aren't. Internal endpoint error. + callback(make_error_code(error::general)); + return; + } + + m_alog->write(log::alevel::devel, m_proxy_data->res.raw()); + + if (m_proxy_data->res.get_status_code() != http::status_code::ok) { + // got an error response back + // TODO: expose this error in a programmatically accessible way? + // if so, see below for an option on how to do this. + std::stringstream s; + s << "Proxy connection error: " << m_proxy_data->res.get_status_code() + << " (" << m_proxy_data->res.get_status_msg() << ")"; + m_elog->write(log::elevel::info, s.str()); + callback(make_error_code(error::proxy_failed)); + return; + } + + // we have successfully established a connection to the proxy, now + // we can continue and the proxy will transparently forward the + // WebSocket connection. + + // TODO: decide if we want an on_proxy callback that would allow + // access to the proxy response. + + // free the proxy buffers and req/res objects as they aren't needed + // anymore + m_proxy_data.reset(); + + // Continue with post proxy initialization + post_init(callback); + } + } + + /// read at least num_bytes bytes into buf and then call handler. + void async_read_at_least(size_t num_bytes, char* buf, size_t len, + read_handler handler) { + if (m_alog->static_test(log::alevel::devel)) { + std::stringstream s; + s << "asio async_read_at_least: " << num_bytes; + m_alog->write(log::alevel::devel, s.str()); } - /// Post init timeout callback - /** - * The timer pointer is included to ensure the timer isn't destroyed until - * after it has expired. - * - * @param post_timer Pointer to the timer in question - * @param callback The function to call back - * @param ec The status code - */ - void handle_post_init_timeout(timer_ptr, init_handler callback, - lib::error_code const & ec) - { - lib::error_code ret_ec; - - if (ec) { - if (ec == transport::error::operation_aborted) { - m_alog->write(log::alevel::devel, - "asio post init timer cancelled"); - return; - } - - log_err(log::elevel::devel,"asio handle_post_init_timeout",ec); - ret_ec = ec; - } else { - if (socket_con_type::get_ec()) { - ret_ec = socket_con_type::get_ec(); - } else { - ret_ec = make_error_code(transport::error::timeout); - } - } - - m_alog->write(log::alevel::devel, "Asio transport post-init timed out"); - cancel_socket_checked(); - callback(ret_ec); - } - - /// Post init timeout callback - /** - * The timer pointer is included to ensure the timer isn't destroyed until - * after it has expired. - * - * @param post_timer Pointer to the timer in question - * @param callback The function to call back - * @param ec The status code - */ - void handle_post_init(timer_ptr post_timer, init_handler callback, - lib::error_code const & ec) - { - if (ec == transport::error::operation_aborted || - (post_timer && lib::asio::is_neg(post_timer->expires_from_now()))) - { - m_alog->write(log::alevel::devel,"post_init cancelled"); - return; - } - - if (post_timer) { - post_timer->cancel(); - } - - if (m_alog->static_test(log::alevel::devel)) { - m_alog->write(log::alevel::devel,"asio connection handle_post_init"); - } - - if (m_tcp_post_init_handler) { - m_tcp_post_init_handler(m_connection_hdl); - } - - callback(ec); - } - - void proxy_write(init_handler callback) { - if (m_alog->static_test(log::alevel::devel)) { - m_alog->write(log::alevel::devel,"asio connection proxy_write"); - } - - if (!m_proxy_data) { - m_elog->write(log::elevel::library, - "assertion failed: !m_proxy_data in asio::connection::proxy_write"); - callback(make_error_code(error::general)); - return; - } - - m_proxy_data->write_buf = m_proxy_data->req.raw(); - - m_bufs.push_back(lib::asio::buffer(m_proxy_data->write_buf.data(), - m_proxy_data->write_buf.size())); - - m_alog->write(log::alevel::devel,m_proxy_data->write_buf); - - // Set a timer so we don't wait forever for the proxy to respond - m_proxy_data->timer = this->set_timer( - m_proxy_data->timeout_proxy, - lib::bind( - &type::handle_proxy_timeout, - get_shared(), - callback, - lib::placeholders::_1 - ) - ); - - // Send proxy request - if (config::enable_multithreading) { - lib::asio::async_write( - socket_con_type::get_next_layer(), - m_bufs, - m_strand->wrap(lib::bind( - &type::handle_proxy_write, get_shared(), - callback, - lib::placeholders::_1 - )) - ); - } else { - lib::asio::async_write( - socket_con_type::get_next_layer(), - m_bufs, - lib::bind( - &type::handle_proxy_write, get_shared(), - callback, - lib::placeholders::_1 - ) - ); - } - } - - void handle_proxy_timeout(init_handler callback, lib::error_code const & ec) - { - if (ec == transport::error::operation_aborted) { - m_alog->write(log::alevel::devel, - "asio handle_proxy_write timer cancelled"); - return; - } else if (ec) { - log_err(log::elevel::devel,"asio handle_proxy_write",ec); - callback(ec); - } else { - m_alog->write(log::alevel::devel, - "asio handle_proxy_write timer expired"); - cancel_socket_checked(); - callback(make_error_code(transport::error::timeout)); - } - } - - void handle_proxy_write(init_handler callback, - lib::asio::error_code const & ec) - { - if (m_alog->static_test(log::alevel::devel)) { - m_alog->write(log::alevel::devel, - "asio connection handle_proxy_write"); - } - - m_bufs.clear(); - - // Timer expired or the operation was aborted for some reason. - // Whatever aborted it will be issuing the callback so we are safe to - // return - if (ec == lib::asio::error::operation_aborted || - lib::asio::is_neg(m_proxy_data->timer->expires_from_now())) - { - m_elog->write(log::elevel::devel,"write operation aborted"); - return; - } - - if (ec) { - log_err(log::elevel::info,"asio handle_proxy_write",ec); - m_proxy_data->timer->cancel(); - callback(make_error_code(error::pass_through)); - return; - } - - proxy_read(callback); - } - - void proxy_read(init_handler callback) { - if (m_alog->static_test(log::alevel::devel)) { - m_alog->write(log::alevel::devel,"asio connection proxy_read"); - } - - if (!m_proxy_data) { - m_elog->write(log::elevel::library, - "assertion failed: !m_proxy_data in asio::connection::proxy_read"); - m_proxy_data->timer->cancel(); - callback(make_error_code(error::general)); - return; - } - - if (config::enable_multithreading) { - lib::asio::async_read_until( - socket_con_type::get_next_layer(), - m_proxy_data->read_buf, - "\r\n\r\n", - m_strand->wrap(lib::bind( - &type::handle_proxy_read, get_shared(), - callback, - lib::placeholders::_1, lib::placeholders::_2 - )) - ); - } else { - lib::asio::async_read_until( - socket_con_type::get_next_layer(), - m_proxy_data->read_buf, - "\r\n\r\n", - lib::bind( - &type::handle_proxy_read, get_shared(), - callback, - lib::placeholders::_1, lib::placeholders::_2 - ) - ); - } - } - - /// Proxy read callback - /** - * @param init_handler The function to call back - * @param ec The status code - * @param bytes_transferred The number of bytes read - */ - void handle_proxy_read(init_handler callback, - lib::asio::error_code const & ec, size_t) - { - if (m_alog->static_test(log::alevel::devel)) { - m_alog->write(log::alevel::devel, - "asio connection handle_proxy_read"); - } - - // Timer expired or the operation was aborted for some reason. - // Whatever aborted it will be issuing the callback so we are safe to - // return - if (ec == lib::asio::error::operation_aborted || - lib::asio::is_neg(m_proxy_data->timer->expires_from_now())) - { - m_elog->write(log::elevel::devel,"read operation aborted"); - return; - } - - // At this point there is no need to wait for the timer anymore - m_proxy_data->timer->cancel(); - - if (ec) { - m_elog->write(log::elevel::info, - "asio handle_proxy_read error: "+ec.message()); - callback(make_error_code(error::pass_through)); - } else { - if (!m_proxy_data) { - m_elog->write(log::elevel::library, - "assertion failed: !m_proxy_data in asio::connection::handle_proxy_read"); - callback(make_error_code(error::general)); - return; - } - - std::istream input(&m_proxy_data->read_buf); - - m_proxy_data->res.consume(input); - - if (!m_proxy_data->res.headers_ready()) { - // we read until the headers were done in theory but apparently - // they aren't. Internal endpoint error. - callback(make_error_code(error::general)); - return; - } - - m_alog->write(log::alevel::devel,m_proxy_data->res.raw()); - - if (m_proxy_data->res.get_status_code() != http::status_code::ok) { - // got an error response back - // TODO: expose this error in a programmatically accessible way? - // if so, see below for an option on how to do this. - std::stringstream s; - s << "Proxy connection error: " - << m_proxy_data->res.get_status_code() - << " (" - << m_proxy_data->res.get_status_msg() - << ")"; - m_elog->write(log::elevel::info,s.str()); - callback(make_error_code(error::proxy_failed)); - return; - } - - // we have successfully established a connection to the proxy, now - // we can continue and the proxy will transparently forward the - // WebSocket connection. - - // TODO: decide if we want an on_proxy callback that would allow - // access to the proxy response. - - // free the proxy buffers and req/res objects as they aren't needed - // anymore - m_proxy_data.reset(); - - // Continue with post proxy initialization - post_init(callback); - } - } - - /// read at least num_bytes bytes into buf and then call handler. - void async_read_at_least(size_t num_bytes, char *buf, size_t len, - read_handler handler) - { - if (m_alog->static_test(log::alevel::devel)) { - std::stringstream s; - s << "asio async_read_at_least: " << num_bytes; - m_alog->write(log::alevel::devel,s.str()); - } - - // TODO: safety vs speed ? - // maybe move into an if devel block - /*if (num_bytes > len) { - m_elog->write(log::elevel::devel, - "asio async_read_at_least error::invalid_num_bytes"); - handler(make_error_code(transport::error::invalid_num_bytes), - size_t(0)); - return; - }*/ - - if (config::enable_multithreading) { - lib::asio::async_read( - socket_con_type::get_socket(), - lib::asio::buffer(buf,len), - lib::asio::transfer_at_least(num_bytes), - m_strand->wrap(make_custom_alloc_handler( - m_read_handler_allocator, - lib::bind( - &type::handle_async_read, get_shared(), - handler, - lib::placeholders::_1, lib::placeholders::_2 - ) - )) - ); - } else { - lib::asio::async_read( - socket_con_type::get_socket(), - lib::asio::buffer(buf,len), - lib::asio::transfer_at_least(num_bytes), - make_custom_alloc_handler( - m_read_handler_allocator, - lib::bind( - &type::handle_async_read, get_shared(), - handler, - lib::placeholders::_1, lib::placeholders::_2 - ) - ) - ); - } - - } - - void handle_async_read(read_handler handler, lib::asio::error_code const & ec, - size_t bytes_transferred) - { - m_alog->write(log::alevel::devel, "asio con handle_async_read"); - - // translate asio error codes into more lib::error_codes - lib::error_code tec; - if (ec == lib::asio::error::eof) { - tec = make_error_code(transport::error::eof); - } else if (ec) { - // We don't know much more about the error at this point. As our - // socket/security policy if it knows more: - tec = socket_con_type::translate_ec(ec); - m_tec = ec; - - if (tec == transport::error::tls_error || - tec == transport::error::pass_through) - { - // These are aggregate/catch all errors. Log some human readable - // information to the info channel to give library users some - // more details about why the upstream method may have failed. - log_err(log::elevel::info,"asio async_read_at_least",ec); - } - } - if (handler) { - handler(tec,bytes_transferred); - } else { - // This can happen in cases where the connection is terminated while - // the transport is waiting on a read. - m_alog->write(log::alevel::devel, - "handle_async_read called with null read handler"); - } - } - - /// Initiate a potentially asyncronous write of the given buffer - void async_write(const char* buf, size_t len, write_handler handler) { - m_bufs.push_back(lib::asio::buffer(buf,len)); - - if (config::enable_multithreading) { - lib::asio::async_write( - socket_con_type::get_socket(), - m_bufs, - m_strand->wrap(make_custom_alloc_handler( - m_write_handler_allocator, - lib::bind( - &type::handle_async_write, get_shared(), - handler, - lib::placeholders::_1, lib::placeholders::_2 - ) - )) - ); - } else { - lib::asio::async_write( - socket_con_type::get_socket(), - m_bufs, - make_custom_alloc_handler( - m_write_handler_allocator, - lib::bind( - &type::handle_async_write, get_shared(), - handler, - lib::placeholders::_1, lib::placeholders::_2 - ) - ) - ); - } - } - - /// Initiate a potentially asyncronous write of the given buffers - void async_write(std::vector const & bufs, write_handler handler) { - std::vector::const_iterator it; - - for (it = bufs.begin(); it != bufs.end(); ++it) { - m_bufs.push_back(lib::asio::buffer((*it).buf,(*it).len)); - } - - if (config::enable_multithreading) { - lib::asio::async_write( - socket_con_type::get_socket(), - m_bufs, - m_strand->wrap(make_custom_alloc_handler( - m_write_handler_allocator, - lib::bind( - &type::handle_async_write, get_shared(), - handler, - lib::placeholders::_1, lib::placeholders::_2 - ) - )) - ); - } else { - lib::asio::async_write( - socket_con_type::get_socket(), - m_bufs, - make_custom_alloc_handler( - m_write_handler_allocator, - lib::bind( - &type::handle_async_write, get_shared(), - handler, - lib::placeholders::_1, lib::placeholders::_2 - ) - ) - ); - } - } - - /// Async write callback - /** - * @param ec The status code - * @param bytes_transferred The number of bytes read - */ - void handle_async_write(write_handler handler, lib::asio::error_code const & ec, size_t) { - m_bufs.clear(); - lib::error_code tec; - if (ec) { - log_err(log::elevel::info,"asio async_write",ec); - tec = make_error_code(transport::error::pass_through); - } - if (handler) { - handler(tec); - } else { - // This can happen in cases where the connection is terminated while - // the transport is waiting on a read. - m_alog->write(log::alevel::devel, - "handle_async_write called with null write handler"); - } - } - - /// Set Connection Handle - /** - * See common/connection_hdl.hpp for information - * - * @param hdl A connection_hdl that the transport will use to refer - * to itself - */ - void set_handle(connection_hdl hdl) { - m_connection_hdl = hdl; - socket_con_type::set_handle(hdl); - } - - /// Trigger the on_interrupt handler - /** - * This needs to be thread safe - */ - lib::error_code interrupt(interrupt_handler handler) { - if (config::enable_multithreading) { - m_io_service->post(m_strand->wrap(handler)); - } else { - m_io_service->post(handler); - } - return lib::error_code(); - } - - lib::error_code dispatch(dispatch_handler handler) { - if (config::enable_multithreading) { - m_io_service->post(m_strand->wrap(handler)); - } else { - m_io_service->post(handler); - } - return lib::error_code(); - } - - /*void handle_interrupt(interrupt_handler handler) { - handler(); + // TODO: safety vs speed ? + // maybe move into an if devel block + /*if (num_bytes > len) { + m_elog->write(log::elevel::devel, + "asio async_read_at_least error::invalid_num_bytes"); + handler(make_error_code(transport::error::invalid_num_bytes), + size_t(0)); + return; }*/ - /// close and clean up the underlying socket - void async_shutdown(shutdown_handler callback) { - if (m_alog->static_test(log::alevel::devel)) { - m_alog->write(log::alevel::devel,"asio connection async_shutdown"); - } + if (config::enable_multithreading) { + lib::asio::async_read( + socket_con_type::get_socket(), lib::asio::buffer(buf, len), + lib::asio::transfer_at_least(num_bytes), + m_strand->wrap(make_custom_alloc_handler( + m_read_handler_allocator, + lib::bind(&type::handle_async_read, get_shared(), handler, + lib::placeholders::_1, lib::placeholders::_2)))); + } else { + lib::asio::async_read( + socket_con_type::get_socket(), lib::asio::buffer(buf, len), + lib::asio::transfer_at_least(num_bytes), + make_custom_alloc_handler( + m_read_handler_allocator, + lib::bind(&type::handle_async_read, get_shared(), handler, + lib::placeholders::_1, lib::placeholders::_2))); + } + } - timer_ptr shutdown_timer; - shutdown_timer = set_timer( - config::timeout_socket_shutdown, - lib::bind( - &type::handle_async_shutdown_timeout, - get_shared(), - shutdown_timer, - callback, - lib::placeholders::_1 - ) - ); + void handle_async_read(read_handler handler, lib::asio::error_code const& ec, + size_t bytes_transferred) { + m_alog->write(log::alevel::devel, "asio con handle_async_read"); - socket_con_type::async_shutdown( - lib::bind( - &type::handle_async_shutdown, - get_shared(), - shutdown_timer, - callback, - lib::placeholders::_1 - ) - ); + // translate asio error codes into more lib::error_codes + lib::error_code tec; + if (ec == lib::asio::error::eof) { + tec = make_error_code(transport::error::eof); + } else if (ec) { + // We don't know much more about the error at this point. As our + // socket/security policy if it knows more: + tec = socket_con_type::translate_ec(ec); + m_tec = ec; + + if (tec == transport::error::tls_error || + tec == transport::error::pass_through) { + // These are aggregate/catch all errors. Log some human readable + // information to the info channel to give library users some + // more details about why the upstream method may have failed. + log_err(log::elevel::info, "asio async_read_at_least", ec); + } + } + if (handler) { + handler(tec, bytes_transferred); + } else { + // This can happen in cases where the connection is terminated while + // the transport is waiting on a read. + m_alog->write(log::alevel::devel, + "handle_async_read called with null read handler"); + } + } + + /// Initiate a potentially asyncronous write of the given buffer + void async_write(const char* buf, size_t len, write_handler handler) { + m_bufs.push_back(lib::asio::buffer(buf, len)); + + if (config::enable_multithreading) { + lib::asio::async_write( + socket_con_type::get_socket(), m_bufs, + m_strand->wrap(make_custom_alloc_handler( + m_write_handler_allocator, + lib::bind(&type::handle_async_write, get_shared(), handler, + lib::placeholders::_1, lib::placeholders::_2)))); + } else { + lib::asio::async_write( + socket_con_type::get_socket(), m_bufs, + make_custom_alloc_handler( + m_write_handler_allocator, + lib::bind(&type::handle_async_write, get_shared(), handler, + lib::placeholders::_1, lib::placeholders::_2))); + } + } + + /// Initiate a potentially asyncronous write of the given buffers + void async_write(std::vector const& bufs, write_handler handler) { + std::vector::const_iterator it; + + for (it = bufs.begin(); it != bufs.end(); ++it) { + m_bufs.push_back(lib::asio::buffer((*it).buf, (*it).len)); } - /// Async shutdown timeout handler - /** - * @param shutdown_timer A pointer to the timer to keep it in scope - * @param callback The function to call back - * @param ec The status code - */ - void handle_async_shutdown_timeout(timer_ptr, init_handler callback, - lib::error_code const & ec) - { - lib::error_code ret_ec; + if (config::enable_multithreading) { + lib::asio::async_write( + socket_con_type::get_socket(), m_bufs, + m_strand->wrap(make_custom_alloc_handler( + m_write_handler_allocator, + lib::bind(&type::handle_async_write, get_shared(), handler, + lib::placeholders::_1, lib::placeholders::_2)))); + } else { + lib::asio::async_write( + socket_con_type::get_socket(), m_bufs, + make_custom_alloc_handler( + m_write_handler_allocator, + lib::bind(&type::handle_async_write, get_shared(), handler, + lib::placeholders::_1, lib::placeholders::_2))); + } + } - if (ec) { - if (ec == transport::error::operation_aborted) { - m_alog->write(log::alevel::devel, - "asio socket shutdown timer cancelled"); - return; - } + /// Async write callback + /** + * @param ec The status code + * @param bytes_transferred The number of bytes read + */ + void handle_async_write(write_handler handler, + lib::asio::error_code const& ec, size_t) { + m_bufs.clear(); + lib::error_code tec; + if (ec) { + log_err(log::elevel::info, "asio async_write", ec); + tec = make_error_code(transport::error::pass_through); + } + if (handler) { + handler(tec); + } else { + // This can happen in cases where the connection is terminated while + // the transport is waiting on a read. + m_alog->write(log::alevel::devel, + "handle_async_write called with null write handler"); + } + } - log_err(log::elevel::devel,"asio handle_async_shutdown_timeout",ec); - ret_ec = ec; - } else { - ret_ec = make_error_code(transport::error::timeout); - } + /// Set Connection Handle + /** + * See common/connection_hdl.hpp for information + * + * @param hdl A connection_hdl that the transport will use to refer + * to itself + */ + void set_handle(connection_hdl hdl) { + m_connection_hdl = hdl; + socket_con_type::set_handle(hdl); + } + /// Trigger the on_interrupt handler + /** + * This needs to be thread safe + */ + lib::error_code interrupt(interrupt_handler handler) { + if (config::enable_multithreading) { + m_io_service->post(m_strand->wrap(handler)); + } else { + m_io_service->post(handler); + } + return lib::error_code(); + } + + lib::error_code dispatch(dispatch_handler handler) { + if (config::enable_multithreading) { + m_io_service->post(m_strand->wrap(handler)); + } else { + m_io_service->post(handler); + } + return lib::error_code(); + } + + /*void handle_interrupt(interrupt_handler handler) { + handler(); + }*/ + + /// close and clean up the underlying socket + void async_shutdown(shutdown_handler callback) { + if (m_alog->static_test(log::alevel::devel)) { + m_alog->write(log::alevel::devel, "asio connection async_shutdown"); + } + + timer_ptr shutdown_timer; + shutdown_timer = + set_timer(config::timeout_socket_shutdown, + lib::bind(&type::handle_async_shutdown_timeout, get_shared(), + shutdown_timer, callback, lib::placeholders::_1)); + + socket_con_type::async_shutdown(lib::bind(&type::handle_async_shutdown, + get_shared(), shutdown_timer, + callback, lib::placeholders::_1)); + } + + /// Async shutdown timeout handler + /** + * @param shutdown_timer A pointer to the timer to keep it in scope + * @param callback The function to call back + * @param ec The status code + */ + void handle_async_shutdown_timeout(timer_ptr, init_handler callback, + lib::error_code const& ec) { + lib::error_code ret_ec; + + if (ec) { + if (ec == transport::error::operation_aborted) { m_alog->write(log::alevel::devel, - "Asio transport socket shutdown timed out"); - cancel_socket_checked(); - callback(ret_ec); + "asio socket shutdown timer cancelled"); + return; + } + + log_err(log::elevel::devel, "asio handle_async_shutdown_timeout", ec); + ret_ec = ec; + } else { + ret_ec = make_error_code(transport::error::timeout); } - void handle_async_shutdown(timer_ptr shutdown_timer, shutdown_handler - callback, lib::asio::error_code const & ec) - { - if (ec == lib::asio::error::operation_aborted || - lib::asio::is_neg(shutdown_timer->expires_from_now())) - { - m_alog->write(log::alevel::devel,"async_shutdown cancelled"); - return; - } + m_alog->write(log::alevel::devel, + "Asio transport socket shutdown timed out"); + cancel_socket_checked(); + callback(ret_ec); + } - shutdown_timer->cancel(); - - lib::error_code tec; - if (ec) { - if (ec == lib::asio::error::not_connected) { - // The socket was already closed when we tried to close it. This - // happens periodically (usually if a read or write fails - // earlier and if it is a real error will be caught at another - // level of the stack. - } else { - // We don't know anything more about this error, give our - // socket/security policy a crack at it. - tec = socket_con_type::translate_ec(ec); - m_tec = ec; - - // all other errors are effectively pass through errors of - // some sort so print some detail on the info channel for - // library users to look up if needed. - log_err(log::elevel::info,"asio async_shutdown",ec); - } - } else { - if (m_alog->static_test(log::alevel::devel)) { - m_alog->write(log::alevel::devel, - "asio con handle_async_shutdown"); - } - } - callback(tec); + void handle_async_shutdown(timer_ptr shutdown_timer, + shutdown_handler callback, + lib::asio::error_code const& ec) { + if (ec == lib::asio::error::operation_aborted || + lib::asio::is_neg(shutdown_timer->expires_from_now())) { + m_alog->write(log::alevel::devel, "async_shutdown cancelled"); + return; } - /// Cancel the underlying socket and log any errors - void cancel_socket_checked() { - lib::asio::error_code cec = socket_con_type::cancel_socket(); - if (cec) { - if (cec == lib::asio::error::operation_not_supported) { - // cancel not supported on this OS, ignore and log at dev level - m_alog->write(log::alevel::devel, "socket cancel not supported"); - } else { - log_err(log::elevel::warn, "socket cancel failed", cec); - } - } + shutdown_timer->cancel(); + + lib::error_code tec; + if (ec) { + if (ec == lib::asio::error::not_connected) { + // The socket was already closed when we tried to close it. This + // happens periodically (usually if a read or write fails + // earlier and if it is a real error will be caught at another + // level of the stack. + } else { + // We don't know anything more about this error, give our + // socket/security policy a crack at it. + tec = socket_con_type::translate_ec(ec); + m_tec = ec; + + // all other errors are effectively pass through errors of + // some sort so print some detail on the info channel for + // library users to look up if needed. + log_err(log::elevel::info, "asio async_shutdown", ec); + } + } else { + if (m_alog->static_test(log::alevel::devel)) { + m_alog->write(log::alevel::devel, "asio con handle_async_shutdown"); + } } + callback(tec); + } -private: - /// Convenience method for logging the code and message for an error_code - template - void log_err(log::level l, const char * msg, const error_type & ec) { - std::stringstream s; - s << msg << " error: " << ec << " (" << ec.message() << ")"; - m_elog->write(l,s.str()); + /// Cancel the underlying socket and log any errors + void cancel_socket_checked() { + lib::asio::error_code cec = socket_con_type::cancel_socket(); + if (cec) { + if (cec == lib::asio::error::operation_not_supported) { + // cancel not supported on this OS, ignore and log at dev level + m_alog->write(log::alevel::devel, "socket cancel not supported"); + } else { + log_err(log::elevel::warn, "socket cancel failed", cec); + } } + } - // static settings - const bool m_is_server; - lib::shared_ptr m_alog; - lib::shared_ptr m_elog; + private: + /// Convenience method for logging the code and message for an error_code + template + void log_err(log::level l, const char* msg, const error_type& ec) { + std::stringstream s; + s << msg << " error: " << ec << " (" << ec.message() << ")"; + m_elog->write(l, s.str()); + } - struct proxy_data { - proxy_data() : timeout_proxy(config::timeout_proxy) {} + // static settings + const bool m_is_server; + lib::shared_ptr m_alog; + lib::shared_ptr m_elog; - request_type req; - response_type res; - std::string write_buf; - lib::asio::streambuf read_buf; - long timeout_proxy; - timer_ptr timer; - }; + struct proxy_data { + proxy_data() : timeout_proxy(config::timeout_proxy) {} - std::string m_proxy; - lib::shared_ptr m_proxy_data; + request_type req; + response_type res; + std::string write_buf; + lib::asio::streambuf read_buf; + long timeout_proxy; + timer_ptr timer; + }; - // transport resources - io_service_ptr m_io_service; - strand_ptr m_strand; - connection_hdl m_connection_hdl; + std::string m_proxy; + lib::shared_ptr m_proxy_data; - std::vector m_bufs; + // transport resources + io_service_ptr m_io_service; + strand_ptr m_strand; + connection_hdl m_connection_hdl; - /// Detailed internal error code - lib::asio::error_code m_tec; + std::vector m_bufs; - // Handlers - tcp_init_handler m_tcp_pre_init_handler; - tcp_init_handler m_tcp_post_init_handler; + /// Detailed internal error code + lib::asio::error_code m_tec; - handler_allocator m_read_handler_allocator; - handler_allocator m_write_handler_allocator; + // Handlers + tcp_init_handler m_tcp_pre_init_handler; + tcp_init_handler m_tcp_post_init_handler; + + handler_allocator m_read_handler_allocator; + handler_allocator m_write_handler_allocator; }; +} // namespace asio +} // namespace transport +} // namespace websocketpp -} // namespace asio -} // namespace transport -} // namespace websocketpp - -#endif // WEBSOCKETPP_TRANSPORT_ASIO_CON_HPP +#endif // WEBSOCKETPP_TRANSPORT_ASIO_CON_HPP +#pragma warning(pop) \ No newline at end of file diff --git a/xmake.lua b/xmake.lua index 31ce80a..041667e 100644 --- a/xmake.lua +++ b/xmake.lua @@ -22,8 +22,7 @@ includes("thirdparty") if is_os("windows") then add_defines("_WEBSOCKETPP_CPP11_INTERNAL_") - -- add_cxflags("/W4", "/WX") - add_cxflags("/W4") + add_cxflags("/W4", "/WX") elseif is_os("linux") then add_requires("glib", {system = true}) add_packages("glib") @@ -176,8 +175,8 @@ target("media") target("qos") set_kind("object") add_deps("log") - add_files("src/qos/kcp/*.c") - add_includedirs("src/qos/kcp", {public = true}) + add_files("src/qos/*.cpp") + add_includedirs("src/qos", {public = true}) target("statistics") set_kind("object")