diff --git a/Makefile b/Makefile index e62cb9c..694b1c3 100644 --- a/Makefile +++ b/Makefile @@ -47,11 +47,15 @@ format: check-format: poetry run ruff format --check awslambdaric/ tests/ +.PHONY: check-docstr +check-docstr: + python3 scripts/dev.py check-docstr + .PHONY: dev dev: init test .PHONY: pr -pr: init check-format check-security dev +pr: init check-format check-annotations check-types check-type-usage check-security dev .PHONY: codebuild codebuild: setup-codebuild-agent @@ -70,17 +74,18 @@ define HELP_MESSAGE Usage: $ make [TARGETS] TARGETS - check-security Run bandit to find security issues. - format Run black to automatically update your code to match formatting. - build Build the package using scripts/dev.py. - clean Cleans the working directory using scripts/dev.py. - dev Run all development tests using scripts/dev.py. - init Install dependencies via scripts/dev.py. - build-container Build awslambdaric wheel in isolated container. - test-rie Test with RIE using pre-built wheel (run build-container first). - pr Perform all checks before submitting a Pull Request. - test Run unit tests using scripts/dev.py. - lint Run all linters via scripts/dev.py. - test-smoke Run smoke tests inside Docker. - test-integ Run all integration tests. + check-security Run bandit to find security issues. + check-docstr Check docstrings in project using ruff format check. + format Run ruff to automatically format your code. + build Build the package using scripts/dev.py. + clean Cleans the working directory using scripts/dev.py. + dev Run all development tests using scripts/dev.py. + init Install dependencies via scripts/dev.py. + build-container Build awslambdaric wheel in isolated container. + test-rie Test with RIE using pre-built wheel (run build-container first). + pr Perform all checks before submitting a Pull Request. + test Run unit tests using scripts/dev.py. + lint Run all linters via scripts/dev.py. + test-smoke Run smoke tests inside Docker. + test-integ Run all integration tests. endef \ No newline at end of file diff --git a/awslambdaric/__init__.py b/awslambdaric/__init__.py index 5605903..0f94db0 100644 --- a/awslambdaric/__init__.py +++ b/awslambdaric/__init__.py @@ -1,5 +1,3 @@ -""" -Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" __version__ = "3.1.1" diff --git a/awslambdaric/__main__.py b/awslambdaric/__main__.py index 5cbbaab..3cd1ad4 100644 --- a/awslambdaric/__main__.py +++ b/awslambdaric/__main__.py @@ -1,6 +1,4 @@ -""" -Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" import os import sys @@ -9,6 +7,7 @@ def main(args): + """Run the Lambda runtime main entry point.""" app_root = os.getcwd() try: diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py index f63e765..9bf1382 100644 --- a/awslambdaric/bootstrap.py +++ b/awslambdaric/bootstrap.py @@ -1,6 +1,4 @@ -""" -Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" import importlib import json @@ -36,6 +34,7 @@ def _get_handler(handler): + """Get handler function from module.""" try: (modname, fname) = handler.rsplit(".", 1) except ValueError as e: @@ -82,6 +81,7 @@ def make_error( stack_trace, invoke_id=None, ): + """Create error response.""" result = { "errorMessage": error_message if error_message else "", "errorType": error_type if error_type else "", @@ -92,6 +92,7 @@ def make_error( def replace_line_indentation(line, indent_char, new_indent_char): + """Replace line indentation characters.""" ident_chars_count = 0 for c in line: if c != indent_char: @@ -104,6 +105,7 @@ def replace_line_indentation(line, indent_char, new_indent_char): _ERROR_FRAME_TYPE = _JSON_FRAME_TYPES[logging.ERROR] def log_error(error_result, log_sink): + """Log error in JSON format.""" error_result = { "timestamp": time.strftime( _DATETIME_FORMAT, logging.Formatter.converter(time.time()) @@ -119,6 +121,7 @@ def log_error(error_result, log_sink): _ERROR_FRAME_TYPE = _TEXT_FRAME_TYPES[logging.ERROR] def log_error(error_result, log_sink): + """Log error in text format.""" error_description = "[ERROR]" error_result_type = error_result.get("errorType") @@ -161,6 +164,7 @@ def handle_event_request( tenant_id, log_sink, ): + """Handle Lambda event request.""" error_result = None try: lambda_context = create_lambda_context( @@ -212,6 +216,7 @@ def handle_event_request( def parse_json_header(header, name): + """Parse JSON header.""" try: return json.loads(header) except Exception as e: @@ -230,6 +235,7 @@ def create_lambda_context( invoked_function_arn, tenant_id, ): + """Create Lambda context object.""" client_context = None if client_context_json: client_context = parse_json_header(client_context_json, "Client Context") @@ -248,6 +254,7 @@ def create_lambda_context( def build_fault_result(exc_info, msg): + """Build fault result from exception info.""" etype, value, tb = exc_info tb_tuples = extract_traceback(tb) for i in range(len(tb_tuples)): @@ -263,6 +270,7 @@ def build_fault_result(exc_info, msg): def make_xray_fault(ex_type, ex_msg, working_dir, tb_tuples): + """Create X-Ray fault object.""" stack = [] files = set() for t in tb_tuples: @@ -281,6 +289,7 @@ def make_xray_fault(ex_type, ex_msg, working_dir, tb_tuples): def extract_traceback(tb): + """Extract traceback information.""" return [ (frame.filename, frame.lineno, frame.name, frame.line) for frame in traceback.extract_tb(tb) @@ -288,6 +297,7 @@ def extract_traceback(tb): def on_init_complete(lambda_runtime_client, log_sink): + """Handle initialization completion.""" from . import lambda_runtime_hooks_runner try: @@ -311,21 +321,29 @@ def on_init_complete(lambda_runtime_client, log_sink): class LambdaLoggerHandler(logging.Handler): + """Lambda logger handler.""" + def __init__(self, log_sink): + """Initialize logger handler.""" logging.Handler.__init__(self) self.log_sink = log_sink def emit(self, record): + """Emit log record.""" msg = self.format(record) self.log_sink.log(msg) class LambdaLoggerHandlerWithFrameType(logging.Handler): + """Lambda logger handler with frame type.""" + def __init__(self, log_sink): + """Initialize logger handler.""" super().__init__() self.log_sink = log_sink def emit(self, record): + """Emit log record with frame type.""" self.log_sink.log( self.format(record), frame_type=( @@ -336,14 +354,20 @@ def emit(self, record): class LambdaLoggerFilter(logging.Filter): + """Lambda logger filter.""" + def filter(self, record): + """Filter log record.""" record.aws_request_id = _GLOBAL_AWS_REQUEST_ID or "" record.tenant_id = _GLOBAL_TENANT_ID return True class Unbuffered(object): + """Unbuffered stream wrapper.""" + def __init__(self, stream): + """Initialize unbuffered stream.""" self.stream = stream def __enter__(self): @@ -356,16 +380,21 @@ def __getattr__(self, attr): return getattr(self.stream, attr) def write(self, msg): + """Write message to stream.""" self.stream.write(msg) self.stream.flush() def writelines(self, msgs): + """Write multiple lines to stream.""" self.stream.writelines(msgs) self.stream.flush() class StandardLogSink(object): + """Standard log sink.""" + def __init__(self): + """Initialize standard log sink.""" pass def __enter__(self): @@ -375,17 +404,19 @@ def __exit__(self, exc_type, exc_value, exc_tb): pass def log(self, msg, frame_type=None): + """Log message to stdout.""" sys.stdout.write(msg) def log_error(self, message_lines): + """Log error message to stdout.""" error_message = ERROR_LOG_LINE_TERMINATE.join(message_lines) + "\n" sys.stdout.write(error_message) class FramedTelemetryLogSink(object): - """ - FramedTelemetryLogSink implements the logging contract between runtimes and the platform. It implements a simple - framing protocol so message boundaries can be determined. Each frame can be visualized as follows: + """FramedTelemetryLogSink implements the logging contract between runtimes and the platform. + + It implements a simple framing protocol so message boundaries can be determined. Each frame can be visualized as follows:
     {@code
     +----------------------+------------------------+---------------------+-----------------------+
@@ -399,6 +430,7 @@ class FramedTelemetryLogSink(object):
     """
 
     def __init__(self, fd):
