0%

2023-HITCON-CTF-Crypto-WP

HITCON2023 Crypto 部分题解

Careless Padding

题目:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/local/bin/python
import random
import os
from secret import flag
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import signal
import json
import socketserver

N = 16

# 0 -> 0, 1~N -> 1, (N+1)~(2N) -> 2 ...
def count_blocks(length):
block_count = (length-1) // N + 1
return block_count

def find_repeat_tail(message):

Y = message[-1]
message_len = len(message)
for i in range(len(message)-1, -1, -1):
if message[i] != Y:
X = message[i]
message_len = i + 1
break
return message_len, X, Y

def my_padding(message):
# X,Y
message_len = len(message)
block_count = count_blocks(message_len)
result_len = block_count * N
if message_len % N == 0:
result_len += N
X = message[-1]
Y = message[(block_count-2)*N+(X%N)]
# ???????? XXXXXXXC
if X==Y:
Y = Y^1 #防冲突的
padded = message.ljust(result_len, bytes([Y]))
return padded

C = D ^ IV

def my_unpad(message):
message_len, X, Y = find_repeat_tail(message)
block_count = count_blocks(message_len)
_Y = message[(block_count-2)*N+(X%N)]
if (Y != _Y and Y != _Y^1):
raise ValueError("Incorrect Padding")
return message[:message_len]

k = os.urandom(16)
m = json.dumps({'key':flag}).encode()
iv = os.urandom(16)
cipher = AES.new(k, AES.MODE_CBC, iv)
padded = my_padding(m)
enc = cipher.encrypt(padded)

band = (f"""
*********************************************************
You are put into the careless prison and trying to escape.
Thanksfully, someone forged a key for you, but seems like it's encrypted...
Fortunately they also leave you a copied (and apparently alive) prison door.
The replica pairs with this encrypted key. Wait, how are this suppose to help?
Anyway, here's your encrypted key: {(iv+enc).hex()}
*********************************************************
""").encode()

class Task(socketserver.BaseRequestHandler):
def _recvall(self):
BUFF_SIZE = 2048
data = b''
while True:
part = self.request.recv(BUFF_SIZE)
data += part
if len(part) < BUFF_SIZE:
break
return data.strip()

def send(self, msg, newline=True):
try:
if newline:
msg += b'\n'
self.request.sendall(msg)
except:
pass

def recv(self, prompt=b''):
self.send(prompt, newline=False)
return self._recvall()

def timeout_handler(self, signum, frame):
raise TimeoutError

def handle(self):
signal.signal(signal.SIGALRM, self.timeout_handler)
signal.alarm(300)
self.send(band)
while True:
self.send(b"Try unlock:")
enc = self.recv()
enc = bytes.fromhex(enc.decode())
iv = enc[:16]
cipher = AES.new(k, AES.MODE_CBC, iv)
try:
message = my_unpad(cipher.decrypt(enc[16:]))
if message == m:
self.send(str(m).encode())
self.send(b"Hey you unlock me! At least you know how to use the key")
else:
self.send(b"Bad key... do you even try?")
except ValueError:
self.send(b"Don't put that weirdo in me!")
except Exception:
self.send(b"What? Are you trying to unlock me with a lock pick?")


class ThreadedServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass

class ForkedServer(socketserver.ForkingMixIn, socketserver.TCPServer):
pass

if __name__ == "__main__":
HOST, PORT = '0.0.0.0', 10002
server = ForkedServer((HOST, PORT), Task)
server.allow_reuse_address = True
print(HOST, PORT)
server.serve_forever()

(这里稍微改了一下题目的部署方式)

exp:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from pwn import *
from Crypto.Util.number import long_to_bytes
import time
#context(os='linux', arch='amd64', log_level='debug')
N = 16

#p = process("./chal.py")
#p = remote("127.0.0.1", 11111)
p = remote("0.0.0.0", 10002)

p.recvuntil(b"key: ")
cipher = p.recvline()
cipher = bytes.fromhex(cipher.decode())
p.recvline()
p.recvline()

IV = cipher[:16]
FB = cipher[16:32]
Z = b"\x00"*16
I = [255]*16
known = b'{"key": "hitcon{'[:16]
IV0 = xor(IV, known) #第一个中间向量d

oracle_count = 0

def oracle(m):
global oracle_count
oracle_count += 1
p.sendline(m.hex().encode())
return not b"weirdo" in p.recvline()

