Added tracing and better error handling

This commit is contained in:
2023-03-20 05:26:42 +01:00
parent 2e7ca98a89
commit cb0ca281c8
3 changed files with 187 additions and 37 deletions

View File

@@ -7,40 +7,53 @@ use axum::{
Json, Router,
};
use base64::{alphabet, engine, Engine};
use commands::{SetFactCommand, GetFactCommand, set_fact, get_fact};
use commands::{get_fact, set_fact, GetFactCommand, SetFactCommand};
use ed25519_dalek::{Signature, VerifyingKey};
use serde::Deserialize;
use sqlx::{postgres::PgPoolOptions, PgPool};
use tower_http::trace::TraceLayer;
use tracing::{Instrument, debug_span};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use twilight_http::Client;
use twilight_interactions::command::{CommandInputData, CommandModel, CreateCommand};
use twilight_model::{
application::interaction::{Interaction, InteractionData, InteractionType},
http::interaction::{InteractionResponse, InteractionResponseType},
id::Id
id::Id,
};
mod commands;
mod database;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let port = 4635;
dotenvy::dotenv().ok();
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "god_replacement_product=debug,tower_http=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let port = listen_port()?;
let pg_pool = PgPoolOptions::new()
.max_connections(5)
.connect(database_url().as_str())
.connect(database_url()?.as_str())
.await?;
sqlx::migrate!().run(&pg_pool).await?;
let app = Router::new()
.route("/", post(post_interaction))
.with_state(pg_pool);
.with_state(pg_pool)
.layer(TraceLayer::new_for_http());
register_commands().await?;
let addr = SocketAddr::from(([127, 0, 0, 1], port));
register_commands().await;
tracing::debug!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
@@ -94,6 +107,7 @@ async fn post_interaction(
kind: InteractionResponseType::Pong,
data: None,
};
Ok((StatusCode::OK, Json(pong)))
}
InteractionType::ApplicationCommand => {
@@ -102,7 +116,11 @@ async fn post_interaction(
return not_found();
};
let command_input_data = CommandInputData::from(*data.clone());
match &*data.name {
let slash_command_span = debug_span!("discord_slash_command", name=data.name.to_owned());
slash_command_span.in_scope(|| {
tracing::debug!("started processing command");
});
let result = match &*data.name {
SetFactCommand::NAME => {
let Ok(command_data) = SetFactCommand::from_interaction(command_input_data) else {
return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", SetFactCommand::NAME)));
@@ -110,11 +128,20 @@ async fn post_interaction(
let Some(author_id) = author_id else {
return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", SetFactCommand::NAME)));
};
match set_fact(interaction.id, interaction.channel_id, author_id, command_data, &pg_pool).await {
match set_fact(
interaction.id,
interaction.channel_id,
author_id,
command_data,
&pg_pool,
)
.instrument(slash_command_span.clone())
.await
{
Ok(response) => Ok((StatusCode::OK, Json(response))),
Err(err) => Err(err),
}
},
}
GetFactCommand::NAME => {
let Ok(command_data) = GetFactCommand::from_interaction(command_input_data) else {
return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", GetFactCommand::NAME)));
@@ -122,14 +149,21 @@ async fn post_interaction(
let Some(author_id) = author_id else {
return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME)));
};
match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).await {
match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).instrument(slash_command_span.clone()).await
{
Ok(response) => Ok((StatusCode::OK, Json(response))),
Err(err) => Err(err),
}
},
}
_ => not_found(),
}
};
slash_command_span.in_scope(|| {
tracing::debug!("finished processing command");
});
result
}
_ => not_found(),
}
@@ -151,15 +185,15 @@ fn discord_pub_key() -> VerifyingKey {
VerifyingKey::from_bytes(&pub_key_bytes).unwrap()
}
async fn register_commands() {
discord_client()
.interaction(Id::from_str(&discord_client_id()).unwrap())
async fn register_commands() -> anyhow::Result<()> {
discord_client()?
.interaction(Id::from_str(&discord_client_id()?)?)
.set_global_commands(&[
GetFactCommand::create_command().into(),
SetFactCommand::create_command().into(),
])
.await
.unwrap();
])
.await?;
Ok(())
}
#[derive(Deserialize)]
@@ -167,37 +201,42 @@ struct ClientCredentialsResponse {
access_token: String,
}
fn authorization() -> String {
fn authorization() -> anyhow::Result<String> {
let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::PAD);
let auth = format!("{}:{}", discord_client_id(), discord_client_secret(),);
engine.encode(auth)
let auth = format!("{}:{}", discord_client_id()?, discord_client_secret()?);
Ok(engine.encode(auth))
}
fn client_credentials_grant() -> ClientCredentialsResponse {
ureq::post("https://discord.com/api/v10/oauth2/token")
.set("Authorization", &format!("Basic {}", authorization()))
fn client_credentials_grant() -> anyhow::Result<ClientCredentialsResponse> {
Ok(ureq::post("https://discord.com/api/v10/oauth2/token")
.set("Authorization", &format!("Basic {}", authorization()?))
.send_form(&[
("grant_type", "client_credentials"),
("scope", "applications.commands.update"),
])
.unwrap()
.into_json()
.unwrap()
.into_json()?)
}
fn discord_client_id() -> String {
std::env::var("DISCORD_CLIENT_ID").unwrap()
fn discord_client_id() -> anyhow::Result<String> {
std::env::var("DISCORD_CLIENT_ID").map_err(Into::into)
}
fn discord_client_secret() -> String {
std::env::var("DISCORD_CLIENT_SECRET").unwrap()
fn discord_client_secret() -> anyhow::Result<String> {
std::env::var("DISCORD_CLIENT_SECRET").map_err(Into::into)
}
fn discord_client() -> Client {
let token = client_credentials_grant().access_token;
Client::new(format!("Bearer {token}"))
fn discord_client() -> anyhow::Result<Client> {
let token = client_credentials_grant()?.access_token;
Ok(Client::new(format!("Bearer {token}")))
}
fn database_url() -> String {
std::env::var("DATABASE_URL").unwrap()
fn database_url() -> anyhow::Result<String> {
std::env::var("DATABASE_URL").map_err(Into::into)
}
fn listen_port() -> anyhow::Result<u16> {
std::env::var("LISTEN_PORT")
.map_err(Into::into)
.and_then(|v| v.parse::<u16>().map_err(Into::into))
}