+        """Initialize framed telemetry log sink."""
         self.fd = int(fd)
 
     def __enter__(self):
@@ -409,6 +441,7 @@ def __exit__(self, exc_type, exc_value, exc_tb):
         self.file.close()
 
     def log(self, msg, frame_type=None):
+        """Log message with frame type."""
         encoded_msg = msg.encode("utf8")
 
         timestamp = int(time.time_ns() / 1000)  # UNIX timestamp in microseconds
@@ -421,6 +454,7 @@ def log(self, msg, frame_type=None):
         self.file.write(log_msg)
 
     def log_error(self, message_lines):
+        """Log error message."""
         error_message = "\n".join(message_lines)
         self.log(
             error_message,
@@ -429,6 +463,7 @@ def log_error(self, message_lines):
 
 
 def update_xray_env_variable(xray_trace_id):
+    """Update X-Ray trace ID environment variable."""
     if xray_trace_id is not None:
         os.environ["_X_AMZN_TRACE_ID"] = xray_trace_id
     else:
@@ -437,6 +472,7 @@ def update_xray_env_variable(xray_trace_id):
 
 
 def create_log_sink():
+    """Create appropriate log sink."""
     if "_LAMBDA_TELEMETRY_LOG_FD" in os.environ:
         fd = os.environ["_LAMBDA_TELEMETRY_LOG_FD"]
         del os.environ["_LAMBDA_TELEMETRY_LOG_FD"]
@@ -451,6 +487,7 @@ def create_log_sink():
 
 
 def _setup_logging(log_format, log_level, log_sink):
+    """Set up logging configuration."""
     logging.Formatter.converter = time.gmtime
     logger = logging.getLogger()
 
@@ -477,6 +514,7 @@ def _setup_logging(log_format, log_level, log_sink):
 
 
 def run(app_root, handler, lambda_runtime_api_addr):
+    """Run Lambda runtime."""
     sys.stdout = Unbuffered(sys.stdout)
     sys.stderr = Unbuffered(sys.stderr)
 
diff --git a/awslambdaric/lambda_context.py b/awslambdaric/lambda_context.py
index e0a3363..e827993 100644
--- a/awslambdaric/lambda_context.py
+++ b/awslambdaric/lambda_context.py
@@ -1,6 +1,4 @@
-"""
-Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-"""
+"""Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved."""
 
 import logging
 import os
@@ -9,6 +7,8 @@
 
 
 class LambdaContext(object):
+    """Lambda context object."""
+
     def __init__(
         self,
         invoke_id,
@@ -18,6 +18,7 @@ def __init__(
         invoked_function_arn=None,
         tenant_id=None,
     ):
+        """Initialize Lambda context."""
         self.aws_request_id = invoke_id
         self.log_group_name = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME")
         self.log_stream_name = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")
@@ -45,11 +46,13 @@ def __init__(
         self._epoch_deadline_time_in_ms = epoch_deadline_time_in_ms
 
     def get_remaining_time_in_millis(self):
+        """Get remaining time in milliseconds."""
         epoch_now_in_ms = int(time.time() * 1000)
         delta_ms = self._epoch_deadline_time_in_ms - epoch_now_in_ms
         return delta_ms if delta_ms > 0 else 0
 
     def log(self, msg):
+        """Log a message."""
         for handler in logging.getLogger().handlers:
             if hasattr(handler, "log_sink"):
                 handler.log_sink.log(str(msg))
@@ -74,6 +77,8 @@ def __repr__(self):
 
 
 class CognitoIdentity(object):
+    """Cognito identity information."""
+
     __slots__ = ["cognito_identity_id", "cognito_identity_pool_id"]
 
     def __repr__(self):
@@ -86,6 +91,8 @@ def __repr__(self):
 
 
 class Client(object):
+    """Client information."""
+
     __slots__ = [
         "installation_id",
         "app_title",
@@ -107,6 +114,8 @@ def __repr__(self):
 
 
 class ClientContext(object):
+    """Client context information."""
+
     __slots__ = ["custom", "env", "client"]
 
     def __repr__(self):
@@ -120,6 +129,7 @@ def __repr__(self):
 
 
 def make_obj_from_dict(_class, _dict, fields=None):
+    """Create object from dictionary."""
     if _dict is None:
         return None
     obj = _class()
@@ -128,6 +138,7 @@ def make_obj_from_dict(_class, _dict, fields=None):
 
 
 def set_obj_from_dict(obj, _dict, fields=None):
+    """Set object attributes from dictionary."""
     if fields is None:
         fields = obj.__class__.__slots__
     for field in fields:
diff --git a/awslambdaric/lambda_literals.py b/awslambdaric/lambda_literals.py
index 2585b89..a2a0746 100644
--- a/awslambdaric/lambda_literals.py
+++ b/awslambdaric/lambda_literals.py
@@ -1,6 +1,4 @@
-"""
-Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-"""
+"""Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved."""
 
 lambda_warning = "LAMBDA_WARNING"
 
diff --git a/awslambdaric/lambda_runtime_client.py b/awslambdaric/lambda_runtime_client.py
index ba4ad92..fd5fdf1 100644
--- a/awslambdaric/lambda_runtime_client.py
+++ b/awslambdaric/lambda_runtime_client.py
@@ -1,6 +1,4 @@
-"""
-Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-"""
+"""Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved."""
 
 import sys
 from awslambdaric import __version__
@@ -29,15 +27,22 @@ def _user_agent():
 
 
 class InvocationRequest(object):
+    """Lambda invocation request."""
+
     def __init__(self, **kwds):
+        """Initialize invocation request."""
         self.__dict__.update(kwds)
 
     def __eq__(self, other):
+        """Check equality."""
         return self.__dict__ == other.__dict__
 
 
 class LambdaRuntimeClientError(Exception):
+    """Lambda runtime client error."""
+
     def __init__(self, endpoint, response_code, response_body):
+        """Initialize runtime client error."""
         self.endpoint = endpoint
         self.response_code = response_code
         self.response_body = response_body
@@ -47,12 +52,15 @@ def __init__(self, endpoint, response_code, response_body):
 
 
 class LambdaRuntimeClient(object):
+    """Lambda runtime client."""
+
     marshaller = LambdaMarshaller()
     """marshaller is a class attribute that determines the unmarshalling and marshalling logic of a function's event
     and response. It allows for function authors to override the the default implementation, LambdaMarshaller which
     unmarshals and marshals JSON, to an instance of a class that implements the same interface."""
 
     def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False):
+        """Initialize runtime client."""
         self.lambda_runtime_address = lambda_runtime_address
         self.use_thread_for_polling_next = use_thread_for_polling_next
         if self.use_thread_for_polling_next:
@@ -65,6 +73,7 @@ def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False):
     def call_rapid(
         self, http_method, endpoint, expected_http_code, payload=None, headers=None
     ):
+        """Call RAPID endpoint."""
         # These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`.
         # Importing them lazily to speed up critical path of a common case.
         import http.client
@@ -84,6 +93,7 @@ def call_rapid(
             raise LambdaRuntimeClientError(endpoint, response.code, response_body)
 
     def post_init_error(self, error_response_data, error_type_override=None):
+        """Post initialization error."""
         import http
 
         endpoint = "/2018-06-01/runtime/init/error"
@@ -99,12 +109,14 @@ def post_init_error(self, error_response_data, error_type_override=None):
         )
 
     def restore_next(self):
+        """Restore next invocation."""
         import http
 
         endpoint = "/2018-06-01/runtime/restore/next"
         self.call_rapid("GET", endpoint, http.HTTPStatus.OK)
 
     def report_restore_error(self, restore_error_data):
+        """Report restore error."""
         import http
 
         endpoint = "/2018-06-01/runtime/restore/error"
@@ -114,6 +126,7 @@ def report_restore_error(self, restore_error_data):
         )
 
     def wait_next_invocation(self):
+        """Wait for next invocation."""
         # Calling runtime_client.next() from a separate thread unblocks the main thread,
         # which can then process signals.
         if self.use_thread_for_polling_next:
@@ -145,6 +158,7 @@ def wait_next_invocation(self):
     def post_invocation_result(
         self, invoke_id, result_data, content_type="application/json"
     ):
+        """Post invocation result."""
         runtime_client.post_invocation_result(
             invoke_id,
             (
@@ -156,6 +170,7 @@ def post_invocation_result(
         )
 
     def post_invocation_error(self, invoke_id, error_response_data, xray_fault):
+        """Post invocation error."""
         max_header_size = 1024 * 1024  # 1MiB
         xray_fault = xray_fault if len(xray_fault.encode()) < max_header_size else ""
         runtime_client.post_error(invoke_id, error_response_data, xray_fault)
diff --git a/awslambdaric/lambda_runtime_exception.py b/awslambdaric/lambda_runtime_exception.py
index 3ea5b29..3c2e41a 100644
--- a/awslambdaric/lambda_runtime_exception.py
+++ b/awslambdaric/lambda_runtime_exception.py
@@ -1,9 +1,9 @@
-"""
-Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-"""
+"""Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved."""
 
 
 class FaultException(Exception):
+    """Exception class for Lambda runtime faults."""
+
     MARSHAL_ERROR = "Runtime.MarshalError"
     UNMARSHAL_ERROR = "Runtime.UnmarshalError"
     USER_CODE_SYNTAX_ERROR = "Runtime.UserCodeSyntaxError"
@@ -17,6 +17,7 @@ class FaultException(Exception):
     LAMBDA_RUNTIME_CLIENT_ERROR = "Runtime.LambdaRuntimeClientError"
 
     def __init__(self, exception_type, msg, trace=None):
+        """Initialize FaultException."""
         self.msg = msg
         self.exception_type = exception_type
         self.trace = trace
diff --git a/awslambdaric/lambda_runtime_hooks_runner.py b/awslambdaric/lambda_runtime_hooks_runner.py
index 8aee181..ab67e63 100644
--- a/awslambdaric/lambda_runtime_hooks_runner.py
+++ b/awslambdaric/lambda_runtime_hooks_runner.py
@@ -1,10 +1,14 @@
-# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
-# SPDX-License-Identifier: Apache-2.0
+"""Lambda runtime hooks runner.
+
+Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+SPDX-License-Identifier: Apache-2.0
+"""
 
 from snapshot_restore_py import get_before_snapshot, get_after_restore
 
 
 def run_before_snapshot():
+    """Run before snapshot hooks."""
     before_snapshot_callables = get_before_snapshot()
     while before_snapshot_callables:
         # Using pop as before checkpoint callables are executed in the reverse order of their registration
@@ -13,6 +17,7 @@ def run_before_snapshot():
 
 
 def run_after_restore():
+    """Run after restore hooks."""
     after_restore_callables = get_after_restore()
     for func, args, kwargs in after_restore_callables:
         func(*args, **kwargs)
diff --git a/awslambdaric/lambda_runtime_log_utils.py b/awslambdaric/lambda_runtime_log_utils.py
index 9ddbcfb..93a1d63 100644
--- a/awslambdaric/lambda_runtime_log_utils.py
+++ b/awslambdaric/lambda_runtime_log_utils.py
@@ -1,6 +1,4 @@
-"""
-Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-"""
+"""Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved."""
 
 import json
 import logging
@@ -36,11 +34,14 @@
 
 
 class LogFormat(IntEnum):
+    """Log format enumeration."""
+
     JSON = 0b0
     TEXT = 0b1
 
     @classmethod
     def from_str(cls, value: str):
+        """Convert string to LogFormat."""
         if value and value.upper() == "JSON":
             return cls.JSON.value
         return cls.TEXT.value
@@ -77,7 +78,10 @@ def _format_log_level(record: logging.LogRecord) -> int:
 
 
 class JsonFormatter(logging.Formatter):
+    """JSON formatter for Lambda logs."""
+
     def __init__(self):
+        """Initialize the JSON formatter."""
         super().__init__(datefmt=_DATETIME_FORMAT)
 
     @staticmethod
@@ -108,6 +112,7 @@ def __format_location(record: logging.LogRecord):
         return f"{record.pathname}:{record.funcName}:{record.lineno}"
 
     def format(self, record: logging.LogRecord) -> str:
+        """Format log record as JSON."""
         record.levelno = _format_log_level(record)
         record.levelname = logging.getLevelName(record.levelno)
         record._frame_type = _JSON_FRAME_TYPES.get(
diff --git a/awslambdaric/lambda_runtime_marshaller.py b/awslambdaric/lambda_runtime_marshaller.py
index 4256066..fe0dd8f 100644
--- a/awslambdaric/lambda_runtime_marshaller.py
+++ b/awslambdaric/lambda_runtime_marshaller.py
@@ -1,6 +1,4 @@
-"""
-Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-"""
+"""Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved."""
 
 import decimal
 import math
@@ -14,7 +12,10 @@
 # to get the good parts of Decimal support, we'll special-case NaN decimals and otherwise duplicate the encoding for decimals the same way simplejson does
 # We also set 'ensure_ascii=False' so that the encoded json contains unicode characters instead of unicode escape sequences
 class Encoder(json.JSONEncoder):
+    """Custom JSON encoder for Lambda responses."""
+
     def __init__(self):
+        """Initialize the encoder."""
         if os.environ.get("AWS_EXECUTION_ENV") in {
             "AWS_Lambda_python3.12",
             "AWS_Lambda_python3.13",
@@ -24,6 +25,7 @@ def __init__(self):
             super().__init__(use_decimal=False, allow_nan=True)
 
     def default(self, obj):
+        """Handle special object types during encoding."""
         if isinstance(obj, decimal.Decimal):
             if obj.is_nan():
                 return math.nan
@@ -32,14 +34,19 @@ def default(self, obj):
 
 
 def to_json(obj):
+    """Convert object to JSON string."""
     return Encoder().encode(obj)
 
 
 class LambdaMarshaller:
+    """Marshaller for Lambda requests and responses."""
+
     def __init__(self):
+        """Initialize the marshaller."""
         self.jsonEncoder = Encoder()
 
     def unmarshal_request(self, request, content_type="application/json"):
+        """Unmarshal incoming request."""
         if content_type != "application/json":
             return request
         try:
@@ -52,6 +59,7 @@ def unmarshal_request(self, request, content_type="application/json"):
             )
 
     def marshal_response(self, response):
+        """Marshal response for Lambda."""
         if isinstance(response, bytes):
             return response, "application/unknown"
 
diff --git a/pyproject.toml b/pyproject.toml
index 7c0c05b..0007d8f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -51,15 +51,5 @@ select = [
     "C90", # mccabe complexity
 ]
 
-# Ignore rules that are too strict for existing codebases
-ignore = [
-    "E501",    # Line too long (handled by formatter)
-    "PLR0913", # Too many arguments
-    "E722",    # Bare except 
-    "PLW0603", # Global statement 
-    "UP031",   # % formatting vs f-strings
-    "E402",    # Module import not at top
-]
-
 [tool.ruff.format]
 quote-style = "double"
diff --git a/scripts/dev.py b/scripts/dev.py
index 636d4b8..e5e91b6 100644
--- a/scripts/dev.py
+++ b/scripts/dev.py
@@ -71,11 +71,16 @@ def test_rie():
     print("Testing with RIE using pre-built wheel")
     run(["./scripts/test-rie.sh"])
 
+def check_docstr():
+    print("Checking docstrings")
+    run(["poetry", "run", "ruff", "check", "--select", "D", "--ignore", "D105", "awslambdaric/"])
+
 
 def main():
     parser = argparse.ArgumentParser(description="Development scripts")
     parser.add_argument("command", choices=[
-        "init", "test", "lint", "format", "clean", "build", "build-container", "test-rie"
+        "init", "test", "lint", "format", "clean", "build", "build-container", "test-rie",
+        "check-docstr"
     ])
     
     args = parser.parse_args()
@@ -89,6 +94,7 @@ def main():
         "build": build,
         "build-container": build_container,
         "test-rie": test_rie,
+        "check-docstr": check_docstr,
     }
     
     command_map[args.command]()