The Beginning Of Filters

Also Flake8 cleanup and other stuff
This commit is contained in:
Casey 2022-08-27 14:27:42 +03:00
parent 6a9750b33b
commit ae8a1ddf34
Signed by: hkc
GPG Key ID: F0F6CFE11CDB0960
12 changed files with 126 additions and 85 deletions

View File

@ -10,24 +10,18 @@ from mastoposter.types import Status
def load_integrations_from(config: ConfigParser) -> List[BaseIntegration]: def load_integrations_from(config: ConfigParser) -> List[BaseIntegration]:
modules: List[BaseIntegration] = [] modules: List[BaseIntegration] = []
for module_name in config.get("main", "modules").split(): for module_name in config.get("main", "modules").split():
module = config[f"module/{module_name}"] mod = config[f"module/{module_name}"]
if module["type"] == "telegram": if mod["type"] == "telegram":
modules.append( modules.append(TelegramIntegration(mod))
TelegramIntegration( elif mod["type"] == "discord":
token=module["token"], modules.append(DiscordIntegration(mod))
chat_id=module["chat"],
show_post_link=module.getboolean("show_post_link", fallback=True),
show_boost_from=module.getboolean("show_boost_from", fallback=True),
)
)
elif module["type"] == "discord":
modules.append(DiscordIntegration(webhook=module["webhook"]))
else: else:
raise ValueError("Invalid module type %r" % module["type"]) raise ValueError("Invalid module type %r" % mod["type"])
return modules return modules
async def execute_integrations( async def execute_integrations(
status: Status, sinks: List[BaseIntegration] status: Status, sinks: List[BaseIntegration]
) -> List[Optional[str]]: ) -> List[Optional[str]]:
return await gather(*[sink.post(status) for sink in sinks], return_exceptions=True) coros = [sink.post(status) for sink in sinks]
return await gather(*coros, return_exceptions=True)

View File

@ -57,7 +57,7 @@ def main(config_path: str):
modules, modules,
conf["main"]["user"], conf["main"]["user"],
url=url, url=url,
reconnect=conf["main"].getboolean("auto_reconnect", fallback=False), reconnect=conf["main"].getboolean("auto_reconnect", False),
list=conf["main"]["list"], list=conf["main"]["list"],
access_token=conf["main"]["token"], access_token=conf["main"]["token"],
) )

View File

@ -0,0 +1,2 @@
from .base import BaseFilter # NOQA
from mastoposter.filters.boost_filter import BoostFilter # NOQA

View File

@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
from typing import ClassVar, Dict, Type
from mastoposter.types import Status
from re import Pattern, compile as regexp
class BaseFilter(ABC):
FILTER_REGISTRY: ClassVar[Dict[str, Type["BaseFilter"]]] = {}
FILTER_NAME_REGEX: Pattern = regexp(r"^([a-z_]+)$")
def __init__(self):
pass
@abstractmethod
def __call__(self, status: Status) -> bool:
raise NotImplementedError
def __init_subclass__(cls, filter_name: str, **kwargs):
super().__init_subclass__(**kwargs)
if not cls.FILTER_NAME_REGEX.match(filter_name):
raise ValueError(f"invalid {filter_name=!r}")
if filter_name in cls.FILTER_REGISTRY:
raise KeyError(f"{filter_name=!r} is already registered")
cls.FILTER_REGISTRY[filter_name] = cls

View File

@ -0,0 +1,7 @@
from mastoposter.filters.base import BaseFilter
from mastoposter.types import Status
class BoostFilter(BaseFilter, filter_name="boost"):
def __call__(self, status: Status) -> bool:
return status.reblog is not None

View File

@ -1,2 +1,2 @@
from .telegram import TelegramIntegration from .telegram import TelegramIntegration # NOQA
from .discord import DiscordIntegration from .discord import DiscordIntegration # NOQA

View File

@ -1,13 +1,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from configparser import SectionProxy
from typing import Optional from typing import Optional
from mastoposter.types import Status from mastoposter.types import Status
class BaseIntegration(ABC): class BaseIntegration(ABC):
def __init__(self): def __init__(self, section: SectionProxy):
pass pass
@abstractmethod @abstractmethod
async def post(self, status: Status) -> Optional[str]: async def post(self, status: Status) -> Optional[str]:
raise NotImplemented raise NotImplementedError

