﻿#include "winssl.h"
#include <winsock2.h>
#define SECURITY_WIN32
#include <wincrypt.h>
#include <wintrust.h>
#include <schannel.h>
#include <security.h>
//#include <sspi.h>
#include <math.h> 

#pragma comment(lib, "Crypt32.Lib")


WinCryptSock::WinCryptSock(int Socket, const char* ServerName)
{
  secureLib_ = LoadLibrary(TEXT("Security.dll"));
  INIT_SECURITY_INTERFACE pInitSecurityInterface = (INIT_SECURITY_INTERFACE)GetProcAddress(
    (HMODULE)secureLib_,
    "InitSecurityInterfaceA");

  pSSPI_ = pInitSecurityInterface();
  Socket_ = Socket;
  ServerName_ = ServerName;
  pbIoBufferOut = NULL;
  BufferSize = 0;
  Error_ = NULL;
}

WinCryptSock::~WinCryptSock()
{
  if (hMyCertStore_) {
    CertCloseStore(hMyCertStore_, 0);
  }

  ((PSecurityFunctionTableA)pSSPI_)->DeleteSecurityContext((PCtxtHandle)&hContext_);


  delete[] pbIoBufferOut;
  FreeLibrary((HMODULE)secureLib_);
}

bool WinCryptSock::CreateCredentials()
{
  DWORD   dwProtocol = 0;
  SCHANNEL_CRED   SchannelCred;

  TimeStamp       tsExpiry;
  SECURITY_STATUS Status;

  DWORD           cSupportedAlgs = 0;
  ALG_ID          rgbSupportedAlgs[16];

  PCCERT_CONTEXT  pCertContext = NULL;

  hMyCertStore_ = CertOpenSystemStoreA(0, "MY");

  if (!hMyCertStore_)
  {
    SetError("WinCrypt system sertificate store opening error");
    return SUCCEEDED(SEC_E_NO_CREDENTIALS);
  }


  ZeroMemory(&SchannelCred, sizeof(SchannelCred));

  SchannelCred.dwVersion = SCHANNEL_CRED_VERSION;
  if (pCertContext)
  {
    SchannelCred.cCreds = 1;
    SchannelCred.paCred = &pCertContext;
  }

  SchannelCred.grbitEnabledProtocols = dwProtocol;

  if (cSupportedAlgs)
  {
    SchannelCred.cSupportedAlgs = cSupportedAlgs;
    SchannelCred.palgSupportedAlgs = rgbSupportedAlgs;
  }

  SchannelCred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;

  SchannelCred.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION;

  Status = ((PSecurityFunctionTableA)pSSPI_)->AcquireCredentialsHandleA(
    NULL,                   // Name of principal    
    UNISP_NAME_A,           // Name of package
    SECPKG_CRED_OUTBOUND,   // Flags indicating use
    NULL,                   // Pointer to logon ID
    &SchannelCred,          // Package specific data
    NULL,                   // Pointer to GetKey() func
    NULL,                   // Value to pass to GetKey()
    (PCredHandle)&hCreds_,  // (out) Cred Handle
    &tsExpiry);             // (out) Lifetime (optional)
  if (Status != SEC_E_OK)
  {
    SetError("WinCrypt system sertificate acquire error");
    goto cleanup;
  }

cleanup:

  if (pCertContext)
  {
    CertFreeCertificateContext(pCertContext);
  }

  return SUCCEEDED(Status);
}

bool WinCryptSock::Handshake()
{
  CtxtHandle      &hContext = *(CtxtHandle*)&hContext_;
  SecBufferDesc   OutBuffer;
  SecBuffer       OutBuffers[1];
  DWORD           dwSSPIFlags;
  unsigned long   dwSSPIOutFlags;
  TimeStamp       tsExpiry;
  SECURITY_STATUS scRet;
  DWORD           cbData;

  dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
    ISC_REQ_REPLAY_DETECT |
    ISC_REQ_CONFIDENTIALITY |
    ISC_RET_EXTENDED_ERROR |
    ISC_REQ_ALLOCATE_MEMORY |
    ISC_REQ_STREAM;

  OutBuffers[0].pvBuffer = NULL;
  OutBuffers[0].BufferType = SECBUFFER_TOKEN;
  OutBuffers[0].cbBuffer = 0;

  OutBuffer.cBuffers = 1;
  OutBuffer.pBuffers = OutBuffers;
  OutBuffer.ulVersion = SECBUFFER_VERSION;

  scRet = ((PSecurityFunctionTableA)pSSPI_)->InitializeSecurityContextA(
    (PCredHandle)&hCreds_,
    NULL,
    (char*)ServerName_,
    dwSSPIFlags,
    0,
    SECURITY_NATIVE_DREP,
    NULL,
    0,
    &hContext,
    &OutBuffer,
    &dwSSPIOutFlags,
    &tsExpiry);

  if (scRet != SEC_I_CONTINUE_NEEDED)
  {
    SetError("WinCrypt context initialization error");
    return false;
  }

  if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL)
  {
    cbData = send(Socket_,
      (const char*)OutBuffers[0].pvBuffer,
      OutBuffers[0].cbBuffer,
      0);

    if (cbData <= 0)
    {
      ((PSecurityFunctionTableA)pSSPI_)->FreeContextBuffer(OutBuffers[0].pvBuffer);
      ((PSecurityFunctionTableA)pSSPI_)->DeleteSecurityContext(&hContext);
      SetError("WinCrypt network error");
      return false;
    }

    ((PSecurityFunctionTableA)pSSPI_)->FreeContextBuffer(OutBuffers[0].pvBuffer);
    OutBuffers[0].pvBuffer = NULL;
  }

  return SUCCEEDED(HandshakeLoop());
}


