Last weekend, I played ISITDTU Quals for an onsite final slot. The crypto challenges were pretty nice, and I managed to solve 3/5 including: ShareMixer1, ShareMixer2 and Sign.

ShareMixer1#

Source Code#

import random   # TODO: heard that this is unsafe but nvm
from Crypto.Util.number import getPrime, bytes_to_long

flag = bytes_to_long(open("flag.txt", "rb").read())
p = getPrime(256)
assert flag < p
l = 32


def share_mixer(xs):
    cs = [random.randint(1, p - 1) for _ in range(l - 1)]
    cs.append(flag)

    # mixy mix
    random.shuffle(xs)
    random.shuffle(cs)

    shares = [sum((c * pow(x, i, p)) %
                  p for i, c in enumerate(cs)) % p for x in xs]
    return shares


if __name__ == "__main__":
    try:
        print(f"{p = }")
        queries = input("Gib me the queries: ")
        xs = list(map(lambda x: int(x) % p, queries.split()))

        if 0 in xs or len(xs) > 256:
            print("GUH")
            exit(1)

        shares = share_mixer(xs)
        print(f"{shares = }")
    except:
        exit(1)

Challenge Analysis#

The challenge can be summarized as follow: When connecting to the server, we are asked to submit maximum 256 numbers $x$ and obtain the value of $f(x) = \sum_{i = 0}^{32}{a_ix^{i}}\pmod{p}$. The flag is shuffled among $a_i$ values.

Solution#

  • This is a Shamir Secret Sharing challenge and we can recover the original polynomial using Lagrange interpolation. However, we do not know which $x$ correspond to which $f(x)$ since the value of xs are shuffled. Here, I notice that the server allows us to send 256 queries, which is more than usual (32 shares is sufficient to recover coefficients of polynomial with degree 31). After a while, I realize that we can distinguish the shares by number: Send the same $x$ multiple times and observed the number of occurences in the responses. After that, some bruteforce + Lagrange interpolation is enough to recover the flag.
  • Here is my approach:
    • Generate $xs = [1..32]$
    • Send each $xs_i$ once for $i = 0..3$
    • Send each $xs_i$ twice for $i = 4..7$
    • Send each $xs_i$ 3 times for $i = 8..11$
    • Send each $xs_i$ 4 times for $i = 12..15$
    • Send each $xs_i$ $i - 11$ times for $i = 16..31$

=> I will need a total of $(1 + 2 + 3 + 4) \times 4 + \sum_{i=16}^{31}{i - 11} = 240$ queries and bruteforce $(4!) ^ 4 = 331776$ cases to recover the flag

Solve Script#

from pwnlib.tubes.remote import remote
from collections import Counter
from Crypto.Util.number import long_to_bytes
from itertools import permutations
from sage.all import *
import hashlib

xs = [i + 1 for i in range(32)]
xs_str = ''
for i in range(32):
    if i < 4:
        xs_str += (str(xs[i]) + ' ')
    elif i < 8:
        xs_str += (str(xs[i]) + ' ') * 2
    elif i < 12:
        xs_str += (str(xs[i]) + ' ') * 3
    elif i < 16:
        xs_str += (str(xs[i]) + ' ') * 4
    else:
        xs_str += (str(xs[i]) + ' ') * (i - 11)
xs_str = xs_str[:-1]        
pos = [-1 for i in range(32)]
pos[0], pos[1], pos[2], pos[3] = [], [], [], []
CONN = remote("35.187.238.100", 5001)
CONN.recvlines(3)
prefix = CONN.recvline().decode().split('"')[1]
print(f'{prefix = }')
i = 0
while True:
    i += 1
    s = prefix + str(i)
    if hashlib.sha256(s.encode()).hexdigest()[:6] == "000000":
        CONN.sendlineafter(b': ', str(i).encode())
        break
        
p = int(CONN.recvline().decode().split(" = ")[-1])
print(f'{p = }')
CONN.sendlineafter(b': ', xs_str.encode())
shares = eval(CONN.recvline().decode().split(" = ")[-1])
CONN.close()

K = GF(p)
R = K['x']
for k, v in Counter(shares).items():
    if v not in [1, 2, 3, 4]:
        pos[v + 10] = k
    else:
        pos[v - 1].append(k)

shares = [-1 for _ in range(32)]
for i in range(32):
    if i >= 16:
        shares[i] = (i + 1, pos[i])

len1 = list(range(4))
len2 = list(range(4, 8))
len3 = list(range(8, 12))
len4 = list(range(12, 16))
for perm1 in permutations(pos[0], 4):
    cur_shares1 = shares.copy()
    for i, v in zip(len1, perm1):
        cur_shares1[i] = (i + 1, v)
    for perm2 in permutations(pos[1], 4):
        cur_shares2 = cur_shares1.copy()
        for i, v in zip(len2, perm2):
            cur_shares2[i] = (i + 1, v)
        for perm3 in permutations(pos[2], 4):
            cur_shares3 = cur_shares2.copy()
            for i, v in zip(len3, perm3):
                cur_shares3[i] = (i + 1, v)
            for perm4 in permutations(pos[3], 4):
                cur_shares4 = cur_shares3.copy()
                for i, v in zip(len4, perm4):
                    cur_shares4[i] = (i + 1, v)
                tmp = R.lagrange_polynomial(cur_shares4).coefficients()
                for x in tmp:
                    val = long_to_bytes(int(x))
                    if b'ISITDTU' in val:
                        print(f'Found flag: {val}')
                        exit()

ShareMixer2#

Source Code#

import random   # TODO: heard that this is unsafe but nvm
from Crypto.Util.number import getPrime, bytes_to_long

flag = bytes_to_long(open("flag.txt", "rb").read())
p = getPrime(256)
assert flag < p
l = 32


def share_mixer(xs):
    cs = [random.randint(1, p - 1) for _ in range(l - 1)]
    cs.append(flag)

    # mixy mix
    random.shuffle(xs)
    random.shuffle(cs)

    shares = [sum((c * pow(x, i, p)) %
                  p for i, c in enumerate(cs)) % p for x in xs]
    return shares


if __name__ == "__main__":
    try:
        print(f"{p = }")
        queries = input("Gib me the queries: ")
        xs = list(map(lambda x: int(x) % p, queries.split()))

        if 0 in xs or len(xs) > 32:
            print("GUH")
            exit(1)

        shares = share_mixer(xs)
        print(f"{shares = }")
    except:
        exit(1)

Solution#

This challenge is the same as ShareMixer1, except that I can send maximum 32 $x$. This beats the approach of ShareMixer1. However, I can choose xs as the 32-th roots of unity modulo $p$ since their sum is 0 modulo $p$. Denote $x$ as a 32th-root of unity modulo $p$, by definition we have: $$x^{32} - 1 = 0 \pmod{p}$$ $$\Rightarrow (x - 1)\sum_{i=0}^{31}{x^i} = 0\pmod{p}$$ $$\Rightarrow \sum_{i=0}^{31}{x^i} = 0\pmod{p}$$ The above equation holds when $x - 1 \neq 0 \pmod{p}$. After choosing such $x$ and send $x^0, x^1, …, x^{31}$ to the server, I can sum all of the shares to eliminate all but the last coefficient. Also for this approach to work, 2 more conditions need to hold:

  1. $p - 1 = 0 \pmod{32}$
  2. The flag is the last coefficient.

I can overcome this by opening multiple connections, and the rest should be trivial.

Solve Script#

