From ccf44a53757485de4f8242d105671047ead5e913 Mon Sep 17 00:00:00 2001 From: Flori <39242991+bitfl0wer@users.noreply.github.com> Date: Fri, 21 Jul 2023 15:35:31 +0200 Subject: [PATCH] Auto updating structs (#163) * Gateway fields don't need to be pub * Add store to Gateway * Add UpdateMessage trait * Proof of concept: impl UpdateMessage for Channel * Start working on auto updating structs * Send entity updates over watch channel * Add id to UpdateMessage * Create trait Updateable * Add documentation * add gateway test * Complete test * Impl UpdateMessage::update() for ChannelUpdate * Impl UpdateMessage::update() for ChannelUpdate Co-authored by: SpecificProtagonist * make channel::modify no longer mutate Channel * change modify call * remove unused imports * Allow dead code with TODO to remove it * fix channel::modify test * Update src/gateway.rs Co-authored-by: SpecificProtagonist --------- Co-authored-by: SpecificProtagonist --- src/api/channels/channels.rs | 8 ++-- src/gateway.rs | 73 ++++++++++++++++++++++++++++------- src/types/entities/channel.rs | 7 ++++ src/types/events/channel.rs | 11 ++++++ src/types/events/mod.rs | 24 ++++++++++++ tests/channels.rs | 7 ++-- tests/gateway.rs | 26 ++++++++++++- 7 files changed, 133 insertions(+), 23 deletions(-) diff --git a/src/api/channels/channels.rs b/src/api/channels/channels.rs index de23b4a..df8e290 100644 --- a/src/api/channels/channels.rs +++ b/src/api/channels/channels.rs @@ -64,11 +64,11 @@ impl Channel { /// /// A `Result` that contains a `Channel` object if the request was successful, or an `ChorusLibError` if an error occurred during the request. pub async fn modify( - &mut self, + &self, modify_data: ChannelModifySchema, channel_id: Snowflake, user: &mut UserMeta, - ) -> ChorusResult<()> { + ) -> ChorusResult { let chorus_request = ChorusRequest { request: Client::new() .patch(format!( @@ -80,9 +80,7 @@ impl Channel { .body(to_string(&modify_data).unwrap()), limit_type: LimitType::Channel(channel_id), }; - let new_channel = chorus_request.deserialize_response::(user).await?; - let _ = std::mem::replace(self, new_channel); - Ok(()) + chorus_request.deserialize_response::(user).await } pub async fn messages( diff --git a/src/gateway.rs b/src/gateway.rs index 9e4eb9c..a9a3749 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1,10 +1,14 @@ use crate::errors::GatewayError; use crate::gateway::events::Events; -use crate::types; -use crate::types::WebSocketEvent; +use crate::types::{self, Channel, ChannelUpdate, Snowflake}; +use crate::types::{UpdateMessage, WebSocketEvent}; use async_trait::async_trait; +use std::any::Any; +use std::collections::HashMap; +use std::fmt::Debug; use std::sync::Arc; use std::time::Duration; +use tokio::sync::watch; use tokio::time::sleep_until; use futures_util::stream::SplitSink; @@ -163,6 +167,12 @@ pub struct GatewayHandle { pub handle: JoinHandle<()>, /// Tells gateway tasks to close kill_send: tokio::sync::broadcast::Sender<()>, + store: Arc>>>, +} + +/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake. +pub trait Updateable: 'static + Send + Sync { + fn id(&self) -> Snowflake; } impl GatewayHandle { @@ -186,6 +196,27 @@ impl GatewayHandle { .unwrap(); } + pub async fn observe(&self, object: T) -> watch::Receiver { + let mut store = self.store.lock().await; + if let Some(channel) = store.get(&object.id()) { + let (_, rx) = channel + .downcast_ref::<(watch::Sender, watch::Receiver)>() + .unwrap_or_else(|| { + panic!( + "Snowflake {} already exists in the store, but it is not of type T.", + object.id() + ) + }); + rx.clone() + } else { + let id = object.id(); + let channel = watch::channel(object); + let receiver = channel.1.clone(); + store.insert(id, Box::new(channel)); + receiver + } + } + /// Sends an identify event to the gateway pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { let to_send_value = serde_json::to_value(&to_send).unwrap(); @@ -263,9 +294,9 @@ impl GatewayHandle { } pub struct Gateway { - pub events: Arc>, + events: Arc>, heartbeat_handler: HeartbeatHandler, - pub websocket_send: Arc< + websocket_send: Arc< Mutex< SplitSink< WebSocketStream>, @@ -273,8 +304,9 @@ pub struct Gateway { >, >, >, - pub websocket_receive: SplitStream>>, + websocket_receive: SplitStream>>, kill_send: tokio::sync::broadcast::Sender<()>, + store: Arc>>>, } impl Gateway { @@ -325,6 +357,8 @@ impl Gateway { let events = Events::default(); let shared_events = Arc::new(Mutex::new(events)); + let store = Arc::new(Mutex::new(HashMap::new())); + let mut gateway = Gateway { events: shared_events.clone(), heartbeat_handler: HeartbeatHandler::new( @@ -335,6 +369,7 @@ impl Gateway { websocket_send: shared_websocket_send.clone(), websocket_receive, kill_send: kill_send.clone(), + store: store.clone(), }; // Now we can continuously check for messages in a different task, since we aren't going to receive another hello @@ -348,6 +383,7 @@ impl Gateway { websocket_send: shared_websocket_send.clone(), handle, kill_send: kill_send.clone(), + store, }) } @@ -379,6 +415,7 @@ impl Gateway { /// Deserializes and updates a dispatched event, when we already know its type; /// (Called for every event in handle_message) + #[allow(dead_code)] // TODO: Remove this allow annotation async fn handle_event<'a, T: WebSocketEvent + serde::Deserialize<'a>>( data: &'a str, event: &mut GatewayEvent, @@ -431,17 +468,25 @@ impl Gateway { trace!("Gateway: Received {event_name}"); macro_rules! handle { - ($($name:literal => $($path:ident).+),*) => { + ($($name:literal => $($path:ident).+ $( $message_type:ty: $update_type:ty)?),*) => { match event_name.as_str() { $($name => { let event = &mut self.events.lock().await.$($path).+; - - let result = - Gateway::handle_event(gateway_payload.event_data.unwrap().get(), event) - .await; - - if let Err(err) = result { - warn!("Failed to parse gateway event {event_name} ({err})"); + match serde_json::from_str(gateway_payload.event_data.unwrap().get()) { + Err(err) => warn!("Failed to parse gateway event {event_name} ({err})"), + Ok(message) => { + $( + let message: $message_type = message; + if let Some(to_update) = self.store.lock().await.get(&message.id()) { + if let Some((tx, _)) = to_update.downcast_ref::<(watch::Sender<$update_type>, watch::Receiver<$update_type>)>() { + tx.send_modify(|object| message.update(object)); + } else { + warn!("Received {} for {}, but it has been observed to be a different type!", $name, message.id()) + } + } + )? + event.notify(message).await; + } } },)* "RESUMED" => (), @@ -482,7 +527,7 @@ impl Gateway { "AUTO_MODERATION_RULE_DELETE" => auto_moderation.rule_delete, "AUTO_MODERATION_ACTION_EXECUTION" => auto_moderation.action_execution, "CHANNEL_CREATE" => channel.create, - "CHANNEL_UPDATE" => channel.update, + "CHANNEL_UPDATE" => channel.update ChannelUpdate: Channel, "CHANNEL_UNREAD_UPDATE" => channel.unread_update, "CHANNEL_DELETE" => channel.delete, "CHANNEL_PINS_UPDATE" => channel.pins_update, diff --git a/src/types/entities/channel.rs b/src/types/entities/channel.rs index 01e9752..aac57b6 100644 --- a/src/types/entities/channel.rs +++ b/src/types/entities/channel.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; use serde_aux::prelude::deserialize_string_from_number; use serde_repr::{Deserialize_repr, Serialize_repr}; +use crate::gateway::Updateable; use crate::types::{ entities::{GuildMember, User}, utils::Snowflake, @@ -65,6 +66,12 @@ pub struct Channel { pub video_quality_mode: Option, } +impl Updateable for Channel { + fn id(&self) -> Snowflake { + self.id + } +} + #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] pub struct Tag { pub id: Snowflake, diff --git a/src/types/events/channel.rs b/src/types/events/channel.rs index 99c7640..b595d57 100644 --- a/src/types/events/channel.rs +++ b/src/types/events/channel.rs @@ -3,6 +3,8 @@ use crate::types::{entities::Channel, Snowflake}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use super::UpdateMessage; + #[derive(Debug, Default, Deserialize, Serialize)] /// See https://discord.com/developers/docs/topics/gateway-events#channel-pins-update pub struct ChannelPinsUpdate { @@ -31,6 +33,15 @@ pub struct ChannelUpdate { impl WebSocketEvent for ChannelUpdate {} +impl UpdateMessage for ChannelUpdate { + fn update(&self, object_to_update: &mut Channel) { + *object_to_update = self.channel.clone(); + } + fn id(&self) -> Snowflake { + self.channel.id + } +} + #[derive(Debug, Default, Deserialize, Serialize, Clone)] /// Officially undocumented. /// Sends updates to client about a new message with its id diff --git a/src/types/events/mod.rs b/src/types/events/mod.rs index 6333544..42e3912 100644 --- a/src/types/events/mod.rs +++ b/src/types/events/mod.rs @@ -26,6 +26,10 @@ pub use user::*; pub use voice::*; pub use webhooks::*; +use crate::gateway::Updateable; + +use super::Snowflake; + mod application; mod auto_moderation; mod call; @@ -95,3 +99,23 @@ pub struct GatewayReceivePayload<'a> { } impl<'a> WebSocketEvent for GatewayReceivePayload<'a> {} + +/// An [`UpdateMessage`] represents a received Gateway Message which contains updated +/// information for an [`Updateable`] of Type T. +/// # Example: +/// ```rs +/// impl UpdateMessage for ChannelUpdate { +/// fn update(...) {...} +/// fn id(...) {...} +/// } +/// ``` +/// This would imply, that the [`WebSocketEvent`] "[`ChannelUpdate`]" contains new/updated information +/// about a [`Channel`]. The update method describes how this new information will be turned into +/// a [`Channel`] object. +pub(crate) trait UpdateMessage: Clone +where + T: Updateable, +{ + fn update(&self, object_to_update: &mut T); + fn id(&self) -> Snowflake; +} diff --git a/tests/channels.rs b/tests/channels.rs index e2838be..c8564d7 100644 --- a/tests/channels.rs +++ b/tests/channels.rs @@ -28,10 +28,11 @@ async fn delete_channel() { #[tokio::test] async fn modify_channel() { + const CHANNEL_NAME: &str = "beepboop"; let mut bundle = common::setup().await; let channel = &mut bundle.channel; let modify_data: types::ChannelModifySchema = types::ChannelModifySchema { - name: Some("beepboop".to_string()), + name: Some(CHANNEL_NAME.to_string()), channel_type: None, topic: None, icon: None, @@ -49,10 +50,10 @@ async fn modify_channel() { default_thread_rate_limit_per_user: None, video_quality_mode: None, }; - Channel::modify(channel, modify_data, channel.id, &mut bundle.user) + let modified_channel = Channel::modify(channel, modify_data, channel.id, &mut bundle.user) .await .unwrap(); - assert_eq!(channel.name, Some("beepboop".to_string())); + assert_eq!(modified_channel.name, Some(CHANNEL_NAME.to_string())); let permission_override = PermissionFlags::from_vec(Vec::from([ PermissionFlags::MANAGE_CHANNELS, diff --git a/tests/gateway.rs b/tests/gateway.rs index c6f46dd..199e14b 100644 --- a/tests/gateway.rs +++ b/tests/gateway.rs @@ -1,6 +1,7 @@ mod common; + use chorus::gateway::*; -use chorus::types; +use chorus::types::{self, Channel}; #[tokio::test] /// Tests establishing a connection (hello and heartbeats) on the local gateway; @@ -22,3 +23,26 @@ async fn test_gateway_authenticate() { gateway.send_identify(identify).await; } + +#[tokio::test] +async fn test_self_updating_structs() { + let mut bundle = common::setup().await; + let gateway = Gateway::new(bundle.urls.wss).await.unwrap(); + let mut identify = types::GatewayIdentifyPayload::common(); + identify.token = bundle.user.token.clone(); + gateway.send_identify(identify).await; + let channel_receiver = gateway.observe(bundle.channel.clone()).await; + let received_channel = channel_receiver.borrow(); + assert_eq!(*received_channel, bundle.channel); + drop(received_channel); + let channel = &mut bundle.channel; + let modify_data = types::ChannelModifySchema { + name: Some("beepboop".to_string()), + ..Default::default() + }; + Channel::modify(channel, modify_data, channel.id, &mut bundle.user) + .await + .unwrap(); + let received_channel = channel_receiver.borrow(); + assert_eq!(received_channel.name.as_ref().unwrap(), "beepboop"); +}