summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHampusM <hampus@hampusmat.com>2022-09-09 20:03:02 +0200
committerHampusM <hampus@hampusmat.com>2022-09-09 20:04:54 +0200
commitdb42316544c65951c7781357ec5aaba1b9abb8ab (patch)
tree13fd19efb80cf29f054c1760e8580f4379ce6a6a
parent8a02d3386d4ce0b58de943fcf42bd072af1e0b42 (diff)
refactor: make AuthPromptHandler use the Hyper web server
-rw-r--r--Cargo.toml4
-rw-r--r--src/auth.rs (renamed from src/auth/mod.rs)92
-rw-r--r--src/auth/service.rs40
-rw-r--r--src/errors/auth.rs4
4 files changed, 76 insertions, 64 deletions
diff --git a/Cargo.toml b/Cargo.toml
index d86a21c..25bf7a7 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -4,12 +4,12 @@ version = "0.1.0"
edition = "2021"
[dependencies]
-actix-web = { version = "4.1.0", default-features = false, features = ["macros"] }
tokio = { version = "1.21.0", features = ["macros", "rt-multi-thread"] }
serde = { version = "1.0.144", features = ["derive"] }
serde_json = "1.0.85"
thiserror = "1.0.33"
-hyper = { version = "0.14.20", features = ["client", "http2", "tcp"] }
+hyper = { version = "0.14.20", features = ["client", "http2", "tcp", "server", "http1"] }
+serde_urlencoded = "0.7.1"
[dev_dependencies]
config = "0.13.2"
diff --git a/src/auth/mod.rs b/src/auth.rs
index ce550e1..4af4997 100644
--- a/src/auth/mod.rs
+++ b/src/auth.rs
@@ -1,22 +1,32 @@
//! Deezer API authentication.
+use std::convert::Infallible;
use std::error::Error;
use std::fmt::Display;
+use std::net::ToSocketAddrs;
use std::time::Duration;
-use actix_web::web::Data;
-use actix_web::{App, HttpServer};
+use hyper::service::{make_service_fn, service_fn};
+use hyper::{Body, Request, Response, Server};
use serde::Deserialize;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio::{select, spawn};
-use crate::auth::service::retrieve_auth_code;
use crate::errors::auth::AuthPromptHandlerError;
-mod service;
-
const AUTH_URL: &str = "https://connect.deezer.com/oauth/auth.php";
+/// A Deezer access token.
+#[derive(Debug, Clone)]
+pub struct AccessToken
+{
+ /// The access token.
+ pub access_token: String,
+
+ /// The duration until the access token expires.
+ pub expires: Duration,
+}
+
/// A Deezer authentication code.
#[derive(Debug, Deserialize, Clone)]
pub struct AuthCode(String);
@@ -59,21 +69,26 @@ impl AuthPromptHandler
{
let (done_tx, mut done_rx) = mpsc::channel::<AuthCode>(1);
- let done_tx_data = Data::new(done_tx);
+ let addr = format!("{}:{}", address, port)
+ .to_socket_addrs()
+ .map_err(|_| AuthPromptHandlerError::InvalidAddress)?
+ .next()
+ .map_or_else(|| Err(AuthPromptHandlerError::InvalidAddress), Ok)?;
- let server = HttpServer::new(move || {
- App::new()
- .app_data(done_tx_data.clone())
- .service(retrieve_auth_code)
- })
- .bind((address.clone(), port))
- .map_err(|_| AuthPromptHandlerError::BindAddressFailed)?;
+ let make_service = make_service_fn(move |_| {
+ let done_tx_clone = done_tx.clone();
- let server_future = server.run();
+ let service =
+ service_fn(move |req| handle_auth_code(done_tx_clone.clone(), req));
+
+ async move { Ok::<_, Infallible>(service) }
+ });
+
+ let server = Server::bind(&addr).serve(make_service);
let handle = spawn(async move {
let opt_auth_code = select! {
- result = server_future => {
+ result = server => {
result.map(|_| None)
},
auth_code = async {
@@ -94,13 +109,46 @@ impl AuthPromptHandler
}
}
-/// A Deezer access token.
-#[derive(Debug, Clone)]
-pub struct AccessToken
+#[derive(Debug, Deserialize)]
+struct AuthCodeQuery
{
- /// The access token.
- pub access_token: String,
+ error_reason: Option<String>,
+ code: Option<AuthCode>,
+}
- /// The duration until the access token expires.
- pub expires: Duration,
+async fn handle_auth_code(
+ done_tx: mpsc::Sender<AuthCode>,
+ request: Request<Body>,
+) -> Result<Response<Body>, String>
+{
+ let query = serde_urlencoded::from_str::<AuthCodeQuery>(
+ request.uri().query().map_or_else(|| "", |query| query),
+ )
+ .map_err(|err| format!("Invalid query. {}", err))?;
+
+ if let Some(error_reason) = &query.error_reason {
+ return Ok(Response::builder().status(401).body(Body::from(format!(
+ "Error: No authentication code was retrieved. Reason: {}\n\nYou can close this tab",
+ error_reason
+ ))).unwrap());
+ }
+
+ let auth_code = match &query.code {
+ Some(auth_code) => auth_code,
+ None => {
+ return Ok(Response::builder()
+ .status(400)
+ .body(Body::from("Error: No authentication code was retrieved."))
+ .unwrap());
+ }
+ };
+
+ done_tx.send(auth_code.clone()).await.unwrap();
+
+ Ok(Response::builder()
+ .status(200)
+ .body(Body::from(
+ "Authentication code was successfully retrieved.\n\nYou can close this tab",
+ ))
+ .unwrap())
}
diff --git a/src/auth/service.rs b/src/auth/service.rs
deleted file mode 100644
index b9b44d4..0000000
--- a/src/auth/service.rs
+++ /dev/null
@@ -1,40 +0,0 @@
-use actix_web::web::{Data, Query};
-use actix_web::{get, HttpResponse};
-use serde::Deserialize;
-use tokio::sync::mpsc;
-
-use crate::auth::AuthCode;
-
-#[derive(Debug, Deserialize)]
-struct AuthCodeQuery
-{
- error_reason: Option<String>,
- code: Option<AuthCode>,
-}
-
-#[get("/")]
-async fn retrieve_auth_code(
- query: Query<AuthCodeQuery>,
- done_tx: Data<mpsc::Sender<AuthCode>>,
-) -> HttpResponse
-{
- if let Some(error_reason) = &query.error_reason {
- return HttpResponse::Unauthorized().body(format!(
- "Error: No authentication code was retrieved. Reason: {}\n\nYou can close this tab",
- error_reason
- ));
- }
-
- let auth_code = match &query.code {
- Some(auth_code) => auth_code,
- None => {
- return HttpResponse::BadRequest()
- .body("Error: No authentication code was retrieved");
- }
- };
-
- done_tx.send(auth_code.clone()).await.unwrap();
-
- HttpResponse::Ok()
- .body("Authentication code was successfully retrieved.\n\nYou can close this tab")
-}
diff --git a/src/errors/auth.rs b/src/errors/auth.rs
index 656673e..a165e05 100644
--- a/src/errors/auth.rs
+++ b/src/errors/auth.rs
@@ -7,4 +7,8 @@ pub enum AuthPromptHandlerError
/// HTTP server failed to bind to a address.
#[error("HTTP server failed to bind to address")]
BindAddressFailed,
+
+ /// Invalid address.
+ #[error("Invalid address")]
+ InvalidAddress,
}