Files
ietf-draft-analyzer/demo/act-ect-mcp/src/poc/agent.py
Christian Nennemann 9a0dc899a8 feat: add ACT+ECT over MCP demo with LangGraph agent
End-to-end PoC demonstrating Agent Context Token authorization and
Execution Context Token accountability over MCP tool calls, using a
LangGraph agent with ES256-signed JWT tokens and DAG verification.
2026-04-12 12:43:22 +00:00

448 lines
15 KiB
Python

"""LangGraph ReAct agent that calls MCP tools with ACT + ECT on every request.
Flow per run
------------
1. ``mint_mandate`` — user issues a Phase 1 ACT mandate that authorises the
agent to use ``mcp.search``, ``mcp.summarize``, plus session-level actions.
2. ``MultiServerMCPClient`` opens a streamable-HTTP session to the MCP
server. The session's ``httpx.AsyncClient`` has event hooks installed
(``_install_ect_hooks``) that, on every outgoing POST to /mcp:
* build an ECT over the request body (inp_hash),
* sign the request per RFC 9421 with ``wimse-aud=mcp-server``,
* attach ``Authorization: Bearer <ACT>``, ``Wimse-ECT: <ect>``,
``Content-Digest``, ``Signature-Input`` and ``Signature``.
Each ECT's ``pred`` chains to the mandate plus all prior tool-call ECTs
in this run, so the ECT DAG captures the per-tool-call ordering.
3. ``create_react_agent`` runs a LangGraph ReAct loop with ChatOllama; the
LLM decides when/what to call. The token plumbing is transparent to
the model.
4. After the agent finishes its response, a single Phase 2 ACT execution
record is minted that summarises the run (ACT §3.2: one mandate → one
record; jti preserved). The record's ``inp_hash`` covers the task
purpose and ``out_hash`` covers the final assistant message.
"""
from __future__ import annotations
import argparse
import asyncio
import hashlib
import json
import logging
import os
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, AsyncIterator
import httpx
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_ollama import ChatOllama
from langgraph.prebuilt import create_react_agent
from act.crypto import b64url_sha256
from .http_sig import SignedRequest, content_digest, sign_request
from .keys import Identity, load_identities
from .tokens import (
MintedMandate,
MintedRecord,
MintedECT,
exec_act_for_rpc_method,
mint_ect,
mint_exec_record,
mint_mandate,
)
LOG = logging.getLogger("poc.agent")
SERVER_IDENTITY_NAME = "mcp-server"
# ---- Session ledger ---------------------------------------------------------
@dataclass
class LedgerEntry:
kind: str # "mandate" | "ect" | "record"
compact: str
jti: str
metadata: dict[str, Any] = field(default_factory=dict)
def to_json(self) -> str:
return json.dumps(
{
"kind": self.kind,
"jti": self.jti,
"compact": self.compact,
"metadata": self.metadata,
},
separators=(",", ":"),
)
@dataclass
class SessionLedger:
"""Mutable per-run state: mandate + growing chain of ECT tool invocations.
The ECT ``pred`` set grows with each successful tool call, giving a
DAG of execution contexts. There is exactly one ACT Phase 2 record per
run (minted at the end), whose jti equals the mandate jti per ACT §3.2.
"""
path: Path
mandate: MintedMandate
tool_ects: list[MintedECT] = field(default_factory=list)
final_record: MintedRecord | None = None
def write_entry(self, entry: LedgerEntry) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
with self.path.open("a", encoding="utf-8") as fh:
fh.write(entry.to_json() + "\n")
def tool_ect_pred(self) -> list[str]:
"""pred list for the *next* tool-call ECT: mandate + prior tool ECTs."""
return [self.mandate.mandate.jti] + [e.payload.jti for e in self.tool_ects]
# ---- httpx event hooks ------------------------------------------------------
def _rpc_method_and_tool(body: bytes) -> tuple[str | None, str | None]:
"""Sniff a JSON-RPC request body for (method, tool_name)."""
try:
obj = json.loads(body.decode("utf-8"))
except Exception:
return None, None
if not isinstance(obj, dict):
return None, None
method = obj.get("method")
if not isinstance(method, str):
return None, None
tool_name = None
if method == "tools/call":
params = obj.get("params") or {}
name = params.get("name") if isinstance(params, dict) else None
if isinstance(name, str):
tool_name = name
return method, tool_name
def _install_ect_hooks(
client: httpx.AsyncClient,
*,
agent: Identity,
audience: str,
ledger: SessionLedger,
mcp_path: str = "/mcp",
) -> None:
"""Attach request/response event hooks that inject ACT+ECT+sig headers."""
state_key = "_poc_ect_state"
async def on_request(request: httpx.Request) -> None:
if not request.url.path.endswith(mcp_path):
return
# httpx may have already serialized body into request.content.
body = request.content or b""
method, tool_name = _rpc_method_and_tool(body)
if method is None:
# Not JSON-RPC — still attach mandate so middleware can 403
# rather than 401, but skip ECT/record minting. The PoC never
# triggers this path; keep it permissive to ease debugging.
request.headers["authorization"] = f"Bearer {ledger.mandate.compact}"
return
try:
exec_act = exec_act_for_rpc_method(method, tool_name)
except ValueError:
LOG.warning("unknown tool in tools/call: %r", tool_name)
return
# Session-setup calls (initialize, tools/list, ping, …) don't grow
# the tool-call DAG — they point only at the mandate. Tool-call
# ECTs chain off the mandate plus every prior tool-call ECT.
is_tool_call = method == "tools/call"
if is_tool_call:
pred_jtis = ledger.tool_ect_pred()
else:
pred_jtis = [ledger.mandate.mandate.jti]
ect = mint_ect(
agent=agent,
audience=audience,
exec_act=exec_act,
pred_jtis=pred_jtis,
inp_body=body,
)
signed: SignedRequest = sign_request(
method=request.method,
target_uri=str(request.url),
body=body,
wimse_ect=ect.compact,
wimse_aud=audience,
keyid=agent.kid,
private_key=agent.private_key,
)
request.headers["authorization"] = f"Bearer {ledger.mandate.compact}"
request.headers["wimse-ect"] = ect.compact
request.headers["content-digest"] = signed.content_digest
request.headers["signature-input"] = signed.signature_input
request.headers["signature"] = signed.signature
# Stash so response hook can mint the exec record correlating the
# HTTP exchange with the ECT we just sent.
setattr(request, state_key, {
"ect": ect,
"exec_act": exec_act,
"method": method,
"tool_name": tool_name,
"inp_hash": b64url_sha256(body),
"pred_jtis": pred_jtis,
"request_body": body,
})
async def on_response(response: httpx.Response) -> None:
request = response.request
st = getattr(request, state_key, None)
if not st:
return
method: str = st["method"]
ect = st["ect"]
if method == "tools/call":
ledger.tool_ects.append(ect)
ledger.write_entry(
LedgerEntry(
kind="ect",
compact=ect.compact,
jti=ect.payload.jti,
metadata={
"method": method,
"tool_name": st["tool_name"],
"exec_act": st["exec_act"],
"pred": list(ect.payload.pred),
},
)
)
else:
ledger.write_entry(
LedgerEntry(
kind="ect",
compact=ect.compact,
jti=ect.payload.jti,
metadata={
"method": method,
"exec_act": st["exec_act"],
"session_only": True,
},
)
)
client.event_hooks["request"].append(on_request)
client.event_hooks["response"].append(on_response)
# ---- MCP client factory -----------------------------------------------------
def make_httpx_client_factory(agent: Identity, audience: str, ledger: SessionLedger):
"""Return an httpx_client_factory that installs our hooks on each client."""
from mcp.shared._httpx_utils import (
MCP_DEFAULT_SSE_READ_TIMEOUT,
MCP_DEFAULT_TIMEOUT,
)
def factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
kwargs: dict[str, Any] = {"follow_redirects": True}
if timeout is None:
kwargs["timeout"] = httpx.Timeout(
MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT
)
else:
kwargs["timeout"] = timeout
if headers is not None:
kwargs["headers"] = headers
if auth is not None:
kwargs["auth"] = auth
client = httpx.AsyncClient(**kwargs)
_install_ect_hooks(client, agent=agent, audience=audience, ledger=ledger)
return client
return factory
# ---- Run an agent turn ------------------------------------------------------
@asynccontextmanager
async def open_mcp_client(
*, agent: Identity, audience: str, ledger: SessionLedger, url: str
) -> AsyncIterator[MultiServerMCPClient]:
factory = make_httpx_client_factory(agent, audience, ledger)
client = MultiServerMCPClient(
{
"poc": {
"transport": "streamable_http",
"url": url,
"httpx_client_factory": factory,
}
}
)
try:
yield client
finally:
# MultiServerMCPClient does not expose an explicit close in 0.2.x;
# sessions are closed per get_tools() call. Nothing to do here.
pass
async def run_once(
*,
purpose: str,
model: str,
mcp_url: str,
keys_dir: str,
ledger_path: str,
ollama_host: str | None,
) -> dict[str, Any]:
identities = load_identities(keys_dir)
user = identities["user"]
agent = identities["agent"]
mandate = mint_mandate(
user=user,
agent=agent,
audience=SERVER_IDENTITY_NAME,
purpose=purpose,
)
ledger = SessionLedger(path=Path(ledger_path), mandate=mandate)
ledger.write_entry(
LedgerEntry(
kind="mandate",
compact=mandate.compact,
jti=mandate.mandate.jti,
metadata={
"iss": mandate.mandate.iss,
"sub": mandate.mandate.sub,
"aud": mandate.mandate.aud,
"task": mandate.mandate.task.to_dict(),
"cap": [c.to_dict() for c in mandate.mandate.cap],
},
)
)
async with open_mcp_client(
agent=agent, audience=SERVER_IDENTITY_NAME, ledger=ledger, url=mcp_url
) as client:
tools = await client.get_tools()
LOG.info("loaded %d MCP tools: %s", len(tools), [t.name for t in tools])
llm_kwargs: dict[str, Any] = {"model": model, "temperature": 0.0}
if ollama_host:
llm_kwargs["base_url"] = ollama_host
llm = ChatOllama(**llm_kwargs)
graph = create_react_agent(llm, tools)
system = SystemMessage(
content=(
"You are a research assistant with access to two tools: "
"search(query) and summarize(text). "
"For the user's task, first call search to gather material, "
"then call summarize on the joined results. "
"After the summary, reply with the summary and stop."
)
)
human = HumanMessage(content=purpose)
result = await graph.ainvoke({"messages": [system, human]})
final_msg = result["messages"][-1]
final_text = getattr(final_msg, "content", str(final_msg))
if isinstance(final_text, list):
final_text = json.dumps(final_text, sort_keys=True)
# ACT §3.2: one mandate → one Phase 2 record (jti preserved). The
# record summarises the whole invocation; per-tool-call DAG structure
# lives in the ECTs we already logged.
final_record = mint_exec_record(
agent=agent,
mandate=mandate.mandate,
exec_act="mcp.summarize", # terminal exec_act; picked from cap
pred_jtis=[], # root task within this run's ACT view
inp_body=purpose.encode("utf-8"),
out_body=final_text.encode("utf-8"),
)
ledger.final_record = final_record
ledger.write_entry(
LedgerEntry(
kind="record",
compact=final_record.compact,
jti=final_record.record.jti,
metadata={
"exec_act": final_record.record.exec_act,
"status": final_record.record.status,
"pred": list(final_record.record.pred),
"inp_hash": final_record.record.inp_hash,
"out_hash": final_record.record.out_hash,
"n_tool_ects": len(ledger.tool_ects),
},
)
)
return {
"mandate_jti": mandate.mandate.jti,
"record_jti": final_record.record.jti,
"tool_ects": [e.payload.jti for e in ledger.tool_ects],
"final_message": final_text,
}
def main() -> None:
logging.basicConfig(
level=os.environ.get("POC_LOG_LEVEL", "INFO"),
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
parser = argparse.ArgumentParser(description="ACT/ECT MCP PoC agent")
parser.add_argument(
"--purpose",
default="Summarise recent research on agent authorization tokens.",
help="High-level task the mandate authorises.",
)
parser.add_argument("--model", default=os.environ.get("POC_MODEL", "qwen3:8b"))
parser.add_argument(
"--mcp-url", default=os.environ.get("POC_MCP_URL", "http://127.0.0.1:8765/mcp")
)
parser.add_argument("--keys-dir", default=os.environ.get("POC_KEYS_DIR", "keys"))
parser.add_argument(
"--ledger", default=os.environ.get("POC_LEDGER", "keys/ledger.jsonl")
)
parser.add_argument("--ollama-host", default=os.environ.get("OLLAMA_HOST"))
args = parser.parse_args()
summary = asyncio.run(
run_once(
purpose=args.purpose,
model=args.model,
mcp_url=args.mcp_url,
keys_dir=args.keys_dir,
ledger_path=args.ledger,
ollama_host=args.ollama_host,
)
)
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()