View File

@ -1,3 +1,4 @@
from configparser import SectionProxy
from typing import List, Optional from typing import List, Optional
from bs4 import BeautifulSoup, PageElement, Tag from bs4 import BeautifulSoup, PageElement, Tag
from httpx import AsyncClient from httpx import AsyncClient
@ -12,8 +13,8 @@ from mastoposter.types import Status
class DiscordIntegration(BaseIntegration): class DiscordIntegration(BaseIntegration):
def __init__(self, webhook: str): def __init__(self, section: SectionProxy):
self.webhook = webhook self.webhook = section.get("webhook", "")
@staticmethod @staticmethod
def md_escape(text: str) -> str: def md_escape(text: str) -> str:
@ -33,11 +34,15 @@ class DiscordIntegration(BaseIntegration):
if isinstance(el, Tag): if isinstance(el, Tag):
if el.name == "a": if el.name == "a":
return "[%s](%s)" % ( return "[%s](%s)" % (
cls.md_escape(str.join("", map(cls.node_to_text, el.children))), cls.md_escape(
str.join("", map(cls.node_to_text, el.children))
),
el.attrs["href"], el.attrs["href"],
) )
elif el.name == "p": elif el.name == "p":
return str.join("", map(cls.node_to_text, el.children)) + "\n\n" return (
str.join("", map(cls.node_to_text, el.children)) + "\n\n"
)
elif el.name == "br": elif el.name == "br":
return "\n" return "\n"
return str.join("", map(cls.node_to_text, el.children)) return str.join("", map(cls.node_to_text, el.children))
@ -70,12 +75,16 @@ class DiscordIntegration(BaseIntegration):
source = status.reblog or status source = status.reblog or status
embeds: List[DiscordEmbed] = [] embeds: List[DiscordEmbed] = []
text = self.node_to_text(BeautifulSoup(source.content, features="lxml")) text = self.node_to_text(
BeautifulSoup(source.content, features="lxml")
)
if source.spoiler_text: if source.spoiler_text:
text = f"{source.spoiler_text}\n||{text}||" text = f"{source.spoiler_text}\n||{text}||"
if status.reblog is not None: if status.reblog is not None:
title = f"@{status.account.acct} boosted from @{source.account.acct}" title = (
f"@{status.account.acct} boosted from @{source.account.acct}"
)
else: else:
title = f"@{status.account.acct} posted" title = f"@{status.account.acct} posted"

View File

@ -68,7 +68,9 @@ class DiscordEmbed:
"title": self.title, "title": self.title,
"description": self.description, "description": self.description,
"url": self.url, "url": self.url,
"timestamp": _f(datetime.isoformat, self.timestamp, "T", "seconds"), "timestamp": _f(
datetime.isoformat, self.timestamp, "T", "seconds"
),
"color": self.color, "color": self.color,
"footer": _f(asdict, self.footer), "footer": _f(asdict, self.footer),
"image": _f(asdict, self.image), "image": _f(asdict, self.image),

View File

