diff --git a/src/commands/level.rs b/src/commands/level.rs index 83ccc2a..0727e79 100644 --- a/src/commands/level.rs +++ b/src/commands/level.rs @@ -43,6 +43,50 @@ impl TypeMapKey for DbKey { type Value = Surreal; } +/// Helper struct for deserializing u64 from numbers or strings +#[derive(Debug, Clone, Copy, serde::Serialize)] +pub struct FluxU64(pub u64); + +impl<'de> Deserialize<'de> for FluxU64 { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct FluxU64Visitor; + + impl<'de> serde::de::Visitor<'de> for FluxU64Visitor { + type Value = FluxU64; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a u64, i64, or string representing a u64") + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(FluxU64(value)) + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(FluxU64(value as u64)) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + value.parse::().map(FluxU64).map_err(serde::de::Error::custom) + } + } + + deserializer.deserialize_any(FluxU64Visitor) + } +} + #[derive(Deserialize, serde::Serialize, Debug)] pub struct UserLevel { pub xp: u64, @@ -105,7 +149,8 @@ impl<'de> Deserialize<'de> for TrackLevelRole { if role_id.is_some() { return Err(serde::de::Error::duplicate_field("role_id")); } - role_id = Some(map.next_value()?); + let val: FluxU64 = map.next_value()?; + role_id = Some(val.0); } "level" => { if level.is_some() { @@ -133,9 +178,9 @@ impl<'de> Deserialize<'de> for TrackLevelRole { #[derive(Deserialize)] struct GuildRecord { level_role_stack: Option>>, - levelup_channel: Option, + levelup_channel: Option, levelup_message: Option, - level_up_role_mapper: Option>, + level_up_role_mapper: Option>, } #[derive(Deserialize, Debug)] @@ -330,8 +375,8 @@ pub async fn process_message( // Determine target channel let target_channel_id = if let Some(record) = &guild_record { - if let Some(channel_id) = record.levelup_channel { - serenity::ChannelId::new(channel_id) + if let Some(channel_id) = &record.levelup_channel { + serenity::ChannelId::new(channel_id.0) } else { msg.channel_id } @@ -820,7 +865,7 @@ pub async fn on_guild_member_update( for (_, other_out_role_id) in &mapper { if new .roles - .contains(&serenity::RoleId::new(*other_out_role_id)) + .contains(&serenity::RoleId::new(other_out_role_id.0)) { has_any_out_role = true; break; @@ -842,7 +887,7 @@ pub async fn on_guild_member_update( .add_member_role( guild_id, new.user.id, - serenity::RoleId::new(*out_role_id), + serenity::RoleId::new(out_role_id.0), Some("Role Bridge"), ) .await @@ -891,3 +936,42 @@ pub async fn on_guild_member_update( Ok(()) } + +// Add tests to verify fix +#[cfg(test)] +mod tests { + use super::*; + use serde_json::from_str; + + #[test] + fn test_deserialize_flux_u64() { + let json_num = "12345"; + let val_num: FluxU64 = from_str(json_num).unwrap(); + assert_eq!(val_num.0, 12345); + + let json_str = "\"12345\""; + let val_str: FluxU64 = from_str(json_str).unwrap(); + assert_eq!(val_str.0, 12345); + } + + #[test] + fn test_deserialize_track_level_role() { + let json = r#"{"role_id": "123456789", "level": 5}"#; + let role: TrackLevelRole = from_str(json).expect("Should deserialize with string ID"); + assert_eq!(role.role_id, 123456789); + assert_eq!(role.level, 5); + } + + #[test] + fn test_deserialize_guild_record() { + let json = r#"{ + "levelup_channel": "987654321", + "level_up_role_mapper": { + "track1": "111222333" + } + }"#; + let record: GuildRecord = from_str(json).expect("Should deserialize GuildRecord"); + assert_eq!(record.levelup_channel.unwrap().0, 987654321); + assert_eq!(record.level_up_role_mapper.unwrap().get("track1").unwrap().0, 111222333); + } +}