mas_keystore/
encrypter.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::Arc;
8
9use aead::Aead;
10use base64ct::{Base64, Encoding};
11use chacha20poly1305::{ChaCha20Poly1305, KeyInit};
12use generic_array::GenericArray;
13use thiserror::Error;
14
15/// Helps encrypting and decrypting data
16#[derive(Clone)]
17pub struct Encrypter {
18    aead: Arc<ChaCha20Poly1305>,
19}
20
21#[derive(Debug, Error)]
22#[error("Decryption error")]
23pub enum DecryptError {
24    Aead(#[from] aead::Error),
25    Base64(#[from] base64ct::Error),
26    Shape,
27}
28
29impl Encrypter {
30    /// Creates an [`Encrypter`] out of an encryption key
31    #[must_use]
32    pub fn new(key: &[u8; 32]) -> Self {
33        let key = GenericArray::from_slice(key);
34        let aead = ChaCha20Poly1305::new(key);
35        let aead = Arc::new(aead);
36        Self { aead }
37    }
38
39    /// Encrypt a payload
40    ///
41    /// # Errors
42    ///
43    /// Will return `Err` when the payload failed to encrypt
44    pub fn encrypt(&self, nonce: &[u8; 12], decrypted: &[u8]) -> Result<Vec<u8>, aead::Error> {
45        let nonce = GenericArray::from_slice(&nonce[..]);
46        let encrypted = self.aead.encrypt(nonce, decrypted)?;
47        Ok(encrypted)
48    }
49
50    /// Decrypts a payload
51    ///
52    /// # Errors
53    ///
54    /// Will return `Err` when the payload failed to decrypt
55    pub fn decrypt(&self, nonce: &[u8; 12], encrypted: &[u8]) -> Result<Vec<u8>, aead::Error> {
56        let nonce = GenericArray::from_slice(&nonce[..]);
57        let encrypted = self.aead.decrypt(nonce, encrypted)?;
58        Ok(encrypted)
59    }
60
61    /// Encrypt a payload to a self-contained base64-encoded string
62    ///
63    /// # Errors
64    ///
65    /// Will return `Err` when the payload failed to encrypt
66    pub fn encrypt_to_string(&self, decrypted: &[u8]) -> Result<String, aead::Error> {
67        let nonce = rand::random();
68        let encrypted = self.encrypt(&nonce, decrypted)?;
69        let encrypted = [&nonce[..], &encrypted].concat();
70        let encrypted = Base64::encode_string(&encrypted);
71        Ok(encrypted)
72    }
73
74    /// Decrypt a payload from a self-contained base64-encoded string
75    ///
76    /// # Errors
77    ///
78    /// Will return `Err` when the payload failed to decrypt
79    pub fn decrypt_string(&self, encrypted: &str) -> Result<Vec<u8>, DecryptError> {
80        let encrypted = Base64::decode_vec(encrypted)?;
81
82        let nonce: &[u8; 12] = encrypted
83            .get(0..12)
84            .ok_or(DecryptError::Shape)?
85            .try_into()
86            .map_err(|_| DecryptError::Shape)?;
87
88        let payload = encrypted.get(12..).ok_or(DecryptError::Shape)?;
89
90        let decrypted_client_secret = self.decrypt(nonce, payload)?;
91
92        Ok(decrypted_client_secret)
93    }
94}