From 68fac1e1123e12da0eb5f298591545b7c642d220 Mon Sep 17 00:00:00 2001 From: Sorunome Date: Tue, 23 Jun 2020 10:14:19 +0200 Subject: [PATCH] pick the correct session when encrypting to_device events --- lib/encryption/olm_manager.dart | 87 ++++++++++++++++++--------- lib/encryption/utils/olm_session.dart | 59 ++++++++++++++++++ lib/src/database/database.dart | 22 ++----- lib/src/database/database.g.dart | 70 +++++++++++++++++---- lib/src/database/database.moor | 3 +- test_driver/famedlysdk_test.dart | 50 ++++++--------- 6 files changed, 203 insertions(+), 88 deletions(-) create mode 100644 lib/encryption/utils/olm_session.dart diff --git a/lib/encryption/olm_manager.dart b/lib/encryption/olm_manager.dart index b08ab1b..d5b2fc1 100644 --- a/lib/encryption/olm_manager.dart +++ b/lib/encryption/olm_manager.dart @@ -23,6 +23,7 @@ import 'package:famedlysdk/famedlysdk.dart'; import 'package:famedlysdk/matrix_api.dart'; import 'package:olm/olm.dart' as olm; import './encryption.dart'; +import './utils/olm_session.dart'; class OlmManager { final Encryption encryption; @@ -43,8 +44,8 @@ class OlmManager { OlmManager(this.encryption); /// A map from Curve25519 identity keys to existing olm sessions. - Map> get olmSessions => _olmSessions; - final Map> _olmSessions = {}; + Map> get olmSessions => _olmSessions; + final Map> _olmSessions = {}; Future init(String olmAccount) async { if (olmAccount == null) { @@ -204,25 +205,24 @@ class OlmManager { } } - void storeOlmSession(String curve25519IdentityKey, olm.Session session) { + void storeOlmSession(OlmSession session) { if (client.database == null) { return; } - if (!_olmSessions.containsKey(curve25519IdentityKey)) { - _olmSessions[curve25519IdentityKey] = []; + if (!_olmSessions.containsKey(session.identityKey)) { + _olmSessions[session.identityKey] = []; } - final ix = _olmSessions[curve25519IdentityKey] - .indexWhere((s) => s.session_id() == session.session_id()); + final ix = _olmSessions[session.identityKey] + .indexWhere((s) => s.sessionId == session.sessionId); if (ix == -1) { // add a new session - _olmSessions[curve25519IdentityKey].add(session); + _olmSessions[session.identityKey].add(session); } else { // update an existing session - _olmSessions[curve25519IdentityKey][ix] = session; + _olmSessions[session.identityKey][ix] = session; } - final pickle = session.pickle(client.userID); - client.database.storeOlmSession( - client.id, curve25519IdentityKey, session.session_id(), pickle); + client.database.storeOlmSession(client.id, session.identityKey, + session.sessionId, session.pickledSession, session.lastReceived); } ToDeviceEvent _decryptToDeviceEvent(ToDeviceEvent event) { @@ -245,14 +245,16 @@ class OlmManager { var existingSessions = olmSessions[senderKey]; if (existingSessions != null) { for (var session in existingSessions) { - if (type == 0 && session.matches_inbound(body) == true) { - plaintext = session.decrypt(type, body); - storeOlmSession(senderKey, session); + if (type == 0 && session.session.matches_inbound(body) == true) { + plaintext = session.session.decrypt(type, body); + session.lastReceived = DateTime.now(); + storeOlmSession(session); break; } else if (type == 1) { try { - plaintext = session.decrypt(type, body); - storeOlmSession(senderKey, session); + plaintext = session.session.decrypt(type, body); + session.lastReceived = DateTime.now(); + storeOlmSession(session); break; } catch (_) { plaintext = null; @@ -271,7 +273,13 @@ class OlmManager { _olmAccount.remove_one_time_keys(newSession); client.database?.updateClientKeys(pickledOlmAccount, client.id); plaintext = newSession.decrypt(type, body); - storeOlmSession(senderKey, newSession); + storeOlmSession(OlmSession( + key: client.userID, + identityKey: identityKey, + sessionId: newSession.session_id(), + session: newSession, + lastReceived: DateTime.now(), + )); } catch (_) { newSession?.free(); rethrow; @@ -299,6 +307,22 @@ class OlmManager { ); } + Future> getOlmSessionsFromDatabase(String senderKey) async { + if (client.database == null) { + return []; + } + final rows = + await client.database.dbGetOlmSessions(client.id, senderKey).get(); + final res = []; + for (final row in rows) { + final sess = OlmSession.fromDb(row, client.userID); + if (sess.isValid) { + res.add(sess); + } + } + return res; + } + Future decryptToDeviceEvent(ToDeviceEvent event) async { if (event.type != EventTypes.Encrypted) { return event; @@ -308,8 +332,7 @@ class OlmManager { if (client.database == null) { return false; } - final sessions = await client.database - .getSingleOlmSessions(client.id, senderKey, client.userID); + final sessions = await getOlmSessionsFromDatabase(senderKey); if (sessions.isEmpty) { return false; // okay, can't do anything } @@ -352,11 +375,18 @@ class OlmManager { fingerprintKey, deviceKey, userId, deviceId)) { continue; } + var session = olm.Session(); try { - var session = olm.Session(); session.create_outbound(_olmAccount, identityKey, deviceKey['key']); - await storeOlmSession(identityKey, session); + await storeOlmSession(OlmSession( + key: client.userID, + identityKey: identityKey, + sessionId: session.session_id(), + session: session, + lastReceived: null, + )); } catch (e) { + session.free(); print('[LibOlm] Could not create new outbound olm session: ' + e.toString()); } @@ -369,14 +399,15 @@ class OlmManager { DeviceKeys device, String type, Map payload) async { var sess = olmSessions[device.curve25519Key]; if (sess == null || sess.isEmpty) { - final sessions = await client.database - .getSingleOlmSessions(client.id, device.curve25519Key, client.userID); + final sessions = await getOlmSessionsFromDatabase(device.curve25519Key); if (sessions.isEmpty) { throw ('No olm session found'); } sess = _olmSessions[device.curve25519Key] = sessions; } - sess.sort((a, b) => a.session_id().compareTo(b.session_id())); + sess.sort((a, b) => a.lastReceived == b.lastReceived + ? a.sessionId.compareTo(b.sessionId) + : b.lastReceived.compareTo(a.lastReceived)); final fullPayload = { 'type': type, 'content': payload, @@ -385,8 +416,8 @@ class OlmManager { 'recipient': device.userId, 'recipient_keys': {'ed25519': device.ed25519Key}, }; - final encryptResult = sess.first.encrypt(json.encode(fullPayload)); - storeOlmSession(device.curve25519Key, sess.first); + final encryptResult = sess.first.session.encrypt(json.encode(fullPayload)); + storeOlmSession(sess.first); final encryptedBody = { 'algorithm': 'm.olm.v1.curve25519-aes-sha2', 'sender_key': identityKey, @@ -440,7 +471,7 @@ class OlmManager { void dispose() { for (final sessions in olmSessions.values) { for (final sess in sessions) { - sess.free(); + sess.dispose(); } } _olmAccount?.free(); diff --git a/lib/encryption/utils/olm_session.dart b/lib/encryption/utils/olm_session.dart new file mode 100644 index 0000000..57e5655 --- /dev/null +++ b/lib/encryption/utils/olm_session.dart @@ -0,0 +1,59 @@ +/* + * Famedly Matrix SDK + * Copyright (C) 2020 Famedly GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import 'package:olm/olm.dart' as olm; +import '../../src/database/database.dart' show DbOlmSessions; + +class OlmSession { + String identityKey; + String sessionId; + olm.Session session; + DateTime lastReceived; + final String key; + String get pickledSession => session.pickle(key); + + bool get isValid => session != null; + + OlmSession({ + this.key, + this.identityKey, + this.sessionId, + this.session, + this.lastReceived, + }); + + OlmSession.fromDb(DbOlmSessions dbEntry, String key) : key = key { + session = olm.Session(); + try { + session.unpickle(key, dbEntry.pickle); + identityKey = dbEntry.identityKey; + sessionId = dbEntry.identityKey; + lastReceived = + dbEntry.lastReceived ?? DateTime.fromMillisecondsSinceEpoch(0); + assert(sessionId == session.session_id()); + } catch (e) { + print('[LibOlm] Could not unpickle olm session: ' + e.toString()); + dispose(); + } + } + + void dispose() { + session?.free(); + session = null; + } +} diff --git a/lib/src/database/database.dart b/lib/src/database/database.dart index 58ffac5..e19bb56 100644 --- a/lib/src/database/database.dart +++ b/lib/src/database/database.dart @@ -16,7 +16,7 @@ class Database extends _$Database { Database(QueryExecutor e) : super(e); @override - int get schemaVersion => 4; + int get schemaVersion => 5; int get maxFileSize => 1 * 1024 * 1024; @@ -55,6 +55,10 @@ class Database extends _$Database { 'UPDATE user_device_keys SET outdated = true'); from++; } + if (from == 4) { + await m.addColumn(olmSessions, olmSessions.lastReceived); + from++; + } }, beforeOpen: (_) async { if (executor.dialect == SqlDialect.sqlite) { @@ -114,22 +118,6 @@ class Database extends _$Database { return res; } - Future> getSingleOlmSessions( - int clientId, String identityKey, String userId) async { - final rows = await dbGetOlmSessions(clientId, identityKey).get(); - final res = []; - for (final row in rows) { - try { - var session = olm.Session(); - session.unpickle(userId, row.pickle); - res.add(session); - } catch (e) { - print('[LibOlm] Could not unpickle olm session: ' + e.toString()); - } - } - return res; - } - Future getDbOutboundGroupSession( int clientId, String roomId) async { final res = await dbGetOutboundGroupSession(clientId, roomId).get(); diff --git a/lib/src/database/database.g.dart b/lib/src/database/database.g.dart index 41485b6..d51eec5 100644 --- a/lib/src/database/database.g.dart +++ b/lib/src/database/database.g.dart @@ -1391,17 +1391,20 @@ class DbOlmSessions extends DataClass implements Insertable { final String identityKey; final String sessionId; final String pickle; + final DateTime lastReceived; DbOlmSessions( {@required this.clientId, @required this.identityKey, @required this.sessionId, - @required this.pickle}); + @required this.pickle, + this.lastReceived}); factory DbOlmSessions.fromData( Map data, GeneratedDatabase db, {String prefix}) { final effectivePrefix = prefix ?? ''; final intType = db.typeSystem.forDartType(); final stringType = db.typeSystem.forDartType(); + final dateTimeType = db.typeSystem.forDartType(); return DbOlmSessions( clientId: intType.mapFromDatabaseResponse(data['${effectivePrefix}client_id']), @@ -1411,6 +1414,8 @@ class DbOlmSessions extends DataClass implements Insertable { .mapFromDatabaseResponse(data['${effectivePrefix}session_id']), pickle: stringType.mapFromDatabaseResponse(data['${effectivePrefix}pickle']), + lastReceived: dateTimeType + .mapFromDatabaseResponse(data['${effectivePrefix}last_received']), ); } @override @@ -1428,6 +1433,9 @@ class DbOlmSessions extends DataClass implements Insertable { if (!nullToAbsent || pickle != null) { map['pickle'] = Variable(pickle); } + if (!nullToAbsent || lastReceived != null) { + map['last_received'] = Variable(lastReceived); + } return map; } @@ -1439,6 +1447,7 @@ class DbOlmSessions extends DataClass implements Insertable { identityKey: serializer.fromJson(json['identity_key']), sessionId: serializer.fromJson(json['session_id']), pickle: serializer.fromJson(json['pickle']), + lastReceived: serializer.fromJson(json['last_received']), ); } @override @@ -1449,6 +1458,7 @@ class DbOlmSessions extends DataClass implements Insertable { 'identity_key': serializer.toJson(identityKey), 'session_id': serializer.toJson(sessionId), 'pickle': serializer.toJson(pickle), + 'last_received': serializer.toJson(lastReceived), }; } @@ -1456,12 +1466,14 @@ class DbOlmSessions extends DataClass implements Insertable { {int clientId, String identityKey, String sessionId, - String pickle}) => + String pickle, + DateTime lastReceived}) => DbOlmSessions( clientId: clientId ?? this.clientId, identityKey: identityKey ?? this.identityKey, sessionId: sessionId ?? this.sessionId, pickle: pickle ?? this.pickle, + lastReceived: lastReceived ?? this.lastReceived, ); @override String toString() { @@ -1469,14 +1481,19 @@ class DbOlmSessions extends DataClass implements Insertable { ..write('clientId: $clientId, ') ..write('identityKey: $identityKey, ') ..write('sessionId: $sessionId, ') - ..write('pickle: $pickle') + ..write('pickle: $pickle, ') + ..write('lastReceived: $lastReceived') ..write(')')) .toString(); } @override - int get hashCode => $mrjf($mrjc(clientId.hashCode, - $mrjc(identityKey.hashCode, $mrjc(sessionId.hashCode, pickle.hashCode)))); + int get hashCode => $mrjf($mrjc( + clientId.hashCode, + $mrjc( + identityKey.hashCode, + $mrjc(sessionId.hashCode, + $mrjc(pickle.hashCode, lastReceived.hashCode))))); @override bool operator ==(dynamic other) => identical(this, other) || @@ -1484,7 +1501,8 @@ class DbOlmSessions extends DataClass implements Insertable { other.clientId == this.clientId && other.identityKey == this.identityKey && other.sessionId == this.sessionId && - other.pickle == this.pickle); + other.pickle == this.pickle && + other.lastReceived == this.lastReceived); } class OlmSessionsCompanion extends UpdateCompanion { @@ -1492,17 +1510,20 @@ class OlmSessionsCompanion extends UpdateCompanion { final Value identityKey; final Value sessionId; final Value pickle; + final Value lastReceived; const OlmSessionsCompanion({ this.clientId = const Value.absent(), this.identityKey = const Value.absent(), this.sessionId = const Value.absent(), this.pickle = const Value.absent(), + this.lastReceived = const Value.absent(), }); OlmSessionsCompanion.insert({ @required int clientId, @required String identityKey, @required String sessionId, @required String pickle, + this.lastReceived = const Value.absent(), }) : clientId = Value(clientId), identityKey = Value(identityKey), sessionId = Value(sessionId), @@ -1512,12 +1533,14 @@ class OlmSessionsCompanion extends UpdateCompanion { Expression identityKey, Expression sessionId, Expression pickle, + Expression lastReceived, }) { return RawValuesInsertable({ if (clientId != null) 'client_id': clientId, if (identityKey != null) 'identity_key': identityKey, if (sessionId != null) 'session_id': sessionId, if (pickle != null) 'pickle': pickle, + if (lastReceived != null) 'last_received': lastReceived, }); } @@ -1525,12 +1548,14 @@ class OlmSessionsCompanion extends UpdateCompanion { {Value clientId, Value identityKey, Value sessionId, - Value pickle}) { + Value pickle, + Value lastReceived}) { return OlmSessionsCompanion( clientId: clientId ?? this.clientId, identityKey: identityKey ?? this.identityKey, sessionId: sessionId ?? this.sessionId, pickle: pickle ?? this.pickle, + lastReceived: lastReceived ?? this.lastReceived, ); } @@ -1549,6 +1574,9 @@ class OlmSessionsCompanion extends UpdateCompanion { if (pickle.present) { map['pickle'] = Variable(pickle.value); } + if (lastReceived.present) { + map['last_received'] = Variable(lastReceived.value); + } return map; } } @@ -1591,9 +1619,19 @@ class OlmSessions extends Table with TableInfo { $customConstraints: 'NOT NULL'); } + final VerificationMeta _lastReceivedMeta = + const VerificationMeta('lastReceived'); + GeneratedDateTimeColumn _lastReceived; + GeneratedDateTimeColumn get lastReceived => + _lastReceived ??= _constructLastReceived(); + GeneratedDateTimeColumn _constructLastReceived() { + return GeneratedDateTimeColumn('last_received', $tableName, true, + $customConstraints: ''); + } + @override List get $columns => - [clientId, identityKey, sessionId, pickle]; + [clientId, identityKey, sessionId, pickle, lastReceived]; @override OlmSessions get asDslTable => this; @override @@ -1631,6 +1669,12 @@ class OlmSessions extends Table with TableInfo { } else if (isInserting) { context.missing(_pickleMeta); } + if (data.containsKey('last_received')) { + context.handle( + _lastReceivedMeta, + lastReceived.isAcceptableOrUnknown( + data['last_received'], _lastReceivedMeta)); + } return context; } @@ -5520,6 +5564,7 @@ abstract class _$Database extends GeneratedDatabase { identityKey: row.readString('identity_key'), sessionId: row.readString('session_id'), pickle: row.readString('pickle'), + lastReceived: row.readDateTime('last_received'), ); } @@ -5543,15 +5588,16 @@ abstract class _$Database extends GeneratedDatabase { }).map(_rowToDbOlmSessions); } - Future storeOlmSession( - int client_id, String identitiy_key, String session_id, String pickle) { + Future storeOlmSession(int client_id, String identitiy_key, + String session_id, String pickle, DateTime last_received) { return customInsert( - 'INSERT OR REPLACE INTO olm_sessions (client_id, identity_key, session_id, pickle) VALUES (:client_id, :identitiy_key, :session_id, :pickle)', + 'INSERT OR REPLACE INTO olm_sessions (client_id, identity_key, session_id, pickle, last_received) VALUES (:client_id, :identitiy_key, :session_id, :pickle, :last_received)', variables: [ Variable.withInt(client_id), Variable.withString(identitiy_key), Variable.withString(session_id), - Variable.withString(pickle) + Variable.withString(pickle), + Variable.withDateTime(last_received) ], updates: {olmSessions}, ); diff --git a/lib/src/database/database.moor b/lib/src/database/database.moor index 49a108d..c2e5ff0 100644 --- a/lib/src/database/database.moor +++ b/lib/src/database/database.moor @@ -48,6 +48,7 @@ CREATE TABLE olm_sessions ( identity_key TEXT NOT NULL, session_id TEXT NOT NULL, pickle TEXT NOT NULL, + last_received DATETIME, UNIQUE(client_id, identity_key, session_id) ) AS DbOlmSessions; CREATE INDEX olm_sessions_index ON olm_sessions(client_id); @@ -177,7 +178,7 @@ getAllUserDeviceKeysKeys: SELECT * FROM user_device_keys_key WHERE client_id = : getAllUserCrossSigningKeys: SELECT * FROM user_cross_signing_keys WHERE client_id = :client_id; getAllOlmSessions: SELECT * FROM olm_sessions WHERE client_id = :client_id; dbGetOlmSessions: SELECT * FROM olm_sessions WHERE client_id = :client_id AND identity_key = :identity_key; -storeOlmSession: INSERT OR REPLACE INTO olm_sessions (client_id, identity_key, session_id, pickle) VALUES (:client_id, :identitiy_key, :session_id, :pickle); +storeOlmSession: INSERT OR REPLACE INTO olm_sessions (client_id, identity_key, session_id, pickle, last_received) VALUES (:client_id, :identitiy_key, :session_id, :pickle, :last_received); getAllOutboundGroupSessions: SELECT * FROM outbound_group_sessions WHERE client_id = :client_id; dbGetOutboundGroupSession: SELECT * FROM outbound_group_sessions WHERE client_id = :client_id AND room_id = :room_id; storeOutboundGroupSession: INSERT OR REPLACE INTO outbound_group_sessions (client_id, room_id, pickle, device_ids, creation_time, sent_messages) VALUES (:client_id, :room_id, :pickle, :device_ids, :creation_time, :sent_messages); diff --git a/test_driver/famedlysdk_test.dart b/test_driver/famedlysdk_test.dart index 1dc8f9d..79cabfb 100644 --- a/test_driver/famedlysdk_test.dart +++ b/test_driver/famedlysdk_test.dart @@ -139,12 +139,10 @@ void test() async { assert(testClientB .encryption.olmManager.olmSessions[testClientA.identityKey].length == 1); - assert(testClientA - .encryption.olmManager.olmSessions[testClientB.identityKey].first - .session_id() == - testClientB - .encryption.olmManager.olmSessions[testClientA.identityKey].first - .session_id()); + assert(testClientA.encryption.olmManager.olmSessions[testClientB.identityKey] + .first.sessionId == + testClientB.encryption.olmManager.olmSessions[testClientA.identityKey] + .first.sessionId); assert(inviteRoom.client.encryption.keyManager .getInboundGroupSession(inviteRoom.id, currentSessionIdA, '') != null); @@ -162,12 +160,10 @@ void test() async { assert(testClientB .encryption.olmManager.olmSessions[testClientA.identityKey].length == 1); - assert(testClientA - .encryption.olmManager.olmSessions[testClientB.identityKey].first - .session_id() == - testClientB - .encryption.olmManager.olmSessions[testClientA.identityKey].first - .session_id()); + assert(testClientA.encryption.olmManager.olmSessions[testClientB.identityKey] + .first.sessionId == + testClientB.encryption.olmManager.olmSessions[testClientA.identityKey] + .first.sessionId); assert(room.client.encryption.keyManager .getOutboundGroupSession(room.id) @@ -231,24 +227,20 @@ void test() async { assert(testClientB .encryption.olmManager.olmSessions[testClientA.identityKey].length == 1); - assert(testClientA - .encryption.olmManager.olmSessions[testClientB.identityKey].first - .session_id() == - testClientB - .encryption.olmManager.olmSessions[testClientA.identityKey].first - .session_id()); + assert(testClientA.encryption.olmManager.olmSessions[testClientB.identityKey] + .first.sessionId == + testClientB.encryption.olmManager.olmSessions[testClientA.identityKey] + .first.sessionId); assert(testClientA .encryption.olmManager.olmSessions[testClientC.identityKey].length == 1); assert(testClientC .encryption.olmManager.olmSessions[testClientA.identityKey].length == 1); - assert(testClientA - .encryption.olmManager.olmSessions[testClientC.identityKey].first - .session_id() == - testClientC - .encryption.olmManager.olmSessions[testClientA.identityKey].first - .session_id()); + assert(testClientA.encryption.olmManager.olmSessions[testClientC.identityKey] + .first.sessionId == + testClientC.encryption.olmManager.olmSessions[testClientA.identityKey] + .first.sessionId); assert(room.client.encryption.keyManager .getOutboundGroupSession(room.id) .outboundGroupSession @@ -281,12 +273,10 @@ void test() async { assert(testClientB .encryption.olmManager.olmSessions[testClientA.identityKey].length == 1); - assert(testClientA - .encryption.olmManager.olmSessions[testClientB.identityKey].first - .session_id() == - testClientB - .encryption.olmManager.olmSessions[testClientA.identityKey].first - .session_id()); + assert(testClientA.encryption.olmManager.olmSessions[testClientB.identityKey] + .first.sessionId == + testClientB.encryption.olmManager.olmSessions[testClientA.identityKey] + .first.sessionId); assert(room.client.encryption.keyManager .getOutboundGroupSession(room.id) .outboundGroupSession