from pwnlib.tubes.remote import remote
from pwnlib.tubes.process import process
from random import randint
from Crypto.Util.number import long_to_bytes
while True:
    CONN = remote("35.187.238.100", 5002)
    p = int(CONN.recvline().strip().split(b" = ")[1])
    if (p - 1) % 32 != 0:
        CONN.close()
        continue
    print(f'{p = }')
    g = randint(2, p - 1)
    gk = pow(g, (p - 1) // 32, p)
    while gk == 1:
        g = randint(2, p - 1)
        gk = pow(g, (p - 1) // 32, p)
	print(f'{gk = }')
    xs = [pow(gk, i, p) for i in range(32)]
    xs_str = " ".join(map(str, xs))
    CONN.sendlineafter(b": ", xs_str.encode())
    
    shares = eval(CONN.recvline().strip().split(b" = ")[1])
    flag = sum(shares) * pow(32, -1, p) % p
    flag = long_to_bytes(int(flag))
    if b'ISITDTU' in flag:
        print(flag)
        exit()
    CONN.close()

Sign#

Source Code#

#!/usr/bin/env python3

import os

from Crypto.Util.number import *
from Crypto.Signature import PKCS1_v1_5
from Crypto.PublicKey import RSA
from Crypto.Hash import SHA256

flag = b'ISITDTU{aaaaaaaaaaaaaaaaaaaaaaaaaa}'
flag = os.urandom(255 - len(flag)) + flag


def genkey(e=11):
    while True:
        p = getPrime(1024)
        q = getPrime(1024)
        if GCD(p-1, e) == 1 and GCD(q-1, e) == 1:
            break
    n = p*q
    d = pow(e, -1, (p-1)*(q-1))
    return RSA.construct((n, e, d))


def gensig(key: RSA.RsaKey) -> bytes:
    m = os.urandom(256)
    h = SHA256.new(m)
    s = PKCS1_v1_5.new(key).sign(h)
    return s


def getflagsig(key: RSA.RsaKey) -> bytes:
    return long_to_bytes(pow(bytes_to_long(flag), key.d, key.n))


key = genkey()

while True:
    print(
        """=================
1. Generate random signature
2. Get flag signature
================="""
    )

    try:
        choice = int(input('> '))
        if choice == 1:
            sig = gensig(key)
            print('sig =', sig.hex())
        elif choice == 2:
            sig = getflagsig(key)
            print('sig =', sig.hex())
    except Exception as e:
        print('huh')
        exit(-1)

Challenge Analysis#

Upon the connection, the server generate a RSA key. i 2 options:

  1. Get PKCS#1 v1.5 signature of random message
  2. Get the signature of the flag: $s_{flag} = flag^{d} \pmod{n}$

Since we already have the public exponent $e = 11$, we just need to recover modulus $n$ and recover the flag thanks to RSA signature verification $flag = s_{flag}^e \pmod{n}$.

I’ll not dive deep into the scheme (since I don’t know much about it neither 😔). But here is a quick summary:

  1. The message (usually a digest) is padded to match modulus $n$ bit length
  2. The padding data is deterministic and consistent between messages.

In this challenge, $n$ size is 2048 bits and hashing algorithm is SHA256, implying that the digest size is 256 bits. Therefore, all padded digest will have the same $2048 - 256 = 1792$ MSBs. Denote the unpadded message as $m$ and its PKCS#1 v1.5 signature as $s_{PKCS}$, we can determine their relation as follow: $$ m_{padded} = known * 2^{256} + m \pmod{n}$$ $$s_{PKCS} = m_{padded}^{d} \pmod{n}$$ $$\Rightarrow s_{PKCS}^{e} = m_{padded} = known * 2^{256} + m \pmod{n}$$ $$\Rightarrow s_{PKCS}^{11} = known * 2^{256} + m + k * n$$ $$\Rightarrow s_{PKCS}^{11} - known * 2^{256} = m + k * n$$

The $m$ is pretty small (256 bits), and we can gather multiple signatures. This is an instance of Approximate Common Divisor (ACD) problem and we can use LLL to recover $n$. Here I used this repo for attack’s implementation. After some local testing, I figured out that around 20 signatures is sufficient to recover the public modulus. Ez peasy!!!

Solve Script#

from pwnlib.tubes.remote import remote
from Crypto.Util.number import bytes_to_long, long_to_bytes
import sys
import hashlib
from sage.all import *

def symmetric_mod(x, m):
    return int((x + m + m // 2) % m) - int(m // 2)

def attack(x, rho):
    assert len(x) >= 2, "At least two x values are required."
    R = 2 ** rho
    B = matrix(ZZ, len(x), len(x) + 1)
    for i, xi in enumerate(x):
        B[i, 0] = xi
        B[i, i + 1] = R
    B = B.LLL()
    K = B.submatrix(row=0, col=1, nrows=len(x) - 1, ncols=len(x)).right_kernel()
    print(B.submatrix(row=0, col=1, nrows=len(x) - 1, ncols=len(x)))
    q = K.an_element()
    r0 = symmetric_mod(x[0], q[0])
    p = abs((x[0] - r0) // q[0])
    r = [symmetric_mod(xi, p) for xi in x]
    if all(-R < ri < R for ri in r):
        return int(p), r
	
sys.set_int_max_str_digits(0)
known = bytes_to_long(b'\x00\x01\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00010\r\x06\t`\x86H\x01e\x03\x04\x02\x01\x05\x00\x04 ' + b'\x00' * 32)
CONN = remote("35.187.238.100", 5003)
CONN.recvlines(3)
prefix = CONN.recvline().decode().split('"')[1]
print(f'{prefix = }')
i = 0
while True:
    i += 1
    s = prefix + str(i)
    if hashlib.sha256(s.encode()).hexdigest()[:6] == "000000":
        CONN.sendlineafter(b': ', str(i).encode())
        break
    
sigs = []
SIZE = 20
for i in range(SIZE):
    print(f'{i = }')
    CONN.sendlineafter(b'> ', b'1')
    sig = int(CONN.recvline().decode().strip().split()[-1], 16)
    sigs.append(sig ** 11 - known)
CONN.sendlineafter(b'> ', b'2')
enc_flag = int(CONN.recvline().decode().strip().split()[-1], 16)
print(f'sigs = {[str(x) for x in sigs]}')
N, _ = attack(sigs, 257)
print(f'{N = }')

flag = pow(enc_flag, 11, N)
print(long_to_bytes(int(flag)))