Skip to content

Commit ce7e4d4

Browse files
committed
Let the exceptions propagate
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 78b5e0b commit ce7e4d4

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

Lib/asyncio/taskgroups.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ def create_task(self, coro, *, name=None, context=None):
159159
self._tasks.add(task)
160160
return task
161161

162-
def _eager_eval(self, coro):
162+
def enqueue(self, coro, no_future=True):
163+
if not self._entered:
164+
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
165+
163166
try:
164167
fut = coro.send(None)
165168
task = self.create_task(coro)
@@ -168,23 +171,8 @@ def _eager_eval(self, coro):
168171
except StopIteration as e:
169172
# The co-routine has completed synchronously and we've got
170173
# our result.
171-
return PyCoroEagerResult(e.args[0] if e.args else None)
172-
except Exception as e:
173174
res = Future(loop=self._loop)
174-
res.set_exception(e)
175-
return res
176-
177-
def enqueue(self, coro, no_future=True):
178-
if not self._entered:
179-
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
180-
181-
res = self._eager_eval(coro)
182-
if isinstance(res, PyCoroEagerResult):
183-
if not no_future:
184-
fut = Future(loop=self._loop)
185-
fut.set_result(res.value)
186-
return fut
187-
else:
175+
res.set_result(e.args[0] if e.args else None)
188176
return res
189177

190178
# Since Python 3.8 Tasks propagate all exceptions correctly,

Lib/test/test_asyncio/test_taskgroups.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,16 @@ async def eager2():
753753
self.assertEqual(t1.result(), 42)
754754
self.assertEqual(t2.result(), 11)
755755

756+
async def test_taskgroup_enqueue_exception(self):
757+
async def foo1():
758+
1 / 0
759+
760+
with self.assertRaises(ExceptionGroup) as cm:
761+
async with taskgroups.TaskGroup() as g:
762+
g.enqueue(foo1())
763+
764+
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
765+
756766
async def test_taskgroup_fanout_task(self):
757767
async def step(i):
758768
if i == 0:

0 commit comments

Comments
 (0)