int WinCryptSock::HandshakeLoop()
{
  SecBuffer       ExtraData;

  SecBufferDesc   InBuffer;
  SecBuffer       InBuffers[2];
  SecBufferDesc   OutBuffer;
  SecBuffer       OutBuffers[1];
  DWORD           dwSSPIFlags;
  unsigned long   dwSSPIOutFlags;
  TimeStamp       tsExpiry;
  SECURITY_STATUS scRet;
  DWORD           cbData;


#define IO_BUFFER_SIZE  0x8000

  char            IoBuffer[IO_BUFFER_SIZE];
  DWORD           cbIoBuffer;
  BOOL            fDoRead;


  dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
    ISC_REQ_REPLAY_DETECT |
    ISC_REQ_CONFIDENTIALITY |
    ISC_RET_EXTENDED_ERROR |
    ISC_REQ_ALLOCATE_MEMORY |
    ISC_REQ_STREAM;

  cbIoBuffer = 0;

  fDoRead = TRUE;

  scRet = SEC_I_CONTINUE_NEEDED;

  while (scRet == SEC_I_CONTINUE_NEEDED ||
    scRet == SEC_E_INCOMPLETE_MESSAGE ||
    scRet == SEC_I_INCOMPLETE_CREDENTIALS)
  {

    if (0 == cbIoBuffer || scRet == SEC_E_INCOMPLETE_MESSAGE)
    {
      if (fDoRead)
      {
        cbData = recv(Socket_,
          IoBuffer + cbIoBuffer,
          IO_BUFFER_SIZE - cbIoBuffer,
          0);
        if (int(cbData) <= 0)
        {
          scRet = SEC_E_INTERNAL_ERROR;
          SetError("WinCrypt network error");
          break;
        }

        cbIoBuffer += cbData;
      }
      else
      {
        fDoRead = TRUE;
      }
    }


    InBuffers[0].pvBuffer = IoBuffer;
    InBuffers[0].cbBuffer = cbIoBuffer;
    InBuffers[0].BufferType = SECBUFFER_TOKEN;

    InBuffers[1].pvBuffer = NULL;
    InBuffers[1].cbBuffer = 0;
    InBuffers[1].BufferType = SECBUFFER_EMPTY;

    InBuffer.cBuffers = 2;
    InBuffer.pBuffers = InBuffers;
    InBuffer.ulVersion = SECBUFFER_VERSION;

    OutBuffers[0].pvBuffer = NULL;
    OutBuffers[0].BufferType = SECBUFFER_TOKEN;
    OutBuffers[0].cbBuffer = 0;

    OutBuffer.cBuffers = 1;
    OutBuffer.pBuffers = OutBuffers;
    OutBuffer.ulVersion = SECBUFFER_VERSION;

    scRet = ((PSecurityFunctionTableA)pSSPI_)->InitializeSecurityContextA(
      (PCredHandle)&hCreds_,
      (PCtxtHandle)&hContext_,
      NULL,
      dwSSPIFlags,
      0,
      SECURITY_NATIVE_DREP,
      &InBuffer,
      0,
      NULL,
      &OutBuffer,
      &dwSSPIOutFlags,
      &tsExpiry);


    if (scRet == SEC_E_OK ||
      scRet == SEC_I_CONTINUE_NEEDED ||
      FAILED(scRet) && (dwSSPIOutFlags & ISC_RET_EXTENDED_ERROR))
    {
      if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL)
      {
        cbData = send(Socket_,
          (const char*)OutBuffers[0].pvBuffer,
          OutBuffers[0].cbBuffer,
          0);
        if (int(cbData) <= 0)
        {
          ((PSecurityFunctionTableA)pSSPI_)->FreeContextBuffer(OutBuffers[0].pvBuffer);
          ((PSecurityFunctionTableA)pSSPI_)->DeleteSecurityContext((PCtxtHandle)&hContext_);
          SetError("WinCrypt network error");
          return SEC_E_INTERNAL_ERROR;
        }

        ((PSecurityFunctionTableA)pSSPI_)->FreeContextBuffer(OutBuffers[0].pvBuffer);
        OutBuffers[0].pvBuffer = NULL;
      }
    }

    if (scRet == SEC_E_INCOMPLETE_MESSAGE)
    {
      continue;
    }

    if (scRet == SEC_E_OK)
    {
      if (InBuffers[1].BufferType == SECBUFFER_EXTRA)
      {
        ExtraData.pvBuffer = LocalAlloc(LMEM_FIXED,
          InBuffers[1].cbBuffer);
        if (ExtraData.pvBuffer == NULL)
        {
          SetError("WinCrypt memory initialization error");
          return SEC_E_INTERNAL_ERROR;
        }

        MoveMemory(ExtraData.pvBuffer,
          IoBuffer + (cbIoBuffer - InBuffers[1].cbBuffer),
          InBuffers[1].cbBuffer);

        ExtraData.cbBuffer = InBuffers[1].cbBuffer;
        ExtraData.BufferType = SECBUFFER_TOKEN;
      }
      else
      {
        ExtraData.pvBuffer = NULL;
        ExtraData.cbBuffer = 0;
        ExtraData.BufferType = SECBUFFER_EMPTY;
      }

      {
        // Sample for usage QueryContextAttributes() & native TimeStamp
        SecPkgContext_Lifespan ls;
        double hi, lo;
        time_t clock_1970;

        scRet = ((PSecurityFunctionTableA)pSSPI_)->QueryContextAttributesA((PCtxtHandle)&hContext_,
          SECPKG_ATTR_LIFESPAN, &ls);
        hi = ls.tsExpiry.HighPart;
        lo = ls.tsExpiry.LowPart;
        // Convert 100-ns interval since January 1, 1601 (UTC) 
        // to 1-sec interval science January 1, 1970, UTC
        clock_1970 = (time_t)(
          ((ldexp(hi, 32) + lo)*100.e-9)
          - 11644473600. //SystemTimeToFileTime({.wYear = 1970, .wMonth = 1, .wDay = 1}... 
          );
      }
      {
        // Sample for usage QueryContextAttributes() & FileTimeToSystemTime()
        FILETIME   ft;
        SYSTEMTIME st;
        unsigned hi;
        unsigned lo;
        SecPkgContext_Lifespan ls;

        scRet = ((PSecurityFunctionTableA)pSSPI_)->QueryContextAttributesA((PCtxtHandle)&hContext_,
          SECPKG_ATTR_LIFESPAN, &ls);
        hi = ls.tsStart.HighPart;
        lo = ls.tsStart.LowPart;
        memcpy(&ft, &ls.tsStart, sizeof(ft));
        FileTimeToSystemTime(&ft, &st);

        hi = ls.tsExpiry.HighPart;
        lo = ls.tsExpiry.LowPart;
        memcpy(&ft, &ls.tsExpiry, sizeof(ft));
        FileTimeToSystemTime(&ft, &st);
      }
      break;
    }

    if (FAILED(scRet))
    {
      break;
    }

    if (scRet == SEC_I_INCOMPLETE_CREDENTIALS)
    {

      ///GetNewClientCredentials(phCreds, phContext);

      fDoRead = FALSE;
      scRet = SEC_I_CONTINUE_NEEDED;

      cbIoBuffer = 0;
      continue;
    }


    if (InBuffers[1].BufferType == SECBUFFER_EXTRA)
    {
      MoveMemory(IoBuffer,
        IoBuffer + (cbIoBuffer - InBuffers[1].cbBuffer),
        InBuffers[1].cbBuffer);

      cbIoBuffer = InBuffers[1].cbBuffer;
    }
    else
    {
      cbIoBuffer = 0;
    }
  }

  if (FAILED(scRet))
  {
    ((PSecurityFunctionTableA)pSSPI_)->DeleteSecurityContext((PCtxtHandle)&hContext_);
    if (!Error_)
    {
      SetError("WinCrypt general error");
    }
  }

  return scRet;
}

