Coverage for functions \ flipdare \ firestore \ core \ db_query.py: 86%

180 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-05-08 12:22 +1000

1#!/usr/bin/env python 

2# Copyright (c) 2026 Flipdare Pty Ltd. All rights reserved. 

3# 

4# This file is part of Flipdare's proprietary software and contains 

5# confidential and copyrighted material. Unauthorised copying, 

6# modification, distribution, or use of this file is strictly 

7# prohibited without prior written permission from Flipdare Pty Ltd. 

8# 

9# This software includes third-party components licensed under MIT, 

10# BSD, and Apache 2.0 licences. See THIRD_PARTY_NOTICES for details. 

11# 

12 

13from __future__ import annotations 

14 

15from datetime import datetime 

16from enum import Enum 

17from typing import Any, Literal, TypeVar, override 

18 

19from google.cloud.firestore import And, DocumentSnapshot 

20from google.cloud.firestore import Client as FirestoreClient 

21from google.cloud.firestore import CollectionReference, Or 

22from google.cloud.firestore_v1 import Query 

23from google.cloud.firestore_v1.base_document import BaseDocumentReference 

24from google.cloud.firestore_v1.base_query import BaseQuery, FieldFilter 

25 

26__all__ = ["IdQuery", "DbQuery", "DbSubQuery", "WhereField", "FieldOp", "OrderByField"] 

27 

28type _TimeKey = Literal["created_at", "updated_at"] 

29 

30K = TypeVar("K", bound=str) 

31 

32 

33class FieldOp(Enum): 

34 # FieldOp is an Enum that maps to Firestore's supported query operators. 

35 # "<": _operator_enum.LESS_THAN, 

36 # "<=": _operator_enum.LESS_THAN_OR_EQUAL, 

37 # _EQ_OP: _operator_enum.EQUAL, 

38 # _NEQ_OP: _operator_enum.NOT_EQUAL, 

39 # ">=": _operator_enum.GREATER_THAN_OR_EQUAL, 

40 # ">": _operator_enum.GREATER_THAN, 

41 # "array_contains": _operator_enum.ARRAY_CONTAINS, 

42 # "in": _operator_enum.IN, 

43 # "not-in": _operator_enum.NOT_IN, 

44 # "array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY, 

45 # } 

46 EQUAL = "==" 

47 NOT_EQUAL = "!=" 

48 LESS_THAN = "<" 

49 LESS_THAN_OR_EQUAL = "<=" 

50 GREATER_THAN = ">" 

51 GREATER_THAN_OR_EQUAL = ">=" 

52 ARRAY_CONTAINS = "array_contains" 

53 ARRAY_CONTAINS_ANY = "array_contains_any" 

54 IN = "in" 

55 NOT_IN = "not_in" 

56 

57 

58class DateFilter: 

59 def __init__( 

60 self, from_date: datetime, to_date: datetime, restrict_by: _TimeKey = "created_at" 

61 ) -> None: 

62 self.from_date = from_date 

63 self.to_date = to_date 

64 self.restrict_by = restrict_by 

65 

66 @classmethod 

67 def created_at(cls, from_date: datetime, to_date: datetime) -> DateFilter: 

68 return cls(from_date=from_date, to_date=to_date, restrict_by="created_at") 

69 

70 @classmethod 

71 def updated_at(cls, from_date: datetime, to_date: datetime) -> DateFilter: 

72 return cls(from_date=from_date, to_date=to_date, restrict_by="updated_at") 

73 

74 @property 

75 def filters(self) -> list[FieldFilter]: 

76 if self.from_date > self.to_date: 

77 msg = f"The 'from_date' ({self.from_date}) must be before the 'to_date' ({self.to_date})." 

78 raise ValueError(msg) 

79 

80 time_field = self.restrict_by 

81 return [ 

82 FieldFilter(time_field, FieldOp.GREATER_THAN.value, self.from_date), 

83 FieldFilter(time_field, FieldOp.LESS_THAN.value, self.to_date), 

84 ] 

85 

86 @override 

87 def __repr__(self) -> str: 

88 return f"DateFilter({self.from_date} -> {self.to_date}, restrict_by={self.restrict_by})" 

89 

90 @override 

91 def __str__(self) -> str: 

92 return self.__repr__() 

93 

94 

95class WhereField[K: str]: 

96 def __init__(self, key: K, op: FieldOp, value: Any) -> None: 

97 self.key = key 

98 self.op = op 

99 self.value = value 

100 

101 @property 

102 def filter(self) -> FieldFilter: 

103 return FieldFilter(self.key, self.op.value, self.value) 

104 

105 @override 

106 def __repr__(self) -> str: 

