NOW IT'S ALL ASYNC
This commit is contained in:
parent
6116030f66
commit
de25182868
|
@ -1,47 +1,21 @@
|
|||
# x-run: cd .. && python -m bta_proxy '201:4f8c:4ea:0:71ec:6d7:6f1b:a4f9'
|
||||
import asyncio
|
||||
import socket
|
||||
from asyncio.streams import StreamReader, StreamWriter
|
||||
from argparse import ArgumentParser
|
||||
from sys import argv
|
||||
|
||||
from .debug import debug_client, debug_server
|
||||
from bta_proxy.proxy import BTAProxy
|
||||
|
||||
MAX_SIZE = 0x400000
|
||||
|
||||
async def handle_server(writer: StreamWriter, server: socket.socket, fp):
|
||||
try:
|
||||
while (packet := await loop.sock_recv(server, MAX_SIZE)):
|
||||
try:
|
||||
debug_server(packet, fp)
|
||||
except Exception as e:
|
||||
print(f'[S] error: {e}')
|
||||
writer.write(packet)
|
||||
await writer.drain()
|
||||
except Exception as e:
|
||||
print(f'handle_server(): {e}')
|
||||
|
||||
async def handle_client(reader: StreamReader, writer: StreamWriter):
|
||||
print(reader, writer)
|
||||
try:
|
||||
server = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
||||
server.connect(('201:4f8c:4ea:0:71ec:6d7:6f1b:a4f9', 25565))
|
||||
server.setblocking(False)
|
||||
|
||||
with open("packets.txt", "w") as fp:
|
||||
loop.create_task(handle_server(writer, server, fp))
|
||||
|
||||
while (packet := await reader.read(MAX_SIZE)):
|
||||
try:
|
||||
debug_client(packet, fp)
|
||||
except Exception as e:
|
||||
print(f'[C] error: {e}')
|
||||
await loop.sock_sendall(server, packet)
|
||||
except Exception as e:
|
||||
print(f'handle_client(): {e}')
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
def main():
|
||||
loop.create_task(asyncio.start_server(handle_client, 'localhost', 25565))
|
||||
loop.run_forever()
|
||||
async def main(args):
|
||||
loop = asyncio.get_running_loop()
|
||||
server = await asyncio.start_server(BTAProxy(args[0], 25565, loop).handle_client, "localhost", 25565)
|
||||
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
asyncio.run(main(argv[1:]))
|
||||
|
|
|
@ -1,65 +1,55 @@
|
|||
from asyncio.queues import Queue
|
||||
import struct
|
||||
|
||||
class DataInputStream:
|
||||
def __init__(self, buffer: bytes):
|
||||
self._buffer = buffer
|
||||
self._cursor = 0
|
||||
class AsyncDataInputStream:
|
||||
def __init__(self, queue: Queue):
|
||||
self._queue = queue
|
||||
self._last = b''
|
||||
|
||||
def read_bytes(self, n: int) -> bytes:
|
||||
if self._cursor + n > len(self._buffer):
|
||||
raise EOFError('stream overread')
|
||||
blob = self._buffer[self._cursor : self._cursor + n]
|
||||
self._cursor += n
|
||||
return blob
|
||||
async def read_bytes(self, n: int) -> bytes:
|
||||
if len(self._last) < n:
|
||||
self._last += await self._queue.get()
|
||||
out, self._last = self._last[:n], self._last[n:]
|
||||
return out
|
||||
|
||||
def empty(self):
|
||||
return self._cursor >= len(self._buffer) - 1
|
||||
async def read(self) -> int:
|
||||
return (await self.read_bytes(1))[0]
|
||||
|
||||
def end(self) -> bytes:
|
||||
return self.read_bytes(len(self._buffer) - self._cursor)
|
||||
async def read_boolean(self) -> bool:
|
||||
return (await self.read()) != 0
|
||||
|
||||
def read_byte(self) -> int:
|
||||
if self._cursor >= len(self._buffer):
|
||||
print(f'\033[91mstream overread in {self._buffer}\033[0m')
|
||||
raise EOFError('stream overread')
|
||||
self._cursor += 1
|
||||
return self._buffer[self._cursor - 1]
|
||||
async def read_short(self) -> int:
|
||||
return struct.unpack('>h', await self.read_bytes(2))[0]
|
||||
|
||||
def read_boolean(self) -> bool:
|
||||
return self.read_byte() != 0
|
||||
async def read_ushort(self) -> int:
|
||||
return struct.unpack('>H', await self.read_bytes(2))[0]
|
||||
|
||||
def read_short(self) -> int:
|
||||
return struct.unpack('>h', self.read_bytes(2))[0]
|
||||
async def read_int(self) -> int:
|
||||
return struct.unpack('>i', await self.read_bytes(4))[0]
|
||||
|
||||
def read_ushort(self) -> int:
|
||||
return struct.unpack('>H', self.read_bytes(2))[0]
|
||||
async def read_uint(self) -> int:
|
||||
return struct.unpack('>I', await self.read_bytes(4))[0]
|
||||
|
||||
def read_int(self) -> int:
|
||||
return struct.unpack('>i', self.read_bytes(4))[0]
|
||||
async def read_long(self) -> int:
|
||||
return struct.unpack('>q', await self.read_bytes(8))[0]
|
||||
|
||||
def read_uint(self) -> int:
|
||||
return struct.unpack('>I', self.read_bytes(4))[0]
|
||||
async def read_ulong(self) -> int:
|
||||
return struct.unpack('>Q', await self.read_bytes(8))[0]
|
||||
|
||||
def read_long(self) -> int:
|
||||
return struct.unpack('>q', self.read_bytes(8))[0]
|
||||
async def read_float(self) -> float:
|
||||
return struct.unpack('>f', await self.read_bytes(4))[0]
|
||||
|
||||
def read_ulong(self) -> int:
|
||||
return struct.unpack('>Q', self.read_bytes(8))[0]
|
||||
async def read_double(self) -> float:
|
||||
return struct.unpack('>d', await self.read_bytes(8))[0]
|
||||
|
||||
def read_float(self) -> float:
|
||||
return struct.unpack('>f', self.read_bytes(4))[0]
|
||||
async def read_char(self) -> str:
|
||||
return chr(await self.read_ushort())
|
||||
|
||||
def read_double(self) -> float:
|
||||
return struct.unpack('>d', self.read_bytes(8))[0]
|
||||
|
||||
def read_char(self) -> str:
|
||||
return chr(self.read_ushort())
|
||||
|
||||
def read_varint(self, bits: int = 32) -> int:
|
||||
async def read_varint(self, bits: int = 32) -> int:
|
||||
value: int = 0
|
||||
position: int = 0
|
||||
while True:
|
||||
byte = self.read_byte()
|
||||
byte = await self.read()
|
||||
value |= (byte & 0x7F) << position
|
||||
if ((byte & 0x80) == 0):
|
||||
break
|
||||
|
@ -68,7 +58,7 @@ class DataInputStream:
|
|||
raise ValueError('varint too big')
|
||||
return value
|
||||
|
||||
def read_string(self) -> str:
|
||||
size = self.read_short()
|
||||
return self.read_bytes(size).decode('utf-8')
|
||||
async def read_string(self) -> str:
|
||||
size = await self.read_short()
|
||||
return (await self.read_bytes(size)).decode('utf-8')
|
||||
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
from collections.abc import Iterable
|
||||
|
||||
from bta_proxy.packets.base import Packet
|
||||
from bta_proxy.packets import *
|
||||
from .datainputstream import DataInputStream
|
||||
from typing import Generator, TextIO, TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
def chunks(gen: Iterable[T], size: int) -> Generator[list[T], None, None]:
|
||||
bucket: list[T] = []
|
||||
for item in gen:
|
||||
bucket.append(item)
|
||||
if len(bucket) >= size:
|
||||
yield bucket
|
||||
bucket.clear()
|
||||
if bucket:
|
||||
yield bucket
|
||||
|
||||
def debug_client(buffer: bytes, tmpfile: TextIO):
|
||||
stream = DataInputStream(buffer)
|
||||
|
||||
while not stream.empty():
|
||||
try:
|
||||
packet = Packet.parse_packet(stream)
|
||||
match packet.packet_id:
|
||||
case _:
|
||||
print('[C]', packet)
|
||||
except ValueError:
|
||||
# print(f'[C:rest] {stream.end()}')
|
||||
buf = stream.end()
|
||||
print(f"[C] {buf[0]=} {len(buf)=}, {buf=}", file=tmpfile)
|
||||
|
||||
|
||||
def debug_server(buffer: bytes, tmpfile: TextIO):
|
||||
stream = DataInputStream(buffer)
|
||||
|
||||
while not stream.empty():
|
||||
try:
|
||||
packet = Packet.parse_packet(stream)
|
||||
match packet.packet_id:
|
||||
case Packet50PreChunk.packet_id:
|
||||
continue
|
||||
case Packet38EntityStatus.packet_id:
|
||||
continue
|
||||
case _:
|
||||
print('[S]', packet)
|
||||
except ValueError:
|
||||
# print(f'[S:rest] {stream.end()}')
|
||||
buf = stream.end()
|
||||
print(f"[S] {buf[0]=} {len(buf)=}, {buf=}", file=tmpfile)
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
|
||||
from asyncio.queues import Queue
|
||||
|
||||
from bta_proxy.datainputstream import AsyncDataInputStream
|
||||
from bta_proxy.packets.base import Packet
|
||||
|
||||
|
||||
async def inspect_client(queue: Queue, addr: tuple[str, int]):
|
||||
dis = AsyncDataInputStream(queue)
|
||||
while True:
|
||||
pkt = await Packet.read_packet(dis)
|
||||
print("C", pkt)
|
||||
|
||||
async def inspect_server(queue: Queue, addr: tuple[str, int]):
|
||||
dis = AsyncDataInputStream(queue)
|
||||
while True:
|
||||
pkt = await Packet.read_packet(dis)
|
||||
print("S", pkt)
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
from typing import Any
|
||||
from bta_proxy.datainputstream import DataInputStream
|
||||
from bta_proxy.datainputstream import AsyncDataInputStream
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
@ -23,26 +23,26 @@ class DataItem:
|
|||
|
||||
class EntityData:
|
||||
@classmethod
|
||||
def read_from(cls, dis: DataInputStream) -> list[DataItem]:
|
||||
async def read_from(cls, dis: AsyncDataInputStream) -> list[DataItem]:
|
||||
items = []
|
||||
while (data := dis.read_byte()) != 0x7F:
|
||||
while (data := await dis.read()) != 0x7F:
|
||||
item_type = DataItemType((data & 0xE0) >> 5)
|
||||
item_id: int = data & 0x1F
|
||||
match item_type:
|
||||
case DataItemType.BYTE:
|
||||
items.append(DataItem(item_type, item_id, dis.read_byte()))
|
||||
items.append(DataItem(item_type, item_id, await dis.read()))
|
||||
case DataItemType.SHORT:
|
||||
items.append(DataItem(item_type, item_id, dis.read_short()))
|
||||
items.append(DataItem(item_type, item_id, await dis.read_short()))
|
||||
case DataItemType.FLOAT:
|
||||
items.append(DataItem(item_type, item_id, dis.read_float()))
|
||||
items.append(DataItem(item_type, item_id, await dis.read_float()))
|
||||
case DataItemType.STRING:
|
||||
items.append(DataItem(item_type, item_id, dis.read_string()))
|
||||
items.append(DataItem(item_type, item_id, await dis.read_string()))
|
||||
case DataItemType.ITEMSTACK:
|
||||
items.append(DataItem(item_type, item_id, ItemStack.read_from(dis)))
|
||||
items.append(DataItem(item_type, item_id, await ItemStack.read_from(dis)))
|
||||
case DataItemType.CHUNK_COORDINATES:
|
||||
x = dis.read_float()
|
||||
y = dis.read_float()
|
||||
z = dis.read_float()
|
||||
x = await dis.read_float()
|
||||
y = await dis.read_float()
|
||||
z = await dis.read_float()
|
||||
items.append(DataItem(item_type, item_id, (x, y, z)))
|
||||
return items
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from bta_proxy.datainputstream import DataInputStream
|
||||
from bta_proxy.datainputstream import AsyncDataInputStream
|
||||
|
||||
|
||||
class ItemStack:
|
||||
|
@ -10,8 +10,8 @@ class ItemStack:
|
|||
self.data = data
|
||||
|
||||
@classmethod
|
||||
def read_from(cls, stream: DataInputStream) -> 'ItemStack':
|
||||
item_id = stream.read_short()
|
||||
count = stream.read_byte()
|
||||
data = stream.read_ushort()
|
||||
async def read_from(cls, stream: AsyncDataInputStream) -> 'ItemStack':
|
||||
item_id = await stream.read_short()
|
||||
count = await stream.read()
|
||||
data = await stream.read_ushort()
|
||||
return cls(item_id, count, data)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, ClassVar, Type
|
||||
|
||||
from bta_proxy.entitydata import EntityData
|
||||
from ..datainputstream import DataInputStream
|
||||
from ..datainputstream import AsyncDataInputStream
|
||||
|
||||
class Packet:
|
||||
REGISTRY: ClassVar[dict[int, Type['Packet']]] = {}
|
||||
|
@ -13,43 +13,43 @@ class Packet:
|
|||
setattr(self, k, v)
|
||||
|
||||
@classmethod
|
||||
def read_from(cls, stream: DataInputStream) -> 'Packet':
|
||||
async def read_data_from(cls, stream: AsyncDataInputStream) -> 'Packet':
|
||||
fields: dict = {}
|
||||
for key, datatype in cls.FIELDS:
|
||||
fields[key] = cls.read_field(stream, datatype)
|
||||
fields[key] = await cls.read_field(stream, datatype)
|
||||
return cls(**fields)
|
||||
|
||||
@staticmethod
|
||||
def read_field(stream: DataInputStream, datatype: Any):
|
||||
async def read_field(stream: AsyncDataInputStream, datatype: Any):
|
||||
match datatype:
|
||||
case 'uint':
|
||||
return stream.read_uint()
|
||||
return await stream.read_uint()
|
||||
case 'int':
|
||||
return stream.read_int()
|
||||
return await stream.read_int()
|
||||
case 'str':
|
||||
return stream.read_string()
|
||||
return await stream.read_string()
|
||||
case 'str', length:
|
||||
return stream.read_string()[:length]
|
||||
return (await stream.read_string())[:length]
|
||||
case 'ulong':
|
||||
return stream.read_ulong()
|
||||
return await stream.read_ulong()
|
||||
case 'long':
|
||||
return stream.read_long()
|
||||
return await stream.read_long()
|
||||
case 'ushort':
|
||||
return stream.read_ushort()
|
||||
return await stream.read_ushort()
|
||||
case 'short':
|
||||
return stream.read_short()
|
||||
return await stream.read_short()
|
||||
case 'byte':
|
||||
return stream.read_byte()
|
||||
return await stream.read()
|
||||
case 'float':
|
||||
return stream.read_float()
|
||||
return await stream.read_float()
|
||||
case 'double':
|
||||
return stream.read_double()
|
||||
return await stream.read_double()
|
||||
case 'bool':
|
||||
return stream.read_boolean()
|
||||
return await stream.read_boolean()
|
||||
case 'bytes', length:
|
||||
return stream.read_bytes(length)
|
||||
return await stream.read_bytes(length)
|
||||
case 'entitydata':
|
||||
return EntityData.read_from(stream)
|
||||
return await EntityData.read_from(stream)
|
||||
case _:
|
||||
raise ValueError(f'unknown type {datatype}')
|
||||
|
||||
|
@ -59,16 +59,17 @@ class Packet:
|
|||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def parse_packet(cls, stream: DataInputStream) -> 'Packet':
|
||||
packet_id: int = stream.read_byte()
|
||||
async def read_packet(cls, stream: AsyncDataInputStream) -> 'Packet':
|
||||
packet_id: int = await stream.read()
|
||||
if packet_id not in cls.REGISTRY:
|
||||
stream._cursor -= 1
|
||||
raise ValueError(f'invalid packet 0x{packet_id:02x}')
|
||||
return cls.REGISTRY[packet_id].read_from(stream)
|
||||
pkt = await cls.REGISTRY[packet_id].read_data_from(stream)
|
||||
pkt.packet_id = packet_id
|
||||
return pkt
|
||||
|
||||
def __repr__(self):
|
||||
pkt_name = self.REGISTRY[self.packet_id].__name__
|
||||
fields = []
|
||||
for name, _ in self.FIELDS:
|
||||
fields.append(f'{name}={getattr(self, name)!r}')
|
||||
return f'<{pkt_name} {str.join(", ", fields)}'
|
||||
return f'<{pkt_name} {str.join(", ", fields)}>'
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
from asyncio.protocols import Protocol
|
||||
from asyncio.queues import Queue
|
||||
from asyncio import AbstractEventLoop, get_event_loop
|
||||
from asyncio.streams import StreamReader, StreamWriter, open_connection
|
||||
from typing import Optional
|
||||
|
||||
from bta_proxy.dpi import inspect_client, inspect_server
|
||||
|
||||
|
||||
class BTAProxy:
|
||||
def __init__(self, host: str, port: int, loop: Optional[AbstractEventLoop] = None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.loop = loop or get_event_loop()
|
||||
|
||||
@staticmethod
|
||||
async def pipe(reader: StreamReader, writer: StreamWriter, queue: Queue):
|
||||
try:
|
||||
while not reader.at_eof():
|
||||
packet = await reader.read(0x400000)
|
||||
queue.put_nowait(packet)
|
||||
writer.write(packet)
|
||||
finally:
|
||||
writer.close()
|
||||
|
||||
async def handle_client(self, cli_reader: StreamReader, cli_writer: StreamWriter):
|
||||
try:
|
||||
peername = cli_writer.get_extra_info("peername")
|
||||
srv_reader, srv_writer = await open_connection(self.host, self.port)
|
||||
|
||||
queue_srv: Queue = Queue()
|
||||
queue_cli: Queue = Queue()
|
||||
|
||||
self.loop.create_task(inspect_client(queue_cli, peername))
|
||||
self.loop.create_task(inspect_server(queue_srv, peername))
|
||||
self.loop.create_task(self.pipe(cli_reader, srv_writer, queue_cli))
|
||||
self.loop.create_task(self.pipe(srv_reader, cli_writer, queue_srv))
|
||||
except Exception as e:
|
||||
print(f"oopsie whoopsie {e}")
|
Loading…
Reference in New Issue