Skip to content

Add tests to verify the behavior of basic COM methods. #126384

Closed
@junkmd

Description

@junkmd

Feature or enhancement

Proposal:

I am one of the maintainers of comtypes. comtypes is based on ctypes to implement IUnknown and other COM stuffs.

In the past, I reported in gh-124520 that projects dependent on ctypes was broken due to changes in Python 3.13.

I am currently researching whether there are any effective ways to proactively prevent such regressions beyond what I attempted in gh-125783.

I noticed that the cpython repository might not contain tests for basic COM methods, such as QueryInterface, AddRef, and Release.
There are many projects besides comtypes that define COM interfaces and call COM methods, so I think it’s important to test them.

I think that simple tests like the one below, which implements a very basic COM interface and calls its methods, might help prevent regressions.
(I removed the complex parts in comtypes, such as defining methods, registering pointer types, and using __del__ to call Release through metaclasse magic.)
(I have confirmed that this test passes in a virtual environment with Python 3.11.2, which I could quickly set up.)

import sys
import unittest


def setUpModule():
    if sys.platform != "win32":
        raise unittest.SkipTest("Win32 only")


import ctypes
import gc
from ctypes import HRESULT, POINTER, byref
from ctypes.wintypes import BYTE, DWORD, HGLOBAL, WORD

ole32 = ctypes.oledll.ole32
oleaut32 = ctypes.oledll.oleaut32


def CLSIDFromString(name):
    guid = GUID()
    ole32.CLSIDFromString(name, byref(guid))
    return guid


class GUID(ctypes.Structure):
    _fields_ = [
        ("Data1", DWORD),
        ("Data2", WORD),
        ("Data3", WORD),
        ("Data4", BYTE * 8),
    ]


PyInstanceMethod_New = ctypes.pythonapi.PyInstanceMethod_New
PyInstanceMethod_New.argtypes = [ctypes.py_object]
PyInstanceMethod_New.restype = ctypes.py_object
PyInstanceMethod_Type = type(PyInstanceMethod_New(id))


class COM_METHOD:
    def __init__(self, index, restype, *argtypes):
        self.index = index
        self.proto = ctypes.WINFUNCTYPE(restype, *argtypes)

    def __set_name__(self, owner, name):
        self.mth = PyInstanceMethod_Type(self.proto(self.index, name))

    def __call__(self, *args, **kwargs):
        return self.mth(*args, **kwargs)

    def __get__(self, instance, owner):
        if instance is None:
            return self
        return self.mth.__get__(instance)


class IUnknown(ctypes.c_void_p):
    IID = CLSIDFromString("{00000000-0000-0000-C000-000000000046}")
    QueryInterface = COM_METHOD(0, HRESULT, POINTER(GUID), POINTER(ctypes.c_void_p))
    AddRef = COM_METHOD(1, ctypes.c_long)
    Release = COM_METHOD(2, ctypes.c_long)


class ICreateTypeLib(IUnknown):
    IID = CLSIDFromString("{00020406-0000-0000-C000-000000000046}")
    # `CreateTypeInfo` and more methods should be implemented


class ICreateTypeLib2(ICreateTypeLib):
    IID = CLSIDFromString("{0002040F-0000-0000-C000-000000000046}")
    # `DeleteTypeInfo` and more methods should be implemented


class ISequentialStream(IUnknown):
    IID = CLSIDFromString("{0C733A30-2A1C-11CE-ADE5-00AA0044773D}")
    # `Read` and `Write` methods should be implemented


class IStream(ISequentialStream):
    IID = CLSIDFromString("{0000000C-0000-0000-C000-000000000046}")
    # `Seek` and more methods should be implemented


CreateTypeLib2 = oleaut32.CreateTypeLib2
CreateTypeLib2.argtypes = (ctypes.c_int, ctypes.c_wchar_p, POINTER(ICreateTypeLib2))

CreateStreamOnHGlobal = ole32.CreateStreamOnHGlobal
CreateStreamOnHGlobal.argtypes = (HGLOBAL, ctypes.c_bool, POINTER(IStream))


COINIT_APARTMENTTHREADED = 0x2
S_OK = 0
E_NOINTERFACE = -2147467262


class Test(unittest.TestCase):
    def setUp(self):
        ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED)

    def tearDown(self):
        ole32.CoUninitialize()
        gc.collect()

    def test_create_typelib_2(self):
        pctlib = ICreateTypeLib2()
        hr = CreateTypeLib2(0, "sample.tlb", pctlib)
        self.assertEqual(S_OK, hr)

        self.assertEqual(2, pctlib.AddRef())
        self.assertEqual(3, pctlib.AddRef())

        self.assertEqual(2, pctlib.Release())
        self.assertEqual(1, pctlib.Release())
        self.assertEqual(0, pctlib.Release())

    def test_stream(self):
        pstm = IStream()
        hr = CreateStreamOnHGlobal(None, True, pstm)
        self.assertEqual(S_OK, hr)

        self.assertEqual(2, pstm.AddRef())
        self.assertEqual(3, pstm.AddRef())

        self.assertEqual(2, pstm.Release())
        self.assertEqual(1, pstm.Release())
        self.assertEqual(0, pstm.Release())

    def test_query_interface(self):
        pctlib2 = ICreateTypeLib2()
        CreateTypeLib2(0, "sample.tlb", pctlib2)

        pctlib = ICreateTypeLib()
        hr1 = pctlib2.QueryInterface(byref(ICreateTypeLib.IID), byref(pctlib))
        self.assertEqual(S_OK, hr1)
        self.assertEqual(1, pctlib.Release())

        punk = IUnknown()
        hr2 = pctlib.QueryInterface(byref(IUnknown.IID), byref(punk))
        self.assertEqual(S_OK, hr2)
        self.assertEqual(1, punk.Release())

        pstm = IStream()
        with self.assertRaises(WindowsError) as e:  # Why not `COMError`?
            punk.QueryInterface(byref(IStream.IID), byref(pstm))
        self.assertEqual(E_NOINTERFACE, e.exception.winerror)

        self.assertEqual(0, punk.Release())
  • I am not sure why a WindowsError is raised instead of a COMError when QueryInterface fails.
  • Perhaps an interface with even fewer methods should be used in the test.
  • At this stage, I think creating custom COM type libraries and interfaces might be excessive.

I welcome any feedback.

Has this already been discussed elsewhere?

No response given

Links to previous discussion of this feature:

No response

Linked PRs

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions