Match scheme for "ws" or "wss" and choose whether to connect with TLS connector for tungstenite

This commit is contained in:
bitfl0wer 2024-08-26 12:53:32 +02:00
parent 5bee907733
commit 98c842fd25
No known key found for this signature in database
GPG Key ID: 8D90CA11485CD14D
1 changed files with 56 additions and 32 deletions

View File

@ -9,8 +9,10 @@ use futures_util::{
};
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async_tls_with_config, tungstenite, Connector, MaybeTlsStream, WebSocketStream,
connect_async_tls_with_config, connect_async_with_config, tungstenite, Connector,
MaybeTlsStream, WebSocketStream,
};
use url::Url;
use crate::gateway::{GatewayMessage, RawGatewayMessage};
@ -32,38 +34,60 @@ impl TungsteniteBackend {
pub async fn connect(
websocket_url: &str,
) -> Result<(TungsteniteSink, TungsteniteStream), TungsteniteBackendError> {
let certs = webpki_roots::TLS_SERVER_ROOTS;
let roots = rustls::RootCertStore {
roots: certs
.iter()
.map(|cert| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
cert.subject.to_vec(),
cert.subject_public_key_info.to_vec(),
cert.name_constraints.as_ref().map(|der| der.to_vec()),
)
})
.collect(),
};
let (websocket_stream, _) = match connect_async_tls_with_config(
websocket_url,
None,
false,
Some(Connector::Rustls(
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth()
.into(),
)),
)
.await
{
Ok(websocket_stream) => websocket_stream,
Err(e) => return Err(TungsteniteBackendError::TungsteniteError { error: e }),
};
let websocket_url_parsed =
Url::parse(websocket_url).map_err(|_| TungsteniteBackendError::TungsteniteError {
error: tungstenite::error::Error::Url(
tungstenite::error::UrlError::UnsupportedUrlScheme,
),
})?;
if websocket_url_parsed.scheme() == "ws" {
let (websocket_stream, _) =
match connect_async_with_config(websocket_url, None, false).await {
Ok(websocket_stream) => websocket_stream,
Err(e) => return Err(TungsteniteBackendError::TungsteniteError { error: e }),
};
Ok(websocket_stream.split())
Ok(websocket_stream.split())
} else if websocket_url_parsed.scheme() == "wss" {
let certs = webpki_roots::TLS_SERVER_ROOTS;
let roots = rustls::RootCertStore {
roots: certs
.iter()
.map(|cert| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
cert.subject.to_vec(),
cert.subject_public_key_info.to_vec(),
cert.name_constraints.as_ref().map(|der| der.to_vec()),
)
})
.collect(),
};
let (websocket_stream, _) = match connect_async_tls_with_config(
websocket_url,
None,
false,
Some(Connector::Rustls(
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth()
.into(),
)),
)
.await
{
Ok(websocket_stream) => websocket_stream,
Err(e) => return Err(TungsteniteBackendError::TungsteniteError { error: e }),
};
Ok(websocket_stream.split())
} else {
Err(TungsteniteBackendError::TungsteniteError {
error: tungstenite::error::Error::Url(
tungstenite::error::UrlError::UnsupportedUrlScheme,
),
})
}
}
}