from typing import Any, ClassVar, Optional, Type, Union import gzip from bta_proxy.entitydata import EntityData from bta_proxy.itemstack import ItemStack from ..datainputstream import AsyncDataInputStream from logging import getLogger logger = getLogger(__name__) def try_int(v: str) -> Union[str, int]: try: return int(v) except ValueError: return v class Packet: REGISTRY: ClassVar[dict[int, Type["Packet"]]] = {} FIELDS: ClassVar[list[tuple[str, Any]]] = [] packet_id: int def __init__(self, **params): for k, v in params.items(): setattr(self, k, v) @classmethod async def read_data_from(cls, stream: AsyncDataInputStream) -> "Packet": logger.debug("Packet.read_data_from(%r)", stream) fields: dict = {} for key, datatype in cls.FIELDS: if "?" in key: key, cond = key.split("?", 1) if "==" in cond: k, v = cond.split("==") if fields[k] != try_int(v): continue elif not fields[cond]: continue try: logger.debug(f"reading {key=} of type {datatype!r} ({fields=})") fields[key] = await cls.read_field(stream, datatype, fields) except Exception as e: raise ValueError(f"Failed getting key {key} on {cls}") from e return cls(**fields) @staticmethod async def read_field( stream: AsyncDataInputStream, datatype: Any, fields: dict[str, Any] = {}, ): logger.debug(f"Packet.read_field(_, {datatype=}, {fields=})") match datatype: case "list", sizekey, *args: args = args[0] if len(args) == 1 else tuple(args) size = try_int(sizekey) length = size if isinstance(size, int) else fields[sizekey] return [ await Packet.read_field(stream, args, fields) for _ in range(length) ] case "tuple", *tuples: out = [] for tup in tuples: out.append(await Packet.read_field(stream, tup, fields)) return tuple(out) case "uint": return await stream.read_uint() case "int": return await stream.read_int() case "str": return await stream.read_string() case "str", length: return (await stream.read_string())[:length] case "string": return await stream.read_string() case "string", length: return (await stream.read_string())[:length] case "ulong": return await stream.read_ulong() case "long": return await stream.read_long() case "ushort": return await stream.read_ushort() case "short": return await stream.read_short() case "byte": return await stream.read_byte() case "ubyte": return await stream.read_ubyte() case "float": return await stream.read_float() case "double": return await stream.read_double() case "bool": return await stream.read_boolean() case "bytes", length_or_key: if isinstance(length_or_key, int): return await stream.read_bytes(length_or_key) elif isinstance(length_or_key, str): if length_or_key == ".rest": return stream.read_rest() if length_or_key not in fields: raise KeyError( f"failed to find {length_or_key} in {fields} to read bytes length" ) return await stream.read_bytes(fields[length_or_key]) raise ValueError( f"invalid type for bytes length_or_key: {length_or_key!r}" ) case "itemstack": return await ItemStack.read_from(stream) case "itemstack", length_or_key: items: list[Optional[ItemStack]] = [] if isinstance(length_or_key, int): for _ in range(length_or_key): if (item_id := await stream.read_short()) >= 0: count = await stream.read() data = await stream.read_short() items.append(ItemStack(item_id, count, data)) else: items.append(None) return items elif isinstance(length_or_key, str): if fields[length_or_key] <= 0: return [] if length_or_key not in fields: raise KeyError( f"failed to find {length_or_key} in {fields} to read number of itemstacks" ) for _ in range(fields[length_or_key]): if (item_id := await stream.read_short()) >= 0: count = await stream.read() data = await stream.read_short() items.append(ItemStack(item_id, count, data)) else: items.append(None) return items raise ValueError( f"invalid type for itemstack length_or_key: {length_or_key!r}" ) case "itemstack_optional": if (item_id := await stream.read_short()) >= 0: count = await stream.read() data = await stream.read_short() return ItemStack(item_id, count, data) return None case "extendeditemstack_optional": if (item_id := await stream.read_short()) >= 0: count = await stream.read() data = await stream.read_short() tag_size = await stream.read_short() tag = await stream.read_bytes(tag_size) return ItemStack(item_id, count, data, tag) return None case "entitydata": return await EntityData.read_from(stream) case "nbt": size = await stream.read_short() if size < 0: raise ValueError("Received tag length is less than zero! Weird tag!") if size == 0: return None return gzip.decompress(await stream.read_bytes(size)) case _: raise ValueError(f"unknown type {datatype}") def __init_subclass__(cls, packet_id: int, **kwargs) -> None: logger.debug(f"registered packet {cls} with id = {packet_id}") Packet.REGISTRY[packet_id] = cls cls.packet_id = packet_id super().__init_subclass__(**kwargs) def post_creation(self): pass @classmethod async def read_packet(cls, stream: AsyncDataInputStream) -> "Packet": packet_id: int = await stream.read() logger.debug(f"incoming {packet_id=}") if packet_id not in cls.REGISTRY: raise ValueError( f"invalid packet 0x{packet_id:02x} ({packet_id}) (rest: {stream.peek_rest()[:16]}...)" ) pkt = await cls.REGISTRY[packet_id].read_data_from(stream) pkt.packet_id = packet_id pkt.post_creation() logger.debug(f"received {pkt}") return pkt def __repr__(self): pkt_name = self.REGISTRY[self.packet_id].__name__ fields = [] for key, _ in self.FIELDS: if "?" in key: key, cond = key.split("?", 1) fields.append(f"{key}={getattr(self, key, None)!r} depending on {cond}") else: fields.append(f"{key}={getattr(self, key)!r}") return f'<{pkt_name} {str.join(", ", fields)}>'