use openmls_traits::crypto::OpenMlsCrypto;
use openmls_traits::types::{Ciphersuite, HpkeCiphertext};
#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use thiserror::*;
use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes};
use super::encryption_keys::{EncryptionKey, EncryptionKeyPair};
use crate::{
binary_tree::array_representation::{LeafNodeIndex, ParentNodeIndex},
ciphersuite::HpkePublicKey,
error::LibraryError,
messages::PathSecret,
schedule::CommitSecret,
treesync::{hashes::ParentHashInput, treekem::UpdatePathNode},
};
#[derive(
Debug,
Eq,
PartialEq,
Clone,
Serialize,
Deserialize,
TlsSerialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSize,
)]
pub struct ParentNode {
pub(super) encryption_key: EncryptionKey,
pub(super) parent_hash: VLBytes,
pub(super) unmerged_leaves: UnmergedLeaves,
}
impl From<EncryptionKey> for ParentNode {
fn from(public_key: EncryptionKey) -> Self {
Self {
encryption_key: public_key,
parent_hash: vec![].into(),
unmerged_leaves: UnmergedLeaves::new(),
}
}
}
#[cfg_attr(test, derive(Clone))]
#[derive(Debug)]
pub(crate) struct PlainUpdatePathNode {
public_key: EncryptionKey,
path_secret: PathSecret,
}
impl PlainUpdatePathNode {
pub(in crate::treesync) fn encrypt(
&self,
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
public_keys: &[EncryptionKey],
group_context: &[u8],
) -> Result<UpdatePathNode, LibraryError> {
#[cfg(target_arch = "wasm32")]
let public_keys = public_keys.iter();
#[cfg(not(target_arch = "wasm32"))]
let public_keys = public_keys.par_iter();
public_keys
.map(|pk| {
self.path_secret
.encrypt(crypto, ciphersuite, pk, group_context)
})
.collect::<Result<Vec<HpkeCiphertext>, LibraryError>>()
.map(|encrypted_path_secrets| UpdatePathNode {
public_key: self.public_key.clone(),
encrypted_path_secrets,
})
}
pub(in crate::treesync) fn path_secret(&self) -> &PathSecret {
&self.path_secret
}
#[cfg(test)]
pub(crate) fn new(public_key: EncryptionKey, path_secret: PathSecret) -> Self {
Self {
public_key,
path_secret,
}
}
}
pub(in crate::treesync) type PathDerivationResult = (
Vec<(ParentNodeIndex, ParentNode)>,
Vec<PlainUpdatePathNode>,
Vec<EncryptionKeyPair>,
CommitSecret,
);
impl ParentNode {
pub(crate) fn derive_path(
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
path_secret: PathSecret,
path_indices: Vec<ParentNodeIndex>,
) -> Result<PathDerivationResult, LibraryError> {
let mut next_path_secret = path_secret;
let mut path_secrets = Vec::with_capacity(path_indices.len());
for _ in 0..path_indices.len() {
let path_secret = next_path_secret;
next_path_secret = path_secret.derive_path_secret(crypto, ciphersuite)?;
path_secrets.push(path_secret);
}
type PathDerivationResults = (
Vec<((ParentNodeIndex, ParentNode), EncryptionKeyPair)>,
Vec<PlainUpdatePathNode>,
);
#[cfg(not(target_arch = "wasm32"))]
let path_secrets = path_secrets.into_par_iter();
#[cfg(target_arch = "wasm32")]
let path_secrets = path_secrets.into_iter();
let (path_with_keypairs, update_path_nodes): PathDerivationResults = path_secrets
.zip(path_indices)
.map(|(path_secret, index)| {
let keypair = path_secret.derive_key_pair(crypto, ciphersuite)?;
let parent_node = ParentNode::from(keypair.public_key().clone());
let update_path_node = PlainUpdatePathNode {
public_key: keypair.public_key().clone(),
path_secret,
};
Ok((((index, parent_node), keypair), update_path_node))
})
.collect::<Result<
Vec<(
((ParentNodeIndex, ParentNode), EncryptionKeyPair),
PlainUpdatePathNode,
)>,
LibraryError,
>>()?
.into_iter()
.unzip();
let (path, keypairs) = path_with_keypairs.into_iter().unzip();
let commit_secret = next_path_secret.into();
Ok((path, update_path_nodes, keypairs, commit_secret))
}
pub(crate) fn public_key(&self) -> &HpkePublicKey {
self.encryption_key.key()
}
pub(crate) fn encryption_key(&self) -> &EncryptionKey {
&self.encryption_key
}
pub(crate) fn unmerged_leaves(&self) -> &[LeafNodeIndex] {
self.unmerged_leaves.list()
}
pub(in crate::treesync) fn set_unmerged_leaves(&mut self, unmerged_leaves: Vec<LeafNodeIndex>) {
self.unmerged_leaves.set_list(unmerged_leaves);
}
pub(in crate::treesync) fn add_unmerged_leaf(&mut self, leaf_index: LeafNodeIndex) {
self.unmerged_leaves.add(leaf_index);
}
pub(in crate::treesync) fn compute_parent_hash(
&self,
crypto: &impl OpenMlsCrypto,
ciphersuite: Ciphersuite,
original_child_resolution: &[u8],
) -> Result<Vec<u8>, LibraryError> {
let parent_hash_input = ParentHashInput::new(
self.encryption_key.key(),
self.parent_hash(),
original_child_resolution,
);
parent_hash_input.hash(crypto, ciphersuite)
}
pub(in crate::treesync) fn set_parent_hash(&mut self, parent_hash: Vec<u8>) {
self.parent_hash = parent_hash.into()
}
pub(crate) fn parent_hash(&self) -> &[u8] {
self.parent_hash.as_slice()
}
}
#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize, TlsSize, TlsSerialize)]
pub(in crate::treesync) struct UnmergedLeaves {
list: Vec<LeafNodeIndex>,
}
impl UnmergedLeaves {
pub(in crate::treesync) fn new() -> Self {
Self { list: Vec::new() }
}
pub(in crate::treesync) fn add(&mut self, leaf_index: LeafNodeIndex) {
let position = self.list.binary_search(&leaf_index).unwrap_or_else(|e| e);
self.list.insert(position, leaf_index);
}
pub(in crate::treesync) fn list(&self) -> &[LeafNodeIndex] {
self.list.as_slice()
}
pub(in crate::treesync) fn set_list(&mut self, list: Vec<LeafNodeIndex>) {
self.list = list;
}
}
#[derive(Error, Debug)]
pub(in crate::treesync) enum UnmergedLeavesError {
#[error("The list of leaves is not sorted.")]
NotSorted,
}
impl TryFrom<Vec<LeafNodeIndex>> for UnmergedLeaves {
type Error = UnmergedLeavesError;
fn try_from(list: Vec<LeafNodeIndex>) -> Result<Self, Self::Error> {
if !list.windows(2).all(|e| e[0] < e[1]) {
return Err(UnmergedLeavesError::NotSorted);
}
Ok(Self { list })
}
}