diff --git a/lib/src/client.dart b/lib/src/client.dart index 4aa54fa..d464121 100644 --- a/lib/src/client.dart +++ b/lib/src/client.dart @@ -97,8 +97,8 @@ class Client { /// enableE2eeRecovery: Enable additional logic to try to recover from bad e2ee sessions Client(this.clientName, {this.debug = false, this.database, this.enableE2eeRecovery = false}) { - keyManager = KeyManager(this); ssss = SSSS(this); + keyManager = KeyManager(this); crossSigning = CrossSigning(this); onLoginStateChanged.stream.listen((loginState) { print('LoginState: ${loginState.toString()}'); diff --git a/lib/src/cross_signing.dart b/lib/src/cross_signing.dart index 8307a96..a884a37 100644 --- a/lib/src/cross_signing.dart +++ b/lib/src/cross_signing.dart @@ -12,7 +12,28 @@ const MASTER_KEY = 'm.cross_signing.master'; class CrossSigning { final Client client; - CrossSigning(this.client); + CrossSigning(this.client) { + client.ssss.setValidator(SELF_SIGNING_KEY, (String secret) async { + final keyObj = olm.PkSigning(); + try { + return keyObj.init_with_seed(base64.decode(secret)) == client.userDeviceKeys[client.userID].selfSigningKey.ed25519Key; + } catch (_) { + return false; + } finally { + keyObj.free(); + } + }); + client.ssss.setValidator(USER_SIGNING_KEY, (String secret) async { + final keyObj = olm.PkSigning(); + try { + return keyObj.init_with_seed(base64.decode(secret)) == client.userDeviceKeys[client.userID].userSigningKey.ed25519Key; + } catch (_) { + return false; + } finally { + keyObj.free(); + } + }); + } bool get enabled => client.accountData[SELF_SIGNING_KEY] != null && diff --git a/lib/src/key_manager.dart b/lib/src/key_manager.dart index abe5b37..6e9d710 100644 --- a/lib/src/key_manager.dart +++ b/lib/src/key_manager.dart @@ -15,7 +15,19 @@ class KeyManager { final outgoingShareRequests = {}; final incomingShareRequests = {}; - KeyManager(this.client); + KeyManager(this.client) { + client.ssss.setValidator(MEGOLM_KEY, (String secret) async { + final keyObj = olm.PkDecryption(); + try { + final info = await getRoomKeysInfo(); + return keyObj.init_with_private_key(base64.decode(secret)) == info['auth_data']['public_key']; + } catch (_) { + return false; + } finally { + keyObj.free(); + } + }); + } bool get enabled => client.accountData[MEGOLM_KEY] != null; @@ -42,50 +54,51 @@ class KeyManager { } final privateKey = base64.decode(await client.ssss.getCached(MEGOLM_KEY)); final decryption = olm.PkDecryption(); + final info = await getRoomKeysInfo(); String backupPubKey; try { backupPubKey = decryption.init_with_private_key(privateKey); - } catch (_) { - decryption.free(); - rethrow; - } - if (backupPubKey == null) { - decryption.free(); - return; - } - // TODO: check if pubkey is valid - for (final roomEntries in payload['rooms'].entries) { - final roomId = roomEntries.key; - if (!(roomEntries.value is Map) || !(roomEntries.value['sessions'] is Map)) { - continue; + + if (backupPubKey == null || !info.containsKey('auth_data') || !(info['auth_data'] is Map) || info['auth_data']['public_key'] != backupPubKey) { + + return; } - for (final sessionEntries in roomEntries.value['sessions'].entries) { - final sessionId = sessionEntries.key; - final rawEncryptedSession = sessionEntries.value; - if (!(rawEncryptedSession is Map)) { + // TODO: check if pubkey is valid + for (final roomEntries in payload['rooms'].entries) { + final roomId = roomEntries.key; + if (!(roomEntries.value is Map) || !(roomEntries.value['sessions'] is Map)) { continue; } - final firstMessageIndex = rawEncryptedSession['first_message_index'] is int ? rawEncryptedSession['first_message_index'] : null; - final forwardedCount = rawEncryptedSession['forwarded_count'] is int ? rawEncryptedSession['forwarded_count'] : null; - final isVerified = rawEncryptedSession['is_verified'] is bool ? rawEncryptedSession['is_verified'] : null; - final sessionData = rawEncryptedSession['session_data']; - if (firstMessageIndex == null || forwardedCount == null || isVerified == null || !(sessionData is Map)) { - continue; - } - final senderKey = sessionData['sender_key']; - Map decrypted; - try { - decrypted = json.decode(decryption.decrypt(sessionData['ephemeral'], sessionData['mac'], sessionData['ciphertext'])); - } catch (err) { - print('[LibOlm] Error decrypting room key: ' + err.toString()); - } - if (decrypted != null) { - decrypted['session_id'] = sessionId; - decrypted['room_id'] = roomId; - final room = client.getRoomById(roomId) ?? Room(id: roomId, client: client); - room.setInboundGroupSession(sessionId, decrypted, forwarded: true); + for (final sessionEntries in roomEntries.value['sessions'].entries) { + final sessionId = sessionEntries.key; + final rawEncryptedSession = sessionEntries.value; + if (!(rawEncryptedSession is Map)) { + continue; + } + final firstMessageIndex = rawEncryptedSession['first_message_index'] is int ? rawEncryptedSession['first_message_index'] : null; + final forwardedCount = rawEncryptedSession['forwarded_count'] is int ? rawEncryptedSession['forwarded_count'] : null; + final isVerified = rawEncryptedSession['is_verified'] is bool ? rawEncryptedSession['is_verified'] : null; + final sessionData = rawEncryptedSession['session_data']; + if (firstMessageIndex == null || forwardedCount == null || isVerified == null || !(sessionData is Map)) { + continue; + } + final senderKey = sessionData['sender_key']; + Map decrypted; + try { + decrypted = json.decode(decryption.decrypt(sessionData['ephemeral'], sessionData['mac'], sessionData['ciphertext'])); + } catch (err) { + print('[LibOlm] Error decrypting room key: ' + err.toString()); + } + if (decrypted != null) { + decrypted['session_id'] = sessionId; + decrypted['room_id'] = roomId; + final room = client.getRoomById(roomId) ?? Room(id: roomId, client: client); + room.setInboundGroupSession(sessionId, decrypted, forwarded: true); + } } } + } finally { + decryption.free(); } } diff --git a/lib/src/ssss.dart b/lib/src/ssss.dart index cb2b3ff..b639af9 100644 --- a/lib/src/ssss.dart +++ b/lib/src/ssss.dart @@ -27,6 +27,7 @@ const OLM_PRIVATE_KEY_LENGTH = 32; // TODO: fetch from dart-olm class SSSS { final Client client; final pendingShareRequests = {}; + final _validators = Function(String)>{}; SSSS(this.client); static _DerivedKeys deriveKeys(Uint8List key, String name) { @@ -119,6 +120,10 @@ class SSSS { info.iterations, info.bits != null ? info.bits / 8 : 32)); } + void setValidator(String type, Future Function(String) validator) { + _validators[type] = validator; + } + String get defaultKeyId { final keyData = client.accountData['m.secret_storage.default_key']; if (keyData == null || !(keyData.content['key'] is String)) { @@ -312,11 +317,17 @@ class SSSS { print('[SSSS] Someone else replied?'); return; // someone replied whom we didn't send the share request to } - pendingShareRequests.remove(request.requestId); + final secret = event.content['secret']; if (!(event.content['secret'] is String)) { print('[SSSS] Secret wasn\'t a string'); return; // the secret wasn't a string....wut? } + // let's validate if the secret is, well, valid + if (_validators.containsKey(request.type) && !(await _validators[request.type](secret))) { + print('[SSSS] The received secret was invalid'); + return; // didn't pass the validator + } + pendingShareRequests.remove(request.requestId); if (request.start.add(Duration(minutes: 15)).isBefore(DateTime.now())) { print('[SSSS] Request is too far in the past'); return; // our request is more than 15min in the past...better not trust it anymore @@ -326,7 +337,7 @@ class SSSS { final keyId = keyIdFromType(request.type); if (keyId != null) { await client.database.storeSSSSCache( - client.id, request.type, keyId, event.content['secret']); + client.id, request.type, keyId, secret); } } }