Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ exclude = [
default = ["compat-3-0-0", "rustls-tls", "dns-resolver"]
compat-3-0-0 = []
sync = []
rustls-tls = ["dep:rustls", "dep:rustls-pemfile", "dep:tokio-rustls"]
rustls-tls = ["dep:rustls", "dep:tokio-rustls"]
openssl-tls = ["dep:openssl", "dep:openssl-probe", "dep:tokio-openssl"]
dns-resolver = ["dep:hickory-resolver", "dep:hickory-proto"]
cert-key-password = ["dep:pem", "dep:pkcs8"]
Expand Down Expand Up @@ -102,7 +102,6 @@ pkcs8 = { version = "0.10.2", features = ["encryption", "pkcs5"], optional = tru
rand = { version = "0.8.3", features = ["small_rng"] }
rayon = { version = "1.5.3", optional = true }
rustc_version_runtime = "0.3.0"
rustls-pemfile = { version = "1.0.1", optional = true }
serde_with = "3.8.1"
sha-1 = "0.10.0"
sha2 = "0.10.2"
Expand All @@ -115,7 +114,7 @@ thiserror = "1.0.24"
tokio-openssl = { version = "0.6.3", optional = true }
tracing = { version = "0.1.36", optional = true }
typed-builder = "0.10.0"
webpki-roots = "0.25.2"
webpki-roots = "0.26"
zstd = { version = "0.11.2", optional = true }
macro_magic = "0.5.1"

Expand All @@ -130,9 +129,10 @@ default-features = false
features = ["json", "rustls-tls"]

[dependencies.rustls]
version = "0.21.6"
version = "0.23.20"
optional = true
features = ["dangerous_configuration"]
default-features = false
features = ["logging", "ring", "std", "tls12"]

[dependencies.serde]
version = "1.0.125"
Expand All @@ -146,9 +146,10 @@ version = "1.17.0"
features = ["io-util", "sync", "macros", "net", "process", "rt", "time", "fs"]

[dependencies.tokio-rustls]
version = "0.24.1"
version = "0.26"
optional = true
features = ["dangerous_configuration"]
default-features = false
features = ["logging", "ring", "tls12"]

[dependencies.tokio-util]
version = "0.7.0"
Expand Down
180 changes: 108 additions & 72 deletions src/runtime/tls_rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@ use std::{
fs::File,
io::{BufReader, Seek},
sync::Arc,
time::SystemTime,
};

use rustls::{
client::{ClientConfig, ServerCertVerified, ServerCertVerifier, ServerName},
Certificate,
client::ClientConfig,
crypto::ring as provider,
pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer, ServerName},
Error as TlsError,
OwnedTrustAnchor,
RootCertStore,
};
use rustls_pemfile::{certs, read_one, Item};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use webpki_roots::TLS_SERVER_ROOTS;
Expand Down Expand Up @@ -49,9 +47,11 @@ pub(super) async fn tls_connect(
tcp_stream: TcpStream,
cfg: &TlsConfig,
) -> Result<TlsStream> {
let name = ServerName::try_from(host).map_err(|e| ErrorKind::DnsResolve {
message: format!("could not resolve {:?}: {}", host, e),
})?;
let name = ServerName::try_from(host)
.map_err(|e| ErrorKind::DnsResolve {
message: format!("could not resolve {:?}: {}", host, e),
})?
.to_owned();

let conn = cfg
.connector
Expand All @@ -66,110 +66,146 @@ pub(super) async fn tls_connect(
fn make_rustls_config(cfg: TlsOptions) -> Result<rustls::ClientConfig> {
let mut store = RootCertStore::empty();
if let Some(path) = cfg.ca_file_path {
let ders = certs(&mut BufReader::new(File::open(&path)?)).map_err(|_| {
ErrorKind::InvalidTlsConfig {
let ders = CertificateDer::pem_file_iter(&path)
.map_err(|err| ErrorKind::InvalidTlsConfig {
message: format!(
"Unable to parse PEM-encoded root certificate from {}",
"Unable to parse PEM-encoded root certificate from {}: {err}",
path.display()
),
}
})?;
store.add_parsable_certificates(&ders);
})?
.flatten();
store.add_parsable_certificates(ders);
} else {
let trust_anchors = TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
});
store.add_trust_anchors(trust_anchors);
store.extend(TLS_SERVER_ROOTS.iter().cloned());
}

let config_builder = ClientConfig::builder_with_provider(provider::default_provider().into())
.with_safe_default_protocol_versions()
.map_err(|e| ErrorKind::InvalidTlsConfig {
message: format!(
"built-in provider should support default protocol versions: {}",
e
),
})?
.with_root_certificates(store);

let mut config = if let Some(path) = cfg.cert_key_file_path {
let mut file = BufReader::new(File::open(&path)?);
let certs = match certs(&mut file) {
Ok(certs) => certs.into_iter().map(Certificate).collect(),
Err(error) => {
return Err(ErrorKind::InvalidTlsConfig {
message: format!(
"Unable to parse PEM-encoded client certificate from {}: {}",
path.display(),
error,
),
}
.into())
}
};
let mut certs = vec![];

for cert in CertificateDer::pem_reader_iter(&mut file) {
let cert = cert.map_err(|error| ErrorKind::InvalidTlsConfig {
message: format!(
"Unable to parse PEM-encoded client certificate from {}: {error}",
path.display(),
),
})?;
certs.push(cert);
}

file.rewind()?;
let key = loop {
let key = 'key: {
#[cfg(feature = "cert-key-password")]
if let Some(key_pw) = cfg.tls_certificate_key_file_password.as_deref() {
use rustls::pki_types::PrivatePkcs8KeyDer;
use std::io::Read;
let mut contents = vec![];
file.read_to_end(&mut contents)?;
break rustls::PrivateKey(super::pem::decrypt_private_key(&contents, key_pw)?);
break 'key PrivatePkcs8KeyDer::from(super::pem::decrypt_private_key(
&contents, key_pw,
)?)
.into();
}
match read_one(&mut file) {
Ok(Some(Item::PKCS8Key(bytes))) | Ok(Some(Item::RSAKey(bytes))) => {
break rustls::PrivateKey(bytes)
}
Ok(Some(_)) => continue,
Ok(None) => {
return Err(ErrorKind::InvalidTlsConfig {
message: format!("No PEM-encoded keys in {}", path.display()),
}
.into())
}
Err(_) => {
match PrivateKeyDer::from_pem_reader(&mut file) {
Ok(key) => break 'key key,
Err(err) => {
return Err(ErrorKind::InvalidTlsConfig {
message: format!(
"Unable to parse PEM-encoded item from {}",
path.display()
"Unable to parse PEM-encoded item from {}: {err}",
path.display(),
),
}
.into())
}
}
};

ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(store)
config_builder
.with_client_auth_cert(certs, key)
.map_err(|error| ErrorKind::InvalidTlsConfig {
message: error.to_string(),
})?
} else {
ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(store)
.with_no_client_auth()
config_builder.with_no_client_auth()
};

if let Some(true) = cfg.allow_invalid_certificates {
// nosemgrep: rustls-dangerous
config // mongodb rating: No Fix Needed
.dangerous()
.set_certificate_verifier(Arc::new(NoCertVerifier {}));
.set_certificate_verifier(Arc::new(danger::NoCertVerifier(
provider::default_provider(),
)));
}

Ok(config)
}

struct NoCertVerifier {}

impl ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_: &Certificate,
_: &[Certificate],
_: &ServerName,
_: &mut dyn Iterator<Item = &[u8]>,
_: &[u8],
_: SystemTime,
) -> std::result::Result<ServerCertVerified, TlsError> {
Ok(ServerCertVerified::assertion())
mod danger {
use super::*;
use rustls::{
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvider},
pki_types::UnixTime,
DigitallySignedStruct,
SignatureScheme,
};

#[derive(Debug)]
pub(super) struct NoCertVerifier(pub(super) CryptoProvider);

impl ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, TlsError> {
Ok(ServerCertVerified::assertion())
}

fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, TlsError> {
verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}

fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, TlsError> {
verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}

fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}
}