﻿#include "wrapclient.h"
#include "rtspwrapper.h"
#include "winssl.h"
#include "Base64.hh"
#include <GroupsockHelper.hh>
#include "CryptoSockets.h"

#include <sstream>
#include <map>
#include <wincrypt.h>

CWrapRTSPClient::CWrapRTSPClient(CLive555Wrapper *pWrapper)
  : RTSPClient(*pWrapper->Environment, pWrapper->RtspAddress.c_str(), pWrapper->VerbosityLevel, APPLICATION_NAME, pWrapper->TunnelPort, -1),
  pWrapper(pWrapper)
{
  setSocketURL(pWrapper->UriAddress.c_str());
}

void CWrapRTSPClient::ReflectSend(const char* Data)
{
  if (pWrapper->SendNotify) {
    pWrapper->SendNotify(pWrapper->NotifyData, pWrapper->CurrentStep, Data);
  }
}

void CWrapRTSPClient::ReflectReceive(const char* Data)
{
  if (pWrapper->ReceiveNotify) {
    pWrapper->ReceiveNotify(pWrapper->NotifyData, pWrapper->CurrentStep, Data);
  }
}

bool WebsocketHandshake(int Socket, const std::string &key);
bool HttpsHandshake(int Socket, const std::string &key);

bool CWrapRTSPClient::ReflectSocketHandshake(int Socket)
{
  if (pWrapper->Transport == ELT_WEBSOCKET)
  {
    return WebsocketHandshake(Socket);
  }
  if (pWrapper->Transport == ELT_HTTPS)
  {
    return HttpsHandshake(Socket);
  }
  return true;
}

struct HttpHeader : public std::map<std::string, std::string>
{
  std::string Firstline;
  int Code;
};

bool GetHeader(HttpHeader& header, char *Buffer)
{
  std::stringstream ss;
  ss.str(Buffer);
  std::string item;
  std::getline(ss, header.Firstline, '\n');

  if (1 != sscanf(header.Firstline.c_str(), "HTTP/1.1 %i", &header.Code)) {
    header.Code = -1;
    return false;
  }

  while (std::getline(ss, item, '\n')) {
    if ((item.size() <= 2) && ((item[0] == '\n') || (item[0] == '\r'))) { //end of header
      return true;
    }

    int pos = item.find_first_of(':');
    if (pos == -1) {
      return false;
    }
    int last = item.find_first_of("\r\n", pos);
    header[item.substr(0, pos)] = item.substr(pos + 2, last - pos - 2);
  }
  return true;
}

bool DoSha1(std::string &str)
{
  HCRYPTPROV hProv = 0;
  HCRYPTHASH hHash = 0;
  bool success = false;
  while (1) {
    if (!CryptAcquireContext(&hProv,
      NULL,
      NULL,
      PROV_RSA_FULL,
      CRYPT_VERIFYCONTEXT))
      break;

    if (!CryptCreateHash(hProv, CALG_SHA1, 0, 0, &hHash))
      break;

    if (!CryptHashData(hHash, (const BYTE *)str.c_str(), str.size(), 0))
      break;

    BYTE rgbHash[20];
    DWORD cbHash = 20;
    if (!CryptGetHashParam(hHash, HP_HASHVAL, rgbHash, &cbHash, 0))
      break;
    str = base64Encode((const char*)rgbHash, cbHash);
    success = true;
    break;
  }
  if (hHash) CryptDestroyHash(hHash);
  if (hProv) CryptReleaseContext(hProv, 0);
  return success;
}