107 return f"WhereField(key={self.key}, op={self.op}, value={self.value})" 

108 

109 @override 

110 def __str__(self) -> str: 

111 return self.__repr__() 

112 

113 

114class OrderByField[K: str]: 

115 def __init__(self, key: K, descending: bool = False) -> None: 

116 self.key = key 

117 self.descending = descending 

118 

119 @classmethod 

120 def created_at(cls, descending: bool = False) -> OrderByField[_TimeKey]: 

121 return OrderByField(key="created_at", descending=descending) 

122 

123 @classmethod 

124 def updated_at(cls, descending: bool = False) -> OrderByField[_TimeKey]: 

125 return OrderByField(key="updated_at", descending=descending) 

126 

127 @classmethod 

128 def asc(cls, key: K) -> OrderByField[Any]: 

129 return cls(key=key, descending=False) 

130 

131 @classmethod 

132 def desc(cls, key: K) -> OrderByField[Any]: 

133 return cls(key=key, descending=True) 

134 

135 @property 

136 def direction(self) -> str: 

137 if self.descending: 

138 return str(BaseQuery.DESCENDING) 

139 return str(BaseQuery.ASCENDING) 

140 

141 @override 

142 def __repr__(self) -> str: 

143 return f"OrderByField(key={self.key}, descending={self.descending})" 

144 

145 @override 

146 def __str__(self) -> str: 

147 return self.__repr__() 

148 

149 

150class IdQuery: 

151 def __init__(self, doc_id: str) -> None: 

152 self.doc_id = doc_id 

153 

154 def doc( 

155 self, 

156 client: FirestoreClient, 

157 collection_name: str, 

158 ) -> BaseDocumentReference: 

159 return client.collection(collection_name).document(self.doc_id) 

160 

161 def snap( 

162 self, 

163 client: FirestoreClient, 

164 collection_name: str, 

165 ) -> DocumentSnapshot: 

166 snap = client.collection(collection_name).document(self.doc_id).get() 

167 # cosmetic, since can return Awaitable .. 

168 assert isinstance(snap, DocumentSnapshot) 

169 return snap 

170 

171 

172class DbQuery: 

173 

174 def __init__( 

175 self, 

176 filter_value: Any | None = None, 

177 limit: int | None = None, 

178 order_by: OrderByField[Any] | None = None, 

179 ) -> None: 

180 self.filter_value = filter_value 

181 self.limit = limit 

182 self.order_by = order_by 

183 

184 @classmethod 

185 def equal( 

186 cls, 

187 key: str, 

188 value: Any, 

189 ) -> DbQuery: 

190 filter_value = FieldFilter(key, FieldOp.EQUAL.value, value) 

191 return cls(filter_value=filter_value) 

192 

193 @classmethod 

194 def where( 

195 cls, 

196 where: WhereField[Any], 

197 limit: int | None = None, 

198 order_by: OrderByField[Any] | None = None, 

199 ) -> DbQuery: 

200 filter_value = FieldFilter(where.key, where.op.value, where.value) 

201 return cls(filter_value=filter_value, limit=limit, order_by=order_by) 

202 

203 @classmethod 

204 def and_( 

205 cls, 

206 where_fields: list[WhereField[Any]], 

207 limit: int | None = None, 

208 order_by: OrderByField[Any] | None = None, 

209 date_filter: DateFilter | None = None, 

210 ) -> DbQuery: 

211 filter_values = [field.filter for field in where_fields] 

212 if date_filter is not None: 

213 filter_values.extend(date_filter.filters) 

214 

215 filter_value = And(filter_values) # type: ignore[arg-type] 

216 return cls(filter_value=filter_value, limit=limit, order_by=order_by) 

217 

218 @classmethod 

219 def or_( 

220 cls, 

221 where_fields: list[WhereField[Any]], 

222 limit: int | None = None, 

223 order_by: OrderByField[Any] | None = None, 

224 ) -> DbQuery: 

225 filter_values = [field.filter for field in where_fields] 

226 filter_value = Or(filter_values) # type: ignore[arg-type] 

227 return cls(filter_value=filter_value, limit=limit, order_by=order_by) 

228 

229 @classmethod 

230 def and_or_( 

231 cls, 

232 and_where: WhereField[Any], 

233 or_where_fields: list[WhereField[Any]], 

234 limit: int | None = None, 

235 order_by: OrderByField[Any] | None = None, 

236 ) -> DbQuery: 

237 """ 

238 e.g. 

239 and_field = WhereField("city", FieldOp.EQUAL, "New York") 

240 or_fields = [ 

241 WhereField("age", FieldOp.EQUAL, 30), 

242 WhereField("age", FieldOp.EQUAL, 28) 

243 ] 

244 """ 

