class HttpAccessLogMiddleware(BaseHTTPMiddleware):
"""Http access log middleware for FastAPI.
Inherits:
BaseHTTPMiddleware: Base HTTP middleware class from starlette.
Attributes:
_DEBUG_MSG_FORMAT_STR (str ): Default http access log debug message format. Defaults to
'<n>[{request_id}]</n> {client_host} {user_id} "<u>{method} {url_path}</u> HTTP/{http_version}"'.
_MSG_FORMAT_STR (str ): Default http access log message format. Defaults to
'<n><w>[{request_id}]</w></n> {client_host} {user_id} "<u>{method} {url_path}</u> HTTP/{http_version}"
{status_code} {content_length}B {response_time}ms'.
debug_msg_format_str (str ): Http access log debug message format.
Defaults to `HttpAccessLogMiddleware._DEBUG_MSG_FORMAT_STR`.
msg_format_str (str ): Http access log message format.
Defaults to `HttpAccessLogMiddleware._MSG_FORMAT_STR`.
use_debug_log (bool): If True, use debug log to log http access log. Defaults to True.
"""
_DEBUG_MSG_FORMAT_STR = (
'<n>[{request_id}]</n> {client_host} {user_id} "<u>{method} {url_path}</u>'
' HTTP/{http_version}"'
)
_MSG_FORMAT_STR = (
'<n><w>[{request_id}]</w></n> {client_host} {user_id} "<u>{method} {url_path}</u> HTTP/{http_version}"'
" {status_code} {content_length}B {response_time}ms"
)
def __init__(
self,
app,
debug_msg_format_str: str = _DEBUG_MSG_FORMAT_STR,
msg_format_str: str = _MSG_FORMAT_STR,
use_debug_log: bool = True,
):
super().__init__(app)
self.debug_msg_format_str = debug_msg_format_str
self.msg_format_str = msg_format_str
self.use_debug_log = use_debug_log
async def dispatch(self, request: Request, call_next) -> Response:
_logger = logger.opt(colors=True, record=True).bind(disable_std_handler=True)
_http_info: dict[str, Any] = {}
if hasattr(request.state, "http_info") and isinstance(
request.state.http_info, dict
):
_http_info: dict[str, Any] = request.state.http_info
# Debug log:
if self.use_debug_log:
_debug_msg = self.debug_msg_format_str.format(**_http_info)
_logger.bind(
http_info=_http_info, disable_http_all_file_handlers=True
).debug(_debug_msg)
# await run_in_threadpool(
# _logger.bind(
# http_info=_http_info, disable_http_all_file_handlers=True
# ).debug,
# _debug_msg,
# )
# Debug log.
# Process request:
response: Response = await call_next(request)
# Response processed.
if hasattr(request.state, "http_info") and isinstance(
request.state.http_info, dict
):
_http_info: dict[str, Any] = request.state.http_info
# Http access log:
_LEVEL = "INFO"
_msg_format_str = self.msg_format_str
if _http_info["status_code"] < 200:
_LEVEL = "DEBUG"
_msg_format_str = f'<d>{_msg_format_str.replace("{status_code}", "<n><b><k>{status_code}</k></b></n>")}</d>'
elif (200 <= _http_info["status_code"]) and (_http_info["status_code"] < 300):
_LEVEL = "SUCCESS"
_msg_format_str = f'<w>{_msg_format_str.replace("{status_code}", "<lvl>{status_code}</lvl>")}</w>'
elif (300 <= _http_info["status_code"]) and (_http_info["status_code"] < 400):
_LEVEL = "INFO"
_msg_format_str = f'<d>{_msg_format_str.replace("{status_code}", "<n><b><c>{status_code}</c></b></n>")}</d>'
elif (400 <= _http_info["status_code"]) and (_http_info["status_code"] < 500):
_LEVEL = "WARNING"
_msg_format_str = _msg_format_str.replace(
"{status_code}", "<r>{status_code}</r>"
)
elif 500 <= _http_info["status_code"]:
_LEVEL = "ERROR"
_msg_format_str = (
f'{_msg_format_str.replace("{status_code}", "<n>{status_code}</n>")}'
)
_msg = _msg_format_str.format(**_http_info)
_logger.bind(http_info=_http_info).log(_LEVEL, _msg)
# await run_in_threadpool(_logger.bind(http_info=_http_info).log, _LEVEL, _msg)
# Http access log.
return response