Coverage for functions \ flipdare \ firestore \ core \ db_query.py: 86%
180 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-05-08 12:22 +1000
« prev ^ index » next coverage.py v7.13.0, created at 2026-05-08 12:22 +1000
1#!/usr/bin/env python
2# Copyright (c) 2026 Flipdare Pty Ltd. All rights reserved.
3#
4# This file is part of Flipdare's proprietary software and contains
5# confidential and copyrighted material. Unauthorised copying,
6# modification, distribution, or use of this file is strictly
7# prohibited without prior written permission from Flipdare Pty Ltd.
8#
9# This software includes third-party components licensed under MIT,
10# BSD, and Apache 2.0 licences. See THIRD_PARTY_NOTICES for details.
11#
13from __future__ import annotations
15from datetime import datetime
16from enum import Enum
17from typing import Any, Literal, TypeVar, override
19from google.cloud.firestore import And, DocumentSnapshot
20from google.cloud.firestore import Client as FirestoreClient
21from google.cloud.firestore import CollectionReference, Or
22from google.cloud.firestore_v1 import Query
23from google.cloud.firestore_v1.base_document import BaseDocumentReference
24from google.cloud.firestore_v1.base_query import BaseQuery, FieldFilter
26__all__ = ["IdQuery", "DbQuery", "DbSubQuery", "WhereField", "FieldOp", "OrderByField"]
28type _TimeKey = Literal["created_at", "updated_at"]
30K = TypeVar("K", bound=str)
33class FieldOp(Enum):
34 # FieldOp is an Enum that maps to Firestore's supported query operators.
35 # "<": _operator_enum.LESS_THAN,
36 # "<=": _operator_enum.LESS_THAN_OR_EQUAL,
37 # _EQ_OP: _operator_enum.EQUAL,
38 # _NEQ_OP: _operator_enum.NOT_EQUAL,
39 # ">=": _operator_enum.GREATER_THAN_OR_EQUAL,
40 # ">": _operator_enum.GREATER_THAN,
41 # "array_contains": _operator_enum.ARRAY_CONTAINS,
42 # "in": _operator_enum.IN,
43 # "not-in": _operator_enum.NOT_IN,
44 # "array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY,
45 # }
46 EQUAL = "=="
47 NOT_EQUAL = "!="
48 LESS_THAN = "<"
49 LESS_THAN_OR_EQUAL = "<="
50 GREATER_THAN = ">"
51 GREATER_THAN_OR_EQUAL = ">="
52 ARRAY_CONTAINS = "array_contains"
53 ARRAY_CONTAINS_ANY = "array_contains_any"
54 IN = "in"
55 NOT_IN = "not_in"
58class DateFilter:
59 def __init__(
60 self, from_date: datetime, to_date: datetime, restrict_by: _TimeKey = "created_at"
61 ) -> None:
62 self.from_date = from_date
63 self.to_date = to_date
64 self.restrict_by = restrict_by
66 @classmethod
67 def created_at(cls, from_date: datetime, to_date: datetime) -> DateFilter:
68 return cls(from_date=from_date, to_date=to_date, restrict_by="created_at")
70 @classmethod
71 def updated_at(cls, from_date: datetime, to_date: datetime) -> DateFilter:
72 return cls(from_date=from_date, to_date=to_date, restrict_by="updated_at")
74 @property
75 def filters(self) -> list[FieldFilter]:
76 if self.from_date > self.to_date:
77 msg = f"The 'from_date' ({self.from_date}) must be before the 'to_date' ({self.to_date})."
78 raise ValueError(msg)
80 time_field = self.restrict_by
81 return [
82 FieldFilter(time_field, FieldOp.GREATER_THAN.value, self.from_date),
83 FieldFilter(time_field, FieldOp.LESS_THAN.value, self.to_date),
84 ]
86 @override
87 def __repr__(self) -> str:
88 return f"DateFilter({self.from_date} -> {self.to_date}, restrict_by={self.restrict_by})"
90 @override
91 def __str__(self) -> str:
92 return self.__repr__()
95class WhereField[K: str]:
96 def __init__(self, key: K, op: FieldOp, value: Any) -> None:
97 self.key = key
98 self.op = op
99 self.value = value
101 @property
102 def filter(self) -> FieldFilter:
103 return FieldFilter(self.key, self.op.value, self.value)
105 @override
106 def __repr__(self) -> str:
107 return f"WhereField(key={self.key}, op={self.op}, value={self.value})"
109 @override
110 def __str__(self) -> str:
111 return self.__repr__()
114class OrderByField[K: str]:
115 def __init__(self, key: K, descending: bool = False) -> None:
116 self.key = key
117 self.descending = descending
119 @classmethod
120 def created_at(cls, descending: bool = False) -> OrderByField[_TimeKey]:
121 return OrderByField(key="created_at", descending=descending)
123 @classmethod
124 def updated_at(cls, descending: bool = False) -> OrderByField[_TimeKey]:
125 return OrderByField(key="updated_at", descending=descending)
127 @classmethod
128 def asc(cls, key: K) -> OrderByField[Any]:
129 return cls(key=key, descending=False)
131 @classmethod
132 def desc(cls, key: K) -> OrderByField[Any]:
133 return cls(key=key, descending=True)
135 @property
136 def direction(self) -> str:
137 if self.descending:
138 return str(BaseQuery.DESCENDING)
139 return str(BaseQuery.ASCENDING)
141 @override
142 def __repr__(self) -> str:
143 return f"OrderByField(key={self.key}, descending={self.descending})"
145 @override
146 def __str__(self) -> str:
147 return self.__repr__()
150class IdQuery:
151 def __init__(self, doc_id: str) -> None:
152 self.doc_id = doc_id
154 def doc(
155 self,
156 client: FirestoreClient,
157 collection_name: str,
158 ) -> BaseDocumentReference:
159 return client.collection(collection_name).document(self.doc_id)
161 def snap(
162 self,
163 client: FirestoreClient,
164 collection_name: str,
165 ) -> DocumentSnapshot:
166 snap = client.collection(collection_name).document(self.doc_id).get()
167 # cosmetic, since can return Awaitable ..
168 assert isinstance(snap, DocumentSnapshot)
169 return snap
172class DbQuery:
174 def __init__(
175 self,
176 filter_value: Any | None = None,
177 limit: int | None = None,
178 order_by: OrderByField[Any] | None = None,
179 ) -> None:
180 self.filter_value = filter_value
181 self.limit = limit
182 self.order_by = order_by
184 @classmethod
185 def equal(
186 cls,
187 key: str,
188 value: Any,
189 ) -> DbQuery:
190 filter_value = FieldFilter(key, FieldOp.EQUAL.value, value)
191 return cls(filter_value=filter_value)
193 @classmethod
194 def where(
195 cls,
196 where: WhereField[Any],
197 limit: int | None = None,
198 order_by: OrderByField[Any] | None = None,
199 ) -> DbQuery:
200 filter_value = FieldFilter(where.key, where.op.value, where.value)
201 return cls(filter_value=filter_value, limit=limit, order_by=order_by)
203 @classmethod
204 def and_(
205 cls,
206 where_fields: list[WhereField[Any]],
207 limit: int | None = None,
208 order_by: OrderByField[Any] | None = None,
209 date_filter: DateFilter | None = None,
210 ) -> DbQuery:
211 filter_values = [field.filter for field in where_fields]
212 if date_filter is not None:
213 filter_values.extend(date_filter.filters)
215 filter_value = And(filter_values) # type: ignore[arg-type]
216 return cls(filter_value=filter_value, limit=limit, order_by=order_by)
218 @classmethod
219 def or_(
220 cls,
221 where_fields: list[WhereField[Any]],
222 limit: int | None = None,
223 order_by: OrderByField[Any] | None = None,
224 ) -> DbQuery:
225 filter_values = [field.filter for field in where_fields]
226 filter_value = Or(filter_values) # type: ignore[arg-type]
227 return cls(filter_value=filter_value, limit=limit, order_by=order_by)
229 @classmethod
230 def and_or_(
231 cls,
232 and_where: WhereField[Any],
233 or_where_fields: list[WhereField[Any]],
234 limit: int | None = None,
235 order_by: OrderByField[Any] | None = None,
236 ) -> DbQuery:
237 """
238 e.g.
239 and_field = WhereField("city", FieldOp.EQUAL, "New York")
240 or_fields = [
241 WhereField("age", FieldOp.EQUAL, 30),
242 WhereField("age", FieldOp.EQUAL, 28)
243 ]
244 """
245 and_filter = and_where.filter
246 or_filter_values = [field.filter for field in or_where_fields]
248 filter_value = And([and_filter, Or(or_filter_values)]) # type: ignore[arg-type]
249 return cls(filter_value=filter_value, limit=limit, order_by=order_by)
251 @classmethod
252 def complex_and_or_(
253 cls,
254 and_where_fields: list[WhereField[Any]],
255 or_where_fields: list[WhereField[Any]],
256 limit: int | None = None,
257 order_by: OrderByField[Any] | None = None,
258 date_filter: DateFilter | None = None,
259 ) -> DbQuery:
260 """
261 e.g.
262 and_field = WhereField("city", FieldOp.EQUAL, "New York")
263 or_fields = [
264 WhereField("age", FieldOp.EQUAL, 30),
265 WhereField("age", FieldOp.EQUAL, 28)
266 ]
267 """
268 and_filters = [field.filter for field in and_where_fields]
269 if date_filter is not None:
270 and_filters.extend(date_filter.filters)
272 or_filter_values = [field.filter for field in or_where_fields]
274 filter_value = And([*and_filters, Or(or_filter_values)]) # type: ignore[arg-type]
275 return cls(filter_value=filter_value, limit=limit, order_by=order_by)
277 def get_query(
278 self,
279 client: FirestoreClient,
280 collection_name: str,
281 ) -> Query | CollectionReference:
282 query: Query | CollectionReference = client.collection(collection_name)
284 if self.order_by is not None:
285 dirn = self.order_by.direction
286 query = query.order_by(self.order_by.key, direction=dirn)
287 if self.filter_value is not None:
288 query = query.where(filter=self.filter_value)
290 if self.limit is not None:
291 query = query.limit(self.limit)
293 return query
295 @override
296 def __str__(self) -> str:
297 filter_str = f"filter_value={self.filter_value}" if self.filter_value else "no filter"
298 return f"AppQuery({filter_str}, limit={self.limit}, order_by={self.order_by})"
300 @override
301 def __repr__(self) -> str:
302 return self.__str__()
305class DbSubQuery:
307 def __init__(
308 self,
309 parent_doc_id: str,
310 filter_value: Any | None = None,
311 limit: int | None = None,
312 order_by: OrderByField[Any] | None = None,
313 ) -> None:
314 self.filter_value = filter_value
315 self.limit = limit
316 self.order_by = order_by
317 self.parent_doc_id = parent_doc_id
319 @classmethod
320 def where(
321 cls,
322 parent_doc_id: str,
323 where_field: WhereField[Any],
324 limit: int | None = None,
325 order_by: OrderByField[Any] | None = None,
326 ) -> DbSubQuery:
327 filter_value = FieldFilter(where_field.key, where_field.op.value, where_field.value)
328 return cls(
329 parent_doc_id=parent_doc_id,
330 filter_value=filter_value,
331 limit=limit,
332 order_by=order_by,
333 )
335 @classmethod
336 def and_(
337 cls,
338 parent_doc_id: str,
339 where_fields: list[WhereField[Any]],
340 limit: int | None = None,
341 order_by: OrderByField[Any] | None = None,
342 date_filter: DateFilter | None = None,
343 ) -> DbSubQuery:
344 filter_values = [field.filter for field in where_fields]
345 if date_filter is not None:
346 filter_values.extend(date_filter.filters)
348 filter_value = And(filter_values) # type: ignore[arg-type]
349 return cls(
350 parent_doc_id=parent_doc_id,
351 filter_value=filter_value,
352 limit=limit,
353 order_by=order_by,
354 )
356 def get_query(
357 self,
358 client: FirestoreClient,
359 collection_name: str,
360 sub_collection_name: str,
361 ) -> Query:
362 query: Query = (
363 client.collection(collection_name)
364 .document(self.parent_doc_id)
365 .collection(sub_collection_name)
366 )
368 if self.order_by is not None:
369 query = query.order_by(self.order_by.key, direction=self.order_by.direction)
371 if self.filter_value is not None:
372 query = query.where(filter=self.filter_value)
374 if self.limit is not None:
375 query = query.limit(self.limit)
377 return query