2023-03-03 16:36:21 +01:00
|
|
|
use std::{net::SocketAddr, str::FromStr};
|
|
|
|
|
|
|
|
|
|
use axum::{
|
2023-03-19 22:59:11 +01:00
|
|
|
extract::State,
|
2023-03-03 16:36:21 +01:00
|
|
|
http::{HeaderMap, StatusCode},
|
|
|
|
|
routing::post,
|
|
|
|
|
Json, Router,
|
|
|
|
|
};
|
2023-03-19 22:59:11 +01:00
|
|
|
use base64::{alphabet, engine, Engine};
|
2023-03-03 16:36:21 +01:00
|
|
|
use ed25519_dalek::{Signature, VerifyingKey};
|
|
|
|
|
use serde::Deserialize;
|
2023-03-19 22:59:11 +01:00
|
|
|
use sqlx::{postgres::PgPoolOptions, PgPool};
|
2023-03-03 16:36:21 +01:00
|
|
|
use twilight_http::Client;
|
2023-03-19 22:59:11 +01:00
|
|
|
use twilight_interactions::command::{CommandInputData, CommandModel, CreateCommand};
|
|
|
|
|
use twilight_mention::{timestamp::{Timestamp, TimestampStyle}, Mention};
|
2023-03-03 16:36:21 +01:00
|
|
|
use twilight_model::{
|
2023-03-19 22:59:11 +01:00
|
|
|
application::interaction::{Interaction, InteractionData, InteractionType},
|
|
|
|
|
http::interaction::{InteractionResponse, InteractionResponseData, InteractionResponseType},
|
|
|
|
|
id::{Id, marker::{UserMarker, ChannelMarker, InteractionMarker}}, channel::message::MessageFlags,
|
2023-03-03 16:36:21 +01:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#[derive(CommandModel, CreateCommand)]
|
2023-03-19 22:59:11 +01:00
|
|
|
#[command(name = "set_fact", desc = "Quietly save a fact")]
|
|
|
|
|
struct SetFactCommand {
|
|
|
|
|
#[command(rename = "name", desc = "Fact name")]
|
2023-03-03 16:36:21 +01:00
|
|
|
fact_name: String,
|
2023-03-19 22:59:11 +01:00
|
|
|
#[command(rename = "value", desc = "Fact value")]
|
2023-03-03 16:36:21 +01:00
|
|
|
fact_value: String,
|
|
|
|
|
}
|
|
|
|
|
|
2023-03-19 22:59:11 +01:00
|
|
|
#[derive(CommandModel, CreateCommand)]
|
|
|
|
|
#[command(name = "get_fact", desc = "Retrieve and display the value of a fact")]
|
|
|
|
|
struct GetFactCommand {
|
|
|
|
|
#[command(rename = "name", desc = "Fact name")]
|
|
|
|
|
fact_name: String,
|
|
|
|
|
#[command(desc = "Should it be displayed publically, by default it won't be")]
|
|
|
|
|
public: Option<bool>,
|
|
|
|
|
}
|
|
|
|
|
|
2023-03-03 16:36:21 +01:00
|
|
|
#[tokio::main]
|
2023-03-19 22:59:11 +01:00
|
|
|
async fn main() -> anyhow::Result<()> {
|
2023-03-03 16:36:21 +01:00
|
|
|
let port = 4635;
|
2023-03-19 22:59:11 +01:00
|
|
|
dotenvy::dotenv().ok();
|
|
|
|
|
|
|
|
|
|
let pg_pool = PgPoolOptions::new()
|
|
|
|
|
.max_connections(5)
|
|
|
|
|
.connect(database_url().as_str())
|
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
sqlx::migrate!().run(&pg_pool).await?;
|
|
|
|
|
|
|
|
|
|
let app = Router::new()
|
|
|
|
|
.route("/", post(post_interaction))
|
|
|
|
|
.with_state(pg_pool);
|
2023-03-03 16:36:21 +01:00
|
|
|
|
|
|
|
|
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
2023-03-19 22:59:11 +01:00
|
|
|
|
|
|
|
|
register_commands().await;
|
2023-03-03 16:36:21 +01:00
|
|
|
|
|
|
|
|
axum::Server::bind(&addr)
|
|
|
|
|
.serve(app.into_make_service())
|
2023-03-19 22:59:11 +01:00
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
Ok(())
|
2023-03-03 16:36:21 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type InteractionResult = Result<(StatusCode, Json<InteractionResponse>), (StatusCode, String)>;
|
|
|
|
|
|
2023-03-19 22:59:11 +01:00
|
|
|
fn validate_request(headers: HeaderMap, body: String) -> Result<Interaction, (StatusCode, String)> {
|
2023-03-03 16:36:21 +01:00
|
|
|
let Ok(interaction): Result<Interaction, _> = serde_json::from_str(&body) else {
|
|
|
|
|
return Err((StatusCode::BAD_REQUEST, "request contained invalid json".to_string()))
|
|
|
|
|
};
|
|
|
|
|
let Some(sig) = headers.get("x-signature-ed25519") else {
|
|
|
|
|
return Err((StatusCode::BAD_REQUEST, "requrest did not include signature header".to_string()))
|
|
|
|
|
};
|
|
|
|
|
let Ok(sig) = hex::decode(sig) else {
|
|
|
|
|
return Err((StatusCode::BAD_REQUEST, "requrest signature is invalid hex".to_string()))
|
|
|
|
|
};
|
|
|
|
|
let Ok(sig) = Signature::from_slice(&sig) else {
|
|
|
|
|
return Err((StatusCode::BAD_REQUEST, "request signature is malformed".to_string()))
|
|
|
|
|
};
|
|
|
|
|
let Some(signed_buf) = headers.get("x-signature-timestamp") else {
|
|
|
|
|
return Err((StatusCode::BAD_REQUEST, "requrest did not include signature timestamp header".to_string()))
|
|
|
|
|
};
|
|
|
|
|
let mut signed_buf = signed_buf.as_bytes().to_owned();
|
|
|
|
|
signed_buf.extend(body.as_bytes());
|
|
|
|
|
|
|
|
|
|
let pub_key = discord_pub_key();
|
|
|
|
|
let Ok(()) = pub_key.verify_strict(&signed_buf, &sig) else {
|
|
|
|
|
return Err((StatusCode::UNAUTHORIZED, "interaction failed signature verification".to_string()))
|
|
|
|
|
};
|
|
|
|
|
|
2023-03-19 22:59:11 +01:00
|
|
|
return Ok(interaction);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn set_fact(
|
|
|
|
|
interaction_id: Id<InteractionMarker>,
|
|
|
|
|
channel_id: Option<Id<ChannelMarker>>,
|
|
|
|
|
author_id: Id<UserMarker>,
|
|
|
|
|
command_data: SetFactCommand,
|
|
|
|
|
pg_pool: &PgPool,
|
|
|
|
|
) -> Result<InteractionResponse, (StatusCode, String)> {
|
|
|
|
|
let Ok(rows) = sqlx::query!("
|
|
|
|
|
INSERT INTO facts (\"last_interaction_id\", \"channel_id\", \"author_id\", \"name\", \"value\")
|
|
|
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
|
|
|
ON CONFLICT ON CONSTRAINT facts_origin_key DO UPDATE SET value = $5, version = facts.version + 1
|
|
|
|
|
",
|
|
|
|
|
interaction_id.to_string(),
|
|
|
|
|
channel_id.map(|cid| cid.to_string()),
|
|
|
|
|
author_id.to_string(),
|
|
|
|
|
command_data.fact_name,
|
|
|
|
|
command_data.fact_value,
|
|
|
|
|
).execute(pg_pool).await.and_then(|rows| Ok(rows.rows_affected())) else {
|
|
|
|
|
return Err((StatusCode::INTERNAL_SERVER_ERROR, "Error saving fact.".to_string()));
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if rows != 1 {
|
|
|
|
|
return Err((StatusCode::INTERNAL_SERVER_ERROR, "Error saving fact".to_string()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(InteractionResponse {
|
|
|
|
|
kind: InteractionResponseType::ChannelMessageWithSource,
|
|
|
|
|
data: Some(InteractionResponseData {
|
|
|
|
|
content: Some(format!(
|
|
|
|
|
"Set {0} to {1}",
|
|
|
|
|
command_data.fact_name, command_data.fact_value
|
|
|
|
|
)),
|
|
|
|
|
flags: Some(MessageFlags::EPHEMERAL),
|
|
|
|
|
..Default::default()
|
|
|
|
|
}),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct FactResponse {
|
|
|
|
|
value: String,
|
|
|
|
|
version: i32,
|
|
|
|
|
created_at: time::OffsetDateTime,
|
|
|
|
|
updated_at: time::OffsetDateTime,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn get_fact(
|
|
|
|
|
channel_id: Option<Id<ChannelMarker>>,
|
|
|
|
|
author_id: Id<UserMarker>,
|
|
|
|
|
command_data: GetFactCommand,
|
|
|
|
|
pg_pool: &PgPool,
|
|
|
|
|
) -> Result<InteractionResponse, (StatusCode, String)> {
|
|
|
|
|
let Ok(facts) = sqlx::query_as!(FactResponse,
|
|
|
|
|
"
|
|
|
|
|
SELECT \"value\", \"version\", \"created_at\", \"updated_at\"
|
|
|
|
|
FROM facts
|
|
|
|
|
WHERE
|
|
|
|
|
channel_id IS NOT DISTINCT FROM $1 AND
|
|
|
|
|
author_id = $2 AND
|
|
|
|
|
name = $3
|
|
|
|
|
", channel_id.map(|cid| cid.to_string()), author_id.to_string(), command_data.fact_name).fetch_all(pg_pool).await else {
|
|
|
|
|
return Err((StatusCode::INTERNAL_SERVER_ERROR, "Querying facts failed".to_string()));
|
|
|
|
|
};
|
|
|
|
|
if facts.len() == 0 {
|
|
|
|
|
return Ok(InteractionResponse {
|
|
|
|
|
kind: InteractionResponseType::ChannelMessageWithSource,
|
|
|
|
|
data: Some(InteractionResponseData {
|
|
|
|
|
content: Some(format!("Fact {0} in channel {1} by you, {2}, was not found.",
|
|
|
|
|
command_data.fact_name,
|
|
|
|
|
channel_id.map_or("<none>".to_string(), |cid| cid.mention().to_string()),
|
|
|
|
|
author_id.mention().to_string(),
|
|
|
|
|
)),
|
|
|
|
|
flags: match command_data.public { Some(true) => None, _ => Some(MessageFlags::EPHEMERAL) },
|
|
|
|
|
..Default::default()
|
|
|
|
|
}),
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
if facts.len() > 1 {
|
|
|
|
|
return Err((StatusCode::INTERNAL_SERVER_ERROR, "Too many facts found, wtf, impossible".to_string()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let fact = &facts[0];
|
|
|
|
|
|
|
|
|
|
Ok(InteractionResponse {
|
|
|
|
|
kind: InteractionResponseType::ChannelMessageWithSource,
|
|
|
|
|
data: Some(InteractionResponseData {
|
|
|
|
|
content: Some(format!(
|
|
|
|
|
"Fact **{0}** was set to **{1}** by {2} at {3}, and was reset {4} times in total since {5}.",
|
|
|
|
|
command_data.fact_name,
|
|
|
|
|
fact.value,
|
|
|
|
|
author_id.mention().to_string(),
|
|
|
|
|
Timestamp::new(fact.updated_at.unix_timestamp().try_into().unwrap(), Some(TimestampStyle::RelativeTime)).mention(),
|
|
|
|
|
fact.version,
|
|
|
|
|
Timestamp::new(fact.created_at.unix_timestamp().try_into().unwrap(), Some(TimestampStyle::ShortDateTime)).mention(),
|
|
|
|
|
)),
|
|
|
|
|
flags: match command_data.public { Some(true) => None, _ => Some(MessageFlags::EPHEMERAL) },
|
|
|
|
|
..Default::default()
|
|
|
|
|
}),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn post_interaction(
|
|
|
|
|
headers: HeaderMap,
|
|
|
|
|
State(pg_pool): State<PgPool>,
|
|
|
|
|
body: String,
|
|
|
|
|
) -> InteractionResult {
|
|
|
|
|
let interaction = match validate_request(headers, body) {
|
|
|
|
|
Ok(interaction) => interaction,
|
|
|
|
|
Err(error) => return Err(error),
|
|
|
|
|
};
|
|
|
|
|
|
2023-03-03 16:36:21 +01:00
|
|
|
match interaction.kind {
|
|
|
|
|
InteractionType::Ping => {
|
|
|
|
|
let pong = InteractionResponse {
|
|
|
|
|
kind: InteractionResponseType::Pong,
|
|
|
|
|
data: None,
|
|
|
|
|
};
|
|
|
|
|
Ok((StatusCode::OK, Json(pong)))
|
|
|
|
|
}
|
|
|
|
|
InteractionType::ApplicationCommand => {
|
2023-03-19 22:59:11 +01:00
|
|
|
let author_id = interaction.author_id();
|
2023-03-03 16:36:21 +01:00
|
|
|
let Some(InteractionData::ApplicationCommand(data)) = interaction.data else {
|
|
|
|
|
return not_found();
|
|
|
|
|
};
|
|
|
|
|
let command_input_data = CommandInputData::from(*data.clone());
|
|
|
|
|
match &*data.name {
|
2023-03-19 22:59:11 +01:00
|
|
|
"set_fact" => {
|
|
|
|
|
let Ok(command_data) = SetFactCommand::from_interaction(command_input_data) else {
|
|
|
|
|
return Err((StatusCode::BAD_REQUEST, "invalid set fact command".to_string()));
|
2023-03-03 16:36:21 +01:00
|
|
|
};
|
2023-03-19 22:59:11 +01:00
|
|
|
let Some(author_id) = author_id else {
|
|
|
|
|
return Err((StatusCode::BAD_REQUEST, "save_fact requires a user".to_string()));
|
2023-03-03 16:36:21 +01:00
|
|
|
};
|
2023-03-19 22:59:11 +01:00
|
|
|
match set_fact(interaction.id, interaction.channel_id, author_id, command_data, &pg_pool).await {
|
|
|
|
|
Ok(response) => Ok((StatusCode::OK, Json(response))),
|
|
|
|
|
Err(err) => Err(err),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"get_fact" => {
|
|
|
|
|
let Ok(command_data) = GetFactCommand::from_interaction(command_input_data) else {
|
|
|
|
|
return Err((StatusCode::BAD_REQUEST, "invalid get fact command".to_string()));
|
|
|
|
|
};
|
|
|
|
|
let Some(author_id) = author_id else {
|
|
|
|
|
return Err((StatusCode::BAD_REQUEST, "get_fact requires a user".to_string()));
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).await {
|
|
|
|
|
Ok(response) => Ok((StatusCode::OK, Json(response))),
|
|
|
|
|
Err(err) => Err(err),
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
_ => not_found(),
|
2023-03-03 16:36:21 +01:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
_ => not_found(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn not_found() -> InteractionResult {
|
|
|
|
|
Err((
|
|
|
|
|
StatusCode::NOT_FOUND,
|
|
|
|
|
"requested interaction not found".to_string(),
|
|
|
|
|
))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn discord_pub_key_bytes() -> Vec<u8> {
|
|
|
|
|
hex::decode(std::env::var("DISCORD_PUB_KEY").unwrap()).unwrap()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn discord_pub_key() -> VerifyingKey {
|
|
|
|
|
let pub_key_bytes: [u8; 32] = discord_pub_key_bytes().try_into().unwrap();
|
|
|
|
|
VerifyingKey::from_bytes(&pub_key_bytes).unwrap()
|
|
|
|
|
}
|
|
|
|
|
|
2023-03-19 22:59:11 +01:00
|
|
|
async fn register_commands() {
|
2023-03-03 16:36:21 +01:00
|
|
|
discord_client()
|
|
|
|
|
.interaction(Id::from_str(&discord_client_id()).unwrap())
|
2023-03-19 22:59:11 +01:00
|
|
|
.set_global_commands(&[
|
|
|
|
|
GetFactCommand::create_command().into(),
|
|
|
|
|
SetFactCommand::create_command().into(),
|
|
|
|
|
])
|
2023-03-03 16:36:21 +01:00
|
|
|
.await
|
|
|
|
|
.unwrap();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
|
struct ClientCredentialsResponse {
|
|
|
|
|
access_token: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn authorization() -> String {
|
|
|
|
|
let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::PAD);
|
2023-03-19 22:59:11 +01:00
|
|
|
let auth = format!("{}:{}", discord_client_id(), discord_client_secret(),);
|
2023-03-03 16:36:21 +01:00
|
|
|
engine.encode(auth)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn client_credentials_grant() -> ClientCredentialsResponse {
|
|
|
|
|
ureq::post("https://discord.com/api/v10/oauth2/token")
|
|
|
|
|
.set("Authorization", &format!("Basic {}", authorization()))
|
|
|
|
|
.send_form(&[
|
|
|
|
|
("grant_type", "client_credentials"),
|
|
|
|
|
("scope", "applications.commands.update"),
|
2023-03-19 22:59:11 +01:00
|
|
|
])
|
|
|
|
|
.unwrap()
|
|
|
|
|
.into_json()
|
|
|
|
|
.unwrap()
|
2023-03-03 16:36:21 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn discord_client_id() -> String {
|
|
|
|
|
std::env::var("DISCORD_CLIENT_ID").unwrap()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn discord_client_secret() -> String {
|
|
|
|
|
std::env::var("DISCORD_CLIENT_SECRET").unwrap()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn discord_client() -> Client {
|
|
|
|
|
let token = client_credentials_grant().access_token;
|
|
|
|
|
Client::new(format!("Bearer {token}"))
|
|
|
|
|
}
|
2023-03-19 22:59:11 +01:00
|
|
|
|
|
|
|
|
fn database_url() -> String {
|
|
|
|
|
std::env::var("DATABASE_URL").unwrap()
|
|
|
|
|
}
|