121 lines
4.1 KiB
Python
121 lines
4.1 KiB
Python
import typing
|
|
from os import PathLike
|
|
|
|
from starlette.background import BackgroundTask
|
|
from starlette.datastructures import URL
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.types import Receive, Scope, Send
|
|
|
|
try:
|
|
import jinja2
|
|
|
|
# @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1
|
|
# hence we try to get pass_context (most installs will be >=3.1)
|
|
# and fall back to contextfunction,
|
|
# adding a type ignore for mypy to let us access an attribute that may not exist
|
|
if hasattr(jinja2, "pass_context"):
|
|
pass_context = jinja2.pass_context
|
|
else: # pragma: nocover
|
|
pass_context = jinja2.contextfunction # type: ignore[attr-defined]
|
|
except ModuleNotFoundError: # pragma: nocover
|
|
jinja2 = None # type: ignore[assignment]
|
|
|
|
|
|
class _TemplateResponse(Response):
|
|
media_type = "text/html"
|
|
|
|
def __init__(
|
|
self,
|
|
template: typing.Any,
|
|
context: dict,
|
|
status_code: int = 200,
|
|
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
|
media_type: typing.Optional[str] = None,
|
|
background: typing.Optional[BackgroundTask] = None,
|
|
):
|
|
self.template = template
|
|
self.context = context
|
|
content = template.render(context)
|
|
super().__init__(content, status_code, headers, media_type, background)
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
request = self.context.get("request", {})
|
|
extensions = request.get("extensions", {})
|
|
if "http.response.debug" in extensions:
|
|
await send(
|
|
{
|
|
"type": "http.response.debug",
|
|
"info": {
|
|
"template": self.template,
|
|
"context": self.context,
|
|
},
|
|
}
|
|
)
|
|
await super().__call__(scope, receive, send)
|
|
|
|
|
|
class Jinja2Templates:
|
|
"""
|
|
templates = Jinja2Templates("templates")
|
|
|
|
return templates.TemplateResponse("index.html", {"request": request})
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
directory: typing.Union[str, PathLike],
|
|
context_processors: typing.Optional[
|
|
typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]]
|
|
] = None,
|
|
**env_options: typing.Any,
|
|
) -> None:
|
|
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
|
|
self.env = self._create_env(directory, **env_options)
|
|
self.context_processors = context_processors or []
|
|
|
|
def _create_env(
|
|
self, directory: typing.Union[str, PathLike], **env_options: typing.Any
|
|
) -> "jinja2.Environment":
|
|
@pass_context
|
|
def url_for(context: dict, name: str, **path_params: typing.Any) -> URL:
|
|
request = context["request"]
|
|
return request.url_for(name, **path_params)
|
|
|
|
loader = jinja2.FileSystemLoader(directory)
|
|
env_options.setdefault("loader", loader)
|
|
env_options.setdefault("autoescape", True)
|
|
|
|
env = jinja2.Environment(**env_options)
|
|
env.globals["url_for"] = url_for
|
|
return env
|
|
|
|
def get_template(self, name: str) -> "jinja2.Template":
|
|
return self.env.get_template(name)
|
|
|
|
def TemplateResponse(
|
|
self,
|
|
name: str,
|
|
context: dict,
|
|
status_code: int = 200,
|
|
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
|
media_type: typing.Optional[str] = None,
|
|
background: typing.Optional[BackgroundTask] = None,
|
|
) -> _TemplateResponse:
|
|
if "request" not in context:
|
|
raise ValueError('context must include a "request" key')
|
|
|
|
request = typing.cast(Request, context["request"])
|
|
for context_processor in self.context_processors:
|
|
context.update(context_processor(request))
|
|
|
|
template = self.get_template(name)
|
|
return _TemplateResponse(
|
|
template,
|
|
context,
|
|
status_code=status_code,
|
|
headers=headers,
|
|
media_type=media_type,
|
|
background=background,
|
|
)
|