bool WinCryptSock::PrepareBuffers()
{
  if (BufferSize)
  {
    return true;
  }
  SecPkgContext_StreamSizes Sizes;
  SECURITY_STATUS scRet = ((PSecurityFunctionTableA)pSSPI_)->QueryContextAttributesA((PCtxtHandle)&hContext_,
    SECPKG_ATTR_STREAM_SIZES,
    &Sizes);
  if (FAILED(scRet))
  {
    SetError("WinCrypt context attributes error");
    return false;
  }
  BufferSize = Sizes.cbHeader +
    Sizes.cbMaximumMessage +
    Sizes.cbTrailer;
  pbIoBufferOut = new char[BufferSize];
  BeginSize = Sizes.cbHeader;
  EndSize = Sizes.cbTrailer;
  return true;
}

int WinCryptSock::Encrypt(std::vector<char>& out, const char * inBuf, int bufLen)
{
  if (!PrepareBuffers())
  {
    return -1;
  }
  int toEncodeSize = bufLen;

  while (toEncodeSize > 0) {
    int freeBufferSize = BufferSize - BeginSize - EndSize;

    if (freeBufferSize == 0) {
      return -1;
    }

    if (toEncodeSize > freeBufferSize) {
      bufLen = freeBufferSize;
      toEncodeSize -= freeBufferSize;
    }
    else {
      bufLen = toEncodeSize;
      toEncodeSize = 0;
    }

    memcpy(pbIoBufferOut + BeginSize, inBuf, bufLen);
    inBuf += bufLen;

    SECURITY_STATUS scRet;
    SecBufferDesc   Message;
    SecBuffer       Buffers[4];


    Buffers[0].pvBuffer = pbIoBufferOut;
    Buffers[0].cbBuffer = BeginSize;
    Buffers[0].BufferType = SECBUFFER_STREAM_HEADER;

    Buffers[1].pvBuffer = pbIoBufferOut + BeginSize;
    Buffers[1].cbBuffer = bufLen;
    Buffers[1].BufferType = SECBUFFER_DATA;

    Buffers[2].pvBuffer = pbIoBufferOut + BeginSize + bufLen;
    Buffers[2].cbBuffer = EndSize;
    Buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;

    Buffers[3].BufferType = SECBUFFER_EMPTY;

    Message.ulVersion = SECBUFFER_VERSION;
    Message.cBuffers = 4;
    Message.pBuffers = Buffers;

    scRet = ((PSecurityFunctionTableA)pSSPI_)->EncryptMessage((PCtxtHandle)&hContext_, 0, &Message, 0);
    if (FAILED(scRet))
    {
      SetError("WinCrypt EncryptMessage error");
      return -1;
    }
    int size = Buffers[0].cbBuffer + Buffers[1].cbBuffer + Buffers[2].cbBuffer;
    out.insert(out.end(), pbIoBufferOut, pbIoBufferOut + size);
  }
  return out.size();
}

