added optimization for leaderboard

This commit is contained in:
2026-01-01 04:13:55 +05:30
parent 0cb8a279f6
commit d7e28b6d91
7 changed files with 967 additions and 67 deletions

355
src/api.rs Normal file
View File

@@ -0,0 +1,355 @@
use actix_web::{web, App, HttpServer, HttpResponse, post, get, error, body::EitherBody, dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}};
use futures::future::LocalBoxFuture;
use futures::FutureExt;
use serde::{Deserialize, Serialize};
use serenity::cache::Cache;
use std::sync::Arc;
use surrealdb::Surreal;
use surrealdb::engine::remote::ws::Client;
use tracing::{info, warn};
#[derive(Deserialize)]
pub struct IsBotThereRequest {
pub guild_ids: Vec<u64>,
}
#[derive(Serialize)]
pub struct IsBotThereResponse {
pub results: Vec<bool>,
}
pub struct ApiState {
pub cache: Arc<Cache>,
pub db: Surreal<Client>,
pub api_key: String,
}
#[derive(Serialize)]
pub struct LeaderboardMember {
pub user_id: String,
pub username: String,
pub avatar: String,
pub level: u64,
pub xp: u64,
pub rank: usize,
}
#[derive(Serialize)]
pub struct Role {
pub role_id: String,
pub role_name: String,
}
#[derive(Serialize)]
pub struct LevelRole {
pub id: String,
pub name: String,
}
#[derive(Deserialize)]
struct GuildRecord {
level_role_stack: Option<std::collections::HashMap<String, Vec<u64>>>,
}
#[derive(Deserialize, Debug)]
struct LeaderboardEntry {
id: surrealdb::sql::Thing,
xp: u64,
level: u64,
}
/// API Key Authentication Middleware
pub struct ApiKeyMiddleware {
api_key: String,
}
impl ApiKeyMiddleware {
pub fn new(api_key: String) -> Self {
ApiKeyMiddleware { api_key }
}
}
impl<S, B> Transform<S, ServiceRequest> for ApiKeyMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = error::Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = error::Error;
type InitError = ();
type Transform = ApiKeyMiddlewareService<S>;
type Future = futures::future::Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
futures::future::ok(ApiKeyMiddlewareService {
service,
api_key: self.api_key.clone(),
})
}
}
pub struct ApiKeyMiddlewareService<S> {
service: S,
api_key: String,
}
impl<S, B> Service<ServiceRequest> for ApiKeyMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = error::Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = error::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let api_key = self.api_key.clone();
let path = req.path().to_string();
// Check API key from header
let header_key = req
.headers()
.get("X-API-Key")
.and_then(|h| h.to_str().ok())
.unwrap_or("")
.to_string();
if header_key.is_empty() || header_key != api_key {
warn!("Unauthorized API request to {} - missing or invalid API key", path);
return Box::pin(async move {
Err(error::ErrorUnauthorized("Missing or invalid API key"))
});
}
info!("Authorized API request to {}", path);
Box::pin(
self.service
.call(req)
.then(|res: Result<ServiceResponse<B>, error::Error>| async move {
Ok(res?.map_into_left_body())
}),
)
}
}
#[post("/api/is_bot_there")]
async fn is_bot_there(
request_body: web::Json<IsBotThereRequest>,
data: web::Data<ApiState>,
) -> Result<HttpResponse, actix_web::Error> {
info!(
"Processing /api/is_bot_there request with {} guild IDs",
request_body.guild_ids.len()
);
let results: Vec<bool> = request_body
.guild_ids
.iter()
.map(|guild_id| {
let guild_exists = data.cache.guild(*guild_id).is_some();
guild_exists
})
.collect();
let found_count = results.iter().filter(|&&b| b).count();
info!(
"Bot found in {}/{} requested guilds",
found_count,
results.len()
);
Ok(HttpResponse::Ok().json(IsBotThereResponse { results }))
}
#[get("/api/{guild_id}/leaderboard")]
async fn get_leaderboard(
guild_id: web::Path<u64>,
data: web::Data<ApiState>,
) -> Result<HttpResponse, actix_web::Error> {
let guild_id_value = guild_id.into_inner();
info!("Processing /api/{}/leaderboard request", guild_id_value);
// Check if bot is in the guild
if data.cache.guild(guild_id_value).is_none() {
warn!("Bot is not in guild {}", guild_id_value);
return Err(error::ErrorNotFound("Bot is not in this guild"));
}
// Query all users for this guild, ordered by level and xp
let sql = "SELECT * FROM levels WHERE string::starts_with(record::id(id), $prefix) ORDER BY level DESC, xp DESC";
let prefix = format!("{}:", guild_id_value);
let mut response = data.db.query(sql).bind(("prefix", prefix)).await
.map_err(|e| {
warn!("Database query error: {}", e);
error::ErrorInternalServerError("Database query failed")
})?;
let entries: Vec<LeaderboardEntry> = response.take(0)
.map_err(|e| {
warn!("Failed to parse database response: {}", e);
error::ErrorInternalServerError("Failed to parse database response")
})?;
if entries.is_empty() {
info!("No leaderboard data found for guild {}", guild_id_value);
return Ok(HttpResponse::Ok().json(Vec::<LeaderboardMember>::new()));
}
info!("Found {} members in leaderboard for guild {}", entries.len(), guild_id_value);
// Fetch user data for all entries
let mut leaderboard: Vec<LeaderboardMember> = Vec::new();
for (i, entry) in entries.iter().enumerate() {
// Extract user id from Surreal Thing
let id_value = &entry.id.id;
let clean_id_str = match id_value {
surrealdb::sql::Id::String(s) => s.as_str().to_string(),
_ => entry.id.id.to_string(),
};
let parts: Vec<&str> = clean_id_str.split(':').collect();
let user_id_str = parts.last().unwrap_or(&"0");
let user_id_u64 = user_id_str.parse::<u64>().unwrap_or(0);
if user_id_u64 == 0 {
continue;
}
let mut username = String::from("Unknown User");
let mut avatar_url = String::new();
let user_id = serenity::all::UserId::new(user_id_u64);
// Try to get member info from cache first
if let Some(guild) = data.cache.guild(guild_id_value) {
if let Some(member) = guild.members.get(&user_id) {
username = member.display_name().to_string();
avatar_url = member.user.face();
}
}
// If not in cache, we'll just use Unknown User
if avatar_url.is_empty() {
info!("User {} not in cache for guild {}", user_id_u64, guild_id_value);
}
leaderboard.push(LeaderboardMember {
user_id: user_id_u64.to_string(),
username,
avatar: avatar_url,
level: entry.level,
xp: entry.xp,
rank: i + 1,
});
}
info!("Returning {} members in leaderboard", leaderboard.len());
Ok(HttpResponse::Ok().json(leaderboard))
}
#[get("/api/{guild_id}/roles")]
async fn get_roles(
guild_id: web::Path<u64>,
data: web::Data<ApiState>,
) -> Result<HttpResponse, actix_web::Error> {
let guild_id_value = guild_id.into_inner();
info!("Processing /api/{}/roles request", guild_id_value);
// Check if bot is in the guild
let guild = data.cache.guild(guild_id_value)
.ok_or_else(|| {
warn!("Bot is not in guild {}", guild_id_value);
error::ErrorNotFound("Bot is not in this guild")
})?;
// Get all roles from the guild
let roles: Vec<Role> = guild.roles
.iter()
.map(|(role_id, role)| Role {
role_id: role_id.to_string(),
role_name: role.name.clone(),
})
.collect();
info!("Returning {} roles for guild {}", roles.len(), guild_id_value);
Ok(HttpResponse::Ok().json(roles))
}
#[get("/api/{guild_id}/level/track")]
async fn get_level_tracks(
guild_id: web::Path<u64>,
data: web::Data<ApiState>,
) -> Result<HttpResponse, actix_web::Error> {
let guild_id_value = guild_id.into_inner();
info!("Processing /api/{}/level/track request", guild_id_value);
// Check if bot is in the guild
let guild = data.cache.guild(guild_id_value)
.ok_or_else(|| {
warn!("Bot is not in guild {}", guild_id_value);
error::ErrorNotFound("Bot is not in this guild")
})?;
// Query the guild record from database
let record: Option<GuildRecord> = data.db.select(("guilds", guild_id_value.to_string())).await
.map_err(|e| {
warn!("Database query error: {}", e);
error::ErrorInternalServerError("Database query failed")
})?;
let mut result: std::collections::HashMap<String, Vec<LevelRole>> = std::collections::HashMap::new();
if let Some(record) = record {
if let Some(level_role_stack) = record.level_role_stack {
for (track_name, role_ids) in level_role_stack {
let mut roles: Vec<LevelRole> = Vec::new();
for role_id in role_ids {
let role_id_obj = serenity::all::RoleId::new(role_id);
let role_name = guild.roles.get(&role_id_obj)
.map(|r| r.name.clone())
.unwrap_or_else(|| "Unknown Role".to_string());
roles.push(LevelRole {
id: role_id.to_string(),
name: role_name,
});
}
result.insert(track_name, roles);
}
}
}
info!("Returning {} level tracks for guild {}", result.len(), guild_id_value);
Ok(HttpResponse::Ok().json(result))
}
pub async fn start_api_server(
cache: Arc<Cache>,
db: Surreal<Client>,
api_key: String,
port: u16,
) -> std::io::Result<()> {
info!("Starting API server on port {}", port);
let state = web::Data::new(ApiState { cache, db, api_key: api_key.clone() });
HttpServer::new(move || {
App::new()
.app_data(state.clone())
.wrap(ApiKeyMiddleware::new(api_key.clone()))
.service(is_bot_there)
.service(get_leaderboard)
.service(get_roles)
.service(get_level_tracks)
})
.bind(("0.0.0.0", port))?
.run()
.await
}

