NOW IT'S ALL ASYNC
This commit is contained in:
parent
6116030f66
commit
de25182868
|
@ -1,47 +1,21 @@
|
||||||
|
# x-run: cd .. && python -m bta_proxy '201:4f8c:4ea:0:71ec:6d7:6f1b:a4f9'
|
||||||
import asyncio
|
import asyncio
|
||||||
import socket
|
from argparse import ArgumentParser
|
||||||
from asyncio.streams import StreamReader, StreamWriter
|
from sys import argv
|
||||||
|
|
||||||
from .debug import debug_client, debug_server
|
from bta_proxy.proxy import BTAProxy
|
||||||
|
|
||||||
MAX_SIZE = 0x400000
|
MAX_SIZE = 0x400000
|
||||||
|
|
||||||
async def handle_server(writer: StreamWriter, server: socket.socket, fp):
|
|
||||||
try:
|
|
||||||
while (packet := await loop.sock_recv(server, MAX_SIZE)):
|
|
||||||
try:
|
|
||||||
debug_server(packet, fp)
|
|
||||||
except Exception as e:
|
|
||||||
print(f'[S] error: {e}')
|
|
||||||
writer.write(packet)
|
|
||||||
await writer.drain()
|
|
||||||
except Exception as e:
|
|
||||||
print(f'handle_server(): {e}')
|
|
||||||
|
|
||||||
async def handle_client(reader: StreamReader, writer: StreamWriter):
|
|
||||||
print(reader, writer)
|
|
||||||
try:
|
|
||||||
server = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
|
||||||
server.connect(('201:4f8c:4ea:0:71ec:6d7:6f1b:a4f9', 25565))
|
|
||||||
server.setblocking(False)
|
|
||||||
|
|
||||||
with open("packets.txt", "w") as fp:
|
|
||||||
loop.create_task(handle_server(writer, server, fp))
|
|
||||||
|
|
||||||
while (packet := await reader.read(MAX_SIZE)):
|
|
||||||
try:
|
|
||||||
debug_client(packet, fp)
|
|
||||||
except Exception as e:
|
|
||||||
print(f'[C] error: {e}')
|
|
||||||
await loop.sock_sendall(server, packet)
|
|
||||||
except Exception as e:
|
|
||||||
print(f'handle_client(): {e}')
|
|
||||||
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
async def main(args):
|
||||||
def main():
|
loop = asyncio.get_running_loop()
|
||||||
loop.create_task(asyncio.start_server(handle_client, 'localhost', 25565))
|
server = await asyncio.start_server(BTAProxy(args[0], 25565, loop).handle_client, "localhost", 25565)
|
||||||
loop.run_forever()
|
|
||||||
|
async with server:
|
||||||
|
await server.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
asyncio.run(main(argv[1:]))
|
||||||
|
|
|
@ -1,65 +1,55 @@
|
||||||
|
from asyncio.queues import Queue
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
class DataInputStream:
|
class AsyncDataInputStream:
|
||||||
def __init__(self, buffer: bytes):
|
def __init__(self, queue: Queue):
|
||||||
self._buffer = buffer
|
self._queue = queue
|
||||||
self._cursor = 0
|
self._last = b''
|
||||||
|
|
||||||
def read_bytes(self, n: int) -> bytes:
|
async def read_bytes(self, n: int) -> bytes:
|
||||||
if self._cursor + n > len(self._buffer):
|
if len(self._last) < n:
|
||||||
raise EOFError('stream overread')
|
self._last += await self._queue.get()
|
||||||
blob = self._buffer[self._cursor : self._cursor + n]
|
out, self._last = self._last[:n], self._last[n:]
|
||||||
self._cursor += n
|
return out
|
||||||
return blob
|
|
||||||
|
|
||||||
def empty(self):
|
async def read(self) -> int:
|
||||||
return self._cursor >= len(self._buffer) - 1
|
return (await self.read_bytes(1))[0]
|
||||||
|
|
||||||
def end(self) -> bytes:
|
async def read_boolean(self) -> bool:
|
||||||
return self.read_bytes(len(self._buffer) - self._cursor)
|
return (await self.read()) != 0
|
||||||
|
|
||||||
def read_byte(self) -> int:
|
async def read_short(self) -> int:
|
||||||
if self._cursor >= len(self._buffer):
|
return struct.unpack('>h', await self.read_bytes(2))[0]
|
||||||
print(f'\033[91mstream overread in {self._buffer}\033[0m')
|
|
||||||
raise EOFError('stream overread')
|
|
||||||
self._cursor += 1
|
|
||||||
return self._buffer[self._cursor - 1]
|
|
||||||
|
|
||||||
def read_boolean(self) -> bool:
|
async def read_ushort(self) -> int:
|
||||||
return self.read_byte() != 0
|
return struct.unpack('>H', await self.read_bytes(2))[0]
|
||||||
|
|
||||||
def read_short(self) -> int:
|
async def read_int(self) -> int:
|
||||||
return struct.unpack('>h', self.read_bytes(2))[0]
|
return struct.unpack('>i', await self.read_bytes(4))[0]
|
||||||
|
|
||||||
def read_ushort(self) -> int:
|
async def read_uint(self) -> int:
|
||||||
return struct.unpack('>H', self.read_bytes(2))[0]
|
return struct.unpack('>I', await self.read_bytes(4))[0]
|
||||||
|
|
||||||
def read_int(self) -> int:
|
async def read_long(self) -> int:
|
||||||
return struct.unpack('>i', self.read_bytes(4))[0]
|
return struct.unpack('>q', await self.read_bytes(8))[0]
|
||||||
|
|
||||||
def read_uint(self) -> int:
|
async def read_ulong(self) -> int:
|
||||||
return struct.unpack('>I', self.read_bytes(4))[0]
|
return struct.unpack('>Q', await self.read_bytes(8))[0]
|
||||||
|
|
||||||
def read_long(self) -> int:
|
async def read_float(self) -> float:
|
||||||
return struct.unpack('>q', self.read_bytes(8))[0]
|
return struct.unpack('>f', await self.read_bytes(4))[0]
|
||||||
|
|
||||||
def read_ulong(self) -> int:
|
async def read_double(self) -> float:
|
||||||
return struct.unpack('>Q', self.read_bytes(8))[0]
|
return struct.unpack('>d', await self.read_bytes(8))[0]
|
||||||
|
|
||||||
def read_float(self) -> float:
|
async def read_char(self) -> str:
|
||||||
return struct.unpack('>f', self.read_bytes(4))[0]
|
return chr(await self.read_ushort())
|
||||||
|
|
||||||
def read_double(self) -> float:
|
async def read_varint(self, bits: int = 32) -> int:
|
||||||
return struct.unpack('>d', self.read_bytes(8))[0]
|
|
||||||
|
|
||||||
def read_char(self) -> str:
|
|
||||||
return chr(self.read_ushort())
|
|
||||||
|
|
||||||
def read_varint(self, bits: int = 32) -> int:
|
|
||||||
value: int = 0
|
value: int = 0
|
||||||
position: int = 0
|
position: int = 0
|
||||||
while True:
|
while True:
|
||||||
byte = self.read_byte()
|
byte = await self.read()
|
||||||
value |= (byte & 0x7F) << position
|
value |= (byte & 0x7F) << position
|
||||||
if ((byte & 0x80) == 0):
|
if ((byte & 0x80) == 0):
|
||||||
break
|
break
|
||||||
|
@ -68,7 +58,7 @@ class DataInputStream:
|
||||||
raise ValueError('varint too big')
|
raise ValueError('varint too big')
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def read_string(self) -> str:
|
async def read_string(self) -> str:
|
||||||
size = self.read_short()
|
size = await self.read_short()
|
||||||
return self.read_bytes(size).decode('utf-8')
|
return (await self.read_bytes(size)).decode('utf-8')
|
||||||
|
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
from collections.abc import Iterable
|
|
||||||
|
|
||||||
from bta_proxy.packets.base import Packet
|
|
||||||
from bta_proxy.packets import *
|
|
||||||
from .datainputstream import DataInputStream
|
|
||||||
from typing import Generator, TextIO, TypeVar
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
|
||||||
|
|
||||||
def chunks(gen: Iterable[T], size: int) -> Generator[list[T], None, None]:
|
|
||||||
bucket: list[T] = []
|
|
||||||
for item in gen:
|
|
||||||
bucket.append(item)
|
|
||||||
if len(bucket) >= size:
|
|
||||||
yield bucket
|
|
||||||
bucket.clear()
|
|
||||||
if bucket:
|
|
||||||
yield bucket
|
|
||||||
|
|
||||||
def debug_client(buffer: bytes, tmpfile: TextIO):
|
|
||||||
stream = DataInputStream(buffer)
|
|
||||||
|
|
||||||
while not stream.empty():
|
|
||||||
try:
|
|
||||||
packet = Packet.parse_packet(stream)
|
|
||||||
match packet.packet_id:
|
|
||||||
case _:
|
|
||||||
print('[C]', packet)
|
|
||||||
except ValueError:
|
|
||||||
# print(f'[C:rest] {stream.end()}')
|
|
||||||
buf = stream.end()
|
|
||||||
print(f"[C] {buf[0]=} {len(buf)=}, {buf=}", file=tmpfile)
|
|
||||||
|
|
||||||
|
|
||||||
def debug_server(buffer: bytes, tmpfile: TextIO):
|
|
||||||
stream = DataInputStream(buffer)
|
|
||||||
|
|
||||||
while not stream.empty():
|
|
||||||
try:
|
|
||||||
packet = Packet.parse_packet(stream)
|
|
||||||
match packet.packet_id:
|
|
||||||
case Packet50PreChunk.packet_id:
|
|
||||||
continue
|
|
||||||
case Packet38EntityStatus.packet_id:
|
|
||||||
continue
|
|
||||||
case _:
|
|
||||||
print('[S]', packet)
|
|
||||||
except ValueError:
|
|
||||||
# print(f'[S:rest] {stream.end()}')
|
|
||||||
buf = stream.end()
|
|
||||||
print(f"[S] {buf[0]=} {len(buf)=}, {buf=}", file=tmpfile)
|
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
|
||||||
|
from asyncio.queues import Queue
|
||||||
|
|
||||||
|
from bta_proxy.datainputstream import AsyncDataInputStream
|
||||||
|
from bta_proxy.packets.base import Packet
|
||||||
|
|
||||||
|
|
||||||
|
async def inspect_client(queue: Queue, addr: tuple[str, int]):
|
||||||
|
dis = AsyncDataInputStream(queue)
|
||||||
|
while True:
|
||||||
|
pkt = await Packet.read_packet(dis)
|
||||||
|
print("C", pkt)
|
||||||
|
|
||||||
|
async def inspect_server(queue: Queue, addr: tuple[str, int]):
|
||||||
|
dis = AsyncDataInputStream(queue)
|
||||||
|
while True:
|
||||||
|
pkt = await Packet.read_packet(dis)
|
||||||
|
print("S", pkt)
|
|
@ -1,6 +1,6 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from bta_proxy.datainputstream import DataInputStream
|
from bta_proxy.datainputstream import AsyncDataInputStream
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@ -23,26 +23,26 @@ class DataItem:
|
||||||
|
|
||||||
class EntityData:
|
class EntityData:
|
||||||
@classmethod
|
@classmethod
|
||||||
def read_from(cls, dis: DataInputStream) -> list[DataItem]:
|
async def read_from(cls, dis: AsyncDataInputStream) -> list[DataItem]:
|
||||||
items = []
|
items = []
|
||||||
while (data := dis.read_byte()) != 0x7F:
|
while (data := await dis.read()) != 0x7F:
|
||||||
item_type = DataItemType((data & 0xE0) >> 5)
|
item_type = DataItemType((data & 0xE0) >> 5)
|
||||||
item_id: int = data & 0x1F
|
item_id: int = data & 0x1F
|
||||||
match item_type:
|
match item_type:
|
||||||
case DataItemType.BYTE:
|
case DataItemType.BYTE:
|
||||||
items.append(DataItem(item_type, item_id, dis.read_byte()))
|
items.append(DataItem(item_type, item_id, await dis.read()))
|
||||||
case DataItemType.SHORT:
|
case DataItemType.SHORT:
|
||||||
items.append(DataItem(item_type, item_id, dis.read_short()))
|
items.append(DataItem(item_type, item_id, await dis.read_short()))
|
||||||
case DataItemType.FLOAT:
|
case DataItemType.FLOAT:
|
||||||
items.append(DataItem(item_type, item_id, dis.read_float()))
|
items.append(DataItem(item_type, item_id, await dis.read_float()))
|
||||||
case DataItemType.STRING:
|
case DataItemType.STRING:
|
||||||
items.append(DataItem(item_type, item_id, dis.read_string()))
|
items.append(DataItem(item_type, item_id, await dis.read_string()))
|
||||||
case DataItemType.ITEMSTACK:
|
case DataItemType.ITEMSTACK:
|
||||||
items.append(DataItem(item_type, item_id, ItemStack.read_from(dis)))
|
items.append(DataItem(item_type, item_id, await ItemStack.read_from(dis)))
|
||||||
case DataItemType.CHUNK_COORDINATES:
|
case DataItemType.CHUNK_COORDINATES:
|
||||||
x = dis.read_float()
|
x = await dis.read_float()
|
||||||
y = dis.read_float()
|
y = await dis.read_float()
|
||||||
z = dis.read_float()
|
z = await dis.read_float()
|
||||||
items.append(DataItem(item_type, item_id, (x, y, z)))
|
items.append(DataItem(item_type, item_id, (x, y, z)))
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
|
||||||
from bta_proxy.datainputstream import DataInputStream
|
from bta_proxy.datainputstream import AsyncDataInputStream
|
||||||
|
|
||||||
|
|
||||||
class ItemStack:
|
class ItemStack:
|
||||||
|
@ -10,8 +10,8 @@ class ItemStack:
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read_from(cls, stream: DataInputStream) -> 'ItemStack':
|
async def read_from(cls, stream: AsyncDataInputStream) -> 'ItemStack':
|
||||||
item_id = stream.read_short()
|
item_id = await stream.read_short()
|
||||||
count = stream.read_byte()
|
count = await stream.read()
|
||||||
data = stream.read_ushort()
|
data = await stream.read_ushort()
|
||||||
return cls(item_id, count, data)
|
return cls(item_id, count, data)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, ClassVar, Type
|
from typing import Any, ClassVar, Type
|
||||||
|
|
||||||
from bta_proxy.entitydata import EntityData
|
from bta_proxy.entitydata import EntityData
|
||||||
from ..datainputstream import DataInputStream
|
from ..datainputstream import AsyncDataInputStream
|
||||||
|
|
||||||
class Packet:
|
class Packet:
|
||||||
REGISTRY: ClassVar[dict[int, Type['Packet']]] = {}
|
REGISTRY: ClassVar[dict[int, Type['Packet']]] = {}
|
||||||
|
@ -13,43 +13,43 @@ class Packet:
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read_from(cls, stream: DataInputStream) -> 'Packet':
|
async def read_data_from(cls, stream: AsyncDataInputStream) -> 'Packet':
|
||||||
fields: dict = {}
|
fields: dict = {}
|
||||||
for key, datatype in cls.FIELDS:
|
for key, datatype in cls.FIELDS:
|
||||||
fields[key] = cls.read_field(stream, datatype)
|
fields[key] = await cls.read_field(stream, datatype)
|
||||||
return cls(**fields)
|
return cls(**fields)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def read_field(stream: DataInputStream, datatype: Any):
|
async def read_field(stream: AsyncDataInputStream, datatype: Any):
|
||||||
match datatype:
|
match datatype:
|
||||||
case 'uint':
|
case 'uint':
|
||||||
return stream.read_uint()
|
return await stream.read_uint()
|
||||||
case 'int':
|
case 'int':
|
||||||
return stream.read_int()
|
return await stream.read_int()
|
||||||
case 'str':
|
case 'str':
|
||||||
return stream.read_string()
|
return await stream.read_string()
|
||||||
case 'str', length:
|
case 'str', length:
|
||||||
return stream.read_string()[:length]
|
return (await stream.read_string())[:length]
|
||||||
case 'ulong':
|
case 'ulong':
|
||||||
return stream.read_ulong()
|
return await stream.read_ulong()
|
||||||
case 'long':
|
case 'long':
|
||||||
return stream.read_long()
|
return await stream.read_long()
|
||||||
case 'ushort':
|
case 'ushort':
|
||||||
return stream.read_ushort()
|
return await stream.read_ushort()
|
||||||
case 'short':
|
case 'short':
|
||||||
return stream.read_short()
|
return await stream.read_short()
|
||||||
case 'byte':
|
case 'byte':
|
||||||
return stream.read_byte()
|
return await stream.read()
|
||||||
case 'float':
|
case 'float':
|
||||||
return stream.read_float()
|
return await stream.read_float()
|
||||||
case 'double':
|
case 'double':
|
||||||
return stream.read_double()
|
return await stream.read_double()
|
||||||
case 'bool':
|
case 'bool':
|
||||||
return stream.read_boolean()
|
return await stream.read_boolean()
|
||||||
case 'bytes', length:
|
case 'bytes', length:
|
||||||
return stream.read_bytes(length)
|
return await stream.read_bytes(length)
|
||||||
case 'entitydata':
|
case 'entitydata':
|
||||||
return EntityData.read_from(stream)
|
return await EntityData.read_from(stream)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f'unknown type {datatype}')
|
raise ValueError(f'unknown type {datatype}')
|
||||||
|
|
||||||
|
@ -59,16 +59,17 @@ class Packet:
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_packet(cls, stream: DataInputStream) -> 'Packet':
|
async def read_packet(cls, stream: AsyncDataInputStream) -> 'Packet':
|
||||||
packet_id: int = stream.read_byte()
|
packet_id: int = await stream.read()
|
||||||
if packet_id not in cls.REGISTRY:
|
if packet_id not in cls.REGISTRY:
|
||||||
stream._cursor -= 1
|
|
||||||
raise ValueError(f'invalid packet 0x{packet_id:02x}')
|
raise ValueError(f'invalid packet 0x{packet_id:02x}')
|
||||||
return cls.REGISTRY[packet_id].read_from(stream)
|
pkt = await cls.REGISTRY[packet_id].read_data_from(stream)
|
||||||
|
pkt.packet_id = packet_id
|
||||||
|
return pkt
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
pkt_name = self.REGISTRY[self.packet_id].__name__
|
pkt_name = self.REGISTRY[self.packet_id].__name__
|
||||||
fields = []
|
fields = []
|
||||||
for name, _ in self.FIELDS:
|
for name, _ in self.FIELDS:
|
||||||
fields.append(f'{name}={getattr(self, name)!r}')
|
fields.append(f'{name}={getattr(self, name)!r}')
|
||||||
return f'<{pkt_name} {str.join(", ", fields)}'
|
return f'<{pkt_name} {str.join(", ", fields)}>'
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
from asyncio.protocols import Protocol
|
||||||
|
from asyncio.queues import Queue
|
||||||
|
from asyncio import AbstractEventLoop, get_event_loop
|
||||||
|
from asyncio.streams import StreamReader, StreamWriter, open_connection
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from bta_proxy.dpi import inspect_client, inspect_server
|
||||||
|
|
||||||
|
|
||||||
|
class BTAProxy:
|
||||||
|
def __init__(self, host: str, port: int, loop: Optional[AbstractEventLoop] = None):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.loop = loop or get_event_loop()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def pipe(reader: StreamReader, writer: StreamWriter, queue: Queue):
|
||||||
|
try:
|
||||||
|
while not reader.at_eof():
|
||||||
|
packet = await reader.read(0x400000)
|
||||||
|
queue.put_nowait(packet)
|
||||||
|
writer.write(packet)
|
||||||
|
finally:
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
async def handle_client(self, cli_reader: StreamReader, cli_writer: StreamWriter):
|
||||||
|
try:
|
||||||
|
peername = cli_writer.get_extra_info("peername")
|
||||||
|
srv_reader, srv_writer = await open_connection(self.host, self.port)
|
||||||
|
|
||||||
|
queue_srv: Queue = Queue()
|
||||||
|
queue_cli: Queue = Queue()
|
||||||
|
|
||||||
|
self.loop.create_task(inspect_client(queue_cli, peername))
|
||||||
|
self.loop.create_task(inspect_server(queue_srv, peername))
|
||||||
|
self.loop.create_task(self.pipe(cli_reader, srv_writer, queue_cli))
|
||||||
|
self.loop.create_task(self.pipe(srv_reader, cli_writer, queue_srv))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"oopsie whoopsie {e}")
|
Loading…
Reference in New Issue