Skip to content

Commit a256752

Browse files
authored
fix: stop / start stream after filter mismatch (#502)
~~Based on branch for PR #500 -- I will rebase after that PR merges.~~ Closes #367. Supersedes PR #497.
1 parent 74d8171 commit a256752

File tree

6 files changed

+45
-65
lines changed

6 files changed

+45
-65
lines changed

google/cloud/firestore_v1/collection.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,5 @@ def on_snapshot(collection_snapshot, changes, read_time):
237237
# Terminate this watch
238238
collection_watch.unsubscribe()
239239
"""
240-
return Watch.for_query(
241-
self._query(),
242-
callback,
243-
document.DocumentSnapshot,
244-
document.DocumentReference,
245-
)
240+
query = self._query()
241+
return Watch.for_query(query, callback, document.DocumentSnapshot)

google/cloud/firestore_v1/document.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,4 +489,4 @@ def on_snapshot(document_snapshot, changes, read_time):
489489
# Terminate this watch
490490
doc_watch.unsubscribe()
491491
"""
492-
return Watch.for_document(self, callback, DocumentSnapshot, DocumentReference)
492+
return Watch.for_document(self, callback, DocumentSnapshot)

google/cloud/firestore_v1/query.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,7 @@ def on_snapshot(docs, changes, read_time):
329329
# Terminate this watch
330330
query_watch.unsubscribe()
331331
"""
332-
return Watch.for_query(
333-
self, callback, document.DocumentSnapshot, document.DocumentReference
334-
)
332+
return Watch.for_query(self, callback, document.DocumentSnapshot)
335333

336334
@staticmethod
337335
def _get_collection_reference_class() -> Type[

google/cloud/firestore_v1/watch.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def __init__(
175175
comparator,
176176
snapshot_callback,
177177
document_snapshot_cls,
178-
document_reference_cls,
179178
):
180179
"""
181180
Args:
@@ -192,35 +191,21 @@ def __init__(
192191
read_time (string): The ISO 8601 time at which this
193192
snapshot was obtained.
194193
195-
document_snapshot_cls: instance of DocumentSnapshot
196-
document_reference_cls: instance of DocumentReference
194+
document_snapshot_cls: factory for instances of DocumentSnapshot
197195
"""
198196
self._document_reference = document_reference
199197
self._firestore = firestore
200-
self._api = firestore._firestore_api
201198
self._targets = target
202199
self._comparator = comparator
203-
self.DocumentSnapshot = document_snapshot_cls
204-
self.DocumentReference = document_reference_cls
200+
self._document_snapshot_cls = document_snapshot_cls
205201
self._snapshot_callback = snapshot_callback
202+
self._api = firestore._firestore_api
206203
self._closing = threading.Lock()
207204
self._closed = False
208205
self._set_documents_pfx(firestore._database_string)
209206

210207
self.resume_token = None
211208

212-
rpc_request = self._get_rpc_request
213-
214-
self._rpc = ResumableBidiRpc(
215-
start_rpc=self._api._transport.listen,
216-
should_recover=_should_recover,
217-
should_terminate=_should_terminate,
218-
initial_request=rpc_request,
219-
metadata=self._firestore._rpc_metadata,
220-
)
221-
222-
self._rpc.add_done_callback(self._on_rpc_done)
223-
224209
# Initialize state for on_snapshot
225210
# The sorted tree of QueryDocumentSnapshots as sent in the last
226211
# snapshot. We only look at the keys.
@@ -242,17 +227,29 @@ def __init__(
242227
# aren't docs.
243228
self.has_pushed = False
244229

230+
self._init_stream()
231+
232+
def _init_stream(self):
233+
234+
rpc_request = self._get_rpc_request
235+
236+
self._rpc = ResumableBidiRpc(
237+
start_rpc=self._api._transport.listen,
238+
should_recover=_should_recover,
239+
should_terminate=_should_terminate,
240+
initial_request=rpc_request,
241+
metadata=self._firestore._rpc_metadata,
242+
)
243+
244+
self._rpc.add_done_callback(self._on_rpc_done)
245+
245246
# The server assigns and updates the resume token.
246247
self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot)
247248
self._consumer.start()
248249

249250
@classmethod
250251
def for_document(
251-
cls,
252-
document_ref,
253-
snapshot_callback,
254-
document_snapshot_cls,
255-
document_reference_cls,
252+
cls, document_ref, snapshot_callback, document_snapshot_cls,
256253
):
257254
"""
258255
Creates a watch snapshot listener for a document. snapshot_callback
@@ -276,13 +273,10 @@ def for_document(
276273
document_watch_comparator,
277274
snapshot_callback,
278275
document_snapshot_cls,
279-
document_reference_cls,
280276
)
281277

282278
@classmethod
283-
def for_query(
284-
cls, query, snapshot_callback, document_snapshot_cls, document_reference_cls,
285-
):
279+
def for_query(cls, query, snapshot_callback, document_snapshot_cls):
286280
parent_path, _ = query._parent._parent_info()
287281
query_target = Target.QueryTarget(
288282
parent=parent_path, structured_query=query._to_protobuf()
@@ -295,12 +289,13 @@ def for_query(
295289
query._comparator,
296290
snapshot_callback,
297291
document_snapshot_cls,
298-
document_reference_cls,
299292
)
300293

301294
def _get_rpc_request(self):
302295
if self.resume_token is not None:
303296
self._targets["resume_token"] = self.resume_token
297+
else:
298+
self._targets.pop("resume_token", None)
304299

305300
return ListenRequest(
306301
database=self._firestore._database_string, add_target=self._targets
@@ -490,7 +485,7 @@ def on_snapshot(self, proto):
490485
document_name = self._strip_document_pfx(document.name)
491486
document_ref = self._firestore.document(document_name)
492487

493-
snapshot = self.DocumentSnapshot(
488+
snapshot = self._document_snapshot_cls(
494489
reference=document_ref,
495490
data=data,
496491
exists=True,
@@ -520,11 +515,17 @@ def on_snapshot(self, proto):
520515
elif which == "filter":
521516
_LOGGER.debug("on_snapshot: filter update")
522517
if pb.filter.count != self._current_size():
523-
# We need to remove all the current results.
518+
# First, shut down current stream
519+
_LOGGER.info("Filter mismatch -- restarting stream.")
520+
thread = threading.Thread(
521+
name=_RPC_ERROR_THREAD_NAME, target=self.close,
522+
)
523+
thread.start()
524+
thread.join() # wait for shutdown to complete
525+
# Then, remove all the current results.
524526
self._reset_docs()
525-
# The filter didn't match, so re-issue the query.
526-
# TODO: reset stream method?
527-
# self._reset_stream();
527+
# Finally, restart stream.
528+
self._init_stream()
528529

529530
else:
530531
_LOGGER.debug("UNKNOWN TYPE. UHOH")

tests/unit/v1/test_cross_language.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def test_listen_testprotos(test_proto): # pragma: NO COVER
216216
# 'docs' (list of 'google.firestore_v1.Document'),
217217
# 'changes' (list lof local 'DocChange', and 'read_time' timestamp.
218218
from google.cloud.firestore_v1 import Client
219-
from google.cloud.firestore_v1 import DocumentReference
220219
from google.cloud.firestore_v1 import DocumentSnapshot
221220
from google.cloud.firestore_v1 import Watch
222221
import google.auth.credentials
@@ -226,6 +225,9 @@ def test_listen_testprotos(test_proto): # pragma: NO COVER
226225

227226
credentials = mock.Mock(spec=google.auth.credentials.Credentials)
228227
client = Client(project="project", credentials=credentials)
228+
# conformance data has db string as this
229+
db_str = "projects/projectID/databases/(default)"
230+
client._database_string_internal = db_str
229231
with mock.patch("google.cloud.firestore_v1.watch.ResumableBidiRpc"):
230232
with mock.patch("google.cloud.firestore_v1.watch.BackgroundConsumer"):
231233
# conformance data sets WATCH_TARGET_ID to 1
@@ -237,12 +239,7 @@ def callback(keys, applied_changes, read_time):
237239

238240
collection = DummyCollection(client=client)
239241
query = DummyQuery(parent=collection)
240-
watch = Watch.for_query(
241-
query, callback, DocumentSnapshot, DocumentReference
242-
)
243-
# conformance data has db string as this
244-
db_str = "projects/projectID/databases/(default)"
245-
watch._firestore._database_string_internal = db_str
242+
watch = Watch.for_query(query, callback, DocumentSnapshot)
246243

247244
wrapped_responses = [
248245
firestore.ListenResponse.wrap(proto) for proto in testcase.responses

tests/unit/v1/test_watch.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ def snapshot_callback(*args):
183183
comparator=comparator,
184184
snapshot_callback=snapshot_callback,
185185
document_snapshot_cls=DummyDocumentSnapshot,
186-
document_reference_cls=DummyDocumentReference,
187186
)
188187

189188

@@ -224,16 +223,11 @@ def snapshot_callback(*args): # pragma: NO COVER
224223
snapshots.append(args)
225224

226225
docref = DummyDocumentReference()
227-
snapshot_class_instance = DummyDocumentSnapshot
228-
document_reference_class_instance = DummyDocumentReference
229226

230227
with mock.patch("google.cloud.firestore_v1.watch.ResumableBidiRpc"):
231228
with mock.patch("google.cloud.firestore_v1.watch.BackgroundConsumer"):
232229
inst = Watch.for_document(
233-
docref,
234-
snapshot_callback,
235-
snapshot_class_instance,
236-
document_reference_class_instance,
230+
docref, snapshot_callback, document_snapshot_cls=DummyDocumentSnapshot,
237231
)
238232

239233
inst._consumer.start.assert_called_once_with()
@@ -246,8 +240,6 @@ def test_watch_for_query(snapshots):
246240
def snapshot_callback(*args): # pragma: NO COVER
247241
snapshots.append(args)
248242

249-
snapshot_class_instance = DummyDocumentSnapshot
250-
document_reference_class_instance = DummyDocumentReference
251243
client = DummyFirestore()
252244
parent = DummyCollection(client)
253245
query = DummyQuery(parent=parent)
@@ -258,8 +250,7 @@ def snapshot_callback(*args): # pragma: NO COVER
258250
inst = Watch.for_query(
259251
query,
260252
snapshot_callback,
261-
snapshot_class_instance,
262-
document_reference_class_instance,
253+
document_snapshot_cls=DummyDocumentSnapshot,
263254
)
264255

265256
inst._consumer.start.assert_called_once_with()
@@ -278,8 +269,6 @@ def test_watch_for_query_nested(snapshots):
278269
def snapshot_callback(*args): # pragma: NO COVER
279270
snapshots.append(args)
280271

281-
snapshot_class_instance = DummyDocumentSnapshot
282-
document_reference_class_instance = DummyDocumentReference
283272
client = DummyFirestore()
284273
root = DummyCollection(client)
285274
grandparent = DummyDocument("document", parent=root)
@@ -292,8 +281,7 @@ def snapshot_callback(*args): # pragma: NO COVER
292281
inst = Watch.for_query(
293282
query,
294283
snapshot_callback,
295-
snapshot_class_instance,
296-
document_reference_class_instance,
284+
document_snapshot_cls=DummyDocumentSnapshot,
297285
)
298286

299287
inst._consumer.start.assert_called_once_with()

0 commit comments

Comments
 (0)