struct EncryptedSocketWS : public EncryptedSocket {
  int poolIndex;
  int remainingData;
  std::vector<char> pool_;
  EncryptedSocketWS(int Socket) : EncryptedSocket(Socket), poolIndex(0), remainingData(0)
  {

  };
  ~EncryptedSocketWS()
  {
  };
  int send(int s, const char * buf, int len, int flags)
  {
    //create WS frame according to RFC 6455 https://tools.ietf.org/html/rfc6455
    char wsFrameBuf[8] = { 0 };

    wsFrameBuf[0] = (char)0x82;

    if (len > 125) {
      wsFrameBuf[1] = (char)0xfe;
      *(short*)(&wsFrameBuf[2]) = htons(len);
    }
    else {
      wsFrameBuf[1] = 0x80 + (len & 0x7f);
    }
    char* mask = (wsFrameBuf + (len > 125 ? 4 : 2));

    for (int i = 0; i < 4; i++) {
      *(mask + i) = rand() & 0xff;
    }

    ::send(Socket_, wsFrameBuf, (len > 125 ? 8 : 6), flags);

    std::vector<char> maskedData(len);

    for (int i = 0; i < len; i++) {
      maskedData[i] = buf[i] ^ *(mask + i % 4);
    }

    return ::send(Socket_, &maskedData.front(), maskedData.size(), flags);
  };
  int recv(int s, char * buf, int len, int flags)
  {
    char sab[40];
    int size = sizeof(sab);
    return recvfrom(s, buf, len, flags, sab, &size);
  };
  int recvfrom(int s, char * buf, int len, int flags, void* adr, int*adrlen)
  {
    while (true) {
      int result = 0;
      if (!pool_.empty())
      {
        len = (int)pool_.size() > len ? len : pool_.size();

        memcpy(buf, (char*)&pool_.front(), len);
        pool_.erase(pool_.begin(), pool_.begin() + len);

        return len;
      }

      if (0 == remainingData) { //parse WS frame structure
        std::vector<char> tmp(2);
        int size = receive(&tmp.front(), tmp.size(), flags, (sockaddr *)adr, adrlen);

        if (tmp.size() != size) {
          return size;
        }

        if ((tmp[1] & 0x7f) > 125) {
          size = receive(&tmp.front(), tmp.size(), flags, (sockaddr *)adr, adrlen);
          if (tmp.size() != size) {
            return size;
          }
          remainingData = htons(*((short*)&tmp[0]));
        }
        else {
          remainingData = tmp[1];
        }

        if (0 == remainingData && (tmp[0] & 0x08)) { //WS Close message
          continue;
        }
      }

      std::vector<char> tmp(len < remainingData ? len : remainingData);

      int received = receive(&tmp.front(), tmp.size(), flags, (sockaddr *)adr, adrlen);

      if (received <= 0) {
        remainingData = 0;
        return received;
      }

      pool_.insert(pool_.end(), tmp.begin(), tmp.begin() + received);
      remainingData -= received;
    }
  }

  int receive(char* buf, int len, int flags, void* adr, int*adrlen) {
    while (true) {
      int size = ::recvfrom(Socket_, buf, len, flags, (sockaddr *)adr, adrlen);
      if (size == SOCKET_ERROR && WSAGetLastError() == EAGAIN) {
        Sleep(5);
        continue;
      }
      return size;
    }
  }
};

bool CWrapRTSPClient::WebsocketHandshake(int Socket)
{
  WriteLog("Performing WebSocket handshake...");
  // see https://tools.ietf.org/html/rfc6455#section-4

  const char HandshakeTemplate[] =
    "GET %s HTTP/1.1\r\n"                             // local uri
    "Host: %s\r\n"                                    // host ip
    "Upgrade: websocket\r\n"
    "Connection: Upgrade\r\n"
    "Sec-WebSocket-Key: %s\r\n"                       // base64 hash key
    "Origin: %s\r\n"                                  // origin uri
    "Sec-WebSocket-Protocol: rtsp.onvif.org\r\n"
    "Sec-WebSocket-Version: 13\r\n\r\n";

  char intKey[16];
  for (int i = 0; i < 16; ++i) intKey[i] = rand() & 0xff;

  char* username;
  char* password;
  const char* cmdURL;
  NetAddress destAddress;
  portNumBits urlPortNum;

  if (!parseRTSPURL(envir(), pWrapper->UriAddress.c_str(), username, password, destAddress, urlPortNum, &cmdURL))
  {
    WriteLog("Can't parse Uri address [%s]", pWrapper->UriAddress.c_str());
    return false;
  }

  char Buffer[2048];
  std::string uri, ip, origin, key, rep;
  uri = cmdURL;
  if (uri.empty()) uri = "/";
  ip = inet_ntop(destAddress.af(), (PVOID)destAddress.data(), Buffer, sizeof(Buffer));
  origin = "http://" + ip;
  key = base64Encode(intKey, 16);
  rep = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
  if (!DoSha1(rep))
  {
    WriteLog("Error with WinCrypt SHA1");
    return false;
  }

  int size;

  size = _snprintf(Buffer, sizeof(Buffer), HandshakeTemplate,
    uri.c_str(),
    ip.c_str(),
    key.c_str(),
    origin.c_str());

  send(Socket, Buffer, size, 0);
  ReflectSend(Buffer);
  memset(Buffer, 0, sizeof(Buffer));



  while (true) {
    size = ::recv(Socket, Buffer, sizeof(Buffer), 0);
    if (size == SOCKET_ERROR && WSAGetLastError() == EAGAIN) {
      Sleep(5);
      continue;
    }
    break;
  }

  if (size <= 0) {
    WriteLog("ERROR while receiving HTTP response");
    return false;
  }

  Buffer[size] = '\n';
  ReflectReceive(Buffer);

  HttpHeader header;
  if (!GetHeader(header, Buffer)) {
    WriteLog("ERROR while parsing HTTP response header");
    return false;
  }

  if (header.Code == 401) {
    WriteLog("ERROR: received HTTP 401 Unauthorized. An ONVIF compliant device should authenticate an RTSP request at the RTSP level. If HTTP is used to tunnel the RTSP request the device shall not authenticate on the HTTP level.");
    return false;
  }

  if (header.Code != 101) {
    WriteLog("HTTP response code is skipped or not equal to 101");
    return false;
  }

  if (header["Upgrade"] != "websocket") {
    WriteLog("HTTP response header Upgrade is skipped or not equal to websocket");
    return false;
  }
  if (header["Connection"] != "Upgrade") {
    WriteLog("HTTP response header Connection is skipped or not equal to Upgrade");
    return false;
  }
  if (header["Sec-WebSocket-Accept"] != rep) {
    WriteLog("HTTP response header Sec-WebSocket-Accept is skipped or wrong");
    return false;
  }
  if (header["Sec-WebSocket-Protocol"] != "rtsp.onvif.org") {
    WriteLog("HTTP response header Sec-WebSocket-Protocol is skipped or not equal to rtsp.onvif.org");
    return false;
  }

  WriteLog("Performing WebSocket handshake done");
  EncryptedSocketWS* w = new EncryptedSocketWS(Socket);
  _sockets.push_back(w);
  return true;
}

