diff --git a/Makefile b/Makefile index e62cb9c..694b1c3 100644 --- a/Makefile +++ b/Makefile @@ -47,11 +47,15 @@ format: check-format: poetry run ruff format --check awslambdaric/ tests/ +.PHONY: check-docstr +check-docstr: + python3 scripts/dev.py check-docstr + .PHONY: dev dev: init test .PHONY: pr -pr: init check-format check-security dev +pr: init check-format check-annotations check-types check-type-usage check-security dev .PHONY: codebuild codebuild: setup-codebuild-agent @@ -70,17 +74,18 @@ define HELP_MESSAGE Usage: $ make [TARGETS] TARGETS - check-security Run bandit to find security issues. - format Run black to automatically update your code to match formatting. - build Build the package using scripts/dev.py. - clean Cleans the working directory using scripts/dev.py. - dev Run all development tests using scripts/dev.py. - init Install dependencies via scripts/dev.py. - build-container Build awslambdaric wheel in isolated container. - test-rie Test with RIE using pre-built wheel (run build-container first). - pr Perform all checks before submitting a Pull Request. - test Run unit tests using scripts/dev.py. - lint Run all linters via scripts/dev.py. - test-smoke Run smoke tests inside Docker. - test-integ Run all integration tests. + check-security Run bandit to find security issues. + check-docstr Check docstrings in project using ruff format check. + format Run ruff to automatically format your code. + build Build the package using scripts/dev.py. + clean Cleans the working directory using scripts/dev.py. + dev Run all development tests using scripts/dev.py. + init Install dependencies via scripts/dev.py. + build-container Build awslambdaric wheel in isolated container. + test-rie Test with RIE using pre-built wheel (run build-container first). + pr Perform all checks before submitting a Pull Request. + test Run unit tests using scripts/dev.py. + lint Run all linters via scripts/dev.py. + test-smoke Run smoke tests inside Docker. + test-integ Run all integration tests. endef \ No newline at end of file diff --git a/awslambdaric/__init__.py b/awslambdaric/__init__.py index 5605903..0f94db0 100644 --- a/awslambdaric/__init__.py +++ b/awslambdaric/__init__.py @@ -1,5 +1,3 @@ -""" -Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" __version__ = "3.1.1" diff --git a/awslambdaric/__main__.py b/awslambdaric/__main__.py index 5cbbaab..3cd1ad4 100644 --- a/awslambdaric/__main__.py +++ b/awslambdaric/__main__.py @@ -1,6 +1,4 @@ -""" -Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" import os import sys @@ -9,6 +7,7 @@ def main(args): + """Run the Lambda runtime main entry point.""" app_root = os.getcwd() try: diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py index f63e765..9bf1382 100644 --- a/awslambdaric/bootstrap.py +++ b/awslambdaric/bootstrap.py @@ -1,6 +1,4 @@ -""" -Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" import importlib import json @@ -36,6 +34,7 @@ def _get_handler(handler): + """Get handler function from module.""" try: (modname, fname) = handler.rsplit(".", 1) except ValueError as e: @@ -82,6 +81,7 @@ def make_error( stack_trace, invoke_id=None, ): + """Create error response.""" result = { "errorMessage": error_message if error_message else "", "errorType": error_type if error_type else "", @@ -92,6 +92,7 @@ def make_error( def replace_line_indentation(line, indent_char, new_indent_char): + """Replace line indentation characters.""" ident_chars_count = 0 for c in line: if c != indent_char: @@ -104,6 +105,7 @@ def replace_line_indentation(line, indent_char, new_indent_char): _ERROR_FRAME_TYPE = _JSON_FRAME_TYPES[logging.ERROR] def log_error(error_result, log_sink): + """Log error in JSON format.""" error_result = { "timestamp": time.strftime( _DATETIME_FORMAT, logging.Formatter.converter(time.time()) @@ -119,6 +121,7 @@ def log_error(error_result, log_sink): _ERROR_FRAME_TYPE = _TEXT_FRAME_TYPES[logging.ERROR] def log_error(error_result, log_sink): + """Log error in text format.""" error_description = "[ERROR]" error_result_type = error_result.get("errorType") @@ -161,6 +164,7 @@ def handle_event_request( tenant_id, log_sink, ): + """Handle Lambda event request.""" error_result = None try: lambda_context = create_lambda_context( @@ -212,6 +216,7 @@ def handle_event_request( def parse_json_header(header, name): + """Parse JSON header.""" try: return json.loads(header) except Exception as e: @@ -230,6 +235,7 @@ def create_lambda_context( invoked_function_arn, tenant_id, ): + """Create Lambda context object.""" client_context = None if client_context_json: client_context = parse_json_header(client_context_json, "Client Context") @@ -248,6 +254,7 @@ def create_lambda_context( def build_fault_result(exc_info, msg): + """Build fault result from exception info.""" etype, value, tb = exc_info tb_tuples = extract_traceback(tb) for i in range(len(tb_tuples)): @@ -263,6 +270,7 @@ def build_fault_result(exc_info, msg): def make_xray_fault(ex_type, ex_msg, working_dir, tb_tuples): + """Create X-Ray fault object.""" stack = [] files = set() for t in tb_tuples: @@ -281,6 +289,7 @@ def make_xray_fault(ex_type, ex_msg, working_dir, tb_tuples): def extract_traceback(tb): + """Extract traceback information.""" return [ (frame.filename, frame.lineno, frame.name, frame.line) for frame in traceback.extract_tb(tb) @@ -288,6 +297,7 @@ def extract_traceback(tb): def on_init_complete(lambda_runtime_client, log_sink): + """Handle initialization completion.""" from . import lambda_runtime_hooks_runner try: @@ -311,21 +321,29 @@ def on_init_complete(lambda_runtime_client, log_sink): class LambdaLoggerHandler(logging.Handler): + """Lambda logger handler.""" + def __init__(self, log_sink): + """Initialize logger handler.""" logging.Handler.__init__(self) self.log_sink = log_sink def emit(self, record): + """Emit log record.""" msg = self.format(record) self.log_sink.log(msg) class LambdaLoggerHandlerWithFrameType(logging.Handler): + """Lambda logger handler with frame type.""" + def __init__(self, log_sink): + """Initialize logger handler.""" super().__init__() self.log_sink = log_sink def emit(self, record): + """Emit log record with frame type.""" self.log_sink.log( self.format(record), frame_type=( @@ -336,14 +354,20 @@ def emit(self, record): class LambdaLoggerFilter(logging.Filter): + """Lambda logger filter.""" + def filter(self, record): + """Filter log record.""" record.aws_request_id = _GLOBAL_AWS_REQUEST_ID or "" record.tenant_id = _GLOBAL_TENANT_ID return True class Unbuffered(object): + """Unbuffered stream wrapper.""" + def __init__(self, stream): + """Initialize unbuffered stream.""" self.stream = stream def __enter__(self): @@ -356,16 +380,21 @@ def __getattr__(self, attr): return getattr(self.stream, attr) def write(self, msg): + """Write message to stream.""" self.stream.write(msg) self.stream.flush() def writelines(self, msgs): + """Write multiple lines to stream.""" self.stream.writelines(msgs) self.stream.flush() class StandardLogSink(object): + """Standard log sink.""" + def __init__(self): + """Initialize standard log sink.""" pass def __enter__(self): @@ -375,17 +404,19 @@ def __exit__(self, exc_type, exc_value, exc_tb): pass def log(self, msg, frame_type=None): + """Log message to stdout.""" sys.stdout.write(msg) def log_error(self, message_lines): + """Log error message to stdout.""" error_message = ERROR_LOG_LINE_TERMINATE.join(message_lines) + "\n" sys.stdout.write(error_message) class FramedTelemetryLogSink(object): - """ - FramedTelemetryLogSink implements the logging contract between runtimes and the platform. It implements a simple - framing protocol so message boundaries can be determined. Each frame can be visualized as follows: + """FramedTelemetryLogSink implements the logging contract between runtimes and the platform. + + It implements a simple framing protocol so message boundaries can be determined. Each frame can be visualized as follows:
{@code +----------------------+------------------------+---------------------+-----------------------+ @@ -399,6 +430,7 @@ class FramedTelemetryLogSink(object): """ def __init__(self, fd): + """Initialize framed telemetry log sink.""" self.fd = int(fd) def __enter__(self): @@ -409,6 +441,7 @@ def __exit__(self, exc_type, exc_value, exc_tb): self.file.close() def log(self, msg, frame_type=None): + """Log message with frame type.""" encoded_msg = msg.encode("utf8") timestamp = int(time.time_ns() / 1000) # UNIX timestamp in microseconds @@ -421,6 +454,7 @@ def log(self, msg, frame_type=None): self.file.write(log_msg) def log_error(self, message_lines): + """Log error message.""" error_message = "\n".join(message_lines) self.log( error_message, @@ -429,6 +463,7 @@ def log_error(self, message_lines): def update_xray_env_variable(xray_trace_id): + """Update X-Ray trace ID environment variable.""" if xray_trace_id is not None: os.environ["_X_AMZN_TRACE_ID"] = xray_trace_id else: @@ -437,6 +472,7 @@ def update_xray_env_variable(xray_trace_id): def create_log_sink(): + """Create appropriate log sink.""" if "_LAMBDA_TELEMETRY_LOG_FD" in os.environ: fd = os.environ["_LAMBDA_TELEMETRY_LOG_FD"] del os.environ["_LAMBDA_TELEMETRY_LOG_FD"] @@ -451,6 +487,7 @@ def create_log_sink(): def _setup_logging(log_format, log_level, log_sink): + """Set up logging configuration.""" logging.Formatter.converter = time.gmtime logger = logging.getLogger() @@ -477,6 +514,7 @@ def _setup_logging(log_format, log_level, log_sink): def run(app_root, handler, lambda_runtime_api_addr): + """Run Lambda runtime.""" sys.stdout = Unbuffered(sys.stdout) sys.stderr = Unbuffered(sys.stderr) diff --git a/awslambdaric/lambda_context.py b/awslambdaric/lambda_context.py index e0a3363..e827993 100644 --- a/awslambdaric/lambda_context.py +++ b/awslambdaric/lambda_context.py @@ -1,6 +1,4 @@ -""" -Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" import logging import os @@ -9,6 +7,8 @@ class LambdaContext(object): + """Lambda context object.""" + def __init__( self, invoke_id, @@ -18,6 +18,7 @@ def __init__( invoked_function_arn=None, tenant_id=None, ): + """Initialize Lambda context.""" self.aws_request_id = invoke_id self.log_group_name = os.environ.get("AWS_LAMBDA_LOG_GROUP_NAME") self.log_stream_name = os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME") @@ -45,11 +46,13 @@ def __init__( self._epoch_deadline_time_in_ms = epoch_deadline_time_in_ms def get_remaining_time_in_millis(self): + """Get remaining time in milliseconds.""" epoch_now_in_ms = int(time.time() * 1000) delta_ms = self._epoch_deadline_time_in_ms - epoch_now_in_ms return delta_ms if delta_ms > 0 else 0 def log(self, msg): + """Log a message.""" for handler in logging.getLogger().handlers: if hasattr(handler, "log_sink"): handler.log_sink.log(str(msg)) @@ -74,6 +77,8 @@ def __repr__(self): class CognitoIdentity(object): + """Cognito identity information.""" + __slots__ = ["cognito_identity_id", "cognito_identity_pool_id"] def __repr__(self): @@ -86,6 +91,8 @@ def __repr__(self): class Client(object): + """Client information.""" + __slots__ = [ "installation_id", "app_title", @@ -107,6 +114,8 @@ def __repr__(self): class ClientContext(object): + """Client context information.""" + __slots__ = ["custom", "env", "client"] def __repr__(self): @@ -120,6 +129,7 @@ def __repr__(self): def make_obj_from_dict(_class, _dict, fields=None): + """Create object from dictionary.""" if _dict is None: return None obj = _class() @@ -128,6 +138,7 @@ def make_obj_from_dict(_class, _dict, fields=None): def set_obj_from_dict(obj, _dict, fields=None): + """Set object attributes from dictionary.""" if fields is None: fields = obj.__class__.__slots__ for field in fields: diff --git a/awslambdaric/lambda_literals.py b/awslambdaric/lambda_literals.py index 2585b89..a2a0746 100644 --- a/awslambdaric/lambda_literals.py +++ b/awslambdaric/lambda_literals.py @@ -1,6 +1,4 @@ -""" -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" lambda_warning = "LAMBDA_WARNING" diff --git a/awslambdaric/lambda_runtime_client.py b/awslambdaric/lambda_runtime_client.py index ba4ad92..fd5fdf1 100644 --- a/awslambdaric/lambda_runtime_client.py +++ b/awslambdaric/lambda_runtime_client.py @@ -1,6 +1,4 @@ -""" -Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" import sys from awslambdaric import __version__ @@ -29,15 +27,22 @@ def _user_agent(): class InvocationRequest(object): + """Lambda invocation request.""" + def __init__(self, **kwds): + """Initialize invocation request.""" self.__dict__.update(kwds) def __eq__(self, other): + """Check equality.""" return self.__dict__ == other.__dict__ class LambdaRuntimeClientError(Exception): + """Lambda runtime client error.""" + def __init__(self, endpoint, response_code, response_body): + """Initialize runtime client error.""" self.endpoint = endpoint self.response_code = response_code self.response_body = response_body @@ -47,12 +52,15 @@ def __init__(self, endpoint, response_code, response_body): class LambdaRuntimeClient(object): + """Lambda runtime client.""" + marshaller = LambdaMarshaller() """marshaller is a class attribute that determines the unmarshalling and marshalling logic of a function's event and response. It allows for function authors to override the the default implementation, LambdaMarshaller which unmarshals and marshals JSON, to an instance of a class that implements the same interface.""" def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False): + """Initialize runtime client.""" self.lambda_runtime_address = lambda_runtime_address self.use_thread_for_polling_next = use_thread_for_polling_next if self.use_thread_for_polling_next: @@ -65,6 +73,7 @@ def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False): def call_rapid( self, http_method, endpoint, expected_http_code, payload=None, headers=None ): + """Call RAPID endpoint.""" # These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`. # Importing them lazily to speed up critical path of a common case. import http.client @@ -84,6 +93,7 @@ def call_rapid( raise LambdaRuntimeClientError(endpoint, response.code, response_body) def post_init_error(self, error_response_data, error_type_override=None): + """Post initialization error.""" import http endpoint = "/2018-06-01/runtime/init/error" @@ -99,12 +109,14 @@ def post_init_error(self, error_response_data, error_type_override=None): ) def restore_next(self): + """Restore next invocation.""" import http endpoint = "/2018-06-01/runtime/restore/next" self.call_rapid("GET", endpoint, http.HTTPStatus.OK) def report_restore_error(self, restore_error_data): + """Report restore error.""" import http endpoint = "/2018-06-01/runtime/restore/error" @@ -114,6 +126,7 @@ def report_restore_error(self, restore_error_data): ) def wait_next_invocation(self): + """Wait for next invocation.""" # Calling runtime_client.next() from a separate thread unblocks the main thread, # which can then process signals. if self.use_thread_for_polling_next: @@ -145,6 +158,7 @@ def wait_next_invocation(self): def post_invocation_result( self, invoke_id, result_data, content_type="application/json" ): + """Post invocation result.""" runtime_client.post_invocation_result( invoke_id, ( @@ -156,6 +170,7 @@ def post_invocation_result( ) def post_invocation_error(self, invoke_id, error_response_data, xray_fault): + """Post invocation error.""" max_header_size = 1024 * 1024 # 1MiB xray_fault = xray_fault if len(xray_fault.encode()) < max_header_size else "" runtime_client.post_error(invoke_id, error_response_data, xray_fault) diff --git a/awslambdaric/lambda_runtime_exception.py b/awslambdaric/lambda_runtime_exception.py index 3ea5b29..3c2e41a 100644 --- a/awslambdaric/lambda_runtime_exception.py +++ b/awslambdaric/lambda_runtime_exception.py @@ -1,9 +1,9 @@ -""" -Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" class FaultException(Exception): + """Exception class for Lambda runtime faults.""" + MARSHAL_ERROR = "Runtime.MarshalError" UNMARSHAL_ERROR = "Runtime.UnmarshalError" USER_CODE_SYNTAX_ERROR = "Runtime.UserCodeSyntaxError" @@ -17,6 +17,7 @@ class FaultException(Exception): LAMBDA_RUNTIME_CLIENT_ERROR = "Runtime.LambdaRuntimeClientError" def __init__(self, exception_type, msg, trace=None): + """Initialize FaultException.""" self.msg = msg self.exception_type = exception_type self.trace = trace diff --git a/awslambdaric/lambda_runtime_hooks_runner.py b/awslambdaric/lambda_runtime_hooks_runner.py index 8aee181..ab67e63 100644 --- a/awslambdaric/lambda_runtime_hooks_runner.py +++ b/awslambdaric/lambda_runtime_hooks_runner.py @@ -1,10 +1,14 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 +"""Lambda runtime hooks runner. + +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" from snapshot_restore_py import get_before_snapshot, get_after_restore def run_before_snapshot(): + """Run before snapshot hooks.""" before_snapshot_callables = get_before_snapshot() while before_snapshot_callables: # Using pop as before checkpoint callables are executed in the reverse order of their registration @@ -13,6 +17,7 @@ def run_before_snapshot(): def run_after_restore(): + """Run after restore hooks.""" after_restore_callables = get_after_restore() for func, args, kwargs in after_restore_callables: func(*args, **kwargs) diff --git a/awslambdaric/lambda_runtime_log_utils.py b/awslambdaric/lambda_runtime_log_utils.py index 9ddbcfb..93a1d63 100644 --- a/awslambdaric/lambda_runtime_log_utils.py +++ b/awslambdaric/lambda_runtime_log_utils.py @@ -1,6 +1,4 @@ -""" -Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" import json import logging @@ -36,11 +34,14 @@ class LogFormat(IntEnum): + """Log format enumeration.""" + JSON = 0b0 TEXT = 0b1 @classmethod def from_str(cls, value: str): + """Convert string to LogFormat.""" if value and value.upper() == "JSON": return cls.JSON.value return cls.TEXT.value @@ -77,7 +78,10 @@ def _format_log_level(record: logging.LogRecord) -> int: class JsonFormatter(logging.Formatter): + """JSON formatter for Lambda logs.""" + def __init__(self): + """Initialize the JSON formatter.""" super().__init__(datefmt=_DATETIME_FORMAT) @staticmethod @@ -108,6 +112,7 @@ def __format_location(record: logging.LogRecord): return f"{record.pathname}:{record.funcName}:{record.lineno}" def format(self, record: logging.LogRecord) -> str: + """Format log record as JSON.""" record.levelno = _format_log_level(record) record.levelname = logging.getLevelName(record.levelno) record._frame_type = _JSON_FRAME_TYPES.get( diff --git a/awslambdaric/lambda_runtime_marshaller.py b/awslambdaric/lambda_runtime_marshaller.py index 4256066..fe0dd8f 100644 --- a/awslambdaric/lambda_runtime_marshaller.py +++ b/awslambdaric/lambda_runtime_marshaller.py @@ -1,6 +1,4 @@ -""" -Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -""" +"""Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.""" import decimal import math @@ -14,7 +12,10 @@ # to get the good parts of Decimal support, we'll special-case NaN decimals and otherwise duplicate the encoding for decimals the same way simplejson does # We also set 'ensure_ascii=False' so that the encoded json contains unicode characters instead of unicode escape sequences class Encoder(json.JSONEncoder): + """Custom JSON encoder for Lambda responses.""" + def __init__(self): + """Initialize the encoder.""" if os.environ.get("AWS_EXECUTION_ENV") in { "AWS_Lambda_python3.12", "AWS_Lambda_python3.13", @@ -24,6 +25,7 @@ def __init__(self): super().__init__(use_decimal=False, allow_nan=True) def default(self, obj): + """Handle special object types during encoding.""" if isinstance(obj, decimal.Decimal): if obj.is_nan(): return math.nan @@ -32,14 +34,19 @@ def default(self, obj): def to_json(obj): + """Convert object to JSON string.""" return Encoder().encode(obj) class LambdaMarshaller: + """Marshaller for Lambda requests and responses.""" + def __init__(self): + """Initialize the marshaller.""" self.jsonEncoder = Encoder() def unmarshal_request(self, request, content_type="application/json"): + """Unmarshal incoming request.""" if content_type != "application/json": return request try: @@ -52,6 +59,7 @@ def unmarshal_request(self, request, content_type="application/json"): ) def marshal_response(self, response): + """Marshal response for Lambda.""" if isinstance(response, bytes): return response, "application/unknown" diff --git a/pyproject.toml b/pyproject.toml index 7c0c05b..0007d8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,15 +51,5 @@ select = [ "C90", # mccabe complexity ] -# Ignore rules that are too strict for existing codebases -ignore = [ - "E501", # Line too long (handled by formatter) - "PLR0913", # Too many arguments - "E722", # Bare except - "PLW0603", # Global statement - "UP031", # % formatting vs f-strings - "E402", # Module import not at top -] - [tool.ruff.format] quote-style = "double" diff --git a/scripts/dev.py b/scripts/dev.py index 636d4b8..e5e91b6 100644 --- a/scripts/dev.py +++ b/scripts/dev.py @@ -71,11 +71,16 @@ def test_rie(): print("Testing with RIE using pre-built wheel") run(["./scripts/test-rie.sh"]) +def check_docstr(): + print("Checking docstrings") + run(["poetry", "run", "ruff", "check", "--select", "D", "--ignore", "D105", "awslambdaric/"]) + def main(): parser = argparse.ArgumentParser(description="Development scripts") parser.add_argument("command", choices=[ - "init", "test", "lint", "format", "clean", "build", "build-container", "test-rie" + "init", "test", "lint", "format", "clean", "build", "build-container", "test-rie", + "check-docstr" ]) args = parser.parse_args() @@ -89,6 +94,7 @@ def main(): "build": build, "build-container": build_container, "test-rie": test_rie, + "check-docstr": check_docstr, } command_map[args.command]()