View File

@@ -1,18 +1,42 @@
use crate::{Context, Error};
use ab_glyph::{FontRef, PxScale};
use image::{ImageBuffer, Rgba};
use imageproc::drawing::{draw_filled_rect_mut, draw_text_mut};
use image::{ImageBuffer, Rgba, DynamicImage};
use imageproc::drawing::{draw_filled_rect_mut, draw_filled_circle_mut, draw_text_mut};
use imageproc::rect::Rect;
use poise::serenity_prelude as serenity;
use serde::Deserialize;
use serenity::Mentionable;
use std::collections::HashMap;
use std::io::Cursor;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use serenity::prelude::TypeMapKey;
use surrealdb::Surreal;
use surrealdb::engine::remote::ws::Client;
use futures::future::join_all;
use tokio::sync::RwLock;
/// Global avatar cache - stores decoded images by URL
/// TTL is handled by periodic cleanup or could use moka crate
static AVATAR_CACHE: OnceLock<Arc<RwLock<HashMap<String, Arc<DynamicImage>>>>> = OnceLock::new();
fn get_avatar_cache() -> &'static Arc<RwLock<HashMap<String, Arc<DynamicImage>>>> {
AVATAR_CACHE.get_or_init(|| Arc::new(RwLock::new(HashMap::new())))
}
/// Global HTTP client for avatar fetches - connection pooling
static AVATAR_CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
fn get_avatar_client() -> &'static reqwest::Client {
AVATAR_CLIENT.get_or_init(|| {
reqwest::Client::builder()
.timeout(Duration::from_millis(800))
.pool_max_idle_per_host(10)
.build()
.unwrap_or_default()
})
}
pub struct DbKey;
@@ -516,7 +540,9 @@ struct LeaderboardRenderEntry {
level: u64,
xp: u64,
next_level_xp: u64,
avatar: Option<image::DynamicImage>,
avatar: Option<Arc<DynamicImage>>,
/// Used to generate a colored placeholder if avatar is missing
user_id: u64,
}
#[poise::command(slash_command, prefix_command, guild_only)]
@@ -542,13 +568,22 @@ pub async fn leaderboard(ctx: Context<'_>) -> Result<(), Error> {
}
// 1. Fetch all user data and avatars in parallel
let avatar_client = get_avatar_client();
let avatar_cache = get_avatar_cache();
let mut tasks = Vec::new();
// Try to get guild from cache first (much faster than API calls)
let cached_guild = guild_id.to_guild_cached(&ctx.serenity_context().cache).map(|g| g.clone());
for (i, entry) in entries.iter().enumerate() {
let ctx = ctx.clone(); // Clone for async
let http = ctx.serenity_context().http.clone();
let entry_level = entry.level;
let entry_xp = entry.xp;
let guild_id = guild_id; // capture
let guild_id = guild_id;
let avatar_client = avatar_client.clone();
let avatar_cache = avatar_cache.clone();
let cached_guild = cached_guild.clone();
// Extract user id robustly from Surreal `Thing`
let id_value = &entry.id.id;
@@ -567,38 +602,78 @@ pub async fn leaderboard(ctx: Context<'_>) -> Result<(), Error> {
if user_id_u64 != 0 {
let user_id = serenity::UserId::new(user_id_u64);
// Prefer guild display name (nickname) when available
if let Ok(member) = ctx.http().get_member(guild_id, user_id).await {
// Try cache first (instant), then fall back to API
let member_from_cache = cached_guild.as_ref()
.and_then(|g| g.members.get(&user_id).cloned());
if let Some(member) = member_from_cache {
user_name = member.display_name().to_string();
let face = member.user.face();
avatar_url = normalize_avatar_url(&face);
} else {
// Fallback to user object
if let Ok(user) = ctx.http().get_user(user_id).await {
user_name = user
.global_name
.as_ref()
.unwrap_or(&user.name)
.to_string();
let face = user.face();
avatar_url = normalize_avatar_url(&face);
}
avatar_url = get_small_avatar_url(&member.user);
} else if let Ok(member) = http.get_member(guild_id, user_id).await {
user_name = member.display_name().to_string();
avatar_url = get_small_avatar_url(&member.user);
} else if let Ok(user) = http.get_user(user_id).await {
user_name = user.global_name.as_ref().unwrap_or(&user.name).to_string();
avatar_url = get_small_avatar_url(&user);
}
} else {
// keep defaults
}
// Fetch avatar image if we have a URL
let mut avatar_img = None;
if !avatar_url.is_empty() {
if let Ok(response) = reqwest::get(&avatar_url).await {
if let Ok(bytes) = response.bytes().await {
if let Ok(img) = image::load_from_memory(&bytes) {
avatar_img = Some(img);
// Check avatar cache first
let avatar_img: Option<Arc<DynamicImage>> = if !avatar_url.is_empty() {
// Fast path: check cache
{
let cache = avatar_cache.read().await;
if let Some(cached) = cache.get(&avatar_url) {
Some(cached.clone())
} else {
None
}
}.or_else(|| None).map(Some).unwrap_or_else(|| {
// Cache miss - need to fetch (will be done below)
None
})
} else {
None
};
// If not in cache, fetch with short timeout
let avatar_img = if avatar_img.is_some() {
avatar_img
} else if !avatar_url.is_empty() {
// Use tokio timeout for precise control
let fetch_result = tokio::time::timeout(
Duration::from_millis(500),
async {
match avatar_client.get(&avatar_url).send().await {
Ok(response) => response.bytes().await.ok(),
Err(_) => None,
}
}
}
).await;
match fetch_result {
Ok(Some(bytes)) => {
if let Ok(img) = image::load_from_memory(&bytes) {
let arc_img = Arc::new(img);
// Store in cache for next time
{
let mut cache = avatar_cache.write().await;
// Limit cache size to prevent memory issues
if cache.len() < 500 {
cache.insert(avatar_url, arc_img.clone());
}
}
Some(arc_img)
} else {
None
}
}
_ => None,
}
}
} else {
None
};
LeaderboardRenderEntry {
username: user_name,
@@ -607,6 +682,7 @@ pub async fn leaderboard(ctx: Context<'_>) -> Result<(), Error> {
xp: entry_xp,
next_level_xp: (entry_level + 1) * 100,
avatar: avatar_img,
user_id: user_id_u64,
}
});
}
@@ -622,7 +698,7 @@ pub async fn leaderboard(ctx: Context<'_>) -> Result<(), Error> {
ctx.send(
poise::CreateReply::default().attachment(serenity::CreateAttachment::bytes(
image_data,
"leaderboard.png",
"leaderboard.jpg",
)),
)
.await?;
@@ -630,30 +706,54 @@ pub async fn leaderboard(ctx: Context<'_>) -> Result<(), Error> {
Ok(())
}
fn normalize_avatar_url(url: &str) -> String {
if url.is_empty() { return String::new(); }
// Prefer PNG to ensure decoder compatibility; preserve size if present
// Replace extension .webp -> .png, and enforce format=png when query exists
// Simple approach: if it contains ".webp", swap to ".png"; also add "?size=128" if none
let mut out = url.replace(".webp", ".png");
if !out.contains("format=") && out.contains("cdn.discordapp.com") {
if out.contains('?') { out.push_str("&format=png"); } else { out.push_str("?format=png"); }
/// Get a small avatar URL (48px) - webp is faster to download & decode
fn get_small_avatar_url(user: &serenity::User) -> String {
// Use webp format (smaller file size, faster download) and small size
// The image crate has webp support enabled
match &user.avatar {
Some(hash) => {
let ext = if hash.is_animated() { "gif" } else { "webp" };
format!(
"https://cdn.discordapp.com/avatars/{}/{}.{}?size=48",
user.id, hash, ext
)
}
None => {
// Default avatar - use small size
let index = if let Some(discrim) = user.discriminator {
discrim.get() % 5
} else {
((user.id.get() >> 22) % 6) as u16
};
format!(
"https://cdn.discordapp.com/embed/avatars/{}.png?size=48",
index
)
}
}
if !out.contains("size=") {
if out.contains('?') { out.push_str("&size=128"); } else { out.push_str("?size=128"); }
}
out
}
/// Pre-defined colors for avatar placeholders based on user ID
const PLACEHOLDER_COLORS: [Rgba<u8>; 8] = [
Rgba([114, 137, 218, 255]), // Blurple
Rgba([67, 181, 129, 255]), // Green
Rgba([250, 166, 26, 255]), // Yellow
Rgba([240, 71, 71, 255]), // Red
Rgba([255, 115, 250, 255]), // Pink
Rgba([26, 188, 156, 255]), // Teal
Rgba([230, 126, 34, 255]), // Orange
Rgba([155, 89, 182, 255]), // Purple
];
fn generate_leaderboard_image(
entries: Vec<LeaderboardRenderEntry>,
) -> Result<Vec<u8>, Error> {
// Image dimensions
let width = 800;
let width = 800u32;
let height = 100 + (entries.len() as u32 * 80); // Header + rows
let mut image = ImageBuffer::from_pixel(width, height, Rgba([40, 44, 52, 255])); // Dark background
// Load font
// Load font once (compiled into binary)
let font_bytes = include_bytes!("../assets/Roboto-Regular.ttf");
let font = FontRef::try_from_slice(font_bytes)?;
@@ -683,11 +783,33 @@ fn generate_leaderboard_image(
&format!("#{}", entry.rank),
);
// Draw Avatar
// Draw Avatar or colored circle placeholder
let avatar_x = 80i64;
let avatar_y = y_offset as i64 + 16;
if let Some(avatar_img) = &entry.avatar {
let avatar_resized =
avatar_img.resize(60, 60, image::imageops::FilterType::Lanczos3);
image::imageops::overlay(&mut image, &avatar_resized, 80, y_offset as i64 + 10);
// Use Nearest filter - fastest possible, good enough for small avatars
let avatar_resized = avatar_img.resize_exact(48, 48, image::imageops::FilterType::Nearest);
image::imageops::overlay(&mut image, &avatar_resized, avatar_x, avatar_y);
} else {
// Draw colored circle placeholder based on user ID
let color_idx = (entry.user_id % 8) as usize;
let color = PLACEHOLDER_COLORS[color_idx];
let center_x = (avatar_x + 24) as i32;
let center_y = (avatar_y + 24) as i32;
draw_filled_circle_mut(&mut image, (center_x, center_y), 24, color);
// Draw first letter of username
let first_char = entry.username.chars().next().unwrap_or('?').to_uppercase().to_string();
draw_text_mut(
&mut image,
white,
center_x - 8,
center_y - 12,
scale_text,
&font,
&first_char,
);
}
// Draw Username
@@ -713,14 +835,14 @@ fn generate_leaderboard_image(
);
// Draw XP Bar
let bar_width = 300;
let bar_height = 20;
let bar_x = 450;
let bar_y = y_offset + 30;
let bar_width = 300u32;
let bar_height = 20u32;
let bar_x = 450i32;
let bar_y = y_offset as i32 + 30;
draw_filled_rect_mut(
&mut image,
Rect::at(bar_x as i32, bar_y as i32).of_size(bar_width, bar_height),
Rect::at(bar_x, bar_y).of_size(bar_width, bar_height),
bar_bg,
);
@@ -734,7 +856,7 @@ fn generate_leaderboard_image(
if fill_width > 0 {
draw_filled_rect_mut(
&mut image,
Rect::at(bar_x as i32, bar_y as i32).of_size(fill_width, bar_height),
Rect::at(bar_x, bar_y).of_size(fill_width, bar_height),
bar_fill,
);
}
@@ -744,16 +866,20 @@ fn generate_leaderboard_image(
draw_text_mut(
&mut image,
white,
bar_x as i32 + 5,
bar_y as i32 + 2, // Centering vertically roughly
bar_x + 5,
bar_y + 2,
PxScale::from(14.0),
&font,
&xp_text,
);
}
let mut bytes: Vec<u8> = Vec::new();
image.write_to(&mut Cursor::new(&mut bytes), image::ImageFormat::Png)?;
// Use JPEG for much faster encoding (PNG is slow)
// Quality 85 is a good balance of size vs quality
let mut bytes: Vec<u8> = Vec::with_capacity(width as usize * height as usize);
let rgb_image = DynamicImage::ImageRgba8(image).into_rgb8();
let mut encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut bytes, 85);
encoder.encode_image(&rgb_image)?;
Ok(bytes)
}

View File

@@ -35,6 +35,9 @@ impl EventHandler for Handler {
if let Err(e) = crate::commands::utility::process_auto_response(&ctx, &msg).await {
tracing::error!("Error processing message for auto-response: {}", e);
}
if let Err(e) = crate::commands::fun::handle_ai_chat(&ctx, &msg).await {
tracing::error!("Error processing message for AI chat: {}", e);
}
}
async fn guild_member_update(

View File

@@ -1,10 +1,11 @@
mod api;
mod commands;
mod listener;
use ::serenity::all::GatewayIntents;
use ::serenity::all::{GatewayIntents, UserId};
use dotenvy::dotenv;
use poise::{Framework, FrameworkOptions, serenity_prelude as serenity};
use std::{env, sync::Arc};
use std::{collections::HashSet, env, sync::Arc};
use surrealdb::Surreal;
use surrealdb::engine::remote::ws::{Client, Wss};
use surrealdb::opt::auth::Root;
@@ -16,6 +17,7 @@ type Context<'a> = poise::Context<'a, Data, Error>;
struct Data {
db: Surreal<Client>,
ai_chat: Arc<commands::fun::AiChatManager>,
}
#[tokio::main()]
@@ -32,6 +34,8 @@ async fn main() -> Result<(), Error> {
let token = env::var("DISCORD_TOKEN")?;
let api_key = env::var("API_KEY").expect("Expected API_KEY in environment");
let surreal_address =
env::var("SURREAL_ADDRESS").expect("Expected SURREAL_ADDRESS in environment");
let surreal_user = env::var("SURREAL_USER").expect("Expected SURREAL_USER in environment");
@@ -39,6 +43,28 @@ async fn main() -> Result<(), Error> {
let surreal_ns = env::var("SURREAL_NS").expect("Expected SURREAL_NS in environment");
let surreal_db = env::var("SURREAL_DB").expect("Expected SURREAL_DB in environment");
let ollama_url = env::var("OLLAMA_SERVER_URL").expect("Expected OLLAMA_SERVER_URL in environment");
let ollama_model = env::var("OLLAMA_MODEL").unwrap_or_else(|_| "llama3".to_string());
let ai_chat_cooldown_ms = env::var("AI_CHAT_COOLDOWN_MS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(1500);
let ignore_rude: Vec<String> = env::var("IGNORE_RUDE")
.unwrap_or_default()
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect();
let ai_chat_manager = Arc::new(commands::fun::AiChatManager::new(
ollama_url,
ollama_model,
std::time::Duration::from_millis(ai_chat_cooldown_ms),
ignore_rude,
));
let db = Surreal::new::<Wss>(&surreal_address).await?;
db.signin(Root {
@@ -51,8 +77,17 @@ async fn main() -> Result<(), Error> {
db.use_ns(&surreal_ns).use_db(&surreal_db).await?;
let db_clone = db.clone();
let ai_chat_clone = ai_chat_manager.clone();
let owner_id = env::var("BOT_OWNER_ID")
.expect("Expected BOT_OWNER_ID in environment")
.parse::<u64>()?;
let mut owners = HashSet::new();
owners.insert(UserId::new(owner_id));
let framework = Framework::builder()
.options(FrameworkOptions::<Data, Error> {
owners,
commands: vec![
commands::level::set_level_roles(),
commands::level::get_level_roles(),
@@ -62,10 +97,12 @@ async fn main() -> Result<(), Error> {
commands::level::levelup_role_bridger(),
commands::fun::say(),
commands::fun::urban(),
commands::fun::ai_chat(),
commands::utility::auto_response(),
commands::utility::view_auto_responses(),
commands::utility::delete_auto_response(),
commands::utility::edit_auto_response(),
commands::utility::summary(),
],
prefix_options: poise::PrefixFrameworkOptions {
prefix: Some("!".into()),
@@ -80,7 +117,10 @@ async fn main() -> Result<(), Error> {
.setup(move |context, _ready, framework| {
Box::pin(async move {
poise::builtins::register_globally(context, &framework.options().commands).await?;
Ok(Data { db: db_clone })
Ok(Data {
db: db_clone,
ai_chat: ai_chat_clone,
})
})
})
.build();
@@ -94,10 +134,37 @@ async fn main() -> Result<(), Error> {
let mut data = client.data.write().await;
data.insert::<commands::level::DbKey>(db.clone());
data.insert::<commands::utility::DbKey>(db.clone());
data.insert::<commands::fun::AiChatKey>(ai_chat_manager.clone());
}
if let Err(why) = client.start_autosharded().await {
eprintln!("An error occurred while running the client: {why}");
// Get API port from environment, default to 8080
let api_port = env::var("API_PORT")
.ok()
.and_then(|v| v.parse::<u16>().ok())
.unwrap_or(8080);
// Run bot and API server concurrently
let cache = client.cache.clone();
let db_for_api = db.clone();
let bot_task = tokio::spawn(async move {
if let Err(why) = client.start_autosharded().await {
eprintln!("An error occurred while running the client: {why}");
}
});
let api_task = api::start_api_server(cache, db_for_api, api_key, api_port);
// Wait for API to start, then run bot
tokio::select! {
_ = bot_task => {
info!("Bot task finished");
}
result = api_task => {
match result {
Ok(_) => info!("API server finished"),
Err(e) => eprintln!("API server error: {}", e),
}
}
}
Ok(())