245 and_filter = and_where.filter 

246 or_filter_values = [field.filter for field in or_where_fields] 

247 

248 filter_value = And([and_filter, Or(or_filter_values)]) # type: ignore[arg-type] 

249 return cls(filter_value=filter_value, limit=limit, order_by=order_by) 

250 

251 @classmethod 

252 def complex_and_or_( 

253 cls, 

254 and_where_fields: list[WhereField[Any]], 

255 or_where_fields: list[WhereField[Any]], 

256 limit: int | None = None, 

257 order_by: OrderByField[Any] | None = None, 

258 date_filter: DateFilter | None = None, 

259 ) -> DbQuery: 

260 """ 

261 e.g. 

262 and_field = WhereField("city", FieldOp.EQUAL, "New York") 

263 or_fields = [ 

264 WhereField("age", FieldOp.EQUAL, 30), 

265 WhereField("age", FieldOp.EQUAL, 28) 

266 ] 

267 """ 

268 and_filters = [field.filter for field in and_where_fields] 

269 if date_filter is not None: 

270 and_filters.extend(date_filter.filters) 

271 

272 or_filter_values = [field.filter for field in or_where_fields] 

273 

274 filter_value = And([*and_filters, Or(or_filter_values)]) # type: ignore[arg-type] 

275 return cls(filter_value=filter_value, limit=limit, order_by=order_by) 

276 

277 def get_query( 

278 self, 

279 client: FirestoreClient, 

280 collection_name: str, 

281 ) -> Query | CollectionReference: 

282 query: Query | CollectionReference = client.collection(collection_name) 

283 

284 if self.order_by is not None: 

285 dirn = self.order_by.direction 

286 query = query.order_by(self.order_by.key, direction=dirn) 

287 if self.filter_value is not None: 

288 query = query.where(filter=self.filter_value) 

289 

290 if self.limit is not None: 

291 query = query.limit(self.limit) 

292 

293 return query 

294 

295 @override 

296 def __str__(self) -> str: 

297 filter_str = f"filter_value={self.filter_value}" if self.filter_value else "no filter" 

298 return f"AppQuery({filter_str}, limit={self.limit}, order_by={self.order_by})" 

299 

300 @override 

301 def __repr__(self) -> str: 

302 return self.__str__() 

303 

304 

305class DbSubQuery: 

306 

307 def __init__( 

308 self, 

309 parent_doc_id: str, 

310 filter_value: Any | None = None, 

311 limit: int | None = None, 

312 order_by: OrderByField[Any] | None = None, 

313 ) -> None: 

314 self.filter_value = filter_value 

315 self.limit = limit 

316 self.order_by = order_by 

317 self.parent_doc_id = parent_doc_id 

318 

319 @classmethod 

320 def where( 

321 cls, 

322 parent_doc_id: str, 

323 where_field: WhereField[Any], 

324 limit: int | None = None, 

325 order_by: OrderByField[Any] | None = None, 

326 ) -> DbSubQuery: 

327 filter_value = FieldFilter(where_field.key, where_field.op.value, where_field.value) 

328 return cls( 

329 parent_doc_id=parent_doc_id, 

330 filter_value=filter_value, 

331 limit=limit, 

332 order_by=order_by, 

333 ) 

334 

335 @classmethod 

336 def and_( 

337 cls, 

338 parent_doc_id: str, 

339 where_fields: list[WhereField[Any]], 

340 limit: int | None = None, 

341 order_by: OrderByField[Any] | None = None, 

342 date_filter: DateFilter | None = None, 

343 ) -> DbSubQuery: 

344 filter_values = [field.filter for field in where_fields] 

345 if date_filter is not None: 

346 filter_values.extend(date_filter.filters) 

347 

348 filter_value = And(filter_values) # type: ignore[arg-type] 

349 return cls( 

350 parent_doc_id=parent_doc_id, 

351 filter_value=filter_value, 

352 limit=limit, 

353 order_by=order_by, 

354 ) 

355 

356 def get_query( 

357 self, 

358 client: FirestoreClient, 

359 collection_name: str, 

360 sub_collection_name: str, 

361 ) -> Query: 

362 query: Query = ( 

363 client.collection(collection_name) 

364 .document(self.parent_doc_id) 

365 .collection(sub_collection_name) 

366 ) 

367 

368 if self.order_by is not None: 

369 query = query.order_by(self.order_by.key, direction=self.order_by.direction) 

370 

371 if self.filter_value is not None: 

372 query = query.where(filter=self.filter_value) 

373 

374 if self.limit is not None: 

375 query = query.limit(self.limit) 

376 

377 return query