@ -1,6 +1,7 @@
from configparser import SectionProxy
from dataclasses import dataclass from dataclasses import dataclass
from html import escape from html import escape
from typing import Any, List, Mapping, Optional, Union from typing import Any, List, Mapping, Optional
from bs4 import BeautifulSoup, Tag, PageElement from bs4 import BeautifulSoup, Tag, PageElement
from httpx import AsyncClient from httpx import AsyncClient
from mastoposter.integrations.base import BaseIntegration from mastoposter.integrations.base import BaseIntegration
@ -16,7 +17,12 @@ class TGResponse:
@classmethod @classmethod
def from_dict(cls, data: dict, params: dict) -> "TGResponse": def from_dict(cls, data: dict, params: dict) -> "TGResponse":
return cls(data["ok"], params, data.get("result"), data.get("description")) return cls(
ok=data["ok"],
params=params,
result=data.get("result"),
error=data.get("description"),
)
class TelegramIntegration(BaseIntegration): class TelegramIntegration(BaseIntegration):
@ -36,19 +42,12 @@ class TelegramIntegration(BaseIntegration):
"unknown": "document", "unknown": "document",
} }
def __init__( def __init__(self, sect: SectionProxy):
self, self.token = sect.get("token", "")
token: str, self.chat_id = sect.get("chat", "")
chat_id: Union[str, int], self.show_post_link = sect.getboolean("show_post_link", True)
show_post_link: bool = True, self.show_boost_from = sect.getboolean("show_boost_from", True)
show_boost_from: bool = True, self.silent = sect.getboolean("silent", True)
silent: bool = True,
):
self.token = token
self.chat_id = chat_id
self.show_post_link = show_post_link
self.show_boost_from = show_boost_from
self.silent = silent
async def _tg_request(self, method: str, **kwargs) -> TGResponse: async def _tg_request(self, method: str, **kwargs) -> TGResponse:
url = self.API_URL.format(self.token, method) url = self.API_URL.format(self.token, method)
@ -82,7 +81,9 @@ class TelegramIntegration(BaseIntegration):
**{self.MEDIA_MAPPING[media.type]: media.url}, **{self.MEDIA_MAPPING[media.type]: media.url},
) )
async def _post_mediagroup(self, text: str, media: List[Attachment]) -> TGResponse: async def _post_mediagroup(
self, text: str, media: List[Attachment]
) -> TGResponse:
media_list: List[dict] = [] media_list: List[dict] = []
allowed_medias = {"image", "gifv", "video", "audio", "unknown"} allowed_medias = {"image", "gifv", "video", "audio", "unknown"}
for attachment in media: for attachment in media:
@ -136,7 +137,9 @@ class TelegramIntegration(BaseIntegration):
str.join("", map(cls.node_to_text, el.children)), str.join("", map(cls.node_to_text, el.children)),
) )
elif el.name == "p": elif el.name == "p":
return str.join("", map(cls.node_to_text, el.children)) + "\n\n" return (
str.join("", map(cls.node_to_text, el.children)) + "\n\n"
)
elif el.name == "br": elif el.name == "br":
return "\n" return "\n"
return str.join("", map(cls.node_to_text, el.children)) return str.join("", map(cls.node_to_text, el.children))
@ -144,7 +147,9 @@ class TelegramIntegration(BaseIntegration):
async def post(self, status: Status) -> Optional[str]: async def post(self, status: Status) -> Optional[str]:
source = status.reblog or status source = status.reblog or status
text = self.node_to_text(BeautifulSoup(source.content, features="lxml")) text = self.node_to_text(
BeautifulSoup(source.content, features="lxml")
)
text = text.rstrip() text = text.rstrip()
if source.spoiler_text: if source.spoiler_text:
@ -173,12 +178,16 @@ class TelegramIntegration(BaseIntegration):
elif len(source.media_attachments) == 1: elif len(source.media_attachments) == 1:
if ( if (
res := await self._post_media(text, source.media_attachments[0]) res := await self._post_media(
text, source.media_attachments[0]
)
).ok and res.result is not None: ).ok and res.result is not None:
ids.append(res.result["message_id"]) ids.append(res.result["message_id"])
else: else:
if ( if (
res := await self._post_mediagroup(text, source.media_attachments) res := await self._post_mediagroup(
text, source.media_attachments
)
).ok and res.result is not None: ).ok and res.result is not None:
ids.append(res.result["message_id"]) ids.append(res.result["message_id"])
@ -203,5 +212,5 @@ class TelegramIntegration(BaseIntegration):
chat=self.chat_id, chat=self.chat_id,
show_post_link=self.show_post_link, show_post_link=self.show_post_link,
show_boost_from=self.show_boost_from, show_boost_from=self.show_boost_from,
silent=self.silent silent=self.silent,
) )

View File

