From de25182868313256d393e2952c106c2d1943e331 Mon Sep 17 00:00:00 2001 From: hkc Date: Fri, 25 Aug 2023 21:49:10 +0300 Subject: [PATCH] NOW IT'S ALL ASYNC --- bta_proxy/__init__.py | 0 bta_proxy/__main__.py | 50 +++-------- bta_proxy/datainputstream.py | 84 ++++++++----------- bta_proxy/debug.py | 52 ------------ bta_proxy/dpi.py | 18 ++++ bta_proxy/entitydata.py | 22 ++--- bta_proxy/itemstack.py | 10 +-- bta_proxy/packets/base.py | 47 ++++++----- bta_proxy/proxy.py | 39 +++++++++ .../__main__.py => tools/mkpacketfile.py | 0 10 files changed, 146 insertions(+), 176 deletions(-) create mode 100644 bta_proxy/__init__.py delete mode 100644 bta_proxy/debug.py create mode 100644 bta_proxy/dpi.py create mode 100644 bta_proxy/proxy.py rename bta_proxy/cli/mkpacketfile/__main__.py => tools/mkpacketfile.py (100%) diff --git a/bta_proxy/__init__.py b/bta_proxy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bta_proxy/__main__.py b/bta_proxy/__main__.py index 87c841c..9d6ea3c 100644 --- a/bta_proxy/__main__.py +++ b/bta_proxy/__main__.py @@ -1,47 +1,21 @@ +# x-run: cd .. && python -m bta_proxy '201:4f8c:4ea:0:71ec:6d7:6f1b:a4f9' import asyncio -import socket -from asyncio.streams import StreamReader, StreamWriter +from argparse import ArgumentParser +from sys import argv -from .debug import debug_client, debug_server +from bta_proxy.proxy import BTAProxy MAX_SIZE = 0x400000 -async def handle_server(writer: StreamWriter, server: socket.socket, fp): - try: - while (packet := await loop.sock_recv(server, MAX_SIZE)): - try: - debug_server(packet, fp) - except Exception as e: - print(f'[S] error: {e}') - writer.write(packet) - await writer.drain() - except Exception as e: - print(f'handle_server(): {e}') - -async def handle_client(reader: StreamReader, writer: StreamWriter): - print(reader, writer) - try: - server = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - server.connect(('201:4f8c:4ea:0:71ec:6d7:6f1b:a4f9', 25565)) - server.setblocking(False) - - with open("packets.txt", "w") as fp: - loop.create_task(handle_server(writer, server, fp)) - - while (packet := await reader.read(MAX_SIZE)): - try: - debug_client(packet, fp) - except Exception as e: - print(f'[C] error: {e}') - await loop.sock_sendall(server, packet) - except Exception as e: - print(f'handle_client(): {e}') -loop = asyncio.get_event_loop() -def main(): - loop.create_task(asyncio.start_server(handle_client, 'localhost', 25565)) - loop.run_forever() +async def main(args): + loop = asyncio.get_running_loop() + server = await asyncio.start_server(BTAProxy(args[0], 25565, loop).handle_client, "localhost", 25565) + + async with server: + await server.serve_forever() + if __name__ == '__main__': - main() + asyncio.run(main(argv[1:])) diff --git a/bta_proxy/datainputstream.py b/bta_proxy/datainputstream.py index d760f5e..be357e3 100644 --- a/bta_proxy/datainputstream.py +++ b/bta_proxy/datainputstream.py @@ -1,65 +1,55 @@ +from asyncio.queues import Queue import struct -class DataInputStream: - def __init__(self, buffer: bytes): - self._buffer = buffer - self._cursor = 0 +class AsyncDataInputStream: + def __init__(self, queue: Queue): + self._queue = queue + self._last = b'' - def read_bytes(self, n: int) -> bytes: - if self._cursor + n > len(self._buffer): - raise EOFError('stream overread') - blob = self._buffer[self._cursor : self._cursor + n] - self._cursor += n - return blob + async def read_bytes(self, n: int) -> bytes: + if len(self._last) < n: + self._last += await self._queue.get() + out, self._last = self._last[:n], self._last[n:] + return out - def empty(self): - return self._cursor >= len(self._buffer) - 1 + async def read(self) -> int: + return (await self.read_bytes(1))[0] - def end(self) -> bytes: - return self.read_bytes(len(self._buffer) - self._cursor) + async def read_boolean(self) -> bool: + return (await self.read()) != 0 - def read_byte(self) -> int: - if self._cursor >= len(self._buffer): - print(f'\033[91mstream overread in {self._buffer}\033[0m') - raise EOFError('stream overread') - self._cursor += 1 - return self._buffer[self._cursor - 1] + async def read_short(self) -> int: + return struct.unpack('>h', await self.read_bytes(2))[0] - def read_boolean(self) -> bool: - return self.read_byte() != 0 + async def read_ushort(self) -> int: + return struct.unpack('>H', await self.read_bytes(2))[0] - def read_short(self) -> int: - return struct.unpack('>h', self.read_bytes(2))[0] + async def read_int(self) -> int: + return struct.unpack('>i', await self.read_bytes(4))[0] - def read_ushort(self) -> int: - return struct.unpack('>H', self.read_bytes(2))[0] + async def read_uint(self) -> int: + return struct.unpack('>I', await self.read_bytes(4))[0] - def read_int(self) -> int: - return struct.unpack('>i', self.read_bytes(4))[0] + async def read_long(self) -> int: + return struct.unpack('>q', await self.read_bytes(8))[0] - def read_uint(self) -> int: - return struct.unpack('>I', self.read_bytes(4))[0] + async def read_ulong(self) -> int: + return struct.unpack('>Q', await self.read_bytes(8))[0] - def read_long(self) -> int: - return struct.unpack('>q', self.read_bytes(8))[0] + async def read_float(self) -> float: + return struct.unpack('>f', await self.read_bytes(4))[0] - def read_ulong(self) -> int: - return struct.unpack('>Q', self.read_bytes(8))[0] + async def read_double(self) -> float: + return struct.unpack('>d', await self.read_bytes(8))[0] - def read_float(self) -> float: - return struct.unpack('>f', self.read_bytes(4))[0] + async def read_char(self) -> str: + return chr(await self.read_ushort()) - def read_double(self) -> float: - return struct.unpack('>d', self.read_bytes(8))[0] - - def read_char(self) -> str: - return chr(self.read_ushort()) - - def read_varint(self, bits: int = 32) -> int: + async def read_varint(self, bits: int = 32) -> int: value: int = 0 position: int = 0 while True: - byte = self.read_byte() + byte = await self.read() value |= (byte & 0x7F) << position if ((byte & 0x80) == 0): break @@ -68,7 +58,7 @@ class DataInputStream: raise ValueError('varint too big') return value - def read_string(self) -> str: - size = self.read_short() - return self.read_bytes(size).decode('utf-8') + async def read_string(self) -> str: + size = await self.read_short() + return (await self.read_bytes(size)).decode('utf-8') diff --git a/bta_proxy/debug.py b/bta_proxy/debug.py deleted file mode 100644 index 1d3cda3..0000000 --- a/bta_proxy/debug.py +++ /dev/null @@ -1,52 +0,0 @@ -from collections.abc import Iterable - -from bta_proxy.packets.base import Packet -from bta_proxy.packets import * -from .datainputstream import DataInputStream -from typing import Generator, TextIO, TypeVar - -T = TypeVar('T') - -def chunks(gen: Iterable[T], size: int) -> Generator[list[T], None, None]: - bucket: list[T] = [] - for item in gen: - bucket.append(item) - if len(bucket) >= size: - yield bucket - bucket.clear() - if bucket: - yield bucket - -def debug_client(buffer: bytes, tmpfile: TextIO): - stream = DataInputStream(buffer) - - while not stream.empty(): - try: - packet = Packet.parse_packet(stream) - match packet.packet_id: - case _: - print('[C]', packet) - except ValueError: - # print(f'[C:rest] {stream.end()}') - buf = stream.end() - print(f"[C] {buf[0]=} {len(buf)=}, {buf=}", file=tmpfile) - - -def debug_server(buffer: bytes, tmpfile: TextIO): - stream = DataInputStream(buffer) - - while not stream.empty(): - try: - packet = Packet.parse_packet(stream) - match packet.packet_id: - case Packet50PreChunk.packet_id: - continue - case Packet38EntityStatus.packet_id: - continue - case _: - print('[S]', packet) - except ValueError: - # print(f'[S:rest] {stream.end()}') - buf = stream.end() - print(f"[S] {buf[0]=} {len(buf)=}, {buf=}", file=tmpfile) - diff --git a/bta_proxy/dpi.py b/bta_proxy/dpi.py new file mode 100644 index 0000000..9d9a7cc --- /dev/null +++ b/bta_proxy/dpi.py @@ -0,0 +1,18 @@ + +from asyncio.queues import Queue + +from bta_proxy.datainputstream import AsyncDataInputStream +from bta_proxy.packets.base import Packet + + +async def inspect_client(queue: Queue, addr: tuple[str, int]): + dis = AsyncDataInputStream(queue) + while True: + pkt = await Packet.read_packet(dis) + print("C", pkt) + +async def inspect_server(queue: Queue, addr: tuple[str, int]): + dis = AsyncDataInputStream(queue) + while True: + pkt = await Packet.read_packet(dis) + print("S", pkt) diff --git a/bta_proxy/entitydata.py b/bta_proxy/entitydata.py index f840ecc..aa51c96 100644 --- a/bta_proxy/entitydata.py +++ b/bta_proxy/entitydata.py @@ -1,6 +1,6 @@ from typing import Any -from bta_proxy.datainputstream import DataInputStream +from bta_proxy.datainputstream import AsyncDataInputStream from enum import Enum from dataclasses import dataclass @@ -23,26 +23,26 @@ class DataItem: class EntityData: @classmethod - def read_from(cls, dis: DataInputStream) -> list[DataItem]: + async def read_from(cls, dis: AsyncDataInputStream) -> list[DataItem]: items = [] - while (data := dis.read_byte()) != 0x7F: + while (data := await dis.read()) != 0x7F: item_type = DataItemType((data & 0xE0) >> 5) item_id: int = data & 0x1F match item_type: case DataItemType.BYTE: - items.append(DataItem(item_type, item_id, dis.read_byte())) + items.append(DataItem(item_type, item_id, await dis.read())) case DataItemType.SHORT: - items.append(DataItem(item_type, item_id, dis.read_short())) + items.append(DataItem(item_type, item_id, await dis.read_short())) case DataItemType.FLOAT: - items.append(DataItem(item_type, item_id, dis.read_float())) + items.append(DataItem(item_type, item_id, await dis.read_float())) case DataItemType.STRING: - items.append(DataItem(item_type, item_id, dis.read_string())) + items.append(DataItem(item_type, item_id, await dis.read_string())) case DataItemType.ITEMSTACK: - items.append(DataItem(item_type, item_id, ItemStack.read_from(dis))) + items.append(DataItem(item_type, item_id, await ItemStack.read_from(dis))) case DataItemType.CHUNK_COORDINATES: - x = dis.read_float() - y = dis.read_float() - z = dis.read_float() + x = await dis.read_float() + y = await dis.read_float() + z = await dis.read_float() items.append(DataItem(item_type, item_id, (x, y, z))) return items diff --git a/bta_proxy/itemstack.py b/bta_proxy/itemstack.py index ba426ac..97edb2a 100644 --- a/bta_proxy/itemstack.py +++ b/bta_proxy/itemstack.py @@ -1,5 +1,5 @@ -from bta_proxy.datainputstream import DataInputStream +from bta_proxy.datainputstream import AsyncDataInputStream class ItemStack: @@ -10,8 +10,8 @@ class ItemStack: self.data = data @classmethod - def read_from(cls, stream: DataInputStream) -> 'ItemStack': - item_id = stream.read_short() - count = stream.read_byte() - data = stream.read_ushort() + async def read_from(cls, stream: AsyncDataInputStream) -> 'ItemStack': + item_id = await stream.read_short() + count = await stream.read() + data = await stream.read_ushort() return cls(item_id, count, data) diff --git a/bta_proxy/packets/base.py b/bta_proxy/packets/base.py index 1406d98..fbdc364 100644 --- a/bta_proxy/packets/base.py +++ b/bta_proxy/packets/base.py @@ -1,7 +1,7 @@ from typing import Any, ClassVar, Type from bta_proxy.entitydata import EntityData -from ..datainputstream import DataInputStream +from ..datainputstream import AsyncDataInputStream class Packet: REGISTRY: ClassVar[dict[int, Type['Packet']]] = {} @@ -13,43 +13,43 @@ class Packet: setattr(self, k, v) @classmethod - def read_from(cls, stream: DataInputStream) -> 'Packet': + async def read_data_from(cls, stream: AsyncDataInputStream) -> 'Packet': fields: dict = {} for key, datatype in cls.FIELDS: - fields[key] = cls.read_field(stream, datatype) + fields[key] = await cls.read_field(stream, datatype) return cls(**fields) @staticmethod - def read_field(stream: DataInputStream, datatype: Any): + async def read_field(stream: AsyncDataInputStream, datatype: Any): match datatype: case 'uint': - return stream.read_uint() + return await stream.read_uint() case 'int': - return stream.read_int() + return await stream.read_int() case 'str': - return stream.read_string() + return await stream.read_string() case 'str', length: - return stream.read_string()[:length] + return (await stream.read_string())[:length] case 'ulong': - return stream.read_ulong() + return await stream.read_ulong() case 'long': - return stream.read_long() + return await stream.read_long() case 'ushort': - return stream.read_ushort() + return await stream.read_ushort() case 'short': - return stream.read_short() + return await stream.read_short() case 'byte': - return stream.read_byte() + return await stream.read() case 'float': - return stream.read_float() + return await stream.read_float() case 'double': - return stream.read_double() + return await stream.read_double() case 'bool': - return stream.read_boolean() + return await stream.read_boolean() case 'bytes', length: - return stream.read_bytes(length) + return await stream.read_bytes(length) case 'entitydata': - return EntityData.read_from(stream) + return await EntityData.read_from(stream) case _: raise ValueError(f'unknown type {datatype}') @@ -59,16 +59,17 @@ class Packet: super().__init_subclass__(**kwargs) @classmethod - def parse_packet(cls, stream: DataInputStream) -> 'Packet': - packet_id: int = stream.read_byte() + async def read_packet(cls, stream: AsyncDataInputStream) -> 'Packet': + packet_id: int = await stream.read() if packet_id not in cls.REGISTRY: - stream._cursor -= 1 raise ValueError(f'invalid packet 0x{packet_id:02x}') - return cls.REGISTRY[packet_id].read_from(stream) + pkt = await cls.REGISTRY[packet_id].read_data_from(stream) + pkt.packet_id = packet_id + return pkt def __repr__(self): pkt_name = self.REGISTRY[self.packet_id].__name__ fields = [] for name, _ in self.FIELDS: fields.append(f'{name}={getattr(self, name)!r}') - return f'<{pkt_name} {str.join(", ", fields)}' + return f'<{pkt_name} {str.join(", ", fields)}>' diff --git a/bta_proxy/proxy.py b/bta_proxy/proxy.py new file mode 100644 index 0000000..734b289 --- /dev/null +++ b/bta_proxy/proxy.py @@ -0,0 +1,39 @@ +from asyncio.protocols import Protocol +from asyncio.queues import Queue +from asyncio import AbstractEventLoop, get_event_loop +from asyncio.streams import StreamReader, StreamWriter, open_connection +from typing import Optional + +from bta_proxy.dpi import inspect_client, inspect_server + + +class BTAProxy: + def __init__(self, host: str, port: int, loop: Optional[AbstractEventLoop] = None): + self.host = host + self.port = port + self.loop = loop or get_event_loop() + + @staticmethod + async def pipe(reader: StreamReader, writer: StreamWriter, queue: Queue): + try: + while not reader.at_eof(): + packet = await reader.read(0x400000) + queue.put_nowait(packet) + writer.write(packet) + finally: + writer.close() + + async def handle_client(self, cli_reader: StreamReader, cli_writer: StreamWriter): + try: + peername = cli_writer.get_extra_info("peername") + srv_reader, srv_writer = await open_connection(self.host, self.port) + + queue_srv: Queue = Queue() + queue_cli: Queue = Queue() + + self.loop.create_task(inspect_client(queue_cli, peername)) + self.loop.create_task(inspect_server(queue_srv, peername)) + self.loop.create_task(self.pipe(cli_reader, srv_writer, queue_cli)) + self.loop.create_task(self.pipe(srv_reader, cli_writer, queue_srv)) + except Exception as e: + print(f"oopsie whoopsie {e}") diff --git a/bta_proxy/cli/mkpacketfile/__main__.py b/tools/mkpacketfile.py similarity index 100% rename from bta_proxy/cli/mkpacketfile/__main__.py rename to tools/mkpacketfile.py