struct EncryptedSocketSSL : public EncryptedSocket {
  WinCryptSock w_;
  std::vector<char> pool_;
  int poolIndex_;
  bool hasError() const { return w_.hasError(); };
  const char* getError() const { return w_.getError(); };
  EncryptedSocketSSL(int Socket, const char* ServerName) : EncryptedSocket(Socket), w_(Socket, ServerName)
  {
    poolIndex_ = 0;
    Socket_ = Socket;
    if (w_.CreateCredentials())
    {
      w_.Handshake();
    }
  };
  ~EncryptedSocketSSL()
  {
  };
  int send(int s, const char * buf, int len, int flags)
  {
    std::vector<char> encryptedData;
    int len2 = w_.Encrypt(encryptedData, buf, len);
    return len2 > 0 && ::send(Socket_, (char*)&encryptedData.front(), len2, flags) == len2 ? len : -1;
  };
  int recv(int s, char * buf, int len, int flags)
  {
    char sab[40];
    int size = sizeof(sab);
    return recvfrom(s, buf, len, flags, sab, &size);
  };
  int recvfrom(int s, char * buf, int len, int flags, void* adr, int*adrlen)
  {
    while (true) {
      if (!pool_.empty())
      {
        if (len > (int)pool_.size() - poolIndex_)
        {
          len = pool_.size() - poolIndex_;
        }
        memcpy(buf, (char*)&pool_.front() + poolIndex_, len);
        poolIndex_ += len;
        if (poolIndex_ >= (int)pool_.size())
        {
          pool_.clear();
          poolIndex_ = 0;
        }
        return len;
      }

      std::vector<char> data;
      data.resize(len + 2000);
      int size = ::recvfrom(Socket_, (char*)&data.front(), data.size(), flags, (sockaddr *)adr, adrlen);

      if (size <= 0)
      {
        return size;
      }

      int res = w_.Decrypt(pool_, (char*)&data.front(), size);

      if (res < 0) {
        return res;
      }
      poolIndex_ = 0;
    }
  }

  bool hasMoreData()
  {
    return ((int)pool_.size() - poolIndex_) > 0;
  }
};

bool CWrapRTSPClient::HttpsHandshake(int Socket)
{
  WriteLog("Performing HTTPS handshake...");
  char* username;
  char* password;
  const char* cmdURL;
  NetAddress destAddress;
  portNumBits urlPortNum;
  if (!parseRTSPURL(envir(), pWrapper->UriAddress.c_str(), username, password, destAddress, urlPortNum, &cmdURL))
  {
    WriteLog("Can't parse Uri address [%s]", pWrapper->UriAddress.c_str());
    return false;
  }
  char Buffer[256];
  std::string ip = inet_ntop(destAddress.af(), (PVOID)destAddress.data(), Buffer, sizeof(Buffer));
  makeSocketBlocking(Socket);
  EncryptedSocketSSL* w = new EncryptedSocketSSL(Socket, ip.c_str());
  _sockets.push_back(w);
  makeSocketNonBlocking(Socket);
  if (w->hasError())
  {
    WriteLog(w->getError());
    return false;
  }
  WriteLog("Performing HTTPS handshake done");
  return true;
}

CWrapRTSPClient::~CWrapRTSPClient()
{
  for (std::list<EncryptedSocket*>::const_iterator it = _sockets.begin(); it != _sockets.end(); it++)
  {
    delete *it;
  }
}

void CWrapRTSPClient::Reset()
{
  reset(); setBaseURL(pWrapper->RtspAddress.c_str());
}

Authenticator& CWrapRTSPClient::Authenticator()
{
  return fCurrentAuthenticator;
}

void CWrapRTSPClient::WriteLog(const char* Text, ...)
{
  va_list ap;
  char Buffer[2048];
  va_start(ap, Text);
  _vsnprintf(Buffer, sizeof(Buffer) - 1, Text, ap);
  va_end(ap);

  envir().setResultMsg(Buffer);
  pWrapper->WriteLog(Buffer);
}