@ -15,7 +15,7 @@ async def websocket_source(
while True: while True:
try: try:
async with connect(url) as ws: async with connect(url) as ws:
while (msg := await ws.recv()) != None: while (msg := await ws.recv()) is not None:
event = loads(msg) event = loads(msg)
if "error" in event: if "error" in event:
raise Exception(event["error"]) raise Exception(event["error"])

View File

@ -1,6 +1,25 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from typing import Optional, List, Literal from typing import Any, Callable, Optional, List, Literal, TypeVar
def _date(val: str) -> datetime:
return datetime.fromisoformat(val.rstrip("Z"))
T = TypeVar("T")
def _fnil(fn: Callable[[Any], T], val: Optional[Any]) -> Optional[T]:
return None if val is None else fn(val)
def _date_or_none(val: Optional[str]) -> Optional[datetime]:
return _fnil(_date, val)
def _int_or_none(val: Optional[str]) -> Optional[int]:
return _fnil(int, val)
@dataclass @dataclass
@ -14,11 +33,7 @@ class Field:
return cls( return cls(
name=data["name"], name=data["name"],
value=data["value"], value=data["value"],
verified_at=( verified_at=_date_or_none(data.get("verified_at")),
datetime.fromisoformat(data["verified_at"].rstrip("Z"))
if data.get("verified_at") is not None
else None
),
) )
@ -75,16 +90,12 @@ class Account:
locked=data["locked"], locked=data["locked"],
emojis=list(map(Emoji.from_dict, data["emojis"])), emojis=list(map(Emoji.from_dict, data["emojis"])),
discoverable=data["discoverable"], discoverable=data["discoverable"],
created_at=datetime.fromisoformat(data["created_at"].rstrip("Z")), created_at=_date(data["created_at"]),
last_status_at=datetime.fromisoformat(data["last_status_at"].rstrip("Z")), last_status_at=_date(data["last_status_at"]),
statuses_count=data["statuses_count"], statuses_count=data["statuses_count"],
followers_count=data["followers_count"], followers_count=data["followers_count"],
following_count=data["following_count"], following_count=data["following_count"],
moved=( moved=_fnil(Account.from_dict, data.get("moved")),
Account.from_dict(data["moved"])
if data.get("moved") is not None
else None
),
fields=list(map(Field.from_dict, data.get("fields", []))), fields=list(map(Field.from_dict, data.get("fields", []))),
bot=bool(data.get("bot")), bot=bool(data.get("bot")),
) )
@ -228,19 +239,11 @@ class Poll:
def from_dict(cls, data: dict) -> "Poll": def from_dict(cls, data: dict) -> "Poll":
return cls( return cls(
id=data["id"], id=data["id"],
expires_at=( expires_at=_date_or_none(data.get("expires_at")),
datetime.fromisoformat(data["expires_at"].rstrip("Z"))
if data.get("expires_at") is not None
else None
),
expired=data["expired"], expired=data["expired"],
multiple=data["multiple"], multiple=data["multiple"],
votes_count=data["votes_count"], votes_count=data["votes_count"],
voters_count=( voters_count=_int_or_none(data.get("voters_count")),
int(data["voters_count"])
if data.get("voters_count") is not None
else None
),
options=[cls.PollOption(**opt) for opt in data["options"]], options=[cls.PollOption(**opt) for opt in data["options"]],
) )
@ -274,7 +277,7 @@ class Status:
return cls( return cls(
id=data["id"], id=data["id"],
uri=data["uri"], uri=data["uri"],
created_at=datetime.fromisoformat(data["created_at"].rstrip("Z")), created_at=_date(data["created_at"]),
account=Account.from_dict(data["account"]), account=Account.from_dict(data["account"]),
content=data["content"], content=data["content"],
visibility=data["visibility"], visibility=data["visibility"],
@ -283,25 +286,15 @@ class Status:
media_attachments=list( media_attachments=list(
map(Attachment.from_dict, data["media_attachments"]) map(Attachment.from_dict, data["media_attachments"])
), ),
application=( application=_fnil(Application.from_dict, data.get("application")),
Application.from_dict(data["application"])
if data.get("application") is not None
else None
),
reblogs_count=data["reblogs_count"], reblogs_count=data["reblogs_count"],
favourites_count=data["favourites_count"], favourites_count=data["favourites_count"],
replies_count=data["replies_count"], replies_count=data["replies_count"],
url=data.get("url"), url=data.get("url"),
in_reply_to_id=data.get("in_reply_to_id"), in_reply_to_id=data.get("in_reply_to_id"),
in_reply_to_account_id=data.get("in_reply_to_account_id"), in_reply_to_account_id=data.get("in_reply_to_account_id"),
reblog=( reblog=_fnil(Status.from_dict, data.get("reblog")),
Status.from_dict(data["reblog"]) poll=_fnil(Poll.from_dict, data.get("poll")),
if data.get("reblog") is not None
else None
),
poll=(
Poll.from_dict(data["poll"]) if data.get("poll") is not None else None
),
card=data.get("card"), card=data.get("card"),
language=data.get("language"), language=data.get("language"),
text=data.get("text"), text=data.get("text"),