int WinCryptSock::Decrypt(std::vector<char>& out, const char * inBuf, int bufLen)
{
  SecBuffer       Buffers[4];
  SecBufferDesc   Message;
  SECURITY_STATUS scRet;
  SecBuffer *     pDataBuffer;
  SecBuffer *     pExtraBuffer;

  decoderBuffer.insert(decoderBuffer.end(), inBuf, inBuf + bufLen);

  while (decoderBuffer.size()) {
    Buffers[0].pvBuffer = &decoderBuffer.front();
    Buffers[0].cbBuffer = decoderBuffer.size();
    Buffers[0].BufferType = SECBUFFER_DATA;

    Buffers[1].BufferType = SECBUFFER_EMPTY;
    Buffers[2].BufferType = SECBUFFER_EMPTY;
    Buffers[3].BufferType = SECBUFFER_EMPTY;

    Message.ulVersion = SECBUFFER_VERSION;
    Message.cBuffers = 4;
    Message.pBuffers = Buffers;

    scRet = ((PSecurityFunctionTableA)pSSPI_)->DecryptMessage((PCtxtHandle)&hContext_, &Message, 0, NULL);

    if (scRet == SEC_E_INCOMPLETE_MESSAGE) { //need more data
      break;
    }

    if (FAILED(scRet))
    {
      SetError("WinCrypt DecryptMessage error");
      return -1;
    }

    pDataBuffer = NULL;
    pExtraBuffer = NULL;
    for (int i = 1; i < 4; i++)
    {
      if (pDataBuffer == NULL && Buffers[i].BufferType == SECBUFFER_DATA)
      {
        pDataBuffer = &Buffers[i];
      }
      if (pExtraBuffer == NULL && Buffers[i].BufferType == SECBUFFER_EXTRA)
      {
        pExtraBuffer = &Buffers[i];
      }
    }
    if (pDataBuffer)
    {
      out.insert(out.end(), (char*)pDataBuffer->pvBuffer, (char*)pDataBuffer->pvBuffer + pDataBuffer->cbBuffer);
      decoderBuffer.resize(0);
    }
    if (pExtraBuffer)
    {
      decoderBuffer.assign((char*)pExtraBuffer->pvBuffer, (char*)pExtraBuffer->pvBuffer + pExtraBuffer->cbBuffer);
    }
  }
  return out.size();
}

void WinCryptSock::SetError(const char* Error)
{
  Error_ = Error;
}