# async send to reduce network lag
def oracle_multi(ms):#用来判断是否是有效攻击的
# res = [oracle(m) for m in ms]
# return res
global oracle_count
l = len(ms)
oracle_count += l
res = [0] * l
for i in range(l):
p.sendline(ms[i].hex().encode())
res[i] = not b"weirdo" in p.recvline()
return res

offset_db = [-1 for i in range(16)]

def get_offset_db():
cur = 0;
for i in range(256):
if not -1 in offset_db:
break
ciphers = []
for offset in range(16):
OFF = (offset ^ FB[-1]) % 16
check = I[:]
check[offset] = 0
cipher = xor(IV0, check, i) + FB + xor(IV0, i) + FB
ciphers.append(cipher)
res = oracle_multi(ciphers)
if res.count(True) == 1:
offset = res.index(True)
OFF = (offset ^ FB[-1]) % 16
offset_db[OFF] = i
return offset_db


def oracle_block_top(BIV, BC):
res = [0] * 16

for offset in range(16):
# get top 7 bit
real_offset = (offset ^ BC[-1]) % 16
IVL = xor(IV0, offset_db[offset])
top_7 = -1
ciphers = []
for diff in range(0, 256, 2):
check = list(BIV[:])
check[real_offset] ^= diff
cipher = xor(BIV, check) + BC + IVL + FB
ciphers.append(cipher)

res2 = oracle_multi(ciphers)
result = list(zip(res2, range(0, 256, 2)))
for ora, diff in result:
if ora:#如果是有效的就能得到一个顶比特
print(offset, diff)
top_7 = (diff ^ BIV[real_offset] ^ offset_db[offset]) & 0xfe
res[real_offset] = top_7
break
else:
# honestly I don't know what happned here
# Sometime things just fall through for some reason...
raise ValueError("Padding not found")

return res

def oracle_block_lower(BIV, BC, Mtop):
# 14th byte first, use as anchor
# make sure Mtop[-1] != Mtop[-2]
# cipher: control IV1 | control IV2 | BC
# IV2 -> BIV + offset to control Mtop decrypt result -> partial known X, Y
# IV1 -> use to bruteforce all permutation
lowers = [0] * 16

baseIV = xor(BIV, Mtop) # so decrypt(BC, iv = baseIV) will only contain 0 or 1
diff = [0] * 16
diff[-2] = 0xf0 # make sure it don't propagate
IV2 = xor(baseIV, diff)

# we check if some value in the first location match
# yes -> last bit of Mtop[-2] is 0
# no -> last bit of Mtop[-2] is 1
ciphers = []
for brute in range(0, 256, 2):
IV1 = [brute] + [0] * 15
ciphers.append(bytes(IV1) + IV2 + BC)

if oracle_multi(ciphers).count(True) == 1:
lowers[-2] = 0
else:
lowers[-2] = 1

# now we check if the last bit is the same as Mtop[-2]
diff = [0] * 16
diff[-3] = 0xf8 # make sure it don't propagate
IV2 = xor(baseIV, diff, lowers)
# we check if some value in the first location match
# yes -> Mtop[-2] is X -> last bit of Mtop[-1] is 1
# no -> Mtop[-2] is not X -> last bit of Mtop[-1] is 0
ciphers = []
for brute in range(0, 256, 2):
IV1 = [brute] + [0] * 15
ciphers.append(bytes(IV1) + IV2 + BC)

if oracle_multi(ciphers).count(True) == 1:
lowers[-1] = 1
else:
lowers[-1] = 0

# now we can consistantly form repeating tail
# fill the rest of the lower bits
for loc in range(13, -1, -1):
diff = [0] * 16
diff[loc] = 0xf0 # make sure it don't propagate
IV2 = xor(baseIV, diff, lowers)
ciphers = []
for brute in range(0, 256, 2):
IV1 = [brute] + [0] * 15
ciphers.append(bytes(IV1) + IV2 + BC)

if oracle_multi(ciphers).count(True) == 1:
lowers[loc] = 0
else:
lowers[loc] = 1

return [i+j for i, j in zip(Mtop, lowers)]

def oracle_block(BIV, BC):
tops = oracle_block_top(BIV, BC)
full = oracle_block_lower(BIV, BC, tops)
return full


def attack():
offset_db = get_offset_db()
print(offset_db, oracle_count)
m = known
for loc in range(32, len(cipher), 16):
m+=bytes(oracle_block(cipher[loc-16:loc], cipher[loc:loc+16]))
print(m)

print(m)
print(oracle_count)


if __name__ == "__main__":
attack()