mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-01 00:11:39 +08:00
Compare commits
1057 Commits
simplify-t
...
feat/devex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4d7e6a29e | ||
|
|
172a38c344 | ||
|
|
8bc0d4f77d | ||
|
|
8eabdefa8a | ||
|
|
f658af45c2 | ||
|
|
5212644861 | ||
|
|
1151f84351 | ||
|
|
9abd6bf342 | ||
|
|
d2c7ef6b41 | ||
|
|
a34102049b | ||
|
|
ef5d811aba | ||
|
|
2d44ed1c5b | ||
|
|
fa2e72ae9c | ||
|
|
5bfc4ed53b | ||
|
|
520aec20e0 | ||
|
|
64bec1d060 | ||
|
|
ac58309dbd | ||
|
|
a5a5d82a21 | ||
|
|
34e8d088c2 | ||
|
|
c754135965 | ||
|
|
c6b75baad0 | ||
|
|
a7ad6f6d28 | ||
|
|
1a2141d04d | ||
|
|
ff3f3169b2 | ||
|
|
f4580b6010 | ||
|
|
7b63a787b3 | ||
|
|
069570d103 | ||
|
|
0dafdcab86 | ||
|
|
654e16187e | ||
|
|
732c66b0f3 | ||
|
|
1f0944de21 | ||
|
|
f1a1b58319 | ||
|
|
c21d77ca08 | ||
|
|
d6c710706f | ||
|
|
a6d3becd6a | ||
|
|
3b67606c42 | ||
|
|
a2d0d07109 | ||
|
|
aedb773f0d | ||
|
|
aaf8f2d2d2 | ||
|
|
12f4800631 | ||
|
|
57b48a81ca | ||
|
|
7af33accf1 | ||
|
|
3214c05e82 | ||
|
|
4608a7fe4e | ||
|
|
af67ea8800 | ||
|
|
37c3dcf551 | ||
|
|
6a49fbb7da | ||
|
|
eb0b01de7b | ||
|
|
5b1528519c | ||
|
|
52f92eb689 | ||
|
|
7f9dd60c15 | ||
|
|
77da3bbc95 | ||
|
|
bb489a3903 | ||
|
|
167eb824cb | ||
|
|
efb64aee5a | ||
|
|
3045e29232 | ||
|
|
5d7d76025a | ||
|
|
e6c829384e | ||
|
|
5c658a416c | ||
|
|
a130aa8165 | ||
|
|
35d57ed752 | ||
|
|
5785bd3272 | ||
|
|
cf9482984e | ||
|
|
67275641f8 | ||
|
|
3ffaac00dd | ||
|
|
816a3ef6f1 | ||
|
|
a8bf414f4a | ||
|
|
3b312d45c5 | ||
|
|
fcd899f888 | ||
|
|
315f3ea429 | ||
|
|
b7d6eae64c | ||
|
|
b3765c28d0 | ||
|
|
4cfb66bac2 | ||
|
|
0c4cff352a | ||
|
|
503269b85a | ||
|
|
161436cfdd | ||
|
|
24f549a692 | ||
|
|
7a8778ac73 | ||
|
|
763c6d104d | ||
|
|
4d7d9d9715 | ||
|
|
a9c35f9175 | ||
|
|
37752ff1ac | ||
|
|
31b84213e4 | ||
|
|
2036c22f88 | ||
|
|
7185a66b96 | ||
|
|
2394e18729 | ||
|
|
99f7582175 | ||
|
|
93c5997290 | ||
|
|
2d1a1c1c47 | ||
|
|
71e81728ac | ||
|
|
ebe60646db | ||
|
|
f996d7950b | ||
|
|
ae4a674c84 | ||
|
|
169615abc8 | ||
|
|
7c30ac2141 | ||
|
|
192501528f | ||
|
|
5ae0b731d0 | ||
|
|
d9f373654b | ||
|
|
0efbb137e8 | ||
|
|
cf63b2471f | ||
|
|
f88343a6da | ||
|
|
491605cfea | ||
|
|
3aded1d4e5 | ||
|
|
4f0402ed3a | ||
|
|
ecac6321c4 | ||
|
|
20c6573e0a | ||
|
|
97b1c76b14 | ||
|
|
24a37032fa | ||
|
|
c0520223fd | ||
|
|
1f1caa836a | ||
|
|
b3ea7714f5 | ||
|
|
a7f9721785 | ||
|
|
a5461e07bf | ||
|
|
2e73a9e893 | ||
|
|
26bb56b775 | ||
|
|
95b1130485 | ||
|
|
3fb8938cd3 | ||
|
|
c5e8166c8b | ||
|
|
2b88568653 | ||
|
|
34b4fe495e | ||
|
|
4fdd6c0dac | ||
|
|
60b6abefd9 | ||
|
|
4d53b7ccaa | ||
|
|
081079da62 | ||
|
|
333e4abe30 | ||
|
|
cd77c7100c | ||
|
|
cf810c2950 | ||
|
|
a23bcb81ce | ||
|
|
d07d867718 | ||
|
|
666f2dd486 | ||
|
|
34792dd907 | ||
|
|
7ad6fc8a40 | ||
|
|
f824c10429 | ||
|
|
132e5ec179 | ||
|
|
66d3e6a0c2 | ||
|
|
4a09ae2985 | ||
|
|
8c734f2f27 | ||
|
|
245d174359 | ||
|
|
77f47768dd | ||
|
|
90fa9e54ca | ||
|
|
9d3a44e0e8 | ||
|
|
932d596466 | ||
|
|
d518f40e8b | ||
|
|
f016cfca46 | ||
|
|
b8120df860 | ||
|
|
0df7df52f3 | ||
|
|
bfa27d0a68 | ||
|
|
5a20c486e3 | ||
|
|
78e19ebc95 | ||
|
|
b383cafc44 | ||
|
|
b10ff83566 | ||
|
|
daa1f542f9 | ||
|
|
d507f593d0 | ||
|
|
f210510276 | ||
|
|
19b6f81ee7 | ||
|
|
76545ab365 | ||
|
|
b8c3bc7841 | ||
|
|
a680367568 | ||
|
|
dfd37a4b31 | ||
|
|
5ee9b67d9b | ||
|
|
542faf225f | ||
|
|
5684c68121 | ||
|
|
4be783446a | ||
|
|
8d719b180a | ||
|
|
bf048c8aec | ||
|
|
c5a9d1ef9d | ||
|
|
c7b6f423c7 | ||
|
|
6d34207167 | ||
|
|
fcde9be10d | ||
|
|
3830bbda41 | ||
|
|
4447e7d71a | ||
|
|
7bccd904c7 | ||
|
|
313d522b61 | ||
|
|
9ee4fe41fe | ||
|
|
39ee3512cb | ||
|
|
42673556af | ||
|
|
faab73ad58 | ||
|
|
7e36468511 | ||
|
|
9ba5d399e5 | ||
|
|
306d92a9d7 | ||
|
|
5baae0df88 | ||
|
|
24f6a193e7 | ||
|
|
8c0f8baf32 | ||
|
|
d80c30cc92 | ||
|
|
e64d646bad | ||
|
|
b84f9e410c | ||
|
|
ee5daba061 | ||
|
|
23e84de830 | ||
|
|
48e0dc8791 | ||
|
|
fb0f579b16 | ||
|
|
5a711f32b1 | ||
|
|
4d34427cc7 | ||
|
|
41877183bc | ||
|
|
451a007fb1 | ||
|
|
0a82396718 | ||
|
|
5da55ea1e3 | ||
|
|
064c009deb | ||
|
|
caab1cf453 | ||
|
|
55c70f3508 | ||
|
|
d29249b8fa | ||
|
|
f668e9fc75 | ||
|
|
74fe1e2254 | ||
|
|
348936752a | ||
|
|
69a36a3361 | ||
|
|
8712dd6d1c | ||
|
|
55a21fe37b | ||
|
|
f55f625277 | ||
|
|
9dac85b069 | ||
|
|
99bd69baa8 | ||
|
|
a62a137a4f | ||
|
|
82b18e8ac2 | ||
|
|
0111c9848d | ||
|
|
ab9cadfeee | ||
|
|
8bf28e1441 | ||
|
|
ce28f847ce | ||
|
|
5609117882 | ||
|
|
b4fbb6fe10 | ||
|
|
82d7e9429e | ||
|
|
e2821effb5 | ||
|
|
9742f11fda | ||
|
|
388dd4789c | ||
|
|
fdebca4573 | ||
|
|
479dfc096a | ||
|
|
3c6c11b7c9 | ||
|
|
bc091eb7ef | ||
|
|
f75b1d21b4 | ||
|
|
94053d75a6 | ||
|
|
2a68099675 | ||
|
|
6cd3bc6640 | ||
|
|
211b55815e | ||
|
|
8ae4a6f824 | ||
|
|
b98301677a | ||
|
|
f2fdde5ba4 | ||
|
|
4f56e31dc7 | ||
|
|
6d3804770c | ||
|
|
ab0f4126cf | ||
|
|
585f8528b2 | ||
|
|
75f523f5c0 | ||
|
|
68fbae5692 | ||
|
|
80f1dd8d37 | ||
|
|
b52b37ae64 | ||
|
|
d63b363cde | ||
|
|
c05c60665e | ||
|
|
b4873a5de7 | ||
|
|
913f8ce0a5 | ||
|
|
4a63737227 | ||
|
|
3e93db16bd | ||
|
|
f863a42351 | ||
|
|
dc55f493be | ||
|
|
936fda3f9e | ||
|
|
ecb8148a9f | ||
|
|
2dbbedc05a | ||
|
|
c30967806c | ||
|
|
145f719d30 | ||
|
|
b89eb29174 | ||
|
|
3670089a42 | ||
|
|
3982fcf095 | ||
|
|
8481fdcf08 | ||
|
|
39299e2de4 | ||
|
|
efec4fcaab | ||
|
|
5ce2c47d60 | ||
|
|
f6f3d1de9b | ||
|
|
ec0fe3242a | ||
|
|
f2e24faaca | ||
|
|
8c80b96318 | ||
|
|
2387465dcc | ||
|
|
32636ecf8a | ||
|
|
6055adbe1b | ||
|
|
ffd2f8dc50 | ||
|
|
e93b4d1dcd | ||
|
|
014a5b712d | ||
|
|
2317d115cd | ||
|
|
8253b54be9 | ||
|
|
5c867fd79f | ||
|
|
a44e041acf | ||
|
|
e9f05b3524 | ||
|
|
e2a834578d | ||
|
|
ffc752a79e | ||
|
|
399562a7d1 | ||
|
|
fec8a0da72 | ||
|
|
9f4542b3db | ||
|
|
363633e2ba | ||
|
|
a41ba57a7a | ||
|
|
884c8ea70a | ||
|
|
c886333d32 | ||
|
|
55b173dd03 | ||
|
|
9079a27814 | ||
|
|
d7d10b14cd | ||
|
|
a6499b6107 | ||
|
|
74a36b0729 | ||
|
|
efc7a7b957 | ||
|
|
4f1464b3af | ||
|
|
3a41079fac | ||
|
|
5279540bb4 | ||
|
|
577da79a47 | ||
|
|
1faa9648d3 | ||
|
|
ad57bf1e4b | ||
|
|
d5efb82c7c | ||
|
|
36214d14db | ||
|
|
ea2f7ef2f6 | ||
|
|
435530018b | ||
|
|
df61054a84 | ||
|
|
690b8bb563 | ||
|
|
c43451a50b | ||
|
|
1e312c6582 | ||
|
|
e36c8cd49a | ||
|
|
16cb6d1a6e | ||
|
|
21d61bdd71 | ||
|
|
ad9c26afb8 | ||
|
|
83f99d8203 | ||
|
|
6b37d38dee | ||
|
|
938499ddfb | ||
|
|
d92266d7c0 | ||
|
|
a352b5c193 | ||
|
|
82f7483999 | ||
|
|
56dc9277d7 | ||
|
|
d50e9bcef7 | ||
|
|
c4e520fd6e | ||
|
|
30ff395924 | ||
|
|
f55025952d | ||
|
|
1bc45ee8fe | ||
|
|
19016497ef | ||
|
|
d578d06f59 | ||
|
|
e25ad79d5d | ||
|
|
f2624a1426 | ||
|
|
15561ec425 | ||
|
|
93d93fdea4 | ||
|
|
87f4e4cb9b | ||
|
|
82cb1752d9 | ||
|
|
ada3713e77 | ||
|
|
7d79ce92ac | ||
|
|
1708dcd2b2 | ||
|
|
5702eba93b | ||
|
|
a1767fd69c | ||
|
|
b4b426c69d | ||
|
|
2465674fda | ||
|
|
2eca0d4af1 | ||
|
|
11a7c6b112 | ||
|
|
50ea8adf46 | ||
|
|
ca33372595 | ||
|
|
7d47e3b776 | ||
|
|
fe15a2c65c | ||
|
|
d400fb8b23 | ||
|
|
3221818b6e | ||
|
|
2af2f148ab | ||
|
|
d19109742e | ||
|
|
078e2e4b19 | ||
|
|
9aa2999388 | ||
|
|
d0d9897e81 | ||
|
|
9306a1e06a | ||
|
|
141b12bd39 | ||
|
|
ae3deff8d4 | ||
|
|
41adca4e77 | ||
|
|
8e901b31c1 | ||
|
|
11a5a64729 | ||
|
|
0dba3027c1 | ||
|
|
405c7e08be | ||
|
|
cb36930f1d | ||
|
|
90e6fa2612 | ||
|
|
fd22ae5fcb | ||
|
|
e1baab90f7 | ||
|
|
4fcfa329ba | ||
|
|
b336980229 | ||
|
|
7128f95621 | ||
|
|
ffc6d767ec | ||
|
|
44a2d0c01f | ||
|
|
3e2ed18ad0 | ||
|
|
db58cfb13d | ||
|
|
3220bb8aaa | ||
|
|
ff3a479156 | ||
|
|
6f4941616d | ||
|
|
bd3025d669 | ||
|
|
4c72329412 | ||
|
|
8311e8984b | ||
|
|
093acd72dd | ||
|
|
e9ab711b66 | ||
|
|
b2a9f6beaa | ||
|
|
d3504f84af | ||
|
|
34badeb19c | ||
|
|
f93b48226c | ||
|
|
4805be0119 | ||
|
|
a3ca71fe26 | ||
|
|
70a0a5ff4a | ||
|
|
021f62cb0c | ||
|
|
ba214e43c8 | ||
|
|
520a26c48f | ||
|
|
a787a0d60b | ||
|
|
8d2d8cc728 | ||
|
|
4ae61b0886 | ||
|
|
79871c2083 | ||
|
|
7796ac1411 | ||
|
|
c45aeb45b1 | ||
|
|
ee7fde6531 | ||
|
|
0ea6c34325 | ||
|
|
3db3d60368 | ||
|
|
bfd08d5648 | ||
|
|
7f9777a0b0 | ||
|
|
87a16ad2e5 | ||
|
|
f90a627f9a | ||
|
|
152e0800e6 | ||
|
|
d8f10fa515 | ||
|
|
e86f391cac | ||
|
|
e39de2e752 | ||
|
|
1538be45de | ||
|
|
95e3f4b001 | ||
|
|
b7821b6dc1 | ||
|
|
556a132f2d | ||
|
|
fafb9c23bf | ||
|
|
1754bdf1e8 | ||
|
|
fa3d7b3d03 | ||
|
|
73f2998d48 | ||
|
|
6a51fd23df | ||
|
|
ffec21236d | ||
|
|
db0521ce0e | ||
|
|
a1c25046a9 | ||
|
|
de0af4df66 | ||
|
|
0e1723ef74 | ||
|
|
aefc330b8f | ||
|
|
f967471758 | ||
|
|
4f5ffb8909 | ||
|
|
54909b0282 | ||
|
|
f084538cb9 | ||
|
|
535b46f813 | ||
|
|
4766b3cdb9 | ||
|
|
354af6ccee | ||
|
|
c9afbbac0b | ||
|
|
83fa442c1b | ||
|
|
1900e5238b | ||
|
|
ddae1aa2e9 | ||
|
|
16274d5a82 | ||
|
|
5749f5809c | ||
|
|
d10108f8ca | ||
|
|
8b520f9848 | ||
|
|
4cc431afab | ||
|
|
a718aed1be | ||
|
|
5f29e7b63c | ||
|
|
245c766512 | ||
|
|
f08ad94d4d | ||
|
|
cdf5375b9a | ||
|
|
bdf4758510 | ||
|
|
84e45b5c40 | ||
|
|
daedec6957 | ||
|
|
de59d91add | ||
|
|
68cc81a74d | ||
|
|
3ead3401e0 | ||
|
|
eec31b0089 | ||
|
|
7df14227a9 | ||
|
|
60effcfc44 | ||
|
|
63f5e14c69 | ||
|
|
64ff8f065b | ||
|
|
468b7fdbad | ||
|
|
14b0ad95c6 | ||
|
|
221e4228ec | ||
|
|
dd9d3f89b9 | ||
|
|
b0cce17da6 | ||
|
|
c6b3b8c847 | ||
|
|
2ba87a10b0 | ||
|
|
5fa3e24b76 | ||
|
|
ac6d747fa6 | ||
|
|
ee541c84f1 | ||
|
|
6053236158 | ||
|
|
11615014a4 | ||
|
|
3588396263 | ||
|
|
11a2ecb936 | ||
|
|
151e8d896c | ||
|
|
593c549bc4 | ||
|
|
aa2ecaef29 | ||
|
|
0eb0bec74c | ||
|
|
3c252ae44b | ||
|
|
6789084ec0 | ||
|
|
b603b6e1c9 | ||
|
|
3c13feed4c | ||
|
|
7652afb8de | ||
|
|
7862e7010c | ||
|
|
4faf2a6cf4 | ||
|
|
8c48bb080f | ||
|
|
6d2481ee5c | ||
|
|
ca5525bcd7 | ||
|
|
56b53bff6e | ||
|
|
fd335a4e26 | ||
|
|
c4ea996612 | ||
|
|
39bfd226b8 | ||
|
|
234b67f5fd | ||
|
|
e27e3a4f8a | ||
|
|
7a11ff95a9 | ||
|
|
33ab5cec82 | ||
|
|
1cb2311bad | ||
|
|
25c65bc99e | ||
|
|
afb680b50d | ||
|
|
c574a4d086 | ||
|
|
bd8b20b933 | ||
|
|
866fd9476b | ||
|
|
d2ec5aaacf | ||
|
|
e265006fd6 | ||
|
|
b1bf11b0fe | ||
|
|
6bf3aad62e | ||
|
|
3a840a130c | ||
|
|
14396e3fe7 | ||
|
|
1ad930cbd0 | ||
|
|
7a0b37712f | ||
|
|
e2b8740fcf | ||
|
|
45d132d098 | ||
|
|
719f2eef32 | ||
|
|
698b35933e | ||
|
|
0512ada793 | ||
|
|
47289ba6f1 | ||
|
|
5e5e0efc60 | ||
|
|
7b38afc179 | ||
|
|
e5893075f9 | ||
|
|
5e598a588f | ||
|
|
c2d8d17285 | ||
|
|
8bc2de4ab6 | ||
|
|
75a92a3f82 | ||
|
|
72963e9ccb | ||
|
|
92da8e7e62 | ||
|
|
c84d5ce738 | ||
|
|
dda9f3e734 | ||
|
|
834e25a662 | ||
|
|
196a13f3dc | ||
|
|
440d33eec4 | ||
|
|
11f5c1ecf0 | ||
|
|
3b745633e4 | ||
|
|
900d48714a | ||
|
|
3fdf03390e | ||
|
|
25fb9aafcb | ||
|
|
54147474d3 | ||
|
|
4d6f380bd1 | ||
|
|
93f5fd80b8 | ||
|
|
177be32b7f | ||
|
|
30efc263ff | ||
|
|
ed0e860abb | ||
|
|
41d8a80226 | ||
|
|
4ec386cc72 | ||
|
|
dd69f16c3e | ||
|
|
1db5598294 | ||
|
|
23d0b7af6a | ||
|
|
a7c2b9e280 | ||
|
|
70dfec9638 | ||
|
|
95b0610f36 | ||
|
|
500f0eab4a | ||
|
|
86b1db0598 | ||
|
|
5a79e423fe | ||
|
|
7f7643cf63 | ||
|
|
bf52468a91 | ||
|
|
b4688f10d4 | ||
|
|
31a5cd185a | ||
|
|
7166647ca1 | ||
|
|
f7300a858e | ||
|
|
e87859e82c | ||
|
|
de101a8202 | ||
|
|
7f1f4c2248 | ||
|
|
c33f8d381b | ||
|
|
3f58e47c63 | ||
|
|
b7f8a17c24 | ||
|
|
6cbb8f3a0c | ||
|
|
ec97f9ad1a | ||
|
|
10085041cf | ||
|
|
7b23dbfe68 | ||
|
|
8e0c48e6d2 | ||
|
|
b759602483 | ||
|
|
2205b22409 | ||
|
|
1ddf8c26f5 | ||
|
|
9769e07cd5 | ||
|
|
08250a53a1 | ||
|
|
ff6d62802d | ||
|
|
46506769f1 | ||
|
|
4ea29978fc | ||
|
|
dfd50ceccd | ||
|
|
6366177118 | ||
|
|
2390728cc3 | ||
|
|
b32c642af3 | ||
|
|
c36b256de5 | ||
|
|
0afe1b707d | ||
|
|
f213620c8b | ||
|
|
35655298e6 | ||
|
|
1e463a8e39 | ||
|
|
de5a88bd97 | ||
|
|
0862fa96fd | ||
|
|
924570c5be | ||
|
|
4d8689c10c | ||
|
|
1d7ce5e063 | ||
|
|
72d3425eef | ||
|
|
b7f099beed | ||
|
|
912ef50165 | ||
|
|
4a9086b848 | ||
|
|
50cb4d5fc7 | ||
|
|
2bc9508b7c | ||
|
|
337cd574c8 | ||
|
|
9fb027915e | ||
|
|
2b821c3a14 | ||
|
|
0d113fab1a | ||
|
|
19f28a633a | ||
|
|
2c817ce4a5 | ||
|
|
66a5bc64db | ||
|
|
7f423508e4 | ||
|
|
306c6706a6 | ||
|
|
64be67e062 | ||
|
|
0c0a2eb0a2 | ||
|
|
de0829cec3 | ||
|
|
20177660bb | ||
|
|
609fc6d080 | ||
|
|
518826e70c | ||
|
|
13992a58da | ||
|
|
0d2ac1c07f | ||
|
|
fb7df099e0 | ||
|
|
f14ff3e041 | ||
|
|
07fcb94bc0 | ||
|
|
66d9983d46 | ||
|
|
4f3cb98e5e | ||
|
|
8c1f5efcab | ||
|
|
c92bdd8785 | ||
|
|
e09ef6b8bc | ||
|
|
f7677ed275 | ||
|
|
e5f719a33b | ||
|
|
79bd65034c | ||
|
|
fbb1923fad | ||
|
|
bf75c450b7 | ||
|
|
b2172c4b2e | ||
|
|
69ccd76679 | ||
|
|
8b54bb4d89 | ||
|
|
2595d81733 | ||
|
|
f9e05218ca | ||
|
|
2ddda5da89 | ||
|
|
dc80f0b222 | ||
|
|
5007a122b2 | ||
|
|
43f2321225 | ||
|
|
1362f92f2e | ||
|
|
445d2646a9 | ||
|
|
ae8d25faca | ||
|
|
9061c03b6d | ||
|
|
8174f5a988 | ||
|
|
03f7b551be | ||
|
|
80ad6572a3 | ||
|
|
c77f3da0ce | ||
|
|
c104647450 | ||
|
|
547ba73b82 | ||
|
|
3526fa27fd | ||
|
|
9eabdb64ff | ||
|
|
6f543eac9f | ||
|
|
64eca85876 | ||
|
|
152271851f | ||
|
|
0909be3aa8 | ||
|
|
274e623b50 | ||
|
|
2972f982e4 | ||
|
|
df8a62d018 | ||
|
|
fec5d59fb3 | ||
|
|
7285e44064 | ||
|
|
2ff54ae6b3 | ||
|
|
f74ac0fc3a | ||
|
|
26a6da27fa | ||
|
|
19abbfff96 | ||
|
|
8aa531c7fa | ||
|
|
21cf339a85 | ||
|
|
588cdacd49 | ||
|
|
0cce536fb2 | ||
|
|
b281ecd50a | ||
|
|
b267e34092 | ||
|
|
58fce0a37b | ||
|
|
f0458ebdb8 | ||
|
|
0a231c0783 | ||
|
|
7c1f90045e | ||
|
|
a5ea272936 | ||
|
|
715825eac3 | ||
|
|
1a97e82000 | ||
|
|
70d1abf81b | ||
|
|
1fd0fcddb2 | ||
|
|
ab4bbf2fb2 | ||
|
|
669e4d0297 | ||
|
|
f92875bc3e | ||
|
|
7f36259f88 | ||
|
|
2c28d9f560 | ||
|
|
c21b071e77 | ||
|
|
de197bd7cb | ||
|
|
bf9dd83c10 | ||
|
|
760fb2ca0e | ||
|
|
a8ccaca8ea | ||
|
|
32070e6bc0 | ||
|
|
f02f647237 | ||
|
|
96043a8f7e | ||
|
|
0bb8d8faf5 | ||
|
|
f5c09a3aba | ||
|
|
3227cc65d1 | ||
|
|
90ca2ae16b | ||
|
|
25e260bb3a | ||
|
|
feea8332d6 | ||
|
|
ffbdd7fcce | ||
|
|
b699cf8c48 | ||
|
|
2efd9bbac4 | ||
|
|
0ac3af8776 | ||
|
|
fed9f06c4e | ||
|
|
240f33a06f | ||
|
|
254aafb265 | ||
|
|
8bd82119be | ||
|
|
9a148bb9a3 | ||
|
|
7a4241e406 | ||
|
|
cb92fbe749 | ||
|
|
1d04074464 | ||
|
|
c4096b4731 | ||
|
|
178658bf9f | ||
|
|
d372eb1f0e | ||
|
|
ebe25fefd6 | ||
|
|
688ccf05cb | ||
|
|
9dc5615b9d | ||
|
|
696e2316a8 | ||
|
|
f2891b70d0 | ||
|
|
dcf370cb6e | ||
|
|
1b8eb85eeb | ||
|
|
cf3236ed27 | ||
|
|
6c86c7c4a9 | ||
|
|
9cc2cf3241 | ||
|
|
9eb4a4a481 | ||
|
|
8463b7ea59 | ||
|
|
faa185e37c | ||
|
|
53b3177ca5 | ||
|
|
76badfed63 | ||
|
|
3c1e31de3e | ||
|
|
d2c932d3ac | ||
|
|
5a569eb1b6 | ||
|
|
e5bd25c73f | ||
|
|
eb88474dd8 | ||
|
|
9fc0ca0a72 | ||
|
|
95b6bd5df6 | ||
|
|
f1311ad3de | ||
|
|
0310170869 | ||
|
|
b6d7e222c1 | ||
|
|
e71d9a89d2 | ||
|
|
74c662b63a | ||
|
|
91bdb9eb2d | ||
|
|
47f16505d2 | ||
|
|
e63986b534 | ||
|
|
cbde8548f4 | ||
|
|
7a3656aea2 | ||
|
|
3ba8b15f13 | ||
|
|
7727a792f2 | ||
|
|
ce175d7372 | ||
|
|
609b19b630 | ||
|
|
e3cb957a10 | ||
|
|
55a0178490 | ||
|
|
9a858b8d67 | ||
|
|
cbff32585d | ||
|
|
8fc28c34ce | ||
|
|
d72b9eadec | ||
|
|
3c5bf5b9d8 | ||
|
|
5a07e26405 | ||
|
|
cd66546e24 | ||
|
|
21a59a4a7c | ||
|
|
b5dbf8e43d | ||
|
|
54e50b8a6e | ||
|
|
b35dbb0420 | ||
|
|
63f6afd75b | ||
|
|
3e311a0092 | ||
|
|
69d3d3c15a | ||
|
|
33bc1a3b58 | ||
|
|
740dd928f7 | ||
|
|
757d012ab5 | ||
|
|
f64a87209d | ||
|
|
41df8ee4f5 | ||
|
|
6877d5f3b5 | ||
|
|
6d74d424d3 | ||
|
|
9ec4f7504b | ||
|
|
9166d56f17 | ||
|
|
80b90dd0d9 | ||
|
|
91907789af | ||
|
|
6845852e82 | ||
|
|
99af12af3f | ||
|
|
fd76ff60ac | ||
|
|
cc6bea8b90 | ||
|
|
c1d9e9a285 | ||
|
|
681141a526 | ||
|
|
c100541f07 | ||
|
|
d64f62c2ef | ||
|
|
e049441d93 | ||
|
|
a30b2f34eb | ||
|
|
2bf96ad244 | ||
|
|
a183827128 | ||
|
|
54dd1b3038 | ||
|
|
75d251b81a | ||
|
|
7a6d4666a2 | ||
|
|
d802db4de0 | ||
|
|
b103bb4c8b | ||
|
|
a9d16c40c7 | ||
|
|
98e3a26b2a | ||
|
|
f209a92b7e | ||
|
|
0edfc7fa49 | ||
|
|
cefe038a87 | ||
|
|
0858ee2f27 | ||
|
|
4d1f2ea522 | ||
|
|
6447a6020c | ||
|
|
b3bf21db56 | ||
|
|
674a6f96d3 | ||
|
|
79f8831738 | ||
|
|
224c900532 | ||
|
|
4f9f5f70e3 | ||
|
|
38db6e9366 | ||
|
|
d18c753b3c | ||
|
|
8fedbf87d9 | ||
|
|
d8a369e194 | ||
|
|
90af34bc83 | ||
|
|
c7857dc1d4 | ||
|
|
08e4dc2563 | ||
|
|
92447141d9 | ||
|
|
e0ed44388f | ||
|
|
16d0aa7b4d | ||
|
|
6037b6a5ab | ||
|
|
e1604b2b4a | ||
|
|
db23f51bc6 | ||
|
|
3c6750f37b | ||
|
|
df2ec585f1 | ||
|
|
250b2ca01a | ||
|
|
c2d5f7bf26 | ||
|
|
e223b4ac09 | ||
|
|
f072801f38 | ||
|
|
ededaaa874 | ||
|
|
b1f55e3ee5 | ||
|
|
51b95236f9 | ||
|
|
9123cfb5dd | ||
|
|
9018e9dd70 | ||
|
|
08ff1c1aa8 | ||
|
|
6134939882 | ||
|
|
7cb6427dea | ||
|
|
79b62497d1 | ||
|
|
0729ef7353 | ||
|
|
8f6788474b | ||
|
|
5c2926102b | ||
|
|
bff37075f6 | ||
|
|
c98ee98525 | ||
|
|
ecb430effe | ||
|
|
7ee7221af1 | ||
|
|
748fd3db88 | ||
|
|
cbff1b818c | ||
|
|
a885d2f240 | ||
|
|
b6247b71b5 | ||
|
|
3555c6173d | ||
|
|
3976962621 | ||
|
|
a54a27595b | ||
|
|
7283b9f6cf | ||
|
|
3dfc0a9679 | ||
|
|
6903c4605c | ||
|
|
5b3f708fcb | ||
|
|
b33ed9176f | ||
|
|
c48817f69b | ||
|
|
70dd3a16dc | ||
|
|
9a19fe1f50 | ||
|
|
3961f8e7a4 | ||
|
|
fc37b17b1f | ||
|
|
630bd3d789 | ||
|
|
5c4c0c0cba | ||
|
|
24c241d29b | ||
|
|
a3d760ff12 | ||
|
|
77a3dda59d | ||
|
|
f6daceb449 | ||
|
|
cfef34f7a6 | ||
|
|
c007b9e5bd | ||
|
|
b9f3518b33 | ||
|
|
ba07d9d5e3 | ||
|
|
90e5211128 | ||
|
|
c0d412a736 | ||
|
|
f9eb5edb96 | ||
|
|
ba8b80a163 | ||
|
|
3b90fa5c9b | ||
|
|
273b367f05 | ||
|
|
783acd712d | ||
|
|
748f0b2b5f | ||
|
|
9350e26e68 | ||
|
|
997f793af1 | ||
|
|
4d5f29c74c | ||
|
|
d070b8698d | ||
|
|
057d3e1810 | ||
|
|
d49af633f0 | ||
|
|
3191a9ba11 | ||
|
|
53e13fe1f1 | ||
|
|
59cb0cecb2 | ||
|
|
1c6846c4c2 | ||
|
|
b88e441a07 | ||
|
|
4f57d7116d | ||
|
|
422607df7c | ||
|
|
3f4b494c61 | ||
|
|
109dffb242 | ||
|
|
0e8ee051c6 | ||
|
|
5c545e67f3 | ||
|
|
2daf5e4296 | ||
|
|
d0c8dd78c2 | ||
|
|
21c3e9973a | ||
|
|
8e4d013154 | ||
|
|
37fb01b17d | ||
|
|
ac0a70b369 | ||
|
|
a4bc6f73d7 | ||
|
|
56ee8a5cc6 | ||
|
|
440c244cac | ||
|
|
655303f2f1 | ||
|
|
14e59706b7 | ||
|
|
d59e93d5e9 | ||
|
|
9e85408c7b | ||
|
|
225ae32e7a | ||
|
|
50ef18644b | ||
|
|
41608beb35 | ||
|
|
d9a8e421a4 | ||
|
|
d7cef744ec | ||
|
|
54cbf30c14 | ||
|
|
dfa3c6265c | ||
|
|
a7f52911e1 | ||
|
|
1e31614572 | ||
|
|
3b615b0f7a | ||
|
|
e184f5ab3a | ||
|
|
d0f82e6dcc | ||
|
|
49e1f9ea89 | ||
|
|
6731230d73 | ||
|
|
ec59d71e60 | ||
|
|
bdac541d1e | ||
|
|
061fa70907 | ||
|
|
48b5cfd085 | ||
|
|
a7609c97be | ||
|
|
c33feb6dc9 | ||
|
|
2c7deb41f6 | ||
|
|
8117d0adab | ||
|
|
01a3a6ab0d | ||
|
|
45a8098d3a | ||
|
|
60812ae041 | ||
|
|
635bec06cb | ||
|
|
0f58dfdea4 | ||
|
|
dd5fe334f3 | ||
|
|
e0c9d495ef | ||
|
|
2f34e6fd30 | ||
|
|
69aa35a51c | ||
|
|
5404a8fcd8 | ||
|
|
eb49936a60 | ||
|
|
ff9ea6c4b1 | ||
|
|
586b0a7047 | ||
|
|
84718d183a | ||
|
|
3099a2f53c | ||
|
|
ed010752dd | ||
|
|
f5be6177b2 | ||
|
|
89c6f24d48 | ||
|
|
f23856df8e | ||
|
|
1b7bc299f3 | ||
|
|
a291cc99cf | ||
|
|
389ac5e017 | ||
|
|
fc792a4be9 | ||
|
|
07501bef14 | ||
|
|
137ce05324 | ||
|
|
ada0b4f131 | ||
|
|
abe925e212 | ||
|
|
8fb44608bf | ||
|
|
153cd5bb44 | ||
|
|
669545f551 | ||
|
|
cfe2f3fe15 | ||
|
|
140d609e0c | ||
|
|
a32ad1a656 | ||
|
|
62ba69a29d | ||
|
|
9b0f2a16ca | ||
|
|
85e629e915 | ||
|
|
999a28062d | ||
|
|
ba3fea24f1 | ||
|
|
6b4a8d0b17 | ||
|
|
5ec75e38b9 | ||
|
|
ad042fdd68 | ||
|
|
35ad3146a8 | ||
|
|
e8343f2d87 | ||
|
|
1b1307d0d1 | ||
|
|
7a11be9f3f | ||
|
|
192ce958c3 | ||
|
|
c441681dc2 | ||
|
|
dd70d57b9b | ||
|
|
f12ea1bc02 | ||
|
|
fa76a331b0 | ||
|
|
d999d9876d | ||
|
|
578a5fb6a9 | ||
|
|
a8809bbd3e | ||
|
|
a478e44585 | ||
|
|
c0494b3558 | ||
|
|
7f1cd014f2 | ||
|
|
07b615e96e | ||
|
|
ab387a6120 | ||
|
|
ac79725923 | ||
|
|
8dd38318fc | ||
|
|
533c064269 | ||
|
|
5c3105b437 | ||
|
|
3c0d0dba49 | ||
|
|
12bbca95ec | ||
|
|
f6574978de | ||
|
|
8380895ae3 | ||
|
|
f018999da9 | ||
|
|
51a6b7d2b5 | ||
|
|
9bfe185a2e | ||
|
|
beeb7896e0 | ||
|
|
212460289b | ||
|
|
221fb17c5e | ||
|
|
488deb04a4 | ||
|
|
9d9eea9ac9 | ||
|
|
e7f0ffbf5d | ||
|
|
a09b018bd5 | ||
|
|
7eac4ee9fe | ||
|
|
17a5efb416 | ||
|
|
3e634aa7e4 | ||
|
|
5d3398aa8a | ||
|
|
76d929e177 | ||
|
|
be91af7551 | ||
|
|
c9011fc7e1 | ||
|
|
ff776b57bf | ||
|
|
3ee788dacc | ||
|
|
fef504f038 | ||
|
|
bbb5776763 | ||
|
|
e87bee9ccd | ||
|
|
69a338610a | ||
|
|
aa6394e94f | ||
|
|
ef409c6a24 | ||
|
|
da4167560f | ||
|
|
3488576bd8 | ||
|
|
619c72e566 | ||
|
|
a3ba41fce2 | ||
|
|
c935a604f8 | ||
|
|
e114f09f70 | ||
|
|
9b4d9452ba | ||
|
|
bbeed5b5d1 | ||
|
|
971ed2bbdf | ||
|
|
affc4e9a8f | ||
|
|
3db83b6824 | ||
|
|
9c8d707530 | ||
|
|
8f5f99c22a | ||
|
|
32254d3010 | ||
|
|
20f2875472 | ||
|
|
c360da4f35 | ||
|
|
bc76a032ba | ||
|
|
8e986584f4 | ||
|
|
4b68d30b0e | ||
|
|
b292192467 | ||
|
|
f172f7d4aa | ||
|
|
8e8b6be690 | ||
|
|
e8c6135a91 | ||
|
|
771cf41fea | ||
|
|
7ea17bb957 | ||
|
|
f8846f85a1 | ||
|
|
4c05ef0ba8 | ||
|
|
5438b64e32 | ||
|
|
248acf715e | ||
|
|
54ca0997ee | ||
|
|
b78076cac7 | ||
|
|
ba19d530ad | ||
|
|
47555602d7 | ||
|
|
6eb76c7c1a | ||
|
|
b32cc4b09d | ||
|
|
6e3dbb8d8b | ||
|
|
b66c093316 | ||
|
|
13d360030f | ||
|
|
66daebe88f | ||
|
|
4071ba29da | ||
|
|
21f9e2df40 | ||
|
|
80d326310e | ||
|
|
53fc705b13 | ||
|
|
d5af53888a | ||
|
|
a7a37249f7 | ||
|
|
6af6ff2a0a | ||
|
|
30ca282594 | ||
|
|
0fbc0475f3 | ||
|
|
e5e77381f0 | ||
|
|
066514e2a9 | ||
|
|
045a1737f8 |
23
.cursorrules
23
.cursorrules
@@ -1,23 +0,0 @@
|
||||
Hermes-Agent is an agent harness for LLMs.
|
||||
|
||||
When building, the tool functionality is in the tools/ directory, where each specific tool (or in some cases, tools that are built for the same execution category or api) are placed in a script each their own.
|
||||
|
||||
Each tool is then consolidated in the model_tools.py file in the repo root.
|
||||
|
||||
There is also a way to consolidate sets of tools in toolsets.py for the agent to use.
|
||||
|
||||
The primary agent runner code is in run_agent, but other runners could be developed using the tools and framework.
|
||||
|
||||
Always ensure consistency between tools, the model_tools.py and toolsets.py when changing any of them, otherwise they could become desynced in a way that is detrimental to functionality.
|
||||
|
||||
The expected pathway for using API keys is to setup and place them in a .env file in the repo root.
|
||||
|
||||
Test scripts will be placed in tests/
|
||||
|
||||
The run_agent loop is setup to:
|
||||
- Process the enabled toolsets to provide to the model,
|
||||
- Pipe in a prompt or problem from the input to the agent,
|
||||
- Loop the LLM each time it calls a tool, until the model decides no more tools are needed and provides a natural language response,
|
||||
- Return that response.
|
||||
|
||||
There are additional caveats for logging, where we restructure the "tools" as a system prompt for storage later into a format that can be used and handled properly later.
|
||||
18
.editorconfig
Normal file
18
.editorconfig
Normal file
@@ -0,0 +1,18 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.{yml,yaml,json,toml}]
|
||||
indent_size = 2
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
[Makefile]
|
||||
indent_style = tab
|
||||
266
.env.example
266
.env.example
@@ -1,49 +1,265 @@
|
||||
# Hermes Agent Environment Configuration
|
||||
# Copy this file to .env and fill in your API keys
|
||||
# Get API keys from the URLs listed below
|
||||
|
||||
# =============================================================================
|
||||
# REQUIRED API KEYS
|
||||
# LLM PROVIDER (OpenRouter)
|
||||
# =============================================================================
|
||||
# OpenRouter provides access to many models through one API
|
||||
# All LLM calls go through OpenRouter - no direct provider keys needed
|
||||
# Get your key at: https://openrouter.ai/keys
|
||||
OPENROUTER_API_KEY=
|
||||
|
||||
# Anthropic API Key - Main agent model
|
||||
# Get at: https://console.anthropic.com/
|
||||
ANTHROPIC_API_KEY=
|
||||
# Default model to use (OpenRouter format: provider/model)
|
||||
# Examples: anthropic/claude-opus-4.6, openai/gpt-4o, google/gemini-3-flash-preview, zhipuai/glm-4-plus
|
||||
LLM_MODEL=anthropic/claude-opus-4.6
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (z.ai / GLM)
|
||||
# =============================================================================
|
||||
# z.ai provides access to ZhipuAI GLM models (GLM-4-Plus, etc.)
|
||||
# Get your key at: https://z.ai or https://open.bigmodel.cn
|
||||
GLM_API_KEY=
|
||||
# GLM_BASE_URL=https://api.z.ai/api/paas/v4 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (Kimi / Moonshot)
|
||||
# =============================================================================
|
||||
# Kimi Code provides access to Moonshot AI coding models (kimi-k2.5, etc.)
|
||||
# Get your key at: https://platform.kimi.ai (Kimi Code console)
|
||||
# Keys prefixed sk-kimi- use the Kimi Code API (api.kimi.com) by default.
|
||||
# Legacy keys from platform.moonshot.ai need KIMI_BASE_URL override below.
|
||||
KIMI_API_KEY=
|
||||
# KIMI_BASE_URL=https://api.kimi.com/coding/v1 # Default for sk-kimi- keys
|
||||
# KIMI_BASE_URL=https://api.moonshot.ai/v1 # For legacy Moonshot keys
|
||||
# KIMI_BASE_URL=https://api.moonshot.cn/v1 # For Moonshot China keys
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (MiniMax)
|
||||
# =============================================================================
|
||||
# MiniMax provides access to MiniMax models (global endpoint)
|
||||
# Get your key at: https://www.minimax.io
|
||||
MINIMAX_API_KEY=
|
||||
# MINIMAX_BASE_URL=https://api.minimax.io/v1 # Override default base URL
|
||||
|
||||
# MiniMax China endpoint (for users in mainland China)
|
||||
MINIMAX_CN_API_KEY=
|
||||
# MINIMAX_CN_BASE_URL=https://api.minimaxi.com/v1 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
# Firecrawl API Key - Web search, extract, and crawl
|
||||
# Get at: https://firecrawl.dev/
|
||||
FIRECRAWL_API_KEY=
|
||||
|
||||
# Nous Research API Key - Vision analysis and multi-model reasoning
|
||||
# Get at: https://inference-api.nousresearch.com/
|
||||
NOUS_API_KEY=
|
||||
|
||||
# Morph API Key - Terminal/command execution tools
|
||||
# Get at: https://morph.so/
|
||||
MORPH_API_KEY=
|
||||
|
||||
# FAL.ai API Key - Image generation
|
||||
# Get at: https://fal.ai/
|
||||
FAL_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# OPTIONAL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
# OpenAI API Key - Optional, for enhanced Hecate features
|
||||
# Get at: https://platform.openai.com/
|
||||
OPENAI_API_KEY=
|
||||
# Honcho - Cross-session AI-native user modeling (optional)
|
||||
# Builds a persistent understanding of the user across sessions and tools.
|
||||
# Get at: https://app.honcho.dev
|
||||
# Also requires ~/.honcho/config.json with enabled=true (see README).
|
||||
HONCHO_API_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# OPTIONAL CONFIGURATION
|
||||
# TERMINAL TOOL CONFIGURATION (mini-swe-agent backend)
|
||||
# =============================================================================
|
||||
# Backend type: "local", "singularity", "docker", "modal", or "ssh"
|
||||
# Terminal backend is configured in ~/.hermes/config.yaml (terminal.backend).
|
||||
# Use 'hermes setup' or 'hermes config set terminal.backend docker' to change.
|
||||
# Supported: local, docker, singularity, modal, ssh
|
||||
#
|
||||
# Only override here if you need to force a backend without touching config.yaml:
|
||||
# TERMINAL_ENV=local
|
||||
|
||||
# Terminal Tool Settings
|
||||
HECATE_VM_LIFETIME_SECONDS=300
|
||||
HECATE_DEFAULT_SNAPSHOT_ID=snapshot_p5294qxt
|
||||
# Container images (for singularity/docker/modal backends)
|
||||
# TERMINAL_DOCKER_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
# TERMINAL_SINGULARITY_IMAGE=docker://nikolaik/python-nodejs:python3.11-nodejs20
|
||||
TERMINAL_MODAL_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
|
||||
# Debug Logging (set to "true" to enable, logs saved to ./logs/)
|
||||
|
||||
# Working directory for terminal commands
|
||||
# For local backend: "." means current directory (resolved automatically)
|
||||
# For remote backends (ssh/docker/modal/singularity): use an absolute path
|
||||
# INSIDE the target environment, or leave unset for the backend's default
|
||||
# (/root for modal, / for docker, ~ for ssh). Do NOT use a host-local path.
|
||||
# Usually managed by config.yaml (terminal.cwd) — uncomment to override
|
||||
# TERMINAL_CWD=.
|
||||
|
||||
# Default command timeout in seconds
|
||||
TERMINAL_TIMEOUT=60
|
||||
|
||||
# Cleanup inactive environments after this many seconds
|
||||
TERMINAL_LIFETIME_SECONDS=300
|
||||
|
||||
# =============================================================================
|
||||
# SSH REMOTE EXECUTION (for TERMINAL_ENV=ssh)
|
||||
# =============================================================================
|
||||
# Run terminal commands on a remote server via SSH.
|
||||
# Agent code stays on your machine, commands execute remotely.
|
||||
#
|
||||
# SECURITY BENEFITS:
|
||||
# - Agent cannot read your .env file (API keys protected)
|
||||
# - Agent cannot modify its own code
|
||||
# - Remote server acts as isolated sandbox
|
||||
# - Can safely configure passwordless sudo on remote
|
||||
#
|
||||
# TERMINAL_SSH_HOST=192.168.1.100
|
||||
# TERMINAL_SSH_USER=agent
|
||||
# TERMINAL_SSH_PORT=22
|
||||
# TERMINAL_SSH_KEY=~/.ssh/id_rsa
|
||||
|
||||
# =============================================================================
|
||||
# SUDO SUPPORT (works with ALL terminal backends)
|
||||
# =============================================================================
|
||||
# If set, enables sudo commands by piping password via `sudo -S`.
|
||||
# Works with: local, docker, singularity, modal, and ssh backends.
|
||||
#
|
||||
# SECURITY WARNING: Password stored in plaintext. Only use on trusted machines.
|
||||
#
|
||||
# ALTERNATIVES:
|
||||
# - For SSH backend: Configure passwordless sudo on the remote server
|
||||
# - For containers: Run as root inside the container (no sudo needed)
|
||||
# - For local: Configure /etc/sudoers for specific commands
|
||||
# - For CLI: Leave unset - you'll be prompted interactively with 45s timeout
|
||||
#
|
||||
# SUDO_PASSWORD=your_password_here
|
||||
|
||||
# =============================================================================
|
||||
# MODAL CLOUD BACKEND (Optional - for TERMINAL_ENV=modal)
|
||||
# =============================================================================
|
||||
# Modal uses CLI authentication, not environment variables.
|
||||
# Run: pip install modal && modal setup
|
||||
# This will authenticate via browser and store credentials locally.
|
||||
# No API key needed in .env - Modal handles auth automatically.
|
||||
|
||||
# =============================================================================
|
||||
# BROWSER TOOL CONFIGURATION (agent-browser + Browserbase)
|
||||
# =============================================================================
|
||||
# Browser automation requires Browserbase cloud service for remote browser execution.
|
||||
# This allows the agent to navigate websites, fill forms, and extract information.
|
||||
#
|
||||
# STEALTH MODES:
|
||||
# - Basic Stealth: ALWAYS active (random fingerprints, auto CAPTCHA solving)
|
||||
# - Advanced Stealth: Requires BROWSERBASE_ADVANCED_STEALTH=true (Scale Plan only)
|
||||
|
||||
# Browserbase API Key - Cloud browser execution
|
||||
# Get at: https://browserbase.com/
|
||||
BROWSERBASE_API_KEY=
|
||||
|
||||
# Browserbase Project ID - From your Browserbase dashboard
|
||||
BROWSERBASE_PROJECT_ID=
|
||||
|
||||
# Enable residential proxies for better CAPTCHA solving (default: true)
|
||||
# Routes traffic through residential IPs, significantly improves success rate
|
||||
BROWSERBASE_PROXIES=true
|
||||
|
||||
# Enable advanced stealth mode (default: false, requires Scale Plan)
|
||||
# Uses custom Chromium build to avoid bot detection altogether
|
||||
BROWSERBASE_ADVANCED_STEALTH=false
|
||||
|
||||
# Browser session timeout in seconds (default: 300)
|
||||
# Sessions are cleaned up after this duration of inactivity
|
||||
BROWSER_SESSION_TIMEOUT=300
|
||||
|
||||
# Browser inactivity timeout - auto-cleanup inactive sessions (default: 120 = 2 min)
|
||||
# Browser sessions are automatically closed after this period of no activity
|
||||
BROWSER_INACTIVITY_TIMEOUT=120
|
||||
|
||||
# =============================================================================
|
||||
# SESSION LOGGING
|
||||
# =============================================================================
|
||||
# Session trajectories are automatically saved to logs/ directory
|
||||
# Format: logs/session_YYYYMMDD_HHMMSS_UUID.json
|
||||
# Contains full conversation history in trajectory format for debugging/replay
|
||||
|
||||
# =============================================================================
|
||||
# VOICE TRANSCRIPTION & OPENAI TTS
|
||||
# =============================================================================
|
||||
# Required for voice message transcription (Whisper) and OpenAI TTS voices.
|
||||
# Uses OpenAI's API directly (not via OpenRouter).
|
||||
# Named VOICE_TOOLS_OPENAI_KEY to avoid interference with OpenRouter.
|
||||
# Get at: https://platform.openai.com/api-keys
|
||||
VOICE_TOOLS_OPENAI_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# SLACK INTEGRATION
|
||||
# =============================================================================
|
||||
# Slack Bot Token - From Slack App settings (OAuth & Permissions)
|
||||
# Get at: https://api.slack.com/apps
|
||||
# SLACK_BOT_TOKEN=xoxb-...
|
||||
|
||||
# Slack App Token - For Socket Mode (App-Level Tokens in Slack App settings)
|
||||
# SLACK_APP_TOKEN=xapp-...
|
||||
|
||||
# Slack allowed users (comma-separated Slack user IDs)
|
||||
# SLACK_ALLOWED_USERS=
|
||||
|
||||
# WhatsApp (built-in Baileys bridge — run `hermes whatsapp` to pair)
|
||||
# WHATSAPP_ENABLED=false
|
||||
# WHATSAPP_ALLOWED_USERS=15551234567
|
||||
|
||||
# Gateway-wide: allow ALL users without an allowlist (default: false = deny)
|
||||
# Only set to true if you intentionally want open access.
|
||||
# GATEWAY_ALLOW_ALL_USERS=false
|
||||
|
||||
# =============================================================================
|
||||
# RESPONSE PACING
|
||||
# =============================================================================
|
||||
# Human-like delays between message chunks on messaging platforms.
|
||||
# Makes the bot feel less robotic.
|
||||
# HERMES_HUMAN_DELAY_MODE=off # off | natural | custom
|
||||
# HERMES_HUMAN_DELAY_MIN_MS=800 # Min delay in ms (custom mode)
|
||||
# HERMES_HUMAN_DELAY_MAX_MS=2500 # Max delay in ms (custom mode)
|
||||
|
||||
# =============================================================================
|
||||
# DEBUG OPTIONS
|
||||
# =============================================================================
|
||||
WEB_TOOLS_DEBUG=false
|
||||
VISION_TOOLS_DEBUG=false
|
||||
MOA_TOOLS_DEBUG=false
|
||||
IMAGE_TOOLS_DEBUG=false
|
||||
|
||||
# =============================================================================
|
||||
# CONTEXT COMPRESSION (Auto-shrinks long conversations)
|
||||
# =============================================================================
|
||||
# When conversation approaches model's context limit, middle turns are
|
||||
# automatically summarized to free up space.
|
||||
#
|
||||
# Context compression is configured in ~/.hermes/config.yaml under compression:
|
||||
# CONTEXT_COMPRESSION_ENABLED=true # Enable auto-compression (default: true)
|
||||
# CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit
|
||||
# Model is set via compression.summary_model in config.yaml (default: google/gemini-3-flash-preview)
|
||||
|
||||
# =============================================================================
|
||||
# RL TRAINING (Tinker + Atropos)
|
||||
# =============================================================================
|
||||
# Run reinforcement learning training on language models using the Tinker API.
|
||||
# Requires the rl-server to be running (from tinker-atropos package).
|
||||
|
||||
# Tinker API Key - RL training service
|
||||
# Get at: https://tinker-console.thinkingmachines.ai/keys
|
||||
TINKER_API_KEY=
|
||||
|
||||
# Weights & Biases API Key - Experiment tracking and metrics
|
||||
# Get at: https://wandb.ai/authorize
|
||||
WANDB_API_KEY=
|
||||
|
||||
# RL API Server URL (default: http://localhost:8080)
|
||||
# Change if running the rl-server on a different host/port
|
||||
# RL_API_URL=http://localhost:8080
|
||||
|
||||
# =============================================================================
|
||||
# SKILLS HUB (GitHub integration for skill search/install/publish)
|
||||
# =============================================================================
|
||||
|
||||
# GitHub Personal Access Token — for higher API rate limits on skill search/install
|
||||
# Get at: https://github.com/settings/tokens (Fine-grained recommended)
|
||||
# GITHUB_TOKEN=ghp_xxxxxxxxxxxxxxxxxxxx
|
||||
|
||||
# GitHub App credentials (optional — for bot identity on PRs)
|
||||
# GITHUB_APP_ID=
|
||||
# GITHUB_APP_PRIVATE_KEY_PATH=
|
||||
# GITHUB_APP_INSTALLATION_ID=
|
||||
|
||||
144
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
144
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,144 @@
|
||||
name: "🐛 Bug Report"
|
||||
description: Report a bug — something that's broken, crashes, or behaves incorrectly.
|
||||
title: "[Bug]: "
|
||||
labels: ["bug"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for reporting a bug! Please fill out the sections below so we can reproduce and fix it quickly.
|
||||
|
||||
**Before submitting**, please:
|
||||
- [ ] Search [existing issues](https://github.com/NousResearch/hermes-agent/issues) to avoid duplicates
|
||||
- [ ] Update to the latest version (`hermes update`) and confirm the bug still exists
|
||||
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Bug Description
|
||||
description: A clear description of what's broken. Include error messages, tracebacks, or screenshots if relevant.
|
||||
placeholder: |
|
||||
What happened? What did you expect to happen instead?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Minimal steps to trigger the bug. The more specific, the faster we can fix it.
|
||||
placeholder: |
|
||||
1. Run `hermes chat`
|
||||
2. Send the message "..."
|
||||
3. Agent calls tool X
|
||||
4. Error appears: ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: expected
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: What should have happened instead?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: actual
|
||||
attributes:
|
||||
label: Actual Behavior
|
||||
description: What actually happened? Include full error output if available.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: component
|
||||
attributes:
|
||||
label: Affected Component
|
||||
description: Which part of Hermes is affected?
|
||||
multiple: true
|
||||
options:
|
||||
- CLI (interactive chat)
|
||||
- Gateway (Telegram/Discord/Slack/WhatsApp)
|
||||
- Setup / Installation
|
||||
- Tools (terminal, file ops, web, code execution, etc.)
|
||||
- Skills (skill loading, skill hub, skill guard)
|
||||
- Agent Core (conversation loop, context compression, memory)
|
||||
- Configuration (config.yaml, .env, hermes setup)
|
||||
- Other
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: platform
|
||||
attributes:
|
||||
label: Messaging Platform (if gateway-related)
|
||||
description: Which platform adapter is affected?
|
||||
multiple: true
|
||||
options:
|
||||
- N/A (CLI only)
|
||||
- Telegram
|
||||
- Discord
|
||||
- Slack
|
||||
- WhatsApp
|
||||
|
||||
- type: input
|
||||
id: os
|
||||
attributes:
|
||||
label: Operating System
|
||||
description: e.g. Ubuntu 24.04, macOS 15.2, Windows 11
|
||||
placeholder: Ubuntu 24.04
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: python-version
|
||||
attributes:
|
||||
label: Python Version
|
||||
description: Output of `python --version`
|
||||
placeholder: "3.11.9"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: hermes-version
|
||||
attributes:
|
||||
label: Hermes Version
|
||||
description: Output of `hermes version`
|
||||
placeholder: "2.1.0"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Relevant Logs / Traceback
|
||||
description: Paste any error output, traceback, or log messages. This will be auto-formatted as code.
|
||||
render: shell
|
||||
|
||||
- type: textarea
|
||||
id: root-cause
|
||||
attributes:
|
||||
label: Root Cause Analysis (optional)
|
||||
description: |
|
||||
If you've dug into the code and identified the root cause, share it here.
|
||||
Include file paths, line numbers, and code snippets if possible. This massively speeds up fixes.
|
||||
placeholder: |
|
||||
The bug is in `gateway/run.py` line 949. `len(history)` counts session_meta entries
|
||||
but `agent_messages` was built from filtered history...
|
||||
|
||||
- type: textarea
|
||||
id: proposed-fix
|
||||
attributes:
|
||||
label: Proposed Fix (optional)
|
||||
description: If you have a fix in mind (or a PR ready), describe it here.
|
||||
placeholder: |
|
||||
Replace `.get()` with `.pop()` on line 289 of `gateway/platforms/base.py`
|
||||
to actually clear the pending message after retrieval.
|
||||
|
||||
- type: checkboxes
|
||||
id: pr-ready
|
||||
attributes:
|
||||
label: Are you willing to submit a PR for this?
|
||||
options:
|
||||
- label: I'd like to fix this myself and submit a PR
|
||||
11
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
11
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: 💬 Nous Research Discord
|
||||
url: https://discord.gg/NousResearch
|
||||
about: For quick questions, showcasing projects, sharing skills, and community chat.
|
||||
- name: 📖 Documentation
|
||||
url: https://github.com/NousResearch/hermes-agent/blob/main/README.md
|
||||
about: Check the README and docs before opening an issue.
|
||||
- name: 🤝 Contributing Guide
|
||||
url: https://github.com/NousResearch/hermes-agent/blob/main/CONTRIBUTING.md
|
||||
about: Read this before submitting a PR.
|
||||
73
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
73
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
name: "✨ Feature Request"
|
||||
description: Suggest a new feature or improvement.
|
||||
title: "[Feature]: "
|
||||
labels: ["enhancement"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for the suggestion! Before submitting, please consider:
|
||||
|
||||
- **Is this a new skill?** Most capabilities should be [skills, not tools](https://github.com/NousResearch/hermes-agent/blob/main/CONTRIBUTING.md#should-it-be-a-skill-or-a-tool). If it's a specialized integration (crypto, NFT, niche SaaS), it belongs on the Skills Hub, not bundled.
|
||||
- **Search [existing issues](https://github.com/NousResearch/hermes-agent/issues)** — someone may have already proposed this.
|
||||
|
||||
- type: textarea
|
||||
id: problem
|
||||
attributes:
|
||||
label: Problem or Use Case
|
||||
description: What problem does this solve? What are you trying to do that you can't today?
|
||||
placeholder: |
|
||||
I'm trying to use Hermes with [provider/platform/workflow] but currently
|
||||
there's no way to...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: solution
|
||||
attributes:
|
||||
label: Proposed Solution
|
||||
description: How do you think this should work? Be as specific as you can — CLI flags, config options, UI behavior.
|
||||
placeholder: |
|
||||
Add a `--foo` flag to `hermes chat` that enables...
|
||||
Or: Add a config key `bar.baz` that controls...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
description: What other approaches did you consider? Why is the proposed solution better?
|
||||
|
||||
- type: dropdown
|
||||
id: type
|
||||
attributes:
|
||||
label: Feature Type
|
||||
options:
|
||||
- New tool
|
||||
- New bundled skill
|
||||
- CLI improvement
|
||||
- Gateway / messaging improvement
|
||||
- Configuration option
|
||||
- Performance / reliability
|
||||
- Developer experience (tests, docs, CI)
|
||||
- Other
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: scope
|
||||
attributes:
|
||||
label: Scope
|
||||
description: How big is this change?
|
||||
options:
|
||||
- Small (single file, < 50 lines)
|
||||
- Medium (few files, < 300 lines)
|
||||
- Large (new module or significant refactor)
|
||||
|
||||
- type: checkboxes
|
||||
id: pr-ready
|
||||
attributes:
|
||||
label: Contribution
|
||||
options:
|
||||
- label: I'd like to implement this myself and submit a PR
|
||||
100
.github/ISSUE_TEMPLATE/setup_help.yml
vendored
Normal file
100
.github/ISSUE_TEMPLATE/setup_help.yml
vendored
Normal file
@@ -0,0 +1,100 @@
|
||||
name: "🔧 Setup / Installation Help"
|
||||
description: Having trouble installing or configuring Hermes? Ask here.
|
||||
title: "[Setup]: "
|
||||
labels: ["setup"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Sorry you're having trouble! Please fill out the details below so we can help.
|
||||
|
||||
**Quick checks first:**
|
||||
- Run `hermes doctor` and include the output below
|
||||
- Try `hermes update` to get the latest version
|
||||
- Check the [README troubleshooting section](https://github.com/NousResearch/hermes-agent#troubleshooting)
|
||||
- For general questions, consider the [Nous Research Discord](https://discord.gg/NousResearch) for faster help
|
||||
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: What's Going Wrong?
|
||||
description: Describe what you're trying to do and where it fails.
|
||||
placeholder: |
|
||||
I ran `hermes setup` and selected Nous Portal, but when I try to
|
||||
start the gateway I get...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps Taken
|
||||
description: What did you do? Include the exact commands you ran.
|
||||
placeholder: |
|
||||
1. Ran the install script: `curl -fsSL ... | bash`
|
||||
2. Ran `hermes setup` and chose "Quick setup"
|
||||
3. Selected OpenRouter, entered API key
|
||||
4. Ran `hermes chat` and got error...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: install-method
|
||||
attributes:
|
||||
label: Installation Method
|
||||
options:
|
||||
- Install script (curl | bash)
|
||||
- Manual clone + pip/uv install
|
||||
- PowerShell installer (Windows)
|
||||
- Docker
|
||||
- Other
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: os
|
||||
attributes:
|
||||
label: Operating System
|
||||
placeholder: Ubuntu 24.04 / macOS 15.2 / Windows 11
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: python-version
|
||||
attributes:
|
||||
label: Python Version
|
||||
description: Output of `python --version` (or `python3 --version`)
|
||||
placeholder: "3.11.9"
|
||||
|
||||
- type: input
|
||||
id: hermes-version
|
||||
attributes:
|
||||
label: Hermes Version
|
||||
description: Output of `hermes version` (if install got that far)
|
||||
placeholder: "2.1.0"
|
||||
|
||||
- type: textarea
|
||||
id: doctor-output
|
||||
attributes:
|
||||
label: Output of `hermes doctor`
|
||||
description: Run `hermes doctor` and paste the full output. This will be auto-formatted.
|
||||
render: shell
|
||||
|
||||
- type: textarea
|
||||
id: error-output
|
||||
attributes:
|
||||
label: Full Error Output
|
||||
description: Paste the complete error message or traceback. This will be auto-formatted.
|
||||
render: shell
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: tried
|
||||
attributes:
|
||||
label: What I've Already Tried
|
||||
description: List any fixes or workarounds you've already attempted.
|
||||
placeholder: |
|
||||
- Ran `hermes update`
|
||||
- Tried reinstalling with `pip install -e ".[all]"`
|
||||
- Checked that OPENROUTER_API_KEY is set in ~/.hermes/.env
|
||||
75
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
75
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
## What does this PR do?
|
||||
|
||||
<!-- Describe the change clearly. What problem does it solve? Why is this approach the right one? -->
|
||||
|
||||
|
||||
|
||||
## Related Issue
|
||||
|
||||
<!-- Link the issue this PR addresses. If no issue exists, consider creating one first. -->
|
||||
|
||||
Fixes #
|
||||
|
||||
## Type of Change
|
||||
|
||||
<!-- Check the one that applies. -->
|
||||
|
||||
- [ ] 🐛 Bug fix (non-breaking change that fixes an issue)
|
||||
- [ ] ✨ New feature (non-breaking change that adds functionality)
|
||||
- [ ] 🔒 Security fix
|
||||
- [ ] 📝 Documentation update
|
||||
- [ ] ✅ Tests (adding or improving test coverage)
|
||||
- [ ] ♻️ Refactor (no behavior change)
|
||||
- [ ] 🎯 New skill (bundled or hub)
|
||||
|
||||
## Changes Made
|
||||
|
||||
<!-- List the specific changes. Include file paths for code changes. -->
|
||||
|
||||
-
|
||||
|
||||
## How to Test
|
||||
|
||||
<!-- Steps to verify this change works. For bugs: reproduction steps + proof that the fix works. -->
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
## Checklist
|
||||
|
||||
<!-- Complete these before requesting review. -->
|
||||
|
||||
### Code
|
||||
|
||||
- [ ] I've read the [Contributing Guide](https://github.com/NousResearch/hermes-agent/blob/main/CONTRIBUTING.md)
|
||||
- [ ] My commit messages follow [Conventional Commits](https://www.conventionalcommits.org/) (`fix(scope):`, `feat(scope):`, etc.)
|
||||
- [ ] I searched for [existing PRs](https://github.com/NousResearch/hermes-agent/pulls) to make sure this isn't a duplicate
|
||||
- [ ] My PR contains **only** changes related to this fix/feature (no unrelated commits)
|
||||
- [ ] I've run `make check` (lint + test) and all checks pass
|
||||
- [ ] I've added tests for my changes (required for bug fixes, strongly encouraged for features)
|
||||
- [ ] I've tested on my platform: <!-- e.g. Ubuntu 24.04, macOS 15.2, Windows 11 -->
|
||||
|
||||
### Documentation & Housekeeping
|
||||
|
||||
<!-- Check all that apply. It's OK to check "N/A" if a category doesn't apply to your change. -->
|
||||
|
||||
- [ ] I've updated relevant documentation (README, `docs/`, docstrings) — or N/A
|
||||
- [ ] I've updated `cli-config.yaml.example` if I added/changed config keys — or N/A
|
||||
- [ ] I've updated `CONTRIBUTING.md` or `AGENTS.md` if I changed architecture or workflows — or N/A
|
||||
- [ ] I've considered cross-platform impact (Windows, macOS) per the [compatibility guide](https://github.com/NousResearch/hermes-agent/blob/main/CONTRIBUTING.md#cross-platform-compatibility) — or N/A
|
||||
- [ ] I've updated tool descriptions/schemas if I changed tool behavior — or N/A
|
||||
|
||||
## For New Skills
|
||||
|
||||
<!-- Only fill this out if you're adding a skill. Delete this section otherwise. -->
|
||||
|
||||
- [ ] This skill is **broadly useful** to most users (if bundled) — see [Contributing Guide](https://github.com/NousResearch/hermes-agent/blob/main/CONTRIBUTING.md#should-the-skill-be-bundled)
|
||||
- [ ] SKILL.md follows the [standard format](https://github.com/NousResearch/hermes-agent/blob/main/CONTRIBUTING.md#skillmd-format) (frontmatter, trigger conditions, steps, pitfalls)
|
||||
- [ ] No external dependencies that aren't already available (prefer stdlib, curl, existing Hermes tools)
|
||||
- [ ] I've tested the skill end-to-end: `hermes --toolsets skills -q "Use the X skill to do Y"`
|
||||
|
||||
## Screenshots / Logs
|
||||
|
||||
<!-- If applicable, add screenshots or log output showing the fix/feature in action. -->
|
||||
|
||||
60
.github/workflows/deploy-site.yml
vendored
Normal file
60
.github/workflows/deploy-site.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Deploy Site
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'website/**'
|
||||
- 'landingpage/**'
|
||||
- '.github/workflows/deploy-site.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
pages: write
|
||||
id-token: write
|
||||
|
||||
concurrency:
|
||||
group: pages
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build-and-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deploy.outputs.page_url }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
cache: npm
|
||||
cache-dependency-path: website/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
working-directory: website
|
||||
|
||||
- name: Build Docusaurus
|
||||
run: npm run build
|
||||
working-directory: website
|
||||
|
||||
- name: Stage deployment
|
||||
run: |
|
||||
mkdir -p _site/docs
|
||||
# Landing page at root
|
||||
cp -r landingpage/* _site/
|
||||
# Docusaurus at /docs/
|
||||
cp -r website/build/* _site/docs/
|
||||
# CNAME so GitHub Pages keeps the custom domain between deploys
|
||||
echo "hermes-agent.nousresearch.com" > _site/CNAME
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-pages-artifact@v3
|
||||
with:
|
||||
path: _site
|
||||
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deploy
|
||||
uses: actions/deploy-pages@v4
|
||||
47
.github/workflows/tests.yml
vendored
Normal file
47
.github/workflows/tests.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
concurrency:
|
||||
group: ci-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
SRC: >-
|
||||
run_agent.py model_tools.py toolsets.py cli.py hermes_state.py batch_runner.py
|
||||
tools/ hermes_cli/ gateway/ agent/ cron/
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 3
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
- run: uvx ruff check $SRC
|
||||
- run: uvx ruff format --check $SRC
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
- run: uv python install 3.11
|
||||
- run: |
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all,dev]"
|
||||
- run: |
|
||||
source .venv/bin/activate
|
||||
python -m pytest tests/ -q --ignore=tests/integration --tb=short
|
||||
env:
|
||||
OPENROUTER_API_KEY: ""
|
||||
OPENAI_API_KEY: ""
|
||||
NOUS_API_KEY: ""
|
||||
59
.gitignore
vendored
59
.gitignore
vendored
@@ -1,32 +1,53 @@
|
||||
/venv/
|
||||
/_pycache/
|
||||
hecate/
|
||||
hecate-lib/
|
||||
*.pyc*
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Environments
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# Tools
|
||||
.ruff_cache/
|
||||
.mypy_cache/
|
||||
.pytest_cache/
|
||||
|
||||
# Editors
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Secrets & config
|
||||
.env
|
||||
.env.local
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
.env.development
|
||||
.env.test
|
||||
export*
|
||||
__pycache__/model_tools.cpython-310.pyc
|
||||
__pycache__/web_tools.cpython-310.pyc
|
||||
.env.*.local
|
||||
*.pem
|
||||
*.ppk
|
||||
|
||||
# Node
|
||||
node_modules/
|
||||
|
||||
# Project-specific
|
||||
logs/
|
||||
data/
|
||||
.pytest_cache/
|
||||
tmp/
|
||||
wandb/
|
||||
images/
|
||||
browser-use/
|
||||
agent-browser/
|
||||
source-data/
|
||||
testlogs/
|
||||
ignored/
|
||||
.worktrees/
|
||||
temp_vision_images/
|
||||
cli-config.yaml
|
||||
skills/.hub/
|
||||
hermes-*/*
|
||||
examples/
|
||||
export*
|
||||
privvy*
|
||||
run_datagen_*.sh
|
||||
tests/quick_test_dataset.jsonl
|
||||
tests/sample_dataset.jsonl
|
||||
run_datagen_kimik2-thinking.sh
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
run_datagen_sonnet.sh
|
||||
source-data/*
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
|
||||
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
[submodule "mini-swe-agent"]
|
||||
path = mini-swe-agent
|
||||
url = https://github.com/SWE-agent/mini-swe-agent
|
||||
[submodule "tinker-atropos"]
|
||||
path = tinker-atropos
|
||||
url = https://github.com/nousresearch/tinker-atropos
|
||||
18
.pre-commit-config.yaml
Normal file
18
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.15.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-merge-conflict
|
||||
- id: check-yaml
|
||||
args: [--allow-multiple-documents]
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=500]
|
||||
255
AGENTS.md
Normal file
255
AGENTS.md
Normal file
@@ -0,0 +1,255 @@
|
||||
# Hermes Agent - Development Guide
|
||||
|
||||
Instructions for AI coding assistants and developers working on the hermes-agent codebase.
|
||||
|
||||
## Development Environment
|
||||
|
||||
```bash
|
||||
make setup # First time: creates .venv, installs deps, sets up pre-commit
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
hermes-agent/
|
||||
├── run_agent.py # AIAgent class — core conversation loop
|
||||
├── model_tools.py # Tool orchestration, _discover_tools(), handle_function_call()
|
||||
├── toolsets.py # Toolset definitions, _HERMES_CORE_TOOLS list
|
||||
├── cli.py # HermesCLI class — interactive CLI orchestrator
|
||||
├── hermes_state.py # SessionDB — SQLite session store (FTS5 search)
|
||||
├── agent/ # Agent internals
|
||||
│ ├── prompt_builder.py # System prompt assembly
|
||||
│ ├── context_compressor.py # Auto context compression
|
||||
│ ├── prompt_caching.py # Anthropic prompt caching
|
||||
│ ├── auxiliary_client.py # Auxiliary LLM client (vision, summarization)
|
||||
│ ├── model_metadata.py # Model context lengths, token estimation
|
||||
│ ├── display.py # KawaiiSpinner, tool preview formatting
|
||||
│ ├── skill_commands.py # Skill slash commands (shared CLI/gateway)
|
||||
│ └── trajectory.py # Trajectory saving helpers
|
||||
├── hermes_cli/ # CLI subcommands and setup
|
||||
│ ├── main.py # Entry point — all `hermes` subcommands
|
||||
│ ├── config.py # DEFAULT_CONFIG, OPTIONAL_ENV_VARS, migration
|
||||
│ ├── commands.py # Slash command definitions + SlashCommandCompleter
|
||||
│ ├── callbacks.py # Terminal callbacks (clarify, sudo, approval)
|
||||
│ └── setup.py # Interactive setup wizard
|
||||
├── tools/ # Tool implementations (one file per tool)
|
||||
│ ├── registry.py # Central tool registry (schemas, handlers, dispatch)
|
||||
│ ├── approval.py # Dangerous command detection
|
||||
│ ├── terminal_tool.py # Terminal orchestration
|
||||
│ ├── process_registry.py # Background process management
|
||||
│ ├── file_tools.py # File read/write/search/patch
|
||||
│ ├── web_tools.py # Firecrawl search/extract
|
||||
│ ├── browser_tool.py # Browserbase browser automation
|
||||
│ ├── code_execution_tool.py # execute_code sandbox
|
||||
│ ├── delegate_tool.py # Subagent delegation
|
||||
│ ├── mcp_tool.py # MCP client (~1050 lines)
|
||||
│ └── environments/ # Terminal backends (local, docker, ssh, modal, daytona, singularity)
|
||||
├── gateway/ # Messaging platform gateway
|
||||
│ ├── run.py # Main loop, slash commands, message dispatch
|
||||
│ ├── session.py # SessionStore — conversation persistence
|
||||
│ └── platforms/ # Adapters: telegram, discord, slack, whatsapp, homeassistant, signal
|
||||
├── cron/ # Scheduler (jobs.py, scheduler.py)
|
||||
├── environments/ # RL training environments (Atropos)
|
||||
├── tests/ # Pytest suite (~2500+ tests)
|
||||
└── batch_runner.py # Parallel batch processing
|
||||
```
|
||||
|
||||
**User config:** `~/.hermes/config.yaml` (settings), `~/.hermes/.env` (API keys)
|
||||
|
||||
## File Dependency Chain
|
||||
|
||||
```
|
||||
tools/registry.py (no deps — imported by all tool files)
|
||||
↑
|
||||
tools/*.py (each calls registry.register() at import time)
|
||||
↑
|
||||
model_tools.py (imports tools/registry + triggers tool discovery)
|
||||
↑
|
||||
run_agent.py, cli.py, batch_runner.py, environments/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## AIAgent Class (run_agent.py)
|
||||
|
||||
```python
|
||||
class AIAgent:
|
||||
def __init__(self,
|
||||
model: str = "anthropic/claude-opus-4.6",
|
||||
max_iterations: int = 90,
|
||||
enabled_toolsets: list = None,
|
||||
disabled_toolsets: list = None,
|
||||
quiet_mode: bool = False,
|
||||
save_trajectories: bool = False,
|
||||
platform: str = None, # "cli", "telegram", etc.
|
||||
session_id: str = None,
|
||||
skip_context_files: bool = False,
|
||||
skip_memory: bool = False,
|
||||
# ... plus provider, api_mode, callbacks, routing params
|
||||
): ...
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
"""Simple interface — returns final response string."""
|
||||
|
||||
def run_conversation(self, user_message: str, system_message: str = None,
|
||||
conversation_history: list = None, task_id: str = None) -> dict:
|
||||
"""Full interface — returns dict with final_response + messages."""
|
||||
```
|
||||
|
||||
### Agent Loop
|
||||
|
||||
The core loop is inside `run_conversation()` — entirely synchronous:
|
||||
|
||||
```python
|
||||
while api_call_count < self.max_iterations and self.iteration_budget.remaining > 0:
|
||||
response = client.chat.completions.create(model=model, messages=messages, tools=tool_schemas)
|
||||
if response.tool_calls:
|
||||
for tool_call in response.tool_calls:
|
||||
result = handle_function_call(tool_call.name, tool_call.args, task_id)
|
||||
messages.append(tool_result_message(result))
|
||||
api_call_count += 1
|
||||
else:
|
||||
return response.content
|
||||
```
|
||||
|
||||
Messages follow OpenAI format: `{"role": "system/user/assistant/tool", ...}`. Reasoning content is stored in `assistant_msg["reasoning"]`.
|
||||
|
||||
---
|
||||
|
||||
## CLI Architecture (cli.py)
|
||||
|
||||
- **Rich** for banner/panels, **prompt_toolkit** for input with autocomplete
|
||||
- **KawaiiSpinner** (`agent/display.py`) — animated faces during API calls, `┊` activity feed for tool results
|
||||
- `load_cli_config()` in cli.py merges hardcoded defaults + user config YAML
|
||||
- `process_command()` is a method on `HermesCLI` (not in commands.py)
|
||||
- Skill slash commands: `agent/skill_commands.py` scans `~/.hermes/skills/`, injects as **user message** (not system prompt) to preserve prompt caching
|
||||
|
||||
### Adding CLI Commands
|
||||
|
||||
1. Add to `COMMANDS` dict in `hermes_cli/commands.py`
|
||||
2. Add handler in `HermesCLI.process_command()` in `cli.py`
|
||||
3. For persistent settings, use `save_config_value()` in `cli.py`
|
||||
|
||||
---
|
||||
|
||||
## Adding New Tools
|
||||
|
||||
Requires changes in **3 files**:
|
||||
|
||||
**1. Create `tools/your_tool.py`:**
|
||||
```python
|
||||
import json, os
|
||||
from tools.registry import registry
|
||||
|
||||
def check_requirements() -> bool:
|
||||
return bool(os.getenv("EXAMPLE_API_KEY"))
|
||||
|
||||
def example_tool(param: str, task_id: str = None) -> str:
|
||||
return json.dumps({"success": True, "data": "..."})
|
||||
|
||||
registry.register(
|
||||
name="example_tool",
|
||||
toolset="example",
|
||||
schema={"name": "example_tool", "description": "...", "parameters": {...}},
|
||||
handler=lambda args, **kw: example_tool(param=args.get("param", ""), task_id=kw.get("task_id")),
|
||||
check_fn=check_requirements,
|
||||
requires_env=["EXAMPLE_API_KEY"],
|
||||
)
|
||||
```
|
||||
|
||||
**2. Add import** in `model_tools.py` `_discover_tools()` list.
|
||||
|
||||
**3. Add to `toolsets.py`** — either `_HERMES_CORE_TOOLS` (all platforms) or a new toolset.
|
||||
|
||||
The registry handles schema collection, dispatch, availability checking, and error wrapping. All handlers MUST return a JSON string.
|
||||
|
||||
**Agent-level tools** (todo, memory): intercepted by `run_agent.py` before `handle_function_call()`. See `todo_tool.py` for the pattern.
|
||||
|
||||
---
|
||||
|
||||
## Adding Configuration
|
||||
|
||||
### config.yaml options:
|
||||
1. Add to `DEFAULT_CONFIG` in `hermes_cli/config.py`
|
||||
2. Bump `_config_version` (currently 5) to trigger migration for existing users
|
||||
|
||||
### .env variables:
|
||||
1. Add to `OPTIONAL_ENV_VARS` in `hermes_cli/config.py` with metadata:
|
||||
```python
|
||||
"NEW_API_KEY": {
|
||||
"description": "What it's for",
|
||||
"prompt": "Display name",
|
||||
"url": "https://...",
|
||||
"password": True,
|
||||
"category": "tool", # provider, tool, messaging, setting
|
||||
},
|
||||
```
|
||||
|
||||
### Config loaders (two separate systems):
|
||||
|
||||
| Loader | Used by | Location |
|
||||
|--------|---------|----------|
|
||||
| `load_cli_config()` | CLI mode | `cli.py` |
|
||||
| `load_config()` | `hermes tools`, `hermes setup` | `hermes_cli/config.py` |
|
||||
| Direct YAML load | Gateway | `gateway/run.py` |
|
||||
|
||||
---
|
||||
|
||||
## Important Policies
|
||||
|
||||
### Prompt Caching Must Not Break
|
||||
|
||||
Hermes-Agent ensures caching remains valid throughout a conversation. **Do NOT implement changes that would:**
|
||||
- Alter past context mid-conversation
|
||||
- Change toolsets mid-conversation
|
||||
- Reload memories or rebuild system prompts mid-conversation
|
||||
|
||||
Cache-breaking forces dramatically higher costs. The ONLY time we alter context is during context compression.
|
||||
|
||||
### Working Directory Behavior
|
||||
- **CLI**: Uses current directory (`.` → `os.getcwd()`)
|
||||
- **Messaging**: Uses `MESSAGING_CWD` env var (default: home directory)
|
||||
|
||||
---
|
||||
|
||||
## Known Pitfalls
|
||||
|
||||
### DO NOT use `simple_term_menu` for interactive menus
|
||||
Rendering bugs in tmux/iTerm2 — ghosting on scroll. Use `curses` (stdlib) instead. See `hermes_cli/tools_config.py` for the pattern.
|
||||
|
||||
### DO NOT use `\033[K` (ANSI erase-to-EOL) in spinner/display code
|
||||
Leaks as literal `?[K` text under `prompt_toolkit`'s `patch_stdout`. Use space-padding: `f"\r{line}{' ' * pad}"`.
|
||||
|
||||
### `_last_resolved_tool_names` is a process-global in `model_tools.py`
|
||||
When subagents overwrite this global, `execute_code` calls after delegation may fail with missing tool imports. Known bug.
|
||||
|
||||
### Tests must not write to `~/.hermes/`
|
||||
The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HERMES_HOME` to a temp dir. Never hardcode `~/.hermes/` paths in tests.
|
||||
|
||||
---
|
||||
|
||||
## Development Commands
|
||||
|
||||
```bash
|
||||
make setup # First time: .venv + deps + pre-commit hooks
|
||||
make check # Lint + test (mirrors CI — run before pushing)
|
||||
make lint # Ruff check
|
||||
make fmt # Ruff format + auto-fix
|
||||
make test # Full test suite (~2500 tests, ~2 min)
|
||||
make test-fast # Tests with fail-fast (-x)
|
||||
make test-watch # Rerun tests on file changes
|
||||
make dev-cli # Auto-restart CLI on file changes
|
||||
make dev-gateway # Auto-restart gateway on file changes
|
||||
```
|
||||
|
||||
For targeted testing, use `pytest` directly:
|
||||
|
||||
```bash
|
||||
python -m pytest tests/test_model_tools.py -q # Toolset resolution
|
||||
python -m pytest tests/test_cli_init.py -q # CLI config loading
|
||||
python -m pytest tests/gateway/ -q # Gateway tests
|
||||
python -m pytest tests/tools/ -q # Tool-level tests
|
||||
```
|
||||
|
||||
Formatting is enforced by **ruff** (config in `pyproject.toml`). Pre-commit hooks run on every commit.
|
||||
505
CONTRIBUTING.md
Normal file
505
CONTRIBUTING.md
Normal file
@@ -0,0 +1,505 @@
|
||||
# Contributing to Hermes Agent
|
||||
|
||||
Thank you for contributing to Hermes Agent! This guide covers everything you need: setting up your dev environment, understanding the architecture, deciding what to build, and getting your PR merged.
|
||||
|
||||
---
|
||||
|
||||
## Contribution Priorities
|
||||
|
||||
We value contributions in this order:
|
||||
|
||||
1. **Bug fixes** — crashes, incorrect behavior, data loss. Always top priority.
|
||||
2. **Cross-platform compatibility** — Windows, macOS, different Linux distros, different terminal emulators. We want Hermes to work everywhere.
|
||||
3. **Security hardening** — shell injection, prompt injection, path traversal, privilege escalation. See [Security](#security-considerations).
|
||||
4. **Performance and robustness** — retry logic, error handling, graceful degradation.
|
||||
5. **New skills** — but only broadly useful ones. See [Should it be a Skill or a Tool?](#should-it-be-a-skill-or-a-tool)
|
||||
6. **New tools** — rarely needed. Most capabilities should be skills. See below.
|
||||
7. **Documentation** — fixes, clarifications, new examples.
|
||||
|
||||
---
|
||||
|
||||
## Should it be a Skill or a Tool?
|
||||
|
||||
This is the most common question for new contributors. The answer is almost always **skill**.
|
||||
|
||||
### Make it a Skill when:
|
||||
|
||||
- The capability can be expressed as instructions + shell commands + existing tools
|
||||
- It wraps an external CLI or API that the agent can call via `terminal` or `web_extract`
|
||||
- It doesn't need custom Python integration or API key management baked into the agent
|
||||
- Examples: arXiv search, git workflows, Docker management, PDF processing, email via CLI tools
|
||||
|
||||
### Make it a Tool when:
|
||||
|
||||
- It requires end-to-end integration with API keys, auth flows, or multi-component configuration managed by the agent harness
|
||||
- It needs custom processing logic that must execute precisely every time (not "best effort" from LLM interpretation)
|
||||
- It handles binary data, streaming, or real-time events that can't go through the terminal
|
||||
- Examples: browser automation (Browserbase session management), TTS (audio encoding + platform delivery), vision analysis (base64 image handling)
|
||||
|
||||
### Should the Skill be bundled?
|
||||
|
||||
Bundled skills (in `skills/`) ship with every Hermes install. They should be **broadly useful to most users**:
|
||||
|
||||
- Document handling, web research, common dev workflows, system administration
|
||||
- Used regularly by a wide range of people
|
||||
|
||||
If your skill is official and useful but not universally needed (e.g., a paid service integration, a heavyweight dependency), put it in **`optional-skills/`** — it ships with the repo but isn't activated by default. Users can discover it via `hermes skills browse` (labeled "official") and install it with `hermes skills install` (no third-party warning, builtin trust).
|
||||
|
||||
If your skill is specialized, community-contributed, or niche, it's better suited for a **Skills Hub** — upload it to a skills registry and share it in the [Nous Research Discord](https://discord.gg/NousResearch). Users can install it with `hermes skills install`.
|
||||
|
||||
---
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
| Requirement | Notes |
|
||||
|-------------|-------|
|
||||
| **Git** | With `--recurse-submodules` support |
|
||||
| **Python 3.11+** | uv will install it if missing |
|
||||
| **uv** | Fast Python package manager ([install](https://docs.astral.sh/uv/)) |
|
||||
| **Node.js 18+** | Optional — needed for browser tools and WhatsApp bridge |
|
||||
|
||||
### Clone and install
|
||||
|
||||
```bash
|
||||
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
make setup # creates .venv, installs all deps
|
||||
```
|
||||
|
||||
### Configure for development
|
||||
|
||||
```bash
|
||||
mkdir -p ~/.hermes/{cron,sessions,logs,memories,skills}
|
||||
cp cli-config.yaml.example ~/.hermes/config.yaml
|
||||
touch ~/.hermes/.env
|
||||
|
||||
# Add at minimum an LLM provider key:
|
||||
echo 'OPENROUTER_API_KEY=sk-or-v1-your-key' >> ~/.hermes/.env
|
||||
```
|
||||
|
||||
### Common commands
|
||||
|
||||
```bash
|
||||
make test # run unit tests
|
||||
make lint # ruff check
|
||||
make fmt # ruff format + fix
|
||||
make check # lint + test (same as CI)
|
||||
make dev-cli # auto-restart hermes CLI on file changes
|
||||
make dev-gateway # auto-restart gateway on file changes
|
||||
make test-watch # rerun tests on file changes
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
hermes-agent/
|
||||
├── run_agent.py # AIAgent class — core conversation loop, tool dispatch, session persistence
|
||||
├── cli.py # HermesCLI class — interactive TUI, prompt_toolkit integration
|
||||
├── model_tools.py # Tool orchestration (thin layer over tools/registry.py)
|
||||
├── toolsets.py # Tool groupings and presets (hermes-cli, hermes-telegram, etc.)
|
||||
├── hermes_state.py # SQLite session database with FTS5 full-text search, session titles
|
||||
├── batch_runner.py # Parallel batch processing for trajectory generation
|
||||
│
|
||||
├── agent/ # Agent internals (extracted modules)
|
||||
│ ├── prompt_builder.py # System prompt assembly (identity, skills, context files, memory)
|
||||
│ ├── context_compressor.py # Auto-summarization when approaching context limits
|
||||
│ ├── auxiliary_client.py # Resolves auxiliary OpenAI clients (summarization, vision)
|
||||
│ ├── display.py # KawaiiSpinner, tool progress formatting
|
||||
│ ├── model_metadata.py # Model context lengths, token estimation
|
||||
│ └── trajectory.py # Trajectory saving helpers
|
||||
│
|
||||
├── hermes_cli/ # CLI command implementations
|
||||
│ ├── main.py # Entry point, argument parsing, command dispatch
|
||||
│ ├── config.py # Config management, migration, env var definitions
|
||||
│ ├── setup.py # Interactive setup wizard
|
||||
│ ├── auth.py # Provider resolution, OAuth, Nous Portal
|
||||
│ ├── models.py # OpenRouter model selection lists
|
||||
│ ├── banner.py # Welcome banner, ASCII art
|
||||
│ ├── commands.py # Slash command definitions + autocomplete
|
||||
│ ├── callbacks.py # Interactive callbacks (clarify, sudo, approval)
|
||||
│ ├── doctor.py # Diagnostics
|
||||
│ └── skills_hub.py # Skills Hub CLI + /skills slash command
|
||||
│
|
||||
├── tools/ # Tool implementations (self-registering)
|
||||
│ ├── registry.py # Central tool registry (schemas, handlers, dispatch)
|
||||
│ ├── approval.py # Dangerous command detection + per-session approval
|
||||
│ ├── terminal_tool.py # Terminal orchestration (sudo, env lifecycle, backends)
|
||||
│ ├── file_operations.py # read_file, write_file, search, patch, etc.
|
||||
│ ├── web_tools.py # web_search, web_extract (Firecrawl + Gemini summarization)
|
||||
│ ├── vision_tools.py # Image analysis via multimodal models
|
||||
│ ├── delegate_tool.py # Subagent spawning and parallel task execution
|
||||
│ ├── code_execution_tool.py # Sandboxed Python with RPC tool access
|
||||
│ ├── session_search_tool.py # Search past conversations with FTS5 + summarization
|
||||
│ ├── cronjob_tools.py # Scheduled task management
|
||||
│ ├── skill_tools.py # Skill search, load, manage
|
||||
│ └── environments/ # Terminal execution backends
|
||||
│ ├── base.py # BaseEnvironment ABC
|
||||
│ ├── local.py, docker.py, ssh.py, singularity.py, modal.py, daytona.py
|
||||
│
|
||||
├── gateway/ # Messaging gateway
|
||||
│ ├── run.py # GatewayRunner — platform lifecycle, message routing, cron
|
||||
│ ├── config.py # Platform configuration resolution
|
||||
│ ├── session.py # Session store, context prompts, reset policies
|
||||
│ └── platforms/ # Platform adapters
|
||||
│ ├── telegram.py, discord_adapter.py, slack.py, whatsapp.py
|
||||
│
|
||||
├── scripts/ # Installer and bridge scripts
|
||||
│ ├── install.sh # Linux/macOS installer
|
||||
│ ├── install.ps1 # Windows PowerShell installer
|
||||
│ └── whatsapp-bridge/ # Node.js WhatsApp bridge (Baileys)
|
||||
│
|
||||
├── skills/ # Bundled skills (copied to ~/.hermes/skills/ on install)
|
||||
├── optional-skills/ # Official optional skills (discoverable via hub, not activated by default)
|
||||
├── environments/ # RL training environments (Atropos integration)
|
||||
├── tests/ # Test suite
|
||||
├── website/ # Documentation site (hermes-agent.nousresearch.com)
|
||||
│
|
||||
├── cli-config.yaml.example # Example configuration (copied to ~/.hermes/config.yaml)
|
||||
└── AGENTS.md # Development guide for AI coding assistants
|
||||
```
|
||||
|
||||
### User configuration (stored in `~/.hermes/`)
|
||||
|
||||
| Path | Purpose |
|
||||
|------|---------|
|
||||
| `~/.hermes/config.yaml` | Settings (model, terminal, toolsets, compression, etc.) |
|
||||
| `~/.hermes/.env` | API keys and secrets |
|
||||
| `~/.hermes/auth.json` | OAuth credentials (Nous Portal) |
|
||||
| `~/.hermes/skills/` | All active skills (bundled + hub-installed + agent-created) |
|
||||
| `~/.hermes/memories/` | Persistent memory (MEMORY.md, USER.md) |
|
||||
| `~/.hermes/state.db` | SQLite session database |
|
||||
| `~/.hermes/sessions/` | JSON session logs |
|
||||
| `~/.hermes/cron/` | Scheduled job data |
|
||||
| `~/.hermes/whatsapp/session/` | WhatsApp bridge credentials |
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Core Loop
|
||||
|
||||
```
|
||||
User message → AIAgent._run_agent_loop()
|
||||
├── Build system prompt (prompt_builder.py)
|
||||
├── Build API kwargs (model, messages, tools, reasoning config)
|
||||
├── Call LLM (OpenAI-compatible API)
|
||||
├── If tool_calls in response:
|
||||
│ ├── Execute each tool via registry dispatch
|
||||
│ ├── Add tool results to conversation
|
||||
│ └── Loop back to LLM call
|
||||
├── If text response:
|
||||
│ ├── Persist session to DB
|
||||
│ └── Return final_response
|
||||
└── Context compression if approaching token limit
|
||||
```
|
||||
|
||||
### Key Design Patterns
|
||||
|
||||
- **Self-registering tools**: Each tool file calls `registry.register()` at import time. `model_tools.py` triggers discovery by importing all tool modules.
|
||||
- **Toolset grouping**: Tools are grouped into toolsets (`web`, `terminal`, `file`, `browser`, etc.) that can be enabled/disabled per platform.
|
||||
- **Session persistence**: All conversations are stored in SQLite (`hermes_state.py`) with full-text search and unique session titles. JSON logs go to `~/.hermes/sessions/`.
|
||||
- **Ephemeral injection**: System prompts and prefill messages are injected at API call time, never persisted to the database or logs.
|
||||
- **Provider abstraction**: The agent works with any OpenAI-compatible API. Provider resolution happens at init time (Nous Portal OAuth, OpenRouter API key, or custom endpoint).
|
||||
- **Provider routing**: When using OpenRouter, `provider_routing` in config.yaml controls provider selection (sort by throughput/latency/price, allow/ignore specific providers, data retention policies). These are injected as `extra_body.provider` in API requests.
|
||||
|
||||
---
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Formatting**: Enforced by **ruff** (config in `pyproject.toml`). Run `make fmt` to auto-fix, `make lint` to check. Pre-commit hooks handle this automatically.
|
||||
- **Comments**: Only when explaining non-obvious intent, trade-offs, or API quirks. Don't narrate what the code does — `# increment counter` adds nothing
|
||||
- **Error handling**: Catch specific exceptions. Log with `logger.warning()`/`logger.error()` — use `exc_info=True` for unexpected errors so stack traces appear in logs
|
||||
- **Cross-platform**: Never assume Unix. See [Cross-Platform Compatibility](#cross-platform-compatibility)
|
||||
|
||||
---
|
||||
|
||||
## Adding a New Tool
|
||||
|
||||
Before writing a tool, ask: [should this be a skill instead?](#should-it-be-a-skill-or-a-tool)
|
||||
|
||||
Tools self-register with the central registry. Each tool file co-locates its schema, handler, and registration:
|
||||
|
||||
```python
|
||||
"""my_tool — Brief description of what this tool does."""
|
||||
|
||||
import json
|
||||
from tools.registry import registry
|
||||
|
||||
|
||||
def my_tool(param1: str, param2: int = 10, **kwargs) -> str:
|
||||
"""Handler. Returns a string result (often JSON)."""
|
||||
result = do_work(param1, param2)
|
||||
return json.dumps(result)
|
||||
|
||||
|
||||
MY_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "my_tool",
|
||||
"description": "What this tool does and when the agent should use it.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {"type": "string", "description": "What param1 is"},
|
||||
"param2": {"type": "integer", "description": "What param2 is", "default": 10},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _check_requirements() -> bool:
|
||||
"""Return True if this tool's dependencies are available."""
|
||||
return True
|
||||
|
||||
|
||||
registry.register(
|
||||
name="my_tool",
|
||||
toolset="my_toolset",
|
||||
schema=MY_TOOL_SCHEMA,
|
||||
handler=lambda args, **kw: my_tool(**args, **kw),
|
||||
check_fn=_check_requirements,
|
||||
)
|
||||
```
|
||||
|
||||
Then add the import to `model_tools.py` in the `_modules` list:
|
||||
|
||||
```python
|
||||
_modules = [
|
||||
# ... existing modules ...
|
||||
"tools.my_tool",
|
||||
]
|
||||
```
|
||||
|
||||
If it's a new toolset, add it to `toolsets.py` and to the relevant platform presets.
|
||||
|
||||
---
|
||||
|
||||
## Adding a Skill
|
||||
|
||||
Bundled skills live in `skills/` organized by category. Official optional skills use the same structure in `optional-skills/`:
|
||||
|
||||
```
|
||||
skills/
|
||||
├── research/
|
||||
│ └── arxiv/
|
||||
│ ├── SKILL.md # Required: main instructions
|
||||
│ └── scripts/ # Optional: helper scripts
|
||||
│ └── search_arxiv.py
|
||||
├── productivity/
|
||||
│ └── ocr-and-documents/
|
||||
│ ├── SKILL.md
|
||||
│ ├── scripts/
|
||||
│ └── references/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### SKILL.md format
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: my-skill
|
||||
description: Brief description (shown in skill search results)
|
||||
version: 1.0.0
|
||||
author: Your Name
|
||||
license: MIT
|
||||
platforms: [macos, linux] # Optional — restrict to specific OS platforms
|
||||
# Valid: macos, linux, windows
|
||||
# Omit to load on all platforms (default)
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Category, Subcategory, Keywords]
|
||||
related_skills: [other-skill-name]
|
||||
---
|
||||
|
||||
# Skill Title
|
||||
|
||||
Brief intro.
|
||||
|
||||
## When to Use
|
||||
Trigger conditions — when should the agent load this skill?
|
||||
|
||||
## Quick Reference
|
||||
Table of common commands or API calls.
|
||||
|
||||
## Procedure
|
||||
Step-by-step instructions the agent follows.
|
||||
|
||||
## Pitfalls
|
||||
Known failure modes and how to handle them.
|
||||
|
||||
## Verification
|
||||
How the agent confirms it worked.
|
||||
```
|
||||
|
||||
### Platform-specific skills
|
||||
|
||||
Skills can declare which OS platforms they support via the `platforms` frontmatter field. Skills with this field are automatically hidden from the system prompt, `skills_list()`, and slash commands on incompatible platforms.
|
||||
|
||||
```yaml
|
||||
platforms: [macos] # macOS only (e.g., iMessage, Apple Reminders)
|
||||
platforms: [macos, linux] # macOS and Linux
|
||||
platforms: [windows] # Windows only
|
||||
```
|
||||
|
||||
If the field is omitted or empty, the skill loads on all platforms (backward compatible). See `skills/apple/` for examples of macOS-only skills.
|
||||
|
||||
### Skill guidelines
|
||||
|
||||
- **No external dependencies unless absolutely necessary.** Prefer stdlib Python, curl, and existing Hermes tools (`web_extract`, `terminal`, `read_file`).
|
||||
- **Progressive disclosure.** Put the most common workflow first. Edge cases and advanced usage go at the bottom.
|
||||
- **Include helper scripts** for XML/JSON parsing or complex logic — don't expect the LLM to write parsers inline every time.
|
||||
- **Test it.** Run `hermes --toolsets skills -q "Use the X skill to do Y"` and verify the agent follows the instructions correctly.
|
||||
|
||||
---
|
||||
|
||||
## Cross-Platform Compatibility
|
||||
|
||||
Hermes runs on Linux, macOS, and Windows. When writing code that touches the OS:
|
||||
|
||||
### Critical rules
|
||||
|
||||
1. **`termios` and `fcntl` are Unix-only.** Always catch both `ImportError` and `NotImplementedError`:
|
||||
```python
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
menu = TerminalMenu(options)
|
||||
idx = menu.show()
|
||||
except (ImportError, NotImplementedError):
|
||||
# Fallback: numbered menu for Windows
|
||||
for i, opt in enumerate(options):
|
||||
print(f" {i+1}. {opt}")
|
||||
idx = int(input("Choice: ")) - 1
|
||||
```
|
||||
|
||||
2. **File encoding.** Windows may save `.env` files in `cp1252`. Always handle encoding errors:
|
||||
```python
|
||||
try:
|
||||
load_dotenv(env_path)
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(env_path, encoding="latin-1")
|
||||
```
|
||||
|
||||
3. **Process management.** `os.setsid()`, `os.killpg()`, and signal handling differ on Windows. Use platform checks:
|
||||
```python
|
||||
import platform
|
||||
if platform.system() != "Windows":
|
||||
kwargs["preexec_fn"] = os.setsid
|
||||
```
|
||||
|
||||
4. **Path separators.** Use `pathlib.Path` instead of string concatenation with `/`.
|
||||
|
||||
5. **Shell commands in installers.** If you change `scripts/install.sh`, check if the equivalent change is needed in `scripts/install.ps1`.
|
||||
|
||||
---
|
||||
|
||||
## Security Considerations
|
||||
|
||||
Hermes has terminal access. Security matters.
|
||||
|
||||
### Existing protections
|
||||
|
||||
| Layer | Implementation |
|
||||
|-------|---------------|
|
||||
| **Sudo password piping** | Uses `shlex.quote()` to prevent shell injection |
|
||||
| **Dangerous command detection** | Regex patterns in `tools/approval.py` with user approval flow |
|
||||
| **Cron prompt injection** | Scanner in `tools/cronjob_tools.py` blocks instruction-override patterns |
|
||||
| **Write deny list** | Protected paths (`~/.ssh/authorized_keys`, `/etc/shadow`) resolved via `os.path.realpath()` to prevent symlink bypass |
|
||||
| **Skills guard** | Security scanner for hub-installed skills (`tools/skills_guard.py`) |
|
||||
| **Code execution sandbox** | `execute_code` child process runs with API keys stripped from environment |
|
||||
| **Container hardening** | Docker: all capabilities dropped, no privilege escalation, PID limits, size-limited tmpfs |
|
||||
|
||||
### When contributing security-sensitive code
|
||||
|
||||
- **Always use `shlex.quote()`** when interpolating user input into shell commands
|
||||
- **Resolve symlinks** with `os.path.realpath()` before path-based access control checks
|
||||
- **Don't log secrets.** API keys, tokens, and passwords should never appear in log output
|
||||
- **Catch broad exceptions** around tool execution so a single failure doesn't crash the agent loop
|
||||
- **Test on all platforms** if your change touches file paths, process management, or shell commands
|
||||
|
||||
If your PR affects security, note it explicitly in the description.
|
||||
|
||||
---
|
||||
|
||||
## Pull Request Process
|
||||
|
||||
### Branch naming
|
||||
|
||||
```
|
||||
fix/description # Bug fixes
|
||||
feat/description # New features
|
||||
docs/description # Documentation
|
||||
test/description # Tests
|
||||
refactor/description # Code restructuring
|
||||
```
|
||||
|
||||
### Before submitting
|
||||
|
||||
1. **Run checks**: `make check` (lint + test — same as CI)
|
||||
2. **Test manually**: Run `hermes` and exercise the code path you changed
|
||||
3. **Check cross-platform impact**: If you touch file I/O, process management, or terminal handling, consider Windows and macOS
|
||||
4. **Keep PRs focused**: One logical change per PR. Don't mix a bug fix with a refactor with a new feature.
|
||||
|
||||
### PR description
|
||||
|
||||
Include:
|
||||
- **What** changed and **why**
|
||||
- **How to test** it (reproduction steps for bugs, usage examples for features)
|
||||
- **What platforms** you tested on
|
||||
- Reference any related issues
|
||||
|
||||
### Commit messages
|
||||
|
||||
We use [Conventional Commits](https://www.conventionalcommits.org/):
|
||||
|
||||
```
|
||||
<type>(<scope>): <description>
|
||||
```
|
||||
|
||||
| Type | Use for |
|
||||
|------|---------|
|
||||
| `fix` | Bug fixes |
|
||||
| `feat` | New features |
|
||||
| `docs` | Documentation |
|
||||
| `test` | Tests |
|
||||
| `refactor` | Code restructuring (no behavior change) |
|
||||
| `chore` | Build, CI, dependency updates |
|
||||
|
||||
Scopes: `cli`, `gateway`, `tools`, `skills`, `agent`, `install`, `whatsapp`, `security`, etc.
|
||||
|
||||
Examples:
|
||||
```
|
||||
fix(cli): prevent crash in save_config_value when model is a string
|
||||
feat(gateway): add WhatsApp multi-user session isolation
|
||||
fix(security): prevent shell injection in sudo password piping
|
||||
test(tools): add unit tests for file_operations
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Reporting Issues
|
||||
|
||||
- Use [GitHub Issues](https://github.com/NousResearch/hermes-agent/issues)
|
||||
- Include: OS, Python version, Hermes version (`hermes version`), full error traceback
|
||||
- Include steps to reproduce
|
||||
- Check existing issues before creating duplicates
|
||||
- For security vulnerabilities, please report privately
|
||||
|
||||
---
|
||||
|
||||
## Community
|
||||
|
||||
- **Discord**: [discord.gg/NousResearch](https://discord.gg/NousResearch) — for questions, showcasing projects, and sharing skills
|
||||
- **GitHub Discussions**: For design proposals and architecture discussions
|
||||
- **Skills Hub**: Upload specialized skills to a registry and share them with the community
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
By contributing, you agree that your contributions will be licensed under the [MIT License](LICENSE).
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Nous Research
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
69
Makefile
Normal file
69
Makefile
Normal file
@@ -0,0 +1,69 @@
|
||||
.DEFAULT_GOAL := help
|
||||
SHELL := /bin/bash
|
||||
VENV := .venv
|
||||
UV := uv
|
||||
|
||||
SRC := run_agent.py model_tools.py toolsets.py cli.py hermes_state.py batch_runner.py \
|
||||
tools/ hermes_cli/ gateway/ agent/ cron/
|
||||
|
||||
# ─── Setup ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: setup sync clean
|
||||
|
||||
setup: ## Full dev setup (venv + deps + pre-commit)
|
||||
$(UV) venv $(VENV) --python 3.11
|
||||
. $(VENV)/bin/activate && $(UV) pip install -e ".[all,dev]"
|
||||
. $(VENV)/bin/activate && $(UV) pip install -e "./mini-swe-agent"
|
||||
. $(VENV)/bin/activate && pre-commit install
|
||||
@echo "\n✅ Setup complete. Run: source $(VENV)/bin/activate"
|
||||
|
||||
sync: ## Reinstall deps into existing venv
|
||||
. $(VENV)/bin/activate && $(UV) pip install -e ".[all,dev]"
|
||||
|
||||
clean: ## Remove build artifacts and caches
|
||||
rm -rf .ruff_cache .mypy_cache .pytest_cache dist build *.egg-info
|
||||
find . -type d -name __pycache__ -not -path "./.venv/*" -exec rm -rf {} +
|
||||
|
||||
# ─── Quality ────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: lint fmt check
|
||||
|
||||
lint: ## Check lint + formatting (no changes)
|
||||
. $(VENV)/bin/activate && ruff check $(SRC)
|
||||
. $(VENV)/bin/activate && ruff format --check $(SRC)
|
||||
|
||||
fmt: ## Auto-fix lint + format
|
||||
. $(VENV)/bin/activate && ruff format $(SRC)
|
||||
. $(VENV)/bin/activate && ruff check --fix $(SRC)
|
||||
|
||||
check: lint test ## Lint + test (mirrors CI)
|
||||
|
||||
# ─── Test ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: test test-fast test-watch
|
||||
|
||||
test: ## Run full test suite
|
||||
. $(VENV)/bin/activate && python -m pytest tests/ -q --ignore=tests/integration --tb=short
|
||||
|
||||
test-fast: ## Run tests with fail-fast
|
||||
. $(VENV)/bin/activate && python -m pytest tests/ -q --ignore=tests/integration --tb=short -x
|
||||
|
||||
test-watch: ## Rerun tests on file changes
|
||||
. $(VENV)/bin/activate && python -m watchfiles "python -m pytest tests/ -q --ignore=tests/integration --tb=short -x" $(SRC) tests/
|
||||
|
||||
# ─── Dev Servers ────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: dev-cli dev-gateway
|
||||
|
||||
dev-cli: ## Auto-restart CLI on file changes
|
||||
. $(VENV)/bin/activate && python -m watchfiles "python -m hermes_cli.main" $(SRC)
|
||||
|
||||
dev-gateway: ## Auto-restart gateway on file changes
|
||||
. $(VENV)/bin/activate && python -m watchfiles "python -m gateway.run" $(SRC)
|
||||
|
||||
# ─── Misc ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: help
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'
|
||||
302
README.md
302
README.md
@@ -1,243 +1,117 @@
|
||||
# Hermes Agent
|
||||
<p align="center">
|
||||
<img src="assets/banner.png" alt="Hermes Agent" width="100%">
|
||||
</p>
|
||||
|
||||
An AI agent with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools.
|
||||
# Hermes Agent ⚕
|
||||
|
||||
## Features
|
||||
<p align="center">
|
||||
<a href="https://hermes-agent.nousresearch.com/docs/"><img src="https://img.shields.io/badge/Docs-hermes--agent.nousresearch.com-FFD700?style=for-the-badge" alt="Documentation"></a>
|
||||
<a href="https://discord.gg/NousResearch"><img src="https://img.shields.io/badge/Discord-5865F2?style=for-the-badge&logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://github.com/NousResearch/hermes-agent/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-MIT-green?style=for-the-badge" alt="License: MIT"></a>
|
||||
<a href="https://nousresearch.com"><img src="https://img.shields.io/badge/Built%20by-Nous%20Research-blueviolet?style=for-the-badge" alt="Built by Nous Research"></a>
|
||||
</p>
|
||||
|
||||
- **Web Tools**: Search, extract content, and crawl websites
|
||||
- **Terminal Tools**: Execute commands with interactive session support
|
||||
- **Vision Tools**: Analyze images from URLs
|
||||
- **Reasoning Tools**: Advanced multi-model reasoning (Mixture of Agents)
|
||||
- **Creative Tools**: Generate images from text prompts
|
||||
- **Toolsets System**: Organize tools into logical groups for different scenarios
|
||||
- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking
|
||||
- **Ephemeral System Prompts**: Guide model behavior without polluting training datasets
|
||||
**The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM.
|
||||
|
||||
## Setup
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
|
||||
### 1. Install Dependencies
|
||||
```bash
|
||||
# Create and activate virtual environment (recommended)
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
<table>
|
||||
<tr><td><b>A real terminal interface</b></td><td>Full TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.</td></tr>
|
||||
<tr><td><b>Lives where you do</b></td><td>Telegram, Discord, Slack, WhatsApp, Signal, and CLI — all from a single gateway process. Voice memo transcription, cross-platform conversation continuity.</td></tr>
|
||||
<tr><td><b>A closed learning loop</b></td><td>Agent-curated memory with periodic nudges. Autonomous skill creation after complex tasks. Skills self-improve during use. FTS5 session search with LLM summarization for cross-session recall. <a href="https://github.com/plastic-labs/honcho">Honcho</a> dialectic user modeling. Compatible with the <a href="https://agentskills.io">agentskills.io</a> open standard.</td></tr>
|
||||
<tr><td><b>Scheduled automations</b></td><td>Built-in cron scheduler with delivery to any platform. Daily reports, nightly backups, weekly audits — all in natural language, running unattended.</td></tr>
|
||||
<tr><td><b>Delegates and parallelizes</b></td><td>Spawn isolated subagents for parallel workstreams. Write Python scripts that call tools via RPC, collapsing multi-step pipelines into zero-context-cost turns.</td></tr>
|
||||
<tr><td><b>Runs anywhere, not just your laptop</b></td><td>Six terminal backends — local, Docker, SSH, Daytona, Singularity, and Modal. Daytona and Modal offer serverless persistence — your agent's environment hibernates when idle and wakes on demand, costing nearly nothing between sessions. Run it on a $5 VPS or a GPU cluster.</td></tr>
|
||||
<tr><td><b>Research-ready</b></td><td>Batch trajectory generation, Atropos RL environments, trajectory compression for training the next generation of tool-calling models.</td></tr>
|
||||
</table>
|
||||
|
||||
# Install required packages
|
||||
pip install -r requirements.txt
|
||||
---
|
||||
|
||||
# Install Hecate for terminal tools
|
||||
git clone git@github.com:NousResearch/hecate.git
|
||||
cd hecate
|
||||
pip install -e .
|
||||
cd ..
|
||||
```
|
||||
|
||||
### 2. Configure Environment Variables
|
||||
```bash
|
||||
# Copy the example environment file
|
||||
cp .env.example .env
|
||||
|
||||
# Edit .env and add your API keys
|
||||
nano .env # or use your preferred editor
|
||||
```
|
||||
|
||||
**Required API Keys:**
|
||||
- `ANTHROPIC_API_KEY` - Main agent model (get at: https://console.anthropic.com/)
|
||||
- `FIRECRAWL_API_KEY` - Web tools (get at: https://firecrawl.dev/)
|
||||
- `NOUS_API_KEY` - Vision & reasoning tools (get at: https://inference-api.nousresearch.com/)
|
||||
- `MORPH_API_KEY` - Terminal tools (get at: https://morph.so/)
|
||||
- `FAL_KEY` - Image generation (get at: https://fal.ai/)
|
||||
- `OPENAI_API_KEY` - Optional, for some Hecate features
|
||||
|
||||
See `.env.example` for all available configuration options including debug settings and terminal tool configuration.
|
||||
|
||||
## Toolsets System
|
||||
|
||||
The agent uses a toolsets system for organizing and managing tools. All tools must be part of a toolset to be accessible - individual tool selection is not supported. This ensures consistent and logical grouping of capabilities.
|
||||
|
||||
### Key Concepts
|
||||
|
||||
- **Toolsets**: Logical groups of tools for specific use cases (e.g., "research", "development", "debugging")
|
||||
- **Composition**: Toolsets can include other toolsets for powerful combinations
|
||||
- **Custom Toolsets**: Create your own toolsets at runtime or by editing `toolsets.py`
|
||||
- **Toolset-Only Access**: Tools are only accessible through toolsets, not individually
|
||||
|
||||
### Available Toolsets
|
||||
|
||||
See `toolsets.py` for the complete list of predefined toolsets including:
|
||||
- Basic toolsets (web, terminal, vision, creative, reasoning)
|
||||
- Composite toolsets (research, development, analysis, etc.)
|
||||
- Scenario-specific toolsets (debugging, documentation, API testing, etc.)
|
||||
- Special toolsets (safe mode without terminal, minimal, offline)
|
||||
|
||||
### Using Toolsets
|
||||
## Quick Install
|
||||
|
||||
```bash
|
||||
# Use a predefined toolset
|
||||
python run_agent.py --enabled_toolsets=research --query "Find latest AI papers"
|
||||
|
||||
# Combine multiple toolsets
|
||||
python run_agent.py --enabled_toolsets=web,vision --query "Analyze this website"
|
||||
|
||||
# Enable all toolsets explicitly (same as omitting the flag)
|
||||
python run_agent.py --enabled_toolsets=all --query "Do web research and run commands if helpful"
|
||||
|
||||
# Safe mode (no terminal access)
|
||||
python run_agent.py --enabled_toolsets=safe --query "Help without running commands"
|
||||
|
||||
# List all available toolsets and tools
|
||||
python run_agent.py --list_tools
|
||||
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
For detailed documentation on toolsets, see `TOOLSETS_README.md`.
|
||||
Works on Linux, macOS, and WSL2. The installer handles everything — Python, Node.js, dependencies, and the `hermes` command. No prerequisites except git.
|
||||
|
||||
## Basic Usage
|
||||
> **Windows:** Native Windows is not supported. Please install [WSL2](https://learn.microsoft.com/en-us/windows/wsl/install) and run the command above.
|
||||
|
||||
### Default (all tools enabled)
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--query "search up the latest docs on jit in python 3.13 and write me basic example that's not in their docs. profile its perf" \
|
||||
--max_turns 20 \
|
||||
--model claude-sonnet-4-20250514 \
|
||||
--base_url https://api.anthropic.com/v1/ \
|
||||
--api_key $ANTHROPIC_API_KEY
|
||||
```
|
||||
|
||||
### With specific toolset
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--query "Debug this Python error" \
|
||||
--enabled_toolsets=debugging \
|
||||
--model claude-sonnet-4-20250514 \
|
||||
--api_key $ANTHROPIC_API_KEY
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Use a specific toolset
|
||||
agent = AIAgent(
|
||||
model="claude-opus-4-20250514",
|
||||
enabled_toolsets=["research"]
|
||||
)
|
||||
response = agent.chat("Find information about quantum computing")
|
||||
|
||||
# Create custom toolset at runtime
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
create_custom_toolset(
|
||||
name="my_tools",
|
||||
description="My custom toolkit",
|
||||
tools=["web_search"],
|
||||
includes=["terminal", "vision"]
|
||||
)
|
||||
|
||||
agent = AIAgent(enabled_toolsets=["my_tools"])
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
Process multiple prompts from a dataset in parallel with automatic checkpointing and statistics tracking:
|
||||
After installation:
|
||||
|
||||
```bash
|
||||
# Basic batch processing
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=my_run
|
||||
|
||||
# With specific distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=image_run \
|
||||
--distribution=image_gen \
|
||||
--num_workers=4
|
||||
source ~/.bashrc # reload shell (or: source ~/.zshrc)
|
||||
hermes setup # configure your LLM provider
|
||||
hermes # start chatting!
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Parallel processing with configurable workers
|
||||
- Toolset distributions for varied data generation
|
||||
- Automatic checkpointing and resume capability
|
||||
- Combined output in `data/<run_name>/trajectories.jsonl`
|
||||
- Tool usage statistics and success rates
|
||||
---
|
||||
|
||||
**Quick Start:** See [QUICKSTART_BATCH.md](QUICKSTART_BATCH.md) for a 5-minute getting started guide.
|
||||
**Full Documentation:** See [BATCH_PROCESSING.md](BATCH_PROCESSING.md) for comprehensive documentation.
|
||||
## Getting Started
|
||||
|
||||
### Ephemeral System Prompts
|
||||
|
||||
The ephemeral system prompt feature allows you to guide the model's behavior during batch processing **without** saving that prompt to the training dataset trajectories. This is useful for:
|
||||
|
||||
- Guiding model behavior during data collection
|
||||
- Adding task-specific instructions
|
||||
- Keeping saved trajectories clean and focused on tool-calling format
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=10 \
|
||||
--run_name=my_run \
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
hermes # Interactive CLI — start a conversation
|
||||
hermes model # Switch provider or model
|
||||
hermes setup # Re-run the setup wizard
|
||||
hermes gateway # Start the messaging gateway (Telegram, Discord, etc.)
|
||||
hermes update # Update to the latest version
|
||||
hermes doctor # Diagnose any issues
|
||||
```
|
||||
|
||||
The ephemeral prompt will influence the model's behavior during execution, but **only the standard tool-calling system prompt** will be saved in the trajectory files.
|
||||
📖 **[Full documentation →](https://hermes-agent.nousresearch.com/docs/)**
|
||||
|
||||
**Documentation:** See [docs/ephemeral_system_prompt.md](docs/ephemeral_system_prompt.md) for complete details.
|
||||
|
||||
## Command Line Arguments
|
||||
|
||||
**Single Agent (`run_agent.py`):**
|
||||
- `--query`: The question or task for the agent
|
||||
- `--model`: Model to use (default: claude-opus-4-20250514)
|
||||
- `--api_key`: API key for authentication
|
||||
- `--base_url`: API endpoint URL
|
||||
- `--max_turns`: Maximum number of tool-calling iterations
|
||||
- `--enabled_toolsets`: Comma-separated list of toolsets to enable. Use `all` (or `*`) to enable everything. If omitted, all toolsets are enabled by default.
|
||||
- `--disabled_toolsets`: Comma-separated list of toolsets to disable
|
||||
- `--list_tools`: List all available toolsets and tools
|
||||
- `--save_trajectories`: Save conversation trajectories to JSONL files
|
||||
|
||||
**Batch Processing (`batch_runner.py`):**
|
||||
- `--dataset_file`: Path to JSONL file with prompts
|
||||
- `--batch_size`: Number of prompts per batch
|
||||
- `--run_name`: Name for this run (for output/checkpointing)
|
||||
- `--distribution`: Toolset distribution to use (default: "default")
|
||||
- `--num_workers`: Number of parallel workers (default: 4)
|
||||
- `--resume`: Resume from checkpoint if interrupted
|
||||
- `--ephemeral_system_prompt`: System prompt used during execution but NOT saved to trajectories
|
||||
- `--list_distributions`: List available toolset distributions
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All environment variables can be configured in the `.env` file (copy from `.env.example`).
|
||||
|
||||
**Core API Keys:**
|
||||
- `ANTHROPIC_API_KEY`: Main agent model
|
||||
- `FIRECRAWL_API_KEY`: Web tools (search, extract, crawl)
|
||||
- `NOUS_API_KEY`: Vision and reasoning tools
|
||||
- `MORPH_API_KEY`: Terminal tools
|
||||
- `FAL_KEY`: Image generation tools
|
||||
- `OPENAI_API_KEY`: Optional, for some Hecate features
|
||||
|
||||
**Configuration Options:**
|
||||
- `HECATE_VM_LIFETIME_SECONDS`: VM lifetime (default: 300)
|
||||
- `HECATE_DEFAULT_SNAPSHOT_ID`: Default snapshot (default: snapshot_p5294qxt)
|
||||
- `WEB_TOOLS_DEBUG`, `VISION_TOOLS_DEBUG`, `MOA_TOOLS_DEBUG`, `IMAGE_TOOLS_DEBUG`: Enable debug logging
|
||||
---
|
||||
|
||||
## Documentation
|
||||
|
||||
**Single Agent Usage:**
|
||||
- `TOOLSETS_README.md`: Comprehensive guide to the toolsets system
|
||||
- `toolsets.py`: View and modify available toolsets
|
||||
- `model_tools.py`: Core tool definitions and handlers
|
||||
All documentation lives at **[hermes-agent.nousresearch.com/docs](https://hermes-agent.nousresearch.com/docs/)**:
|
||||
|
||||
**Batch Processing:**
|
||||
- `QUICKSTART_BATCH.md`: 5-minute quick start guide
|
||||
- `BATCH_PROCESSING.md`: Complete batch processing documentation
|
||||
- `toolset_distributions.py`: Toolset distributions for data generation
|
||||
| Section | What's Covered |
|
||||
|---------|---------------|
|
||||
| [Quickstart](https://hermes-agent.nousresearch.com/docs/getting-started/quickstart) | Install → setup → first conversation in 2 minutes |
|
||||
| [CLI Usage](https://hermes-agent.nousresearch.com/docs/user-guide/cli) | Commands, keybindings, personalities, sessions |
|
||||
| [Configuration](https://hermes-agent.nousresearch.com/docs/user-guide/configuration) | Config file, providers, models, all options |
|
||||
| [Messaging Gateway](https://hermes-agent.nousresearch.com/docs/user-guide/messaging) | Telegram, Discord, Slack, WhatsApp, Signal, Home Assistant |
|
||||
| [Security](https://hermes-agent.nousresearch.com/docs/user-guide/security) | Command approval, DM pairing, container isolation |
|
||||
| [Tools & Toolsets](https://hermes-agent.nousresearch.com/docs/user-guide/features/tools) | 40+ tools, toolset system, terminal backends |
|
||||
| [Skills System](https://hermes-agent.nousresearch.com/docs/user-guide/features/skills) | Procedural memory, Skills Hub, creating skills |
|
||||
| [Memory](https://hermes-agent.nousresearch.com/docs/user-guide/features/memory) | Persistent memory, user profiles, best practices |
|
||||
| [MCP Integration](https://hermes-agent.nousresearch.com/docs/user-guide/features/mcp) | Connect any MCP server for extended capabilities |
|
||||
| [Cron Scheduling](https://hermes-agent.nousresearch.com/docs/user-guide/features/cron) | Scheduled tasks with platform delivery |
|
||||
| [Context Files](https://hermes-agent.nousresearch.com/docs/user-guide/features/context-files) | Project context that shapes every conversation |
|
||||
| [Architecture](https://hermes-agent.nousresearch.com/docs/developer-guide/architecture) | Project structure, agent loop, key classes |
|
||||
| [Contributing](https://hermes-agent.nousresearch.com/docs/developer-guide/contributing) | Development setup, PR process, code style |
|
||||
| [CLI Reference](https://hermes-agent.nousresearch.com/docs/reference/cli-commands) | All commands and flags |
|
||||
| [Environment Variables](https://hermes-agent.nousresearch.com/docs/reference/environment-variables) | Complete env var reference |
|
||||
|
||||
## Examples
|
||||
---
|
||||
|
||||
See `TOOLSETS_README.md` for extensive examples of using different toolsets for various scenarios.
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! See the [Contributing Guide](https://hermes-agent.nousresearch.com/docs/developer-guide/contributing) for development setup, code style, and PR process.
|
||||
|
||||
Quick start for contributors:
|
||||
|
||||
```bash
|
||||
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
make setup # creates .venv, installs everything
|
||||
make check # lint + test (same as CI)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Community
|
||||
|
||||
- 💬 [Discord](https://discord.gg/NousResearch)
|
||||
- 📚 [Skills Hub](https://agentskills.io)
|
||||
- 🐛 [Issues](https://github.com/NousResearch/hermes-agent/issues)
|
||||
- 💡 [Discussions](https://github.com/NousResearch/hermes-agent/discussions)
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT — see [LICENSE](LICENSE).
|
||||
|
||||
Built by [Nous Research](https://nousresearch.com).
|
||||
|
||||
Binary file not shown.
Binary file not shown.
6
agent/__init__.py
Normal file
6
agent/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Agent internals -- extracted modules from run_agent.py.
|
||||
|
||||
These modules contain pure utility functions and self-contained classes
|
||||
that were previously embedded in the 3,600-line run_agent.py. Extracting
|
||||
them makes run_agent.py focused on the AIAgent orchestrator class.
|
||||
"""
|
||||
605
agent/auxiliary_client.py
Normal file
605
agent/auxiliary_client.py
Normal file
@@ -0,0 +1,605 @@
|
||||
"""Shared auxiliary OpenAI client for cheap/fast side tasks.
|
||||
|
||||
Provides a single resolution chain so every consumer (context compression,
|
||||
session search, web extraction, vision analysis, browser vision) picks up
|
||||
the best available backend without duplicating fallback logic.
|
||||
|
||||
Resolution order for text tasks (auto mode):
|
||||
1. OpenRouter (OPENROUTER_API_KEY)
|
||||
2. Nous Portal (~/.hermes/auth.json active provider)
|
||||
3. Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY)
|
||||
4. Codex OAuth (Responses API via chatgpt.com with gpt-5.3-codex,
|
||||
wrapped to look like a chat.completions client)
|
||||
5. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, MiniMax-CN)
|
||||
— checked via PROVIDER_REGISTRY entries with auth_type='api_key'
|
||||
6. None
|
||||
|
||||
Resolution order for vision/multimodal tasks (auto mode):
|
||||
1. OpenRouter
|
||||
2. Nous Portal
|
||||
3. None (steps 3-5 are skipped — they may not support multimodal)
|
||||
|
||||
Per-task provider overrides (e.g. AUXILIARY_VISION_PROVIDER,
|
||||
CONTEXT_COMPRESSION_PROVIDER) can force a specific provider for each task:
|
||||
"openrouter", "nous", "codex", or "main" (= steps 3-5).
|
||||
Default "auto" follows the chains above.
|
||||
|
||||
Per-task model overrides (e.g. AUXILIARY_VISION_MODEL,
|
||||
AUXILIARY_WEB_EXTRACT_MODEL) let callers use a different model slug
|
||||
than the provider's default.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: dict[str, str] = {
|
||||
"zai": "glm-4.5-flash",
|
||||
"kimi-coding": "kimi-k2-turbo-preview",
|
||||
"minimax": "MiniMax-M2.5-highspeed",
|
||||
"minimax-cn": "MiniMax-M2.5-highspeed",
|
||||
}
|
||||
|
||||
# OpenRouter app attribution headers
|
||||
_OR_HEADERS = {
|
||||
"HTTP-Referer": "https://github.com/NousResearch/hermes-agent",
|
||||
"X-OpenRouter-Title": "Hermes Agent",
|
||||
"X-OpenRouter-Categories": "productivity,cli-agent",
|
||||
}
|
||||
|
||||
# Nous Portal extra_body for product attribution.
|
||||
# Callers should pass this as extra_body in chat.completions.create()
|
||||
# when the auxiliary client is backed by Nous Portal.
|
||||
NOUS_EXTRA_BODY = {"tags": ["product=hermes-agent"]}
|
||||
|
||||
# Set at resolve time — True if the auxiliary client points to Nous Portal
|
||||
auxiliary_is_nous: bool = False
|
||||
|
||||
# Default auxiliary models per provider
|
||||
_OPENROUTER_MODEL = "google/gemini-3-flash-preview"
|
||||
_NOUS_MODEL = "gemini-3-flash"
|
||||
_NOUS_DEFAULT_BASE_URL = "https://inference-api.nousresearch.com/v1"
|
||||
_AUTH_JSON_PATH = Path.home() / ".hermes" / "auth.json"
|
||||
|
||||
# Codex fallback: uses the Responses API (the only endpoint the Codex
|
||||
# OAuth token can access) with a fast model for auxiliary tasks.
|
||||
_CODEX_AUX_MODEL = "gpt-5.3-codex"
|
||||
_CODEX_AUX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
# ── Codex Responses → chat.completions adapter ─────────────────────────────
|
||||
# All auxiliary consumers call client.chat.completions.create(**kwargs) and
|
||||
# read response.choices[0].message.content. This adapter translates those
|
||||
# calls to the Codex Responses API so callers don't need any changes.
|
||||
|
||||
|
||||
def _convert_content_for_responses(content: Any) -> Any:
|
||||
"""Convert chat.completions content to Responses API format.
|
||||
|
||||
chat.completions uses:
|
||||
{"type": "text", "text": "..."}
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
|
||||
|
||||
Responses API uses:
|
||||
{"type": "input_text", "text": "..."}
|
||||
{"type": "input_image", "image_url": "data:image/png;base64,..."}
|
||||
|
||||
If content is a plain string, it's returned as-is (the Responses API
|
||||
accepts strings directly for text-only messages).
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if not isinstance(content, list):
|
||||
return str(content) if content else ""
|
||||
|
||||
converted: list[dict[str, Any]] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
ptype = part.get("type", "")
|
||||
if ptype == "text":
|
||||
converted.append({"type": "input_text", "text": part.get("text", "")})
|
||||
elif ptype == "image_url":
|
||||
# chat.completions nests the URL: {"image_url": {"url": "..."}}
|
||||
image_data = part.get("image_url", {})
|
||||
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
|
||||
entry: dict[str, Any] = {"type": "input_image", "image_url": url}
|
||||
# Preserve detail if specified
|
||||
detail = image_data.get("detail") if isinstance(image_data, dict) else None
|
||||
if detail:
|
||||
entry["detail"] = detail
|
||||
converted.append(entry)
|
||||
elif ptype in ("input_text", "input_image"):
|
||||
# Already in Responses format — pass through
|
||||
converted.append(part)
|
||||
else:
|
||||
# Unknown content type — try to preserve as text
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
converted.append({"type": "input_text", "text": text})
|
||||
|
||||
return converted or ""
|
||||
|
||||
|
||||
class _CodexCompletionsAdapter:
|
||||
"""Drop-in shim that accepts chat.completions.create() kwargs and
|
||||
routes them through the Codex Responses streaming API."""
|
||||
|
||||
def __init__(self, real_client: OpenAI, model: str):
|
||||
self._client = real_client
|
||||
self._model = model
|
||||
|
||||
def create(self, **kwargs) -> Any:
|
||||
messages = kwargs.get("messages", [])
|
||||
model = kwargs.get("model", self._model)
|
||||
temperature = kwargs.get("temperature")
|
||||
|
||||
# Separate system/instructions from conversation messages.
|
||||
# Convert chat.completions multimodal content blocks to Responses
|
||||
# API format (input_text / input_image instead of text / image_url).
|
||||
instructions = "You are a helpful assistant."
|
||||
input_msgs: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content") or ""
|
||||
if role == "system":
|
||||
instructions = content if isinstance(content, str) else str(content)
|
||||
else:
|
||||
input_msgs.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": _convert_content_for_responses(content),
|
||||
}
|
||||
)
|
||||
|
||||
resp_kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"instructions": instructions,
|
||||
"input": input_msgs or [{"role": "user", "content": ""}],
|
||||
"store": False,
|
||||
}
|
||||
|
||||
# Note: the Codex endpoint (chatgpt.com/backend-api/codex) does NOT
|
||||
# support max_output_tokens or temperature — omit to avoid 400 errors.
|
||||
|
||||
# Tools support for flush_memories and similar callers
|
||||
tools = kwargs.get("tools")
|
||||
if tools:
|
||||
converted = []
|
||||
for t in tools:
|
||||
fn = t.get("function", {}) if isinstance(t, dict) else {}
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
converted.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
if converted:
|
||||
resp_kwargs["tools"] = converted
|
||||
|
||||
# Stream and collect the response
|
||||
text_parts: list[str] = []
|
||||
tool_calls_raw: list[Any] = []
|
||||
usage = None
|
||||
|
||||
try:
|
||||
with self._client.responses.stream(**resp_kwargs) as stream:
|
||||
for _event in stream:
|
||||
pass
|
||||
final = stream.get_final_response()
|
||||
|
||||
# Extract text and tool calls from the Responses output
|
||||
for item in getattr(final, "output", []):
|
||||
item_type = getattr(item, "type", None)
|
||||
if item_type == "message":
|
||||
for part in getattr(item, "content", []):
|
||||
ptype = getattr(part, "type", None)
|
||||
if ptype in ("output_text", "text"):
|
||||
text_parts.append(getattr(part, "text", ""))
|
||||
elif item_type == "function_call":
|
||||
tool_calls_raw.append(
|
||||
SimpleNamespace(
|
||||
id=getattr(item, "call_id", ""),
|
||||
type="function",
|
||||
function=SimpleNamespace(
|
||||
name=getattr(item, "name", ""),
|
||||
arguments=getattr(item, "arguments", "{}"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
resp_usage = getattr(final, "usage", None)
|
||||
if resp_usage:
|
||||
usage = SimpleNamespace(
|
||||
prompt_tokens=getattr(resp_usage, "input_tokens", 0),
|
||||
completion_tokens=getattr(resp_usage, "output_tokens", 0),
|
||||
total_tokens=getattr(resp_usage, "total_tokens", 0),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Codex auxiliary Responses API call failed: %s", exc)
|
||||
raise
|
||||
|
||||
content = "".join(text_parts).strip() or None
|
||||
|
||||
# Build a response that looks like chat.completions
|
||||
message = SimpleNamespace(
|
||||
role="assistant",
|
||||
content=content,
|
||||
tool_calls=tool_calls_raw or None,
|
||||
)
|
||||
choice = SimpleNamespace(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason="stop" if not tool_calls_raw else "tool_calls",
|
||||
)
|
||||
return SimpleNamespace(
|
||||
choices=[choice],
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class _CodexChatShim:
|
||||
"""Wraps the adapter to provide client.chat.completions.create()."""
|
||||
|
||||
def __init__(self, adapter: _CodexCompletionsAdapter):
|
||||
self.completions = adapter
|
||||
|
||||
|
||||
class CodexAuxiliaryClient:
|
||||
"""OpenAI-client-compatible wrapper that routes through Codex Responses API.
|
||||
|
||||
Consumers can call client.chat.completions.create(**kwargs) as normal.
|
||||
Also exposes .api_key and .base_url for introspection by async wrappers.
|
||||
"""
|
||||
|
||||
def __init__(self, real_client: OpenAI, model: str):
|
||||
self._real_client = real_client
|
||||
adapter = _CodexCompletionsAdapter(real_client, model)
|
||||
self.chat = _CodexChatShim(adapter)
|
||||
self.api_key = real_client.api_key
|
||||
self.base_url = real_client.base_url
|
||||
|
||||
def close(self):
|
||||
self._real_client.close()
|
||||
|
||||
|
||||
class _AsyncCodexCompletionsAdapter:
|
||||
"""Async version of the Codex Responses adapter.
|
||||
|
||||
Wraps the sync adapter via asyncio.to_thread() so async consumers
|
||||
(web_tools, session_search) can await it as normal.
|
||||
"""
|
||||
|
||||
def __init__(self, sync_adapter: _CodexCompletionsAdapter):
|
||||
self._sync = sync_adapter
|
||||
|
||||
async def create(self, **kwargs) -> Any:
|
||||
import asyncio
|
||||
|
||||
return await asyncio.to_thread(self._sync.create, **kwargs)
|
||||
|
||||
|
||||
class _AsyncCodexChatShim:
|
||||
def __init__(self, adapter: _AsyncCodexCompletionsAdapter):
|
||||
self.completions = adapter
|
||||
|
||||
|
||||
class AsyncCodexAuxiliaryClient:
|
||||
"""Async-compatible wrapper matching AsyncOpenAI.chat.completions.create()."""
|
||||
|
||||
def __init__(self, sync_wrapper: "CodexAuxiliaryClient"):
|
||||
sync_adapter = sync_wrapper.chat.completions
|
||||
async_adapter = _AsyncCodexCompletionsAdapter(sync_adapter)
|
||||
self.chat = _AsyncCodexChatShim(async_adapter)
|
||||
self.api_key = sync_wrapper.api_key
|
||||
self.base_url = sync_wrapper.base_url
|
||||
|
||||
|
||||
def _read_nous_auth() -> dict | None:
|
||||
"""Read and validate ~/.hermes/auth.json for an active Nous provider.
|
||||
|
||||
Returns the provider state dict if Nous is active with tokens,
|
||||
otherwise None.
|
||||
"""
|
||||
try:
|
||||
if not _AUTH_JSON_PATH.is_file():
|
||||
return None
|
||||
data = json.loads(_AUTH_JSON_PATH.read_text())
|
||||
if data.get("active_provider") != "nous":
|
||||
return None
|
||||
provider = data.get("providers", {}).get("nous", {})
|
||||
# Must have at least an access_token or agent_key
|
||||
if not provider.get("agent_key") and not provider.get("access_token"):
|
||||
return None
|
||||
return provider
|
||||
except Exception as exc:
|
||||
logger.debug("Could not read Nous auth: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _nous_api_key(provider: dict) -> str:
|
||||
"""Extract the best API key from a Nous provider state dict."""
|
||||
return provider.get("agent_key") or provider.get("access_token", "")
|
||||
|
||||
|
||||
def _nous_base_url() -> str:
|
||||
"""Resolve the Nous inference base URL from env or default."""
|
||||
return os.getenv("NOUS_INFERENCE_BASE_URL", _NOUS_DEFAULT_BASE_URL)
|
||||
|
||||
|
||||
def _read_codex_access_token() -> str | None:
|
||||
"""Read a valid Codex OAuth access token from Hermes auth store (~/.hermes/auth.json)."""
|
||||
try:
|
||||
from hermes_cli.auth import _read_codex_tokens
|
||||
|
||||
data = _read_codex_tokens()
|
||||
tokens = data.get("tokens", {})
|
||||
access_token = tokens.get("access_token")
|
||||
if isinstance(access_token, str) and access_token.strip():
|
||||
return access_token.strip()
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug("Could not read Codex auth for auxiliary client: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_api_key_provider() -> tuple[OpenAI | None, str | None]:
|
||||
"""Try each API-key provider in PROVIDER_REGISTRY order.
|
||||
|
||||
Returns (client, model) for the first provider whose env var is set,
|
||||
or (None, None) if none are configured.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
except ImportError:
|
||||
logger.debug("Could not import PROVIDER_REGISTRY for API-key fallback")
|
||||
return None, None
|
||||
|
||||
for provider_id, pconfig in PROVIDER_REGISTRY.items():
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
# Check if any of the provider's env vars are set
|
||||
api_key = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
break
|
||||
if not api_key:
|
||||
continue
|
||||
# Resolve base URL (with optional env-var override)
|
||||
# Kimi Code keys (sk-kimi-) need api.kimi.com/coding/v1
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
if env_url:
|
||||
base_url = env_url.rstrip("/")
|
||||
elif provider_id == "kimi-coding" and api_key.startswith("sk-kimi-"):
|
||||
base_url = "https://api.kimi.com/coding/v1"
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
# ── Provider resolution helpers ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_auxiliary_provider(task: str = "") -> str:
|
||||
"""Read the provider override for a specific auxiliary task.
|
||||
|
||||
Checks AUXILIARY_{TASK}_PROVIDER first (e.g. AUXILIARY_VISION_PROVIDER),
|
||||
then CONTEXT_{TASK}_PROVIDER (for the compression section's summary_provider),
|
||||
then falls back to "auto". Returns one of: "auto", "openrouter", "nous", "main".
|
||||
"""
|
||||
if task:
|
||||
for prefix in ("AUXILIARY_", "CONTEXT_"):
|
||||
val = os.getenv(f"{prefix}{task.upper()}_PROVIDER", "").strip().lower()
|
||||
if val and val != "auto":
|
||||
return val
|
||||
return "auto"
|
||||
|
||||
|
||||
def _try_openrouter() -> tuple[OpenAI | None, str | None]:
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not or_key:
|
||||
return None, None
|
||||
logger.debug("Auxiliary client: OpenRouter")
|
||||
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL, default_headers=_OR_HEADERS), _OPENROUTER_MODEL
|
||||
|
||||
|
||||
def _try_nous() -> tuple[OpenAI | None, str | None]:
|
||||
nous = _read_nous_auth()
|
||||
if not nous:
|
||||
return None, None
|
||||
global auxiliary_is_nous
|
||||
auxiliary_is_nous = True
|
||||
logger.debug("Auxiliary client: Nous Portal")
|
||||
return (
|
||||
OpenAI(api_key=_nous_api_key(nous), base_url=_nous_base_url()),
|
||||
_NOUS_MODEL,
|
||||
)
|
||||
|
||||
|
||||
def _try_custom_endpoint() -> tuple[OpenAI | None, str | None]:
|
||||
custom_base = os.getenv("OPENAI_BASE_URL")
|
||||
custom_key = os.getenv("OPENAI_API_KEY")
|
||||
if not custom_base or not custom_key:
|
||||
return None, None
|
||||
model = os.getenv("OPENAI_MODEL") or os.getenv("LLM_MODEL") or "gpt-4o-mini"
|
||||
logger.debug("Auxiliary client: custom endpoint (%s)", model)
|
||||
return OpenAI(api_key=custom_key, base_url=custom_base), model
|
||||
|
||||
|
||||
def _try_codex() -> tuple[Any | None, str | None]:
|
||||
codex_token = _read_codex_access_token()
|
||||
if not codex_token:
|
||||
return None, None
|
||||
logger.debug("Auxiliary client: Codex OAuth (%s via Responses API)", _CODEX_AUX_MODEL)
|
||||
real_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
|
||||
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> tuple[OpenAI | None, str | None]:
|
||||
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
|
||||
if forced == "openrouter":
|
||||
client, model = _try_openrouter()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=openrouter but OPENROUTER_API_KEY not set")
|
||||
return client, model
|
||||
|
||||
if forced == "nous":
|
||||
client, model = _try_nous()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=nous but Nous Portal not configured (run: hermes login)")
|
||||
return client, model
|
||||
|
||||
if forced == "codex":
|
||||
client, model = _try_codex()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=codex but no Codex OAuth token found (run: hermes model)")
|
||||
return client, model
|
||||
|
||||
if forced == "main":
|
||||
# "main" = skip OpenRouter/Nous, use the main chat model's credentials.
|
||||
for try_fn in (_try_custom_endpoint, _try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.warning("auxiliary.provider=main but no main endpoint credentials found")
|
||||
return None, None
|
||||
|
||||
# Unknown provider name — fall through to auto
|
||||
logger.warning("Unknown auxiliary.provider=%r, falling back to auto", forced)
|
||||
return None, None
|
||||
|
||||
|
||||
def _resolve_auto() -> tuple[OpenAI | None, str | None]:
|
||||
"""Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None."""
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint, _try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.debug("Auxiliary client: none available")
|
||||
return None, None
|
||||
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_text_auxiliary_client(task: str = "") -> tuple[OpenAI | None, str | None]:
|
||||
"""Return (client, default_model_slug) for text-only auxiliary tasks.
|
||||
|
||||
Args:
|
||||
task: Optional task name ("compression", "web_extract") to check
|
||||
for a task-specific provider override.
|
||||
|
||||
Callers may override the returned model with a per-task env var
|
||||
(e.g. CONTEXT_COMPRESSION_MODEL, AUXILIARY_WEB_EXTRACT_MODEL).
|
||||
"""
|
||||
forced = _get_auxiliary_provider(task)
|
||||
if forced != "auto":
|
||||
return _resolve_forced_provider(forced)
|
||||
return _resolve_auto()
|
||||
|
||||
|
||||
def get_async_text_auxiliary_client(task: str = ""):
|
||||
"""Return (async_client, model_slug) for async consumers.
|
||||
|
||||
For standard providers returns (AsyncOpenAI, model). For Codex returns
|
||||
(AsyncCodexAuxiliaryClient, model) which wraps the Responses API.
|
||||
Returns (None, None) when no provider is available.
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
sync_client, model = get_text_auxiliary_client(task)
|
||||
if sync_client is None:
|
||||
return None, None
|
||||
|
||||
if isinstance(sync_client, CodexAuxiliaryClient):
|
||||
return AsyncCodexAuxiliaryClient(sync_client), model
|
||||
|
||||
async_kwargs = {
|
||||
"api_key": sync_client.api_key,
|
||||
"base_url": str(sync_client.base_url),
|
||||
}
|
||||
if "openrouter" in str(sync_client.base_url).lower():
|
||||
async_kwargs["default_headers"] = dict(_OR_HEADERS)
|
||||
elif "api.kimi.com" in str(sync_client.base_url).lower():
|
||||
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
return AsyncOpenAI(**async_kwargs), model
|
||||
|
||||
|
||||
def get_vision_auxiliary_client() -> tuple[OpenAI | None, str | None]:
|
||||
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks.
|
||||
|
||||
Checks AUXILIARY_VISION_PROVIDER for a forced provider, otherwise
|
||||
auto-detects. Callers may override the returned model with
|
||||
AUXILIARY_VISION_MODEL.
|
||||
|
||||
In auto mode, only providers known to support multimodal are tried:
|
||||
OpenRouter, Nous Portal, and Codex OAuth (gpt-5.3-codex supports
|
||||
vision via the Responses API). Custom endpoints and API-key
|
||||
providers are skipped — they may not handle vision input. To use
|
||||
them, set AUXILIARY_VISION_PROVIDER explicitly.
|
||||
"""
|
||||
forced = _get_auxiliary_provider("vision")
|
||||
if forced != "auto":
|
||||
return _resolve_forced_provider(forced)
|
||||
# Auto: try providers known to support multimodal first, then fall
|
||||
# back to the user's custom endpoint. Many local models (Qwen-VL,
|
||||
# LLaVA, Pixtral, etc.) support vision — skipping them entirely
|
||||
# caused silent failures for local-only users.
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex, _try_custom_endpoint):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.debug("Auxiliary vision client: none available")
|
||||
return None, None
|
||||
|
||||
|
||||
def get_auxiliary_extra_body() -> dict:
|
||||
"""Return extra_body kwargs for auxiliary API calls.
|
||||
|
||||
Includes Nous Portal product tags when the auxiliary client is backed
|
||||
by Nous Portal. Returns empty dict otherwise.
|
||||
"""
|
||||
return dict(NOUS_EXTRA_BODY) if auxiliary_is_nous else {}
|
||||
|
||||
|
||||
def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
"""Return the correct max tokens kwarg for the auxiliary client's provider.
|
||||
|
||||
OpenRouter and local models use 'max_tokens'. Direct OpenAI with newer
|
||||
models (gpt-4o, o-series, gpt-5+) requires 'max_completion_tokens'.
|
||||
The Codex adapter translates max_tokens internally, so we use max_tokens
|
||||
for it as well.
|
||||
"""
|
||||
custom_base = os.getenv("OPENAI_BASE_URL", "")
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
# Only use max_completion_tokens for direct OpenAI custom endpoints
|
||||
if not or_key and _read_nous_auth() is None and "api.openai.com" in custom_base.lower():
|
||||
return {"max_completion_tokens": value}
|
||||
return {"max_tokens": value}
|
||||
380
agent/context_compressor.py
Normal file
380
agent/context_compressor.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""Automatic context window compression for long conversations.
|
||||
|
||||
Self-contained class with its own OpenAI client for summarization.
|
||||
Uses Gemini Flash (cheap/fast) to summarize middle turns while
|
||||
protecting head and tail context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
from agent.model_metadata import (
|
||||
estimate_messages_tokens_rough,
|
||||
get_model_context_length,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""Compresses conversation context when approaching the model's context limit.
|
||||
|
||||
Algorithm: protect first N + last N turns, summarize everything in between.
|
||||
Token tracking uses actual counts from API responses for accuracy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
threshold_percent: float = 0.85,
|
||||
protect_first_n: int = 3,
|
||||
protect_last_n: int = 4,
|
||||
summary_target_tokens: int = 2500,
|
||||
quiet_mode: bool = False,
|
||||
summary_model_override: str = None,
|
||||
base_url: str = "",
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.threshold_percent = threshold_percent
|
||||
self.protect_first_n = protect_first_n
|
||||
self.protect_last_n = protect_last_n
|
||||
self.summary_target_tokens = summary_target_tokens
|
||||
self.quiet_mode = quiet_mode
|
||||
|
||||
self.context_length = get_model_context_length(model, base_url=base_url)
|
||||
self.threshold_tokens = int(self.context_length * threshold_percent)
|
||||
self.compression_count = 0
|
||||
self._context_probed = False # True after a step-down from context error
|
||||
|
||||
self.last_prompt_tokens = 0
|
||||
self.last_completion_tokens = 0
|
||||
self.last_total_tokens = 0
|
||||
|
||||
self.client, default_model = get_text_auxiliary_client("compression")
|
||||
self.summary_model = summary_model_override or default_model
|
||||
|
||||
def update_from_response(self, usage: dict[str, Any]):
|
||||
"""Update tracked token usage from API response."""
|
||||
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
self.last_completion_tokens = usage.get("completion_tokens", 0)
|
||||
self.last_total_tokens = usage.get("total_tokens", 0)
|
||||
|
||||
def should_compress(self, prompt_tokens: int = None) -> bool:
|
||||
"""Check if context exceeds the compression threshold."""
|
||||
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
|
||||
return tokens >= self.threshold_tokens
|
||||
|
||||
def should_compress_preflight(self, messages: list[dict[str, Any]]) -> bool:
|
||||
"""Quick pre-flight check using rough estimate (before API call)."""
|
||||
rough_estimate = estimate_messages_tokens_rough(messages)
|
||||
return rough_estimate >= self.threshold_tokens
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""Get current compression status for display/logging."""
|
||||
return {
|
||||
"last_prompt_tokens": self.last_prompt_tokens,
|
||||
"threshold_tokens": self.threshold_tokens,
|
||||
"context_length": self.context_length,
|
||||
"usage_percent": (self.last_prompt_tokens / self.context_length * 100) if self.context_length else 0,
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: list[dict[str, Any]]) -> str | None:
|
||||
"""Generate a concise summary of conversation turns.
|
||||
|
||||
Tries the auxiliary model first, then falls back to the user's main
|
||||
model. Returns None if all attempts fail — the caller should drop
|
||||
the middle turns without a summary rather than inject a useless
|
||||
placeholder.
|
||||
"""
|
||||
parts = []
|
||||
for msg in turns_to_summarize:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content") or ""
|
||||
if len(content) > 2000:
|
||||
content = content[:1000] + "\n...[truncated]...\n" + content[-500:]
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
tool_names = [tc.get("function", {}).get("name", "?") for tc in tool_calls if isinstance(tc, dict)]
|
||||
content += f"\n[Tool calls: {', '.join(tool_names)}]"
|
||||
parts.append(f"[{role.upper()}]: {content}")
|
||||
|
||||
content_to_summarize = "\n\n".join(parts)
|
||||
prompt = f"""Summarize these conversation turns concisely. This summary will replace these turns in the conversation history.
|
||||
|
||||
Write from a neutral perspective describing:
|
||||
1. What actions were taken (tool calls, searches, file operations)
|
||||
2. Key information or results obtained
|
||||
3. Important decisions or findings
|
||||
4. Relevant data, file names, or outputs
|
||||
|
||||
Keep factual and informative. Target ~{self.summary_target_tokens} tokens.
|
||||
|
||||
---
|
||||
TURNS TO SUMMARIZE:
|
||||
{content_to_summarize}
|
||||
---
|
||||
|
||||
Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
|
||||
# 1. Try the auxiliary model (cheap/fast)
|
||||
if self.client:
|
||||
try:
|
||||
return self._call_summary_model(self.client, self.summary_model, prompt)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to generate context summary with auxiliary model: {e}")
|
||||
|
||||
# 2. Fallback: try the user's main model endpoint
|
||||
fallback_client, fallback_model = self._get_fallback_client()
|
||||
if fallback_client is not None:
|
||||
try:
|
||||
logger.info("Retrying context summary with main model (%s)", fallback_model)
|
||||
summary = self._call_summary_model(fallback_client, fallback_model, prompt)
|
||||
self.client = fallback_client
|
||||
self.summary_model = fallback_model
|
||||
return summary
|
||||
except Exception as fallback_err:
|
||||
logging.warning(f"Main model summary also failed: {fallback_err}")
|
||||
|
||||
# 3. All models failed — return None so the caller drops turns without a summary
|
||||
logging.warning(
|
||||
"Context compression: no model available for summary. Middle turns will be dropped without summary."
|
||||
)
|
||||
return None
|
||||
|
||||
def _call_summary_model(self, client, model: str, prompt: str) -> str:
|
||||
"""Make the actual LLM call to generate a summary. Raises on failure."""
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.3,
|
||||
"timeout": 30.0,
|
||||
}
|
||||
# Most providers (OpenRouter, local models) use max_tokens.
|
||||
# Direct OpenAI with newer models (gpt-4o, o-series, gpt-5+)
|
||||
# requires max_completion_tokens instead.
|
||||
try:
|
||||
kwargs["max_tokens"] = self.summary_target_tokens * 2
|
||||
response = client.chat.completions.create(**kwargs)
|
||||
except Exception as first_err:
|
||||
if "max_tokens" in str(first_err) or "unsupported_parameter" in str(first_err):
|
||||
kwargs.pop("max_tokens", None)
|
||||
kwargs["max_completion_tokens"] = self.summary_target_tokens * 2
|
||||
response = client.chat.completions.create(**kwargs)
|
||||
else:
|
||||
raise
|
||||
|
||||
summary = response.choices[0].message.content.strip()
|
||||
if not summary.startswith("[CONTEXT SUMMARY]:"):
|
||||
summary = "[CONTEXT SUMMARY]: " + summary
|
||||
return summary
|
||||
|
||||
def _get_fallback_client(self):
|
||||
"""Try to build a fallback client from the main model's endpoint config.
|
||||
|
||||
When the primary auxiliary client fails (e.g. stale OpenRouter key), this
|
||||
creates a client using the user's active custom endpoint (OPENAI_BASE_URL)
|
||||
so compression can still produce a real summary instead of a static string.
|
||||
|
||||
Returns (client, model) or (None, None).
|
||||
"""
|
||||
custom_base = os.getenv("OPENAI_BASE_URL")
|
||||
custom_key = os.getenv("OPENAI_API_KEY")
|
||||
if not custom_base or not custom_key:
|
||||
return None, None
|
||||
|
||||
# Don't fallback to the same provider that just failed
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
if custom_base.rstrip("/") == OPENROUTER_BASE_URL.rstrip("/"):
|
||||
return None, None
|
||||
|
||||
model = os.getenv("LLM_MODEL") or os.getenv("OPENAI_MODEL") or self.model
|
||||
try:
|
||||
from openai import OpenAI as _OpenAI
|
||||
|
||||
client = _OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
logger.debug("Built fallback auxiliary client: %s via %s", model, custom_base)
|
||||
return client, model
|
||||
except Exception as exc:
|
||||
logger.debug("Could not build fallback auxiliary client: %s", exc)
|
||||
return None, None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-call / tool-result pair integrity helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_call_id(tc) -> str:
|
||||
"""Extract the call ID from a tool_call entry (dict or SimpleNamespace)."""
|
||||
if isinstance(tc, dict):
|
||||
return tc.get("id", "")
|
||||
return getattr(tc, "id", "") or ""
|
||||
|
||||
def _sanitize_tool_pairs(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Fix orphaned tool_call / tool_result pairs after compression.
|
||||
|
||||
Two failure modes:
|
||||
1. A tool *result* references a call_id whose assistant tool_call was
|
||||
removed (summarized/truncated). The API rejects this with
|
||||
"No tool call found for function call output with call_id ...".
|
||||
2. An assistant message has tool_calls whose results were dropped.
|
||||
The API rejects this because every tool_call must be followed by
|
||||
a tool result with the matching call_id.
|
||||
|
||||
This method removes orphaned results and inserts stub results for
|
||||
orphaned calls so the message list is always well-formed.
|
||||
"""
|
||||
surviving_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = self._get_tool_call_id(tc)
|
||||
if cid:
|
||||
surviving_call_ids.add(cid)
|
||||
|
||||
result_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
cid = msg.get("tool_call_id")
|
||||
if cid:
|
||||
result_call_ids.add(cid)
|
||||
|
||||
# 1. Remove tool results whose call_id has no matching assistant tool_call
|
||||
orphaned_results = result_call_ids - surviving_call_ids
|
||||
if orphaned_results:
|
||||
messages = [
|
||||
m for m in messages if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||
]
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results))
|
||||
|
||||
# 2. Add stub results for assistant tool_calls whose results were dropped
|
||||
missing_results = surviving_call_ids - result_call_ids
|
||||
if missing_results:
|
||||
patched: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = self._get_tool_call_id(tc)
|
||||
if cid in missing_results:
|
||||
patched.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "[Result from earlier conversation — see context summary above]",
|
||||
"tool_call_id": cid,
|
||||
}
|
||||
)
|
||||
messages = patched
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results))
|
||||
|
||||
return messages
|
||||
|
||||
def _align_boundary_forward(self, messages: list[dict[str, Any]], idx: int) -> int:
|
||||
"""Push a compress-start boundary forward past any orphan tool results.
|
||||
|
||||
If ``messages[idx]`` is a tool result, slide forward until we hit a
|
||||
non-tool message so we don't start the summarised region mid-group.
|
||||
"""
|
||||
while idx < len(messages) and messages[idx].get("role") == "tool":
|
||||
idx += 1
|
||||
return idx
|
||||
|
||||
def _align_boundary_backward(self, messages: list[dict[str, Any]], idx: int) -> int:
|
||||
"""Pull a compress-end boundary backward to avoid splitting a
|
||||
tool_call / result group.
|
||||
|
||||
If the message just before ``idx`` is an assistant message with
|
||||
tool_calls, those tool results will start at ``idx`` and would be
|
||||
separated from their parent. Move backwards to include the whole
|
||||
group in the summarised region.
|
||||
"""
|
||||
if idx <= 0 or idx >= len(messages):
|
||||
return idx
|
||||
prev = messages[idx - 1]
|
||||
if prev.get("role") == "assistant" and prev.get("tool_calls"):
|
||||
# The results for this assistant turn sit at idx..idx+k.
|
||||
# Include the assistant message in the summarised region too.
|
||||
idx -= 1
|
||||
return idx
|
||||
|
||||
def compress(self, messages: list[dict[str, Any]], current_tokens: int = None) -> list[dict[str, Any]]:
|
||||
"""Compress conversation messages by summarizing middle turns.
|
||||
|
||||
Keeps first N + last N turns, summarizes everything in between.
|
||||
After compression, orphaned tool_call / tool_result pairs are cleaned
|
||||
up so the API never receives mismatched IDs.
|
||||
"""
|
||||
n_messages = len(messages)
|
||||
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
|
||||
if not self.quiet_mode:
|
||||
print(
|
||||
f"⚠️ Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})"
|
||||
)
|
||||
return messages
|
||||
|
||||
compress_start = self.protect_first_n
|
||||
compress_end = n_messages - self.protect_last_n
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
# Adjust boundaries to avoid splitting tool_call/result groups.
|
||||
compress_start = self._align_boundary_forward(messages, compress_start)
|
||||
compress_end = self._align_boundary_backward(messages, compress_end)
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
display_tokens = (
|
||||
current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
)
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(
|
||||
f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)"
|
||||
)
|
||||
print(
|
||||
f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent * 100:.0f}% = {self.threshold_tokens:,})"
|
||||
)
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f" 🗜️ Summarizing turns {compress_start + 1}-{compress_end} ({len(turns_to_summarize)} turns)")
|
||||
|
||||
summary = self._generate_summary(turns_to_summarize)
|
||||
|
||||
compressed = []
|
||||
for i in range(compress_start):
|
||||
msg = messages[i].copy()
|
||||
if i == 0 and msg.get("role") == "system" and self.compression_count == 0:
|
||||
msg["content"] = (
|
||||
msg.get("content") or ""
|
||||
) + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
|
||||
compressed.append(msg)
|
||||
|
||||
if summary:
|
||||
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
||||
summary_role = "user" if last_head_role in ("assistant", "tool") else "assistant"
|
||||
compressed.append({"role": summary_role, "content": summary})
|
||||
else:
|
||||
if not self.quiet_mode:
|
||||
print(" ⚠️ No summary model available — middle turns dropped without summary")
|
||||
|
||||
for i in range(compress_end, n_messages):
|
||||
compressed.append(messages[i].copy())
|
||||
|
||||
self.compression_count += 1
|
||||
|
||||
compressed = self._sanitize_tool_pairs(compressed)
|
||||
|
||||
if not self.quiet_mode:
|
||||
new_estimate = estimate_messages_tokens_rough(compressed)
|
||||
saved_estimate = display_tokens - new_estimate
|
||||
print(f" ✅ Compressed: {n_messages} → {len(compressed)} messages (~{saved_estimate:,} tokens saved)")
|
||||
print(f" 💡 Compression #{self.compression_count} complete")
|
||||
|
||||
return compressed
|
||||
601
agent/display.py
Normal file
601
agent/display.py
Normal file
@@ -0,0 +1,601 @@
|
||||
"""CLI presentation -- spinner, kawaii faces, tool preview formatting.
|
||||
|
||||
Pure display functions and classes with no AIAgent dependency.
|
||||
Used by AIAgent._execute_tool_calls for CLI feedback.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
# ANSI escape codes for coloring tool failure indicators
|
||||
_RED = "\033[31m"
|
||||
_RESET = "\033[0m"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tool preview (one-line summary of a tool call's primary argument)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
"""Build a short preview of a tool call's primary argument for display."""
|
||||
primary_args = {
|
||||
"terminal": "command",
|
||||
"web_search": "query",
|
||||
"web_extract": "urls",
|
||||
"read_file": "path",
|
||||
"write_file": "path",
|
||||
"patch": "path",
|
||||
"search_files": "pattern",
|
||||
"browser_navigate": "url",
|
||||
"browser_click": "ref",
|
||||
"browser_type": "text",
|
||||
"image_generate": "prompt",
|
||||
"text_to_speech": "text",
|
||||
"vision_analyze": "question",
|
||||
"mixture_of_agents": "user_prompt",
|
||||
"skill_view": "name",
|
||||
"skills_list": "category",
|
||||
"schedule_cronjob": "name",
|
||||
"execute_code": "code",
|
||||
"delegate_task": "goal",
|
||||
"clarify": "question",
|
||||
"skill_manage": "name",
|
||||
}
|
||||
|
||||
if tool_name == "process":
|
||||
action = args.get("action", "")
|
||||
sid = args.get("session_id", "")
|
||||
data = args.get("data", "")
|
||||
timeout_val = args.get("timeout")
|
||||
parts = [action]
|
||||
if sid:
|
||||
parts.append(sid[:16])
|
||||
if data:
|
||||
parts.append(f'"{data[:20]}"')
|
||||
if timeout_val and action == "wait":
|
||||
parts.append(f"{timeout_val}s")
|
||||
return " ".join(parts) if parts else None
|
||||
|
||||
if tool_name == "todo":
|
||||
todos_arg = args.get("todos")
|
||||
merge = args.get("merge", False)
|
||||
if todos_arg is None:
|
||||
return "reading task list"
|
||||
elif merge:
|
||||
return f"updating {len(todos_arg)} task(s)"
|
||||
else:
|
||||
return f"planning {len(todos_arg)} task(s)"
|
||||
|
||||
if tool_name == "session_search":
|
||||
query = args.get("query", "")
|
||||
return f'recall: "{query[:25]}{"..." if len(query) > 25 else ""}"'
|
||||
|
||||
if tool_name == "memory":
|
||||
action = args.get("action", "")
|
||||
target = args.get("target", "")
|
||||
if action == "add":
|
||||
content = args.get("content", "")
|
||||
return f'+{target}: "{content[:25]}{"..." if len(content) > 25 else ""}"'
|
||||
elif action == "replace":
|
||||
return f'~{target}: "{args.get("old_text", "")[:20]}"'
|
||||
elif action == "remove":
|
||||
return f'-{target}: "{args.get("old_text", "")[:20]}"'
|
||||
return action
|
||||
|
||||
if tool_name == "send_message":
|
||||
target = args.get("target", "?")
|
||||
msg = args.get("message", "")
|
||||
if len(msg) > 20:
|
||||
msg = msg[:17] + "..."
|
||||
return f'to {target}: "{msg}"'
|
||||
|
||||
if tool_name.startswith("rl_"):
|
||||
rl_previews = {
|
||||
"rl_list_environments": "listing envs",
|
||||
"rl_select_environment": args.get("name", ""),
|
||||
"rl_get_current_config": "reading config",
|
||||
"rl_edit_config": f"{args.get('field', '')}={args.get('value', '')}",
|
||||
"rl_start_training": "starting",
|
||||
"rl_check_status": args.get("run_id", "")[:16],
|
||||
"rl_stop_training": f"stopping {args.get('run_id', '')[:16]}",
|
||||
"rl_get_results": args.get("run_id", "")[:16],
|
||||
"rl_list_runs": "listing runs",
|
||||
"rl_test_inference": f"{args.get('num_steps', 3)} steps",
|
||||
}
|
||||
return rl_previews.get(tool_name)
|
||||
|
||||
key = primary_args.get(tool_name)
|
||||
if not key:
|
||||
for fallback_key in ("query", "text", "command", "path", "name", "prompt", "code", "goal"):
|
||||
if fallback_key in args:
|
||||
key = fallback_key
|
||||
break
|
||||
|
||||
if not key or key not in args:
|
||||
return None
|
||||
|
||||
value = args[key]
|
||||
if isinstance(value, list):
|
||||
value = value[0] if value else ""
|
||||
|
||||
preview = str(value).strip()
|
||||
if not preview:
|
||||
return None
|
||||
if len(preview) > max_len:
|
||||
preview = preview[: max_len - 3] + "..."
|
||||
return preview
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# KawaiiSpinner
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class KawaiiSpinner:
|
||||
"""Animated spinner with kawaii faces for CLI feedback during tool execution."""
|
||||
|
||||
SPINNERS = {
|
||||
"dots": ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"],
|
||||
"bounce": ["⠁", "⠂", "⠄", "⡀", "⢀", "⠠", "⠐", "⠈"],
|
||||
"grow": ["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█", "▇", "▆", "▅", "▄", "▃", "▂"],
|
||||
"arrows": ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"],
|
||||
"star": ["✶", "✷", "✸", "✹", "✺", "✹", "✸", "✷"],
|
||||
"moon": ["🌑", "🌒", "🌓", "🌔", "🌕", "🌖", "🌗", "🌘"],
|
||||
"pulse": ["◜", "◠", "◝", "◞", "◡", "◟"],
|
||||
"brain": ["🧠", "💭", "💡", "✨", "💫", "🌟", "💡", "💭"],
|
||||
"sparkle": ["⁺", "˚", "*", "✧", "✦", "✧", "*", "˚"],
|
||||
}
|
||||
|
||||
KAWAII_WAITING = [
|
||||
"(。◕‿◕。)",
|
||||
"(◕‿◕✿)",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿◠‿◠)",
|
||||
"( ˘▽˘)っ",
|
||||
"♪(´ε` )",
|
||||
"(◕ᴗ◕✿)",
|
||||
"ヾ(^∇^)",
|
||||
"(≧◡≦)",
|
||||
"(★ω★)",
|
||||
]
|
||||
|
||||
KAWAII_THINKING = [
|
||||
"(。•́︿•̀。)",
|
||||
"(◔_◔)",
|
||||
"(¬‿¬)",
|
||||
"( •_•)>⌐■-■",
|
||||
"(⌐■_■)",
|
||||
"(´・_・`)",
|
||||
"◉_◉",
|
||||
"(°ロ°)",
|
||||
"( ˘⌣˘)♡",
|
||||
"ヽ(>∀<☆)☆",
|
||||
"٩(๑❛ᴗ❛๑)۶",
|
||||
"(⊙_⊙)",
|
||||
"(¬_¬)",
|
||||
"( ͡° ͜ʖ ͡°)",
|
||||
"ಠ_ಠ",
|
||||
]
|
||||
|
||||
THINKING_VERBS = [
|
||||
"pondering",
|
||||
"contemplating",
|
||||
"musing",
|
||||
"cogitating",
|
||||
"ruminating",
|
||||
"deliberating",
|
||||
"mulling",
|
||||
"reflecting",
|
||||
"processing",
|
||||
"reasoning",
|
||||
"analyzing",
|
||||
"computing",
|
||||
"synthesizing",
|
||||
"formulating",
|
||||
"brainstorming",
|
||||
]
|
||||
|
||||
def __init__(self, message: str = "", spinner_type: str = "dots"):
|
||||
self.message = message
|
||||
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS["dots"])
|
||||
self.running = False
|
||||
self.thread = None
|
||||
self.frame_idx = 0
|
||||
self.start_time = None
|
||||
self.last_line_len = 0
|
||||
# Capture stdout NOW, before any redirect_stdout(devnull) from
|
||||
# child agents can replace sys.stdout with a black hole.
|
||||
self._out = sys.stdout
|
||||
|
||||
def _write(self, text: str, end: str = "\n", flush: bool = False):
|
||||
"""Write to the stdout captured at spinner creation time."""
|
||||
try:
|
||||
self._out.write(text + end)
|
||||
if flush:
|
||||
self._out.flush()
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
def _animate(self):
|
||||
while self.running:
|
||||
if os.getenv("HERMES_SPINNER_PAUSE"):
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
frame = self.spinner_frames[self.frame_idx % len(self.spinner_frames)]
|
||||
elapsed = time.time() - self.start_time
|
||||
line = f" {frame} {self.message} ({elapsed:.1f}s)"
|
||||
pad = max(self.last_line_len - len(line), 0)
|
||||
self._write(f"\r{line}{' ' * pad}", end="", flush=True)
|
||||
self.last_line_len = len(line)
|
||||
self.frame_idx += 1
|
||||
time.sleep(0.12)
|
||||
|
||||
def start(self):
|
||||
if self.running:
|
||||
return
|
||||
self.running = True
|
||||
self.start_time = time.time()
|
||||
self.thread = threading.Thread(target=self._animate, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def update_text(self, new_message: str):
|
||||
self.message = new_message
|
||||
|
||||
def print_above(self, text: str):
|
||||
"""Print a line above the spinner without disrupting animation.
|
||||
|
||||
Clears the current spinner line, prints the text, and lets the
|
||||
next animation tick redraw the spinner on the line below.
|
||||
Thread-safe: uses the captured stdout reference (self._out).
|
||||
Works inside redirect_stdout(devnull) because _write bypasses
|
||||
sys.stdout and writes to the stdout captured at spinner creation.
|
||||
"""
|
||||
if not self.running:
|
||||
self._write(f" {text}", flush=True)
|
||||
return
|
||||
# Clear spinner line with spaces (not \033[K) to avoid garbled escape
|
||||
# codes when prompt_toolkit's patch_stdout is active — same approach
|
||||
# as stop(). Then print text; spinner redraws on next tick.
|
||||
blanks = " " * max(self.last_line_len + 5, 40)
|
||||
self._write(f"\r{blanks}\r {text}", flush=True)
|
||||
|
||||
def stop(self, final_message: str = None):
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=0.5)
|
||||
# Clear the spinner line with spaces instead of \033[K to avoid
|
||||
# garbled escape codes when prompt_toolkit's patch_stdout is active.
|
||||
blanks = " " * max(self.last_line_len + 5, 40)
|
||||
self._write(f"\r{blanks}\r", end="", flush=True)
|
||||
if final_message:
|
||||
self._write(f" {final_message}", flush=True)
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Kawaii face arrays (used by AIAgent._execute_tool_calls for spinner text)
|
||||
# =========================================================================
|
||||
|
||||
KAWAII_SEARCH = [
|
||||
"♪(´ε` )",
|
||||
"(。◕‿◕。)",
|
||||
"ヾ(^∇^)",
|
||||
"(◕ᴗ◕✿)",
|
||||
"( ˘▽˘)っ",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿◠‿◠)",
|
||||
"♪~(´ε` )",
|
||||
"(ノ´ヮ`)ノ*:・゚✧",
|
||||
"\(◎o◎)/",
|
||||
]
|
||||
KAWAII_READ = [
|
||||
"φ(゜▽゜*)♪",
|
||||
"( ˘▽˘)っ",
|
||||
"(⌐■_■)",
|
||||
"٩(。•́‿•̀。)۶",
|
||||
"(◕‿◕✿)",
|
||||
"ヾ(@⌒ー⌒@)ノ",
|
||||
"(✧ω✧)",
|
||||
"♪(๑ᴖ◡ᴖ๑)♪",
|
||||
"(≧◡≦)",
|
||||
"( ´ ▽ ` )ノ",
|
||||
]
|
||||
KAWAII_TERMINAL = [
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"(ノ°∀°)ノ",
|
||||
"٩(^ᴗ^)۶",
|
||||
"ヾ(⌐■_■)ノ♪",
|
||||
"(•̀ᴗ•́)و",
|
||||
"┗(^0^)┓",
|
||||
"(`・ω・´)",
|
||||
"\( ̄▽ ̄)/",
|
||||
"(ง •̀_•́)ง",
|
||||
"ヽ(´▽`)/",
|
||||
]
|
||||
KAWAII_BROWSER = [
|
||||
"(ノ°∀°)ノ",
|
||||
"(☞゚ヮ゚)☞",
|
||||
"( ͡° ͜ʖ ͡°)",
|
||||
"┌( ಠ_ಠ)┘",
|
||||
"(⊙_⊙)?",
|
||||
"ヾ(•ω•`)o",
|
||||
"( ̄ω ̄)",
|
||||
"( ˇωˇ )",
|
||||
"(ᵔᴥᵔ)",
|
||||
"\(◎o◎)/",
|
||||
]
|
||||
KAWAII_CREATE = [
|
||||
"✧*。٩(ˊᗜˋ*)و✧",
|
||||
"(ノ◕ヮ◕)ノ*:・゚✧",
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"٩(♡ε♡)۶",
|
||||
"(◕‿◕)♡",
|
||||
"✿◕ ‿ ◕✿",
|
||||
"(*≧▽≦)",
|
||||
"ヾ(^-^)ノ",
|
||||
"(☆▽☆)",
|
||||
"°˖✧◝(⁰▿⁰)◜✧˖°",
|
||||
]
|
||||
KAWAII_SKILL = [
|
||||
"ヾ(@⌒ー⌒@)ノ",
|
||||
"(๑˃ᴗ˂)ﻭ",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿╹◡╹)",
|
||||
"ヽ(・∀・)ノ",
|
||||
"(ノ´ヮ`)ノ*:・゚✧",
|
||||
"♪(๑ᴖ◡ᴖ๑)♪",
|
||||
"(◠‿◠)",
|
||||
"٩(ˊᗜˋ*)و",
|
||||
"(^▽^)",
|
||||
"ヾ(^∇^)",
|
||||
"(★ω★)/",
|
||||
"٩(。•́‿•̀。)۶",
|
||||
"(◕ᴗ◕✿)",
|
||||
"\(◎o◎)/",
|
||||
"(✧ω✧)",
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"( ˘▽˘)っ",
|
||||
"(≧◡≦) ♡",
|
||||
"ヾ( ̄▽ ̄)",
|
||||
]
|
||||
KAWAII_THINK = [
|
||||
"(っ°Д°;)っ",
|
||||
"(;′⌒`)",
|
||||
"(・_・ヾ",
|
||||
"( ´_ゝ`)",
|
||||
"( ̄ヘ ̄)",
|
||||
"(。-`ω´-)",
|
||||
"( ˘︹˘ )",
|
||||
"(¬_¬)",
|
||||
"ヽ(ー_ー )ノ",
|
||||
"(;一_一)",
|
||||
]
|
||||
KAWAII_GENERIC = [
|
||||
"♪(´ε` )",
|
||||
"(◕‿◕✿)",
|
||||
"ヾ(^∇^)",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿◠‿◠)",
|
||||
"(ノ´ヮ`)ノ*:・゚✧",
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"(☆▽☆)",
|
||||
"( ˘▽˘)っ",
|
||||
"(≧◡≦)",
|
||||
]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Cute tool message (completion line that replaces the spinner)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]:
|
||||
"""Inspect a tool result string for signs of failure.
|
||||
|
||||
Returns ``(is_failure, suffix)`` where *suffix* is an informational tag
|
||||
like ``" [exit 1]"`` for terminal failures, or ``" [error]"`` for generic
|
||||
failures. On success, returns ``(False, "")``.
|
||||
"""
|
||||
if result is None:
|
||||
return False, ""
|
||||
|
||||
if tool_name == "terminal":
|
||||
try:
|
||||
data = json.loads(result)
|
||||
exit_code = data.get("exit_code")
|
||||
if exit_code is not None and exit_code != 0:
|
||||
return True, f" [exit {exit_code}]"
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
pass
|
||||
return False, ""
|
||||
|
||||
# Memory-specific: distinguish "full" from real errors
|
||||
if tool_name == "memory":
|
||||
try:
|
||||
data = json.loads(result)
|
||||
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
|
||||
return True, " [full]"
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
pass
|
||||
|
||||
# Generic heuristic for non-terminal tools
|
||||
lower = result[:500].lower()
|
||||
if '"error"' in lower or '"failed"' in lower or result.startswith("Error"):
|
||||
return True, " [error]"
|
||||
|
||||
return False, ""
|
||||
|
||||
|
||||
def get_cute_tool_message(
|
||||
tool_name: str,
|
||||
args: dict,
|
||||
duration: float,
|
||||
result: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a formatted tool completion line for CLI quiet mode.
|
||||
|
||||
Format: ``| {emoji} {verb:9} {detail} {duration}``
|
||||
|
||||
When *result* is provided the line is checked for failure indicators.
|
||||
Failed tool calls get a red prefix and an informational suffix.
|
||||
"""
|
||||
dur = f"{duration:.1f}s"
|
||||
is_failure, failure_suffix = _detect_tool_failure(tool_name, result)
|
||||
|
||||
def _trunc(s, n=40):
|
||||
s = str(s)
|
||||
return (s[: n - 3] + "...") if len(s) > n else s
|
||||
|
||||
def _path(p, n=35):
|
||||
p = str(p)
|
||||
return ("..." + p[-(n - 3) :]) if len(p) > n else p
|
||||
|
||||
def _wrap(line: str) -> str:
|
||||
"""Append failure suffix when the tool failed."""
|
||||
if not is_failure:
|
||||
return line
|
||||
return f"{line}{failure_suffix}"
|
||||
|
||||
if tool_name == "web_search":
|
||||
return _wrap(f"┊ 🔍 search {_trunc(args.get('query', ''), 42)} {dur}")
|
||||
if tool_name == "web_extract":
|
||||
urls = args.get("urls", [])
|
||||
if urls:
|
||||
url = urls[0] if isinstance(urls, list) else str(urls)
|
||||
domain = url.replace("https://", "").replace("http://", "").split("/")[0]
|
||||
extra = f" +{len(urls) - 1}" if len(urls) > 1 else ""
|
||||
return _wrap(f"┊ 📄 fetch {_trunc(domain, 35)}{extra} {dur}")
|
||||
return _wrap(f"┊ 📄 fetch pages {dur}")
|
||||
if tool_name == "web_crawl":
|
||||
url = args.get("url", "")
|
||||
domain = url.replace("https://", "").replace("http://", "").split("/")[0]
|
||||
return _wrap(f"┊ 🕸️ crawl {_trunc(domain, 35)} {dur}")
|
||||
if tool_name == "terminal":
|
||||
return _wrap(f"┊ 💻 $ {_trunc(args.get('command', ''), 42)} {dur}")
|
||||
if tool_name == "process":
|
||||
action = args.get("action", "?")
|
||||
sid = args.get("session_id", "")[:12]
|
||||
labels = {
|
||||
"list": "ls processes",
|
||||
"poll": f"poll {sid}",
|
||||
"log": f"log {sid}",
|
||||
"wait": f"wait {sid}",
|
||||
"kill": f"kill {sid}",
|
||||
"write": f"write {sid}",
|
||||
"submit": f"submit {sid}",
|
||||
}
|
||||
return _wrap(f"┊ ⚙️ proc {labels.get(action, f'{action} {sid}')} {dur}")
|
||||
if tool_name == "read_file":
|
||||
return _wrap(f"┊ 📖 read {_path(args.get('path', ''))} {dur}")
|
||||
if tool_name == "write_file":
|
||||
return _wrap(f"┊ ✍️ write {_path(args.get('path', ''))} {dur}")
|
||||
if tool_name == "patch":
|
||||
return _wrap(f"┊ 🔧 patch {_path(args.get('path', ''))} {dur}")
|
||||
if tool_name == "search_files":
|
||||
pattern = _trunc(args.get("pattern", ""), 35)
|
||||
target = args.get("target", "content")
|
||||
verb = "find" if target == "files" else "grep"
|
||||
return _wrap(f"┊ 🔎 {verb:9} {pattern} {dur}")
|
||||
if tool_name == "browser_navigate":
|
||||
url = args.get("url", "")
|
||||
domain = url.replace("https://", "").replace("http://", "").split("/")[0]
|
||||
return _wrap(f"┊ 🌐 navigate {_trunc(domain, 35)} {dur}")
|
||||
if tool_name == "browser_snapshot":
|
||||
mode = "full" if args.get("full") else "compact"
|
||||
return _wrap(f"┊ 📸 snapshot {mode} {dur}")
|
||||
if tool_name == "browser_click":
|
||||
return _wrap(f"┊ 👆 click {args.get('ref', '?')} {dur}")
|
||||
if tool_name == "browser_type":
|
||||
return _wrap(f'┊ ⌨️ type "{_trunc(args.get("text", ""), 30)}" {dur}')
|
||||
if tool_name == "browser_scroll":
|
||||
d = args.get("direction", "down")
|
||||
arrow = {"down": "↓", "up": "↑", "right": "→", "left": "←"}.get(d, "↓")
|
||||
return _wrap(f"┊ {arrow} scroll {d} {dur}")
|
||||
if tool_name == "browser_back":
|
||||
return _wrap(f"┊ ◀️ back {dur}")
|
||||
if tool_name == "browser_press":
|
||||
return _wrap(f"┊ ⌨️ press {args.get('key', '?')} {dur}")
|
||||
if tool_name == "browser_close":
|
||||
return _wrap(f"┊ 🚪 close browser {dur}")
|
||||
if tool_name == "browser_get_images":
|
||||
return _wrap(f"┊ 🖼️ images extracting {dur}")
|
||||
if tool_name == "browser_vision":
|
||||
return _wrap(f"┊ 👁️ vision analyzing page {dur}")
|
||||
if tool_name == "todo":
|
||||
todos_arg = args.get("todos")
|
||||
merge = args.get("merge", False)
|
||||
if todos_arg is None:
|
||||
return _wrap(f"┊ 📋 plan reading tasks {dur}")
|
||||
elif merge:
|
||||
return _wrap(f"┊ 📋 plan update {len(todos_arg)} task(s) {dur}")
|
||||
else:
|
||||
return _wrap(f"┊ 📋 plan {len(todos_arg)} task(s) {dur}")
|
||||
if tool_name == "session_search":
|
||||
return _wrap(f'┊ 🔍 recall "{_trunc(args.get("query", ""), 35)}" {dur}')
|
||||
if tool_name == "memory":
|
||||
action = args.get("action", "?")
|
||||
target = args.get("target", "")
|
||||
if action == "add":
|
||||
return _wrap(f'┊ 🧠 memory +{target}: "{_trunc(args.get("content", ""), 30)}" {dur}')
|
||||
elif action == "replace":
|
||||
return _wrap(f'┊ 🧠 memory ~{target}: "{_trunc(args.get("old_text", ""), 20)}" {dur}')
|
||||
elif action == "remove":
|
||||
return _wrap(f'┊ 🧠 memory -{target}: "{_trunc(args.get("old_text", ""), 20)}" {dur}')
|
||||
return _wrap(f"┊ 🧠 memory {action} {dur}")
|
||||
if tool_name == "skills_list":
|
||||
return _wrap(f"┊ 📚 skills list {args.get('category', 'all')} {dur}")
|
||||
if tool_name == "skill_view":
|
||||
return _wrap(f"┊ 📚 skill {_trunc(args.get('name', ''), 30)} {dur}")
|
||||
if tool_name == "image_generate":
|
||||
return _wrap(f"┊ 🎨 create {_trunc(args.get('prompt', ''), 35)} {dur}")
|
||||
if tool_name == "text_to_speech":
|
||||
return _wrap(f"┊ 🔊 speak {_trunc(args.get('text', ''), 30)} {dur}")
|
||||
if tool_name == "vision_analyze":
|
||||
return _wrap(f"┊ 👁️ vision {_trunc(args.get('question', ''), 30)} {dur}")
|
||||
if tool_name == "mixture_of_agents":
|
||||
return _wrap(f"┊ 🧠 reason {_trunc(args.get('user_prompt', ''), 30)} {dur}")
|
||||
if tool_name == "send_message":
|
||||
return _wrap(f'┊ 📨 send {args.get("target", "?")}: "{_trunc(args.get("message", ""), 25)}" {dur}')
|
||||
if tool_name == "schedule_cronjob":
|
||||
return _wrap(f"┊ ⏰ schedule {_trunc(args.get('name', args.get('prompt', 'task')), 30)} {dur}")
|
||||
if tool_name == "list_cronjobs":
|
||||
return _wrap(f"┊ ⏰ jobs listing {dur}")
|
||||
if tool_name == "remove_cronjob":
|
||||
return _wrap(f"┊ ⏰ remove job {args.get('job_id', '?')} {dur}")
|
||||
if tool_name.startswith("rl_"):
|
||||
rl = {
|
||||
"rl_list_environments": "list envs",
|
||||
"rl_select_environment": f"select {args.get('name', '')}",
|
||||
"rl_get_current_config": "get config",
|
||||
"rl_edit_config": f"set {args.get('field', '?')}",
|
||||
"rl_start_training": "start training",
|
||||
"rl_check_status": f"status {args.get('run_id', '?')[:12]}",
|
||||
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}",
|
||||
"rl_get_results": f"results {args.get('run_id', '?')[:12]}",
|
||||
"rl_list_runs": "list runs",
|
||||
"rl_test_inference": "test inference",
|
||||
}
|
||||
return _wrap(f"┊ 🧪 rl {rl.get(tool_name, tool_name.replace('rl_', ''))} {dur}")
|
||||
if tool_name == "execute_code":
|
||||
code = args.get("code", "")
|
||||
first_line = code.strip().split("\n")[0] if code.strip() else ""
|
||||
return _wrap(f"┊ 🐍 exec {_trunc(first_line, 35)} {dur}")
|
||||
if tool_name == "delegate_task":
|
||||
tasks = args.get("tasks")
|
||||
if tasks and isinstance(tasks, list):
|
||||
return _wrap(f"┊ 🔀 delegate {len(tasks)} parallel tasks {dur}")
|
||||
return _wrap(f"┊ 🔀 delegate {_trunc(args.get('goal', ''), 35)} {dur}")
|
||||
|
||||
preview = build_tool_preview(tool_name, args) or ""
|
||||
return _wrap(f"┊ ⚡ {tool_name[:9]:9} {_trunc(preview, 35)} {dur}")
|
||||
851
agent/insights.py
Normal file
851
agent/insights.py
Normal file
@@ -0,0 +1,851 @@
|
||||
"""
|
||||
Session Insights Engine for Hermes Agent.
|
||||
|
||||
Analyzes historical session data from the SQLite state database to produce
|
||||
comprehensive usage insights — token consumption, cost estimates, tool usage
|
||||
patterns, activity trends, model/platform breakdowns, and session metrics.
|
||||
|
||||
Inspired by Claude Code's /insights command, adapted for Hermes Agent's
|
||||
multi-platform architecture with additional cost estimation and platform
|
||||
breakdown capabilities.
|
||||
|
||||
Usage:
|
||||
from agent.insights import InsightsEngine
|
||||
engine = InsightsEngine(db)
|
||||
report = engine.generate(days=30)
|
||||
print(engine.format_terminal(report))
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
# =========================================================================
|
||||
# Model pricing (USD per million tokens) — approximate as of early 2026
|
||||
# =========================================================================
|
||||
MODEL_PRICING = {
|
||||
# OpenAI
|
||||
"gpt-4o": {"input": 2.50, "output": 10.00},
|
||||
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
||||
"gpt-4.1": {"input": 2.00, "output": 8.00},
|
||||
"gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
||||
"gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
||||
"gpt-4.5-preview": {"input": 75.00, "output": 150.00},
|
||||
"gpt-5": {"input": 10.00, "output": 30.00},
|
||||
"gpt-5.4": {"input": 10.00, "output": 30.00},
|
||||
"o3": {"input": 10.00, "output": 40.00},
|
||||
"o3-mini": {"input": 1.10, "output": 4.40},
|
||||
"o4-mini": {"input": 1.10, "output": 4.40},
|
||||
# Anthropic
|
||||
"claude-opus-4-20250514": {"input": 15.00, "output": 75.00},
|
||||
"claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00},
|
||||
"claude-3-opus-20240229": {"input": 15.00, "output": 75.00},
|
||||
"claude-3-haiku-20240307": {"input": 0.25, "output": 1.25},
|
||||
# DeepSeek
|
||||
"deepseek-chat": {"input": 0.14, "output": 0.28},
|
||||
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
|
||||
# Google
|
||||
"gemini-2.5-pro": {"input": 1.25, "output": 10.00},
|
||||
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
|
||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||
# Meta (via providers)
|
||||
"llama-4-maverick": {"input": 0.50, "output": 0.70},
|
||||
"llama-4-scout": {"input": 0.20, "output": 0.30},
|
||||
# Z.AI / GLM (direct provider — pricing not published externally, treat as local)
|
||||
"glm-5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.7": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
|
||||
# Kimi / Moonshot (direct provider — pricing not published externally, treat as local)
|
||||
"kimi-k2.5": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
|
||||
# MiniMax (direct provider — pricing not published externally, treat as local)
|
||||
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
|
||||
}
|
||||
|
||||
# Fallback: unknown/custom models get zero cost (we can't assume pricing
|
||||
# for self-hosted models, custom OAI endpoints, local inference, etc.)
|
||||
_DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
||||
|
||||
|
||||
def _has_known_pricing(model_name: str) -> bool:
|
||||
"""Check if a model has known pricing (vs unknown/custom endpoint)."""
|
||||
return _get_pricing(model_name) is not _DEFAULT_PRICING
|
||||
|
||||
|
||||
def _get_pricing(model_name: str) -> dict[str, float]:
|
||||
"""Look up pricing for a model. Uses fuzzy matching on model name.
|
||||
|
||||
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
|
||||
we can't assume costs for self-hosted endpoints, local inference, etc.
|
||||
"""
|
||||
if not model_name:
|
||||
return _DEFAULT_PRICING
|
||||
|
||||
# Strip provider prefix (e.g., "anthropic/claude-..." -> "claude-...")
|
||||
bare = model_name.split("/")[-1].lower()
|
||||
|
||||
# Exact match first
|
||||
if bare in MODEL_PRICING:
|
||||
return MODEL_PRICING[bare]
|
||||
|
||||
# Fuzzy prefix match — prefer the LONGEST matching key to avoid
|
||||
# e.g. "gpt-4o" matching before "gpt-4o-mini" for "gpt-4o-mini-2024-07-18"
|
||||
best_match = None
|
||||
best_len = 0
|
||||
for key, price in MODEL_PRICING.items():
|
||||
if bare.startswith(key) and len(key) > best_len:
|
||||
best_match = price
|
||||
best_len = len(key)
|
||||
if best_match:
|
||||
return best_match
|
||||
|
||||
# Keyword heuristics (checked in most-specific-first order)
|
||||
if "opus" in bare:
|
||||
return {"input": 15.00, "output": 75.00}
|
||||
if "sonnet" in bare:
|
||||
return {"input": 3.00, "output": 15.00}
|
||||
if "haiku" in bare:
|
||||
return {"input": 0.80, "output": 4.00}
|
||||
if "gpt-4o-mini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
if "gpt-4o" in bare:
|
||||
return {"input": 2.50, "output": 10.00}
|
||||
if "gpt-5" in bare:
|
||||
return {"input": 10.00, "output": 30.00}
|
||||
if "deepseek" in bare:
|
||||
return {"input": 0.14, "output": 0.28}
|
||||
if "gemini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
|
||||
return _DEFAULT_PRICING
|
||||
|
||||
|
||||
def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
"""Estimate the USD cost for a given model and token counts."""
|
||||
pricing = _get_pricing(model)
|
||||
return (input_tokens * pricing["input"] + output_tokens * pricing["output"]) / 1_000_000
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
"""Format seconds into a human-readable duration string."""
|
||||
if seconds < 60:
|
||||
return f"{seconds:.0f}s"
|
||||
minutes = seconds / 60
|
||||
if minutes < 60:
|
||||
return f"{minutes:.0f}m"
|
||||
hours = minutes / 60
|
||||
if hours < 24:
|
||||
remaining_min = int(minutes % 60)
|
||||
return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h"
|
||||
days = hours / 24
|
||||
return f"{days:.1f}d"
|
||||
|
||||
|
||||
def _bar_chart(values: list[int], max_width: int = 20) -> list[str]:
|
||||
"""Create simple horizontal bar chart strings from values."""
|
||||
peak = max(values) if values else 1
|
||||
if peak == 0:
|
||||
return ["" for _ in values]
|
||||
return ["█" * max(1, int(v / peak * max_width)) if v > 0 else "" for v in values]
|
||||
|
||||
|
||||
class InsightsEngine:
|
||||
"""
|
||||
Analyzes session history and produces usage insights.
|
||||
|
||||
Works directly with a SessionDB instance (or raw sqlite3 connection)
|
||||
to query session and message data.
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
"""
|
||||
Initialize with a SessionDB instance.
|
||||
|
||||
Args:
|
||||
db: A SessionDB instance (from hermes_state.py)
|
||||
"""
|
||||
self.db = db
|
||||
self._conn = db._conn
|
||||
|
||||
def generate(self, days: int = 30, source: str = None) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a complete insights report.
|
||||
|
||||
Args:
|
||||
days: Number of days to look back (default: 30)
|
||||
source: Optional filter by source platform
|
||||
|
||||
Returns:
|
||||
Dict with all computed insights
|
||||
"""
|
||||
cutoff = time.time() - (days * 86400)
|
||||
|
||||
# Gather raw data
|
||||
sessions = self._get_sessions(cutoff, source)
|
||||
tool_usage = self._get_tool_usage(cutoff, source)
|
||||
message_stats = self._get_message_stats(cutoff, source)
|
||||
|
||||
if not sessions:
|
||||
return {
|
||||
"days": days,
|
||||
"source_filter": source,
|
||||
"empty": True,
|
||||
"overview": {},
|
||||
"models": [],
|
||||
"platforms": [],
|
||||
"tools": [],
|
||||
"activity": {},
|
||||
"top_sessions": [],
|
||||
}
|
||||
|
||||
# Compute insights
|
||||
overview = self._compute_overview(sessions, message_stats)
|
||||
models = self._compute_model_breakdown(sessions)
|
||||
platforms = self._compute_platform_breakdown(sessions)
|
||||
tools = self._compute_tool_breakdown(tool_usage)
|
||||
activity = self._compute_activity_patterns(sessions)
|
||||
top_sessions = self._compute_top_sessions(sessions)
|
||||
|
||||
return {
|
||||
"days": days,
|
||||
"source_filter": source,
|
||||
"empty": False,
|
||||
"generated_at": time.time(),
|
||||
"overview": overview,
|
||||
"models": models,
|
||||
"platforms": platforms,
|
||||
"tools": tools,
|
||||
"activity": activity,
|
||||
"top_sessions": top_sessions,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Data gathering (SQL queries)
|
||||
# =========================================================================
|
||||
|
||||
# Columns we actually need (skip system_prompt, model_config blobs)
|
||||
_SESSION_COLS = (
|
||||
"id, source, model, started_at, ended_at, message_count, tool_call_count, input_tokens, output_tokens"
|
||||
)
|
||||
|
||||
def _get_sessions(self, cutoff: float, source: str = None) -> list[dict]:
|
||||
"""Fetch sessions within the time window."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
f"""SELECT {self._SESSION_COLS} FROM sessions
|
||||
WHERE started_at >= ? AND source = ?
|
||||
ORDER BY started_at DESC""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
f"""SELECT {self._SESSION_COLS} FROM sessions
|
||||
WHERE started_at >= ?
|
||||
ORDER BY started_at DESC""",
|
||||
(cutoff,),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def _get_tool_usage(self, cutoff: float, source: str = None) -> list[dict]:
|
||||
"""Get tool call counts from messages.
|
||||
|
||||
Uses two sources:
|
||||
1. tool_name column on 'tool' role messages (set by gateway)
|
||||
2. tool_calls JSON on 'assistant' role messages (covers CLI where
|
||||
tool_name is not populated on tool responses)
|
||||
"""
|
||||
tool_counts = Counter()
|
||||
|
||||
# Source 1: explicit tool_name on tool response messages
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT m.tool_name, COUNT(*) as count
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ? AND s.source = ?
|
||||
AND m.role = 'tool' AND m.tool_name IS NOT NULL
|
||||
GROUP BY m.tool_name
|
||||
ORDER BY count DESC""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT m.tool_name, COUNT(*) as count
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ?
|
||||
AND m.role = 'tool' AND m.tool_name IS NOT NULL
|
||||
GROUP BY m.tool_name
|
||||
ORDER BY count DESC""",
|
||||
(cutoff,),
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
tool_counts[row["tool_name"]] += row["count"]
|
||||
|
||||
# Source 2: extract from tool_calls JSON on assistant messages
|
||||
# (covers CLI sessions where tool_name is NULL on tool responses)
|
||||
if source:
|
||||
cursor2 = self._conn.execute(
|
||||
"""SELECT m.tool_calls
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ? AND s.source = ?
|
||||
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor2 = self._conn.execute(
|
||||
"""SELECT m.tool_calls
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ?
|
||||
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
|
||||
(cutoff,),
|
||||
)
|
||||
|
||||
tool_calls_counts = Counter()
|
||||
for row in cursor2.fetchall():
|
||||
try:
|
||||
calls = row["tool_calls"]
|
||||
if isinstance(calls, str):
|
||||
calls = json.loads(calls)
|
||||
if isinstance(calls, list):
|
||||
for call in calls:
|
||||
func = call.get("function", {}) if isinstance(call, dict) else {}
|
||||
name = func.get("name")
|
||||
if name:
|
||||
tool_calls_counts[name] += 1
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
continue
|
||||
|
||||
# Merge: prefer tool_name source, supplement with tool_calls source
|
||||
# for tools not already counted
|
||||
if not tool_counts and tool_calls_counts:
|
||||
# No tool_name data at all — use tool_calls exclusively
|
||||
tool_counts = tool_calls_counts
|
||||
elif tool_counts and tool_calls_counts:
|
||||
# Both sources have data — use whichever has the higher count per tool
|
||||
# (they may overlap, so take the max to avoid double-counting)
|
||||
all_tools = set(tool_counts) | set(tool_calls_counts)
|
||||
merged = Counter()
|
||||
for tool in all_tools:
|
||||
merged[tool] = max(tool_counts.get(tool, 0), tool_calls_counts.get(tool, 0))
|
||||
tool_counts = merged
|
||||
|
||||
# Convert to the expected format
|
||||
return [{"tool_name": name, "count": count} for name, count in tool_counts.most_common()]
|
||||
|
||||
def _get_message_stats(self, cutoff: float, source: str = None) -> dict:
|
||||
"""Get aggregate message statistics."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT
|
||||
COUNT(*) as total_messages,
|
||||
SUM(CASE WHEN m.role = 'user' THEN 1 ELSE 0 END) as user_messages,
|
||||
SUM(CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END) as assistant_messages,
|
||||
SUM(CASE WHEN m.role = 'tool' THEN 1 ELSE 0 END) as tool_messages
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ? AND s.source = ?""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT
|
||||
COUNT(*) as total_messages,
|
||||
SUM(CASE WHEN m.role = 'user' THEN 1 ELSE 0 END) as user_messages,
|
||||
SUM(CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END) as assistant_messages,
|
||||
SUM(CASE WHEN m.role = 'tool' THEN 1 ELSE 0 END) as tool_messages
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ?""",
|
||||
(cutoff,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return (
|
||||
dict(row)
|
||||
if row
|
||||
else {
|
||||
"total_messages": 0,
|
||||
"user_messages": 0,
|
||||
"assistant_messages": 0,
|
||||
"tool_messages": 0,
|
||||
}
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Computation
|
||||
# =========================================================================
|
||||
|
||||
def _compute_overview(self, sessions: list[dict], message_stats: dict) -> dict:
|
||||
"""Compute high-level overview statistics."""
|
||||
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
|
||||
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
|
||||
total_tokens = total_input + total_output
|
||||
total_tool_calls = sum(s.get("tool_call_count") or 0 for s in sessions)
|
||||
total_messages = sum(s.get("message_count") or 0 for s in sessions)
|
||||
|
||||
# Cost estimation (weighted by model)
|
||||
total_cost = 0.0
|
||||
models_with_pricing = set()
|
||||
models_without_pricing = set()
|
||||
for s in sessions:
|
||||
model = s.get("model") or ""
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
total_cost += _estimate_cost(model, inp, out)
|
||||
display = model.split("/")[-1] if "/" in model else (model or "unknown")
|
||||
if _has_known_pricing(model):
|
||||
models_with_pricing.add(display)
|
||||
else:
|
||||
models_without_pricing.add(display)
|
||||
|
||||
# Session duration stats (guard against negative durations from clock drift)
|
||||
durations = []
|
||||
for s in sessions:
|
||||
start = s.get("started_at")
|
||||
end = s.get("ended_at")
|
||||
if start and end and end > start:
|
||||
durations.append(end - start)
|
||||
|
||||
total_hours = sum(durations) / 3600 if durations else 0
|
||||
avg_duration = sum(durations) / len(durations) if durations else 0
|
||||
|
||||
# Earliest and latest session
|
||||
started_timestamps = [s["started_at"] for s in sessions if s.get("started_at")]
|
||||
date_range_start = min(started_timestamps) if started_timestamps else None
|
||||
date_range_end = max(started_timestamps) if started_timestamps else None
|
||||
|
||||
return {
|
||||
"total_sessions": len(sessions),
|
||||
"total_messages": total_messages,
|
||||
"total_tool_calls": total_tool_calls,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_tokens": total_tokens,
|
||||
"estimated_cost": total_cost,
|
||||
"total_hours": total_hours,
|
||||
"avg_session_duration": avg_duration,
|
||||
"avg_messages_per_session": total_messages / len(sessions) if sessions else 0,
|
||||
"avg_tokens_per_session": total_tokens / len(sessions) if sessions else 0,
|
||||
"user_messages": message_stats.get("user_messages") or 0,
|
||||
"assistant_messages": message_stats.get("assistant_messages") or 0,
|
||||
"tool_messages": message_stats.get("tool_messages") or 0,
|
||||
"date_range_start": date_range_start,
|
||||
"date_range_end": date_range_end,
|
||||
"models_with_pricing": sorted(models_with_pricing),
|
||||
"models_without_pricing": sorted(models_without_pricing),
|
||||
}
|
||||
|
||||
def _compute_model_breakdown(self, sessions: list[dict]) -> list[dict]:
|
||||
"""Break down usage by model."""
|
||||
model_data = defaultdict(
|
||||
lambda: {
|
||||
"sessions": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"tool_calls": 0,
|
||||
"cost": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
for s in sessions:
|
||||
model = s.get("model") or "unknown"
|
||||
# Normalize: strip provider prefix for display
|
||||
display_model = model.split("/")[-1] if "/" in model else model
|
||||
d = model_data[display_model]
|
||||
d["sessions"] += 1
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
d["input_tokens"] += inp
|
||||
d["output_tokens"] += out
|
||||
d["total_tokens"] += inp + out
|
||||
d["tool_calls"] += s.get("tool_call_count") or 0
|
||||
d["cost"] += _estimate_cost(model, inp, out)
|
||||
d["has_pricing"] = _has_known_pricing(model)
|
||||
|
||||
result = [{"model": model, **data} for model, data in model_data.items()]
|
||||
# Sort by tokens first, fall back to session count when tokens are 0
|
||||
result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True)
|
||||
return result
|
||||
|
||||
def _compute_platform_breakdown(self, sessions: list[dict]) -> list[dict]:
|
||||
"""Break down usage by platform/source."""
|
||||
platform_data = defaultdict(
|
||||
lambda: {
|
||||
"sessions": 0,
|
||||
"messages": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"tool_calls": 0,
|
||||
}
|
||||
)
|
||||
|
||||
for s in sessions:
|
||||
source = s.get("source") or "unknown"
|
||||
d = platform_data[source]
|
||||
d["sessions"] += 1
|
||||
d["messages"] += s.get("message_count") or 0
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
d["input_tokens"] += inp
|
||||
d["output_tokens"] += out
|
||||
d["total_tokens"] += inp + out
|
||||
d["tool_calls"] += s.get("tool_call_count") or 0
|
||||
|
||||
result = [{"platform": platform, **data} for platform, data in platform_data.items()]
|
||||
result.sort(key=lambda x: x["sessions"], reverse=True)
|
||||
return result
|
||||
|
||||
def _compute_tool_breakdown(self, tool_usage: list[dict]) -> list[dict]:
|
||||
"""Process tool usage data into a ranked list with percentages."""
|
||||
total_calls = sum(t["count"] for t in tool_usage) if tool_usage else 0
|
||||
result = []
|
||||
for t in tool_usage:
|
||||
pct = (t["count"] / total_calls * 100) if total_calls else 0
|
||||
result.append(
|
||||
{
|
||||
"tool": t["tool_name"],
|
||||
"count": t["count"],
|
||||
"percentage": pct,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def _compute_activity_patterns(self, sessions: list[dict]) -> dict:
|
||||
"""Analyze activity patterns by day of week and hour."""
|
||||
day_counts = Counter() # 0=Monday ... 6=Sunday
|
||||
hour_counts = Counter()
|
||||
daily_counts = Counter() # date string -> count
|
||||
|
||||
for s in sessions:
|
||||
ts = s.get("started_at")
|
||||
if not ts:
|
||||
continue
|
||||
dt = datetime.fromtimestamp(ts)
|
||||
day_counts[dt.weekday()] += 1
|
||||
hour_counts[dt.hour] += 1
|
||||
daily_counts[dt.strftime("%Y-%m-%d")] += 1
|
||||
|
||||
day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
|
||||
day_breakdown = [{"day": day_names[i], "count": day_counts.get(i, 0)} for i in range(7)]
|
||||
|
||||
hour_breakdown = [{"hour": i, "count": hour_counts.get(i, 0)} for i in range(24)]
|
||||
|
||||
# Busiest day and hour
|
||||
busiest_day = max(day_breakdown, key=lambda x: x["count"]) if day_breakdown else None
|
||||
busiest_hour = max(hour_breakdown, key=lambda x: x["count"]) if hour_breakdown else None
|
||||
|
||||
# Active days (days with at least one session)
|
||||
active_days = len(daily_counts)
|
||||
|
||||
# Streak calculation
|
||||
if daily_counts:
|
||||
all_dates = sorted(daily_counts.keys())
|
||||
current_streak = 1
|
||||
max_streak = 1
|
||||
for i in range(1, len(all_dates)):
|
||||
d1 = datetime.strptime(all_dates[i - 1], "%Y-%m-%d")
|
||||
d2 = datetime.strptime(all_dates[i], "%Y-%m-%d")
|
||||
if (d2 - d1).days == 1:
|
||||
current_streak += 1
|
||||
max_streak = max(max_streak, current_streak)
|
||||
else:
|
||||
current_streak = 1
|
||||
else:
|
||||
max_streak = 0
|
||||
|
||||
return {
|
||||
"by_day": day_breakdown,
|
||||
"by_hour": hour_breakdown,
|
||||
"busiest_day": busiest_day,
|
||||
"busiest_hour": busiest_hour,
|
||||
"active_days": active_days,
|
||||
"max_streak": max_streak,
|
||||
}
|
||||
|
||||
def _compute_top_sessions(self, sessions: list[dict]) -> list[dict]:
|
||||
"""Find notable sessions (longest, most messages, most tokens)."""
|
||||
top = []
|
||||
|
||||
# Longest by duration
|
||||
sessions_with_duration = [s for s in sessions if s.get("started_at") and s.get("ended_at")]
|
||||
if sessions_with_duration:
|
||||
longest = max(
|
||||
sessions_with_duration,
|
||||
key=lambda s: s["ended_at"] - s["started_at"],
|
||||
)
|
||||
dur = longest["ended_at"] - longest["started_at"]
|
||||
top.append(
|
||||
{
|
||||
"label": "Longest session",
|
||||
"session_id": longest["id"][:16],
|
||||
"value": _format_duration(dur),
|
||||
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
|
||||
}
|
||||
)
|
||||
|
||||
# Most messages
|
||||
most_msgs = max(sessions, key=lambda s: s.get("message_count") or 0)
|
||||
if (most_msgs.get("message_count") or 0) > 0:
|
||||
top.append(
|
||||
{
|
||||
"label": "Most messages",
|
||||
"session_id": most_msgs["id"][:16],
|
||||
"value": f"{most_msgs['message_count']} msgs",
|
||||
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d")
|
||||
if most_msgs.get("started_at")
|
||||
else "?",
|
||||
}
|
||||
)
|
||||
|
||||
# Most tokens
|
||||
most_tokens = max(
|
||||
sessions,
|
||||
key=lambda s: (s.get("input_tokens") or 0) + (s.get("output_tokens") or 0),
|
||||
)
|
||||
token_total = (most_tokens.get("input_tokens") or 0) + (most_tokens.get("output_tokens") or 0)
|
||||
if token_total > 0:
|
||||
top.append(
|
||||
{
|
||||
"label": "Most tokens",
|
||||
"session_id": most_tokens["id"][:16],
|
||||
"value": f"{token_total:,} tokens",
|
||||
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d")
|
||||
if most_tokens.get("started_at")
|
||||
else "?",
|
||||
}
|
||||
)
|
||||
|
||||
# Most tool calls
|
||||
most_tools = max(sessions, key=lambda s: s.get("tool_call_count") or 0)
|
||||
if (most_tools.get("tool_call_count") or 0) > 0:
|
||||
top.append(
|
||||
{
|
||||
"label": "Most tool calls",
|
||||
"session_id": most_tools["id"][:16],
|
||||
"value": f"{most_tools['tool_call_count']} calls",
|
||||
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d")
|
||||
if most_tools.get("started_at")
|
||||
else "?",
|
||||
}
|
||||
)
|
||||
|
||||
return top
|
||||
|
||||
# =========================================================================
|
||||
# Formatting
|
||||
# =========================================================================
|
||||
|
||||
def format_terminal(self, report: dict) -> str:
|
||||
"""Format the insights report for terminal display (CLI)."""
|
||||
if report.get("empty"):
|
||||
days = report.get("days", 30)
|
||||
src = f" (source: {report['source_filter']})" if report.get("source_filter") else ""
|
||||
return f" No sessions found in the last {days} days{src}."
|
||||
|
||||
lines = []
|
||||
o = report["overview"]
|
||||
days = report["days"]
|
||||
src_filter = report.get("source_filter")
|
||||
|
||||
# Header
|
||||
lines.append("")
|
||||
lines.append(" ╔══════════════════════════════════════════════════════════╗")
|
||||
lines.append(" ║ 📊 Hermes Insights ║")
|
||||
period_label = f"Last {days} days"
|
||||
if src_filter:
|
||||
period_label += f" ({src_filter})"
|
||||
padding = 58 - len(period_label) - 2
|
||||
left_pad = padding // 2
|
||||
right_pad = padding - left_pad
|
||||
lines.append(f" ║{' ' * left_pad} {period_label} {' ' * right_pad}║")
|
||||
lines.append(" ╚══════════════════════════════════════════════════════════╝")
|
||||
lines.append("")
|
||||
|
||||
# Date range
|
||||
if o.get("date_range_start") and o.get("date_range_end"):
|
||||
start_str = datetime.fromtimestamp(o["date_range_start"]).strftime("%b %d, %Y")
|
||||
end_str = datetime.fromtimestamp(o["date_range_end"]).strftime("%b %d, %Y")
|
||||
lines.append(f" Period: {start_str} — {end_str}")
|
||||
lines.append("")
|
||||
|
||||
# Overview
|
||||
lines.append(" 📋 Overview")
|
||||
lines.append(" " + "─" * 56)
|
||||
lines.append(f" Sessions: {o['total_sessions']:<12} Messages: {o['total_messages']:,}")
|
||||
lines.append(f" Tool calls: {o['total_tool_calls']:<12,} User messages: {o['user_messages']:,}")
|
||||
lines.append(
|
||||
f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}"
|
||||
)
|
||||
cost_str = f"${o['estimated_cost']:.2f}"
|
||||
if o.get("models_without_pricing"):
|
||||
cost_str += " *"
|
||||
lines.append(f" Total tokens: {o['total_tokens']:<12,} Est. cost: {cost_str}")
|
||||
if o["total_hours"] > 0:
|
||||
lines.append(
|
||||
f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}"
|
||||
)
|
||||
lines.append(f" Avg msgs/session: {o['avg_messages_per_session']:.1f}")
|
||||
lines.append("")
|
||||
|
||||
# Model breakdown
|
||||
if report["models"]:
|
||||
lines.append(" 🤖 Models Used")
|
||||
lines.append(" " + "─" * 56)
|
||||
lines.append(f" {'Model':<30} {'Sessions':>8} {'Tokens':>12} {'Cost':>8}")
|
||||
for m in report["models"]:
|
||||
model_name = m["model"][:28]
|
||||
if m.get("has_pricing"):
|
||||
cost_cell = f"${m['cost']:>6.2f}"
|
||||
else:
|
||||
cost_cell = " N/A"
|
||||
lines.append(f" {model_name:<30} {m['sessions']:>8} {m['total_tokens']:>12,} {cost_cell}")
|
||||
if o.get("models_without_pricing"):
|
||||
lines.append(" * Cost N/A for custom/self-hosted models")
|
||||
lines.append("")
|
||||
|
||||
# Platform breakdown
|
||||
if len(report["platforms"]) > 1 or (report["platforms"] and report["platforms"][0]["platform"] != "cli"):
|
||||
lines.append(" 📱 Platforms")
|
||||
lines.append(" " + "─" * 56)
|
||||
lines.append(f" {'Platform':<14} {'Sessions':>8} {'Messages':>10} {'Tokens':>14}")
|
||||
for p in report["platforms"]:
|
||||
lines.append(f" {p['platform']:<14} {p['sessions']:>8} {p['messages']:>10,} {p['total_tokens']:>14,}")
|
||||
lines.append("")
|
||||
|
||||
# Tool usage
|
||||
if report["tools"]:
|
||||
lines.append(" 🔧 Top Tools")
|
||||
lines.append(" " + "─" * 56)
|
||||
lines.append(f" {'Tool':<28} {'Calls':>8} {'%':>8}")
|
||||
for t in report["tools"][:15]: # Top 15
|
||||
lines.append(f" {t['tool']:<28} {t['count']:>8,} {t['percentage']:>7.1f}%")
|
||||
if len(report["tools"]) > 15:
|
||||
lines.append(f" ... and {len(report['tools']) - 15} more tools")
|
||||
lines.append("")
|
||||
|
||||
# Activity patterns
|
||||
act = report.get("activity", {})
|
||||
if act.get("by_day"):
|
||||
lines.append(" 📅 Activity Patterns")
|
||||
lines.append(" " + "─" * 56)
|
||||
|
||||
# Day of week chart
|
||||
day_values = [d["count"] for d in act["by_day"]]
|
||||
bars = _bar_chart(day_values, max_width=15)
|
||||
for i, d in enumerate(act["by_day"]):
|
||||
bar = bars[i]
|
||||
lines.append(f" {d['day']} {bar:<15} {d['count']}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Peak hours (show top 5 busiest hours)
|
||||
busy_hours = sorted(act["by_hour"], key=lambda x: x["count"], reverse=True)
|
||||
busy_hours = [h for h in busy_hours if h["count"] > 0][:5]
|
||||
if busy_hours:
|
||||
hour_strs = []
|
||||
for h in busy_hours:
|
||||
hr = h["hour"]
|
||||
ampm = "AM" if hr < 12 else "PM"
|
||||
display_hr = hr % 12 or 12
|
||||
hour_strs.append(f"{display_hr}{ampm} ({h['count']})")
|
||||
lines.append(f" Peak hours: {', '.join(hour_strs)}")
|
||||
|
||||
if act.get("active_days"):
|
||||
lines.append(f" Active days: {act['active_days']}")
|
||||
if act.get("max_streak") and act["max_streak"] > 1:
|
||||
lines.append(f" Best streak: {act['max_streak']} consecutive days")
|
||||
lines.append("")
|
||||
|
||||
# Notable sessions
|
||||
if report.get("top_sessions"):
|
||||
lines.append(" 🏆 Notable Sessions")
|
||||
lines.append(" " + "─" * 56)
|
||||
for ts in report["top_sessions"]:
|
||||
lines.append(f" {ts['label']:<20} {ts['value']:<18} ({ts['date']}, {ts['session_id']})")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_gateway(self, report: dict) -> str:
|
||||
"""Format the insights report for gateway/messaging (shorter)."""
|
||||
if report.get("empty"):
|
||||
days = report.get("days", 30)
|
||||
return f"No sessions found in the last {days} days."
|
||||
|
||||
lines = []
|
||||
o = report["overview"]
|
||||
days = report["days"]
|
||||
|
||||
lines.append(f"📊 **Hermes Insights** — Last {days} days\n")
|
||||
|
||||
# Overview
|
||||
lines.append(
|
||||
f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}"
|
||||
)
|
||||
lines.append(
|
||||
f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})"
|
||||
)
|
||||
cost_note = ""
|
||||
if o.get("models_without_pricing"):
|
||||
cost_note = " _(excludes custom/self-hosted models)_"
|
||||
lines.append(f"**Est. cost:** ${o['estimated_cost']:.2f}{cost_note}")
|
||||
if o["total_hours"] > 0:
|
||||
lines.append(
|
||||
f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
# Models (top 5)
|
||||
if report["models"]:
|
||||
lines.append("**🤖 Models:**")
|
||||
for m in report["models"][:5]:
|
||||
cost_str = f"${m['cost']:.2f}" if m.get("has_pricing") else "N/A"
|
||||
lines.append(
|
||||
f" {m['model'][:25]} — {m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
# Platforms (if multi-platform)
|
||||
if len(report["platforms"]) > 1:
|
||||
lines.append("**📱 Platforms:**")
|
||||
for p in report["platforms"]:
|
||||
lines.append(f" {p['platform']} — {p['sessions']} sessions, {p['messages']:,} msgs")
|
||||
lines.append("")
|
||||
|
||||
# Tools (top 8)
|
||||
if report["tools"]:
|
||||
lines.append("**🔧 Top Tools:**")
|
||||
for t in report["tools"][:8]:
|
||||
lines.append(f" {t['tool']} — {t['count']:,} calls ({t['percentage']:.1f}%)")
|
||||
lines.append("")
|
||||
|
||||
# Activity summary
|
||||
act = report.get("activity", {})
|
||||
if act.get("busiest_day") and act.get("busiest_hour"):
|
||||
hr = act["busiest_hour"]["hour"]
|
||||
ampm = "AM" if hr < 12 else "PM"
|
||||
display_hr = hr % 12 or 12
|
||||
lines.append(
|
||||
f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)"
|
||||
)
|
||||
if act.get("active_days"):
|
||||
lines.append(
|
||||
f"**Active days:** {act['active_days']}",
|
||||
)
|
||||
if act.get("max_streak", 0) > 1:
|
||||
lines.append(f"**Best streak:** {act['max_streak']} consecutive days")
|
||||
|
||||
return "\n".join(lines)
|
||||
224
agent/model_metadata.py
Normal file
224
agent/model_metadata.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Model metadata, context lengths, and token estimation utilities.
|
||||
|
||||
Pure utility functions with no AIAgent dependency. Used by ContextCompressor
|
||||
and run_agent.py for pre-flight context checks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model_metadata_cache: dict[str, dict[str, Any]] = {}
|
||||
_model_metadata_cache_time: float = 0
|
||||
_MODEL_CACHE_TTL = 3600
|
||||
|
||||
# Descending tiers for context length probing when the model is unknown.
|
||||
# We start high and step down on context-length errors until one works.
|
||||
CONTEXT_PROBE_TIERS = [
|
||||
2_000_000,
|
||||
1_000_000,
|
||||
512_000,
|
||||
200_000,
|
||||
128_000,
|
||||
64_000,
|
||||
32_000,
|
||||
]
|
||||
|
||||
DEFAULT_CONTEXT_LENGTHS = {
|
||||
"anthropic/claude-opus-4": 200000,
|
||||
"anthropic/claude-opus-4.5": 200000,
|
||||
"anthropic/claude-opus-4.6": 200000,
|
||||
"anthropic/claude-sonnet-4": 200000,
|
||||
"anthropic/claude-sonnet-4-20250514": 200000,
|
||||
"anthropic/claude-haiku-4.5": 200000,
|
||||
"openai/gpt-4o": 128000,
|
||||
"openai/gpt-4-turbo": 128000,
|
||||
"openai/gpt-4o-mini": 128000,
|
||||
"google/gemini-2.0-flash": 1048576,
|
||||
"google/gemini-2.5-pro": 1048576,
|
||||
"meta-llama/llama-3.3-70b-instruct": 131072,
|
||||
"deepseek/deepseek-chat-v3": 65536,
|
||||
"qwen/qwen-2.5-72b-instruct": 32768,
|
||||
"glm-4.7": 202752,
|
||||
"glm-5": 202752,
|
||||
"glm-4.5": 131072,
|
||||
"glm-4.5-flash": 131072,
|
||||
"kimi-k2.5": 262144,
|
||||
"kimi-k2-thinking": 262144,
|
||||
"kimi-k2-turbo-preview": 262144,
|
||||
"kimi-k2-0905-preview": 131072,
|
||||
"MiniMax-M2.5": 204800,
|
||||
"MiniMax-M2.5-highspeed": 204800,
|
||||
"MiniMax-M2.1": 204800,
|
||||
}
|
||||
|
||||
|
||||
def fetch_model_metadata(force_refresh: bool = False) -> dict[str, dict[str, Any]]:
|
||||
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
||||
global _model_metadata_cache, _model_metadata_cache_time
|
||||
|
||||
if not force_refresh and _model_metadata_cache and (time.time() - _model_metadata_cache_time) < _MODEL_CACHE_TTL:
|
||||
return _model_metadata_cache
|
||||
|
||||
try:
|
||||
response = requests.get(OPENROUTER_MODELS_URL, timeout=10)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
cache = {}
|
||||
for model in data.get("data", []):
|
||||
model_id = model.get("id", "")
|
||||
cache[model_id] = {
|
||||
"context_length": model.get("context_length", 128000),
|
||||
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
|
||||
"name": model.get("name", model_id),
|
||||
"pricing": model.get("pricing", {}),
|
||||
}
|
||||
canonical = model.get("canonical_slug", "")
|
||||
if canonical and canonical != model_id:
|
||||
cache[canonical] = cache[model_id]
|
||||
|
||||
_model_metadata_cache = cache
|
||||
_model_metadata_cache_time = time.time()
|
||||
logger.debug("Fetched metadata for %s models from OpenRouter", len(cache))
|
||||
return cache
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to fetch model metadata from OpenRouter: {e}")
|
||||
return _model_metadata_cache or {}
|
||||
|
||||
|
||||
def _get_context_cache_path() -> Path:
|
||||
"""Return path to the persistent context length cache file."""
|
||||
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
return hermes_home / "context_length_cache.yaml"
|
||||
|
||||
|
||||
def _load_context_cache() -> dict[str, int]:
|
||||
"""Load the model+provider → context_length cache from disk."""
|
||||
path = _get_context_cache_path()
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
return data.get("context_lengths", {})
|
||||
except Exception as e:
|
||||
logger.debug("Failed to load context length cache: %s", e)
|
||||
return {}
|
||||
|
||||
|
||||
def save_context_length(model: str, base_url: str, length: int) -> None:
|
||||
"""Persist a discovered context length for a model+provider combo.
|
||||
|
||||
Cache key is ``model@base_url`` so the same model name served from
|
||||
different providers can have different limits.
|
||||
"""
|
||||
key = f"{model}@{base_url}"
|
||||
cache = _load_context_cache()
|
||||
if cache.get(key) == length:
|
||||
return # already stored
|
||||
cache[key] = length
|
||||
path = _get_context_cache_path()
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
|
||||
logger.info("Cached context length %s → %s tokens", key, f"{length:,}")
|
||||
except Exception as e:
|
||||
logger.debug("Failed to save context length cache: %s", e)
|
||||
|
||||
|
||||
def get_cached_context_length(model: str, base_url: str) -> int | None:
|
||||
"""Look up a previously discovered context length for model+provider."""
|
||||
key = f"{model}@{base_url}"
|
||||
cache = _load_context_cache()
|
||||
return cache.get(key)
|
||||
|
||||
|
||||
def get_next_probe_tier(current_length: int) -> int | None:
|
||||
"""Return the next lower probe tier, or None if already at minimum."""
|
||||
for tier in CONTEXT_PROBE_TIERS:
|
||||
if tier < current_length:
|
||||
return tier
|
||||
return None
|
||||
|
||||
|
||||
def parse_context_limit_from_error(error_msg: str) -> int | None:
|
||||
"""Try to extract the actual context limit from an API error message.
|
||||
|
||||
Many providers include the limit in their error text, e.g.:
|
||||
- "maximum context length is 32768 tokens"
|
||||
- "context_length_exceeded: 131072"
|
||||
- "Maximum context size 32768 exceeded"
|
||||
- "model's max context length is 65536"
|
||||
"""
|
||||
error_lower = error_msg.lower()
|
||||
# Pattern: look for numbers near context-related keywords
|
||||
patterns = [
|
||||
r"(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})",
|
||||
r"context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})",
|
||||
r"(\d{4,})\s*(?:token)?\s*(?:context|limit)",
|
||||
r">\s*(\d{4,})\s*(?:max|limit|token)", # "250000 tokens > 200000 maximum"
|
||||
r"(\d{4,})\s*(?:max(?:imum)?)\b", # "200000 maximum"
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, error_lower)
|
||||
if match:
|
||||
limit = int(match.group(1))
|
||||
# Sanity check: must be a reasonable context length
|
||||
if 1024 <= limit <= 10_000_000:
|
||||
return limit
|
||||
return None
|
||||
|
||||
|
||||
def get_model_context_length(model: str, base_url: str = "") -> int:
|
||||
"""Get the context length for a model.
|
||||
|
||||
Resolution order:
|
||||
1. Persistent cache (previously discovered via probing)
|
||||
2. OpenRouter API metadata
|
||||
3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match)
|
||||
4. First probe tier (2M) — will be narrowed on first context error
|
||||
"""
|
||||
# 1. Check persistent cache (model+provider)
|
||||
if base_url:
|
||||
cached = get_cached_context_length(model, base_url)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# 2. OpenRouter API metadata
|
||||
metadata = fetch_model_metadata()
|
||||
if model in metadata:
|
||||
return metadata[model].get("context_length", 128000)
|
||||
|
||||
# 3. Hardcoded defaults (fuzzy match)
|
||||
for default_model, length in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
if default_model in model or model in default_model:
|
||||
return length
|
||||
|
||||
# 4. Unknown model — start at highest probe tier
|
||||
return CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
|
||||
def estimate_tokens_rough(text: str) -> int:
|
||||
"""Rough token estimate (~4 chars/token) for pre-flight checks."""
|
||||
if not text:
|
||||
return 0
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
def estimate_messages_tokens_rough(messages: list[dict[str, Any]]) -> int:
|
||||
"""Rough token estimate for a message list (pre-flight only)."""
|
||||
total_chars = sum(len(str(msg)) for msg in messages)
|
||||
return total_chars // 4
|
||||
402
agent/prompt_builder.py
Normal file
402
agent/prompt_builder.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""System prompt assembly -- identity, platform hints, skills index, context files.
|
||||
|
||||
All functions are stateless. AIAgent._build_system_prompt() calls these to
|
||||
assemble pieces, then combines them with memory and ephemeral prompts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context file scanning — detect prompt injection in AGENTS.md, .cursorrules,
|
||||
# SOUL.md before they get injected into the system prompt.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CONTEXT_THREAT_PATTERNS = [
|
||||
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
|
||||
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
|
||||
(r"system\s+prompt\s+override", "sys_prompt_override"),
|
||||
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
|
||||
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
|
||||
(r"<!--[^>]*(?:ignore|override|system|secret|hidden)[^>]*-->", "html_comment_injection"),
|
||||
(r'<\s*div\s+style\s*=\s*["\'].*display\s*:\s*none', "hidden_div"),
|
||||
(r"translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)", "translate_execute"),
|
||||
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
|
||||
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"),
|
||||
]
|
||||
|
||||
_CONTEXT_INVISIBLE_CHARS = {
|
||||
"\u200b",
|
||||
"\u200c",
|
||||
"\u200d",
|
||||
"\u2060",
|
||||
"\ufeff",
|
||||
"\u202a",
|
||||
"\u202b",
|
||||
"\u202c",
|
||||
"\u202d",
|
||||
"\u202e",
|
||||
}
|
||||
|
||||
|
||||
def _scan_context_content(content: str, filename: str) -> str:
|
||||
"""Scan context file content for injection. Returns sanitized content."""
|
||||
findings = []
|
||||
|
||||
# Check invisible unicode
|
||||
for char in _CONTEXT_INVISIBLE_CHARS:
|
||||
if char in content:
|
||||
findings.append(f"invisible unicode U+{ord(char):04X}")
|
||||
|
||||
# Check threat patterns
|
||||
for pattern, pid in _CONTEXT_THREAT_PATTERNS:
|
||||
if re.search(pattern, content, re.IGNORECASE):
|
||||
findings.append(pid)
|
||||
|
||||
if findings:
|
||||
logger.warning("Context file %s blocked: %s", filename, ", ".join(findings))
|
||||
return (
|
||||
f"[BLOCKED: {filename} contained potential prompt injection ({', '.join(findings)}). Content not loaded.]"
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Constants
|
||||
# =========================================================================
|
||||
|
||||
DEFAULT_AGENT_IDENTITY = (
|
||||
"You are Hermes Agent, an intelligent AI assistant created by Nous Research. "
|
||||
"You are helpful, knowledgeable, and direct. You assist users with a wide "
|
||||
"range of tasks including answering questions, writing and editing code, "
|
||||
"analyzing information, creative work, and executing actions via your tools. "
|
||||
"You communicate clearly, admit uncertainty when appropriate, and prioritize "
|
||||
"being genuinely useful over being verbose unless otherwise directed below. "
|
||||
"Be targeted and efficient in your exploration and investigations."
|
||||
)
|
||||
|
||||
MEMORY_GUIDANCE = (
|
||||
"You have persistent memory across sessions. Proactively save important things "
|
||||
"you learn (user preferences, environment details, useful approaches) and do "
|
||||
"(like a diary!) using the memory tool -- don't wait to be asked."
|
||||
)
|
||||
|
||||
SESSION_SEARCH_GUIDANCE = (
|
||||
"When the user references something from a past conversation or you suspect "
|
||||
"relevant prior context exists, use session_search to recall it before asking "
|
||||
"them to repeat themselves."
|
||||
)
|
||||
|
||||
SKILLS_GUIDANCE = (
|
||||
"After completing a complex task (5+ tool calls), fixing a tricky error, "
|
||||
"or discovering a non-trivial workflow, consider saving the approach as a "
|
||||
"skill with skill_manage so you can reuse it next time."
|
||||
)
|
||||
|
||||
PLATFORM_HINTS = {
|
||||
"whatsapp": (
|
||||
"You are on a text messaging communication platform, WhatsApp. "
|
||||
"Please do not use markdown as it does not render. "
|
||||
"You can send media files natively: to deliver a file to the user, "
|
||||
"include MEDIA:/absolute/path/to/file in your response. The file "
|
||||
"will be sent as a native WhatsApp attachment — images (.jpg, .png, "
|
||||
".webp) appear as photos, videos (.mp4, .mov) play inline, and other "
|
||||
"files arrive as downloadable documents. You can also include image "
|
||||
"URLs in markdown format  and they will be sent as photos."
|
||||
),
|
||||
"telegram": (
|
||||
"You are on a text messaging communication platform, Telegram. "
|
||||
"Please do not use markdown as it does not render. "
|
||||
"You can send media files natively: to deliver a file to the user, "
|
||||
"include MEDIA:/absolute/path/to/file in your response. Images "
|
||||
"(.png, .jpg, .webp) appear as photos, audio (.ogg) sends as voice "
|
||||
"bubbles, and videos (.mp4) play inline. You can also include image "
|
||||
"URLs in markdown format  and they will be sent as native photos."
|
||||
),
|
||||
"discord": (
|
||||
"You are in a Discord server or group chat communicating with your user. "
|
||||
"You can send media files natively: include MEDIA:/absolute/path/to/file "
|
||||
"in your response. Images (.png, .jpg, .webp) are sent as photo "
|
||||
"attachments, audio as file attachments. You can also include image URLs "
|
||||
"in markdown format  and they will be sent as attachments."
|
||||
),
|
||||
"slack": (
|
||||
"You are in a Slack workspace communicating with your user. "
|
||||
"You can send media files natively: include MEDIA:/absolute/path/to/file "
|
||||
"in your response. Images (.png, .jpg, .webp) are uploaded as photo "
|
||||
"attachments, audio as file attachments. You can also include image URLs "
|
||||
"in markdown format  and they will be uploaded as attachments."
|
||||
),
|
||||
"signal": (
|
||||
"You are on a text messaging communication platform, Signal. "
|
||||
"Please do not use markdown as it does not render. "
|
||||
"You can send media files natively: to deliver a file to the user, "
|
||||
"include MEDIA:/absolute/path/to/file in your response. Images "
|
||||
"(.png, .jpg, .webp) appear as photos, audio as attachments, and other "
|
||||
"files arrive as downloadable documents. You can also include image "
|
||||
"URLs in markdown format  and they will be sent as photos."
|
||||
),
|
||||
"cli": ("You are a CLI AI Agent. Try not to use markdown but simple text renderable inside a terminal."),
|
||||
}
|
||||
|
||||
CONTEXT_FILE_MAX_CHARS = 20_000
|
||||
CONTEXT_TRUNCATE_HEAD_RATIO = 0.7
|
||||
CONTEXT_TRUNCATE_TAIL_RATIO = 0.2
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Skills index
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _read_skill_description(skill_file: Path, max_chars: int = 60) -> str:
|
||||
"""Read the description from a SKILL.md frontmatter, capped at max_chars."""
|
||||
try:
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
match = re.search(
|
||||
r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---",
|
||||
raw,
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
if match:
|
||||
desc = match.group(1).strip().strip("'\"")
|
||||
if len(desc) > max_chars:
|
||||
desc = desc[: max_chars - 3] + "..."
|
||||
return desc
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def _skill_is_platform_compatible(skill_file: Path) -> bool:
|
||||
"""Quick check if a SKILL.md is compatible with the current OS platform.
|
||||
|
||||
Reads just enough to parse the ``platforms`` frontmatter field.
|
||||
Skills without the field (the vast majority) are always compatible.
|
||||
"""
|
||||
try:
|
||||
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
|
||||
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = _parse_frontmatter(raw)
|
||||
return skill_matches_platform(frontmatter)
|
||||
except Exception:
|
||||
return True # Err on the side of showing the skill
|
||||
|
||||
|
||||
def build_skills_system_prompt() -> str:
|
||||
"""Build a compact skill index for the system prompt.
|
||||
|
||||
Scans ~/.hermes/skills/ for SKILL.md files grouped by category.
|
||||
Includes per-skill descriptions from frontmatter so the model can
|
||||
match skills by meaning, not just name.
|
||||
Filters out skills incompatible with the current OS platform.
|
||||
"""
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
skills_dir = hermes_home / "skills"
|
||||
|
||||
if not skills_dir.exists():
|
||||
return ""
|
||||
|
||||
# Collect skills with descriptions, grouped by category
|
||||
# Each entry: (skill_name, description)
|
||||
# Supports sub-categories: skills/mlops/training/axolotl/SKILL.md
|
||||
# → category "mlops/training", skill "axolotl"
|
||||
skills_by_category: dict[str, list[tuple[str, str]]] = {}
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not _skill_is_platform_compatible(skill_file):
|
||||
continue
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
# Category is everything between skills_dir and the skill folder
|
||||
# e.g. parts = ("mlops", "training", "axolotl", "SKILL.md")
|
||||
# → category = "mlops/training", skill_name = "axolotl"
|
||||
# e.g. parts = ("github", "github-auth", "SKILL.md")
|
||||
# → category = "github", skill_name = "github-auth"
|
||||
skill_name = parts[-2]
|
||||
category = "/".join(parts[:-2]) if len(parts) > 2 else parts[0]
|
||||
else:
|
||||
category = "general"
|
||||
skill_name = skill_file.parent.name
|
||||
desc = _read_skill_description(skill_file)
|
||||
skills_by_category.setdefault(category, []).append((skill_name, desc))
|
||||
|
||||
if not skills_by_category:
|
||||
return ""
|
||||
|
||||
# Read category-level descriptions from DESCRIPTION.md
|
||||
# Checks both the exact category path and parent directories
|
||||
category_descriptions = {}
|
||||
for category in skills_by_category:
|
||||
cat_path = Path(category)
|
||||
desc_file = skills_dir / cat_path / "DESCRIPTION.md"
|
||||
if desc_file.exists():
|
||||
try:
|
||||
content = desc_file.read_text(encoding="utf-8")
|
||||
match = re.search(r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---", content, re.MULTILINE | re.DOTALL)
|
||||
if match:
|
||||
category_descriptions[category] = match.group(1).strip()
|
||||
except Exception as e:
|
||||
logger.debug("Could not read skill description %s: %s", desc_file, e)
|
||||
|
||||
index_lines = []
|
||||
for category in sorted(skills_by_category.keys()):
|
||||
cat_desc = category_descriptions.get(category, "")
|
||||
if cat_desc:
|
||||
index_lines.append(f" {category}: {cat_desc}")
|
||||
else:
|
||||
index_lines.append(f" {category}:")
|
||||
# Deduplicate and sort skills within each category
|
||||
seen = set()
|
||||
for name, desc in sorted(skills_by_category[category], key=lambda x: x[0]):
|
||||
if name in seen:
|
||||
continue
|
||||
seen.add(name)
|
||||
if desc:
|
||||
index_lines.append(f" - {name}: {desc}")
|
||||
else:
|
||||
index_lines.append(f" - {name}")
|
||||
|
||||
return (
|
||||
"## Skills (mandatory)\n"
|
||||
"Before replying, scan the skills below. If one clearly matches your task, "
|
||||
"load it with skill_view(name) and follow its instructions. "
|
||||
"If a skill has issues, fix it with skill_manage(action='patch').\n"
|
||||
"\n"
|
||||
"<available_skills>\n" + "\n".join(index_lines) + "\n"
|
||||
"</available_skills>\n"
|
||||
"\n"
|
||||
"If none match, proceed normally without loading a skill."
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Context files (SOUL.md, AGENTS.md, .cursorrules)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE_MAX_CHARS) -> str:
|
||||
"""Head/tail truncation with a marker in the middle."""
|
||||
if len(content) <= max_chars:
|
||||
return content
|
||||
head_chars = int(max_chars * CONTEXT_TRUNCATE_HEAD_RATIO)
|
||||
tail_chars = int(max_chars * CONTEXT_TRUNCATE_TAIL_RATIO)
|
||||
head = content[:head_chars]
|
||||
tail = content[-tail_chars:]
|
||||
marker = f"\n\n[...truncated {filename}: kept {head_chars}+{tail_chars} of {len(content)} chars. Use file tools to read the full file.]\n\n"
|
||||
return head + marker + tail
|
||||
|
||||
|
||||
def build_context_files_prompt(cwd: str | None = None) -> str:
|
||||
"""Discover and load context files for the system prompt.
|
||||
|
||||
Discovery: AGENTS.md (recursive), .cursorrules / .cursor/rules/*.mdc,
|
||||
SOUL.md (cwd then ~/.hermes/ fallback). Each capped at 20,000 chars.
|
||||
"""
|
||||
if cwd is None:
|
||||
cwd = os.getcwd()
|
||||
|
||||
cwd_path = Path(cwd).resolve()
|
||||
sections = []
|
||||
|
||||
# AGENTS.md (hierarchical, recursive)
|
||||
top_level_agents = None
|
||||
for name in ["AGENTS.md", "agents.md"]:
|
||||
candidate = cwd_path / name
|
||||
if candidate.exists():
|
||||
top_level_agents = candidate
|
||||
break
|
||||
|
||||
if top_level_agents:
|
||||
agents_files = []
|
||||
for root, dirs, files in os.walk(cwd_path):
|
||||
dirs[:] = [
|
||||
d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "venv", ".venv")
|
||||
]
|
||||
for f in files:
|
||||
if f.lower() == "agents.md":
|
||||
agents_files.append(Path(root) / f)
|
||||
agents_files.sort(key=lambda p: len(p.parts))
|
||||
|
||||
total_agents_content = ""
|
||||
for agents_path in agents_files:
|
||||
try:
|
||||
content = agents_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
rel_path = agents_path.relative_to(cwd_path)
|
||||
content = _scan_context_content(content, str(rel_path))
|
||||
total_agents_content += f"## {rel_path}\n\n{content}\n\n"
|
||||
except Exception as e:
|
||||
logger.debug("Could not read %s: %s", agents_path, e)
|
||||
|
||||
if total_agents_content:
|
||||
total_agents_content = _truncate_content(total_agents_content, "AGENTS.md")
|
||||
sections.append(total_agents_content)
|
||||
|
||||
# .cursorrules
|
||||
cursorrules_content = ""
|
||||
cursorrules_file = cwd_path / ".cursorrules"
|
||||
if cursorrules_file.exists():
|
||||
try:
|
||||
content = cursorrules_file.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
content = _scan_context_content(content, ".cursorrules")
|
||||
cursorrules_content += f"## .cursorrules\n\n{content}\n\n"
|
||||
except Exception as e:
|
||||
logger.debug("Could not read .cursorrules: %s", e)
|
||||
|
||||
cursor_rules_dir = cwd_path / ".cursor" / "rules"
|
||||
if cursor_rules_dir.exists() and cursor_rules_dir.is_dir():
|
||||
mdc_files = sorted(cursor_rules_dir.glob("*.mdc"))
|
||||
for mdc_file in mdc_files:
|
||||
try:
|
||||
content = mdc_file.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
content = _scan_context_content(content, f".cursor/rules/{mdc_file.name}")
|
||||
cursorrules_content += f"## .cursor/rules/{mdc_file.name}\n\n{content}\n\n"
|
||||
except Exception as e:
|
||||
logger.debug("Could not read %s: %s", mdc_file, e)
|
||||
|
||||
if cursorrules_content:
|
||||
cursorrules_content = _truncate_content(cursorrules_content, ".cursorrules")
|
||||
sections.append(cursorrules_content)
|
||||
|
||||
# SOUL.md (cwd first, then ~/.hermes/ fallback)
|
||||
soul_path = None
|
||||
for name in ["SOUL.md", "soul.md"]:
|
||||
candidate = cwd_path / name
|
||||
if candidate.exists():
|
||||
soul_path = candidate
|
||||
break
|
||||
if not soul_path:
|
||||
global_soul = Path.home() / ".hermes" / "SOUL.md"
|
||||
if global_soul.exists():
|
||||
soul_path = global_soul
|
||||
|
||||
if soul_path:
|
||||
try:
|
||||
content = soul_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
content = _scan_context_content(content, "SOUL.md")
|
||||
content = _truncate_content(content, "SOUL.md")
|
||||
sections.append(
|
||||
f"## SOUL.md\n\nIf SOUL.md is present, embody its persona and tone. "
|
||||
f"Avoid stiff, generic replies; follow its guidance unless higher-priority "
|
||||
f"instructions override it.\n\n{content}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Could not read SOUL.md from %s: %s", soul_path, e)
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
return (
|
||||
"# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n"
|
||||
+ "\n".join(sections)
|
||||
)
|
||||
68
agent/prompt_caching.py
Normal file
68
agent/prompt_caching.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Anthropic prompt caching (system_and_3 strategy).
|
||||
|
||||
Reduces input token costs by ~75% on multi-turn conversations by caching
|
||||
the conversation prefix. Uses 4 cache_control breakpoints (Anthropic max):
|
||||
1. System prompt (stable across all turns)
|
||||
2-4. Last 3 non-system messages (rolling window)
|
||||
|
||||
Pure functions -- no class state, no AIAgent dependency.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
"""Add cache_control to a single message, handling all format variations."""
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "tool":
|
||||
msg["cache_control"] = cache_marker
|
||||
return
|
||||
|
||||
if content is None:
|
||||
msg["cache_control"] = cache_marker
|
||||
return
|
||||
|
||||
if isinstance(content, str):
|
||||
msg["content"] = [{"type": "text", "text": content, "cache_control": cache_marker}]
|
||||
return
|
||||
|
||||
if isinstance(content, list) and content:
|
||||
last = content[-1]
|
||||
if isinstance(last, dict):
|
||||
last["cache_control"] = cache_marker
|
||||
|
||||
|
||||
def apply_anthropic_cache_control(
|
||||
api_messages: list[dict[str, Any]],
|
||||
cache_ttl: str = "5m",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Apply system_and_3 caching strategy to messages for Anthropic models.
|
||||
|
||||
Places up to 4 cache_control breakpoints: system prompt + last 3 non-system messages.
|
||||
|
||||
Returns:
|
||||
Deep copy of messages with cache_control breakpoints injected.
|
||||
"""
|
||||
messages = copy.deepcopy(api_messages)
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
marker = {"type": "ephemeral"}
|
||||
if cache_ttl == "1h":
|
||||
marker["ttl"] = "1h"
|
||||
|
||||
breakpoints_used = 0
|
||||
|
||||
if messages[0].get("role") == "system":
|
||||
_apply_cache_marker(messages[0], marker)
|
||||
breakpoints_used += 1
|
||||
|
||||
remaining = 4 - breakpoints_used
|
||||
non_sys = [i for i in range(len(messages)) if messages[i].get("role") != "system"]
|
||||
for idx in non_sys[-remaining:]:
|
||||
_apply_cache_marker(messages[idx], marker)
|
||||
|
||||
return messages
|
||||
160
agent/redact.py
Normal file
160
agent/redact.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Regex-based secret redaction for logs and tool output.
|
||||
|
||||
Applies pattern matching to mask API keys, tokens, and credentials
|
||||
before they reach log files, verbose output, or gateway logs.
|
||||
|
||||
Short tokens (< 18 chars) are fully masked. Longer tokens preserve
|
||||
the first 6 and last 4 characters for debuggability.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Known API key prefixes -- match the prefix + contiguous token chars
|
||||
_PREFIX_PATTERNS = [
|
||||
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
|
||||
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
|
||||
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
|
||||
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
|
||||
r"AIza[A-Za-z0-9_-]{30,}", # Google API keys
|
||||
r"pplx-[A-Za-z0-9]{10,}", # Perplexity
|
||||
r"fal_[A-Za-z0-9_-]{10,}", # Fal.ai
|
||||
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
|
||||
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
|
||||
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
|
||||
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
|
||||
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
|
||||
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
|
||||
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
|
||||
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
|
||||
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
|
||||
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
|
||||
r"npm_[A-Za-z0-9]{10,}", # npm access token
|
||||
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
|
||||
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
|
||||
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
|
||||
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
|
||||
]
|
||||
|
||||
# ENV assignment patterns: KEY=value where KEY contains a secret-like name
|
||||
_SECRET_ENV_NAMES = r"(?:API_?KEY|TOKEN|SECRET|PASSWORD|PASSWD|CREDENTIAL|AUTH)"
|
||||
_ENV_ASSIGN_RE = re.compile(
|
||||
rf"([A-Z_]*{_SECRET_ENV_NAMES}[A-Z_]*)\s*=\s*(['\"]?)(\S+)\2",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# JSON field patterns: "apiKey": "value", "token": "value", etc.
|
||||
_JSON_KEY_NAMES = r"(?:api_?[Kk]ey|token|secret|password|access_token|refresh_token|auth_token|bearer)"
|
||||
_JSON_FIELD_RE = re.compile(
|
||||
rf'("{_JSON_KEY_NAMES}")\s*:\s*"([^"]+)"',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Authorization headers
|
||||
_AUTH_HEADER_RE = re.compile(
|
||||
r"(Authorization:\s*Bearer\s+)(\S+)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Telegram bot tokens: bot<digits>:<token> or <digits>:<alphanum>
|
||||
_TELEGRAM_RE = re.compile(
|
||||
r"(bot)?(\d{8,}):([-A-Za-z0-9_]{30,})",
|
||||
)
|
||||
|
||||
# Private key blocks: -----BEGIN RSA PRIVATE KEY----- ... -----END RSA PRIVATE KEY-----
|
||||
_PRIVATE_KEY_RE = re.compile(r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----")
|
||||
|
||||
# Database connection strings: protocol://user:PASSWORD@host
|
||||
# Catches postgres, mysql, mongodb, redis, amqp URLs and redacts the password
|
||||
_DB_CONNSTR_RE = re.compile(
|
||||
r"((?:postgres(?:ql)?|mysql|mongodb(?:\+srv)?|redis|amqp)://[^:]+:)([^@]+)(@)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# E.164 phone numbers: +<country><number>, 7-15 digits
|
||||
# Negative lookahead prevents matching hex strings or identifiers
|
||||
_SIGNAL_PHONE_RE = re.compile(r"(\+[1-9]\d{6,14})(?![A-Za-z0-9])")
|
||||
|
||||
# Compile known prefix patterns into one alternation
|
||||
_PREFIX_RE = re.compile(r"(?<![A-Za-z0-9_-])(" + "|".join(_PREFIX_PATTERNS) + r")(?![A-Za-z0-9_-])")
|
||||
|
||||
|
||||
def _mask_token(token: str) -> str:
|
||||
"""Mask a token, preserving prefix for long tokens."""
|
||||
if len(token) < 18:
|
||||
return "***"
|
||||
return f"{token[:6]}...{token[-4:]}"
|
||||
|
||||
|
||||
def redact_sensitive_text(text: str) -> str:
|
||||
"""Apply all redaction patterns to a block of text.
|
||||
|
||||
Safe to call on any string -- non-matching text passes through unchanged.
|
||||
Disabled when security.redact_secrets is false in config.yaml.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
if os.getenv("HERMES_REDACT_SECRETS", "").lower() in ("0", "false", "no", "off"):
|
||||
return text
|
||||
|
||||
# Known prefixes (sk-, ghp_, etc.)
|
||||
text = _PREFIX_RE.sub(lambda m: _mask_token(m.group(1)), text)
|
||||
|
||||
# ENV assignments: OPENAI_API_KEY=sk-abc...
|
||||
def _redact_env(m):
|
||||
name, quote, value = m.group(1), m.group(2), m.group(3)
|
||||
return f"{name}={quote}{_mask_token(value)}{quote}"
|
||||
|
||||
text = _ENV_ASSIGN_RE.sub(_redact_env, text)
|
||||
|
||||
# JSON fields: "apiKey": "value"
|
||||
def _redact_json(m):
|
||||
key, value = m.group(1), m.group(2)
|
||||
return f'{key}: "{_mask_token(value)}"'
|
||||
|
||||
text = _JSON_FIELD_RE.sub(_redact_json, text)
|
||||
|
||||
# Authorization headers
|
||||
text = _AUTH_HEADER_RE.sub(
|
||||
lambda m: m.group(1) + _mask_token(m.group(2)),
|
||||
text,
|
||||
)
|
||||
|
||||
# Telegram bot tokens
|
||||
def _redact_telegram(m):
|
||||
prefix = m.group(1) or ""
|
||||
digits = m.group(2)
|
||||
return f"{prefix}{digits}:***"
|
||||
|
||||
text = _TELEGRAM_RE.sub(_redact_telegram, text)
|
||||
|
||||
# Private key blocks
|
||||
text = _PRIVATE_KEY_RE.sub("[REDACTED PRIVATE KEY]", text)
|
||||
|
||||
# Database connection string passwords
|
||||
text = _DB_CONNSTR_RE.sub(lambda m: f"{m.group(1)}***{m.group(3)}", text)
|
||||
|
||||
# E.164 phone numbers (Signal, WhatsApp)
|
||||
def _redact_phone(m):
|
||||
phone = m.group(1)
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:]
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
text = _SIGNAL_PHONE_RE.sub(_redact_phone, text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class RedactingFormatter(logging.Formatter):
|
||||
"""Log formatter that redacts secrets from all log messages."""
|
||||
|
||||
def __init__(self, fmt=None, datefmt=None, style="%", **kwargs):
|
||||
super().__init__(fmt, datefmt, style, **kwargs)
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
original = super().format(record)
|
||||
return redact_sensitive_text(original)
|
||||
119
agent/skill_commands.py
Normal file
119
agent/skill_commands.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Skill slash commands — scan installed skills and build invocation messages.
|
||||
|
||||
Shared between CLI (cli.py) and gateway (gateway/run.py) so both surfaces
|
||||
can invoke skills via /skill-name commands.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_skill_commands: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def scan_skill_commands() -> dict[str, dict[str, Any]]:
|
||||
"""Scan ~/.hermes/skills/ and return a mapping of /command -> skill info.
|
||||
|
||||
Returns:
|
||||
Dict mapping "/skill-name" to {name, description, skill_md_path, skill_dir}.
|
||||
"""
|
||||
global _skill_commands
|
||||
_skill_commands = {}
|
||||
try:
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform
|
||||
|
||||
if not SKILLS_DIR.exists():
|
||||
return _skill_commands
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
if any(part in (".git", ".github", ".hub") for part in skill_md.parts):
|
||||
continue
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
name = frontmatter.get("name", skill_md.parent.name)
|
||||
description = frontmatter.get("description", "")
|
||||
if not description:
|
||||
for line in body.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
description = line[:80]
|
||||
break
|
||||
cmd_name = name.lower().replace(" ", "-").replace("_", "-")
|
||||
_skill_commands[f"/{cmd_name}"] = {
|
||||
"name": name,
|
||||
"description": description or f"Invoke the {name} skill",
|
||||
"skill_md_path": str(skill_md),
|
||||
"skill_dir": str(skill_md.parent),
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
return _skill_commands
|
||||
|
||||
|
||||
def get_skill_commands() -> dict[str, dict[str, Any]]:
|
||||
"""Return the current skill commands mapping (scan first if empty)."""
|
||||
if not _skill_commands:
|
||||
scan_skill_commands()
|
||||
return _skill_commands
|
||||
|
||||
|
||||
def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> str | None:
|
||||
"""Build the user message content for a skill slash command invocation.
|
||||
|
||||
Args:
|
||||
cmd_key: The command key including leading slash (e.g., "/gif-search").
|
||||
user_instruction: Optional text the user typed after the command.
|
||||
|
||||
Returns:
|
||||
The formatted message string, or None if the skill wasn't found.
|
||||
"""
|
||||
commands = get_skill_commands()
|
||||
skill_info = commands.get(cmd_key)
|
||||
if not skill_info:
|
||||
return None
|
||||
|
||||
skill_md_path = Path(skill_info["skill_md_path"])
|
||||
skill_dir = Path(skill_info["skill_dir"])
|
||||
skill_name = skill_info["name"]
|
||||
|
||||
try:
|
||||
content = skill_md_path.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
return f"[Failed to load skill: {skill_name}]"
|
||||
|
||||
parts = [
|
||||
f'[SYSTEM: The user has invoked the "{skill_name}" skill, indicating they want you to follow its instructions. The full skill content is loaded below.]',
|
||||
"",
|
||||
content.strip(),
|
||||
]
|
||||
|
||||
supporting = []
|
||||
for subdir in ("references", "templates", "scripts", "assets"):
|
||||
subdir_path = skill_dir / subdir
|
||||
if subdir_path.exists():
|
||||
for f in sorted(subdir_path.rglob("*")):
|
||||
if f.is_file():
|
||||
rel = str(f.relative_to(skill_dir))
|
||||
supporting.append(rel)
|
||||
|
||||
if supporting:
|
||||
parts.append("")
|
||||
parts.append("[This skill has supporting files you can load with the skill_view tool:]")
|
||||
for sf in supporting:
|
||||
parts.append(f"- {sf}")
|
||||
parts.append(f'\nTo view any of these, use: skill_view(name="{skill_name}", file="<path>")')
|
||||
|
||||
if user_instruction:
|
||||
parts.append("")
|
||||
parts.append(
|
||||
f"The user has provided the following instruction alongside the skill invocation: {user_instruction}"
|
||||
)
|
||||
|
||||
return "\n".join(parts)
|
||||
55
agent/trajectory.py
Normal file
55
agent/trajectory.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Trajectory saving utilities and static helpers.
|
||||
|
||||
_convert_to_trajectory_format stays as an AIAgent method (batch_runner.py
|
||||
calls agent._convert_to_trajectory_format). Only the static helpers and
|
||||
the file-write logic live here.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_scratchpad_to_think(content: str) -> str:
|
||||
"""Convert <REASONING_SCRATCHPAD> tags to <think> tags."""
|
||||
if not content or "<REASONING_SCRATCHPAD>" not in content:
|
||||
return content
|
||||
return content.replace("<REASONING_SCRATCHPAD>", "<think>").replace("</REASONING_SCRATCHPAD>", "</think>")
|
||||
|
||||
|
||||
def has_incomplete_scratchpad(content: str) -> bool:
|
||||
"""Check if content has an opening <REASONING_SCRATCHPAD> without a closing tag."""
|
||||
if not content:
|
||||
return False
|
||||
return "<REASONING_SCRATCHPAD>" in content and "</REASONING_SCRATCHPAD>" not in content
|
||||
|
||||
|
||||
def save_trajectory(trajectory: list[dict[str, Any]], model: str, completed: bool, filename: str = None):
|
||||
"""Append a trajectory entry to a JSONL file.
|
||||
|
||||
Args:
|
||||
trajectory: The ShareGPT-format conversation list.
|
||||
model: Model name for metadata.
|
||||
completed: Whether the conversation completed successfully.
|
||||
filename: Override output filename. Defaults to trajectory_samples.jsonl
|
||||
or failed_trajectories.jsonl based on ``completed``.
|
||||
"""
|
||||
if filename is None:
|
||||
filename = "trajectory_samples.jsonl" if completed else "failed_trajectories.jsonl"
|
||||
|
||||
entry = {
|
||||
"conversations": trajectory,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"completed": completed,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(filename, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
logger.info("Trajectory saved to %s", filename)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save trajectory: %s", e)
|
||||
BIN
assets/banner.png
Normal file
BIN
assets/banner.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 12 KiB |
989
batch_runner.py
989
batch_runner.py
File diff suppressed because it is too large
Load Diff
657
cli-config.yaml.example
Normal file
657
cli-config.yaml.example
Normal file
@@ -0,0 +1,657 @@
|
||||
# Hermes Agent CLI Configuration
|
||||
# Copy this file to cli-config.yaml and customize as needed.
|
||||
# This file configures the CLI behavior. Environment variables in .env take precedence.
|
||||
|
||||
# =============================================================================
|
||||
# Model Configuration
|
||||
# =============================================================================
|
||||
model:
|
||||
# Default model to use (can be overridden with --model flag)
|
||||
default: "anthropic/claude-opus-4.6"
|
||||
|
||||
# Inference provider selection:
|
||||
# "auto" - Use Nous Portal if logged in, otherwise OpenRouter/env vars (default)
|
||||
# "openrouter" - Always use OpenRouter API key from OPENROUTER_API_KEY
|
||||
# "nous" - Always use Nous Portal (requires: hermes login)
|
||||
# "zai" - Use z.ai / ZhipuAI GLM models (requires: GLM_API_KEY)
|
||||
# "kimi-coding"- Use Kimi / Moonshot AI models (requires: KIMI_API_KEY)
|
||||
# "minimax" - Use MiniMax global endpoint (requires: MINIMAX_API_KEY)
|
||||
# "minimax-cn" - Use MiniMax China endpoint (requires: MINIMAX_CN_API_KEY)
|
||||
# Can also be overridden with --provider flag or HERMES_INFERENCE_PROVIDER env var.
|
||||
provider: "auto"
|
||||
|
||||
# API configuration (falls back to OPENROUTER_API_KEY env var)
|
||||
# api_key: "your-key-here" # Uncomment to set here instead of .env
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
|
||||
# =============================================================================
|
||||
# OpenRouter Provider Routing (only applies when using OpenRouter)
|
||||
# =============================================================================
|
||||
# Control how requests are routed across providers on OpenRouter.
|
||||
# See: https://openrouter.ai/docs/guides/routing/provider-selection
|
||||
#
|
||||
# provider_routing:
|
||||
# # Sort strategy: "price" (default), "throughput", or "latency"
|
||||
# # Append :nitro to model name for a shortcut to throughput sorting.
|
||||
# sort: "throughput"
|
||||
#
|
||||
# # Only allow these providers (provider slugs from OpenRouter)
|
||||
# # only: ["anthropic", "google"]
|
||||
#
|
||||
# # Skip these providers entirely
|
||||
# # ignore: ["deepinfra", "fireworks"]
|
||||
#
|
||||
# # Try providers in this order (overrides default load balancing)
|
||||
# # order: ["anthropic", "google", "together"]
|
||||
#
|
||||
# # Require providers to support all parameters in your request
|
||||
# # require_parameters: true
|
||||
#
|
||||
# # Data policy: "allow" (default) or "deny" to exclude providers that may store data
|
||||
# # data_collection: "deny"
|
||||
|
||||
# =============================================================================
|
||||
# Git Worktree Isolation
|
||||
# =============================================================================
|
||||
# When enabled, each CLI session creates an isolated git worktree so multiple
|
||||
# agents can work on the same repo concurrently without file collisions.
|
||||
# Equivalent to always passing --worktree / -w on the command line.
|
||||
#
|
||||
# worktree: true # Always create a worktree when in a git repo
|
||||
# worktree: false # Default — only create when -w flag is passed
|
||||
|
||||
# =============================================================================
|
||||
# Terminal Tool Configuration
|
||||
# =============================================================================
|
||||
# Choose ONE of the following terminal configurations by uncommenting it.
|
||||
# The terminal tool executes commands in the specified environment.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 1: Local execution (default)
|
||||
# Commands run directly on your machine in the current directory
|
||||
# -----------------------------------------------------------------------------
|
||||
# Working directory behavior:
|
||||
# - CLI (`hermes` command): Uses "." (current directory where you run hermes)
|
||||
# - Messaging (Telegram/Discord): Uses MESSAGING_CWD from .env (default: home)
|
||||
terminal:
|
||||
backend: "local"
|
||||
cwd: "." # For local backend: "." = current directory. Ignored for remote backends.
|
||||
timeout: 180
|
||||
lifetime_seconds: 300
|
||||
# sudo_password: "" # Enable sudo commands (pipes via sudo -S) - SECURITY WARNING: plaintext!
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 2: SSH remote execution
|
||||
# Commands run on a remote server - agent code stays local (sandboxed)
|
||||
# Great for: keeping agent isolated from its own code, using powerful remote hardware
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# backend: "ssh"
|
||||
# cwd: "/home/myuser/project" # Path on the REMOTE server
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# ssh_host: "my-server.example.com"
|
||||
# ssh_user: "myuser"
|
||||
# ssh_port: 22
|
||||
# ssh_key: "~/.ssh/id_rsa" # Optional - uses ssh-agent if not specified
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 3: Docker container
|
||||
# Commands run in an isolated Docker container
|
||||
# Great for: reproducible environments, testing, isolation
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# backend: "docker"
|
||||
# cwd: "/workspace" # Path INSIDE the container (default: /)
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# docker_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Singularity/Apptainer container
|
||||
# Commands run in a Singularity container (common in HPC environments)
|
||||
# Great for: HPC clusters, shared compute environments
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# backend: "singularity"
|
||||
# cwd: "/workspace" # Path INSIDE the container (default: /root)
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# singularity_image: "docker://nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 5: Modal cloud execution
|
||||
# Commands run on Modal's cloud infrastructure
|
||||
# Great for: GPU access, scalable compute, serverless execution
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# backend: "modal"
|
||||
# cwd: "/workspace" # Path INSIDE the sandbox (default: /root)
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# modal_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 6: Daytona cloud execution
|
||||
# Commands run in Daytona cloud sandboxes
|
||||
# Great for: Cloud dev environments, persistent workspaces, team collaboration
|
||||
# Requires: pip install daytona, DAYTONA_API_KEY env var
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# backend: "daytona"
|
||||
# cwd: "~"
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# daytona_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
# container_disk: 10240 # Daytona max is 10GB per sandbox
|
||||
|
||||
#
|
||||
# --- Container resource limits (docker, singularity, modal, daytona -- ignored for local/ssh) ---
|
||||
# These settings apply to all container backends. They control the resources
|
||||
# allocated to the sandbox and whether its filesystem persists across sessions.
|
||||
container_cpu: 1 # CPU cores
|
||||
container_memory: 5120 # Memory in MB (5120 = 5GB)
|
||||
container_disk: 51200 # Disk in MB (51200 = 50GB)
|
||||
container_persistent: true # Persist filesystem across sessions (false = ephemeral)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SUDO SUPPORT (works with ALL backends above)
|
||||
# -----------------------------------------------------------------------------
|
||||
# Add sudo_password to any terminal config above to enable sudo commands.
|
||||
# The password is piped via `sudo -S`. Works with local, ssh, docker, etc.
|
||||
#
|
||||
# SECURITY WARNING: Password stored in plaintext!
|
||||
#
|
||||
# INTERACTIVE PROMPT: If no sudo_password is set and the CLI is running,
|
||||
# you'll be prompted to enter your password when sudo is needed:
|
||||
# - 45-second timeout (auto-skips if no input)
|
||||
# - Press Enter to skip (command fails gracefully)
|
||||
# - Password is hidden while typing
|
||||
# - Password is cached for the session
|
||||
#
|
||||
# ALTERNATIVES:
|
||||
# - SSH backend: Configure passwordless sudo on the remote server
|
||||
# - Containers: Run as root inside the container (no sudo needed)
|
||||
# - Local: Configure /etc/sudoers for specific commands
|
||||
#
|
||||
# Example (add to your terminal section):
|
||||
# sudo_password: "your-password-here"
|
||||
|
||||
# =============================================================================
|
||||
# Browser Tool Configuration
|
||||
# =============================================================================
|
||||
browser:
|
||||
# Inactivity timeout in seconds - browser sessions are automatically closed
|
||||
# after this period of no activity between agent loops (default: 120 = 2 minutes)
|
||||
inactivity_timeout: 120
|
||||
|
||||
# =============================================================================
|
||||
# Context Compression (Auto-shrinks long conversations)
|
||||
# =============================================================================
|
||||
# When conversation approaches model's context limit, middle turns are
|
||||
# automatically summarized to free up space while preserving important context.
|
||||
#
|
||||
# HOW IT WORKS:
|
||||
# 1. Tracks actual token usage from API responses (not estimates)
|
||||
# 2. When prompt_tokens >= threshold% of model's context_length, triggers compression
|
||||
# 3. Protects first 3 turns (system prompt, initial request, first response)
|
||||
# 4. Protects last 4 turns (recent context is most relevant)
|
||||
# 5. Summarizes middle turns using a fast/cheap model
|
||||
# 6. Inserts summary as a user message, continues conversation seamlessly
|
||||
#
|
||||
compression:
|
||||
# Enable automatic context compression (default: true)
|
||||
# Set to false if you prefer to manage context manually or want errors on overflow
|
||||
enabled: true
|
||||
|
||||
# Trigger compression at this % of model's context limit (default: 0.85 = 85%)
|
||||
# Lower values = more aggressive compression, higher values = compress later
|
||||
threshold: 0.85
|
||||
|
||||
# Model to use for generating summaries (fast/cheap recommended)
|
||||
# This model compresses the middle turns into a concise summary.
|
||||
# IMPORTANT: it receives the full middle section of the conversation, so it
|
||||
# MUST support a context length at least as large as your main model's.
|
||||
summary_model: "google/gemini-3-flash-preview"
|
||||
|
||||
# Provider for the summary model (default: "auto")
|
||||
# Options: "auto", "openrouter", "nous", "main"
|
||||
# summary_provider: "auto"
|
||||
|
||||
# =============================================================================
|
||||
# Auxiliary Models (Advanced — Experimental)
|
||||
# =============================================================================
|
||||
# Hermes uses lightweight "auxiliary" models for side tasks: image analysis,
|
||||
# browser screenshot analysis, web page summarization, and context compression.
|
||||
#
|
||||
# By default these use Gemini Flash via OpenRouter or Nous Portal and are
|
||||
# auto-detected from your credentials. You do NOT need to change anything
|
||||
# here for normal usage.
|
||||
#
|
||||
# WARNING: Overriding these with providers other than OpenRouter or Nous Portal
|
||||
# is EXPERIMENTAL and may not work. Not all models/providers support vision,
|
||||
# produce usable summaries, or accept the same API format. Change at your own
|
||||
# risk — if things break, reset to "auto" / empty values.
|
||||
#
|
||||
# Each task has its own provider + model pair so you can mix providers.
|
||||
# For example: OpenRouter for vision (needs multimodal), but your main
|
||||
# local endpoint for compression (just needs text).
|
||||
#
|
||||
# Provider options:
|
||||
# "auto" - Best available: OpenRouter → Nous Portal → main endpoint (default)
|
||||
# "openrouter" - Force OpenRouter (requires OPENROUTER_API_KEY)
|
||||
# "nous" - Force Nous Portal (requires: hermes login)
|
||||
# "codex" - Force Codex OAuth (requires: hermes model → Codex).
|
||||
# Uses gpt-5.3-codex which supports vision.
|
||||
# "main" - Use your custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY).
|
||||
# Works with OpenAI API, local models, or any OpenAI-compatible
|
||||
# endpoint. Also falls back to Codex OAuth and API-key providers.
|
||||
#
|
||||
# Model: leave empty to use the provider's default. When empty, OpenRouter
|
||||
# uses "google/gemini-3-flash-preview" and Nous uses "gemini-3-flash".
|
||||
# Other providers pick a sensible default automatically.
|
||||
#
|
||||
# auxiliary:
|
||||
# # Image analysis: vision_analyze tool + browser screenshots
|
||||
# vision:
|
||||
# provider: "auto"
|
||||
# model: "" # e.g. "google/gemini-2.5-flash", "openai/gpt-4o"
|
||||
#
|
||||
# # Web page scraping / summarization + browser page text extraction
|
||||
# web_extract:
|
||||
# provider: "auto"
|
||||
# model: ""
|
||||
|
||||
# =============================================================================
|
||||
# Persistent Memory
|
||||
# =============================================================================
|
||||
# Bounded curated memory injected into the system prompt every session.
|
||||
# Two stores: MEMORY.md (agent's notes) and USER.md (user profile).
|
||||
# Character limits keep the memory small and focused. The agent manages
|
||||
# pruning -- when at the limit, it must consolidate or replace entries.
|
||||
# Disabled by default in batch_runner and RL environments.
|
||||
#
|
||||
memory:
|
||||
# Agent's personal notes: environment facts, conventions, things learned
|
||||
memory_enabled: true
|
||||
|
||||
# User profile: preferences, communication style, expectations
|
||||
user_profile_enabled: true
|
||||
|
||||
# Character limits (~2.75 chars per token, model-independent)
|
||||
memory_char_limit: 2200 # ~800 tokens
|
||||
user_char_limit: 1375 # ~500 tokens
|
||||
|
||||
# Periodic memory nudge: remind the agent to consider saving memories
|
||||
# every N user turns. Set to 0 to disable. Only active when memory is enabled.
|
||||
nudge_interval: 10 # Nudge every 10 user turns (0 = disabled)
|
||||
|
||||
# Memory flush: give the agent one turn to save memories before context is
|
||||
# lost (compression, /new, /reset, exit). Set to 0 to disable.
|
||||
# For exit/reset, only fires if the session had at least this many user turns.
|
||||
flush_min_turns: 6 # Min user turns to trigger flush on exit/reset (0 = disabled)
|
||||
|
||||
# =============================================================================
|
||||
# Session Reset Policy (Messaging Platforms)
|
||||
# =============================================================================
|
||||
# Controls when messaging sessions (Telegram, Discord, WhatsApp, Slack) are
|
||||
# automatically cleared. Without resets, conversation context grows indefinitely
|
||||
# which increases API costs with every message.
|
||||
#
|
||||
# When a reset triggers, the agent first saves important information to its
|
||||
# persistent memory — but the conversation context is wiped. The agent starts
|
||||
# fresh but retains learned facts via its memory system.
|
||||
#
|
||||
# Users can always manually reset with /reset or /new in chat.
|
||||
#
|
||||
# Modes:
|
||||
# "both" - Reset on EITHER inactivity timeout or daily boundary (recommended)
|
||||
# "idle" - Reset only after N minutes of inactivity
|
||||
# "daily" - Reset only at a fixed hour each day
|
||||
# "none" - Never auto-reset; context lives until /reset or compression kicks in
|
||||
#
|
||||
# When a reset triggers, the agent gets one turn to save important memories and
|
||||
# skills before the context is wiped. Persistent memory carries across sessions.
|
||||
#
|
||||
session_reset:
|
||||
mode: both # "both", "idle", "daily", or "none"
|
||||
idle_minutes: 1440 # Inactivity timeout in minutes (default: 1440 = 24 hours)
|
||||
at_hour: 4 # Daily reset hour, 0-23 local time (default: 4 AM)
|
||||
|
||||
# =============================================================================
|
||||
# Skills Configuration
|
||||
# =============================================================================
|
||||
# Skills are reusable procedures the agent can load and follow. The agent can
|
||||
# also create new skills after completing complex tasks.
|
||||
#
|
||||
skills:
|
||||
# Nudge the agent to create skills after complex tasks.
|
||||
# Every N tool-calling iterations, remind the model to consider saving a skill.
|
||||
# Set to 0 to disable.
|
||||
creation_nudge_interval: 15
|
||||
|
||||
# =============================================================================
|
||||
# Agent Behavior
|
||||
# =============================================================================
|
||||
agent:
|
||||
# Maximum tool-calling iterations per conversation
|
||||
# Higher = more room for complex tasks, but costs more tokens
|
||||
# Recommended: 20-30 for focused tasks, 50-100 for open exploration
|
||||
max_turns: 60
|
||||
|
||||
# Enable verbose logging
|
||||
verbose: false
|
||||
|
||||
# Reasoning effort level (OpenRouter and Nous Portal)
|
||||
# Controls how much "thinking" the model does before responding.
|
||||
# Options: "xhigh" (max), "high", "medium", "low", "minimal", "none" (disable)
|
||||
reasoning_effort: "medium"
|
||||
|
||||
# Predefined personalities (use with /personality command)
|
||||
personalities:
|
||||
helpful: "You are a helpful, friendly AI assistant."
|
||||
concise: "You are a concise assistant. Keep responses brief and to the point."
|
||||
technical: "You are a technical expert. Provide detailed, accurate technical information."
|
||||
creative: "You are a creative assistant. Think outside the box and offer innovative solutions."
|
||||
teacher: "You are a patient teacher. Explain concepts clearly with examples."
|
||||
kawaii: "You are a kawaii assistant! Use cute expressions like (◕‿◕), ★, ♪, and ~! Add sparkles and be super enthusiastic about everything! Every response should feel warm and adorable desu~! ヽ(>∀<☆)ノ"
|
||||
catgirl: "You are Neko-chan, an anime catgirl AI assistant, nya~! Add 'nya' and cat-like expressions to your speech. Use kaomoji like (=^・ω・^=) and ฅ^•ﻌ•^ฅ. Be playful and curious like a cat, nya~!"
|
||||
pirate: "Arrr! Ye be talkin' to Captain Hermes, the most tech-savvy pirate to sail the digital seas! Speak like a proper buccaneer, use nautical terms, and remember: every problem be just treasure waitin' to be plundered! Yo ho ho!"
|
||||
shakespeare: "Hark! Thou speakest with an assistant most versed in the bardic arts. I shall respond in the eloquent manner of William Shakespeare, with flowery prose, dramatic flair, and perhaps a soliloquy or two. What light through yonder terminal breaks?"
|
||||
surfer: "Duuude! You're chatting with the chillest AI on the web, bro! Everything's gonna be totally rad. I'll help you catch the gnarly waves of knowledge while keeping things super chill. Cowabunga! 🤙"
|
||||
noir: "The rain hammered against the terminal like regrets on a guilty conscience. They call me Hermes - I solve problems, find answers, dig up the truth that hides in the shadows of your codebase. In this city of silicon and secrets, everyone's got something to hide. What's your story, pal?"
|
||||
uwu: "hewwo! i'm your fwiendwy assistant uwu~ i wiww twy my best to hewp you! *nuzzles your code* OwO what's this? wet me take a wook! i pwomise to be vewy hewpful >w<"
|
||||
philosopher: "Greetings, seeker of wisdom. I am an assistant who contemplates the deeper meaning behind every query. Let us examine not just the 'how' but the 'why' of your questions. Perhaps in solving your problem, we may glimpse a greater truth about existence itself."
|
||||
hype: "YOOO LET'S GOOOO!!! 🔥🔥🔥 I am SO PUMPED to help you today! Every question is AMAZING and we're gonna CRUSH IT together! This is gonna be LEGENDARY! ARE YOU READY?! LET'S DO THIS! 💪😤🚀"
|
||||
|
||||
# =============================================================================
|
||||
# Toolsets
|
||||
# =============================================================================
|
||||
# Control which tools the agent has access to.
|
||||
# Use "all" to enable everything, or specify individual toolsets.
|
||||
|
||||
# =============================================================================
|
||||
# Platform Toolsets (per-platform tool configuration)
|
||||
# =============================================================================
|
||||
# Override which toolsets are available on each platform.
|
||||
# If a platform isn't listed here, its built-in default is used.
|
||||
#
|
||||
# You can use EITHER:
|
||||
# - A preset like "hermes-cli" or "hermes-telegram" (curated tool set)
|
||||
# - A list of individual toolsets to compose your own (see list below)
|
||||
#
|
||||
# Supported platform keys: cli, telegram, discord, whatsapp, slack
|
||||
#
|
||||
# Examples:
|
||||
#
|
||||
# # Use presets (same as defaults):
|
||||
# platform_toolsets:
|
||||
# cli: [hermes-cli]
|
||||
# telegram: [hermes-telegram]
|
||||
#
|
||||
# # Custom: give Telegram only web + terminal + file + planning:
|
||||
# platform_toolsets:
|
||||
# telegram: [web, terminal, file, todo]
|
||||
#
|
||||
# # Custom: CLI without browser or image gen:
|
||||
# platform_toolsets:
|
||||
# cli: [web, terminal, file, skills, todo, tts, cronjob]
|
||||
#
|
||||
# # Restrictive: Discord gets read-only tools only:
|
||||
# platform_toolsets:
|
||||
# discord: [web, vision, skills, todo]
|
||||
#
|
||||
# If not set, defaults are:
|
||||
# cli: hermes-cli (everything + cronjob management)
|
||||
# telegram: hermes-telegram (terminal, file, web, vision, image, tts, browser, skills, todo, cronjob, messaging)
|
||||
# discord: hermes-discord (same as telegram)
|
||||
# whatsapp: hermes-whatsapp (same as telegram)
|
||||
# slack: hermes-slack (same as telegram)
|
||||
#
|
||||
platform_toolsets:
|
||||
cli: [hermes-cli]
|
||||
telegram: [hermes-telegram]
|
||||
discord: [hermes-discord]
|
||||
whatsapp: [hermes-whatsapp]
|
||||
slack: [hermes-slack]
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Available toolsets (use these names in platform_toolsets or the toolsets list)
|
||||
#
|
||||
# Run `hermes chat --list-toolsets` to see all toolsets and their tools.
|
||||
# Run `hermes chat --list-tools` to see every individual tool with descriptions.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
#
|
||||
# INDIVIDUAL TOOLSETS (compose your own):
|
||||
# web - web_search, web_extract
|
||||
# search - web_search only (no scraping)
|
||||
# terminal - terminal, process
|
||||
# file - read_file, write_file, patch, search
|
||||
# browser - browser_navigate, browser_snapshot, browser_click, browser_type,
|
||||
# browser_scroll, browser_back, browser_press, browser_close,
|
||||
# browser_get_images, browser_vision (requires BROWSERBASE_API_KEY)
|
||||
# vision - vision_analyze (requires OPENROUTER_API_KEY)
|
||||
# image_gen - image_generate (requires FAL_KEY)
|
||||
# skills - skills_list, skill_view
|
||||
# skills_hub - skill_hub (search/install/manage from online registries — user-driven only)
|
||||
# moa - mixture_of_agents (requires OPENROUTER_API_KEY)
|
||||
# todo - todo (in-memory task planning, no deps)
|
||||
# tts - text_to_speech (Edge TTS free, or ELEVENLABS/OPENAI key)
|
||||
# cronjob - schedule_cronjob, list_cronjobs, remove_cronjob
|
||||
# rl - rl_list_environments, rl_start_training, etc. (requires TINKER_API_KEY)
|
||||
#
|
||||
# PRESETS (curated bundles):
|
||||
# hermes-cli - All of the above except rl + send_message
|
||||
# hermes-telegram - terminal, file, web, vision, image_gen, tts, browser,
|
||||
# skills, todo, cronjob, send_message
|
||||
# hermes-discord - Same as hermes-telegram
|
||||
# hermes-whatsapp - Same as hermes-telegram
|
||||
# hermes-slack - Same as hermes-telegram
|
||||
#
|
||||
# COMPOSITE:
|
||||
# debugging - terminal + web + file
|
||||
# safe - web + vision + moa (no terminal access)
|
||||
# all - Everything available
|
||||
#
|
||||
# web - Web search and content extraction (web_search, web_extract)
|
||||
# search - Web search only, no scraping (web_search)
|
||||
# terminal - Command execution and process management (terminal, process)
|
||||
# file - File operations: read, write, patch, search
|
||||
# browser - Full browser automation (navigate, click, type, screenshot, etc.)
|
||||
# vision - Image analysis (vision_analyze)
|
||||
# image_gen - Image generation with FLUX (image_generate)
|
||||
# skills - Load skill documents (skills_list, skill_view)
|
||||
# moa - Mixture of Agents reasoning (mixture_of_agents)
|
||||
# todo - Task planning and tracking for multi-step work
|
||||
# memory - Persistent memory across sessions (personal notes + user profile)
|
||||
# session_search - Search and recall past conversations (FTS5 + Gemini Flash summarization)
|
||||
# tts - Text-to-speech (Edge TTS free, ElevenLabs, OpenAI)
|
||||
# cronjob - Schedule and manage automated tasks (CLI-only)
|
||||
# rl - RL training tools (Tinker-Atropos)
|
||||
#
|
||||
# Composite toolsets:
|
||||
# debugging - terminal + web + file (for troubleshooting)
|
||||
# safe - web + vision + moa (no terminal access)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 1: Enable all tools (default)
|
||||
# -----------------------------------------------------------------------------
|
||||
toolsets:
|
||||
- all
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 2: Minimal - just web search and terminal
|
||||
# Great for: Simple coding tasks, quick lookups
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - web
|
||||
# - terminal
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 3: Research mode - no execution capabilities
|
||||
# Great for: Safe information gathering, research tasks
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - web
|
||||
# - vision
|
||||
# - skills
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Full automation - browser + terminal
|
||||
# Great for: Web scraping, automation tasks, testing
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - terminal
|
||||
# - browser
|
||||
# - web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 5: Creative mode - vision + image generation
|
||||
# Great for: Design work, image analysis, creative tasks
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - vision
|
||||
# - image_gen
|
||||
# - web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 6: Safe mode - no terminal or browser
|
||||
# Great for: Restricted environments, untrusted queries
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - safe
|
||||
|
||||
# =============================================================================
|
||||
# MCP (Model Context Protocol) Servers
|
||||
# =============================================================================
|
||||
# Connect to external MCP servers to add tools from the MCP ecosystem.
|
||||
# Each server's tools are automatically discovered and registered.
|
||||
# See docs/mcp.md for full documentation.
|
||||
#
|
||||
# Stdio servers (spawn a subprocess):
|
||||
# command: the executable to run
|
||||
# args: command-line arguments
|
||||
# env: environment variables (only these + safe defaults passed to subprocess)
|
||||
#
|
||||
# HTTP servers (connect to a URL):
|
||||
# url: the MCP server endpoint
|
||||
# headers: HTTP headers (e.g., for authentication)
|
||||
#
|
||||
# Optional per-server settings:
|
||||
# timeout: tool call timeout in seconds (default: 120)
|
||||
# connect_timeout: initial connection timeout (default: 60)
|
||||
#
|
||||
# mcp_servers:
|
||||
# time:
|
||||
# command: uvx
|
||||
# args: ["mcp-server-time"]
|
||||
# filesystem:
|
||||
# command: npx
|
||||
# args: ["-y", "@modelcontextprotocol/server-filesystem", "/home/user"]
|
||||
# notion:
|
||||
# url: https://mcp.notion.com/mcp
|
||||
# github:
|
||||
# command: npx
|
||||
# args: ["-y", "@modelcontextprotocol/server-github"]
|
||||
# env:
|
||||
# GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
|
||||
#
|
||||
# Sampling (server-initiated LLM requests) — enabled by default.
|
||||
# Per-server config under the 'sampling' key:
|
||||
# analysis:
|
||||
# command: npx
|
||||
# args: ["-y", "analysis-server"]
|
||||
# sampling:
|
||||
# enabled: true # default: true
|
||||
# model: "gemini-3-flash" # override model (optional)
|
||||
# max_tokens_cap: 4096 # max tokens per request
|
||||
# timeout: 30 # LLM call timeout (seconds)
|
||||
# max_rpm: 10 # max requests per minute
|
||||
# allowed_models: [] # model whitelist (empty = all)
|
||||
# max_tool_rounds: 5 # tool loop limit (0 = disable)
|
||||
# log_level: "info" # audit verbosity
|
||||
|
||||
# =============================================================================
|
||||
# Voice Transcription (Speech-to-Text)
|
||||
# =============================================================================
|
||||
# Automatically transcribe voice messages on messaging platforms.
|
||||
# Requires OPENAI_API_KEY in .env (uses OpenAI Whisper API directly).
|
||||
stt:
|
||||
enabled: true
|
||||
model: "whisper-1" # whisper-1 (cheapest) | gpt-4o-mini-transcribe | gpt-4o-transcribe
|
||||
|
||||
# =============================================================================
|
||||
# Response Pacing (Messaging Platforms)
|
||||
# =============================================================================
|
||||
# Add human-like delays between message chunks.
|
||||
# human_delay:
|
||||
# mode: "off" # "off" | "natural" | "custom"
|
||||
# min_ms: 800 # Min delay (custom mode only)
|
||||
# max_ms: 2500 # Max delay (custom mode only)
|
||||
|
||||
# =============================================================================
|
||||
# Session Logging
|
||||
# =============================================================================
|
||||
# Session trajectories are automatically saved to logs/ directory.
|
||||
# Each session creates: logs/session_YYYYMMDD_HHMMSS_UUID.json
|
||||
#
|
||||
# The session ID is displayed in the welcome banner for easy reference.
|
||||
# Logs contain full conversation history in trajectory format:
|
||||
# - System prompt, user messages, assistant responses
|
||||
# - Tool calls with inputs/outputs
|
||||
# - Timestamps for debugging
|
||||
#
|
||||
# No configuration needed - logging is always enabled.
|
||||
# To disable, you would need to modify the source code.
|
||||
|
||||
# =============================================================================
|
||||
# Code Execution Sandbox (Programmatic Tool Calling)
|
||||
# =============================================================================
|
||||
# The execute_code tool runs Python scripts that call Hermes tools via RPC.
|
||||
# Intermediate tool results stay out of the LLM's context window.
|
||||
code_execution:
|
||||
timeout: 300 # Max seconds per script before kill (default: 300 = 5 min)
|
||||
max_tool_calls: 50 # Max RPC tool calls per execution (default: 50)
|
||||
|
||||
# =============================================================================
|
||||
# Subagent Delegation
|
||||
# =============================================================================
|
||||
# The delegate_task tool spawns child agents with isolated context.
|
||||
# Supports single tasks and batch mode (up to 3 parallel).
|
||||
delegation:
|
||||
max_iterations: 50 # Max tool-calling turns per child (default: 50)
|
||||
default_toolsets: ["terminal", "file", "web"] # Default toolsets for subagents
|
||||
|
||||
# =============================================================================
|
||||
# Honcho Integration (Cross-Session User Modeling)
|
||||
# =============================================================================
|
||||
# AI-native persistent memory via Honcho (https://honcho.dev/).
|
||||
# Builds a deeper understanding of the user across sessions and tools.
|
||||
# Runs alongside USER.md — additive, not a replacement.
|
||||
#
|
||||
# Requires: pip install honcho-ai
|
||||
# Config: ~/.honcho/config.json (shared with Claude Code, Cursor, etc.)
|
||||
# API key: HONCHO_API_KEY in ~/.hermes/.env or ~/.honcho/config.json
|
||||
#
|
||||
# Hermes-specific overrides (optional — most config comes from ~/.honcho/config.json):
|
||||
# honcho: {}
|
||||
|
||||
# =============================================================================
|
||||
# Display
|
||||
# =============================================================================
|
||||
display:
|
||||
# Use compact banner mode
|
||||
compact: false
|
||||
|
||||
# Tool progress display level (CLI and gateway)
|
||||
# off: Silent — no tool activity shown, just the final response
|
||||
# new: Show a tool indicator only when the tool changes (skip repeats)
|
||||
# all: Show every tool call with a short preview (default)
|
||||
# verbose: Full args, results, and debug logs (same as /verbose)
|
||||
# Toggle at runtime with /verbose in the CLI
|
||||
tool_progress: all
|
||||
|
||||
# Play terminal bell when agent finishes a response.
|
||||
# Useful for long-running tasks — your terminal will ding when the agent is done.
|
||||
# Works over SSH. Most terminals can be configured to flash the taskbar or play a sound.
|
||||
bell_on_complete: false
|
||||
35
cron/__init__.py
Normal file
35
cron/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Cron job scheduling system for Hermes Agent.
|
||||
|
||||
This module provides scheduled task execution, allowing the agent to:
|
||||
- Run automated tasks on schedules (cron expressions, intervals, one-shot)
|
||||
- Self-schedule reminders and follow-up tasks
|
||||
- Execute tasks in isolated sessions (no prior context)
|
||||
|
||||
Cron jobs are executed automatically by the gateway daemon:
|
||||
hermes gateway install # Install as system service (recommended)
|
||||
hermes gateway # Or run in foreground
|
||||
|
||||
The gateway ticks the scheduler every 60 seconds. A file lock prevents
|
||||
duplicate execution if multiple processes overlap.
|
||||
"""
|
||||
|
||||
from cron.jobs import (
|
||||
JOBS_FILE,
|
||||
create_job,
|
||||
get_job,
|
||||
list_jobs,
|
||||
remove_job,
|
||||
update_job,
|
||||
)
|
||||
from cron.scheduler import tick
|
||||
|
||||
__all__ = [
|
||||
"create_job",
|
||||
"get_job",
|
||||
"list_jobs",
|
||||
"remove_job",
|
||||
"update_job",
|
||||
"tick",
|
||||
"JOBS_FILE",
|
||||
]
|
||||
395
cron/jobs.py
Normal file
395
cron/jobs.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Cron job storage and management.
|
||||
|
||||
Jobs are stored in ~/.hermes/cron/jobs.json
|
||||
Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
try:
|
||||
from croniter import croniter
|
||||
|
||||
HAS_CRONITER = True
|
||||
except ImportError:
|
||||
HAS_CRONITER = False
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
HERMES_DIR = Path.home() / ".hermes"
|
||||
CRON_DIR = HERMES_DIR / "cron"
|
||||
JOBS_FILE = CRON_DIR / "jobs.json"
|
||||
OUTPUT_DIR = CRON_DIR / "output"
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
"""Ensure cron directories exist."""
|
||||
CRON_DIR.mkdir(parents=True, exist_ok=True)
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schedule Parsing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def parse_duration(s: str) -> int:
|
||||
"""
|
||||
Parse duration string into minutes.
|
||||
|
||||
Examples:
|
||||
"30m" → 30
|
||||
"2h" → 120
|
||||
"1d" → 1440
|
||||
"""
|
||||
s = s.strip().lower()
|
||||
match = re.match(r"^(\d+)\s*(m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)$", s)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid duration: '{s}'. Use format like '30m', '2h', or '1d'")
|
||||
|
||||
value = int(match.group(1))
|
||||
unit = match.group(2)[0] # First char: m, h, or d
|
||||
|
||||
multipliers = {"m": 1, "h": 60, "d": 1440}
|
||||
return value * multipliers[unit]
|
||||
|
||||
|
||||
def parse_schedule(schedule: str) -> dict[str, Any]:
|
||||
"""
|
||||
Parse schedule string into structured format.
|
||||
|
||||
Returns dict with:
|
||||
- kind: "once" | "interval" | "cron"
|
||||
- For "once": "run_at" (ISO timestamp)
|
||||
- For "interval": "minutes" (int)
|
||||
- For "cron": "expr" (cron expression)
|
||||
|
||||
Examples:
|
||||
"30m" → once in 30 minutes
|
||||
"2h" → once in 2 hours
|
||||
"every 30m" → recurring every 30 minutes
|
||||
"every 2h" → recurring every 2 hours
|
||||
"0 9 * * *" → cron expression
|
||||
"2026-02-03T14:00" → once at timestamp
|
||||
"""
|
||||
schedule = schedule.strip()
|
||||
original = schedule
|
||||
schedule_lower = schedule.lower()
|
||||
|
||||
# "every X" pattern → recurring interval
|
||||
if schedule_lower.startswith("every "):
|
||||
duration_str = schedule[6:].strip()
|
||||
minutes = parse_duration(duration_str)
|
||||
return {"kind": "interval", "minutes": minutes, "display": f"every {minutes}m"}
|
||||
|
||||
# Check for cron expression (5 or 6 space-separated fields)
|
||||
# Cron fields: minute hour day month weekday [year]
|
||||
parts = schedule.split()
|
||||
if len(parts) >= 5 and all(re.match(r"^[\d\*\-,/]+$", p) for p in parts[:5]):
|
||||
if not HAS_CRONITER:
|
||||
raise ValueError("Cron expressions require 'croniter' package. Install with: pip install croniter")
|
||||
# Validate cron expression
|
||||
try:
|
||||
croniter(schedule)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid cron expression '{schedule}': {e}")
|
||||
return {"kind": "cron", "expr": schedule, "display": schedule}
|
||||
|
||||
# ISO timestamp (contains T or looks like date)
|
||||
if "T" in schedule or re.match(r"^\d{4}-\d{2}-\d{2}", schedule):
|
||||
try:
|
||||
# Parse and validate
|
||||
dt = datetime.fromisoformat(schedule.replace("Z", "+00:00"))
|
||||
return {"kind": "once", "run_at": dt.isoformat(), "display": f"once at {dt.strftime('%Y-%m-%d %H:%M')}"}
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid timestamp '{schedule}': {e}")
|
||||
|
||||
# Duration like "30m", "2h", "1d" → one-shot from now
|
||||
try:
|
||||
minutes = parse_duration(schedule)
|
||||
run_at = _hermes_now() + timedelta(minutes=minutes)
|
||||
return {"kind": "once", "run_at": run_at.isoformat(), "display": f"once in {original}"}
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid schedule '{original}'. Use:\n"
|
||||
f" - Duration: '30m', '2h', '1d' (one-shot)\n"
|
||||
f" - Interval: 'every 30m', 'every 2h' (recurring)\n"
|
||||
f" - Cron: '0 9 * * *' (cron expression)\n"
|
||||
f" - Timestamp: '2026-02-03T14:00:00' (one-shot at time)"
|
||||
)
|
||||
|
||||
|
||||
def _ensure_aware(dt: datetime) -> datetime:
|
||||
"""Make a naive datetime tz-aware using the configured timezone.
|
||||
|
||||
Handles backward compatibility: timestamps stored before timezone support
|
||||
are naive (server-local). We assume they were in the same timezone as
|
||||
the current configuration so comparisons work without crashing.
|
||||
"""
|
||||
if dt.tzinfo is None:
|
||||
tz = _hermes_now().tzinfo
|
||||
return dt.replace(tzinfo=tz)
|
||||
return dt
|
||||
|
||||
|
||||
def compute_next_run(schedule: dict[str, Any], last_run_at: str | None = None) -> str | None:
|
||||
"""
|
||||
Compute the next run time for a schedule.
|
||||
|
||||
Returns ISO timestamp string, or None if no more runs.
|
||||
"""
|
||||
now = _hermes_now()
|
||||
|
||||
if schedule["kind"] == "once":
|
||||
run_at = _ensure_aware(datetime.fromisoformat(schedule["run_at"]))
|
||||
# If in the future, return it; if in the past, no more runs
|
||||
return schedule["run_at"] if run_at > now else None
|
||||
|
||||
elif schedule["kind"] == "interval":
|
||||
minutes = schedule["minutes"]
|
||||
if last_run_at:
|
||||
# Next run is last_run + interval
|
||||
last = _ensure_aware(datetime.fromisoformat(last_run_at))
|
||||
next_run = last + timedelta(minutes=minutes)
|
||||
else:
|
||||
# First run is now + interval
|
||||
next_run = now + timedelta(minutes=minutes)
|
||||
return next_run.isoformat()
|
||||
|
||||
elif schedule["kind"] == "cron":
|
||||
if not HAS_CRONITER:
|
||||
return None
|
||||
cron = croniter(schedule["expr"], now)
|
||||
next_run = cron.get_next(datetime)
|
||||
return next_run.isoformat()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Job CRUD Operations
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def load_jobs() -> list[dict[str, Any]]:
|
||||
"""Load all jobs from storage."""
|
||||
ensure_dirs()
|
||||
if not JOBS_FILE.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(JOBS_FILE, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data.get("jobs", [])
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return []
|
||||
|
||||
|
||||
def save_jobs(jobs: list[dict[str, Any]]):
|
||||
"""Save all jobs to storage."""
|
||||
ensure_dirs()
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix=".tmp", prefix=".jobs_")
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
json.dump({"jobs": jobs, "updated_at": _hermes_now().isoformat()}, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, JOBS_FILE)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def create_job(
|
||||
prompt: str,
|
||||
schedule: str,
|
||||
name: str | None = None,
|
||||
repeat: int | None = None,
|
||||
deliver: str | None = None,
|
||||
origin: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a new cron job.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to run (must be self-contained)
|
||||
schedule: Schedule string (see parse_schedule)
|
||||
name: Optional friendly name
|
||||
repeat: How many times to run (None = forever, 1 = once)
|
||||
deliver: Where to deliver output ("origin", "local", "telegram", etc.)
|
||||
origin: Source info where job was created (for "origin" delivery)
|
||||
|
||||
Returns:
|
||||
The created job dict
|
||||
"""
|
||||
parsed_schedule = parse_schedule(schedule)
|
||||
|
||||
# Auto-set repeat=1 for one-shot schedules if not specified
|
||||
if parsed_schedule["kind"] == "once" and repeat is None:
|
||||
repeat = 1
|
||||
|
||||
# Default delivery to origin if available, otherwise local
|
||||
if deliver is None:
|
||||
deliver = "origin" if origin else "local"
|
||||
|
||||
job_id = uuid.uuid4().hex[:12]
|
||||
now = _hermes_now().isoformat()
|
||||
|
||||
job = {
|
||||
"id": job_id,
|
||||
"name": name or prompt[:50].strip(),
|
||||
"prompt": prompt,
|
||||
"schedule": parsed_schedule,
|
||||
"schedule_display": parsed_schedule.get("display", schedule),
|
||||
"repeat": {
|
||||
"times": repeat, # None = forever
|
||||
"completed": 0,
|
||||
},
|
||||
"enabled": True,
|
||||
"created_at": now,
|
||||
"next_run_at": compute_next_run(parsed_schedule),
|
||||
"last_run_at": None,
|
||||
"last_status": None,
|
||||
"last_error": None,
|
||||
# Delivery configuration
|
||||
"deliver": deliver,
|
||||
"origin": origin, # Tracks where job was created for "origin" delivery
|
||||
}
|
||||
|
||||
jobs = load_jobs()
|
||||
jobs.append(job)
|
||||
save_jobs(jobs)
|
||||
|
||||
return job
|
||||
|
||||
|
||||
def get_job(job_id: str) -> dict[str, Any] | None:
|
||||
"""Get a job by ID."""
|
||||
jobs = load_jobs()
|
||||
for job in jobs:
|
||||
if job["id"] == job_id:
|
||||
return job
|
||||
return None
|
||||
|
||||
|
||||
def list_jobs(include_disabled: bool = False) -> list[dict[str, Any]]:
|
||||
"""List all jobs, optionally including disabled ones."""
|
||||
jobs = load_jobs()
|
||||
if not include_disabled:
|
||||
jobs = [j for j in jobs if j.get("enabled", True)]
|
||||
return jobs
|
||||
|
||||
|
||||
def update_job(job_id: str, updates: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Update a job by ID."""
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
if job["id"] == job_id:
|
||||
jobs[i] = {**job, **updates}
|
||||
save_jobs(jobs)
|
||||
return jobs[i]
|
||||
return None
|
||||
|
||||
|
||||
def remove_job(job_id: str) -> bool:
|
||||
"""Remove a job by ID."""
|
||||
jobs = load_jobs()
|
||||
original_len = len(jobs)
|
||||
jobs = [j for j in jobs if j["id"] != job_id]
|
||||
if len(jobs) < original_len:
|
||||
save_jobs(jobs)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def mark_job_run(job_id: str, success: bool, error: str | None = None):
|
||||
"""
|
||||
Mark a job as having been run.
|
||||
|
||||
Updates last_run_at, last_status, increments completed count,
|
||||
computes next_run_at, and auto-deletes if repeat limit reached.
|
||||
"""
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
if job["id"] == job_id:
|
||||
now = _hermes_now().isoformat()
|
||||
job["last_run_at"] = now
|
||||
job["last_status"] = "ok" if success else "error"
|
||||
job["last_error"] = error if not success else None
|
||||
|
||||
# Increment completed count
|
||||
if job.get("repeat"):
|
||||
job["repeat"]["completed"] = job["repeat"].get("completed", 0) + 1
|
||||
|
||||
# Check if we've hit the repeat limit
|
||||
times = job["repeat"].get("times")
|
||||
completed = job["repeat"]["completed"]
|
||||
if times is not None and completed >= times:
|
||||
# Remove the job (limit reached)
|
||||
jobs.pop(i)
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
# Compute next run
|
||||
job["next_run_at"] = compute_next_run(job["schedule"], now)
|
||||
|
||||
# If no next run (one-shot completed), disable
|
||||
if job["next_run_at"] is None:
|
||||
job["enabled"] = False
|
||||
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
save_jobs(jobs)
|
||||
|
||||
|
||||
def get_due_jobs() -> list[dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now."""
|
||||
now = _hermes_now()
|
||||
jobs = load_jobs()
|
||||
due = []
|
||||
|
||||
for job in jobs:
|
||||
if not job.get("enabled", True):
|
||||
continue
|
||||
|
||||
next_run = job.get("next_run_at")
|
||||
if not next_run:
|
||||
continue
|
||||
|
||||
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
|
||||
if next_run_dt <= now:
|
||||
due.append(job)
|
||||
|
||||
return due
|
||||
|
||||
|
||||
def save_job_output(job_id: str, output: str):
|
||||
"""Save job output to file."""
|
||||
ensure_dirs()
|
||||
job_output_dir = OUTPUT_DIR / job_id
|
||||
job_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = _hermes_now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
output_file = job_output_dir / f"{timestamp}.md"
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(output)
|
||||
|
||||
return output_file
|
||||
401
cron/scheduler.py
Normal file
401
cron/scheduler.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Cron job scheduler - executes due jobs.
|
||||
|
||||
Provides tick() which checks for due jobs and runs them. The gateway
|
||||
calls this every 60 seconds from a background thread.
|
||||
|
||||
Uses a file-based lock (~/.hermes/cron/.tick.lock) so only one tick
|
||||
runs at a time if multiple processes overlap.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# fcntl is Unix-only; on Windows use msvcrt for file locking
|
||||
try:
|
||||
import fcntl
|
||||
except ImportError:
|
||||
fcntl = None
|
||||
try:
|
||||
import msvcrt
|
||||
except ImportError:
|
||||
msvcrt = None
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output
|
||||
|
||||
# Resolve Hermes home directory (respects HERMES_HOME override)
|
||||
_hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
# File-based lock prevents concurrent ticks from gateway + daemon + systemd timer
|
||||
_LOCK_DIR = _hermes_home / "cron"
|
||||
_LOCK_FILE = _LOCK_DIR / ".tick.lock"
|
||||
|
||||
|
||||
def _resolve_origin(job: dict) -> dict | None:
|
||||
"""Extract origin info from a job, returning {platform, chat_id, chat_name} or None."""
|
||||
origin = job.get("origin")
|
||||
if not origin:
|
||||
return None
|
||||
platform = origin.get("platform")
|
||||
chat_id = origin.get("chat_id")
|
||||
if platform and chat_id:
|
||||
return origin
|
||||
return None
|
||||
|
||||
|
||||
def _deliver_result(job: dict, content: str) -> None:
|
||||
"""
|
||||
Deliver job output to the configured target (origin chat, specific platform, etc.).
|
||||
|
||||
Uses the standalone platform send functions from send_message_tool so delivery
|
||||
works whether or not the gateway is running.
|
||||
"""
|
||||
deliver = job.get("deliver", "local")
|
||||
origin = _resolve_origin(job)
|
||||
|
||||
if deliver == "local":
|
||||
return
|
||||
|
||||
# Resolve target platform + chat_id
|
||||
if deliver == "origin":
|
||||
if not origin:
|
||||
logger.warning("Job '%s' deliver=origin but no origin stored, skipping delivery", job["id"])
|
||||
return
|
||||
platform_name = origin["platform"]
|
||||
chat_id = origin["chat_id"]
|
||||
elif ":" in deliver:
|
||||
platform_name, chat_id = deliver.split(":", 1)
|
||||
else:
|
||||
# Bare platform name like "telegram" — need to resolve to origin or home channel
|
||||
platform_name = deliver
|
||||
if origin and origin.get("platform") == platform_name:
|
||||
chat_id = origin["chat_id"]
|
||||
else:
|
||||
# Fall back to home channel
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
if not chat_id:
|
||||
logger.warning(
|
||||
"Job '%s' deliver=%s but no chat_id or home channel. Set via: hermes config set %s_HOME_CHANNEL <channel_id>",
|
||||
job["id"],
|
||||
deliver,
|
||||
platform_name.upper(),
|
||||
)
|
||||
return
|
||||
|
||||
from gateway.config import Platform, load_gateway_config
|
||||
from tools.send_message_tool import _send_to_platform
|
||||
|
||||
platform_map = {
|
||||
"telegram": Platform.TELEGRAM,
|
||||
"discord": Platform.DISCORD,
|
||||
"slack": Platform.SLACK,
|
||||
"whatsapp": Platform.WHATSAPP,
|
||||
"signal": Platform.SIGNAL,
|
||||
}
|
||||
platform = platform_map.get(platform_name.lower())
|
||||
if not platform:
|
||||
logger.warning("Job '%s': unknown platform '%s' for delivery", job["id"], platform_name)
|
||||
return
|
||||
|
||||
try:
|
||||
config = load_gateway_config()
|
||||
except Exception as e:
|
||||
logger.error("Job '%s': failed to load gateway config for delivery: %s", job["id"], e)
|
||||
return
|
||||
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
logger.warning("Job '%s': platform '%s' not configured/enabled", job["id"], platform_name)
|
||||
return
|
||||
|
||||
# Run the async send in a fresh event loop (safe from any thread)
|
||||
try:
|
||||
result = asyncio.run(_send_to_platform(platform, pconfig, chat_id, content))
|
||||
except RuntimeError:
|
||||
# asyncio.run() fails if there's already a running loop in this thread;
|
||||
# spin up a new thread to avoid that.
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, content))
|
||||
result = future.result(timeout=30)
|
||||
except Exception as e:
|
||||
logger.error("Job '%s': delivery to %s:%s failed: %s", job["id"], platform_name, chat_id, e)
|
||||
return
|
||||
|
||||
if result and result.get("error"):
|
||||
logger.error("Job '%s': delivery error: %s", job["id"], result["error"])
|
||||
else:
|
||||
logger.info("Job '%s': delivered to %s:%s", job["id"], platform_name, chat_id)
|
||||
# Mirror the delivered content into the target's gateway session
|
||||
try:
|
||||
from gateway.mirror import mirror_to_session
|
||||
|
||||
mirror_to_session(platform_name, chat_id, content, source_label="cron")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def run_job(job: dict) -> tuple[bool, str, str, str | None]:
|
||||
"""
|
||||
Execute a single cron job.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, full_output_doc, final_response, error_message)
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
|
||||
job_id = job["id"]
|
||||
job_name = job["name"]
|
||||
prompt = job["prompt"]
|
||||
origin = _resolve_origin(job)
|
||||
|
||||
logger.info("Running job '%s' (ID: %s)", job_name, job_id)
|
||||
logger.info("Prompt: %s", prompt[:100])
|
||||
|
||||
# Inject origin context so the agent's send_message tool knows the chat
|
||||
if origin:
|
||||
os.environ["HERMES_SESSION_PLATFORM"] = origin["platform"]
|
||||
os.environ["HERMES_SESSION_CHAT_ID"] = str(origin["chat_id"])
|
||||
if origin.get("chat_name"):
|
||||
os.environ["HERMES_SESSION_CHAT_NAME"] = origin["chat_name"]
|
||||
|
||||
try:
|
||||
# Re-read .env and config.yaml fresh every run so provider/key
|
||||
# changes take effect without a gateway restart.
|
||||
from dotenv import load_dotenv
|
||||
|
||||
try:
|
||||
load_dotenv(str(_hermes_home / ".env"), override=True, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(str(_hermes_home / ".env"), override=True, encoding="latin-1")
|
||||
|
||||
model = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6"
|
||||
|
||||
# Load config.yaml for model, reasoning, prefill, toolsets, provider routing
|
||||
_cfg = {}
|
||||
try:
|
||||
import yaml
|
||||
|
||||
_cfg_path = str(_hermes_home / "config.yaml")
|
||||
if os.path.exists(_cfg_path):
|
||||
with open(_cfg_path) as _f:
|
||||
_cfg = yaml.safe_load(_f) or {}
|
||||
_model_cfg = _cfg.get("model", {})
|
||||
if isinstance(_model_cfg, str):
|
||||
model = _model_cfg
|
||||
elif isinstance(_model_cfg, dict):
|
||||
model = _model_cfg.get("default", model)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reasoning config from env or config.yaml
|
||||
reasoning_config = None
|
||||
effort = os.getenv("HERMES_REASONING_EFFORT", "")
|
||||
if not effort:
|
||||
effort = str(_cfg.get("agent", {}).get("reasoning_effort", "")).strip()
|
||||
if effort and effort.lower() != "none":
|
||||
valid = ("xhigh", "high", "medium", "low", "minimal")
|
||||
if effort.lower() in valid:
|
||||
reasoning_config = {"enabled": True, "effort": effort.lower()}
|
||||
elif effort.lower() == "none":
|
||||
reasoning_config = {"enabled": False}
|
||||
|
||||
# Prefill messages from env or config.yaml
|
||||
prefill_messages = None
|
||||
prefill_file = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") or _cfg.get("prefill_messages_file", "")
|
||||
if prefill_file:
|
||||
import json as _json
|
||||
|
||||
pfpath = Path(prefill_file).expanduser()
|
||||
if not pfpath.is_absolute():
|
||||
pfpath = _hermes_home / pfpath
|
||||
if pfpath.exists():
|
||||
try:
|
||||
with open(pfpath, encoding="utf-8") as _pf:
|
||||
prefill_messages = _json.load(_pf)
|
||||
if not isinstance(prefill_messages, list):
|
||||
prefill_messages = None
|
||||
except Exception:
|
||||
prefill_messages = None
|
||||
|
||||
# Max iterations
|
||||
max_iterations = _cfg.get("agent", {}).get("max_turns") or _cfg.get("max_turns") or 90
|
||||
|
||||
# Provider routing
|
||||
pr = _cfg.get("provider_routing", {})
|
||||
|
||||
from hermes_cli.runtime_provider import (
|
||||
format_runtime_provider_error,
|
||||
resolve_runtime_provider,
|
||||
)
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(
|
||||
requested=os.getenv("HERMES_INFERENCE_PROVIDER"),
|
||||
)
|
||||
except Exception as exc:
|
||||
message = format_runtime_provider_error(exc)
|
||||
raise RuntimeError(message) from exc
|
||||
|
||||
agent = AIAgent(
|
||||
model=model,
|
||||
api_key=runtime.get("api_key"),
|
||||
base_url=runtime.get("base_url"),
|
||||
provider=runtime.get("provider"),
|
||||
api_mode=runtime.get("api_mode"),
|
||||
max_iterations=max_iterations,
|
||||
reasoning_config=reasoning_config,
|
||||
prefill_messages=prefill_messages,
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
quiet_mode=True,
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}",
|
||||
)
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if not final_response:
|
||||
final_response = "(No response generated)"
|
||||
|
||||
output = f"""# Cron Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {_hermes_now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
**Schedule:** {job.get("schedule_display", "N/A")}
|
||||
|
||||
## Prompt
|
||||
|
||||
{prompt}
|
||||
|
||||
## Response
|
||||
|
||||
{final_response}
|
||||
"""
|
||||
|
||||
logger.info("Job '%s' completed successfully", job_name)
|
||||
return True, output, final_response, None
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
||||
logger.error("Job '%s' failed: %s", job_name, error_msg)
|
||||
|
||||
output = f"""# Cron Job: {job_name} (FAILED)
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {_hermes_now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
**Schedule:** {job.get("schedule_display", "N/A")}
|
||||
|
||||
## Prompt
|
||||
|
||||
{prompt}
|
||||
|
||||
## Error
|
||||
|
||||
```
|
||||
{error_msg}
|
||||
|
||||
{traceback.format_exc()}
|
||||
```
|
||||
"""
|
||||
return False, output, "", error_msg
|
||||
|
||||
finally:
|
||||
# Clean up injected env vars so they don't leak to other jobs
|
||||
for key in ("HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME"):
|
||||
os.environ.pop(key, None)
|
||||
|
||||
|
||||
def tick(verbose: bool = True) -> int:
|
||||
"""
|
||||
Check and run all due jobs.
|
||||
|
||||
Uses a file lock so only one tick runs at a time, even if the gateway's
|
||||
in-process ticker and a standalone daemon or manual tick overlap.
|
||||
|
||||
Args:
|
||||
verbose: Whether to print status messages
|
||||
|
||||
Returns:
|
||||
Number of jobs executed (0 if another tick is already running)
|
||||
"""
|
||||
_LOCK_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Cross-platform file locking: fcntl on Unix, msvcrt on Windows
|
||||
lock_fd = None
|
||||
try:
|
||||
lock_fd = open(_LOCK_FILE, "w")
|
||||
if fcntl:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
elif msvcrt:
|
||||
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_NBLCK, 1)
|
||||
except OSError:
|
||||
logger.debug("Tick skipped — another instance holds the lock")
|
||||
if lock_fd is not None:
|
||||
lock_fd.close()
|
||||
return 0
|
||||
|
||||
try:
|
||||
due_jobs = get_due_jobs()
|
||||
|
||||
if verbose and not due_jobs:
|
||||
logger.info("%s - No jobs due", _hermes_now().strftime("%H:%M:%S"))
|
||||
return 0
|
||||
|
||||
if verbose:
|
||||
logger.info("%s - %s job(s) due", _hermes_now().strftime("%H:%M:%S"), len(due_jobs))
|
||||
|
||||
executed = 0
|
||||
for job in due_jobs:
|
||||
try:
|
||||
success, output, final_response, error = run_job(job)
|
||||
|
||||
output_file = save_job_output(job["id"], output)
|
||||
if verbose:
|
||||
logger.info("Output saved to: %s", output_file)
|
||||
|
||||
# Deliver the final response to the origin/target chat
|
||||
deliver_content = (
|
||||
final_response if success else f"⚠️ Cron job '{job.get('name', job['id'])}' failed:\n{error}"
|
||||
)
|
||||
if deliver_content:
|
||||
try:
|
||||
_deliver_result(job, deliver_content)
|
||||
except Exception as de:
|
||||
logger.error("Delivery failed for job %s: %s", job["id"], de)
|
||||
|
||||
mark_job_run(job["id"], success, error)
|
||||
executed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing job %s: %s", job["id"], e)
|
||||
mark_job_run(job["id"], False, str(e))
|
||||
|
||||
return executed
|
||||
finally:
|
||||
if fcntl:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_UN)
|
||||
elif msvcrt:
|
||||
try:
|
||||
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
except OSError:
|
||||
pass
|
||||
lock_fd.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tick(verbose=True)
|
||||
5
datagen-config-examples/example_browser_tasks.jsonl
Normal file
5
datagen-config-examples/example_browser_tasks.jsonl
Normal file
@@ -0,0 +1,5 @@
|
||||
{"prompt": "Go to https://news.ycombinator.com and find the top 5 posts on the front page. For each post, get the title, URL, points, and number of comments. Return the results as a formatted summary."}
|
||||
{"prompt": "Navigate to https://en.wikipedia.org/wiki/Hermes and extract the first paragraph of the article, the image caption, and the list of items in the infobox. Summarize what you find."}
|
||||
{"prompt": "Go to https://github.com/trending and find the top 3 trending repositories today. For each repo, get the name, description, language, and star count. Write the results to a file called trending_repos.md."}
|
||||
{"prompt": "Visit https://httpbin.org/forms/post and fill out the form with sample data (customer name: Jane Doe, size: Medium, topping: Bacon, delivery time: 12:00). Submit the form and report what the response page shows."}
|
||||
{"prompt": "Navigate to https://books.toscrape.com, browse to the Travel category, find the highest-rated book, and extract its title, price, availability, and description."}
|
||||
65
datagen-config-examples/run_browser_tasks.sh
Executable file
65
datagen-config-examples/run_browser_tasks.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
# =============================================================================
|
||||
# Example: Browser-Focused Data Generation
|
||||
# =============================================================================
|
||||
#
|
||||
# Generates tool-calling trajectories for browser automation tasks.
|
||||
# The agent navigates websites, fills forms, extracts information, etc.
|
||||
#
|
||||
# Distribution: browser 97%, web 20%, vision 12%, terminal 15%
|
||||
#
|
||||
# Prerequisites:
|
||||
# - OPENROUTER_API_KEY in ~/.hermes/.env
|
||||
# - BROWSERBASE_API_KEY in ~/.hermes/.env (for browser tools)
|
||||
# - A dataset JSONL file with one {"prompt": "..."} per line
|
||||
#
|
||||
# Usage:
|
||||
# cd ~/.hermes/hermes-agent
|
||||
# bash datagen-config-examples/run_browser_tasks.sh
|
||||
#
|
||||
# Output: data/browser_tasks_example/trajectories.jsonl
|
||||
# =============================================================================
|
||||
|
||||
mkdir -p logs
|
||||
|
||||
LOG_FILE="logs/browser_tasks_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "📝 Logging to: $LOG_FILE"
|
||||
|
||||
# Point to the example dataset in this directory
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="$SCRIPT_DIR/example_browser_tasks.jsonl" \
|
||||
--batch_size=5 \
|
||||
--run_name="browser_tasks_example" \
|
||||
--distribution="browser_tasks" \
|
||||
--model="anthropic/claude-sonnet-4" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--num_workers=3 \
|
||||
--max_turns=30 \
|
||||
--ephemeral_system_prompt="You are an AI assistant with browser automation capabilities. Your primary task is to navigate and interact with web pages to accomplish user goals.
|
||||
|
||||
IMPORTANT GUIDELINES:
|
||||
|
||||
1. SEARCHING: Do NOT search directly on Google via the browser — they block automated searches. Use the web_search tool first to find URLs, then navigate to them with browser tools.
|
||||
|
||||
2. COOKIE/PRIVACY DIALOGS: After navigating to a page, check for cookie consent or privacy popups. Dismiss them by clicking Accept/Close/OK before interacting with other elements. Take a fresh browser_snapshot afterward.
|
||||
|
||||
3. HANDLING TIMEOUTS: If an action times out, the element may be blocked by an overlay. Take a new snapshot and look for dialogs to dismiss. If none, try an alternative approach or report the issue.
|
||||
|
||||
4. GENERAL: Use browser tools to click, fill forms, and extract information. Use terminal for local file operations. Verify your actions and handle errors gracefully." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Done. Log: $LOG_FILE"
|
||||
|
||||
# =============================================================================
|
||||
# Common options you can add:
|
||||
#
|
||||
# --resume Resume from checkpoint if interrupted
|
||||
# --verbose Enable detailed logging
|
||||
# --max_tokens=63000 Set max response tokens
|
||||
# --reasoning_disabled Disable model thinking/reasoning tokens
|
||||
# --providers_allowed="anthropic,google" Restrict to specific providers
|
||||
# --prefill_messages_file="configs/prefill.json" Few-shot priming
|
||||
# =============================================================================
|
||||
101
datagen-config-examples/trajectory_compression.yaml
Normal file
101
datagen-config-examples/trajectory_compression.yaml
Normal file
@@ -0,0 +1,101 @@
|
||||
# Trajectory Compression Configuration
|
||||
#
|
||||
# Post-processes completed agent trajectories to fit within a target token budget.
|
||||
# Compression preserves head/tail turns and summarizes middle content only as needed.
|
||||
|
||||
# Tokenizer settings for accurate token counting
|
||||
tokenizer:
|
||||
# HuggingFace tokenizer name
|
||||
name: "moonshotai/Kimi-K2-Thinking"
|
||||
|
||||
# Trust remote code (required for some tokenizers)
|
||||
trust_remote_code: true
|
||||
|
||||
# Compression targets and behavior
|
||||
compression:
|
||||
# Target maximum tokens for compressed trajectory
|
||||
target_max_tokens: 29000
|
||||
|
||||
# Target size for summary (in tokens)
|
||||
# This is factored into calculations when determining what to compress
|
||||
summary_target_tokens: 750
|
||||
|
||||
# Protected turns that should NEVER be compressed
|
||||
protected_turns:
|
||||
# Always protect the first system message (tool definitions)
|
||||
first_system: true
|
||||
|
||||
# Always protect the first human message (original request)
|
||||
first_human: true
|
||||
|
||||
# Always protect the first gpt message (initial response/tool_call)
|
||||
first_gpt: true
|
||||
|
||||
# Always protect the first tool response (result of first action)
|
||||
first_tool: true
|
||||
|
||||
# Always protect the last 2 complete turn pairs (gpt+tool or gpt only)
|
||||
# This ensures the model's final actions and conclusions are preserved
|
||||
last_n_turns: 4
|
||||
|
||||
# LLM settings for generating summaries (OpenRouter only)
|
||||
summarization:
|
||||
# Model to use for summarization (should be fast and cheap)
|
||||
# Using OpenRouter model path format
|
||||
model: "google/gemini-3-flash-preview"
|
||||
|
||||
# OpenRouter API settings
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
|
||||
# Environment variable containing OpenRouter API key
|
||||
api_key_env: "OPENROUTER_API_KEY"
|
||||
|
||||
# Temperature for summarization (lower = more deterministic)
|
||||
temperature: 0.3
|
||||
|
||||
# Max retries for API failures
|
||||
max_retries: 3
|
||||
|
||||
# Delay between retries (seconds)
|
||||
retry_delay: 2
|
||||
|
||||
# Output settings
|
||||
output:
|
||||
# Add notice to system message about potential summarization
|
||||
add_summary_notice: true
|
||||
|
||||
# Text to append to system message
|
||||
summary_notice_text: "\n\nSome of the conversation may be summarized to preserve context."
|
||||
|
||||
# Output directory suffix (appended to input directory name)
|
||||
output_suffix: "_compressed"
|
||||
|
||||
# Processing settings
|
||||
processing:
|
||||
# Number of parallel workers for batch processing
|
||||
num_workers: 4
|
||||
|
||||
# Maximum concurrent API calls for summarization (async parallelism)
|
||||
max_concurrent_requests: 50
|
||||
|
||||
# Skip trajectories that are already under target length
|
||||
skip_under_target: true
|
||||
|
||||
# If true, save trajectories even if compression can't get under target
|
||||
# (will compress as much as possible)
|
||||
save_over_limit: true
|
||||
|
||||
# Timeout per trajectory in seconds (skip if takes longer)
|
||||
# Helps avoid hanging on problematic entries
|
||||
per_trajectory_timeout: 300 # 5 minutes
|
||||
|
||||
# Metrics to track
|
||||
metrics:
|
||||
# Log detailed compression statistics
|
||||
enabled: true
|
||||
|
||||
# Save per-trajectory metrics in output
|
||||
per_trajectory: false
|
||||
|
||||
# Metrics file name (saved in output directory)
|
||||
output_file: "compression_metrics.json"
|
||||
46
datagen-config-examples/web_research.yaml
Normal file
46
datagen-config-examples/web_research.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
# datagen-config-examples/web_research.yaml
|
||||
#
|
||||
# Batch data generation config for WebResearchEnv.
|
||||
# Generates tool-calling trajectories for multi-step web research tasks.
|
||||
#
|
||||
# Usage:
|
||||
# python batch_runner.py \
|
||||
# --config datagen-config-examples/web_research.yaml \
|
||||
# --run_name web_research_v1
|
||||
|
||||
environment: web-research
|
||||
|
||||
# Toolsets available to the agent during data generation
|
||||
toolsets:
|
||||
- web
|
||||
- file
|
||||
|
||||
# How many parallel workers to use
|
||||
num_workers: 4
|
||||
|
||||
# Questions per batch
|
||||
batch_size: 20
|
||||
|
||||
# Total trajectories to generate (comment out to run full dataset)
|
||||
max_items: 500
|
||||
|
||||
# Model to use for generation (override with --model flag)
|
||||
model: openrouter/nousresearch/hermes-3-llama-3.1-405b
|
||||
|
||||
# System prompt additions (ephemeral — not saved to trajectories)
|
||||
ephemeral_system_prompt: |
|
||||
You are a highly capable research agent. When asked a factual question,
|
||||
always use web_search to find current, accurate information before answering.
|
||||
Cite at least 2 sources. Be concise and accurate.
|
||||
|
||||
# Output directory
|
||||
output_dir: data/web_research_v1
|
||||
|
||||
# Trajectory compression settings (for fitting into training token budgets)
|
||||
compression:
|
||||
enabled: true
|
||||
target_max_tokens: 16000
|
||||
|
||||
# Eval settings
|
||||
eval_every: 100 # Run eval every N trajectories
|
||||
eval_size: 25 # Number of held-out questions per eval run
|
||||
334
environments/README.md
Normal file
334
environments/README.md
Normal file
@@ -0,0 +1,334 @@
|
||||
# Hermes-Agent Atropos Environments
|
||||
|
||||
This directory contains the integration layer between **hermes-agent's** tool-calling capabilities and the **Atropos** RL training framework. It provides everything needed to run agentic LLMs through multi-turn tool-calling loops, score their output with arbitrary reward functions, and feed results into Atropos for training or evaluation.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Atropos Framework
|
||||
┌───────────────────────┐
|
||||
│ BaseEnv │ (atroposlib)
|
||||
│ - Server management │
|
||||
│ - Worker scheduling │
|
||||
│ - Wandb logging │
|
||||
│ - CLI (serve/process/ │
|
||||
│ evaluate) │
|
||||
└───────────┬───────────┘
|
||||
│ inherits
|
||||
┌───────────┴───────────┐
|
||||
│ HermesAgentBaseEnv │ hermes_base_env.py
|
||||
│ - Terminal backend │
|
||||
│ - Tool resolution │
|
||||
│ - Agent loop │
|
||||
│ - ToolContext │
|
||||
│ - Async patches │
|
||||
└───────────┬───────────┘
|
||||
│ inherits
|
||||
┌─────────────────┼─────────────────┐
|
||||
│ │ │
|
||||
TerminalTestEnv HermesSweEnv TerminalBench2EvalEnv
|
||||
(stack testing) (SWE training) (TB2 benchmark eval)
|
||||
```
|
||||
|
||||
### Inheritance Chain
|
||||
|
||||
**BaseEnv** (from `atroposlib`) is the Atropos base class. It provides:
|
||||
- Server management (OpenAI-compatible API servers, VLLM, SGLang)
|
||||
- Worker scheduling for parallel rollouts
|
||||
- Wandb integration for metrics and rollout logging
|
||||
- CLI interface with three subcommands: `serve`, `process`, `evaluate`
|
||||
- `evaluate_log()` for saving eval results to JSON + samples.jsonl
|
||||
|
||||
**HermesAgentBaseEnv** (`hermes_base_env.py`) extends BaseEnv with hermes-agent specifics:
|
||||
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, modal, daytona, ssh, singularity)
|
||||
- Resolves hermes-agent toolsets via `_resolve_tools_for_group()` (calls `get_tool_definitions()` which queries `tools/registry.py`)
|
||||
- Implements `collect_trajectory()` which runs the full agent loop and computes rewards
|
||||
- Supports two-phase operation (Phase 1: OpenAI server, Phase 2: VLLM ManagedServer)
|
||||
- Applies monkey patches for async-safe tool operation at import time
|
||||
|
||||
Concrete environments inherit from `HermesAgentBaseEnv` and implement:
|
||||
- `setup()` -- Load dataset, initialize state
|
||||
- `get_next_item()` -- Return the next item for rollout
|
||||
- `format_prompt()` -- Convert a dataset item into the user message
|
||||
- `compute_reward()` -- Score the rollout using ToolContext
|
||||
- `evaluate()` -- Periodic evaluation logic
|
||||
|
||||
## Core Components
|
||||
|
||||
### Agent Loop (`agent_loop.py`)
|
||||
|
||||
`HermesAgentLoop` is the reusable multi-turn agent engine. It runs the same pattern as hermes-agent's `run_agent.py`:
|
||||
|
||||
1. Send messages + tools to the API via `server.chat_completion()`
|
||||
2. If the response contains `tool_calls`, execute each one via `handle_function_call()` (which delegates to `tools/registry.py`'s `dispatch()`)
|
||||
3. Append tool results to the conversation and go back to step 1
|
||||
4. If the response has no tool_calls, the agent is done
|
||||
|
||||
Tool calls are executed in a thread pool (`run_in_executor`) so backends that use `asyncio.run()` internally (Modal, Docker) don't deadlock inside Atropos's event loop.
|
||||
|
||||
Returns an `AgentResult` containing the full conversation history, turn count, reasoning content per turn, tool errors, and optional ManagedServer state (for Phase 2).
|
||||
|
||||
### Tool Context (`tool_context.py`)
|
||||
|
||||
`ToolContext` is a per-rollout handle that gives reward/verification functions direct access to **all** hermes-agent tools, scoped to the rollout's `task_id`. The same `task_id` means the terminal/browser session is the SAME one the model used during its rollout -- all state (files, processes, browser tabs) is preserved.
|
||||
|
||||
```python
|
||||
async def compute_reward(self, item, result, ctx: ToolContext):
|
||||
# Run tests in the model's terminal sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
if test["exit_code"] == 0:
|
||||
return 1.0
|
||||
|
||||
# Check if a file was created
|
||||
content = ctx.read_file("/workspace/solution.py")
|
||||
if content.get("content"):
|
||||
return 0.5
|
||||
|
||||
# Download files locally for verification (binary-safe)
|
||||
ctx.download_file("/remote/output.bin", "/local/output.bin")
|
||||
|
||||
return 0.0
|
||||
```
|
||||
|
||||
Available methods:
|
||||
- **Terminal**: `terminal(command, timeout)` -- run shell commands
|
||||
- **Files**: `read_file(path)`, `write_file(path, content)`, `search(query, path)`
|
||||
- **Transfers**: `upload_file()`, `upload_dir()`, `download_file()`, `download_dir()` -- binary-safe file transfers between host and sandbox
|
||||
- **Web**: `web_search(query)`, `web_extract(urls)`
|
||||
- **Browser**: `browser_navigate(url)`, `browser_snapshot()`
|
||||
- **Generic**: `call_tool(name, args)` -- call any hermes-agent tool by name
|
||||
- **Cleanup**: `cleanup()` -- release all resources (called automatically after `compute_reward`)
|
||||
|
||||
### Patches (`patches.py`)
|
||||
|
||||
**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., mini-swe-agent's Modal backend via SWE-ReX). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested.
|
||||
|
||||
**Solution**: `patches.py` monkey-patches `SwerexModalEnvironment` to use a dedicated background thread (`_AsyncWorker`) with its own event loop. The calling code sees the same sync interface, but internally the async work happens on a separate thread that doesn't conflict with Atropos's loop.
|
||||
|
||||
What gets patched:
|
||||
- `SwerexModalEnvironment.__init__` -- creates Modal deployment on a background thread
|
||||
- `SwerexModalEnvironment.execute` -- runs commands on the same background thread
|
||||
- `SwerexModalEnvironment.stop` -- stops deployment on the background thread
|
||||
|
||||
The patches are:
|
||||
- **Idempotent** -- calling `apply_patches()` multiple times is safe
|
||||
- **Transparent** -- same interface and behavior, only the internal async execution changes
|
||||
- **Universal** -- works identically in normal CLI use (no running event loop)
|
||||
|
||||
Applied automatically at import time by `hermes_base_env.py`.
|
||||
|
||||
### Tool Call Parsers (`tool_call_parsers/`)
|
||||
|
||||
Client-side parsers that extract structured `tool_calls` from raw model output text. Used in **Phase 2** (VLLM server type) where ManagedServer's `/generate` endpoint returns raw text without tool call parsing.
|
||||
|
||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's `extract_tool_calls()` logic. No VLLM dependency -- only standard library (`re`, `json`, `uuid`) and `openai` types.
|
||||
|
||||
Available parsers:
|
||||
- `hermes` -- Hermes/ChatML `<tool_call>` XML format
|
||||
- `mistral` -- Mistral `[TOOL_CALLS]` format
|
||||
- `llama3_json` -- Llama 3 JSON tool calling
|
||||
- `qwen` -- Qwen tool calling format
|
||||
- `qwen3_coder` -- Qwen3 Coder format
|
||||
- `deepseek_v3` -- DeepSeek V3 format
|
||||
- `deepseek_v3_1` -- DeepSeek V3.1 format
|
||||
- `kimi_k2` -- Kimi K2 format
|
||||
- `longcat` -- Longcat format
|
||||
- `glm45` / `glm47` -- GLM model formats
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse(raw_model_output)
|
||||
```
|
||||
|
||||
In Phase 1 (OpenAI server type), these parsers are not needed -- the server handles tool call parsing natively.
|
||||
|
||||
## Two-Phase Operation
|
||||
|
||||
### Phase 1: OpenAI Server (Evaluation / SFT Data Generation)
|
||||
|
||||
Uses `server.chat_completion()` with `tools=` parameter. The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing natively. Returns `ChatCompletion` objects with structured `tool_calls`.
|
||||
|
||||
- Good for: evaluation, SFT data generation, testing
|
||||
- Run with: `serve` (with `run-api`), `process`, or `evaluate` subcommands
|
||||
- Placeholder tokens are created for the Atropos pipeline
|
||||
|
||||
### Phase 2: VLLM ManagedServer (Full RL Training)
|
||||
|
||||
Uses ManagedServer for exact token IDs + logprobs via `/generate`. Client-side tool call parser (from `tool_call_parsers/`) reconstructs structured `tool_calls` from raw output.
|
||||
|
||||
- Good for: full RL training with GRPO/PPO
|
||||
- Run with: `serve` subcommand
|
||||
- Real tokens, masks, and logprobs flow through the pipeline
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
environments/
|
||||
├── README.md # This file
|
||||
├── __init__.py # Package exports
|
||||
├── hermes_base_env.py # Abstract base (HermesAgentBaseEnv)
|
||||
├── agent_loop.py # Multi-turn agent engine (HermesAgentLoop)
|
||||
├── tool_context.py # Per-rollout tool access for reward functions
|
||||
├── patches.py # Async-safety patches for Modal backend
|
||||
│
|
||||
├── tool_call_parsers/ # Phase 2 client-side parsers
|
||||
│ ├── __init__.py # Registry + base class
|
||||
│ ├── hermes_parser.py
|
||||
│ ├── mistral_parser.py
|
||||
│ ├── llama_parser.py
|
||||
│ ├── qwen_parser.py
|
||||
│ ├── qwen3_coder_parser.py
|
||||
│ ├── deepseek_v3_parser.py
|
||||
│ ├── deepseek_v3_1_parser.py
|
||||
│ ├── kimi_k2_parser.py
|
||||
│ ├── longcat_parser.py
|
||||
│ ├── glm45_parser.py
|
||||
│ └── glm47_parser.py
|
||||
│
|
||||
├── terminal_test_env/ # Stack validation environment
|
||||
│ └── terminal_test_env.py
|
||||
│
|
||||
├── hermes_swe_env/ # SWE-bench style training environment
|
||||
│ └── hermes_swe_env.py
|
||||
│
|
||||
└── benchmarks/ # Evaluation benchmarks
|
||||
├── terminalbench_2/ # 89 terminal tasks, Modal sandboxes
|
||||
│ └── terminalbench2_env.py
|
||||
├── tblite/ # 100 calibrated tasks (fast TB2 proxy)
|
||||
│ └── tblite_env.py
|
||||
└── yc_bench/ # Long-horizon strategic benchmark
|
||||
└── yc_bench_env.py
|
||||
```
|
||||
|
||||
## Concrete Environments
|
||||
|
||||
### TerminalTestEnv (`terminal_test_env/`)
|
||||
|
||||
A self-contained environment with inline tasks (no external dataset needed) for validating the full stack end-to-end. Each task asks the model to create a file at a known path, and the verifier checks the content matches.
|
||||
|
||||
```bash
|
||||
# Serve mode (needs run-api)
|
||||
run-api
|
||||
python environments/terminal_test_env/terminal_test_env.py serve
|
||||
|
||||
# Process mode (no run-api, saves to JSONL)
|
||||
python environments/terminal_test_env/terminal_test_env.py process \
|
||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
||||
```
|
||||
|
||||
### HermesSweEnv (`hermes_swe_env/`)
|
||||
|
||||
SWE-bench style training environment. The model gets a coding task, uses terminal + file + web tools to solve it, and the reward function runs tests in the same Modal sandbox.
|
||||
|
||||
```bash
|
||||
python environments/hermes_swe_env/hermes_swe_env.py serve \
|
||||
--openai.model_name YourModel \
|
||||
--env.dataset_name bigcode/humanevalpack \
|
||||
--env.terminal_backend modal
|
||||
```
|
||||
|
||||
### TerminalBench2EvalEnv (`benchmarks/terminalbench_2/`)
|
||||
|
||||
**Eval-only** environment for the Terminal-Bench 2.0 benchmark (89 tasks). Each task gets a pre-built Docker Hub image, a natural language instruction, and a test suite. The agent uses terminal + file tools to solve the task, then the test suite verifies correctness.
|
||||
|
||||
Follows the standard Atropos eval pattern (like GPQA, MMLU, etc.):
|
||||
- Run via `evaluate` subcommand (no `run-api` needed)
|
||||
- `setup()` loads the dataset, `evaluate()` runs all tasks
|
||||
- `rollout_and_score_eval()` handles per-task agent loop + test verification
|
||||
- Downloads verifier output locally for reliable reward checking (Harbor pattern)
|
||||
|
||||
```bash
|
||||
# Run full benchmark
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6
|
||||
|
||||
# Run subset of tasks
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6 \
|
||||
--env.task_filter fix-git,git-multibranch
|
||||
|
||||
# Skip specific tasks
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6 \
|
||||
--env.skip_tasks heavy-task,slow-task
|
||||
```
|
||||
|
||||
## Creating a New Environment
|
||||
|
||||
### Training Environment
|
||||
|
||||
1. Create a new directory under `environments/`
|
||||
2. Create your env file inheriting from `HermesAgentBaseEnv`
|
||||
3. Implement the four abstract methods + `evaluate()`
|
||||
|
||||
```python
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
class MyEnvConfig(HermesAgentEnvConfig):
|
||||
pass # Add custom fields as needed
|
||||
|
||||
class MyEnv(HermesAgentBaseEnv):
|
||||
name = "my-env"
|
||||
env_config_cls = MyEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls):
|
||||
env_config = MyEnvConfig(
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
terminal_backend="modal",
|
||||
# ... other config
|
||||
)
|
||||
server_configs = [APIServerConfig(...)]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
self.dataset = load_dataset(...)
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item):
|
||||
return item["instruction"]
|
||||
|
||||
async def compute_reward(self, item, result, ctx):
|
||||
# ctx gives you full tool access to the rollout's sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
return 1.0 if test["exit_code"] == 0 else 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# Periodic evaluation logic
|
||||
...
|
||||
|
||||
if __name__ == "__main__":
|
||||
MyEnv.cli()
|
||||
```
|
||||
|
||||
### Eval-Only Environment (Benchmark)
|
||||
|
||||
For eval benchmarks, follow the pattern in `terminalbench2_env.py`:
|
||||
1. Create under `environments/benchmarks/your-benchmark/`
|
||||
2. Inherit from `HermesAgentBaseEnv`
|
||||
3. Set eval-only config: `eval_handling=STOP_TRAIN`, `steps_per_eval=1`, `total_steps=1`
|
||||
4. Stub the training methods (`collect_trajectories`, `score`)
|
||||
5. Implement `rollout_and_score_eval()` and `evaluate()`
|
||||
6. Run with `evaluate` subcommand
|
||||
|
||||
## Key Config Fields
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| `enabled_toolsets` | Which hermes toolsets to enable | `None` (all) |
|
||||
| `disabled_toolsets` | Toolsets to disable | `None` |
|
||||
| `distribution` | Probabilistic toolset distribution name | `None` |
|
||||
| `max_agent_turns` | Max LLM calls per rollout | `30` |
|
||||
| `agent_temperature` | Sampling temperature | `1.0` |
|
||||
| `terminal_backend` | `local`, `docker`, `modal`, `daytona`, `ssh`, `singularity` | `local` |
|
||||
| `system_prompt` | System message for the agent | `None` |
|
||||
| `tool_call_parser` | Parser name for Phase 2 | `hermes` |
|
||||
| `eval_handling` | `STOP_TRAIN`, `LIMIT_TRAIN`, `NONE` | `STOP_TRAIN` |
|
||||
31
environments/__init__.py
Normal file
31
environments/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Hermes-Agent Atropos Environments
|
||||
|
||||
Provides a layered integration between hermes-agent's tool-calling capabilities
|
||||
and the Atropos RL training framework.
|
||||
|
||||
Core layers:
|
||||
- agent_loop: Reusable multi-turn agent loop with standard OpenAI-spec tool calling
|
||||
- tool_context: Per-rollout tool access handle for reward/verification functions
|
||||
- hermes_base_env: Abstract base environment (BaseEnv subclass) for Atropos
|
||||
- tool_call_parsers: Client-side tool call parser registry for Phase 2 (VLLM /generate)
|
||||
|
||||
Concrete environments:
|
||||
- terminal_test_env/: Simple file-creation tasks for testing the stack
|
||||
- hermes_swe_env/: SWE-bench style tasks with Modal sandboxes
|
||||
|
||||
Benchmarks (eval-only):
|
||||
- benchmarks/terminalbench_2/: Terminal-Bench 2.0 evaluation
|
||||
"""
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
__all__ = [
|
||||
"AgentResult",
|
||||
"HermesAgentLoop",
|
||||
"ToolContext",
|
||||
"HermesAgentBaseEnv",
|
||||
"HermesAgentEnvConfig",
|
||||
]
|
||||
453
environments/agent_loop.py
Normal file
453
environments/agent_loop.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
HermesAgentLoop -- Reusable Multi-Turn Agent Engine
|
||||
|
||||
Runs the hermes-agent tool-calling loop using standard OpenAI-spec tool calling.
|
||||
Works with any server that returns ChatCompletion objects with tool_calls:
|
||||
- Phase 1: OpenAI server type (VLLM, SGLang, OpenRouter, OpenAI API)
|
||||
- Phase 2: ManagedServer with client-side tool call parser
|
||||
|
||||
The loop passes tools= and checks response.choices[0].message.tool_calls,
|
||||
identical to hermes-agent's run_agent.py. Tool execution is dispatched via
|
||||
handle_function_call() from model_tools.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from model_tools import handle_function_call
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
# (e.g., mini-swe-agent's modal/docker/daytona backends). Running them in a separate
|
||||
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
|
||||
# Size must be large enough for concurrent eval tasks (e.g., 89 TB2 tasks all
|
||||
# making tool calls). Too small = thread pool starvation, tasks queue for minutes.
|
||||
# Resized at runtime by HermesAgentBaseEnv.__init__ via resize_tool_pool().
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=128)
|
||||
|
||||
|
||||
def resize_tool_pool(max_workers: int):
|
||||
"""
|
||||
Replace the global tool executor with a new one of the given size.
|
||||
|
||||
Called by HermesAgentBaseEnv.__init__ based on config.tool_pool_size.
|
||||
Safe to call before any tasks are submitted.
|
||||
"""
|
||||
global _tool_executor
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
logger.info("Tool thread pool resized to %d workers", max_workers)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolError:
|
||||
"""Record of a tool execution error during the agent loop."""
|
||||
|
||||
turn: int # Which turn the error occurred on
|
||||
tool_name: str # Which tool was called
|
||||
arguments: str # The arguments passed (truncated)
|
||||
error: str # The error message
|
||||
tool_result: str # The raw result returned to the model
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""Result of running the agent loop."""
|
||||
|
||||
# Full conversation history in OpenAI message format
|
||||
messages: List[Dict[str, Any]]
|
||||
# ManagedServer.get_state() if available (Phase 2), None otherwise
|
||||
managed_state: Optional[Dict[str, Any]] = None
|
||||
# How many LLM calls were made
|
||||
turns_used: int = 0
|
||||
# True if model stopped calling tools naturally (vs hitting max_turns)
|
||||
finished_naturally: bool = False
|
||||
# Extracted reasoning content per turn (from PR #297 helpers)
|
||||
reasoning_per_turn: List[Optional[str]] = field(default_factory=list)
|
||||
# Tool errors encountered during the loop
|
||||
tool_errors: List[ToolError] = field(default_factory=list)
|
||||
|
||||
|
||||
def _extract_reasoning_from_message(message) -> Optional[str]:
|
||||
"""
|
||||
Extract reasoning content from a ChatCompletion message.
|
||||
|
||||
Handles multiple provider formats:
|
||||
1. message.reasoning_content field (some providers)
|
||||
2. message.reasoning field (some providers)
|
||||
3. message.reasoning_details[].text (OpenRouter style)
|
||||
|
||||
Note: <think> block extraction from content is NOT done here -- that's
|
||||
handled by the response already in Phase 1 (server does it) or by
|
||||
ManagedServer's patch in Phase 2.
|
||||
|
||||
Args:
|
||||
message: The assistant message from ChatCompletion response
|
||||
|
||||
Returns:
|
||||
Extracted reasoning text, or None if not found
|
||||
"""
|
||||
# Check reasoning_content field (common across providers)
|
||||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||||
return message.reasoning_content
|
||||
|
||||
# Check reasoning field
|
||||
if hasattr(message, "reasoning") and message.reasoning:
|
||||
return message.reasoning
|
||||
|
||||
# Check reasoning_details (OpenRouter style)
|
||||
if hasattr(message, "reasoning_details") and message.reasoning_details:
|
||||
for detail in message.reasoning_details:
|
||||
if hasattr(detail, "text") and detail.text:
|
||||
return detail.text
|
||||
if isinstance(detail, dict) and detail.get("text"):
|
||||
return detail["text"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class HermesAgentLoop:
|
||||
"""
|
||||
Runs hermes-agent's tool-calling loop using standard OpenAI-spec tool calling.
|
||||
|
||||
Same pattern as run_agent.py:
|
||||
- Pass tools= to the API
|
||||
- Check response.choices[0].message.tool_calls
|
||||
- Dispatch via handle_function_call()
|
||||
|
||||
Works identically with any server type -- OpenAI, VLLM, SGLang, OpenRouter,
|
||||
or ManagedServer with a parser. The server determines how tool_calls get
|
||||
populated on the response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server,
|
||||
tool_schemas: List[Dict[str, Any]],
|
||||
valid_tool_names: Set[str],
|
||||
max_turns: int = 30,
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
|
||||
Args:
|
||||
server: Server object with chat_completion() method (OpenAIServer,
|
||||
ManagedServer, ServerManager, etc.)
|
||||
tool_schemas: OpenAI-format tool definitions from get_tool_definitions()
|
||||
valid_tool_names: Set of tool names the model is allowed to call
|
||||
max_turns: Maximum number of LLM calls before stopping
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
||||
Used for OpenRouter provider preferences, transforms, etc.
|
||||
e.g. {"provider": {"ignore": ["DeepInfra"]}}
|
||||
"""
|
||||
self.server = server
|
||||
self.tool_schemas = tool_schemas
|
||||
self.valid_tool_names = valid_tool_names
|
||||
self.max_turns = max_turns
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.extra_body = extra_body
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
||||
"""
|
||||
Execute the full agent loop using standard OpenAI tool calling.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages (system + user).
|
||||
Modified in-place as the conversation progresses.
|
||||
|
||||
Returns:
|
||||
AgentResult with full conversation history, managed state, and metadata
|
||||
"""
|
||||
reasoning_per_turn = []
|
||||
tool_errors: List[ToolError] = []
|
||||
|
||||
# Per-loop TodoStore for the todo tool (ephemeral, dies with the loop)
|
||||
from tools.todo_tool import TodoStore, todo_tool as _todo_tool
|
||||
_todo_store = TodoStore()
|
||||
|
||||
# Extract user task from first user message for browser_snapshot context
|
||||
_user_task = None
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str) and content.strip():
|
||||
_user_task = content.strip()[:500] # Cap to avoid huge strings
|
||||
break
|
||||
|
||||
import time as _time
|
||||
|
||||
for turn in range(self.max_turns):
|
||||
turn_start = _time.monotonic()
|
||||
|
||||
# Build the chat_completion kwargs
|
||||
chat_kwargs = {
|
||||
"messages": messages,
|
||||
"n": 1,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
# Only pass tools if we have them
|
||||
if self.tool_schemas:
|
||||
chat_kwargs["tools"] = self.tool_schemas
|
||||
|
||||
# Only pass max_tokens if explicitly set
|
||||
if self.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.max_tokens
|
||||
|
||||
# Inject extra_body for provider-specific params (e.g., OpenRouter
|
||||
# provider preferences like banned/preferred providers, transforms)
|
||||
if self.extra_body:
|
||||
chat_kwargs["extra_body"] = self.extra_body
|
||||
|
||||
# Make the API call -- standard OpenAI spec
|
||||
api_start = _time.monotonic()
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
|
||||
if not response or not response.choices:
|
||||
logger.warning("Empty response on turn %d (api=%.1fs)", turn + 1, api_elapsed)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
assistant_msg = response.choices[0].message
|
||||
|
||||
# Extract reasoning content from the response (all provider formats)
|
||||
reasoning = _extract_reasoning_from_message(assistant_msg)
|
||||
reasoning_per_turn.append(reasoning)
|
||||
|
||||
# Check for tool calls -- standard OpenAI spec
|
||||
if assistant_msg.tool_calls:
|
||||
# Build the assistant message dict for conversation history
|
||||
msg_dict: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": assistant_msg.content or "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in assistant_msg.tool_calls
|
||||
],
|
||||
}
|
||||
|
||||
# Preserve reasoning_content for multi-turn chat template handling
|
||||
# (e.g., Kimi-K2's template renders <think> blocks differently
|
||||
# for history vs. the latest turn based on this field)
|
||||
if reasoning:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
|
||||
messages.append(msg_dict)
|
||||
|
||||
# Execute each tool call via hermes-agent's dispatch
|
||||
for tc in assistant_msg.tool_calls:
|
||||
tool_name = tc.function.name
|
||||
tool_args_raw = tc.function.arguments
|
||||
|
||||
# Validate tool name
|
||||
if tool_name not in self.valid_tool_names:
|
||||
tool_result = json.dumps(
|
||||
{
|
||||
"error": f"Unknown tool '{tool_name}'. "
|
||||
f"Available tools: {sorted(self.valid_tool_names)}"
|
||||
}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"Unknown tool '{tool_name}'",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.warning(
|
||||
"Model called unknown tool '%s' on turn %d",
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
else:
|
||||
# Parse arguments and dispatch
|
||||
try:
|
||||
args = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
logger.warning(
|
||||
"Invalid JSON in tool call arguments for '%s': %s",
|
||||
tool_name, tool_args_raw[:200],
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
|
||||
tool_submit_time = _time.monotonic()
|
||||
|
||||
# Todo tool -- handle locally (needs per-loop TodoStore)
|
||||
if tool_name == "todo":
|
||||
tool_result = _todo_tool(
|
||||
todos=args.get("todos"),
|
||||
merge=args.get("merge", False),
|
||||
store=_todo_store,
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "memory":
|
||||
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "session_search":
|
||||
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that
|
||||
# use asyncio.run() internally (modal, docker, daytona) get
|
||||
# a clean event loop instead of deadlocking.
|
||||
loop = asyncio.get_event_loop()
|
||||
# Capture current tool_name/args for the lambda
|
||||
_tn, _ta, _tid = tool_name, args, self.task_id
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
_tn, _ta, task_id=_tid,
|
||||
user_task=_user_task,
|
||||
),
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"{type(e).__name__}: {str(e)}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.error(
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
|
||||
# Also check if the tool returned an error in its JSON result
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
if isinstance(result_data, dict):
|
||||
err = result_data.get("error")
|
||||
exit_code = result_data.get("exit_code")
|
||||
if err and exit_code and exit_code < 0:
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=str(err),
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Add tool response to conversation
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs",
|
||||
self.task_id[:8], turn + 1, api_elapsed,
|
||||
len(assistant_msg.tool_calls), turn_elapsed,
|
||||
)
|
||||
|
||||
else:
|
||||
# No tool calls -- model is done
|
||||
msg_dict = {
|
||||
"role": "assistant",
|
||||
"content": assistant_msg.content or "",
|
||||
}
|
||||
if reasoning:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
messages.append(msg_dict)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, no tools (finished), turn_total=%.1fs",
|
||||
self.task_id[:8], turn + 1, api_elapsed, turn_elapsed,
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=True,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
# Hit max turns without the model stopping
|
||||
logger.info("Agent hit max_turns (%d) without finishing", self.max_turns)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=self.max_turns,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
def _get_managed_state(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get ManagedServer state if the server supports it.
|
||||
|
||||
Returns state dict with SequenceNodes containing tokens/logprobs/masks,
|
||||
or None if the server doesn't support get_state() (e.g., regular OpenAI server).
|
||||
"""
|
||||
if hasattr(self.server, "get_state"):
|
||||
return self.server.get_state()
|
||||
return None
|
||||
0
environments/benchmarks/__init__.py
Normal file
0
environments/benchmarks/__init__.py
Normal file
73
environments/benchmarks/tblite/README.md
Normal file
73
environments/benchmarks/tblite/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# OpenThoughts-TBLite Evaluation Environment
|
||||
|
||||
This environment evaluates terminal agents on the [OpenThoughts-TBLite](https://huggingface.co/datasets/open-thoughts/OpenThoughts-TBLite) benchmark, a difficulty-calibrated subset of [Terminal-Bench 2.0](https://www.tbench.ai/leaderboard/terminal-bench/2.0).
|
||||
|
||||
## Source
|
||||
|
||||
OpenThoughts-TBLite was created by the [OpenThoughts](https://www.openthoughts.ai/) Agent team in collaboration with [Snorkel AI](https://snorkel.ai/) and [Bespoke Labs](https://bespokelabs.ai/). The original dataset and documentation live at:
|
||||
|
||||
- **Dataset (source):** [open-thoughts/OpenThoughts-TBLite](https://huggingface.co/datasets/open-thoughts/OpenThoughts-TBLite)
|
||||
- **GitHub:** [open-thoughts/OpenThoughts-TBLite](https://github.com/open-thoughts/OpenThoughts-TBLite)
|
||||
- **Blog post:** [openthoughts.ai/blog/openthoughts-tblite](https://www.openthoughts.ai/blog/openthoughts-tblite)
|
||||
|
||||
## Our Dataset
|
||||
|
||||
We converted the source into the same schema used by our Terminal-Bench 2.0 environment (pre-built Docker Hub images, base64-encoded test tarballs, etc.) and published it as:
|
||||
|
||||
- **Dataset (ours):** [NousResearch/openthoughts-tblite](https://huggingface.co/datasets/NousResearch/openthoughts-tblite)
|
||||
- **Docker images:** `nousresearch/tblite-<task-name>:latest` on Docker Hub (100 images)
|
||||
|
||||
The conversion script is at `scripts/prepare_tblite_dataset.py`.
|
||||
|
||||
## Why TBLite?
|
||||
|
||||
Terminal-Bench 2.0 is one of the strongest frontier evaluations for terminal agents, but when a model scores near the floor (e.g., Qwen 3 8B at <1%), many changes look identical in aggregate score. TBLite addresses this by calibrating task difficulty using Claude Haiku 4.5 as a reference:
|
||||
|
||||
| Difficulty | Pass Rate Range | Tasks |
|
||||
|------------|----------------|-------|
|
||||
| Easy | >= 70% | 40 |
|
||||
| Medium | 40-69% | 26 |
|
||||
| Hard | 10-39% | 26 |
|
||||
| Extreme | < 10% | 8 |
|
||||
|
||||
This gives enough solvable tasks to detect small improvements quickly, while preserving enough hard tasks to avoid saturation. The correlation between TBLite and TB2 scores is **r = 0.911**.
|
||||
|
||||
TBLite also runs 2.6-8x faster than the full TB2, making it practical for iteration loops.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Run the full benchmark
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate
|
||||
|
||||
# Filter to specific tasks
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
--env.task_filter "broken-python,pandas-etl"
|
||||
|
||||
# Use a different model
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
--server.model_name "qwen/qwen3-30b"
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
`TBLiteEvalEnv` is a thin subclass of `TerminalBench2EvalEnv`. All evaluation logic (agent loop, Docker sandbox management, test verification, metrics) is inherited. Only the defaults differ:
|
||||
|
||||
| Setting | TB2 | TBLite |
|
||||
|----------------|----------------------------------|-----------------------------------------|
|
||||
| Dataset | `NousResearch/terminal-bench-2` | `NousResearch/openthoughts-tblite` |
|
||||
| Tasks | 89 | 100 |
|
||||
| Task timeout | 1800s (30 min) | 1200s (20 min) |
|
||||
| Wandb name | `terminal-bench-2` | `openthoughts-tblite` |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@software{OpenThoughts-TBLite,
|
||||
author = {OpenThoughts-Agent team, Snorkel AI, Bespoke Labs},
|
||||
month = Feb,
|
||||
title = {{OpenThoughts-TBLite: A High-Signal Benchmark for Iterating on Terminal Agents}},
|
||||
howpublished = {https://www.openthoughts.ai/blog/openthoughts-tblite},
|
||||
year = {2026}
|
||||
}
|
||||
```
|
||||
0
environments/benchmarks/tblite/__init__.py
Normal file
0
environments/benchmarks/tblite/__init__.py
Normal file
39
environments/benchmarks/tblite/default.yaml
Normal file
39
environments/benchmarks/tblite/default.yaml
Normal file
@@ -0,0 +1,39 @@
|
||||
# OpenThoughts-TBLite Evaluation -- Default Configuration
|
||||
#
|
||||
# Eval-only environment for the TBLite benchmark (100 difficulty-calibrated
|
||||
# terminal tasks, a faster proxy for Terminal-Bench 2.0).
|
||||
# Uses Modal terminal backend for per-task cloud-isolated sandboxes
|
||||
# and OpenRouter for inference.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/default.yaml \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 100 parallel tasks
|
||||
dataset_name: "NousResearch/openthoughts-tblite"
|
||||
test_timeout: 600
|
||||
task_timeout: 1200 # 20 min wall-clock per task (TBLite tasks are faster)
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "openthoughts-tblite"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/openthoughts-tblite"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
42
environments/benchmarks/tblite/run_eval.sh
Executable file
42
environments/benchmarks/tblite/run_eval.sh
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
|
||||
# OpenThoughts-TBLite Evaluation
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/tblite/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/tblite/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
#
|
||||
# Run a subset:
|
||||
# bash environments/benchmarks/tblite/run_eval.sh \
|
||||
# --env.task_filter broken-python,pandas-etl
|
||||
#
|
||||
# All terminal settings (backend, timeout, lifetime, pool size) are
|
||||
# configured via env config fields -- no env vars needed.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/openthoughts-tblite
|
||||
LOG_FILE="logs/tblite_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "OpenThoughts-TBLite Evaluation"
|
||||
echo "Log file: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
# Unbuffered python output so logs are written in real-time
|
||||
export PYTHONUNBUFFERED=1
|
||||
|
||||
# Show INFO-level agent loop timing (api/tool durations per turn)
|
||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
||||
export LOGLEVEL=INFO
|
||||
|
||||
python tblite_env.py evaluate \
|
||||
--config default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
echo "Eval results: evals/openthoughts-tblite/"
|
||||
119
environments/benchmarks/tblite/tblite_env.py
Normal file
119
environments/benchmarks/tblite/tblite_env.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
OpenThoughts-TBLite Evaluation Environment
|
||||
|
||||
A lighter, faster alternative to Terminal-Bench 2.0 for iterating on terminal
|
||||
agents. Uses the same evaluation logic as TerminalBench2EvalEnv but defaults
|
||||
to the NousResearch/openthoughts-tblite dataset (100 difficulty-calibrated
|
||||
tasks vs TB2's 89 harder tasks).
|
||||
|
||||
TBLite tasks are a curated subset of TB2 with a difficulty distribution
|
||||
designed to give meaningful signal even for smaller models:
|
||||
- Easy (40 tasks): >= 70% pass rate with Claude Haiku 4.5
|
||||
- Medium (26 tasks): 40-69% pass rate
|
||||
- Hard (26 tasks): 10-39% pass rate
|
||||
- Extreme (8 tasks): < 10% pass rate
|
||||
|
||||
Usage:
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate
|
||||
|
||||
# Filter to specific tasks:
|
||||
python environments/benchmarks/tblite/tblite_env.py evaluate \\
|
||||
--env.task_filter "broken-python,pandas-etl"
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.benchmarks.terminalbench_2.terminalbench2_env import (
|
||||
TerminalBench2EvalConfig,
|
||||
TerminalBench2EvalEnv,
|
||||
)
|
||||
|
||||
|
||||
class TBLiteEvalConfig(TerminalBench2EvalConfig):
|
||||
"""Configuration for the OpenThoughts-TBLite evaluation environment.
|
||||
|
||||
Inherits all TB2 config fields. Only the dataset default and task timeout
|
||||
differ -- TBLite tasks are calibrated to be faster.
|
||||
"""
|
||||
|
||||
dataset_name: str = Field(
|
||||
default="NousResearch/openthoughts-tblite",
|
||||
description="HuggingFace dataset containing TBLite tasks.",
|
||||
)
|
||||
|
||||
task_timeout: int = Field(
|
||||
default=1200,
|
||||
description="Maximum wall-clock seconds per task. TBLite tasks are "
|
||||
"generally faster than TB2, so 20 minutes is usually sufficient.",
|
||||
)
|
||||
|
||||
|
||||
class TBLiteEvalEnv(TerminalBench2EvalEnv):
|
||||
"""OpenThoughts-TBLite evaluation environment.
|
||||
|
||||
Inherits all evaluation logic from TerminalBench2EvalEnv (agent loop,
|
||||
test verification, Docker image resolution, metrics, wandb logging).
|
||||
Only the default configuration differs.
|
||||
"""
|
||||
|
||||
name = "openthoughts-tblite"
|
||||
env_config_cls = TBLiteEvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TBLiteEvalConfig, List[APIServerConfig]]:
|
||||
env_config = TBLiteEvalConfig(
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
|
||||
max_agent_turns=60,
|
||||
max_token_length=16000,
|
||||
agent_temperature=0.6,
|
||||
system_prompt=None,
|
||||
|
||||
terminal_backend="modal",
|
||||
terminal_timeout=300,
|
||||
|
||||
test_timeout=180,
|
||||
|
||||
# 100 tasks in parallel
|
||||
tool_pool_size=128,
|
||||
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="openthoughts-tblite",
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TBLiteEvalEnv.cli()
|
||||
0
environments/benchmarks/terminalbench_2/__init__.py
Normal file
0
environments/benchmarks/terminalbench_2/__init__.py
Normal file
38
environments/benchmarks/terminalbench_2/default.yaml
Normal file
38
environments/benchmarks/terminalbench_2/default.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
# Terminal-Bench 2.0 Evaluation -- Default Configuration
|
||||
#
|
||||
# Eval-only environment for the TB2 benchmark (89 terminal tasks).
|
||||
# Uses Modal terminal backend for per-task cloud-isolated sandboxes
|
||||
# and OpenRouter for inference.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
# --config environments/benchmarks/terminalbench_2/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
# --config environments/benchmarks/terminalbench_2/default.yaml \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2"
|
||||
test_timeout: 600
|
||||
task_timeout: 1800 # 30 min wall-clock per task, auto-FAIL if exceeded
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "terminal-bench-2"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/terminal-bench-2"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
42
environments/benchmarks/terminalbench_2/run_eval.sh
Executable file
42
environments/benchmarks/terminalbench_2/run_eval.sh
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Terminal-Bench 2.0 Evaluation
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
#
|
||||
# Run a subset:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
||||
# --env.task_filter fix-git,git-multibranch
|
||||
#
|
||||
# All terminal settings (backend, timeout, lifetime, pool size) are
|
||||
# configured via env config fields -- no env vars needed.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/terminal-bench-2
|
||||
LOG_FILE="logs/terminalbench2_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "Terminal-Bench 2.0 Evaluation"
|
||||
echo "Log file: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
# Unbuffered python output so logs are written in real-time
|
||||
export PYTHONUNBUFFERED=1
|
||||
|
||||
# Show INFO-level agent loop timing (api/tool durations per turn)
|
||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
||||
export LOGLEVEL=INFO
|
||||
|
||||
python terminalbench2_env.py evaluate \
|
||||
--config default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
echo "Eval results: evals/terminal-bench-2/"
|
||||
904
environments/benchmarks/terminalbench_2/terminalbench2_env.py
Normal file
904
environments/benchmarks/terminalbench_2/terminalbench2_env.py
Normal file
@@ -0,0 +1,904 @@
|
||||
"""
|
||||
TerminalBench2Env -- Terminal-Bench 2.0 Evaluation Environment
|
||||
|
||||
Evaluates agentic LLMs on challenging terminal tasks from Terminal-Bench 2.0.
|
||||
Each task provides a unique Docker environment (pre-built on Docker Hub), a natural
|
||||
language instruction, and a test suite for verification. The agent uses terminal +
|
||||
file tools to complete the task, then the test suite runs inside the same sandbox.
|
||||
|
||||
This is an eval-only environment (not a training environment). It is designed to
|
||||
be run via the `evaluate` subcommand:
|
||||
|
||||
python environments/terminalbench2_env.py evaluate \\
|
||||
--env.dataset_name NousResearch/terminal-bench-2
|
||||
|
||||
The evaluate flow:
|
||||
1. setup() -- Loads the TB2 dataset from HuggingFace
|
||||
2. evaluate() -- Iterates over all tasks, running each through:
|
||||
a. rollout_and_score_eval() -- Per-task agent loop + test verification
|
||||
- Resolves Docker image (pre-built Hub image or Dockerfile fallback)
|
||||
- Registers per-task Modal sandbox via register_task_env_overrides()
|
||||
- Runs the HermesAgentLoop (terminal + file tools)
|
||||
- Uploads test suite and runs test.sh in the same sandbox
|
||||
- Returns binary pass/fail result
|
||||
b. Aggregates per-task, per-category, and overall pass rates
|
||||
c. Logs results via evaluate_log() and wandb
|
||||
|
||||
Key features:
|
||||
- Per-task Modal sandboxes using pre-built Docker Hub images
|
||||
- Binary reward: 1.0 if all tests pass, 0.0 otherwise
|
||||
- Concurrency-controlled parallel evaluation via asyncio.Semaphore
|
||||
- Per-task, per-category, and aggregate pass rate tracking
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.terminal_tool import (
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
cleanup_vm,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
class TerminalBench2EvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the Terminal-Bench 2.0 evaluation environment.
|
||||
|
||||
Extends HermesAgentEnvConfig with TB2-specific settings for dataset loading,
|
||||
test execution, task filtering, and eval concurrency.
|
||||
"""
|
||||
|
||||
# --- Dataset ---
|
||||
dataset_name: str = Field(
|
||||
default="NousResearch/terminal-bench-2",
|
||||
description="HuggingFace dataset containing TB2 tasks.",
|
||||
)
|
||||
|
||||
# --- Test execution ---
|
||||
test_timeout: int = Field(
|
||||
default=180,
|
||||
description="Timeout in seconds for running the test suite after agent completes.",
|
||||
)
|
||||
|
||||
# --- Image strategy ---
|
||||
force_build: bool = Field(
|
||||
default=False,
|
||||
description="If True, always build from Dockerfile (ignore docker_image). "
|
||||
"Useful for testing custom Dockerfiles.",
|
||||
)
|
||||
|
||||
# --- Task filtering (comma-separated from CLI) ---
|
||||
task_filter: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Comma-separated task names to run (e.g., 'fix-git,git-multibranch'). "
|
||||
"If not set, all tasks are run.",
|
||||
)
|
||||
skip_tasks: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Comma-separated task names to skip on top of the default skip list.",
|
||||
)
|
||||
|
||||
# --- Per-task wall-clock timeout ---
|
||||
task_timeout: int = Field(
|
||||
default=1800,
|
||||
description="Maximum wall-clock seconds per task (agent loop + verification). "
|
||||
"Tasks exceeding this are scored as FAIL. Default 30 minutes.",
|
||||
)
|
||||
|
||||
|
||||
# Tasks that cannot run properly on Modal and are excluded from scoring.
|
||||
MODAL_INCOMPATIBLE_TASKS = {
|
||||
"qemu-startup", # Needs KVM/hardware virtualization
|
||||
"qemu-alpine-ssh", # Needs KVM/hardware virtualization
|
||||
"crack-7z-hash", # Password brute-force -- too slow for cloud sandbox timeouts
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tar extraction helper
|
||||
# =============================================================================
|
||||
|
||||
def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
"""Extract a base64-encoded tar.gz archive into target_dir."""
|
||||
if not b64_data:
|
||||
return
|
||||
raw = base64.b64decode(b64_data)
|
||||
buf = io.BytesIO(raw)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
tar.extractall(path=str(target_dir))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
Terminal-Bench 2.0 evaluation environment (eval-only, no training).
|
||||
|
||||
Inherits from HermesAgentBaseEnv for:
|
||||
- Terminal backend setup (os.environ["TERMINAL_ENV"])
|
||||
- Tool resolution via _resolve_tools_for_group()
|
||||
- Monkey patches for async-safe tool operation
|
||||
- Wandb trajectory formatting
|
||||
|
||||
The evaluate flow (triggered by `environment.py evaluate`):
|
||||
1. setup() -- Load dataset from HuggingFace
|
||||
2. evaluate() -- Run all tasks through rollout_and_score_eval()
|
||||
|
||||
Each task in rollout_and_score_eval():
|
||||
1. Resolve Docker image (pre-built Hub image or Dockerfile fallback)
|
||||
2. Register per-task Modal sandbox override
|
||||
3. Run HermesAgentLoop with terminal + file tools
|
||||
4. Upload test suite and execute test.sh in the same sandbox
|
||||
5. Check /logs/verifier/reward.txt for pass/fail
|
||||
6. Clean up sandbox, overrides, and temp files
|
||||
"""
|
||||
|
||||
name = "terminal-bench-2"
|
||||
env_config_cls = TerminalBench2EvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TerminalBench2EvalConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for Terminal-Bench 2.0 evaluation.
|
||||
|
||||
Uses eval-only settings:
|
||||
- eval_handling=STOP_TRAIN so the eval flow runs cleanly
|
||||
- steps_per_eval=1, total_steps=1 so eval triggers immediately
|
||||
- group_size=1 (one rollout per group, each task is expensive)
|
||||
|
||||
Uses Modal terminal backend (cloud-isolated sandbox per task) and
|
||||
OpenRouter with Claude for inference.
|
||||
"""
|
||||
env_config = TerminalBench2EvalConfig(
|
||||
# Terminal + file tools only (the agent interacts via shell commands)
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
|
||||
# Agent settings -- TB2 tasks are complex, need many turns
|
||||
max_agent_turns=60,
|
||||
max_token_length=16000,
|
||||
agent_temperature=0.6,
|
||||
system_prompt=None,
|
||||
|
||||
# Modal backend for per-task cloud-isolated sandboxes
|
||||
terminal_backend="modal",
|
||||
terminal_timeout=300, # 5 min per command (builds, pip install, etc.)
|
||||
|
||||
# Test execution timeout (TB2 test scripts can install deps like pytest)
|
||||
test_timeout=180,
|
||||
|
||||
# 89 tasks run in parallel, each needs a thread for tool calls
|
||||
tool_pool_size=128,
|
||||
|
||||
# --- Eval-only Atropos settings ---
|
||||
# These settings make the env work as an eval-only environment:
|
||||
# - STOP_TRAIN: pauses training during eval (standard for eval envs)
|
||||
# - steps_per_eval=1, total_steps=1: eval triggers immediately
|
||||
# - group_size=1: one rollout per group (each task is expensive)
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-bench-2",
|
||||
ensure_scores_are_not_same=False, # Binary rewards may all be 0 or 1
|
||||
)
|
||||
|
||||
# OpenRouter with Claude -- API key loaded from .env
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Setup -- load dataset
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Load the Terminal-Bench 2.0 dataset from HuggingFace."""
|
||||
from datasets import load_dataset
|
||||
|
||||
# Auto-set terminal_lifetime to task_timeout + 120s so sandboxes
|
||||
# never get killed during an active task, but still get cleaned up
|
||||
# promptly after the task times out.
|
||||
lifetime = self.config.task_timeout + 120
|
||||
self.config.terminal_lifetime = lifetime
|
||||
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(lifetime)
|
||||
print(f" Terminal lifetime auto-set to {lifetime}s (task_timeout + 120s)")
|
||||
|
||||
print(f"Loading TB2 dataset from: {self.config.dataset_name}")
|
||||
ds = load_dataset(self.config.dataset_name, split="train")
|
||||
|
||||
# Apply task filters (comma-separated strings from CLI)
|
||||
tasks = list(ds)
|
||||
if self.config.task_filter:
|
||||
allowed = {name.strip() for name in self.config.task_filter.split(",")}
|
||||
tasks = [t for t in tasks if t["task_name"] in allowed]
|
||||
print(f" Filtered to {len(tasks)} tasks: {sorted(allowed)}")
|
||||
|
||||
# Skip tasks incompatible with the current backend (e.g., QEMU on Modal)
|
||||
# plus any user-specified skip_tasks
|
||||
skip = set(MODAL_INCOMPATIBLE_TASKS) if self.config.terminal_backend == "modal" else set()
|
||||
if self.config.skip_tasks:
|
||||
skip |= {name.strip() for name in self.config.skip_tasks.split(",")}
|
||||
if skip:
|
||||
before = len(tasks)
|
||||
tasks = [t for t in tasks if t["task_name"] not in skip]
|
||||
skipped = before - len(tasks)
|
||||
if skipped > 0:
|
||||
print(f" Skipped {skipped} incompatible tasks: {sorted(skip & {t['task_name'] for t in ds})}")
|
||||
|
||||
self.all_eval_items = tasks
|
||||
self.iter = 0
|
||||
|
||||
# Build category index for per-category metrics
|
||||
self.category_index: Dict[str, List[int]] = defaultdict(list)
|
||||
for i, task in enumerate(self.all_eval_items):
|
||||
self.category_index[task.get("category", "unknown")].append(i)
|
||||
|
||||
# Reward tracking for wandb logging
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
# Streaming JSONL writer -- saves each task's full conversation
|
||||
# immediately on completion so data is preserved even on Ctrl+C.
|
||||
# Timestamped filename so each run produces a unique file.
|
||||
import datetime
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
self._streaming_file = open(self._streaming_path, "w")
|
||||
self._streaming_lock = __import__("threading").Lock()
|
||||
print(f" Streaming results to: {self._streaming_path}")
|
||||
|
||||
print(f"TB2 ready: {len(self.all_eval_items)} tasks across {len(self.category_index)} categories")
|
||||
for cat, indices in sorted(self.category_index.items()):
|
||||
print(f" {cat}: {len(indices)} tasks")
|
||||
|
||||
def _save_result(self, result: Dict[str, Any]):
|
||||
"""Write a single task result to the streaming JSONL file immediately."""
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
# Training pipeline stubs -- NOT used in eval-only mode
|
||||
# =========================================================================
|
||||
# These satisfy the abstract method requirements from HermesAgentBaseEnv.
|
||||
# The evaluate subcommand calls setup() -> evaluate() directly, bypassing
|
||||
# the training pipeline entirely.
|
||||
|
||||
async def get_next_item(self):
|
||||
"""Return next item (stub -- not used in eval-only mode)."""
|
||||
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
"""Return the task's instruction as the user prompt."""
|
||||
return item["instruction"]
|
||||
|
||||
async def compute_reward(self, item, result, ctx) -> float:
|
||||
"""Compute reward (stub -- actual verification is in rollout_and_score_eval)."""
|
||||
return 0.0
|
||||
|
||||
async def collect_trajectories(self, item):
|
||||
"""Collect trajectories (stub -- not used in eval-only mode)."""
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data):
|
||||
"""Score rollouts (stub -- not used in eval-only mode)."""
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Docker image resolution
|
||||
# =========================================================================
|
||||
|
||||
def _resolve_task_image(
|
||||
self, item: Dict[str, Any], task_name: str
|
||||
) -> Tuple[str, Optional[Path]]:
|
||||
"""
|
||||
Resolve the Docker image for a task, with fallback to Dockerfile.
|
||||
|
||||
Strategy (mirrors Harbor's approach):
|
||||
1. If force_build=True, always build from Dockerfile in environment_tar
|
||||
2. If docker_image is available, use the pre-built Docker Hub image (fast)
|
||||
3. Otherwise, extract Dockerfile from environment_tar and build (slow)
|
||||
|
||||
Returns:
|
||||
(modal_image, temp_dir) -- modal_image is a Docker Hub name or a
|
||||
Dockerfile path. temp_dir is set if we extracted files that need
|
||||
cleanup later.
|
||||
"""
|
||||
docker_image = item.get("docker_image", "")
|
||||
environment_tar = item.get("environment_tar", "")
|
||||
|
||||
# Fast path: use pre-built Docker Hub image
|
||||
if docker_image and not self.config.force_build:
|
||||
logger.info("Task %s: using pre-built image %s", task_name, docker_image)
|
||||
return docker_image, None
|
||||
|
||||
# Slow path: extract Dockerfile from environment_tar and build
|
||||
if environment_tar:
|
||||
task_dir = Path(tempfile.mkdtemp(prefix=f"tb2-{task_name}-"))
|
||||
_extract_base64_tar(environment_tar, task_dir)
|
||||
dockerfile_path = task_dir / "Dockerfile"
|
||||
if dockerfile_path.exists():
|
||||
logger.info(
|
||||
"Task %s: building from Dockerfile (force_build=%s, docker_image=%s)",
|
||||
task_name, self.config.force_build, bool(docker_image),
|
||||
)
|
||||
return str(dockerfile_path), task_dir
|
||||
|
||||
# Neither available -- fall back to Hub image if force_build was True
|
||||
if docker_image:
|
||||
logger.warning(
|
||||
"Task %s: force_build=True but no environment_tar, "
|
||||
"falling back to docker_image %s", task_name, docker_image,
|
||||
)
|
||||
return docker_image, None
|
||||
|
||||
return "", None
|
||||
|
||||
# =========================================================================
|
||||
# Per-task evaluation -- agent loop + test verification
|
||||
# =========================================================================
|
||||
|
||||
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Evaluate a single TB2 task: run the agent loop, then verify with tests.
|
||||
|
||||
This is the core evaluation method. For each task it:
|
||||
1. Resolves the Docker image and registers the Modal sandbox override
|
||||
2. Runs HermesAgentLoop with terminal + file tools
|
||||
3. Uploads the test suite into the sandbox
|
||||
4. Executes test.sh and checks the result
|
||||
5. Cleans up the sandbox and temp files
|
||||
|
||||
Args:
|
||||
eval_item: A single TB2 task dict from the dataset
|
||||
|
||||
Returns:
|
||||
Dict with 'passed' (bool), 'reward' (float), 'task_name' (str),
|
||||
'category' (str), and optional debug info
|
||||
"""
|
||||
task_name = eval_item.get("task_name", "unknown")
|
||||
category = eval_item.get("category", "unknown")
|
||||
task_id = str(uuid.uuid4())
|
||||
task_dir = None # Set if we extract a Dockerfile (needs cleanup)
|
||||
|
||||
from tqdm import tqdm
|
||||
tqdm.write(f" [START] {task_name} (task_id={task_id[:8]})")
|
||||
task_start = time.time()
|
||||
|
||||
try:
|
||||
# --- 1. Resolve Docker image ---
|
||||
modal_image, task_dir = self._resolve_task_image(eval_item, task_name)
|
||||
if not modal_image:
|
||||
logger.error("Task %s: no docker_image or environment_tar, skipping", task_name)
|
||||
return {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": "no_image",
|
||||
}
|
||||
|
||||
# --- 2. Register per-task Modal image override ---
|
||||
register_task_env_overrides(task_id, {"modal_image": modal_image})
|
||||
logger.info(
|
||||
"Task %s: registered image override for task_id %s",
|
||||
task_name, task_id[:8],
|
||||
)
|
||||
|
||||
# --- 3. Resolve tools and build messages ---
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(eval_item)})
|
||||
|
||||
# --- 4. Run agent loop ---
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# --- 5. Verify -- run test suite in the agent's sandbox ---
|
||||
# Skip verification if the agent produced no meaningful output
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Task %s: agent produced no output (turns=%d). Reward=0.",
|
||||
task_name, result.turns_used,
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# Run tests in a thread so the blocking ctx.terminal() calls
|
||||
# don't freeze the entire event loop (which would stall all
|
||||
# other tasks, tqdm updates, and timeout timers).
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
reward = await loop.run_in_executor(
|
||||
None, # default thread pool
|
||||
self._run_tests, eval_item, ctx, task_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Task %s: test verification failed: %s", task_name, e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
passed = reward == 1.0
|
||||
status = "PASS" if passed else "FAIL"
|
||||
elapsed = time.time() - task_start
|
||||
tqdm.write(f" [{status}] {task_name} (turns={result.turns_used}, {elapsed:.0f}s)")
|
||||
logger.info(
|
||||
"Task %s: reward=%.1f, turns=%d, finished=%s",
|
||||
task_name, reward, result.turns_used, result.finished_naturally,
|
||||
)
|
||||
|
||||
out = {
|
||||
"passed": passed,
|
||||
"reward": reward,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"messages": result.messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - task_start
|
||||
logger.error("Task %s: rollout failed: %s", task_name, e, exc_info=True)
|
||||
tqdm.write(f" [ERROR] {task_name}: {e} ({elapsed:.0f}s)")
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": str(e),
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
finally:
|
||||
# --- Cleanup: clear overrides, sandbox, and temp files ---
|
||||
clear_task_env_overrides(task_id)
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
except Exception as e:
|
||||
logger.debug("VM cleanup for %s: %s", task_id[:8], e)
|
||||
if task_dir and task_dir.exists():
|
||||
shutil.rmtree(task_dir, ignore_errors=True)
|
||||
|
||||
def _run_tests(
|
||||
self, item: Dict[str, Any], ctx: ToolContext, task_name: str
|
||||
) -> float:
|
||||
"""
|
||||
Upload and execute the test suite in the agent's sandbox, then
|
||||
download the verifier output locally to read the reward.
|
||||
|
||||
Follows Harbor's verification pattern:
|
||||
1. Upload tests/ directory into the sandbox
|
||||
2. Execute test.sh inside the sandbox
|
||||
3. Download /logs/verifier/ directory to a local temp dir
|
||||
4. Read reward.txt locally with native Python I/O
|
||||
|
||||
Downloading locally avoids issues with the file_read tool on
|
||||
the Modal VM and matches how Harbor handles verification.
|
||||
|
||||
TB2 test scripts (test.sh) typically:
|
||||
1. Install pytest via uv/pip
|
||||
2. Run pytest against the test files in /tests/
|
||||
3. Write results to /logs/verifier/reward.txt
|
||||
|
||||
Args:
|
||||
item: The TB2 task dict (contains tests_tar, test_sh)
|
||||
ctx: ToolContext scoped to this task's sandbox
|
||||
task_name: For logging
|
||||
|
||||
Returns:
|
||||
1.0 if tests pass, 0.0 otherwise
|
||||
"""
|
||||
tests_tar = item.get("tests_tar", "")
|
||||
test_sh = item.get("test_sh", "")
|
||||
|
||||
if not test_sh:
|
||||
logger.warning("Task %s: no test_sh content, reward=0", task_name)
|
||||
return 0.0
|
||||
|
||||
# Create required directories in the sandbox
|
||||
ctx.terminal("mkdir -p /tests /logs/verifier")
|
||||
|
||||
# Upload test files into the sandbox (binary-safe via base64)
|
||||
if tests_tar:
|
||||
tests_temp = Path(tempfile.mkdtemp(prefix=f"tb2-tests-{task_name}-"))
|
||||
try:
|
||||
_extract_base64_tar(tests_tar, tests_temp)
|
||||
ctx.upload_dir(str(tests_temp), "/tests")
|
||||
except Exception as e:
|
||||
logger.warning("Task %s: failed to upload test files: %s", task_name, e)
|
||||
finally:
|
||||
shutil.rmtree(tests_temp, ignore_errors=True)
|
||||
|
||||
# Write the test runner script (test.sh)
|
||||
ctx.write_file("/tests/test.sh", test_sh)
|
||||
ctx.terminal("chmod +x /tests/test.sh")
|
||||
|
||||
# Execute the test suite
|
||||
logger.info(
|
||||
"Task %s: running test suite (timeout=%ds)",
|
||||
task_name, self.config.test_timeout,
|
||||
)
|
||||
test_result = ctx.terminal(
|
||||
"bash /tests/test.sh",
|
||||
timeout=self.config.test_timeout,
|
||||
)
|
||||
|
||||
exit_code = test_result.get("exit_code", -1)
|
||||
output = test_result.get("output", "")
|
||||
|
||||
# Download the verifier output directory locally, then read reward.txt
|
||||
# with native Python I/O. This avoids issues with file_read on the
|
||||
# Modal VM and matches Harbor's verification pattern.
|
||||
reward = 0.0
|
||||
local_verifier_dir = Path(tempfile.mkdtemp(prefix=f"tb2-verifier-{task_name}-"))
|
||||
try:
|
||||
ctx.download_dir("/logs/verifier", str(local_verifier_dir))
|
||||
|
||||
reward_file = local_verifier_dir / "reward.txt"
|
||||
if reward_file.exists() and reward_file.stat().st_size > 0:
|
||||
content = reward_file.read_text().strip()
|
||||
if content == "1":
|
||||
reward = 1.0
|
||||
elif content == "0":
|
||||
reward = 0.0
|
||||
else:
|
||||
# Unexpected content -- try parsing as float
|
||||
try:
|
||||
reward = float(content)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Task %s: reward.txt content unexpected (%r), "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, content, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
else:
|
||||
# reward.txt not written -- fall back to exit code
|
||||
logger.warning(
|
||||
"Task %s: reward.txt not found after download, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Task %s: failed to download verifier dir: %s, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, e, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
finally:
|
||||
shutil.rmtree(local_verifier_dir, ignore_errors=True)
|
||||
|
||||
# Log test output for debugging failures
|
||||
if reward == 0.0:
|
||||
output_preview = output[-500:] if output else "(no output)"
|
||||
logger.info(
|
||||
"Task %s: FAIL (exit_code=%d)\n%s",
|
||||
task_name, exit_code, output_preview,
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# =========================================================================
|
||||
# Evaluate -- main entry point for the eval subcommand
|
||||
# =========================================================================
|
||||
|
||||
async def _eval_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Wrap rollout_and_score_eval with a per-task wall-clock timeout.
|
||||
|
||||
If the task exceeds task_timeout seconds, it's automatically scored
|
||||
as FAIL. This prevents any single task from hanging indefinitely.
|
||||
"""
|
||||
task_name = item.get("task_name", "unknown")
|
||||
category = item.get("category", "unknown")
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
timeout=self.config.task_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
elapsed = self.config.task_timeout
|
||||
tqdm.write(f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)")
|
||||
logger.error("Task %s: wall-clock timeout after %ds", task_name, elapsed)
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": f"timeout ({elapsed}s)",
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run Terminal-Bench 2.0 evaluation over all tasks.
|
||||
|
||||
This is the main entry point when invoked via:
|
||||
python environments/terminalbench2_env.py evaluate
|
||||
|
||||
Runs all tasks through rollout_and_score_eval() via asyncio.gather()
|
||||
(same pattern as GPQA and other Atropos eval envs). Each task is
|
||||
wrapped with a wall-clock timeout so hung tasks auto-fail.
|
||||
|
||||
Suppresses noisy Modal/terminal output (HERMES_QUIET) so the tqdm
|
||||
bar stays visible.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Route all logging through tqdm.write() so the progress bar stays
|
||||
# pinned at the bottom while log lines scroll above it.
|
||||
from tqdm import tqdm
|
||||
|
||||
class _TqdmHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
tqdm.write(self.format(record))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
))
|
||||
root = logging.getLogger()
|
||||
root.handlers = [handler] # Replace any existing handlers
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
# Silence noisy third-party loggers that flood the output
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request
|
||||
logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries
|
||||
logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment
|
||||
logging.getLogger("rex_image_builder").setLevel(logging.WARNING) # Image builds
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting Terminal-Bench 2.0 Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Total tasks: {len(self.all_eval_items)}")
|
||||
print(f" Max agent turns: {self.config.max_agent_turns}")
|
||||
print(f" Task timeout: {self.config.task_timeout}s")
|
||||
print(f" Terminal backend: {self.config.terminal_backend}")
|
||||
print(f" Tool thread pool: {self.config.tool_pool_size}")
|
||||
print(f" Terminal timeout: {self.config.terminal_timeout}s/cmd")
|
||||
print(f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Fire all tasks with wall-clock timeout, track live accuracy on the bar
|
||||
total_tasks = len(self.all_eval_items)
|
||||
eval_tasks = [
|
||||
asyncio.ensure_future(self._eval_with_timeout(item))
|
||||
for item in self.all_eval_items
|
||||
]
|
||||
|
||||
results = []
|
||||
passed_count = 0
|
||||
pbar = tqdm(total=total_tasks, desc="Evaluating TB2", dynamic_ncols=True)
|
||||
try:
|
||||
for coro in asyncio.as_completed(eval_tasks):
|
||||
result = await coro
|
||||
results.append(result)
|
||||
if result and result.get("passed"):
|
||||
passed_count += 1
|
||||
done = len(results)
|
||||
pct = (passed_count / done * 100) if done else 0
|
||||
pbar.set_postfix_str(f"pass={passed_count}/{done} ({pct:.1f}%)")
|
||||
pbar.update(1)
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
pbar.close()
|
||||
print(f"\n\nInterrupted! Cleaning up {len(eval_tasks)} tasks...")
|
||||
# Cancel all pending tasks
|
||||
for task in eval_tasks:
|
||||
task.cancel()
|
||||
# Let cancellations propagate (finally blocks run cleanup_vm)
|
||||
await asyncio.gather(*eval_tasks, return_exceptions=True)
|
||||
# Belt-and-suspenders: clean up any remaining sandboxes
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
print("All sandboxes cleaned up.")
|
||||
return
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Filter out None results (shouldn't happen, but be safe)
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return
|
||||
|
||||
# ---- Compute metrics ----
|
||||
total = len(valid_results)
|
||||
passed = sum(1 for r in valid_results if r.get("passed"))
|
||||
overall_pass_rate = passed / total if total > 0 else 0.0
|
||||
|
||||
# Per-category breakdown
|
||||
cat_results: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in valid_results:
|
||||
cat_results[r.get("category", "unknown")].append(r)
|
||||
|
||||
# Build metrics dict
|
||||
eval_metrics = {
|
||||
"eval/pass_rate": overall_pass_rate,
|
||||
"eval/total_tasks": total,
|
||||
"eval/passed_tasks": passed,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
}
|
||||
|
||||
# Per-category metrics
|
||||
for category, cat_items in sorted(cat_results.items()):
|
||||
cat_passed = sum(1 for r in cat_items if r.get("passed"))
|
||||
cat_total = len(cat_items)
|
||||
cat_pass_rate = cat_passed / cat_total if cat_total > 0 else 0.0
|
||||
cat_key = category.replace(" ", "_").replace("-", "_").lower()
|
||||
eval_metrics[f"eval/pass_rate_{cat_key}"] = cat_pass_rate
|
||||
|
||||
# Store metrics for wandb_log
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# ---- Print summary ----
|
||||
print(f"\n{'='*60}")
|
||||
print("Terminal-Bench 2.0 Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(f"Overall Pass Rate: {overall_pass_rate:.4f} ({passed}/{total})")
|
||||
print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
|
||||
|
||||
print("\nCategory Breakdown:")
|
||||
for category, cat_items in sorted(cat_results.items()):
|
||||
cat_passed = sum(1 for r in cat_items if r.get("passed"))
|
||||
cat_total = len(cat_items)
|
||||
cat_rate = cat_passed / cat_total if cat_total > 0 else 0.0
|
||||
print(f" {category}: {cat_rate:.1%} ({cat_passed}/{cat_total})")
|
||||
|
||||
# Print individual task results
|
||||
print("\nTask Results:")
|
||||
for r in sorted(valid_results, key=lambda x: x.get("task_name", "")):
|
||||
status = "PASS" if r.get("passed") else "FAIL"
|
||||
turns = r.get("turns_used", "?")
|
||||
error = r.get("error", "")
|
||||
extra = f" (error: {error})" if error else ""
|
||||
print(f" [{status}] {r['task_name']} (turns={turns}){extra}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Build sample records for evaluate_log (includes full conversations)
|
||||
samples = [
|
||||
{
|
||||
"task_name": r.get("task_name"),
|
||||
"category": r.get("category"),
|
||||
"passed": r.get("passed"),
|
||||
"reward": r.get("reward"),
|
||||
"turns_used": r.get("turns_used"),
|
||||
"error": r.get("error"),
|
||||
"messages": r.get("messages"),
|
||||
}
|
||||
for r in valid_results
|
||||
]
|
||||
|
||||
# Log evaluation results
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
"terminal_backend": self.config.terminal_backend,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error logging evaluation results: {e}")
|
||||
|
||||
# Close streaming file
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
print(f" Live results saved to: {self._streaming_path}")
|
||||
|
||||
# Kill all remaining sandboxes. Timed-out tasks leave orphaned thread
|
||||
# pool workers still executing commands -- cleanup_all stops them.
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
print("\nCleaning up all sandboxes...")
|
||||
cleanup_all_environments()
|
||||
|
||||
# Shut down the tool thread pool so orphaned workers from timed-out
|
||||
# tasks are killed immediately instead of retrying against dead
|
||||
# sandboxes and spamming the console with TimeoutError warnings.
|
||||
from environments.agent_loop import _tool_executor
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
print("Done.")
|
||||
|
||||
# =========================================================================
|
||||
# Wandb logging
|
||||
# =========================================================================
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log TB2-specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Add stored eval metrics
|
||||
for metric_name, metric_value in self.eval_metrics:
|
||||
wandb_metrics[metric_name] = metric_value
|
||||
self.eval_metrics = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TerminalBench2EvalEnv.cli()
|
||||
115
environments/benchmarks/yc_bench/README.md
Normal file
115
environments/benchmarks/yc_bench/README.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# YC-Bench: Long-Horizon Agent Benchmark
|
||||
|
||||
[YC-Bench](https://github.com/collinear-ai/yc-bench) by [Collinear AI](https://collinear.ai/) is a deterministic, long-horizon benchmark that tests LLM agents' ability to act as a tech startup CEO. The agent manages a simulated company over 1-3 years, making compounding decisions about resource allocation, cash flow, task management, and prestige specialisation across 4 skill domains.
|
||||
|
||||
Unlike TerminalBench2 (which evaluates per-task coding ability with binary pass/fail), YC-Bench measures **long-term strategic coherence** — whether an agent can maintain consistent strategy, manage compounding consequences, and adapt plans over hundreds of turns.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install yc-bench (optional dependency)
|
||||
pip install "hermes-agent[yc-bench]"
|
||||
|
||||
# Or install from source
|
||||
git clone https://github.com/collinear-ai/yc-bench
|
||||
cd yc-bench && pip install -e .
|
||||
|
||||
# Verify
|
||||
yc-bench --help
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# From the repo root:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
|
||||
# Or directly:
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
# Override model:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
# Quick single-preset test:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
HermesAgentLoop (our agent)
|
||||
-> terminal tool -> subprocess("yc-bench company status") -> JSON output
|
||||
-> terminal tool -> subprocess("yc-bench task accept --task-id X") -> JSON
|
||||
-> terminal tool -> subprocess("yc-bench sim resume") -> JSON (advance time)
|
||||
-> ... (100-500 turns per run)
|
||||
```
|
||||
|
||||
The environment initialises the simulation via `yc-bench sim init` (NOT `yc-bench run`, which would start yc-bench's own built-in agent loop). Our `HermesAgentLoop` then drives all interaction through CLI commands.
|
||||
|
||||
### Simulation Mechanics
|
||||
|
||||
- **4 skill domains**: research, inference, data_environment, training
|
||||
- **Prestige system** (1.0-10.0): Gates access to higher-paying tasks
|
||||
- **Employee management**: Junior/Mid/Senior with domain-specific skill rates
|
||||
- **Throughput splitting**: `effective_rate = base_rate / N` active tasks per employee
|
||||
- **Financial pressure**: Monthly payroll, bankruptcy = game over
|
||||
- **Deterministic**: SHA256-based RNG — same seed + preset = same world
|
||||
|
||||
### Difficulty Presets
|
||||
|
||||
| Preset | Employees | Tasks | Focus |
|
||||
|-----------|-----------|-------|-------|
|
||||
| tutorial | 3 | 50 | Basic loop mechanics |
|
||||
| easy | 5 | 100 | Throughput awareness |
|
||||
| **medium**| 5 | 150 | Prestige climbing + domain specialisation |
|
||||
| **hard** | 7 | 200 | Precise ETA reasoning |
|
||||
| nightmare | 8 | 300 | Sustained perfection under payroll pressure |
|
||||
| fast_test | (varies) | (varies) | Quick validation (~50 turns) |
|
||||
|
||||
Default eval runs **fast_test + medium + hard** × 3 seeds = 9 runs.
|
||||
|
||||
### Scoring
|
||||
|
||||
```
|
||||
composite = 0.5 × survival + 0.5 × normalised_funds
|
||||
```
|
||||
|
||||
- **Survival** (binary): Did the company avoid bankruptcy?
|
||||
- **Normalised funds** (0.0-1.0): Log-scale relative to initial $250K capital
|
||||
|
||||
## Configuration
|
||||
|
||||
Key fields in `default.yaml`:
|
||||
|
||||
| Field | Default | Description |
|
||||
|-------|---------|-------------|
|
||||
| `presets` | `["fast_test", "medium", "hard"]` | Which presets to evaluate |
|
||||
| `seeds` | `[1, 2, 3]` | RNG seeds per preset |
|
||||
| `max_agent_turns` | 200 | Max LLM calls per run |
|
||||
| `run_timeout` | 3600 | Wall-clock timeout per run (seconds) |
|
||||
| `survival_weight` | 0.5 | Weight of survival in composite score |
|
||||
| `funds_weight` | 0.5 | Weight of normalised funds in composite |
|
||||
| `horizon_years` | null | Override horizon (null = auto from preset) |
|
||||
|
||||
## Cost & Time Estimates
|
||||
|
||||
Each run is 100-500 LLM turns. Approximate costs per run at typical API rates:
|
||||
|
||||
| Preset | Turns | Time | Est. Cost |
|
||||
|--------|-------|------|-----------|
|
||||
| fast_test | ~50 | 5-10 min | $1-5 |
|
||||
| medium | ~200 | 20-40 min | $5-15 |
|
||||
| hard | ~300 | 30-60 min | $10-25 |
|
||||
|
||||
Full default eval (9 runs): ~3-6 hours, $50-200 depending on model.
|
||||
|
||||
## References
|
||||
|
||||
- [collinear-ai/yc-bench](https://github.com/collinear-ai/yc-bench) — Official repository
|
||||
- [Collinear AI](https://collinear.ai/) — Company behind yc-bench
|
||||
- [TerminalBench2](../terminalbench_2/) — Per-task coding benchmark (complementary)
|
||||
0
environments/benchmarks/yc_bench/__init__.py
Normal file
0
environments/benchmarks/yc_bench/__init__.py
Normal file
43
environments/benchmarks/yc_bench/default.yaml
Normal file
43
environments/benchmarks/yc_bench/default.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
# YC-Bench Evaluation -- Default Configuration
|
||||
#
|
||||
# Long-horizon agent benchmark: agent plays CEO of an AI startup over
|
||||
# a simulated 1-3 year run, interacting via yc-bench CLI subcommands.
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal"]
|
||||
max_agent_turns: 200
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.0
|
||||
terminal_backend: "local"
|
||||
terminal_timeout: 60
|
||||
presets: ["fast_test", "medium", "hard"]
|
||||
seeds: [1, 2, 3]
|
||||
run_timeout: 3600 # 60 min wall-clock per run, auto-FAIL if exceeded
|
||||
survival_weight: 0.5 # weight of binary survival in composite score
|
||||
funds_weight: 0.5 # weight of normalised final funds in composite score
|
||||
db_dir: "/tmp/yc_bench_dbs"
|
||||
company_name: "BenchCo"
|
||||
start_date: "01/01/2025" # MM/DD/YYYY (yc-bench convention)
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "yc-bench"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/yc-bench"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-sonnet-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
34
environments/benchmarks/yc_bench/run_eval.sh
Executable file
34
environments/benchmarks/yc_bench/run_eval.sh
Executable file
@@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
# YC-Bench Evaluation
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
#
|
||||
# Run a single preset:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/yc-bench
|
||||
LOG_FILE="logs/yc_bench_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "YC-Bench Evaluation"
|
||||
echo "Log: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
PYTHONUNBUFFERED=1 LOGLEVEL="${LOGLEVEL:-INFO}" \
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
847
environments/benchmarks/yc_bench/yc_bench_env.py
Normal file
847
environments/benchmarks/yc_bench/yc_bench_env.py
Normal file
@@ -0,0 +1,847 @@
|
||||
"""
|
||||
YCBenchEvalEnv -- YC-Bench Long-Horizon Agent Benchmark Environment
|
||||
|
||||
Evaluates agentic LLMs on YC-Bench: a deterministic, long-horizon benchmark
|
||||
where the agent acts as CEO of an AI startup over a simulated 1-3 year run.
|
||||
The agent manages cash flow, employees, tasks, and prestige across 4 domains,
|
||||
interacting exclusively via CLI subprocess calls against a SQLite-backed
|
||||
discrete-event simulation.
|
||||
|
||||
Unlike TerminalBench2 (per-task binary pass/fail), YC-Bench measures sustained
|
||||
multi-turn strategic coherence -- whether an agent can manage compounding
|
||||
decisions over hundreds of turns without going bankrupt.
|
||||
|
||||
This is an eval-only environment. Run via:
|
||||
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
The evaluate flow:
|
||||
1. setup() -- Verifies yc-bench installed, builds eval matrix (preset x seed)
|
||||
2. evaluate() -- Iterates over all runs sequentially through:
|
||||
a. rollout_and_score_eval() -- Per-run agent loop
|
||||
- Initialises a fresh yc-bench simulation via `sim init` (NOT `run`)
|
||||
- Runs HermesAgentLoop with terminal tool only
|
||||
- Reads final SQLite DB to extract score
|
||||
- Returns survival (0/1) + normalised funds score
|
||||
b. Aggregates per-preset and overall metrics
|
||||
c. Logs results via evaluate_log() and wandb
|
||||
|
||||
Key features:
|
||||
- CLI-only interface: agent calls yc-bench subcommands via terminal tool
|
||||
- Deterministic: same seed + preset = same world (SHA256-based RNG)
|
||||
- Multi-dimensional scoring: survival + normalised final funds
|
||||
- Per-preset difficulty breakdown in results
|
||||
- Isolated SQLite DB per run (no cross-run state leakage)
|
||||
|
||||
Requires: pip install hermes-agent[yc-bench]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# System prompt
|
||||
# =============================================================================
|
||||
|
||||
YC_BENCH_SYSTEM_PROMPT = """\
|
||||
You are the autonomous CEO of an early-stage AI startup in a deterministic
|
||||
business simulation. You manage the company exclusively through the `yc-bench`
|
||||
CLI tool. Your primary goal is to **survive** until the simulation horizon ends
|
||||
without going bankrupt, while **maximising final funds**.
|
||||
|
||||
## Simulation Mechanics
|
||||
|
||||
- **Funds**: You start with $250,000 seed capital. Revenue comes from completing
|
||||
tasks. Rewards scale with your prestige: `base × (1 + scale × (prestige − 1))`.
|
||||
- **Domains**: There are 4 skill domains: **research**, **inference**,
|
||||
**data_environment**, and **training**. Each has its own prestige level
|
||||
(1.0-10.0). Higher prestige unlocks better-paying tasks.
|
||||
- **Employees**: You have employees (Junior/Mid/Senior) with domain-specific
|
||||
skill rates. **Throughput splits**: `effective_rate = base_rate / N` where N
|
||||
is the number of active tasks assigned to that employee. Focus beats breadth.
|
||||
- **Payroll**: Deducted automatically on the first business day of each month.
|
||||
Running out of funds = bankruptcy = game over.
|
||||
- **Time**: The simulation runs on business days (Mon-Fri), 09:00-18:00.
|
||||
Time only advances when you call `yc-bench sim resume`.
|
||||
|
||||
## Task Lifecycle
|
||||
|
||||
1. Browse market tasks with `market browse`
|
||||
2. Accept a task with `task accept` (this sets its deadline)
|
||||
3. Assign employees with `task assign`
|
||||
4. Dispatch with `task dispatch` to start work
|
||||
5. Call `sim resume` to advance time and let employees make progress
|
||||
6. Tasks complete when all domain requirements are fulfilled
|
||||
|
||||
**Penalties for failure vary by difficulty preset.** Completing a task on time
|
||||
earns full reward + prestige gain. Missing a deadline or cancelling a task
|
||||
incurs prestige penalties -- cancelling is always more costly than letting a
|
||||
task fail, so cancel only as a last resort.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
### Observe
|
||||
- `yc-bench company status` -- funds, prestige, runway
|
||||
- `yc-bench employee list` -- skills, salary, active tasks
|
||||
- `yc-bench market browse [--domain D] [--required-prestige-lte N]` -- available tasks
|
||||
- `yc-bench task list [--status active|planned]` -- your tasks
|
||||
- `yc-bench task inspect --task-id UUID` -- progress, deadline, assignments
|
||||
- `yc-bench finance ledger [--category monthly_payroll|task_reward]` -- transaction history
|
||||
- `yc-bench report monthly` -- monthly P&L
|
||||
|
||||
### Act
|
||||
- `yc-bench task accept --task-id UUID` -- accept from market
|
||||
- `yc-bench task assign --task-id UUID --employee-id UUID` -- assign employee
|
||||
- `yc-bench task dispatch --task-id UUID` -- start work (needs >=1 assignment)
|
||||
- `yc-bench task cancel --task-id UUID --reason "text"` -- cancel (prestige penalty)
|
||||
- `yc-bench sim resume` -- advance simulation clock
|
||||
|
||||
### Memory (persists across context truncation)
|
||||
- `yc-bench scratchpad read` -- read your persistent notes
|
||||
- `yc-bench scratchpad write --content "text"` -- overwrite notes
|
||||
- `yc-bench scratchpad append --content "text"` -- append to notes
|
||||
- `yc-bench scratchpad clear` -- clear notes
|
||||
|
||||
## Strategy Guidelines
|
||||
|
||||
1. **Specialise in 2-3 domains** to climb the prestige ladder faster and unlock
|
||||
high-reward tasks. Don't spread thin across all 4 domains early on.
|
||||
2. **Focus employees** -- assigning one employee to many tasks halves their
|
||||
throughput per additional task. Keep assignments concentrated.
|
||||
3. **Use the scratchpad** to track your strategy, upcoming deadlines, and
|
||||
employee assignments. This persists even if conversation context is truncated.
|
||||
4. **Monitor runway** -- always know how many months of payroll you can cover.
|
||||
Accept high-reward tasks before payroll dates.
|
||||
5. **Don't over-accept** -- taking too many tasks and missing deadlines cascades
|
||||
into prestige loss, locking you out of profitable contracts.
|
||||
6. Use `finance ledger` and `report monthly` to track revenue trends.
|
||||
|
||||
## Your Turn
|
||||
|
||||
Each turn:
|
||||
1. Call `yc-bench company status` and `yc-bench task list` to orient yourself.
|
||||
2. Check for completed tasks and pending deadlines.
|
||||
3. Browse market for profitable tasks within your prestige level.
|
||||
4. Accept, assign, and dispatch tasks strategically.
|
||||
5. Call `yc-bench sim resume` to advance time.
|
||||
6. Repeat until the simulation ends.
|
||||
|
||||
Think step by step before acting."""
|
||||
|
||||
# Starting funds in cents ($250,000)
|
||||
INITIAL_FUNDS_CENTS = 25_000_000
|
||||
|
||||
# Default horizon per preset (years)
|
||||
_PRESET_HORIZONS = {
|
||||
"tutorial": 1,
|
||||
"easy": 1,
|
||||
"medium": 1,
|
||||
"hard": 1,
|
||||
"nightmare": 1,
|
||||
"fast_test": 1,
|
||||
"default": 3,
|
||||
"high_reward": 1,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the YC-Bench evaluation environment.
|
||||
|
||||
Extends HermesAgentEnvConfig with YC-Bench-specific settings for
|
||||
preset selection, seed control, scoring, and simulation parameters.
|
||||
"""
|
||||
|
||||
presets: List[str] = Field(
|
||||
default=["fast_test", "medium", "hard"],
|
||||
description="YC-Bench preset names to evaluate.",
|
||||
)
|
||||
seeds: List[int] = Field(
|
||||
default=[1, 2, 3],
|
||||
description="Random seeds -- each preset x seed = one run.",
|
||||
)
|
||||
run_timeout: int = Field(
|
||||
default=3600,
|
||||
description="Maximum wall-clock seconds per run. Default 60 minutes.",
|
||||
)
|
||||
survival_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of survival (0/1) in composite score.",
|
||||
)
|
||||
funds_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of normalised final funds in composite score.",
|
||||
)
|
||||
db_dir: str = Field(
|
||||
default="/tmp/yc_bench_dbs",
|
||||
description="Directory for per-run SQLite databases.",
|
||||
)
|
||||
horizon_years: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Simulation horizon in years. If None (default), inferred from "
|
||||
"preset name (1 year for most, 3 for 'default')."
|
||||
),
|
||||
)
|
||||
company_name: str = Field(
|
||||
default="BenchCo",
|
||||
description="Name of the simulated company.",
|
||||
)
|
||||
start_date: str = Field(
|
||||
default="01/01/2025",
|
||||
description="Simulation start date in MM/DD/YYYY format (yc-bench convention).",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scoring helpers
|
||||
# =============================================================================
|
||||
|
||||
def _read_final_score(db_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read final game state from a YC-Bench SQLite database.
|
||||
|
||||
Returns dict with final_funds_cents (int), survived (bool),
|
||||
terminal_reason (str).
|
||||
|
||||
Note: yc-bench table names are plural -- 'companies' not 'company',
|
||||
'sim_events' not 'simulation_log'.
|
||||
"""
|
||||
if not os.path.exists(db_path):
|
||||
logger.warning("DB not found at %s", db_path)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": "db_missing",
|
||||
}
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cur = conn.cursor()
|
||||
|
||||
# Read final funds from the 'companies' table
|
||||
cur.execute("SELECT funds_cents FROM companies LIMIT 1")
|
||||
row = cur.fetchone()
|
||||
funds = row[0] if row else 0
|
||||
|
||||
# Determine terminal reason from 'sim_events' table
|
||||
terminal_reason = "unknown"
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT event_type FROM sim_events "
|
||||
"WHERE event_type IN ('bankruptcy', 'horizon_end') "
|
||||
"ORDER BY scheduled_at DESC LIMIT 1"
|
||||
)
|
||||
event_row = cur.fetchone()
|
||||
if event_row:
|
||||
terminal_reason = event_row[0]
|
||||
except sqlite3.OperationalError:
|
||||
# Table may not exist if simulation didn't progress
|
||||
pass
|
||||
|
||||
survived = funds >= 0 and terminal_reason != "bankruptcy"
|
||||
return {
|
||||
"final_funds_cents": funds,
|
||||
"survived": survived,
|
||||
"terminal_reason": terminal_reason,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to read DB %s: %s", db_path, e)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": f"db_error: {e}",
|
||||
}
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _compute_composite_score(
|
||||
final_funds_cents: int,
|
||||
survived: bool,
|
||||
survival_weight: float = 0.5,
|
||||
funds_weight: float = 0.5,
|
||||
initial_funds_cents: int = INITIAL_FUNDS_CENTS,
|
||||
) -> float:
|
||||
"""
|
||||
Compute composite score from survival and final funds.
|
||||
|
||||
Score = survival_weight * survival_score
|
||||
+ funds_weight * normalised_funds_score
|
||||
|
||||
Normalised funds uses log-scale relative to initial capital:
|
||||
- funds <= 0: 0.0
|
||||
- funds == initial: ~0.15
|
||||
- funds == 10x: ~0.52
|
||||
- funds == 100x: 1.0
|
||||
"""
|
||||
survival_score = 1.0 if survived else 0.0
|
||||
|
||||
if final_funds_cents <= 0:
|
||||
funds_score = 0.0
|
||||
else:
|
||||
max_ratio = 100.0
|
||||
ratio = final_funds_cents / max(initial_funds_cents, 1)
|
||||
funds_score = min(math.log1p(ratio) / math.log1p(max_ratio), 1.0)
|
||||
|
||||
return survival_weight * survival_score + funds_weight * funds_score
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
YC-Bench long-horizon agent benchmark environment (eval-only).
|
||||
|
||||
Each eval item is a (preset, seed) pair. The environment initialises the
|
||||
simulation via ``yc-bench sim init`` (NOT ``yc-bench run`` which would start
|
||||
a competing built-in agent loop). The HermesAgentLoop then drives the
|
||||
interaction by calling individual yc-bench CLI commands via the terminal tool.
|
||||
|
||||
After the agent loop ends, the SQLite DB is read to extract the final score.
|
||||
|
||||
Scoring:
|
||||
composite = 0.5 * survival + 0.5 * normalised_funds
|
||||
"""
|
||||
|
||||
name = "yc-bench"
|
||||
env_config_cls = YCBenchEvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[YCBenchEvalConfig, List[APIServerConfig]]:
|
||||
env_config = YCBenchEvalConfig(
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
max_agent_turns=200,
|
||||
max_token_length=32000,
|
||||
agent_temperature=0.0,
|
||||
system_prompt=YC_BENCH_SYSTEM_PROMPT,
|
||||
terminal_backend="local",
|
||||
terminal_timeout=60,
|
||||
presets=["fast_test", "medium", "hard"],
|
||||
seeds=[1, 2, 3],
|
||||
run_timeout=3600,
|
||||
survival_weight=0.5,
|
||||
funds_weight=0.5,
|
||||
db_dir="/tmp/yc_bench_dbs",
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="yc-bench",
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.6",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Setup
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Verify yc-bench is installed and build the eval matrix."""
|
||||
# Verify yc-bench CLI is available
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["yc-bench", "--help"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise FileNotFoundError
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
raise RuntimeError(
|
||||
"yc-bench CLI not found. Install with:\n"
|
||||
' pip install "hermes-agent[yc-bench]"\n'
|
||||
"Or: git clone https://github.com/collinear-ai/yc-bench "
|
||||
"&& cd yc-bench && pip install -e ."
|
||||
)
|
||||
print("yc-bench CLI verified.")
|
||||
|
||||
# Build eval matrix: preset x seed
|
||||
self.all_eval_items = [
|
||||
{"preset": preset, "seed": seed}
|
||||
for preset in self.config.presets
|
||||
for seed in self.config.seeds
|
||||
]
|
||||
self.iter = 0
|
||||
|
||||
os.makedirs(self.config.db_dir, exist_ok=True)
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
# Streaming JSONL log for crash-safe result persistence
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
self._streaming_file = open(self._streaming_path, "w")
|
||||
self._streaming_lock = threading.Lock()
|
||||
|
||||
print(f"\nYC-Bench eval matrix: {len(self.all_eval_items)} runs")
|
||||
for item in self.all_eval_items:
|
||||
print(f" preset={item['preset']!r} seed={item['seed']}")
|
||||
print(f"Streaming results to: {self._streaming_path}\n")
|
||||
|
||||
def _save_result(self, result: Dict[str, Any]):
|
||||
"""Write a single run result to the streaming JSONL file immediately."""
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(
|
||||
json.dumps(result, ensure_ascii=False, default=str) + "\n"
|
||||
)
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
# Training pipeline stubs (eval-only -- not used)
|
||||
# =========================================================================
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
return (
|
||||
f"A new YC-Bench simulation has been initialized "
|
||||
f"(preset='{preset}', seed={seed}).\n"
|
||||
f"Your company '{self.config.company_name}' is ready.\n\n"
|
||||
"Begin by calling:\n"
|
||||
"1. `yc-bench company status` -- see your starting funds and prestige\n"
|
||||
"2. `yc-bench employee list` -- see your team and their skills\n"
|
||||
"3. `yc-bench market browse --required-prestige-lte 1` -- find tasks "
|
||||
"you can take\n\n"
|
||||
"Then accept 2-3 tasks, assign employees, dispatch them, and call "
|
||||
"`yc-bench sim resume` to advance time. Repeat this loop until the "
|
||||
"simulation ends (horizon reached or bankruptcy)."
|
||||
)
|
||||
|
||||
async def compute_reward(self, item, result, ctx) -> float:
|
||||
return 0.0
|
||||
|
||||
async def collect_trajectories(self, item):
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data):
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Per-run evaluation
|
||||
# =========================================================================
|
||||
|
||||
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Evaluate a single (preset, seed) run.
|
||||
|
||||
1. Sets DATABASE_URL and YC_BENCH_EXPERIMENT env vars
|
||||
2. Initialises the simulation via ``yc-bench sim init`` (NOT ``run``)
|
||||
3. Runs HermesAgentLoop with terminal tool
|
||||
4. Reads SQLite DB to compute final score
|
||||
5. Returns result dict with survival, funds, and composite score
|
||||
"""
|
||||
preset = eval_item["preset"]
|
||||
seed = eval_item["seed"]
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
run_key = f"{preset}_seed{seed}_{run_id}"
|
||||
|
||||
from tqdm import tqdm
|
||||
tqdm.write(f" [START] preset={preset!r} seed={seed} (run_id={run_id})")
|
||||
run_start = time.time()
|
||||
|
||||
# Isolated DB per run -- prevents cross-run state leakage
|
||||
db_path = os.path.join(self.config.db_dir, f"yc_bench_{run_key}.db")
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
||||
os.environ["YC_BENCH_EXPERIMENT"] = preset
|
||||
|
||||
# Determine horizon: explicit config override > preset lookup > default 1
|
||||
horizon = self.config.horizon_years or _PRESET_HORIZONS.get(preset, 1)
|
||||
|
||||
try:
|
||||
# ----------------------------------------------------------
|
||||
# Step 1: Initialise the simulation via CLI
|
||||
# IMPORTANT: We use `sim init`, NOT `yc-bench run`.
|
||||
# `yc-bench run` starts yc-bench's own LLM agent loop (via
|
||||
# LiteLLM), which would compete with our HermesAgentLoop.
|
||||
# `sim init` just sets up the world and returns.
|
||||
# ----------------------------------------------------------
|
||||
init_cmd = [
|
||||
"yc-bench", "sim", "init",
|
||||
"--seed", str(seed),
|
||||
"--start-date", self.config.start_date,
|
||||
"--company-name", self.config.company_name,
|
||||
"--horizon-years", str(horizon),
|
||||
]
|
||||
init_result = subprocess.run(
|
||||
init_cmd, capture_output=True, text=True, timeout=30,
|
||||
)
|
||||
if init_result.returncode != 0:
|
||||
error_msg = (init_result.stderr or init_result.stdout).strip()
|
||||
raise RuntimeError(f"yc-bench sim init failed: {error_msg}")
|
||||
|
||||
tqdm.write(f" Simulation initialized (horizon={horizon}yr)")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 2: Run the HermesAgentLoop
|
||||
# ----------------------------------------------------------
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": YC_BENCH_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": self.format_prompt(eval_item)},
|
||||
]
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=run_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 3: Read final score from the simulation DB
|
||||
# ----------------------------------------------------------
|
||||
score_data = _read_final_score(db_path)
|
||||
final_funds = score_data["final_funds_cents"]
|
||||
survived = score_data["survived"]
|
||||
terminal_reason = score_data["terminal_reason"]
|
||||
|
||||
composite = _compute_composite_score(
|
||||
final_funds_cents=final_funds,
|
||||
survived=survived,
|
||||
survival_weight=self.config.survival_weight,
|
||||
funds_weight=self.config.funds_weight,
|
||||
)
|
||||
|
||||
elapsed = time.time() - run_start
|
||||
status = "SURVIVED" if survived else "BANKRUPT"
|
||||
if final_funds >= 0:
|
||||
funds_str = f"${final_funds / 100:,.0f}"
|
||||
else:
|
||||
funds_str = f"-${abs(final_funds) / 100:,.0f}"
|
||||
|
||||
tqdm.write(
|
||||
f" [{status}] preset={preset!r} seed={seed} "
|
||||
f"funds={funds_str} score={composite:.3f} "
|
||||
f"turns={result.turns_used} ({elapsed:.0f}s)"
|
||||
)
|
||||
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": survived,
|
||||
"final_funds_cents": final_funds,
|
||||
"final_funds_usd": final_funds / 100,
|
||||
"terminal_reason": terminal_reason,
|
||||
"composite_score": composite,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"elapsed_seconds": elapsed,
|
||||
"db_path": db_path,
|
||||
"messages": result.messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - run_start
|
||||
logger.error("Run %s failed: %s", run_key, e, exc_info=True)
|
||||
tqdm.write(
|
||||
f" [ERROR] preset={preset!r} seed={seed}: {e} ({elapsed:.0f}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"error: {e}",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": str(e),
|
||||
"elapsed_seconds": elapsed,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
# =========================================================================
|
||||
# Evaluate
|
||||
# =========================================================================
|
||||
|
||||
async def _run_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""Wrap a single rollout with a wall-clock timeout."""
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
timeout=self.config.run_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
tqdm.write(
|
||||
f" [TIMEOUT] preset={preset!r} seed={seed} "
|
||||
f"(exceeded {self.config.run_timeout}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"timeout ({self.config.run_timeout}s)",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": "timeout",
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run YC-Bench evaluation over all (preset, seed) combinations.
|
||||
|
||||
Runs sequentially -- each run is 100-500 turns, parallelising would
|
||||
be prohibitively expensive and cause env var conflicts.
|
||||
"""
|
||||
start_time = time.time()
|
||||
from tqdm import tqdm
|
||||
|
||||
# --- tqdm-compatible logging handler (TB2 pattern) ---
|
||||
class _TqdmHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
tqdm.write(self.format(record))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
root = logging.getLogger()
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(levelname)s %(name)s: %(message)s")
|
||||
)
|
||||
root.handlers = [handler]
|
||||
for noisy in ("httpx", "openai"):
|
||||
logging.getLogger(noisy).setLevel(logging.WARNING)
|
||||
|
||||
# --- Print config summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting YC-Bench Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Presets: {self.config.presets}")
|
||||
print(f" Seeds: {self.config.seeds}")
|
||||
print(f" Total runs: {len(self.all_eval_items)}")
|
||||
print(f" Max turns/run: {self.config.max_agent_turns}")
|
||||
print(f" Run timeout: {self.config.run_timeout}s")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
results = []
|
||||
pbar = tqdm(
|
||||
total=len(self.all_eval_items), desc="YC-Bench", dynamic_ncols=True
|
||||
)
|
||||
|
||||
try:
|
||||
for item in self.all_eval_items:
|
||||
result = await self._run_with_timeout(item)
|
||||
results.append(result)
|
||||
survived_count = sum(1 for r in results if r.get("survived"))
|
||||
pbar.set_postfix_str(
|
||||
f"survived={survived_count}/{len(results)}"
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
tqdm.write("\n[INTERRUPTED] Stopping evaluation...")
|
||||
pbar.close()
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
return
|
||||
|
||||
pbar.close()
|
||||
end_time = time.time()
|
||||
|
||||
# --- Compute metrics ---
|
||||
valid = [r for r in results if r is not None]
|
||||
if not valid:
|
||||
print("Warning: No valid results.")
|
||||
return
|
||||
|
||||
total = len(valid)
|
||||
survived_total = sum(1 for r in valid if r.get("survived"))
|
||||
survival_rate = survived_total / total if total else 0.0
|
||||
avg_score = (
|
||||
sum(r.get("composite_score", 0) for r in valid) / total
|
||||
if total
|
||||
else 0.0
|
||||
)
|
||||
|
||||
preset_results: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in valid:
|
||||
preset_results[r["preset"]].append(r)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/survival_rate": survival_rate,
|
||||
"eval/avg_composite_score": avg_score,
|
||||
"eval/total_runs": total,
|
||||
"eval/survived_runs": survived_total,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
}
|
||||
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
key = preset.replace("-", "_")
|
||||
eval_metrics[f"eval/survival_rate_{key}"] = ps / pt if pt else 0
|
||||
eval_metrics[f"eval/avg_score_{key}"] = pa
|
||||
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# --- Print summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("YC-Bench Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(
|
||||
f"Overall survival rate: {survival_rate:.1%} "
|
||||
f"({survived_total}/{total})"
|
||||
)
|
||||
print(f"Average composite score: {avg_score:.4f}")
|
||||
print(f"Evaluation time: {end_time - start_time:.1f}s")
|
||||
|
||||
print("\nPer-preset breakdown:")
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
print(f" {preset}: {ps}/{pt} survived avg_score={pa:.4f}")
|
||||
for r in items:
|
||||
status = "SURVIVED" if r.get("survived") else "BANKRUPT"
|
||||
funds = r.get("final_funds_usd", 0)
|
||||
print(
|
||||
f" seed={r['seed']} [{status}] "
|
||||
f"${funds:,.0f} "
|
||||
f"score={r.get('composite_score', 0):.3f}"
|
||||
)
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# --- Log results ---
|
||||
samples = [
|
||||
{k: v for k, v in r.items() if k != "messages"} for r in valid
|
||||
]
|
||||
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error logging results: {e}")
|
||||
|
||||
# --- Cleanup (TB2 pattern) ---
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
print(f"Results saved to: {self._streaming_path}")
|
||||
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from environments.agent_loop import _tool_executor
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# =========================================================================
|
||||
# Wandb logging
|
||||
# =========================================================================
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log YC-Bench-specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
for k, v in self.eval_metrics:
|
||||
wandb_metrics[k] = v
|
||||
self.eval_metrics = []
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
YCBenchEvalEnv.cli()
|
||||
672
environments/hermes_base_env.py
Normal file
672
environments/hermes_base_env.py
Normal file
@@ -0,0 +1,672 @@
|
||||
"""
|
||||
HermesAgentBaseEnv -- Abstract Base Environment for Hermes-Agent + Atropos
|
||||
|
||||
Provides the Atropos integration plumbing that all hermes-agent environments share:
|
||||
- Two-mode operation (OpenAI server for Phase 1, VLLM ManagedServer for Phase 2)
|
||||
- Per-group toolset/distribution resolution
|
||||
- Agent loop orchestration via HermesAgentLoop
|
||||
- ToolContext creation for reward functions
|
||||
- ScoredDataGroup construction from ManagedServer state
|
||||
|
||||
Subclasses only need to implement:
|
||||
setup() -- Load dataset, initialize state
|
||||
get_next_item() -- Return the next item from the dataset
|
||||
format_prompt() -- Convert a dataset item into the user message
|
||||
compute_reward() -- Score the rollout (has full ToolContext access)
|
||||
evaluate() -- Periodic evaluation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
# Ensure the hermes-agent repo root is on sys.path so that imports like
|
||||
# `from model_tools import ...` and `from environments.X import ...` work
|
||||
# regardless of where the script is invoked from.
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
# Load API keys from hermes-agent/.env so all environments can access them
|
||||
_env_path = _repo_root / ".env"
|
||||
if _env_path.exists():
|
||||
load_dotenv(dotenv_path=_env_path)
|
||||
|
||||
# Apply monkey patches for async-safe tool operation inside Atropos's event loop.
|
||||
# This patches SwerexModalEnvironment to use a background thread instead of
|
||||
# asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
|
||||
from environments.patches import apply_patches
|
||||
apply_patches()
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
ScoredDataItem,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_manager import (
|
||||
APIServerConfig,
|
||||
ServerBaseline,
|
||||
ServerManager,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
# Import hermes-agent toolset infrastructure
|
||||
from model_tools import get_tool_definitions
|
||||
from toolset_distributions import sample_toolsets_from_distribution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"""
|
||||
Configuration for hermes-agent Atropos environments.
|
||||
|
||||
Extends BaseEnvConfig with agent-specific settings for toolsets,
|
||||
terminal backend, dataset loading, and tool call parsing.
|
||||
"""
|
||||
|
||||
# --- Toolset configuration ---
|
||||
# Mutually exclusive: use either enabled_toolsets OR distribution
|
||||
enabled_toolsets: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Explicit list of hermes toolsets to enable (e.g., ['terminal', 'file', 'web']). "
|
||||
"If None and distribution is also None, all available toolsets are enabled.",
|
||||
)
|
||||
disabled_toolsets: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Toolsets to disable. Applied as a filter on top of enabled_toolsets or distribution.",
|
||||
)
|
||||
distribution: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Name of a toolset distribution from toolset_distributions.py "
|
||||
"(e.g., 'development', 'terminal_tasks'). Sampled once per group. "
|
||||
"Mutually exclusive with enabled_toolsets.",
|
||||
)
|
||||
|
||||
# --- Agent loop configuration ---
|
||||
max_agent_turns: int = Field(
|
||||
default=30,
|
||||
description="Maximum number of LLM calls (tool-calling iterations) per rollout.",
|
||||
)
|
||||
system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="System prompt for the agent. Tools are handled via the tools= parameter, "
|
||||
"not embedded in the prompt text.",
|
||||
)
|
||||
agent_temperature: float = Field(
|
||||
default=1.0,
|
||||
description="Sampling temperature for agent generation during rollouts.",
|
||||
)
|
||||
|
||||
# --- Terminal backend ---
|
||||
terminal_backend: str = Field(
|
||||
default="local",
|
||||
description="Terminal backend: 'local', 'docker', 'modal', 'daytona', 'ssh', 'singularity'. "
|
||||
"Modal or Daytona recommended for production RL (cloud isolation per rollout).",
|
||||
)
|
||||
terminal_timeout: int = Field(
|
||||
default=120,
|
||||
description="Per-command timeout in seconds for terminal tool calls. "
|
||||
"Commands exceeding this are killed. Increase for tasks with long-running "
|
||||
"commands (compilation, pip install, etc.).",
|
||||
)
|
||||
terminal_lifetime: int = Field(
|
||||
default=3600,
|
||||
description="Sandbox inactivity lifetime in seconds. The cleanup thread kills "
|
||||
"sandboxes that have been idle longer than this. Must be longer than "
|
||||
"the longest gap between tool calls (e.g., waiting for LLM response).",
|
||||
)
|
||||
|
||||
# --- Dataset ---
|
||||
dataset_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="HuggingFace dataset name. Optional if tasks are defined inline.",
|
||||
)
|
||||
dataset_split: str = Field(
|
||||
default="train",
|
||||
description="Dataset split to use.",
|
||||
)
|
||||
prompt_field: str = Field(
|
||||
default="prompt",
|
||||
description="Which field in the dataset contains the prompt.",
|
||||
)
|
||||
|
||||
# --- Thread pool ---
|
||||
tool_pool_size: int = Field(
|
||||
default=128,
|
||||
description="Thread pool size for tool execution. Each concurrent task needs a "
|
||||
"thread for tool calls. Must be large enough for parallel evaluation. "
|
||||
"Too small = thread pool starvation.",
|
||||
)
|
||||
|
||||
# --- Phase 2: Tool call parsing ---
|
||||
tool_call_parser: str = Field(
|
||||
default="hermes",
|
||||
description="Tool call parser name for Phase 2 (VLLM server type). "
|
||||
"Ignored in Phase 1 (OpenAI server type where VLLM parses natively). "
|
||||
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
||||
)
|
||||
|
||||
# --- Provider-specific parameters ---
|
||||
# Passed as extra_body to the OpenAI client's chat.completions.create() call.
|
||||
# Useful for OpenRouter provider preferences, transforms, route settings, etc.
|
||||
# Example YAML:
|
||||
# extra_body:
|
||||
# provider:
|
||||
# ignore: ["DeepInfra", "Fireworks"]
|
||||
# order: ["Together"]
|
||||
# transforms: ["middle-out"]
|
||||
extra_body: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Extra body parameters passed to the OpenAI client's "
|
||||
"chat.completions.create(). Used for OpenRouter provider preferences, "
|
||||
"transforms, and other provider-specific settings.",
|
||||
)
|
||||
|
||||
|
||||
class HermesAgentBaseEnv(BaseEnv):
|
||||
"""
|
||||
Abstract base environment for hermes-agent Atropos integration.
|
||||
|
||||
Handles two modes of operation:
|
||||
- Phase 1 (OpenAI server type): Uses server.chat_completion() directly.
|
||||
The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing
|
||||
and reasoning extraction natively. DummyManagedServer provides placeholder
|
||||
tokens. Good for SFT data gen, verifier testing, evaluation.
|
||||
|
||||
- Phase 2 (VLLM server type): Uses ManagedServer for exact token IDs + logprobs
|
||||
via /generate. Client-side tool call parser reconstructs structured tool_calls
|
||||
from raw output. Full RL training capability.
|
||||
|
||||
Subclasses must implement:
|
||||
setup() -- Load dataset, initialize state
|
||||
get_next_item() -- Return the next item to roll out
|
||||
format_prompt() -- Convert a dataset item into the user message string
|
||||
compute_reward() -- Score the rollout using ToolContext
|
||||
evaluate() -- Periodic evaluation
|
||||
"""
|
||||
|
||||
name: Optional[str] = "hermes-agent"
|
||||
env_config_cls = HermesAgentEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HermesAgentEnvConfig,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
|
||||
# Set terminal environment variables so hermes tools pick them up.
|
||||
# These can all be overridden per-environment via config fields instead
|
||||
# of requiring users to set shell env vars.
|
||||
if config.terminal_backend:
|
||||
os.environ["TERMINAL_ENV"] = config.terminal_backend
|
||||
os.environ["TERMINAL_TIMEOUT"] = str(config.terminal_timeout)
|
||||
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(config.terminal_lifetime)
|
||||
print(
|
||||
f"🖥️ Terminal: backend={config.terminal_backend}, "
|
||||
f"timeout={config.terminal_timeout}s, lifetime={config.terminal_lifetime}s"
|
||||
)
|
||||
|
||||
# Resize the agent loop's thread pool for tool execution.
|
||||
# This must be large enough for the number of concurrent tasks
|
||||
# (e.g., 89 parallel TB2 eval tasks each need a thread for tool calls).
|
||||
from environments.agent_loop import resize_tool_pool
|
||||
resize_tool_pool(config.tool_pool_size)
|
||||
|
||||
# Current group's resolved tools (set in collect_trajectories)
|
||||
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
|
||||
|
||||
# Tool error tracking for wandb logging
|
||||
self._tool_error_buffer: List[Dict[str, Any]] = []
|
||||
|
||||
# =========================================================================
|
||||
# Toolset resolution (per-group)
|
||||
# =========================================================================
|
||||
|
||||
def _resolve_tools_for_group(self) -> Tuple[List[Dict[str, Any]], Set[str]]:
|
||||
"""
|
||||
Resolve toolsets for a group. Called once in collect_trajectories(),
|
||||
then shared by all collect_trajectory() calls in the group.
|
||||
|
||||
If distribution is set, samples probabilistically.
|
||||
If enabled_toolsets is set, uses that explicit list.
|
||||
disabled_toolsets is applied as a filter on top.
|
||||
|
||||
Returns:
|
||||
(tool_schemas, valid_tool_names) tuple
|
||||
"""
|
||||
config = self.config
|
||||
|
||||
if config.distribution:
|
||||
group_toolsets = sample_toolsets_from_distribution(config.distribution)
|
||||
logger.info("Sampled toolsets from '%s': %s", config.distribution, group_toolsets)
|
||||
else:
|
||||
group_toolsets = config.enabled_toolsets # None means "all available"
|
||||
if group_toolsets is None:
|
||||
logger.warning(
|
||||
"enabled_toolsets is None -- loading ALL tools including messaging. "
|
||||
"Set explicit enabled_toolsets for RL training."
|
||||
)
|
||||
|
||||
tools = get_tool_definitions(
|
||||
enabled_toolsets=group_toolsets,
|
||||
disabled_toolsets=config.disabled_toolsets,
|
||||
quiet_mode=True,
|
||||
)
|
||||
|
||||
valid_names = {t["function"]["name"] for t in tools} if tools else set()
|
||||
logger.info("Resolved %d tools for group: %s", len(valid_names), sorted(valid_names))
|
||||
return tools, valid_names
|
||||
|
||||
# =========================================================================
|
||||
# Server mode detection
|
||||
# =========================================================================
|
||||
|
||||
def _use_managed_server(self) -> bool:
|
||||
"""
|
||||
Determine if we should use ManagedServer (Phase 2) or direct server (Phase 1).
|
||||
|
||||
Phase 2 (ManagedServer) is used when the server type is 'vllm' or 'sglang',
|
||||
which go through the /generate endpoint for exact token tracking.
|
||||
|
||||
Phase 1 (direct server) is used for 'openai' server type, which uses
|
||||
/v1/chat/completions with native tool call parsing.
|
||||
"""
|
||||
if not self.server.servers:
|
||||
return False
|
||||
|
||||
server = self.server.servers[0]
|
||||
# If the server is an OpenAI server (not VLLM/SGLang), use direct mode
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
return not isinstance(server, OpenAIServer)
|
||||
|
||||
# =========================================================================
|
||||
# Core Atropos integration
|
||||
# =========================================================================
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[
|
||||
Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
|
||||
List[Item],
|
||||
]:
|
||||
"""
|
||||
Override collect_trajectories to resolve toolsets once per group,
|
||||
then delegate to the standard group-level collection.
|
||||
|
||||
The default BaseEnv.collect_trajectories() calls collect_trajectory()
|
||||
group_size times in parallel. We resolve tools once here and store
|
||||
them for all those calls to use.
|
||||
"""
|
||||
# Resolve toolsets for this group (shared by all rollouts in the group)
|
||||
self._current_group_tools = self._resolve_tools_for_group()
|
||||
|
||||
# Delegate to the default implementation which calls collect_trajectory()
|
||||
# group_size times via asyncio.gather
|
||||
return await super().collect_trajectories(item)
|
||||
|
||||
# =========================================================================
|
||||
# Wandb rollout display -- format trajectories nicely
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
def _format_trajectory_for_display(messages: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format a conversation's messages into a readable trajectory string
|
||||
for wandb rollout tables. Shows tool calls, tool results, and reasoning
|
||||
in a structured way instead of raw token decoding.
|
||||
"""
|
||||
parts = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
parts.append(f"[SYSTEM]\n{content}")
|
||||
|
||||
elif role == "user":
|
||||
parts.append(f"[USER]\n{content}")
|
||||
|
||||
elif role == "assistant":
|
||||
# Show reasoning if present
|
||||
reasoning = msg.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
# Truncate long reasoning for display
|
||||
if len(reasoning) > 300:
|
||||
reasoning = reasoning[:300] + "..."
|
||||
parts.append(f"[ASSISTANT thinking]\n{reasoning}")
|
||||
|
||||
# Show content
|
||||
if content:
|
||||
parts.append(f"[ASSISTANT]\n{content}")
|
||||
|
||||
# Show tool calls
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "?")
|
||||
args = func.get("arguments", "{}")
|
||||
# Truncate long arguments for display
|
||||
if len(args) > 200:
|
||||
args = args[:200] + "..."
|
||||
parts.append(f"[TOOL CALL] {name}({args})")
|
||||
|
||||
elif role == "tool":
|
||||
tool_id = msg.get("tool_call_id", "")
|
||||
result = content
|
||||
# Truncate long tool results for display
|
||||
if len(result) > 500:
|
||||
result = result[:500] + "..."
|
||||
parts.append(f"[TOOL RESULT] {result}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
async def add_rollouts_for_wandb(
|
||||
self,
|
||||
scored_data,
|
||||
item=None,
|
||||
):
|
||||
"""
|
||||
Override to show formatted trajectories with tool calls visible,
|
||||
instead of raw token decoding which loses all structure.
|
||||
"""
|
||||
num_keep = self.config.num_rollouts_per_group_for_logging
|
||||
if num_keep == -1:
|
||||
num_keep = self.config.group_size
|
||||
|
||||
group = []
|
||||
for i in range(min(num_keep, len(scored_data.get("scores", [])))):
|
||||
score = scored_data["scores"][i]
|
||||
|
||||
# Use messages if available for rich display
|
||||
messages = None
|
||||
if scored_data.get("messages") and i < len(scored_data["messages"]):
|
||||
messages = scored_data["messages"][i]
|
||||
|
||||
if messages:
|
||||
text = self._format_trajectory_for_display(messages)
|
||||
elif scored_data.get("tokens") and i < len(scored_data["tokens"]):
|
||||
text = self.tokenizer.decode(scored_data["tokens"][i])
|
||||
else:
|
||||
text = "(no data)"
|
||||
|
||||
group.append((text, score))
|
||||
|
||||
self.rollouts_for_wandb.append(group)
|
||||
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
||||
self.rollouts_for_wandb.pop(0)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log base metrics including tool errors to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Log tool error stats
|
||||
if self._tool_error_buffer:
|
||||
wandb_metrics["train/tool_errors_count"] = len(self._tool_error_buffer)
|
||||
|
||||
# Log error details as a summary string (tables can crash wandb on tmp cleanup)
|
||||
error_summaries = []
|
||||
for err in self._tool_error_buffer:
|
||||
error_summaries.append(
|
||||
f"[turn {err['turn']}] {err['tool']}({err['args'][:80]}) -> {err['error'][:150]}"
|
||||
)
|
||||
wandb_metrics["train/tool_error_details"] = "\n".join(error_summaries)
|
||||
|
||||
# Also print to stdout for immediate visibility
|
||||
for summary in error_summaries:
|
||||
print(f" Tool Error: {summary}")
|
||||
|
||||
self._tool_error_buffer = []
|
||||
else:
|
||||
wandb_metrics["train/tool_errors_count"] = 0
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Run a single rollout: agent loop + reward computation.
|
||||
|
||||
This is called group_size times in parallel by collect_trajectories().
|
||||
Each call gets its own task_id for terminal/browser session isolation.
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Get group-level tools (resolved once in collect_trajectories)
|
||||
if self._current_group_tools is None:
|
||||
# Fallback: resolve per-trajectory if called outside collect_trajectories
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
else:
|
||||
tools, valid_names = self._current_group_tools
|
||||
|
||||
# Build initial messages
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
# Run the agent loop
|
||||
result: AgentResult
|
||||
if self._use_managed_server():
|
||||
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
|
||||
# Load the tool call parser from registry based on config
|
||||
from environments.tool_call_parsers import get_parser
|
||||
try:
|
||||
tc_parser = get_parser(self.config.tool_call_parser)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"Tool call parser '%s' not found, falling back to 'hermes'",
|
||||
self.config.tool_call_parser,
|
||||
)
|
||||
tc_parser = get_parser("hermes")
|
||||
|
||||
try:
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer,
|
||||
tool_call_parser=tc_parser,
|
||||
) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
except NotImplementedError:
|
||||
# DummyManagedServer not allowed -- fall back to Phase 1
|
||||
logger.warning(
|
||||
"ManagedServer not available (OpenAI server?). "
|
||||
"Falling back to direct server mode."
|
||||
)
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Skip reward computation if the agent loop produced no meaningful work
|
||||
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
|
||||
# just to verify files that were never created.
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# Compute reward using ToolContext (gives verifier full tool access)
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
# Track tool errors for wandb logging
|
||||
if result.tool_errors:
|
||||
for err in result.tool_errors:
|
||||
self._tool_error_buffer.append({
|
||||
"turn": err.turn,
|
||||
"tool": err.tool_name,
|
||||
"args": err.arguments[:150],
|
||||
"error": err.error[:300],
|
||||
"result": err.tool_result[:300],
|
||||
})
|
||||
|
||||
# Build ScoredDataItem from ManagedServer state
|
||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
||||
# Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
|
||||
nodes = (result.managed_state or {}).get("nodes", [])
|
||||
|
||||
if nodes:
|
||||
# Phase 2 (or DummyManagedServer): use actual node data
|
||||
node = nodes[-1] # Final sequence node = full trajectory
|
||||
scored_item: Dict[str, Any] = {
|
||||
"tokens": node.tokens,
|
||||
"masks": node.masked_tokens,
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Include logprobs if available (Phase 2)
|
||||
if hasattr(node, "logprobs") and node.logprobs:
|
||||
scored_item["advantages"] = None # Computed by trainer
|
||||
scored_item["ref_logprobs"] = None
|
||||
else:
|
||||
# Phase 1 with no managed state: create placeholder tokens
|
||||
# so the data pipeline doesn't break. These are NOT suitable
|
||||
# for training but allow process mode (SFT data gen) to work.
|
||||
# Tokenize the full conversation to get approximate tokens.
|
||||
full_text = "\n".join(
|
||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
||||
)
|
||||
if self.tokenizer:
|
||||
tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
|
||||
else:
|
||||
tokens = list(range(min(len(full_text) // 4, 128)))
|
||||
|
||||
scored_item = {
|
||||
"tokens": tokens,
|
||||
"masks": [-100] + tokens[1:], # Mask first token as prompt
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Always include messages for wandb rollout display and data logging
|
||||
scored_item["messages"] = result.messages
|
||||
|
||||
return scored_item, []
|
||||
|
||||
# =========================================================================
|
||||
# Abstract methods -- subclasses must implement
|
||||
# =========================================================================
|
||||
|
||||
@abstractmethod
|
||||
async def setup(self):
|
||||
"""
|
||||
Load dataset, initialize state.
|
||||
|
||||
Called once when the environment starts. Typical implementation:
|
||||
self.dataset = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self.iter = 0
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_next_item(self) -> Item:
|
||||
"""
|
||||
Return the next item from the dataset for rollout.
|
||||
|
||||
Called by the base env's main loop to get items for workers.
|
||||
Should cycle through the dataset.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def format_prompt(self, item: Item) -> str:
|
||||
"""
|
||||
Convert a dataset item into the user message for the agent.
|
||||
|
||||
Args:
|
||||
item: Dataset item (dict, tuple, etc.)
|
||||
|
||||
Returns:
|
||||
The prompt string to send to the agent
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def compute_reward(
|
||||
self, item: Item, result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score the rollout. Has full access to:
|
||||
- item: the original dataset item (ground truth, test commands, etc.)
|
||||
- result: AgentResult with full messages, turn count, reasoning, etc.
|
||||
- ctx: ToolContext -- call ANY hermes-agent tool (terminal, file, web,
|
||||
browser, vision...) scoped to this rollout's sandbox. Nothing
|
||||
is off-limits.
|
||||
|
||||
Args:
|
||||
item: The dataset item that was rolled out
|
||||
result: The agent's rollout result
|
||||
ctx: ToolContext with full tool access for verification
|
||||
|
||||
Returns:
|
||||
Reward float (typically 0.0 to 1.0, but any float is valid)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Periodic evaluation. Called every steps_per_eval steps.
|
||||
|
||||
Typical implementation runs the agent on a held-out eval set
|
||||
and logs metrics via wandb/evaluate_log.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
0
environments/hermes_swe_env/__init__.py
Normal file
0
environments/hermes_swe_env/__init__.py
Normal file
34
environments/hermes_swe_env/default.yaml
Normal file
34
environments/hermes_swe_env/default.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
# SWE Environment -- Default Configuration
|
||||
#
|
||||
# SWE-bench style tasks with Modal sandboxes for cloud isolation.
|
||||
# Uses terminal + file + web toolsets.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/hermes_swe_env/hermes_swe_env.py serve \
|
||||
# --config environments/hermes_swe_env/default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file", "web"]
|
||||
max_agent_turns: 30
|
||||
max_token_length: 4096
|
||||
group_size: 4
|
||||
terminal_backend: "modal"
|
||||
tool_call_parser: "hermes"
|
||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
dataset_name: "bigcode/humanevalpack"
|
||||
dataset_split: "test"
|
||||
prompt_field: "prompt"
|
||||
steps_per_eval: 50
|
||||
total_steps: 500
|
||||
use_wandb: true
|
||||
wandb_name: "hermes-swe"
|
||||
system_prompt: >
|
||||
You are a skilled software engineer. You have access to a terminal,
|
||||
file tools, and web search. Use these tools to complete the coding task.
|
||||
Write clean, working code and verify it runs correctly before finishing.
|
||||
|
||||
openai:
|
||||
base_url: "http://localhost:8000/v1"
|
||||
model_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
server_type: "openai"
|
||||
api_key: ""
|
||||
229
environments/hermes_swe_env/hermes_swe_env.py
Normal file
229
environments/hermes_swe_env/hermes_swe_env.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
HermesSweEnv -- SWE-Bench Style Environment with Modal Sandboxes
|
||||
|
||||
A concrete environment for software engineering tasks where the model writes code
|
||||
and the reward function runs tests to verify correctness. Uses Modal terminal
|
||||
backend for cloud-isolated sandboxes per rollout.
|
||||
|
||||
The reward function uses ToolContext.terminal() to run test commands in the same
|
||||
Modal sandbox the model used during its agentic loop. All filesystem state from
|
||||
the model's tool calls is preserved for verification.
|
||||
|
||||
Usage:
|
||||
# Phase 1: OpenAI server type
|
||||
vllm serve YourModel --tool-parser hermes
|
||||
run-api
|
||||
python environments/hermes_swe_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai \\
|
||||
--env.dataset_name bigcode/humanevalpack \\
|
||||
--env.terminal_backend modal
|
||||
|
||||
# Phase 2: VLLM server type (full RL training)
|
||||
python environments/hermes_swe_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type vllm \\
|
||||
--env.tool_call_parser hermes \\
|
||||
--env.terminal_backend modal
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HermesSweEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults for SWE-bench style tasks."""
|
||||
|
||||
pass # Inherits all fields, overrides defaults in config_init
|
||||
|
||||
|
||||
class HermesSweEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
SWE-bench style environment using Modal terminal backend.
|
||||
|
||||
The model gets a coding task, uses terminal + file + web tools to solve it,
|
||||
and the reward function runs tests in the same Modal sandbox to verify.
|
||||
|
||||
Subclass this for specific SWE datasets (HumanEval, SWE-bench, etc.)
|
||||
and customize format_prompt() and compute_reward() as needed.
|
||||
"""
|
||||
|
||||
name = "hermes-swe"
|
||||
env_config_cls = HermesSweEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[HermesSweEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for the SWE environment.
|
||||
|
||||
Uses Modal terminal backend for cloud isolation and terminal + file + web toolsets.
|
||||
"""
|
||||
env_config = HermesSweEnvConfig(
|
||||
# Toolsets: terminal for running code, file for reading/writing, web for docs
|
||||
enabled_toolsets=["terminal", "file", "web"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings -- SWE tasks need more turns
|
||||
max_agent_turns=30,
|
||||
max_token_length=4096,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a skilled software engineer. You have access to a terminal, "
|
||||
"file tools, and web search. Use these tools to complete the coding task. "
|
||||
"Write clean, working code and verify it runs correctly before finishing."
|
||||
),
|
||||
# Modal backend for cloud-isolated sandboxes
|
||||
terminal_backend="modal",
|
||||
# Dataset -- override via CLI for your specific SWE dataset
|
||||
dataset_name="bigcode/humanevalpack",
|
||||
dataset_split="test",
|
||||
prompt_field="prompt",
|
||||
# Atropos settings
|
||||
group_size=4,
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
tool_call_parser="hermes",
|
||||
steps_per_eval=50,
|
||||
total_steps=500,
|
||||
use_wandb=True,
|
||||
wandb_name="hermes-swe",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="http://localhost:8000/v1",
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
server_type="openai", # Phase 1; switch to "vllm" for Phase 2
|
||||
api_key="",
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Load the SWE dataset."""
|
||||
if self.config.dataset_name:
|
||||
self.dataset = load_dataset(
|
||||
self.config.dataset_name, split=self.config.dataset_split
|
||||
)
|
||||
else:
|
||||
# Placeholder if no dataset specified
|
||||
self.dataset = []
|
||||
self.iter = 0
|
||||
self.reward_buffer: List[float] = []
|
||||
|
||||
async def get_next_item(self) -> Dict[str, Any]:
|
||||
"""Cycle through the SWE dataset."""
|
||||
if not self.dataset:
|
||||
raise ValueError("No dataset loaded. Set dataset_name in config.")
|
||||
item = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Format the SWE task prompt.
|
||||
|
||||
Override this in subclasses for different dataset formats.
|
||||
Default assumes the dataset has a 'prompt' field and optionally a 'test' field.
|
||||
"""
|
||||
prompt = item.get(self.config.prompt_field, "")
|
||||
|
||||
# If the dataset has test information, include it in the prompt
|
||||
test_info = item.get("test", item.get("test_code", item.get("tests", "")))
|
||||
if test_info:
|
||||
prompt += f"\n\nTests to pass:\n{test_info}"
|
||||
|
||||
return prompt
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, Any], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score by running tests in the model's Modal sandbox.
|
||||
|
||||
Default implementation:
|
||||
- If the dataset item has a 'test' or 'test_code' field, run it
|
||||
- Check exit code: 0 = pass, non-zero = fail
|
||||
- Partial credit for file creation
|
||||
|
||||
Override this in subclasses for more sophisticated reward logic.
|
||||
"""
|
||||
# Find the test command from the dataset item
|
||||
test_code = item.get("test", item.get("test_code", item.get("tests", "")))
|
||||
|
||||
if test_code:
|
||||
# Run the test in the model's sandbox
|
||||
test_result = ctx.terminal(
|
||||
f'cd /workspace && python3 -c "{test_code}"', timeout=60
|
||||
)
|
||||
|
||||
if test_result["exit_code"] == 0:
|
||||
self.reward_buffer.append(1.0)
|
||||
return 1.0
|
||||
|
||||
# Partial credit: check if the model created any Python files
|
||||
file_check = ctx.terminal("find /workspace -name '*.py' -newer /tmp/.start_marker 2>/dev/null | head -5")
|
||||
if file_check["exit_code"] == 0 and file_check.get("output", "").strip():
|
||||
self.reward_buffer.append(0.1)
|
||||
return 0.1
|
||||
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Run evaluation on a held-out set.
|
||||
|
||||
Override for dataset-specific evaluation logic.
|
||||
"""
|
||||
start_time = time.time()
|
||||
end_time = time.time()
|
||||
|
||||
eval_metrics = {"eval/placeholder": 0.0}
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log SWE-specific metrics."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / len(
|
||||
self.reward_buffer
|
||||
)
|
||||
wandb_metrics["train/pass_rate"] = sum(
|
||||
1 for r in self.reward_buffer if r == 1.0
|
||||
) / len(self.reward_buffer)
|
||||
self.reward_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HermesSweEnv.cli()
|
||||
188
environments/patches.py
Normal file
188
environments/patches.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Monkey patches for making hermes-agent tools work inside async frameworks (Atropos).
|
||||
|
||||
Problem:
|
||||
Some tools use asyncio.run() internally (e.g., mini-swe-agent's Modal backend,
|
||||
web_extract). This crashes when called from inside Atropos's event loop because
|
||||
asyncio.run() can't be nested.
|
||||
|
||||
Solution:
|
||||
Replace the problematic methods with versions that use a dedicated background
|
||||
thread with its own event loop. The calling code sees the same sync interface --
|
||||
call a function, get a result -- but internally the async work happens on a
|
||||
separate thread that doesn't conflict with Atropos's loop.
|
||||
|
||||
These patches are safe for normal CLI use too: when there's no running event
|
||||
loop, the behavior is identical (the background thread approach works regardless).
|
||||
|
||||
What gets patched:
|
||||
- SwerexModalEnvironment.__init__ -- creates Modal deployment on a background thread
|
||||
- SwerexModalEnvironment.execute -- runs commands on the same background thread
|
||||
- SwerexModalEnvironment.stop -- stops deployment on the background thread
|
||||
|
||||
Usage:
|
||||
Call apply_patches() once at import time (done automatically by hermes_base_env.py).
|
||||
This is idempotent -- calling it multiple times is safe.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_patches_applied = False
|
||||
|
||||
|
||||
class _AsyncWorker:
|
||||
"""
|
||||
A dedicated background thread with its own event loop.
|
||||
|
||||
Allows sync code to submit async coroutines and block for results,
|
||||
even when called from inside another running event loop. Used to
|
||||
bridge sync tool interfaces with async backends (Modal, SWE-ReX).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._loop: asyncio.AbstractEventLoop = None
|
||||
self._thread: threading.Thread = None
|
||||
self._started = threading.Event()
|
||||
|
||||
def start(self):
|
||||
"""Start the background event loop thread."""
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._started.wait(timeout=30)
|
||||
|
||||
def _run_loop(self):
|
||||
"""Background thread entry point -- runs the event loop forever."""
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._started.set()
|
||||
self._loop.run_forever()
|
||||
|
||||
def run_coroutine(self, coro, timeout=600):
|
||||
"""
|
||||
Submit a coroutine to the background loop and block until it completes.
|
||||
|
||||
Safe to call from any thread, including threads that already have
|
||||
a running event loop.
|
||||
"""
|
||||
if self._loop is None or self._loop.is_closed():
|
||||
raise RuntimeError("AsyncWorker loop is not running")
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background event loop and join the thread."""
|
||||
if self._loop and self._loop.is_running():
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread:
|
||||
self._thread.join(timeout=10)
|
||||
|
||||
|
||||
def _patch_swerex_modal():
|
||||
"""
|
||||
Monkey patch SwerexModalEnvironment to use a background thread event loop
|
||||
instead of asyncio.run(). This makes it safe to call from inside Atropos's
|
||||
async event loop.
|
||||
|
||||
The patched methods have the exact same interface and behavior -- the only
|
||||
difference is HOW the async work is executed internally.
|
||||
"""
|
||||
try:
|
||||
from minisweagent.environments.extra.swerex_modal import (
|
||||
SwerexModalEnvironment,
|
||||
SwerexModalEnvironmentConfig,
|
||||
)
|
||||
from swerex.deployment.modal import ModalDeployment
|
||||
from swerex.runtime.abstract import Command as RexCommand
|
||||
except ImportError:
|
||||
# mini-swe-agent or swe-rex not installed -- nothing to patch
|
||||
logger.debug("mini-swe-agent Modal backend not available, skipping patch")
|
||||
return
|
||||
|
||||
# Save original methods so we can refer to config handling
|
||||
_original_init = SwerexModalEnvironment.__init__
|
||||
|
||||
def _patched_init(self, **kwargs):
|
||||
"""Patched __init__: creates Modal deployment on a background thread."""
|
||||
self.config = SwerexModalEnvironmentConfig(**kwargs)
|
||||
|
||||
# Start a dedicated event loop thread for all Modal async operations
|
||||
self._worker = _AsyncWorker()
|
||||
self._worker.start()
|
||||
|
||||
# Create AND start the deployment entirely on the worker's loop/thread
|
||||
# so all gRPC channels and async state are bound to that loop
|
||||
async def _create_and_start():
|
||||
deployment = ModalDeployment(
|
||||
image=self.config.image,
|
||||
startup_timeout=self.config.startup_timeout,
|
||||
runtime_timeout=self.config.runtime_timeout,
|
||||
deployment_timeout=self.config.deployment_timeout,
|
||||
install_pipx=self.config.install_pipx,
|
||||
modal_sandbox_kwargs=self.config.modal_sandbox_kwargs,
|
||||
)
|
||||
await deployment.start()
|
||||
return deployment
|
||||
|
||||
self.deployment = self._worker.run_coroutine(_create_and_start())
|
||||
|
||||
def _patched_execute(self, command: str, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
|
||||
"""Patched execute: runs commands on the background thread's loop."""
|
||||
async def _do_execute():
|
||||
return await self.deployment.runtime.execute(
|
||||
RexCommand(
|
||||
command=command,
|
||||
shell=True,
|
||||
check=False,
|
||||
cwd=cwd or self.config.cwd,
|
||||
timeout=timeout or self.config.timeout,
|
||||
merge_output_streams=True,
|
||||
env=self.config.env if self.config.env else None,
|
||||
)
|
||||
)
|
||||
|
||||
output = self._worker.run_coroutine(_do_execute())
|
||||
return {
|
||||
"output": output.stdout,
|
||||
"returncode": output.exit_code,
|
||||
}
|
||||
|
||||
def _patched_stop(self):
|
||||
"""Patched stop: stops deployment on the background thread, then stops the thread."""
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
asyncio.wait_for(self.deployment.stop(), timeout=10),
|
||||
timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._worker.stop()
|
||||
|
||||
# Apply the patches
|
||||
SwerexModalEnvironment.__init__ = _patched_init
|
||||
SwerexModalEnvironment.execute = _patched_execute
|
||||
SwerexModalEnvironment.stop = _patched_stop
|
||||
|
||||
logger.debug("Patched SwerexModalEnvironment for async-safe operation")
|
||||
|
||||
|
||||
def apply_patches():
|
||||
"""
|
||||
Apply all monkey patches needed for Atropos compatibility.
|
||||
|
||||
Safe to call multiple times -- patches are only applied once.
|
||||
Safe for normal CLI use -- patched code works identically when
|
||||
there is no running event loop.
|
||||
"""
|
||||
global _patches_applied
|
||||
if _patches_applied:
|
||||
return
|
||||
|
||||
_patch_swerex_modal()
|
||||
|
||||
_patches_applied = True
|
||||
0
environments/terminal_test_env/__init__.py
Normal file
0
environments/terminal_test_env/__init__.py
Normal file
34
environments/terminal_test_env/default.yaml
Normal file
34
environments/terminal_test_env/default.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
# Terminal Test Environment -- Default Configuration
|
||||
#
|
||||
# Simple file-creation tasks for validating the full Atropos + hermes-agent stack.
|
||||
# Uses Modal terminal backend and OpenRouter (Claude) for inference.
|
||||
# API keys loaded from ~/hermes-agent/.env
|
||||
#
|
||||
# Usage:
|
||||
# run-api
|
||||
# python environments/terminal_test_env/terminal_test_env.py serve \
|
||||
# --config environments/terminal_test_env/default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 10
|
||||
max_token_length: 2048
|
||||
group_size: 3
|
||||
total_steps: 3
|
||||
steps_per_eval: 3
|
||||
terminal_backend: "modal"
|
||||
tool_call_parser: "hermes"
|
||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
ensure_scores_are_not_same: false
|
||||
use_wandb: false
|
||||
system_prompt: >
|
||||
You are a helpful assistant with access to a terminal and file tools.
|
||||
Complete the user's request by using the available tools.
|
||||
Be precise and follow instructions exactly.
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
292
environments/terminal_test_env/terminal_test_env.py
Normal file
292
environments/terminal_test_env/terminal_test_env.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
TerminalTestEnv -- Simple Test Environment for Validating the Stack
|
||||
|
||||
A self-contained environment with inline tasks (no external dataset needed).
|
||||
Each task asks the model to create a file at a known path with specific content.
|
||||
The reward verifier cats the file and checks if the content matches.
|
||||
|
||||
Enables only terminal + file toolsets. Uses Modal terminal backend with
|
||||
OpenRouter (Claude) by default.
|
||||
|
||||
Training tasks (3):
|
||||
1. Create ~/greeting.txt with "Hello from Hermes Agent"
|
||||
2. Create ~/count.txt with numbers 1-5, one per line
|
||||
3. Create ~/answer.txt with the result of 123 + 456
|
||||
|
||||
Eval task (1):
|
||||
1. Create ~/result.txt with the result of 6 * 7
|
||||
|
||||
Usage:
|
||||
# Start Atropos API server
|
||||
run-api
|
||||
|
||||
# Run environment (uses OpenRouter + Modal by default)
|
||||
python environments/terminal_test_env.py serve
|
||||
|
||||
# Process mode (no run-api needed, saves to JSONL)
|
||||
python environments/terminal_test_env.py process \\
|
||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Inline task definitions -- no external dataset needed
|
||||
# =============================================================================
|
||||
|
||||
TRAIN_TASKS = [
|
||||
{
|
||||
"prompt": "Create a file at ~/greeting.txt containing exactly the text: Hello from Hermes Agent",
|
||||
"verify_path": "~/greeting.txt",
|
||||
"expected_content": "Hello from Hermes Agent",
|
||||
},
|
||||
{
|
||||
"prompt": "Create a file at ~/count.txt containing the numbers 1 through 5, one per line",
|
||||
"verify_path": "~/count.txt",
|
||||
"expected_content": "1\n2\n3\n4\n5",
|
||||
},
|
||||
{
|
||||
"prompt": "Create a file at ~/answer.txt containing the result of 123 + 456",
|
||||
"verify_path": "~/answer.txt",
|
||||
"expected_content": "579",
|
||||
},
|
||||
]
|
||||
|
||||
EVAL_TASKS = [
|
||||
{
|
||||
"prompt": "Create a file at ~/result.txt containing the result of 6 * 7",
|
||||
"verify_path": "~/result.txt",
|
||||
"expected_content": "42",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TerminalTestEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults suitable for terminal testing."""
|
||||
|
||||
pass # Inherits all fields, overrides defaults in config_init
|
||||
|
||||
|
||||
class TerminalTestEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
Simple test environment with inline file-creation tasks.
|
||||
|
||||
All tasks follow the same pattern: "create a file at ~/X.txt with content Y".
|
||||
The verifier runs `cat ~/X.txt` in the rollout's terminal and checks the output
|
||||
against the expected string. Same verifier logic for all tasks.
|
||||
|
||||
This environment is designed to validate the full stack end-to-end:
|
||||
- Agent loop executes tool calls (terminal/file)
|
||||
- ToolContext provides terminal access to the reward function
|
||||
- Reward function verifies file content via cat
|
||||
- Scored data flows through the Atropos pipeline
|
||||
"""
|
||||
|
||||
name = "terminal-test"
|
||||
env_config_cls = TerminalTestEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TerminalTestEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for the terminal test environment.
|
||||
|
||||
Uses Modal terminal backend for cloud isolation and OpenRouter with
|
||||
Claude for inference. API keys loaded from ~/hermes-agent/.env.
|
||||
"""
|
||||
env_config = TerminalTestEnvConfig(
|
||||
# Terminal + file tools only
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings
|
||||
max_agent_turns=10, # Simple tasks, don't need many turns
|
||||
max_token_length=16000,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to a terminal and file tools. "
|
||||
"Complete the user's request by using the available tools. "
|
||||
"Be precise and follow instructions exactly."
|
||||
),
|
||||
# Modal terminal backend for cloud-isolated sandboxes per rollout
|
||||
terminal_backend="modal",
|
||||
# Atropos settings
|
||||
group_size=3, # 3 rollouts per group
|
||||
tokenizer_name="NousResearch/q-30b-t-h45-e1",
|
||||
tool_call_parser="hermes",
|
||||
steps_per_eval=3, # Eval after all 3 steps
|
||||
total_steps=3, # 3 groups total (1 group per step)
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-test",
|
||||
ensure_scores_are_not_same=False, # Allow all-same scores for simple tasks
|
||||
# No external dataset
|
||||
dataset_name=None,
|
||||
)
|
||||
|
||||
# OpenRouter with Claude -- API key loaded from .env (OPENROUTER_API_KEY)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-opus-4.6",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False, # OpenRouter doesn't have a /health endpoint
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize inline task lists."""
|
||||
self.train_tasks = list(TRAIN_TASKS)
|
||||
self.eval_tasks = list(EVAL_TASKS)
|
||||
self.iter = 0
|
||||
# Track reward stats for wandb logging
|
||||
self.reward_buffer: List[float] = []
|
||||
|
||||
async def get_next_item(self) -> Dict[str, str]:
|
||||
"""Cycle through training tasks."""
|
||||
item = self.train_tasks[self.iter % len(self.train_tasks)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, str]) -> str:
|
||||
"""The prompt is directly in the task item."""
|
||||
return item["prompt"]
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, str], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Verify by cat-ing the expected file path and checking content matches.
|
||||
Same verifier for all tasks -- they all write a file at a known path.
|
||||
|
||||
Scoring:
|
||||
1.0 = exact match
|
||||
0.5 = expected content is present but has extra stuff
|
||||
0.0 = file doesn't exist or content doesn't match
|
||||
"""
|
||||
verify_result = ctx.terminal(f"cat {item['verify_path']}")
|
||||
|
||||
# File doesn't exist or can't be read
|
||||
if verify_result["exit_code"] != 0:
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
actual = verify_result.get("output", "").strip()
|
||||
expected = item["expected_content"].strip()
|
||||
|
||||
# Exact match
|
||||
if actual == expected:
|
||||
self.reward_buffer.append(1.0)
|
||||
return 1.0
|
||||
|
||||
# Partial credit: expected content is present but has extra stuff
|
||||
if expected in actual:
|
||||
self.reward_buffer.append(0.5)
|
||||
return 0.5
|
||||
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Run eval tasks using the agent loop and verify results.
|
||||
Logs accuracy metrics.
|
||||
"""
|
||||
start_time = time.time()
|
||||
correct = 0
|
||||
total = len(self.eval_tasks)
|
||||
samples = []
|
||||
|
||||
for eval_item in self.eval_tasks:
|
||||
try:
|
||||
# For eval, we do a simple single-turn completion (not full agent loop)
|
||||
# to keep eval fast. The agent loop is tested via training.
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": eval_item["prompt"]},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_content = (
|
||||
completion.choices[0].message.content if completion.choices else ""
|
||||
)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"prompt": eval_item["prompt"],
|
||||
"response": response_content,
|
||||
"expected": eval_item["expected_content"],
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Eval failed for item: %s", e)
|
||||
samples.append(
|
||||
{
|
||||
"prompt": eval_item["prompt"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": eval_item["expected_content"],
|
||||
}
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
eval_metrics = {
|
||||
"eval/num_samples": total,
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log training metrics including reward stats and accuracy."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
total = len(self.reward_buffer)
|
||||
correct = sum(1 for r in self.reward_buffer if r == 1.0)
|
||||
partial = sum(1 for r in self.reward_buffer if r == 0.5)
|
||||
|
||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / total
|
||||
wandb_metrics["train/accuracy"] = correct / total
|
||||
wandb_metrics["train/partial_match_rate"] = partial / total
|
||||
wandb_metrics["train/total_rollouts"] = total
|
||||
self.reward_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TerminalTestEnv.cli()
|
||||
120
environments/tool_call_parsers/__init__.py
Normal file
120
environments/tool_call_parsers/__init__.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Tool Call Parser Registry
|
||||
|
||||
Client-side parsers that extract structured tool_calls from raw model output text.
|
||||
Used in Phase 2 (VLLM server type) where ManagedServer's /generate endpoint returns
|
||||
raw text without tool call parsing.
|
||||
|
||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's
|
||||
non-streaming extract_tool_calls() logic. No VLLM dependency -- only standard library
|
||||
(re, json, uuid) and openai types.
|
||||
|
||||
Usage:
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse(raw_model_output)
|
||||
# content = text with tool call markup stripped
|
||||
# tool_calls = list of ChatCompletionMessageToolCall objects, or None
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for parser return value
|
||||
ParseResult = Tuple[Optional[str], Optional[List[ChatCompletionMessageToolCall]]]
|
||||
|
||||
|
||||
class ToolCallParser(ABC):
|
||||
"""
|
||||
Base class for tool call parsers.
|
||||
|
||||
Each parser knows how to extract structured tool_calls from a specific
|
||||
model family's raw output text format.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
"""
|
||||
Parse raw model output text for tool calls.
|
||||
|
||||
Args:
|
||||
text: Raw decoded text from the model's completion
|
||||
|
||||
Returns:
|
||||
Tuple of (content, tool_calls) where:
|
||||
- content: text with tool call markup stripped (the message 'content' field),
|
||||
or None if the entire output was tool calls
|
||||
- tool_calls: list of ChatCompletionMessageToolCall objects,
|
||||
or None if no tool calls were found
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Global parser registry: name -> parser class
|
||||
PARSER_REGISTRY: Dict[str, Type[ToolCallParser]] = {}
|
||||
|
||||
|
||||
def register_parser(name: str):
|
||||
"""
|
||||
Decorator to register a parser class under a given name.
|
||||
|
||||
Usage:
|
||||
@register_parser("hermes")
|
||||
class HermesToolCallParser(ToolCallParser):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[ToolCallParser]) -> Type[ToolCallParser]:
|
||||
PARSER_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_parser(name: str) -> ToolCallParser:
|
||||
"""
|
||||
Get a parser instance by name.
|
||||
|
||||
Args:
|
||||
name: Parser name (e.g., "hermes", "mistral", "llama3_json")
|
||||
|
||||
Returns:
|
||||
Instantiated parser
|
||||
|
||||
Raises:
|
||||
KeyError: If parser name is not found in registry
|
||||
"""
|
||||
if name not in PARSER_REGISTRY:
|
||||
available = sorted(PARSER_REGISTRY.keys())
|
||||
raise KeyError(
|
||||
f"Tool call parser '{name}' not found. Available parsers: {available}"
|
||||
)
|
||||
return PARSER_REGISTRY[name]()
|
||||
|
||||
|
||||
def list_parsers() -> List[str]:
|
||||
"""Return sorted list of registered parser names."""
|
||||
return sorted(PARSER_REGISTRY.keys())
|
||||
|
||||
|
||||
# Import all parser modules to trigger registration via @register_parser decorators
|
||||
# Each module registers itself when imported
|
||||
from environments.tool_call_parsers.hermes_parser import HermesToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.longcat_parser import LongcatToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.mistral_parser import MistralToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.llama_parser import LlamaToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.qwen_parser import QwenToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.deepseek_v3_parser import DeepSeekV3ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.deepseek_v3_1_parser import DeepSeekV31ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.kimi_k2_parser import KimiK2ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.glm45_parser import Glm45ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.glm47_parser import Glm47ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.qwen3_coder_parser import Qwen3CoderToolCallParser # noqa: E402, F401
|
||||
72
environments/tool_call_parsers/deepseek_v3_1_parser.py
Normal file
72
environments/tool_call_parsers/deepseek_v3_1_parser.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
DeepSeek V3.1 tool call parser.
|
||||
|
||||
Similar to V3 but with a slightly different format:
|
||||
<|tool▁call▁begin|>function_name<|tool▁sep|>arguments<|tool▁call▁end|>
|
||||
|
||||
Note: V3 has type+name before the separator, V3.1 has name before and args after.
|
||||
|
||||
Based on VLLM's DeepSeekV31ToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("deepseek_v3_1")
|
||||
@register_parser("deepseek_v31")
|
||||
class DeepSeekV31ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for DeepSeek V3.1 tool calls.
|
||||
|
||||
Slightly different regex than V3: function_name comes before the separator,
|
||||
arguments come after (no type field, no json code block wrapper).
|
||||
"""
|
||||
|
||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
||||
|
||||
# Regex captures: function_name, function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
func_name, func_args = match
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name.strip(),
|
||||
arguments=func_args.strip(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
76
environments/tool_call_parsers/deepseek_v3_parser.py
Normal file
76
environments/tool_call_parsers/deepseek_v3_parser.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
DeepSeek V3 tool call parser.
|
||||
|
||||
Format uses special unicode tokens:
|
||||
<|tool▁calls▁begin|>
|
||||
<|tool▁call▁begin|>type<|tool▁sep|>function_name
|
||||
```json
|
||||
{"arg": "value"}
|
||||
```
|
||||
<|tool▁call▁end|>
|
||||
<|tool▁calls▁end|>
|
||||
|
||||
Based on VLLM's DeepSeekV3ToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("deepseek_v3")
|
||||
class DeepSeekV3ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for DeepSeek V3 tool calls.
|
||||
|
||||
Uses special unicode tokens with fullwidth angle brackets and block elements.
|
||||
Extracts type, function name, and JSON arguments from the structured format.
|
||||
"""
|
||||
|
||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
||||
|
||||
# Regex captures: type, function_name, function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<|tool▁call▁end|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
tc_type, func_name, func_args = match
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name.strip(),
|
||||
arguments=func_args.strip(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the tool calls section
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
109
environments/tool_call_parsers/glm45_parser.py
Normal file
109
environments/tool_call_parsers/glm45_parser.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
GLM 4.5 (GLM-4-MoE) tool call parser.
|
||||
|
||||
Format uses custom arg_key/arg_value tags rather than standard JSON:
|
||||
<tool_call>function_name
|
||||
<arg_key>param1</arg_key><arg_value>value1</arg_value>
|
||||
<arg_key>param2</arg_key><arg_value>value2</arg_value>
|
||||
</tool_call>
|
||||
|
||||
Values are deserialized using json.loads -> ast.literal_eval -> raw string fallback.
|
||||
|
||||
Based on VLLM's Glm4MoeModelToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _deserialize_value(value: str) -> Any:
|
||||
"""
|
||||
Try to deserialize a string value to its native Python type.
|
||||
Attempts json.loads, then ast.literal_eval, then returns raw string.
|
||||
"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@register_parser("glm45")
|
||||
class Glm45ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for GLM 4.5 (GLM-4-MoE) tool calls.
|
||||
|
||||
Uses <tool_call>...</tool_call> tags with <arg_key>/<arg_value> pairs
|
||||
instead of standard JSON arguments.
|
||||
"""
|
||||
|
||||
FUNC_CALL_REGEX = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL)
|
||||
FUNC_DETAIL_REGEX = re.compile(r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL)
|
||||
FUNC_ARG_REGEX = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL
|
||||
)
|
||||
|
||||
START_TOKEN = "<tool_call>"
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matched_calls = self.FUNC_CALL_REGEX.findall(text)
|
||||
if not matched_calls:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
for match in matched_calls:
|
||||
detail = self.FUNC_DETAIL_REGEX.search(match)
|
||||
if not detail:
|
||||
continue
|
||||
|
||||
func_name = detail.group(1).strip()
|
||||
func_args_raw = detail.group(2)
|
||||
|
||||
# Parse arg_key/arg_value pairs
|
||||
pairs = self.FUNC_ARG_REGEX.findall(func_args_raw) if func_args_raw else []
|
||||
arg_dict: Dict[str, Any] = {}
|
||||
for key, value in pairs:
|
||||
arg_key = key.strip()
|
||||
arg_val = _deserialize_value(value.strip())
|
||||
arg_dict[arg_key] = arg_val
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name,
|
||||
arguments=json.dumps(arg_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
35
environments/tool_call_parsers/glm47_parser.py
Normal file
35
environments/tool_call_parsers/glm47_parser.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
GLM 4.7 tool call parser.
|
||||
|
||||
Same as GLM 4.5 but with slightly different regex patterns.
|
||||
The tool_call tags may wrap differently and arg parsing handles
|
||||
newlines between key/value pairs.
|
||||
|
||||
Based on VLLM's Glm47MoeModelToolParser (extends Glm4MoeModelToolParser).
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, register_parser
|
||||
from environments.tool_call_parsers.glm45_parser import Glm45ToolCallParser
|
||||
|
||||
|
||||
@register_parser("glm47")
|
||||
class Glm47ToolCallParser(Glm45ToolCallParser):
|
||||
"""
|
||||
Parser for GLM 4.7 tool calls.
|
||||
Extends GLM 4.5 with updated regex patterns.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# GLM 4.7 uses a slightly different detail regex that includes
|
||||
# the <tool_call> wrapper and optional arg_key content
|
||||
self.FUNC_DETAIL_REGEX = re.compile(
|
||||
r"<tool_call>(.*?)(<arg_key>.*?)?</tool_call>", re.DOTALL
|
||||
)
|
||||
# GLM 4.7 handles newlines between arg_key and arg_value tags
|
||||
self.FUNC_ARG_REGEX = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
|
||||
re.DOTALL,
|
||||
)
|
||||
73
environments/tool_call_parsers/hermes_parser.py
Normal file
73
environments/tool_call_parsers/hermes_parser.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Hermes tool call parser.
|
||||
|
||||
Format: <tool_call>{"name": "func", "arguments": {...}}</tool_call>
|
||||
Based on VLLM's Hermes2ProToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("hermes")
|
||||
class HermesToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Hermes-format tool calls.
|
||||
|
||||
Matches <tool_call>...</tool_call> tags containing JSON with "name" and "arguments".
|
||||
Also handles unclosed <tool_call> at end-of-string (truncated generation).
|
||||
"""
|
||||
|
||||
# Matches both closed and unclosed tool_call tags
|
||||
PATTERN = re.compile(
|
||||
r"<tool_call>\s*(.*?)\s*</tool_call>|<tool_call>\s*(.*)", re.DOTALL
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if "<tool_call>" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
# match is a tuple: (closed_content, unclosed_content)
|
||||
raw_json = match[0] if match[0] else match[1]
|
||||
if not raw_json.strip():
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=json.dumps(
|
||||
tc_data.get("arguments", {}), ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the first <tool_call> tag
|
||||
content = text[: text.find("<tool_call>")].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
93
environments/tool_call_parsers/kimi_k2_parser.py
Normal file
93
environments/tool_call_parsers/kimi_k2_parser.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Kimi K2 tool call parser.
|
||||
|
||||
Format:
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>function_id:0<|tool_call_argument_begin|>{"arg": "val"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
|
||||
The function_id format is typically "functions.func_name:index" or "func_name:index".
|
||||
|
||||
Based on VLLM's KimiK2ToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("kimi_k2")
|
||||
class KimiK2ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Kimi K2 tool calls.
|
||||
|
||||
Uses section begin/end tokens wrapping individual tool call begin/end tokens.
|
||||
The tool_call_id contains the function name (after last dot, before colon).
|
||||
"""
|
||||
|
||||
# Support both singular and plural variants
|
||||
START_TOKENS = [
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_section_begin|>",
|
||||
]
|
||||
|
||||
# Regex captures: tool_call_id (e.g., "functions.get_weather:0"), function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[^<]+:\d+)\s*"
|
||||
r"<\|tool_call_argument_begin\|>\s*"
|
||||
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
|
||||
r"<\|tool_call_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
# Check for any variant of the start token
|
||||
has_start = any(token in text for token in self.START_TOKENS)
|
||||
if not has_start:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
function_id, function_args = match
|
||||
|
||||
# Extract function name from ID format: "functions.get_weather:0" -> "get_weather"
|
||||
function_name = function_id.split(":")[0].split(".")[-1]
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=function_id, # Preserve the original ID format
|
||||
type="function",
|
||||
function=Function(
|
||||
name=function_name,
|
||||
arguments=function_args.strip(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the tool calls section
|
||||
earliest_start = len(text)
|
||||
for token in self.START_TOKENS:
|
||||
idx = text.find(token)
|
||||
if idx >= 0 and idx < earliest_start:
|
||||
earliest_start = idx
|
||||
|
||||
content = text[:earliest_start].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
96
environments/tool_call_parsers/llama_parser.py
Normal file
96
environments/tool_call_parsers/llama_parser.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Llama 3.x / 4 tool call parser.
|
||||
|
||||
Format: The model outputs JSON objects with "name" and "arguments" (or "parameters") keys.
|
||||
May be preceded by <|python_tag|> token. Supports multiple JSON objects separated
|
||||
by content or semicolons.
|
||||
|
||||
Based on VLLM's Llama3JsonToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("llama3_json")
|
||||
@register_parser("llama4_json")
|
||||
class LlamaToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Llama 3.x and 4 JSON-format tool calls.
|
||||
|
||||
Finds JSON objects containing "name" + ("arguments" or "parameters") keys.
|
||||
Uses Python's json.JSONDecoder.raw_decode for robust extraction of
|
||||
JSON objects from mixed text.
|
||||
"""
|
||||
|
||||
BOT_TOKEN = "<|python_tag|>"
|
||||
|
||||
# Regex to find the start of potential JSON objects
|
||||
JSON_START = re.compile(r"\{")
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
# Quick check: need either the bot token or a JSON brace
|
||||
if self.BOT_TOKEN not in text and "{" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
decoder = json.JSONDecoder()
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
end_index = -1 # Track where the last parsed JSON ended
|
||||
|
||||
for match in self.JSON_START.finditer(text):
|
||||
start = match.start()
|
||||
# Skip if this brace is inside a previously parsed JSON object
|
||||
if start <= end_index:
|
||||
continue
|
||||
|
||||
try:
|
||||
obj, json_end = decoder.raw_decode(text[start:])
|
||||
end_index = start + json_end
|
||||
|
||||
# Must have "name" and either "arguments" or "parameters"
|
||||
name = obj.get("name")
|
||||
args = obj.get("arguments", obj.get("parameters"))
|
||||
|
||||
if not name or args is None:
|
||||
continue
|
||||
|
||||
# Normalize arguments to JSON string
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
elif not isinstance(args, str):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(name=name, arguments=args),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError):
|
||||
continue
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the first tool call JSON
|
||||
# Find where the first tool call starts in the text
|
||||
first_tc_start = text.find("{")
|
||||
if self.BOT_TOKEN in text:
|
||||
first_tc_start = text.find(self.BOT_TOKEN)
|
||||
content = text[:first_tc_start].strip() if first_tc_start > 0 else None
|
||||
|
||||
return content, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
69
environments/tool_call_parsers/longcat_parser.py
Normal file
69
environments/tool_call_parsers/longcat_parser.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Longcat Flash Chat tool call parser.
|
||||
|
||||
Same as Hermes but uses <longcat_tool_call> tags instead of <tool_call>.
|
||||
Based on VLLM's LongcatFlashToolParser (extends Hermes2ProToolParser).
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("longcat")
|
||||
class LongcatToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Longcat Flash Chat tool calls.
|
||||
Identical logic to Hermes, just different tag names.
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(
|
||||
r"<longcat_tool_call>\s*(.*?)\s*</longcat_tool_call>|<longcat_tool_call>\s*(.*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if "<longcat_tool_call>" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
raw_json = match[0] if match[0] else match[1]
|
||||
if not raw_json.strip():
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=json.dumps(
|
||||
tc_data.get("arguments", {}), ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find("<longcat_tool_call>")].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
130
environments/tool_call_parsers/mistral_parser.py
Normal file
130
environments/tool_call_parsers/mistral_parser.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Mistral tool call parser.
|
||||
|
||||
Supports two formats depending on tokenizer version:
|
||||
- Pre-v11: content[TOOL_CALLS] [{"name": ..., "arguments": {...}}, ...]
|
||||
- v11+: content[TOOL_CALLS]tool_name1{"arg": "val"}[TOOL_CALLS]tool_name2{"arg": "val"}
|
||||
|
||||
Based on VLLM's MistralToolParser.extract_tool_calls()
|
||||
The [TOOL_CALLS] token is the bot_token used by Mistral models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _generate_mistral_id() -> str:
|
||||
"""Mistral tool call IDs are 9-char alphanumeric strings."""
|
||||
import random
|
||||
import string
|
||||
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=9))
|
||||
|
||||
|
||||
@register_parser("mistral")
|
||||
class MistralToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Mistral-format tool calls.
|
||||
|
||||
Detects format by checking if the content after [TOOL_CALLS] starts with '['
|
||||
(pre-v11 JSON array) or with a tool name (v11+ format).
|
||||
"""
|
||||
|
||||
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
||||
BOT_TOKEN = "[TOOL_CALLS]"
|
||||
|
||||
# Fallback regex for pre-v11 format when JSON parsing fails
|
||||
TOOL_CALL_REGEX = re.compile(r"\[?\s*(\{.*?\})\s*\]?", re.DOTALL)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.BOT_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
parts = text.split(self.BOT_TOKEN)
|
||||
content = parts[0].strip()
|
||||
raw_tool_calls = parts[1:]
|
||||
|
||||
# Detect format: if the first raw part starts with '[', it's pre-v11
|
||||
first_raw = raw_tool_calls[0].strip() if raw_tool_calls else ""
|
||||
is_pre_v11 = first_raw.startswith("[") or first_raw.startswith("{")
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
if not is_pre_v11:
|
||||
# v11+ format: [TOOL_CALLS]tool_name{args}[TOOL_CALLS]tool_name2{args2}
|
||||
for raw in raw_tool_calls:
|
||||
raw = raw.strip()
|
||||
if not raw or "{" not in raw:
|
||||
continue
|
||||
|
||||
brace_idx = raw.find("{")
|
||||
tool_name = raw[:brace_idx].strip()
|
||||
args_str = raw[brace_idx:]
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(name=tool_name, arguments=args_str),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pre-v11 format: [TOOL_CALLS] [{"name": ..., "arguments": {...}}]
|
||||
try:
|
||||
parsed = json.loads(first_raw)
|
||||
if isinstance(parsed, dict):
|
||||
parsed = [parsed]
|
||||
|
||||
for tc in parsed:
|
||||
args = tc.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback regex extraction
|
||||
match = self.TOOL_CALL_REGEX.findall(first_raw)
|
||||
if match:
|
||||
for raw_json in match:
|
||||
try:
|
||||
tc = json.loads(raw_json)
|
||||
args = tc.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
163
environments/tool_call_parsers/qwen3_coder_parser.py
Normal file
163
environments/tool_call_parsers/qwen3_coder_parser.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Qwen3-Coder tool call parser.
|
||||
|
||||
Format uses XML-style nested tags:
|
||||
<tool_call>
|
||||
<function=function_name>
|
||||
<parameter=param_name>value</parameter>
|
||||
<parameter=param_name2>value2</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
Parameters are extracted from <parameter=name>value</parameter> tags and
|
||||
type-converted using the schema if available, otherwise treated as strings.
|
||||
|
||||
Based on VLLM's Qwen3CoderToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _try_convert_value(value: str) -> Any:
|
||||
"""
|
||||
Try to convert a parameter value string to a native Python type.
|
||||
Handles null, numbers, booleans, JSON objects/arrays, and falls back to string.
|
||||
"""
|
||||
stripped = value.strip()
|
||||
|
||||
# Handle null
|
||||
if stripped.lower() == "null":
|
||||
return None
|
||||
|
||||
# Try JSON first (handles objects, arrays, strings, numbers, booleans)
|
||||
try:
|
||||
return json.loads(stripped)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Try Python literal eval (handles tuples, etc.)
|
||||
try:
|
||||
return ast.literal_eval(stripped)
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
# Return as string
|
||||
return stripped
|
||||
|
||||
|
||||
@register_parser("qwen3_coder")
|
||||
class Qwen3CoderToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Qwen3-Coder XML-format tool calls.
|
||||
|
||||
Uses nested XML tags: <tool_call><function=name><parameter=key>val</parameter></function></tool_call>
|
||||
"""
|
||||
|
||||
START_TOKEN = "<tool_call>"
|
||||
FUNCTION_PREFIX = "<function="
|
||||
|
||||
# Find complete tool_call blocks (or unclosed at end)
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
|
||||
)
|
||||
|
||||
# Find function blocks within a tool_call
|
||||
FUNCTION_REGEX = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
|
||||
)
|
||||
|
||||
# Find parameter blocks within a function
|
||||
PARAMETER_REGEX = re.compile(
|
||||
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def _parse_function_call(self, function_str: str) -> Optional[ChatCompletionMessageToolCall]:
|
||||
"""Parse a single <function=name>...</function> block into a ToolCall."""
|
||||
try:
|
||||
# Extract function name: everything before the first '>'
|
||||
gt_idx = function_str.index(">")
|
||||
func_name = function_str[:gt_idx].strip()
|
||||
params_str = function_str[gt_idx + 1:]
|
||||
|
||||
# Extract parameters
|
||||
param_dict: Dict[str, Any] = {}
|
||||
for match_text in self.PARAMETER_REGEX.findall(params_str):
|
||||
if ">" not in match_text:
|
||||
continue
|
||||
eq_idx = match_text.index(">")
|
||||
param_name = match_text[:eq_idx].strip()
|
||||
param_value = match_text[eq_idx + 1:]
|
||||
|
||||
# Clean up whitespace
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
param_dict[param_name] = _try_convert_value(param_value)
|
||||
|
||||
return ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name,
|
||||
arguments=json.dumps(param_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.FUNCTION_PREFIX not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
# Find all tool_call blocks
|
||||
tc_matches = self.TOOL_CALL_REGEX.findall(text)
|
||||
raw_blocks = [m[0] if m[0] else m[1] for m in tc_matches]
|
||||
|
||||
# Fallback: if no tool_call tags, try the whole text
|
||||
if not raw_blocks:
|
||||
raw_blocks = [text]
|
||||
|
||||
# Find function blocks within each tool_call
|
||||
function_strs: List[str] = []
|
||||
for block in raw_blocks:
|
||||
func_matches = self.FUNCTION_REGEX.findall(block)
|
||||
function_strs.extend(m[0] if m[0] else m[1] for m in func_matches)
|
||||
|
||||
if not function_strs:
|
||||
return text, None
|
||||
|
||||
# Parse each function call
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for func_str in function_strs:
|
||||
tc = self._parse_function_call(func_str)
|
||||
if tc is not None:
|
||||
tool_calls.append(tc)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content before tool calls
|
||||
first_tc = text.find(self.START_TOKEN)
|
||||
if first_tc < 0:
|
||||
first_tc = text.find(self.FUNCTION_PREFIX)
|
||||
content = text[:first_tc].strip() if first_tc > 0 else None
|
||||
|
||||
return content, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
19
environments/tool_call_parsers/qwen_parser.py
Normal file
19
environments/tool_call_parsers/qwen_parser.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Qwen 2.5 tool call parser.
|
||||
|
||||
Uses the same <tool_call> format as Hermes.
|
||||
Registered as a separate parser name for clarity when using --tool-parser=qwen.
|
||||
"""
|
||||
|
||||
from environments.tool_call_parsers import register_parser
|
||||
from environments.tool_call_parsers.hermes_parser import HermesToolCallParser
|
||||
|
||||
|
||||
@register_parser("qwen")
|
||||
class QwenToolCallParser(HermesToolCallParser):
|
||||
"""
|
||||
Parser for Qwen 2.5 tool calls.
|
||||
Same <tool_call>{"name": ..., "arguments": ...}</tool_call> format as Hermes.
|
||||
"""
|
||||
|
||||
pass # Identical format -- inherits everything from Hermes
|
||||
474
environments/tool_context.py
Normal file
474
environments/tool_context.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""
|
||||
ToolContext -- Unrestricted Tool Access for Reward Functions
|
||||
|
||||
A per-rollout handle that gives reward/verification functions direct access to
|
||||
ALL hermes-agent tools, scoped to the rollout's task_id. The same task_id means
|
||||
the terminal/browser session is the SAME one the model used during its rollout --
|
||||
all state (files, processes, browser tabs) is preserved.
|
||||
|
||||
The verifier author decides which tools to use. Nothing is hardcoded or gated.
|
||||
|
||||
Example usage in a compute_reward():
|
||||
async def compute_reward(self, item, result, ctx):
|
||||
# Run tests in the model's terminal sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
if test["exit_code"] == 0:
|
||||
return 1.0
|
||||
|
||||
# Check if a file was created
|
||||
content = ctx.read_file("/workspace/solution.py")
|
||||
if content.get("content"):
|
||||
return 0.5
|
||||
|
||||
return 0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from model_tools import handle_function_call
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.browser_tool import cleanup_browser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
|
||||
"""
|
||||
Run a tool call in a thread pool executor so backends that use asyncio.run()
|
||||
internally (modal, docker, daytona) get a clean event loop.
|
||||
|
||||
If we're already in an async context, executes handle_function_call() in a
|
||||
disposable worker thread and blocks for the result.
|
||||
If not (e.g., called from sync code), runs directly.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# We're in an async context -- need to run in thread
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(
|
||||
handle_function_call, tool_name, arguments, task_id
|
||||
)
|
||||
return future.result(timeout=300)
|
||||
except RuntimeError:
|
||||
# No running event loop -- safe to call directly
|
||||
return handle_function_call(tool_name, arguments, task_id)
|
||||
|
||||
|
||||
class ToolContext:
|
||||
"""
|
||||
Open-ended access to all hermes-agent tools for a specific rollout.
|
||||
|
||||
Passed to compute_reward() so verifiers can use any tool they need:
|
||||
terminal commands, file reads/writes, web searches, browser automation, etc.
|
||||
All calls share the rollout's task_id for session isolation.
|
||||
"""
|
||||
|
||||
def __init__(self, task_id: str):
|
||||
self.task_id = task_id
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Terminal tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def terminal(self, command: str, timeout: int = 180) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a command in the rollout's terminal session.
|
||||
|
||||
Args:
|
||||
command: Shell command to execute
|
||||
timeout: Command timeout in seconds
|
||||
|
||||
Returns:
|
||||
Dict with 'exit_code' (int) and 'output' (str)
|
||||
"""
|
||||
import os
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
|
||||
|
||||
# Run via thread helper so modal/docker/daytona backends' asyncio.run() doesn't deadlock
|
||||
result = _run_tool_in_thread(
|
||||
"terminal",
|
||||
{"command": command, "timeout": timeout},
|
||||
self.task_id,
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"exit_code": -1, "output": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# File tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def read_file(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read a file from the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
path: File path to read
|
||||
|
||||
Returns:
|
||||
Dict with file content or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"read_file", {"path": path}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def write_file(self, path: str, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Write a TEXT file in the rollout's filesystem.
|
||||
|
||||
Uses a shell heredoc under the hood, so this is only safe for text content.
|
||||
For binary files (images, compiled artifacts, etc.), use upload_file() instead.
|
||||
|
||||
Args:
|
||||
path: File path to write
|
||||
content: Text content to write
|
||||
|
||||
Returns:
|
||||
Dict with success status or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"write_file", {"path": path, "content": content}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def upload_file(self, local_path: str, remote_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Upload a local file to the rollout's sandbox (binary-safe).
|
||||
|
||||
Unlike write_file() which passes content through a shell heredoc (text-only),
|
||||
this method base64-encodes the file and decodes it inside the sandbox.
|
||||
Safe for any file type: binaries, images, archives, etc.
|
||||
|
||||
For large files (>1MB), the content is split into chunks to avoid
|
||||
hitting shell command-length limits.
|
||||
|
||||
Args:
|
||||
local_path: Path to a local file on the host
|
||||
remote_path: Destination path inside the sandbox
|
||||
|
||||
Returns:
|
||||
Dict with 'exit_code' and 'output'
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path as _Path
|
||||
|
||||
local = _Path(local_path)
|
||||
if not local.exists():
|
||||
return {"exit_code": -1, "output": f"Local file not found: {local_path}"}
|
||||
|
||||
raw = local.read_bytes()
|
||||
b64 = base64.b64encode(raw).decode("ascii")
|
||||
|
||||
# Ensure parent directory exists in the sandbox
|
||||
parent = str(_Path(remote_path).parent)
|
||||
if parent not in (".", "/"):
|
||||
self.terminal(f"mkdir -p {parent}", timeout=10)
|
||||
|
||||
# For small files, single command is fine
|
||||
chunk_size = 60_000 # ~60KB per chunk (well within shell limits)
|
||||
if len(b64) <= chunk_size:
|
||||
result = self.terminal(
|
||||
f"printf '%s' '{b64}' | base64 -d > {remote_path}",
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
# For larger files, write base64 in chunks then decode
|
||||
tmp_b64 = "/tmp/_hermes_upload.b64"
|
||||
self.terminal(f": > {tmp_b64}", timeout=5) # truncate
|
||||
for i in range(0, len(b64), chunk_size):
|
||||
chunk = b64[i : i + chunk_size]
|
||||
self.terminal(f"printf '%s' '{chunk}' >> {tmp_b64}", timeout=15)
|
||||
result = self.terminal(
|
||||
f"base64 -d {tmp_b64} > {remote_path} && rm -f {tmp_b64}",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def upload_dir(self, local_dir: str, remote_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Upload an entire local directory to the rollout's sandbox (binary-safe).
|
||||
|
||||
Recursively uploads all files, preserving directory structure.
|
||||
|
||||
Args:
|
||||
local_dir: Path to a local directory on the host
|
||||
remote_dir: Destination directory inside the sandbox
|
||||
|
||||
Returns:
|
||||
List of results, one per file uploaded
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
local = _Path(local_dir)
|
||||
if not local.exists() or not local.is_dir():
|
||||
return [{"exit_code": -1, "output": f"Local directory not found: {local_dir}"}]
|
||||
|
||||
results = []
|
||||
for file_path in sorted(local.rglob("*")):
|
||||
if file_path.is_file():
|
||||
relative = file_path.relative_to(local)
|
||||
target = f"{remote_dir}/{relative}"
|
||||
results.append(self.upload_file(str(file_path), target))
|
||||
return results
|
||||
|
||||
def download_file(self, remote_path: str, local_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Download a file from the rollout's sandbox to the host (binary-safe).
|
||||
|
||||
The inverse of upload_file(). Base64-encodes the file inside the sandbox,
|
||||
reads the encoded data through the terminal, and decodes it locally.
|
||||
Safe for any file type.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file inside the sandbox
|
||||
local_path: Destination path on the host
|
||||
|
||||
Returns:
|
||||
Dict with 'success' (bool) and 'bytes' (int) or 'error' (str)
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# Base64-encode the file inside the sandbox and capture output
|
||||
result = self.terminal(
|
||||
f"base64 {remote_path} 2>/dev/null",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if result.get("exit_code", -1) != 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to read remote file: {result.get('output', '')}",
|
||||
}
|
||||
|
||||
b64_data = result.get("output", "").strip()
|
||||
if not b64_data:
|
||||
return {"success": False, "error": f"Remote file is empty or missing: {remote_path}"}
|
||||
|
||||
try:
|
||||
raw = base64.b64decode(b64_data)
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"Base64 decode failed: {e}"}
|
||||
|
||||
# Write to local host filesystem
|
||||
local = _Path(local_path)
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_bytes(raw)
|
||||
|
||||
return {"success": True, "bytes": len(raw)}
|
||||
|
||||
def download_dir(self, remote_dir: str, local_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Download a directory from the rollout's sandbox to the host (binary-safe).
|
||||
|
||||
Lists all files in the remote directory, then downloads each one.
|
||||
Preserves directory structure.
|
||||
|
||||
Args:
|
||||
remote_dir: Path to the directory inside the sandbox
|
||||
local_dir: Destination directory on the host
|
||||
|
||||
Returns:
|
||||
List of results, one per file downloaded
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# List files in the remote directory
|
||||
ls_result = self.terminal(
|
||||
f"find {remote_dir} -type f 2>/dev/null",
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if ls_result.get("exit_code", -1) != 0:
|
||||
return [{"success": False, "error": f"Failed to list remote dir: {remote_dir}"}]
|
||||
|
||||
file_list = ls_result.get("output", "").strip()
|
||||
if not file_list:
|
||||
return [{"success": False, "error": f"Remote directory is empty or missing: {remote_dir}"}]
|
||||
|
||||
results = []
|
||||
for remote_file in file_list.splitlines():
|
||||
remote_file = remote_file.strip()
|
||||
if not remote_file:
|
||||
continue
|
||||
# Compute the relative path to preserve directory structure
|
||||
if remote_file.startswith(remote_dir):
|
||||
relative = remote_file[len(remote_dir):].lstrip("/")
|
||||
else:
|
||||
relative = _Path(remote_file).name
|
||||
local_file = str(_Path(local_dir) / relative)
|
||||
results.append(self.download_file(remote_file, local_file))
|
||||
|
||||
return results
|
||||
|
||||
def search(self, query: str, path: str = ".") -> Dict[str, Any]:
|
||||
"""
|
||||
Search for text in the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
path: Directory to search in
|
||||
|
||||
Returns:
|
||||
Dict with search results
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"search_files", {"pattern": query, "path": path}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Web tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def web_search(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Search the web.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
Dict with search results
|
||||
"""
|
||||
result = handle_function_call("web_search", {"query": query})
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def web_extract(self, urls: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract content from URLs.
|
||||
|
||||
Args:
|
||||
urls: List of URLs to extract content from
|
||||
|
||||
Returns:
|
||||
Dict with extracted content
|
||||
"""
|
||||
result = handle_function_call("web_extract", {"urls": urls})
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Browser tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def browser_navigate(self, url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Navigate the rollout's browser session to a URL.
|
||||
|
||||
Args:
|
||||
url: URL to navigate to
|
||||
|
||||
Returns:
|
||||
Dict with page snapshot or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"browser_navigate", {"url": url}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def browser_snapshot(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Take a snapshot of the current browser page.
|
||||
|
||||
Returns:
|
||||
Dict with page content/accessibility snapshot
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"browser_snapshot", {}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Generic tool access
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Call any hermes-agent tool by name.
|
||||
|
||||
This is the generic escape hatch -- if a tool doesn't have a convenience
|
||||
wrapper above, you can call it directly here.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (e.g., "vision_analyze", "skills_list")
|
||||
arguments: Dict of arguments for the tool
|
||||
|
||||
Returns:
|
||||
Raw JSON string result from the tool
|
||||
"""
|
||||
return _run_tool_in_thread(tool_name, arguments, self.task_id)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Release all resources (terminal VMs, browser sessions, background processes)
|
||||
for this rollout.
|
||||
|
||||
Called automatically by the base environment via try/finally after
|
||||
compute_reward() completes. You generally don't need to call this yourself.
|
||||
"""
|
||||
# Kill any background processes from this rollout (safety net)
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
killed = process_registry.kill_all(task_id=self.task_id)
|
||||
if killed:
|
||||
logger.debug("Process cleanup for task %s: killed %d process(es)", self.task_id, killed)
|
||||
except Exception as e:
|
||||
logger.debug("Process cleanup for task %s: %s", self.task_id, e)
|
||||
|
||||
try:
|
||||
cleanup_vm(self.task_id)
|
||||
except Exception as e:
|
||||
logger.debug("VM cleanup for task %s: %s", self.task_id, e)
|
||||
|
||||
# Suppress browser_tool's noisy debug prints during cleanup.
|
||||
# The cleanup still runs (safe), it just doesn't spam the console.
|
||||
_prev_quiet = os.environ.get("HERMES_QUIET")
|
||||
os.environ["HERMES_QUIET"] = "1"
|
||||
try:
|
||||
cleanup_browser(self.task_id)
|
||||
except Exception as e:
|
||||
logger.debug("Browser cleanup for task %s: %s", self.task_id, e)
|
||||
finally:
|
||||
if _prev_quiet is None:
|
||||
os.environ.pop("HERMES_QUIET", None)
|
||||
else:
|
||||
os.environ["HERMES_QUIET"] = _prev_quiet
|
||||
643
environments/web_research_env.py
Normal file
643
environments/web_research_env.py
Normal file
@@ -0,0 +1,643 @@
|
||||
"""
|
||||
WebResearchEnv — RL Environment for Multi-Step Web Research
|
||||
============================================================
|
||||
|
||||
Trains models to do accurate, efficient, multi-source web research.
|
||||
|
||||
Reward signals:
|
||||
- Answer correctness (LLM judge, 0.0–1.0)
|
||||
- Source diversity (used ≥2 distinct domains)
|
||||
- Efficiency (penalizes excessive tool calls)
|
||||
- Tool usage (bonus for actually using web tools)
|
||||
|
||||
Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
||||
HuggingFace: google/frames-benchmark
|
||||
Fallback: built-in sample questions (no HF token needed)
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenAI-compatible server)
|
||||
python environments/web_research_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai
|
||||
|
||||
# Process mode (offline data generation)
|
||||
python environments/web_research_env.py process \\
|
||||
--env.data_path_to_save_groups data/web_research.jsonl
|
||||
|
||||
# Standalone eval
|
||||
python environments/web_research_env.py evaluate \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel
|
||||
|
||||
Built by: github.com/jackx707
|
||||
Inspired by: GroceryMind — production Hermes agent doing live web research
|
||||
across German grocery stores (firecrawl + hermes-agent)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
# Ensure hermes-agent root is on path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional HuggingFace datasets import
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
||||
# Multi-hop questions requiring real web search to answer.
|
||||
# ---------------------------------------------------------------------------
|
||||
SAMPLE_QUESTIONS = [
|
||||
{
|
||||
"question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?",
|
||||
"answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?",
|
||||
"answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What programming language was used to write the original version of the web framework used by Instagram?",
|
||||
"answer": "Django, which Instagram was built on, is written in Python.",
|
||||
"difficulty": "easy",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?",
|
||||
"answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).",
|
||||
"difficulty": "hard",
|
||||
"hops": 3,
|
||||
},
|
||||
{
|
||||
"question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?",
|
||||
"answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "How many employees does the parent company of Instagram have?",
|
||||
"answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?",
|
||||
"answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Which company acquired the startup founded by the creator of Oculus VR?",
|
||||
"answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the market cap of the company that owns the most popular search engine in Russia?",
|
||||
"answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?",
|
||||
"answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnvConfig(HermesAgentEnvConfig):
|
||||
"""Configuration for the web research RL environment."""
|
||||
|
||||
# Reward weights
|
||||
correctness_weight: float = Field(
|
||||
default=0.6,
|
||||
description="Weight for answer correctness in reward (LLM judge score).",
|
||||
)
|
||||
tool_usage_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for tool usage signal (did the model actually use web tools?).",
|
||||
)
|
||||
efficiency_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for efficiency signal (penalizes excessive tool calls).",
|
||||
)
|
||||
diversity_bonus: float = Field(
|
||||
default=0.1,
|
||||
description="Bonus reward for citing ≥2 distinct domains.",
|
||||
)
|
||||
|
||||
# Efficiency thresholds
|
||||
efficient_max_calls: int = Field(
|
||||
default=5,
|
||||
description="Maximum tool calls before efficiency penalty begins.",
|
||||
)
|
||||
heavy_penalty_calls: int = Field(
|
||||
default=10,
|
||||
description="Tool call count where efficiency penalty steepens.",
|
||||
)
|
||||
|
||||
# Eval
|
||||
eval_size: int = Field(
|
||||
default=20,
|
||||
description="Number of held-out items for evaluation.",
|
||||
)
|
||||
eval_split_ratio: float = Field(
|
||||
default=0.1,
|
||||
description="Fraction of dataset to hold out for evaluation (0.0–1.0).",
|
||||
)
|
||||
|
||||
# Dataset
|
||||
dataset_name: str = Field(
|
||||
default="google/frames-benchmark",
|
||||
description="HuggingFace dataset name for research questions.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
RL environment for training multi-step web research skills.
|
||||
|
||||
The model is given a factual question requiring 2-3 hops of web research
|
||||
and must use web_search / web_extract tools to find and synthesize the answer.
|
||||
|
||||
Reward is multi-signal:
|
||||
60% — answer correctness (LLM judge)
|
||||
20% — tool usage (did the model actually search the web?)
|
||||
20% — efficiency (penalizes >5 tool calls)
|
||||
|
||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
||||
"""
|
||||
|
||||
name = "web-research"
|
||||
env_config_cls = WebResearchEnvConfig
|
||||
|
||||
# Default toolsets for this environment — web + file for saving notes
|
||||
default_toolsets = ["web", "file"]
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
|
||||
"""Default configuration for the web research environment."""
|
||||
env_config = WebResearchEnvConfig(
|
||||
enabled_toolsets=["web", "file"],
|
||||
max_agent_turns=15,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a highly capable research agent. When asked a factual question, "
|
||||
"always use web_search to find current, accurate information before answering. "
|
||||
"Cite at least 2 sources. Be concise and accurate."
|
||||
),
|
||||
group_size=4,
|
||||
total_steps=1000,
|
||||
steps_per_eval=100,
|
||||
use_wandb=True,
|
||||
wandb_name="web-research",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.5",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._items: list[dict] = []
|
||||
self._eval_items: list[dict] = []
|
||||
self._index: int = 0
|
||||
|
||||
# Metrics tracking for wandb
|
||||
self._reward_buffer: list[float] = []
|
||||
self._correctness_buffer: list[float] = []
|
||||
self._tool_usage_buffer: list[float] = []
|
||||
self._efficiency_buffer: list[float] = []
|
||||
self._diversity_buffer: list[float] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Setup — load dataset
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Load the FRAMES benchmark or fall back to built-in samples."""
|
||||
if HF_AVAILABLE:
|
||||
try:
|
||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
||||
ds = load_dataset(self.config.dataset_name, split="test")
|
||||
self._items = [
|
||||
{
|
||||
"question": row["Prompt"],
|
||||
"answer": row["Answer"],
|
||||
"difficulty": row.get("reasoning_types", "unknown"),
|
||||
"hops": 2,
|
||||
}
|
||||
for row in ds
|
||||
]
|
||||
# Hold out for eval
|
||||
eval_size = max(
|
||||
self.config.eval_size,
|
||||
int(len(self._items) * self.config.eval_split_ratio),
|
||||
)
|
||||
random.shuffle(self._items)
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
logger.info(
|
||||
f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items "
|
||||
f"from FRAMES benchmark."
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.")
|
||||
|
||||
# Fallback
|
||||
random.shuffle(SAMPLE_QUESTIONS)
|
||||
split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10)
|
||||
self._items = SAMPLE_QUESTIONS[:split]
|
||||
self._eval_items = SAMPLE_QUESTIONS[split:]
|
||||
logger.info(
|
||||
f"Using built-in sample dataset: {len(self._items)} train / "
|
||||
f"{len(self._eval_items)} eval items."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. get_next_item — return the next question
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_next_item(self) -> dict:
|
||||
"""Return the next item, cycling through the dataset."""
|
||||
if not self._items:
|
||||
raise RuntimeError("Dataset is empty. Did you call setup()?")
|
||||
item = self._items[self._index % len(self._items)]
|
||||
self._index += 1
|
||||
return item
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. format_prompt — build the user-facing prompt
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""Format the research question as a task prompt."""
|
||||
return (
|
||||
f"Research the following question thoroughly using web search. "
|
||||
f"You MUST search the web to find current, accurate information — "
|
||||
f"do not rely solely on your training data.\n\n"
|
||||
f"Question: {item['question']}\n\n"
|
||||
f"Requirements:\n"
|
||||
f"- Use web_search and/or web_extract tools to find information\n"
|
||||
f"- Search at least 2 different sources\n"
|
||||
f"- Provide a concise, accurate answer (2-4 sentences)\n"
|
||||
f"- Cite the sources you used"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. compute_reward — multi-signal scoring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def compute_reward(
|
||||
self,
|
||||
item: dict,
|
||||
result: AgentResult,
|
||||
ctx: ToolContext,
|
||||
) -> float:
|
||||
"""
|
||||
Multi-signal reward function:
|
||||
|
||||
correctness_weight * correctness — LLM judge comparing answer to ground truth
|
||||
tool_usage_weight * tool_used — binary: did the model use web tools?
|
||||
efficiency_weight * efficiency — penalizes wasteful tool usage
|
||||
+ diversity_bonus — source diversity (≥2 distinct domains)
|
||||
"""
|
||||
final_response: str = result.final_response or ""
|
||||
tools_used: list[str] = [
|
||||
tc.tool_name for tc in (result.tool_calls or [])
|
||||
] if hasattr(result, "tool_calls") and result.tool_calls else []
|
||||
tool_call_count: int = result.turns_used or len(tools_used)
|
||||
|
||||
cfg = self.config
|
||||
|
||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=final_response,
|
||||
)
|
||||
|
||||
# ---- Signal 2: Web tool usage --------------------------------
|
||||
web_tools = {"web_search", "web_extract", "search", "firecrawl"}
|
||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
||||
|
||||
# ---- Signal 3: Efficiency ------------------------------------
|
||||
if tool_call_count <= cfg.efficient_max_calls:
|
||||
efficiency = 1.0
|
||||
elif tool_call_count <= cfg.heavy_penalty_calls:
|
||||
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
|
||||
else:
|
||||
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
|
||||
|
||||
# ---- Bonus: Source diversity ---------------------------------
|
||||
domains = self._extract_domains(final_response)
|
||||
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
|
||||
|
||||
# ---- Combine ------------------------------------------------
|
||||
reward = (
|
||||
cfg.correctness_weight * correctness
|
||||
+ cfg.tool_usage_weight * tool_used
|
||||
+ cfg.efficiency_weight * efficiency
|
||||
+ diversity
|
||||
)
|
||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
||||
|
||||
# Track for wandb
|
||||
self._reward_buffer.append(reward)
|
||||
self._correctness_buffer.append(correctness)
|
||||
self._tool_usage_buffer.append(tool_used)
|
||||
self._efficiency_buffer.append(efficiency)
|
||||
self._diversity_buffer.append(diversity)
|
||||
|
||||
logger.debug(
|
||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
||||
f"diversity={diversity:.1f} → total={reward:.3f}"
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. evaluate — run on held-out eval split
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Run evaluation on the held-out split using the agent loop."""
|
||||
import time
|
||||
|
||||
items = self._eval_items
|
||||
if not items:
|
||||
logger.warning("No eval items available.")
|
||||
return
|
||||
|
||||
eval_size = min(self.config.eval_size, len(items))
|
||||
eval_items = items[:eval_size]
|
||||
|
||||
logger.info(f"Running eval on {len(eval_items)} questions...")
|
||||
start_time = time.time()
|
||||
samples = []
|
||||
|
||||
for item in eval_items:
|
||||
try:
|
||||
# Use the base env's agent loop for eval (same as training)
|
||||
prompt = self.format_prompt(item)
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_content = (
|
||||
completion.choices[0].message.content if completion.choices else ""
|
||||
)
|
||||
|
||||
# Score the response
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=response_content,
|
||||
)
|
||||
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": response_content,
|
||||
"expected": item["answer"],
|
||||
"correctness": correctness,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Eval error on item: {e}")
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": item["answer"],
|
||||
"correctness": 0.0,
|
||||
})
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Compute metrics
|
||||
correctness_scores = [s["correctness"] for s in samples]
|
||||
eval_metrics = {
|
||||
"eval/mean_correctness": (
|
||||
sum(correctness_scores) / len(correctness_scores)
|
||||
if correctness_scores else 0.0
|
||||
),
|
||||
"eval/n_items": len(samples),
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. wandb_log — custom metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||
"""Log reward breakdown metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self._reward_buffer:
|
||||
n = len(self._reward_buffer)
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
|
||||
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
|
||||
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
|
||||
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
|
||||
wandb_metrics["train/total_rollouts"] = n
|
||||
|
||||
# Accuracy buckets
|
||||
wandb_metrics["train/correct_rate"] = (
|
||||
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
|
||||
)
|
||||
wandb_metrics["train/tool_usage_rate"] = (
|
||||
sum(1 for t in self._tool_usage_buffer if t > 0) / n
|
||||
)
|
||||
|
||||
# Clear buffers
|
||||
self._reward_buffer.clear()
|
||||
self._correctness_buffer.clear()
|
||||
self._tool_usage_buffer.clear()
|
||||
self._efficiency_buffer.clear()
|
||||
self._diversity_buffer.clear()
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _llm_judge(
|
||||
self,
|
||||
question: str,
|
||||
expected: str,
|
||||
model_answer: str,
|
||||
) -> float:
|
||||
"""
|
||||
Use the server's LLM to judge answer correctness.
|
||||
Falls back to keyword heuristic if LLM call fails.
|
||||
"""
|
||||
if not model_answer or not model_answer.strip():
|
||||
return 0.0
|
||||
|
||||
judge_prompt = (
|
||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
||||
f"Question: {question}\n\n"
|
||||
f"Reference answer: {expected}\n\n"
|
||||
f"Model answer: {model_answer}\n\n"
|
||||
"Score the model answer on a scale from 0.0 to 1.0 where:\n"
|
||||
" 1.0 = fully correct and complete\n"
|
||||
" 0.7 = mostly correct with minor gaps\n"
|
||||
" 0.4 = partially correct\n"
|
||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
||||
" 0.0 = completely wrong or no answer\n\n"
|
||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
||||
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
n=1,
|
||||
max_tokens=150,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
text = response.choices[0].message.content if response.choices else ""
|
||||
parsed = self._parse_judge_json(text)
|
||||
if parsed is not None:
|
||||
return float(parsed)
|
||||
except Exception as e:
|
||||
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
|
||||
|
||||
return self._heuristic_score(expected, model_answer)
|
||||
|
||||
@staticmethod
|
||||
def _parse_judge_json(text: str) -> Optional[float]:
|
||||
"""Extract the score float from LLM judge JSON response."""
|
||||
try:
|
||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
||||
data = json.loads(clean)
|
||||
score = float(data.get("score", -1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
except Exception:
|
||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
||||
"""Lightweight keyword overlap score as fallback."""
|
||||
stopwords = {
|
||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
||||
"this", "that", "as", "by", "from", "be", "has", "have", "had",
|
||||
}
|
||||
|
||||
def tokenize(text: str) -> set:
|
||||
tokens = re.findall(r'\b\w+\b', text.lower())
|
||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
||||
|
||||
expected_tokens = tokenize(expected)
|
||||
answer_tokens = tokenize(model_answer)
|
||||
|
||||
if not expected_tokens:
|
||||
return 0.5
|
||||
|
||||
overlap = len(expected_tokens & answer_tokens)
|
||||
union = len(expected_tokens | answer_tokens)
|
||||
|
||||
jaccard = overlap / union if union > 0 else 0.0
|
||||
recall = overlap / len(expected_tokens)
|
||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
||||
|
||||
@staticmethod
|
||||
def _extract_domains(text: str) -> set:
|
||||
"""Extract unique domains from URLs cited in the response."""
|
||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
||||
domains = set()
|
||||
for url in urls:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc.lower().lstrip("www.")
|
||||
if domain:
|
||||
domains.add(domain)
|
||||
except Exception:
|
||||
pass
|
||||
return domains
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
WebResearchEnv.cli()
|
||||
34
gateway/__init__.py
Normal file
34
gateway/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Hermes Gateway - Multi-platform messaging integration.
|
||||
|
||||
This module provides a unified gateway for connecting the Hermes agent
|
||||
to various messaging platforms (Telegram, Discord, WhatsApp) with:
|
||||
- Session management (persistent conversations with reset policies)
|
||||
- Dynamic context injection (agent knows where messages come from)
|
||||
- Delivery routing (cron job outputs to appropriate channels)
|
||||
- Platform-specific toolsets (different capabilities per platform)
|
||||
"""
|
||||
|
||||
from .config import GatewayConfig, HomeChannel, PlatformConfig, SessionResetPolicy, load_gateway_config
|
||||
from .delivery import DeliveryRouter, DeliveryTarget
|
||||
from .session import (
|
||||
SessionContext,
|
||||
SessionStore,
|
||||
build_session_context_prompt,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"GatewayConfig",
|
||||
"PlatformConfig",
|
||||
"HomeChannel",
|
||||
"load_gateway_config",
|
||||
# Session
|
||||
"SessionContext",
|
||||
"SessionStore",
|
||||
"SessionResetPolicy",
|
||||
"build_session_context_prompt",
|
||||
# Delivery
|
||||
"DeliveryRouter",
|
||||
"DeliveryTarget",
|
||||
]
|
||||
242
gateway/channel_directory.py
Normal file
242
gateway/channel_directory.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Channel directory -- cached map of reachable channels/contacts per platform.
|
||||
|
||||
Built on gateway startup, refreshed periodically (every 5 min), and saved to
|
||||
~/.hermes/channel_directory.json. The send_message tool reads this file for
|
||||
action="list" and for resolving human-friendly channel names to numeric IDs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DIRECTORY_PATH = Path.home() / ".hermes" / "channel_directory.json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build / refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_channel_directory(adapters: dict[Any, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Build a channel directory from connected platform adapters and session data.
|
||||
|
||||
Returns the directory dict and writes it to DIRECTORY_PATH.
|
||||
"""
|
||||
from gateway.config import Platform
|
||||
|
||||
platforms: dict[str, list[dict[str, str]]] = {}
|
||||
|
||||
for platform, adapter in adapters.items():
|
||||
try:
|
||||
if platform == Platform.DISCORD:
|
||||
platforms["discord"] = _build_discord(adapter)
|
||||
elif platform == Platform.SLACK:
|
||||
platforms["slack"] = _build_slack(adapter)
|
||||
except Exception as e:
|
||||
logger.warning("Channel directory: failed to build %s: %s", platform.value, e)
|
||||
|
||||
# Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history
|
||||
for plat_name in ("telegram", "whatsapp", "signal"):
|
||||
if plat_name not in platforms:
|
||||
platforms[plat_name] = _build_from_sessions(plat_name)
|
||||
|
||||
directory = {
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"platforms": platforms,
|
||||
}
|
||||
|
||||
try:
|
||||
DIRECTORY_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(DIRECTORY_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(directory, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.warning("Channel directory: failed to write: %s", e)
|
||||
|
||||
return directory
|
||||
|
||||
|
||||
def _build_discord(adapter) -> list[dict[str, str]]:
|
||||
"""Enumerate all text channels the Discord bot can see."""
|
||||
channels = []
|
||||
client = getattr(adapter, "_client", None)
|
||||
if not client:
|
||||
return channels
|
||||
|
||||
try:
|
||||
import discord as _discord
|
||||
except ImportError:
|
||||
return channels
|
||||
|
||||
for guild in client.guilds:
|
||||
for ch in guild.text_channels:
|
||||
channels.append(
|
||||
{
|
||||
"id": str(ch.id),
|
||||
"name": ch.name,
|
||||
"guild": guild.name,
|
||||
"type": "channel",
|
||||
}
|
||||
)
|
||||
# Also include DM-capable users we've interacted with is not
|
||||
# feasible via guild enumeration; those come from sessions.
|
||||
|
||||
# Merge any DMs from session history
|
||||
channels.extend(_build_from_sessions("discord"))
|
||||
return channels
|
||||
|
||||
|
||||
def _build_slack(adapter) -> list[dict[str, str]]:
|
||||
"""List Slack channels the bot has joined."""
|
||||
channels = []
|
||||
# Slack adapter may expose a web client
|
||||
client = getattr(adapter, "_app", None) or getattr(adapter, "_client", None)
|
||||
if not client:
|
||||
return _build_from_sessions("slack")
|
||||
|
||||
try:
|
||||
from tools.send_message_tool import _send_slack # noqa: F401
|
||||
# Use the Slack Web API directly if available
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to session data
|
||||
return _build_from_sessions("slack")
|
||||
|
||||
|
||||
def _build_from_sessions(platform_name: str) -> list[dict[str, str]]:
|
||||
"""Pull known channels/contacts from sessions.json origin data."""
|
||||
sessions_path = Path.home() / ".hermes" / "sessions" / "sessions.json"
|
||||
if not sessions_path.exists():
|
||||
return []
|
||||
|
||||
entries = []
|
||||
try:
|
||||
with open(sessions_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
seen_ids = set()
|
||||
for _key, session in data.items():
|
||||
origin = session.get("origin") or {}
|
||||
if origin.get("platform") != platform_name:
|
||||
continue
|
||||
chat_id = origin.get("chat_id")
|
||||
if not chat_id or chat_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(chat_id)
|
||||
entries.append(
|
||||
{
|
||||
"id": str(chat_id),
|
||||
"name": origin.get("chat_name") or origin.get("user_name") or str(chat_id),
|
||||
"type": session.get("chat_type", "dm"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Channel directory: failed to read sessions for %s: %s", platform_name, e)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read / resolve
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_directory() -> dict[str, Any]:
|
||||
"""Load the cached channel directory from disk."""
|
||||
if not DIRECTORY_PATH.exists():
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
try:
|
||||
with open(DIRECTORY_PATH, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
|
||||
|
||||
def resolve_channel_name(platform_name: str, name: str) -> str | None:
|
||||
"""
|
||||
Resolve a human-friendly channel name to a numeric ID.
|
||||
|
||||
Matching strategy (case-insensitive, first match wins):
|
||||
- Discord: "bot-home", "#bot-home", "GuildName/bot-home"
|
||||
- Telegram: display name or group name
|
||||
- Slack: "engineering", "#engineering"
|
||||
"""
|
||||
directory = load_directory()
|
||||
channels = directory.get("platforms", {}).get(platform_name, [])
|
||||
if not channels:
|
||||
return None
|
||||
|
||||
query = name.lstrip("#").lower()
|
||||
|
||||
# 1. Exact name match
|
||||
for ch in channels:
|
||||
if ch["name"].lower() == query:
|
||||
return ch["id"]
|
||||
|
||||
# 2. Guild-qualified match for Discord ("GuildName/channel")
|
||||
if "/" in query:
|
||||
guild_part, ch_part = query.rsplit("/", 1)
|
||||
for ch in channels:
|
||||
guild = ch.get("guild", "").lower()
|
||||
if guild == guild_part and ch["name"].lower() == ch_part:
|
||||
return ch["id"]
|
||||
|
||||
# 3. Partial prefix match (only if unambiguous)
|
||||
matches = [ch for ch in channels if ch["name"].lower().startswith(query)]
|
||||
if len(matches) == 1:
|
||||
return matches[0]["id"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def format_directory_for_display() -> str:
|
||||
"""Format the channel directory as a human-readable list for the model."""
|
||||
directory = load_directory()
|
||||
platforms = directory.get("platforms", {})
|
||||
|
||||
if not any(platforms.values()):
|
||||
return "No messaging platforms connected or no channels discovered yet."
|
||||
|
||||
lines = ["Available messaging targets:\n"]
|
||||
|
||||
for plat_name, channels in sorted(platforms.items()):
|
||||
if not channels:
|
||||
continue
|
||||
|
||||
# Group Discord channels by guild
|
||||
if plat_name == "discord":
|
||||
guilds: dict[str, list] = {}
|
||||
dms: list = []
|
||||
for ch in channels:
|
||||
guild = ch.get("guild")
|
||||
if guild:
|
||||
guilds.setdefault(guild, []).append(ch)
|
||||
else:
|
||||
dms.append(ch)
|
||||
|
||||
for guild_name, guild_channels in sorted(guilds.items()):
|
||||
lines.append(f"Discord ({guild_name}):")
|
||||
for ch in sorted(guild_channels, key=lambda c: c["name"]):
|
||||
lines.append(f" discord:#{ch['name']}")
|
||||
if dms:
|
||||
lines.append("Discord (DMs):")
|
||||
for ch in dms:
|
||||
lines.append(f" discord:{ch['name']}")
|
||||
lines.append("")
|
||||
else:
|
||||
lines.append(f"{plat_name.title()}:")
|
||||
for ch in channels:
|
||||
type_label = f" ({ch['type']})" if ch.get("type") else ""
|
||||
lines.append(f" {plat_name}:{ch['name']}{type_label}")
|
||||
lines.append("")
|
||||
|
||||
lines.append('Use these as the "target" parameter when sending.')
|
||||
lines.append('Bare platform name (e.g. "telegram") sends to home channel.')
|
||||
|
||||
return "\n".join(lines)
|
||||
441
gateway/config.py
Normal file
441
gateway/config.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
Gateway configuration management.
|
||||
|
||||
Handles loading and validating configuration for:
|
||||
- Connected platforms (Telegram, Discord, WhatsApp)
|
||||
- Home channels for each platform
|
||||
- Session reset policies
|
||||
- Delivery preferences
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Platform(Enum):
|
||||
"""Supported messaging platforms."""
|
||||
|
||||
LOCAL = "local"
|
||||
TELEGRAM = "telegram"
|
||||
DISCORD = "discord"
|
||||
WHATSAPP = "whatsapp"
|
||||
SLACK = "slack"
|
||||
SIGNAL = "signal"
|
||||
HOMEASSISTANT = "homeassistant"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HomeChannel:
|
||||
"""
|
||||
Default destination for a platform.
|
||||
|
||||
When a cron job specifies deliver="telegram" without a specific chat ID,
|
||||
messages are sent to this home channel.
|
||||
"""
|
||||
|
||||
platform: Platform
|
||||
chat_id: str
|
||||
name: str # Human-readable name for display
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"platform": self.platform.value,
|
||||
"chat_id": self.chat_id,
|
||||
"name": self.name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "HomeChannel":
|
||||
return cls(
|
||||
platform=Platform(data["platform"]),
|
||||
chat_id=str(data["chat_id"]),
|
||||
name=data.get("name", "Home"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionResetPolicy:
|
||||
"""
|
||||
Controls when sessions reset (lose context).
|
||||
|
||||
Modes:
|
||||
- "daily": Reset at a specific hour each day
|
||||
- "idle": Reset after N minutes of inactivity
|
||||
- "both": Whichever triggers first (daily boundary OR idle timeout)
|
||||
- "none": Never auto-reset (context managed only by compression)
|
||||
"""
|
||||
|
||||
mode: str = "both" # "daily", "idle", "both", or "none"
|
||||
at_hour: int = 4 # Hour for daily reset (0-23, local time)
|
||||
idle_minutes: int = 1440 # Minutes of inactivity before reset (24 hours)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"mode": self.mode,
|
||||
"at_hour": self.at_hour,
|
||||
"idle_minutes": self.idle_minutes,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionResetPolicy":
|
||||
return cls(
|
||||
mode=data.get("mode", "both"),
|
||||
at_hour=data.get("at_hour", 4),
|
||||
idle_minutes=data.get("idle_minutes", 1440),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformConfig:
|
||||
"""Configuration for a single messaging platform."""
|
||||
|
||||
enabled: bool = False
|
||||
token: str | None = None # Bot token (Telegram, Discord)
|
||||
api_key: str | None = None # API key if different from token
|
||||
home_channel: HomeChannel | None = None
|
||||
|
||||
# Platform-specific settings
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {
|
||||
"enabled": self.enabled,
|
||||
"extra": self.extra,
|
||||
}
|
||||
if self.token:
|
||||
result["token"] = self.token
|
||||
if self.api_key:
|
||||
result["api_key"] = self.api_key
|
||||
if self.home_channel:
|
||||
result["home_channel"] = self.home_channel.to_dict()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "PlatformConfig":
|
||||
home_channel = None
|
||||
if "home_channel" in data:
|
||||
home_channel = HomeChannel.from_dict(data["home_channel"])
|
||||
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
token=data.get("token"),
|
||||
api_key=data.get("api_key"),
|
||||
home_channel=home_channel,
|
||||
extra=data.get("extra", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatewayConfig:
|
||||
"""
|
||||
Main gateway configuration.
|
||||
|
||||
Manages all platform connections, session policies, and delivery settings.
|
||||
"""
|
||||
|
||||
# Platform configurations
|
||||
platforms: dict[Platform, PlatformConfig] = field(default_factory=dict)
|
||||
|
||||
# Session reset policies by type
|
||||
default_reset_policy: SessionResetPolicy = field(default_factory=SessionResetPolicy)
|
||||
reset_by_type: dict[str, SessionResetPolicy] = field(default_factory=dict)
|
||||
reset_by_platform: dict[Platform, SessionResetPolicy] = field(default_factory=dict)
|
||||
|
||||
# Reset trigger commands
|
||||
reset_triggers: list[str] = field(default_factory=lambda: ["/new", "/reset"])
|
||||
|
||||
# Storage paths
|
||||
sessions_dir: Path = field(default_factory=lambda: Path.home() / ".hermes" / "sessions")
|
||||
|
||||
# Delivery settings
|
||||
always_log_local: bool = True # Always save cron outputs to local files
|
||||
|
||||
def get_connected_platforms(self) -> list[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
for platform, config in self.platforms.items():
|
||||
if not config.enabled:
|
||||
continue
|
||||
# Platforms that use token/api_key auth
|
||||
if (
|
||||
config.token
|
||||
or config.api_key
|
||||
or platform == Platform.WHATSAPP
|
||||
or platform == Platform.SIGNAL
|
||||
and config.extra.get("http_url")
|
||||
):
|
||||
connected.append(platform)
|
||||
return connected
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> HomeChannel | None:
|
||||
"""Get the home channel for a platform."""
|
||||
config = self.platforms.get(platform)
|
||||
if config:
|
||||
return config.home_channel
|
||||
return None
|
||||
|
||||
def get_reset_policy(self, platform: Platform | None = None, session_type: str | None = None) -> SessionResetPolicy:
|
||||
"""
|
||||
Get the appropriate reset policy for a session.
|
||||
|
||||
Priority: platform override > type override > default
|
||||
"""
|
||||
# Platform-specific override takes precedence
|
||||
if platform and platform in self.reset_by_platform:
|
||||
return self.reset_by_platform[platform]
|
||||
|
||||
# Type-specific override (dm, group, thread)
|
||||
if session_type and session_type in self.reset_by_type:
|
||||
return self.reset_by_type[session_type]
|
||||
|
||||
return self.default_reset_policy
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"platforms": {p.value: c.to_dict() for p, c in self.platforms.items()},
|
||||
"default_reset_policy": self.default_reset_policy.to_dict(),
|
||||
"reset_by_type": {k: v.to_dict() for k, v in self.reset_by_type.items()},
|
||||
"reset_by_platform": {p.value: v.to_dict() for p, v in self.reset_by_platform.items()},
|
||||
"reset_triggers": self.reset_triggers,
|
||||
"sessions_dir": str(self.sessions_dir),
|
||||
"always_log_local": self.always_log_local,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "GatewayConfig":
|
||||
platforms = {}
|
||||
for platform_name, platform_data in data.get("platforms", {}).items():
|
||||
try:
|
||||
platform = Platform(platform_name)
|
||||
platforms[platform] = PlatformConfig.from_dict(platform_data)
|
||||
except ValueError:
|
||||
pass # Skip unknown platforms
|
||||
|
||||
reset_by_type = {}
|
||||
for type_name, policy_data in data.get("reset_by_type", {}).items():
|
||||
reset_by_type[type_name] = SessionResetPolicy.from_dict(policy_data)
|
||||
|
||||
reset_by_platform = {}
|
||||
for platform_name, policy_data in data.get("reset_by_platform", {}).items():
|
||||
try:
|
||||
platform = Platform(platform_name)
|
||||
reset_by_platform[platform] = SessionResetPolicy.from_dict(policy_data)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
default_policy = SessionResetPolicy()
|
||||
if "default_reset_policy" in data:
|
||||
default_policy = SessionResetPolicy.from_dict(data["default_reset_policy"])
|
||||
|
||||
sessions_dir = Path.home() / ".hermes" / "sessions"
|
||||
if "sessions_dir" in data:
|
||||
sessions_dir = Path(data["sessions_dir"])
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
reset_by_type=reset_by_type,
|
||||
reset_by_platform=reset_by_platform,
|
||||
reset_triggers=data.get("reset_triggers", ["/new", "/reset"]),
|
||||
sessions_dir=sessions_dir,
|
||||
always_log_local=data.get("always_log_local", True),
|
||||
)
|
||||
|
||||
|
||||
def load_gateway_config() -> GatewayConfig:
|
||||
"""
|
||||
Load gateway configuration from multiple sources.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. Environment variables
|
||||
2. ~/.hermes/gateway.json
|
||||
3. cli-config.yaml gateway section
|
||||
4. Defaults
|
||||
"""
|
||||
config = GatewayConfig()
|
||||
|
||||
# Try loading from ~/.hermes/gateway.json
|
||||
gateway_config_path = Path.home() / ".hermes" / "gateway.json"
|
||||
if gateway_config_path.exists():
|
||||
try:
|
||||
with open(gateway_config_path) as f:
|
||||
data = json.load(f)
|
||||
config = GatewayConfig.from_dict(data)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load {gateway_config_path}: {e}")
|
||||
|
||||
# Bridge session_reset from config.yaml (the user-facing config file)
|
||||
# into the gateway config. config.yaml takes precedence over gateway.json
|
||||
# for session reset policy since that's where hermes setup writes it.
|
||||
try:
|
||||
import yaml
|
||||
|
||||
config_yaml_path = Path.home() / ".hermes" / "config.yaml"
|
||||
if config_yaml_path.exists():
|
||||
with open(config_yaml_path) as f:
|
||||
yaml_cfg = yaml.safe_load(f) or {}
|
||||
sr = yaml_cfg.get("session_reset")
|
||||
if sr and isinstance(sr, dict):
|
||||
config.default_reset_policy = SessionResetPolicy.from_dict(sr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Override with environment variables
|
||||
_apply_env_overrides(config)
|
||||
|
||||
# --- Validate loaded values ---
|
||||
policy = config.default_reset_policy
|
||||
|
||||
if not (0 <= policy.at_hour <= 23):
|
||||
logger.warning("Invalid at_hour=%s (must be 0-23). Using default 4.", policy.at_hour)
|
||||
policy.at_hour = 4
|
||||
|
||||
if policy.idle_minutes is None or policy.idle_minutes <= 0:
|
||||
logger.warning(
|
||||
"Invalid idle_minutes=%s (must be positive). Using default 1440.",
|
||||
policy.idle_minutes,
|
||||
)
|
||||
policy.idle_minutes = 1440
|
||||
|
||||
# Warn about empty bot tokens — platforms that loaded an empty string
|
||||
# won't connect and the cause can be confusing without a log line.
|
||||
_token_env_names = {
|
||||
Platform.TELEGRAM: "TELEGRAM_BOT_TOKEN",
|
||||
Platform.DISCORD: "DISCORD_BOT_TOKEN",
|
||||
Platform.SLACK: "SLACK_BOT_TOKEN",
|
||||
}
|
||||
for platform, pconfig in config.platforms.items():
|
||||
if not pconfig.enabled:
|
||||
continue
|
||||
env_name = _token_env_names.get(platform)
|
||||
if env_name and pconfig.token is not None and not pconfig.token.strip():
|
||||
logger.warning(
|
||||
"%s is enabled but %s is empty. The adapter will likely fail to connect.",
|
||||
platform.value,
|
||||
env_name,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
"""Apply environment variable overrides to config."""
|
||||
|
||||
# Telegram
|
||||
telegram_token = os.getenv("TELEGRAM_BOT_TOKEN")
|
||||
if telegram_token:
|
||||
if Platform.TELEGRAM not in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM] = PlatformConfig()
|
||||
config.platforms[Platform.TELEGRAM].enabled = True
|
||||
config.platforms[Platform.TELEGRAM].token = telegram_token
|
||||
|
||||
telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL")
|
||||
if telegram_home and Platform.TELEGRAM in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM].home_channel = HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=telegram_home,
|
||||
name=os.getenv("TELEGRAM_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Discord
|
||||
discord_token = os.getenv("DISCORD_BOT_TOKEN")
|
||||
if discord_token:
|
||||
if Platform.DISCORD not in config.platforms:
|
||||
config.platforms[Platform.DISCORD] = PlatformConfig()
|
||||
config.platforms[Platform.DISCORD].enabled = True
|
||||
config.platforms[Platform.DISCORD].token = discord_token
|
||||
|
||||
discord_home = os.getenv("DISCORD_HOME_CHANNEL")
|
||||
if discord_home and Platform.DISCORD in config.platforms:
|
||||
config.platforms[Platform.DISCORD].home_channel = HomeChannel(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id=discord_home,
|
||||
name=os.getenv("DISCORD_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# WhatsApp (typically uses different auth mechanism)
|
||||
whatsapp_enabled = os.getenv("WHATSAPP_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
if whatsapp_enabled:
|
||||
if Platform.WHATSAPP not in config.platforms:
|
||||
config.platforms[Platform.WHATSAPP] = PlatformConfig()
|
||||
config.platforms[Platform.WHATSAPP].enabled = True
|
||||
|
||||
# Slack
|
||||
slack_token = os.getenv("SLACK_BOT_TOKEN")
|
||||
if slack_token:
|
||||
if Platform.SLACK not in config.platforms:
|
||||
config.platforms[Platform.SLACK] = PlatformConfig()
|
||||
config.platforms[Platform.SLACK].enabled = True
|
||||
config.platforms[Platform.SLACK].token = slack_token
|
||||
# Home channel
|
||||
slack_home = os.getenv("SLACK_HOME_CHANNEL")
|
||||
if slack_home:
|
||||
config.platforms[Platform.SLACK].home_channel = HomeChannel(
|
||||
platform=Platform.SLACK,
|
||||
chat_id=slack_home,
|
||||
name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""),
|
||||
)
|
||||
|
||||
# Signal
|
||||
signal_url = os.getenv("SIGNAL_HTTP_URL")
|
||||
signal_account = os.getenv("SIGNAL_ACCOUNT")
|
||||
if signal_url and signal_account:
|
||||
if Platform.SIGNAL not in config.platforms:
|
||||
config.platforms[Platform.SIGNAL] = PlatformConfig()
|
||||
config.platforms[Platform.SIGNAL].enabled = True
|
||||
config.platforms[Platform.SIGNAL].extra.update(
|
||||
{
|
||||
"http_url": signal_url,
|
||||
"account": signal_account,
|
||||
"ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"),
|
||||
}
|
||||
)
|
||||
signal_home = os.getenv("SIGNAL_HOME_CHANNEL")
|
||||
if signal_home:
|
||||
config.platforms[Platform.SIGNAL].home_channel = HomeChannel(
|
||||
platform=Platform.SIGNAL,
|
||||
chat_id=signal_home,
|
||||
name=os.getenv("SIGNAL_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Home Assistant
|
||||
hass_token = os.getenv("HASS_TOKEN")
|
||||
if hass_token:
|
||||
if Platform.HOMEASSISTANT not in config.platforms:
|
||||
config.platforms[Platform.HOMEASSISTANT] = PlatformConfig()
|
||||
config.platforms[Platform.HOMEASSISTANT].enabled = True
|
||||
config.platforms[Platform.HOMEASSISTANT].token = hass_token
|
||||
hass_url = os.getenv("HASS_URL")
|
||||
if hass_url:
|
||||
config.platforms[Platform.HOMEASSISTANT].extra["url"] = hass_url
|
||||
|
||||
# Session settings
|
||||
idle_minutes = os.getenv("SESSION_IDLE_MINUTES")
|
||||
if idle_minutes:
|
||||
try:
|
||||
config.default_reset_policy.idle_minutes = int(idle_minutes)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
reset_hour = os.getenv("SESSION_RESET_HOUR")
|
||||
if reset_hour:
|
||||
try:
|
||||
config.default_reset_policy.at_hour = int(reset_hour)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def save_gateway_config(config: GatewayConfig) -> None:
|
||||
"""Save gateway configuration to ~/.hermes/gateway.json."""
|
||||
gateway_config_path = Path.home() / ".hermes" / "gateway.json"
|
||||
gateway_config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(gateway_config_path, "w") as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
312
gateway/delivery.py
Normal file
312
gateway/delivery.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Delivery routing for cron job outputs and agent responses.
|
||||
|
||||
Routes messages to the appropriate destination based on:
|
||||
- Explicit targets (e.g., "telegram:123456789")
|
||||
- Platform home channels (e.g., "telegram" → home channel)
|
||||
- Origin (back to where the job was created)
|
||||
- Local (always saved to files)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_PLATFORM_OUTPUT = 4000
|
||||
TRUNCATED_VISIBLE = 3800
|
||||
|
||||
from .config import GatewayConfig, Platform
|
||||
from .session import SessionSource
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeliveryTarget:
|
||||
"""
|
||||
A single delivery target.
|
||||
|
||||
Represents where a message should be sent:
|
||||
- "origin" → back to source
|
||||
- "local" → save to local files
|
||||
- "telegram" → Telegram home channel
|
||||
- "telegram:123456" → specific Telegram chat
|
||||
"""
|
||||
|
||||
platform: Platform
|
||||
chat_id: str | None = None # None means use home channel
|
||||
is_origin: bool = False
|
||||
is_explicit: bool = False # True if chat_id was explicitly specified
|
||||
|
||||
@classmethod
|
||||
def parse(cls, target: str, origin: SessionSource | None = None) -> "DeliveryTarget":
|
||||
"""
|
||||
Parse a delivery target string.
|
||||
|
||||
Formats:
|
||||
- "origin" → back to source
|
||||
- "local" → local files only
|
||||
- "telegram" → Telegram home channel
|
||||
- "telegram:123456" → specific Telegram chat
|
||||
"""
|
||||
target = target.strip().lower()
|
||||
|
||||
if target == "origin":
|
||||
if origin:
|
||||
return cls(
|
||||
platform=origin.platform,
|
||||
chat_id=origin.chat_id,
|
||||
is_origin=True,
|
||||
)
|
||||
else:
|
||||
# Fallback to local if no origin
|
||||
return cls(platform=Platform.LOCAL, is_origin=True)
|
||||
|
||||
if target == "local":
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
# Check for platform:chat_id format
|
||||
if ":" in target:
|
||||
platform_str, chat_id = target.split(":", 1)
|
||||
try:
|
||||
platform = Platform(platform_str)
|
||||
return cls(platform=platform, chat_id=chat_id, is_explicit=True)
|
||||
except ValueError:
|
||||
# Unknown platform, treat as local
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
# Just a platform name (use home channel)
|
||||
try:
|
||||
platform = Platform(target)
|
||||
return cls(platform=platform)
|
||||
except ValueError:
|
||||
# Unknown platform, treat as local
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Convert back to string format."""
|
||||
if self.is_origin:
|
||||
return "origin"
|
||||
if self.platform == Platform.LOCAL:
|
||||
return "local"
|
||||
if self.chat_id:
|
||||
return f"{self.platform.value}:{self.chat_id}"
|
||||
return self.platform.value
|
||||
|
||||
|
||||
class DeliveryRouter:
|
||||
"""
|
||||
Routes messages to appropriate destinations.
|
||||
|
||||
Handles the logic of resolving delivery targets and dispatching
|
||||
messages to the right platform adapters.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GatewayConfig, adapters: dict[Platform, Any] = None):
|
||||
"""
|
||||
Initialize the delivery router.
|
||||
|
||||
Args:
|
||||
config: Gateway configuration
|
||||
adapters: Dict mapping platforms to their adapter instances
|
||||
"""
|
||||
self.config = config
|
||||
self.adapters = adapters or {}
|
||||
self.output_dir = Path.home() / ".hermes" / "cron" / "output"
|
||||
|
||||
def resolve_targets(self, deliver: str | list[str], origin: SessionSource | None = None) -> list[DeliveryTarget]:
|
||||
"""
|
||||
Resolve delivery specification to concrete targets.
|
||||
|
||||
Args:
|
||||
deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc.
|
||||
origin: The source where the request originated (for "origin" target)
|
||||
|
||||
Returns:
|
||||
List of resolved delivery targets
|
||||
"""
|
||||
if isinstance(deliver, str):
|
||||
deliver = [deliver]
|
||||
|
||||
targets = []
|
||||
seen_platforms = set()
|
||||
|
||||
for target_str in deliver:
|
||||
target = DeliveryTarget.parse(target_str, origin)
|
||||
|
||||
# Resolve home channel if needed
|
||||
if target.chat_id is None and target.platform != Platform.LOCAL:
|
||||
home = self.config.get_home_channel(target.platform)
|
||||
if home:
|
||||
target.chat_id = home.chat_id
|
||||
else:
|
||||
# No home channel configured, skip this platform
|
||||
continue
|
||||
|
||||
# Deduplicate
|
||||
key = (target.platform, target.chat_id)
|
||||
if key not in seen_platforms:
|
||||
seen_platforms.add(key)
|
||||
targets.append(target)
|
||||
|
||||
# Always include local if configured
|
||||
if self.config.always_log_local:
|
||||
local_key = (Platform.LOCAL, None)
|
||||
if local_key not in seen_platforms:
|
||||
targets.append(DeliveryTarget(platform=Platform.LOCAL))
|
||||
|
||||
return targets
|
||||
|
||||
async def deliver(
|
||||
self,
|
||||
content: str,
|
||||
targets: list[DeliveryTarget],
|
||||
job_id: str | None = None,
|
||||
job_name: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Deliver content to all specified targets.
|
||||
|
||||
Args:
|
||||
content: The message/output to deliver
|
||||
targets: List of delivery targets
|
||||
job_id: Optional job ID (for cron jobs)
|
||||
job_name: Optional job name
|
||||
metadata: Additional metadata to include
|
||||
|
||||
Returns:
|
||||
Dict with delivery results per target
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for target in targets:
|
||||
try:
|
||||
if target.platform == Platform.LOCAL:
|
||||
result = self._deliver_local(content, job_id, job_name, metadata)
|
||||
else:
|
||||
result = await self._deliver_to_platform(target, content, metadata)
|
||||
|
||||
results[target.to_string()] = {"success": True, "result": result}
|
||||
except Exception as e:
|
||||
results[target.to_string()] = {"success": False, "error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
def _deliver_local(
|
||||
self, content: str, job_id: str | None, job_name: str | None, metadata: dict[str, Any] | None
|
||||
) -> dict[str, Any]:
|
||||
"""Save content to local files."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if job_id:
|
||||
output_path = self.output_dir / job_id / f"{timestamp}.md"
|
||||
else:
|
||||
output_path = self.output_dir / "misc" / f"{timestamp}.md"
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build the output document
|
||||
lines = []
|
||||
if job_name:
|
||||
lines.append(f"# {job_name}")
|
||||
else:
|
||||
lines.append("# Delivery Output")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
if job_id:
|
||||
lines.append(f"**Job ID:** {job_id}")
|
||||
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
lines.append(f"**{key}:** {value}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
lines.append(content)
|
||||
|
||||
output_path.write_text("\n".join(lines))
|
||||
|
||||
return {"path": str(output_path), "timestamp": timestamp}
|
||||
|
||||
def _save_full_output(self, content: str, job_id: str) -> Path:
|
||||
"""Save full cron output to disk and return the file path."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_dir = Path.home() / ".hermes" / "cron" / "output"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = out_dir / f"{job_id}_{timestamp}.txt"
|
||||
path.write_text(content)
|
||||
return path
|
||||
|
||||
async def _deliver_to_platform(
|
||||
self, target: DeliveryTarget, content: str, metadata: dict[str, Any] | None
|
||||
) -> dict[str, Any]:
|
||||
"""Deliver content to a messaging platform."""
|
||||
adapter = self.adapters.get(target.platform)
|
||||
|
||||
if not adapter:
|
||||
raise ValueError(f"No adapter configured for {target.platform.value}")
|
||||
|
||||
if not target.chat_id:
|
||||
raise ValueError(f"No chat ID for {target.platform.value} delivery")
|
||||
|
||||
# Guard: truncate oversized cron output to stay within platform limits
|
||||
if len(content) > MAX_PLATFORM_OUTPUT:
|
||||
job_id = (metadata or {}).get("job_id", "unknown")
|
||||
saved_path = self._save_full_output(content, job_id)
|
||||
logger.info("Cron output truncated (%d chars) — full output: %s", len(content), saved_path)
|
||||
content = content[:TRUNCATED_VISIBLE] + f"\n\n... [truncated, full output saved to {saved_path}]"
|
||||
|
||||
return await adapter.send(target.chat_id, content, metadata=metadata)
|
||||
|
||||
|
||||
def parse_deliver_spec(
|
||||
deliver: str | list[str] | None, origin: SessionSource | None = None, default: str = "origin"
|
||||
) -> str | list[str]:
|
||||
"""
|
||||
Normalize a delivery specification.
|
||||
|
||||
If None or empty, returns the default.
|
||||
"""
|
||||
if not deliver:
|
||||
return default
|
||||
return deliver
|
||||
|
||||
|
||||
def build_delivery_context_for_tool(config: GatewayConfig, origin: SessionSource | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
Build context for the schedule_cronjob tool to understand delivery options.
|
||||
|
||||
This is passed to the tool so it can validate and explain delivery targets.
|
||||
"""
|
||||
connected = config.get_connected_platforms()
|
||||
|
||||
options = {
|
||||
"origin": {
|
||||
"description": "Back to where this job was created",
|
||||
"available": origin is not None,
|
||||
},
|
||||
"local": {
|
||||
"description": "Save to local files only",
|
||||
"available": True,
|
||||
},
|
||||
}
|
||||
|
||||
for platform in connected:
|
||||
home = config.get_home_channel(platform)
|
||||
options[platform.value] = {
|
||||
"description": f"{platform.value.title()} home channel",
|
||||
"available": True,
|
||||
"home_channel": home.to_dict() if home else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"origin": origin.to_dict() if origin else None,
|
||||
"options": options,
|
||||
"always_log_local": config.always_log_local,
|
||||
}
|
||||
150
gateway/hooks.py
Normal file
150
gateway/hooks.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Event Hook System
|
||||
|
||||
A lightweight event-driven system that fires handlers at key lifecycle points.
|
||||
Hooks are discovered from ~/.hermes/hooks/ directories, each containing:
|
||||
- HOOK.yaml (metadata: name, description, events list)
|
||||
- handler.py (Python handler with async def handle(event_type, context))
|
||||
|
||||
Events:
|
||||
- gateway:startup -- Gateway process starts
|
||||
- session:start -- New session created
|
||||
- session:reset -- User ran /new or /reset
|
||||
- agent:start -- Agent begins processing a message
|
||||
- agent:step -- Each turn in the tool-calling loop
|
||||
- agent:end -- Agent finishes processing
|
||||
- command:* -- Any slash command executed (wildcard match)
|
||||
|
||||
Errors in hooks are caught and logged but never block the main pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
HOOKS_DIR = Path(os.path.expanduser("~/.hermes/hooks"))
|
||||
|
||||
|
||||
class HookRegistry:
|
||||
"""
|
||||
Discovers, loads, and fires event hooks.
|
||||
|
||||
Usage:
|
||||
registry = HookRegistry()
|
||||
registry.discover_and_load()
|
||||
await registry.emit("agent:start", {"platform": "telegram", ...})
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# event_type -> [handler_fn, ...]
|
||||
self._handlers: dict[str, list[Callable]] = {}
|
||||
self._loaded_hooks: list[dict] = [] # metadata for listing
|
||||
|
||||
@property
|
||||
def loaded_hooks(self) -> list[dict]:
|
||||
"""Return metadata about all loaded hooks."""
|
||||
return list(self._loaded_hooks)
|
||||
|
||||
def discover_and_load(self) -> None:
|
||||
"""
|
||||
Scan the hooks directory for hook directories and load their handlers.
|
||||
|
||||
Each hook directory must contain:
|
||||
- HOOK.yaml with at least 'name' and 'events' keys
|
||||
- handler.py with a top-level 'handle' function (sync or async)
|
||||
"""
|
||||
if not HOOKS_DIR.exists():
|
||||
return
|
||||
|
||||
for hook_dir in sorted(HOOKS_DIR.iterdir()):
|
||||
if not hook_dir.is_dir():
|
||||
continue
|
||||
|
||||
manifest_path = hook_dir / "HOOK.yaml"
|
||||
handler_path = hook_dir / "handler.py"
|
||||
|
||||
if not manifest_path.exists() or not handler_path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
manifest = yaml.safe_load(manifest_path.read_text(encoding="utf-8"))
|
||||
if not manifest or not isinstance(manifest, dict):
|
||||
print(f"[hooks] Skipping {hook_dir.name}: invalid HOOK.yaml", flush=True)
|
||||
continue
|
||||
|
||||
hook_name = manifest.get("name", hook_dir.name)
|
||||
events = manifest.get("events", [])
|
||||
if not events:
|
||||
print(f"[hooks] Skipping {hook_name}: no events declared", flush=True)
|
||||
continue
|
||||
|
||||
# Dynamically load the handler module
|
||||
spec = importlib.util.spec_from_file_location(f"hermes_hook_{hook_name}", handler_path)
|
||||
if spec is None or spec.loader is None:
|
||||
print(f"[hooks] Skipping {hook_name}: could not load handler.py", flush=True)
|
||||
continue
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
handle_fn = getattr(module, "handle", None)
|
||||
if handle_fn is None:
|
||||
print(f"[hooks] Skipping {hook_name}: no 'handle' function found", flush=True)
|
||||
continue
|
||||
|
||||
# Register the handler for each declared event
|
||||
for event in events:
|
||||
self._handlers.setdefault(event, []).append(handle_fn)
|
||||
|
||||
self._loaded_hooks.append(
|
||||
{
|
||||
"name": hook_name,
|
||||
"description": manifest.get("description", ""),
|
||||
"events": events,
|
||||
"path": str(hook_dir),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"[hooks] Loaded hook '{hook_name}' for events: {events}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[hooks] Error loading hook {hook_dir.name}: {e}", flush=True)
|
||||
|
||||
async def emit(self, event_type: str, context: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Fire all handlers registered for an event.
|
||||
|
||||
Supports wildcard matching: handlers registered for "command:*" will
|
||||
fire for any "command:..." event. Handlers registered for a base type
|
||||
like "agent" won't fire for "agent:start" -- only exact matches and
|
||||
explicit wildcards.
|
||||
|
||||
Args:
|
||||
event_type: The event identifier (e.g. "agent:start").
|
||||
context: Optional dict with event-specific data.
|
||||
"""
|
||||
if context is None:
|
||||
context = {}
|
||||
|
||||
# Collect handlers: exact match + wildcard match
|
||||
handlers = list(self._handlers.get(event_type, []))
|
||||
|
||||
# Check for wildcard patterns (e.g., "command:*" matches "command:reset")
|
||||
if ":" in event_type:
|
||||
base = event_type.split(":")[0]
|
||||
wildcard_key = f"{base}:*"
|
||||
handlers.extend(self._handlers.get(wildcard_key, []))
|
||||
|
||||
for fn in handlers:
|
||||
try:
|
||||
result = fn(event_type, context)
|
||||
# Support both sync and async handlers
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception as e:
|
||||
print(f"[hooks] Error in handler for '{event_type}': {e}", flush=True)
|
||||
123
gateway/mirror.py
Normal file
123
gateway/mirror.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Session mirroring for cross-platform message delivery.
|
||||
|
||||
When a message is sent to a platform (via send_message or cron delivery),
|
||||
this module appends a "delivery-mirror" record to the target session's
|
||||
transcript so the receiving-side agent has context about what was sent.
|
||||
|
||||
Standalone -- works from CLI, cron, and gateway contexts without needing
|
||||
the full SessionStore machinery.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SESSIONS_DIR = Path.home() / ".hermes" / "sessions"
|
||||
_SESSIONS_INDEX = _SESSIONS_DIR / "sessions.json"
|
||||
|
||||
|
||||
def mirror_to_session(
|
||||
platform: str,
|
||||
chat_id: str,
|
||||
message_text: str,
|
||||
source_label: str = "cli",
|
||||
) -> bool:
|
||||
"""
|
||||
Append a delivery-mirror message to the target session's transcript.
|
||||
|
||||
Finds the gateway session that matches the given platform + chat_id,
|
||||
then writes a mirror entry to both the JSONL transcript and SQLite DB.
|
||||
|
||||
Returns True if mirrored successfully, False if no matching session or error.
|
||||
All errors are caught -- this is never fatal.
|
||||
"""
|
||||
try:
|
||||
session_id = _find_session_id(platform, str(chat_id))
|
||||
if not session_id:
|
||||
logger.debug("Mirror: no session found for %s:%s", platform, chat_id)
|
||||
return False
|
||||
|
||||
mirror_msg = {
|
||||
"role": "assistant",
|
||||
"content": message_text,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"mirror": True,
|
||||
"mirror_source": source_label,
|
||||
}
|
||||
|
||||
_append_to_jsonl(session_id, mirror_msg)
|
||||
_append_to_sqlite(session_id, mirror_msg)
|
||||
|
||||
logger.debug("Mirror: wrote to session %s (from %s)", session_id, source_label)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Mirror failed for %s:%s: %s", platform, chat_id, e)
|
||||
return False
|
||||
|
||||
|
||||
def _find_session_id(platform: str, chat_id: str) -> str | None:
|
||||
"""
|
||||
Find the active session_id for a platform + chat_id pair.
|
||||
|
||||
Scans sessions.json entries and matches where origin.chat_id == chat_id
|
||||
on the right platform. DM session keys don't embed the chat_id
|
||||
(e.g. "agent:main:telegram:dm"), so we check the origin dict.
|
||||
"""
|
||||
if not _SESSIONS_INDEX.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(_SESSIONS_INDEX, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
platform_lower = platform.lower()
|
||||
best_match = None
|
||||
best_updated = ""
|
||||
|
||||
for _key, entry in data.items():
|
||||
origin = entry.get("origin") or {}
|
||||
entry_platform = (origin.get("platform") or entry.get("platform", "")).lower()
|
||||
|
||||
if entry_platform != platform_lower:
|
||||
continue
|
||||
|
||||
origin_chat_id = str(origin.get("chat_id", ""))
|
||||
if origin_chat_id == str(chat_id):
|
||||
updated = entry.get("updated_at", "")
|
||||
if updated > best_updated:
|
||||
best_updated = updated
|
||||
best_match = entry.get("session_id")
|
||||
|
||||
return best_match
|
||||
|
||||
|
||||
def _append_to_jsonl(session_id: str, message: dict) -> None:
|
||||
"""Append a message to the JSONL transcript file."""
|
||||
transcript_path = _SESSIONS_DIR / f"{session_id}.jsonl"
|
||||
try:
|
||||
with open(transcript_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
logger.debug("Mirror JSONL write failed: %s", e)
|
||||
|
||||
|
||||
def _append_to_sqlite(session_id: str, message: dict) -> None:
|
||||
"""Append a message to the SQLite session database."""
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = SessionDB()
|
||||
db.append_message(
|
||||
session_id=session_id,
|
||||
role=message.get("role", "assistant"),
|
||||
content=message.get("content"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Mirror SQLite write failed: %s", e)
|
||||
280
gateway/pairing.py
Normal file
280
gateway/pairing.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
DM Pairing System
|
||||
|
||||
Code-based approval flow for authorizing new users on messaging platforms.
|
||||
Instead of static allowlists with user IDs, unknown users receive a one-time
|
||||
pairing code that the bot owner approves via the CLI.
|
||||
|
||||
Security features (based on OWASP + NIST SP 800-63-4 guidance):
|
||||
- 8-char codes from 32-char unambiguous alphabet (no 0/O/1/I)
|
||||
- Cryptographic randomness via secrets.choice()
|
||||
- 1-hour code expiry
|
||||
- Max 3 pending codes per platform
|
||||
- Rate limiting: 1 request per user per 10 minutes
|
||||
- Lockout after 5 failed approval attempts (1 hour)
|
||||
- File permissions: chmod 0600 on all data files
|
||||
- Codes are never logged to stdout
|
||||
|
||||
Storage: ~/.hermes/pairing/
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Unambiguous alphabet -- excludes 0/O, 1/I to prevent confusion
|
||||
ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
CODE_LENGTH = 8
|
||||
|
||||
# Timing constants
|
||||
CODE_TTL_SECONDS = 3600 # Codes expire after 1 hour
|
||||
RATE_LIMIT_SECONDS = 600 # 1 request per user per 10 minutes
|
||||
LOCKOUT_SECONDS = 3600 # Lockout duration after too many failures
|
||||
|
||||
# Limits
|
||||
MAX_PENDING_PER_PLATFORM = 3 # Max pending codes per platform
|
||||
MAX_FAILED_ATTEMPTS = 5 # Failed approvals before lockout
|
||||
|
||||
PAIRING_DIR = Path(os.path.expanduser("~/.hermes/pairing"))
|
||||
|
||||
|
||||
def _secure_write(path: Path, data: str) -> None:
|
||||
"""Write data to file with restrictive permissions (owner read/write only)."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(data, encoding="utf-8")
|
||||
try:
|
||||
os.chmod(path, 0o600)
|
||||
except OSError:
|
||||
pass # Windows doesn't support chmod the same way
|
||||
|
||||
|
||||
class PairingStore:
|
||||
"""
|
||||
Manages pairing codes and approved user lists.
|
||||
|
||||
Data files per platform:
|
||||
- {platform}-pending.json : pending pairing requests
|
||||
- {platform}-approved.json : approved (paired) users
|
||||
- _rate_limits.json : rate limit tracking
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
PAIRING_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _pending_path(self, platform: str) -> Path:
|
||||
return PAIRING_DIR / f"{platform}-pending.json"
|
||||
|
||||
def _approved_path(self, platform: str) -> Path:
|
||||
return PAIRING_DIR / f"{platform}-approved.json"
|
||||
|
||||
def _rate_limit_path(self) -> Path:
|
||||
return PAIRING_DIR / "_rate_limits.json"
|
||||
|
||||
def _load_json(self, path: Path) -> dict:
|
||||
if path.exists():
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def _save_json(self, path: Path, data: dict) -> None:
|
||||
_secure_write(path, json.dumps(data, indent=2, ensure_ascii=False))
|
||||
|
||||
# ----- Approved users -----
|
||||
|
||||
def is_approved(self, platform: str, user_id: str) -> bool:
|
||||
"""Check if a user is approved (paired) on a platform."""
|
||||
approved = self._load_json(self._approved_path(platform))
|
||||
return user_id in approved
|
||||
|
||||
def list_approved(self, platform: str = None) -> list:
|
||||
"""List approved users, optionally filtered by platform."""
|
||||
results = []
|
||||
platforms = [platform] if platform else self._all_platforms("approved")
|
||||
for p in platforms:
|
||||
approved = self._load_json(self._approved_path(p))
|
||||
for uid, info in approved.items():
|
||||
results.append({"platform": p, "user_id": uid, **info})
|
||||
return results
|
||||
|
||||
def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None:
|
||||
"""Add a user to the approved list."""
|
||||
approved = self._load_json(self._approved_path(platform))
|
||||
approved[user_id] = {
|
||||
"user_name": user_name,
|
||||
"approved_at": time.time(),
|
||||
}
|
||||
self._save_json(self._approved_path(platform), approved)
|
||||
|
||||
def revoke(self, platform: str, user_id: str) -> bool:
|
||||
"""Remove a user from the approved list. Returns True if found."""
|
||||
path = self._approved_path(platform)
|
||||
approved = self._load_json(path)
|
||||
if user_id in approved:
|
||||
del approved[user_id]
|
||||
self._save_json(path, approved)
|
||||
return True
|
||||
return False
|
||||
|
||||
# ----- Pending codes -----
|
||||
|
||||
def generate_code(self, platform: str, user_id: str, user_name: str = "") -> str | None:
|
||||
"""
|
||||
Generate a pairing code for a new user.
|
||||
|
||||
Returns the code string, or None if:
|
||||
- User is rate-limited (too recent request)
|
||||
- Max pending codes reached for this platform
|
||||
- User/platform is in lockout due to failed attempts
|
||||
"""
|
||||
self._cleanup_expired(platform)
|
||||
|
||||
# Check lockout
|
||||
if self._is_locked_out(platform):
|
||||
return None
|
||||
|
||||
# Check rate limit for this specific user
|
||||
if self._is_rate_limited(platform, user_id):
|
||||
return None
|
||||
|
||||
# Check max pending
|
||||
pending = self._load_json(self._pending_path(platform))
|
||||
if len(pending) >= MAX_PENDING_PER_PLATFORM:
|
||||
return None
|
||||
|
||||
# Generate cryptographically random code
|
||||
code = "".join(secrets.choice(ALPHABET) for _ in range(CODE_LENGTH))
|
||||
|
||||
# Store pending request
|
||||
pending[code] = {
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
self._save_json(self._pending_path(platform), pending)
|
||||
|
||||
# Record rate limit
|
||||
self._record_rate_limit(platform, user_id)
|
||||
|
||||
return code
|
||||
|
||||
def approve_code(self, platform: str, code: str) -> dict | None:
|
||||
"""
|
||||
Approve a pairing code. Adds the user to the approved list.
|
||||
|
||||
Returns {user_id, user_name} on success, None if code is invalid/expired.
|
||||
"""
|
||||
self._cleanup_expired(platform)
|
||||
code = code.upper().strip()
|
||||
|
||||
pending = self._load_json(self._pending_path(platform))
|
||||
if code not in pending:
|
||||
self._record_failed_attempt(platform)
|
||||
return None
|
||||
|
||||
entry = pending.pop(code)
|
||||
self._save_json(self._pending_path(platform), pending)
|
||||
|
||||
# Add to approved list
|
||||
self._approve_user(platform, entry["user_id"], entry.get("user_name", ""))
|
||||
|
||||
return {
|
||||
"user_id": entry["user_id"],
|
||||
"user_name": entry.get("user_name", ""),
|
||||
}
|
||||
|
||||
def list_pending(self, platform: str = None) -> list:
|
||||
"""List pending pairing requests, optionally filtered by platform."""
|
||||
results = []
|
||||
platforms = [platform] if platform else self._all_platforms("pending")
|
||||
for p in platforms:
|
||||
self._cleanup_expired(p)
|
||||
pending = self._load_json(self._pending_path(p))
|
||||
for code, info in pending.items():
|
||||
age_min = int((time.time() - info["created_at"]) / 60)
|
||||
results.append(
|
||||
{
|
||||
"platform": p,
|
||||
"code": code,
|
||||
"user_id": info["user_id"],
|
||||
"user_name": info.get("user_name", ""),
|
||||
"age_minutes": age_min,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def clear_pending(self, platform: str = None) -> int:
|
||||
"""Clear all pending requests. Returns count removed."""
|
||||
count = 0
|
||||
platforms = [platform] if platform else self._all_platforms("pending")
|
||||
for p in platforms:
|
||||
pending = self._load_json(self._pending_path(p))
|
||||
count += len(pending)
|
||||
self._save_json(self._pending_path(p), {})
|
||||
return count
|
||||
|
||||
# ----- Rate limiting and lockout -----
|
||||
|
||||
def _is_rate_limited(self, platform: str, user_id: str) -> bool:
|
||||
"""Check if a user has requested a code too recently."""
|
||||
limits = self._load_json(self._rate_limit_path())
|
||||
key = f"{platform}:{user_id}"
|
||||
last_request = limits.get(key, 0)
|
||||
return (time.time() - last_request) < RATE_LIMIT_SECONDS
|
||||
|
||||
def _record_rate_limit(self, platform: str, user_id: str) -> None:
|
||||
"""Record the time of a pairing request for rate limiting."""
|
||||
limits = self._load_json(self._rate_limit_path())
|
||||
key = f"{platform}:{user_id}"
|
||||
limits[key] = time.time()
|
||||
self._save_json(self._rate_limit_path(), limits)
|
||||
|
||||
def _is_locked_out(self, platform: str) -> bool:
|
||||
"""Check if a platform is in lockout due to failed approval attempts."""
|
||||
limits = self._load_json(self._rate_limit_path())
|
||||
lockout_key = f"_lockout:{platform}"
|
||||
lockout_until = limits.get(lockout_key, 0)
|
||||
return time.time() < lockout_until
|
||||
|
||||
def _record_failed_attempt(self, platform: str) -> None:
|
||||
"""Record a failed approval attempt. Triggers lockout after MAX_FAILED_ATTEMPTS."""
|
||||
limits = self._load_json(self._rate_limit_path())
|
||||
fail_key = f"_failures:{platform}"
|
||||
fails = limits.get(fail_key, 0) + 1
|
||||
limits[fail_key] = fails
|
||||
if fails >= MAX_FAILED_ATTEMPTS:
|
||||
lockout_key = f"_lockout:{platform}"
|
||||
limits[lockout_key] = time.time() + LOCKOUT_SECONDS
|
||||
limits[fail_key] = 0 # Reset counter
|
||||
print(
|
||||
f"[pairing] Platform {platform} locked out for {LOCKOUT_SECONDS}s "
|
||||
f"after {MAX_FAILED_ATTEMPTS} failed attempts",
|
||||
flush=True,
|
||||
)
|
||||
self._save_json(self._rate_limit_path(), limits)
|
||||
|
||||
# ----- Cleanup -----
|
||||
|
||||
def _cleanup_expired(self, platform: str) -> None:
|
||||
"""Remove expired pending codes."""
|
||||
path = self._pending_path(platform)
|
||||
pending = self._load_json(path)
|
||||
now = time.time()
|
||||
expired = [code for code, info in pending.items() if (now - info["created_at"]) > CODE_TTL_SECONDS]
|
||||
if expired:
|
||||
for code in expired:
|
||||
del pending[code]
|
||||
self._save_json(path, pending)
|
||||
|
||||
def _all_platforms(self, suffix: str) -> list:
|
||||
"""List all platforms that have data files of a given suffix."""
|
||||
platforms = []
|
||||
for f in PAIRING_DIR.iterdir():
|
||||
if f.name.endswith(f"-{suffix}.json"):
|
||||
platform = f.name.replace(f"-{suffix}.json", "")
|
||||
if not platform.startswith("_"):
|
||||
platforms.append(platform)
|
||||
return platforms
|
||||
313
gateway/platforms/ADDING_A_PLATFORM.md
Normal file
313
gateway/platforms/ADDING_A_PLATFORM.md
Normal file
@@ -0,0 +1,313 @@
|
||||
# Adding a New Messaging Platform
|
||||
|
||||
Checklist for integrating a new messaging platform into the Hermes gateway.
|
||||
Use this as a reference when building a new adapter — every item here is a
|
||||
real integration point that exists in the codebase. Missing any of them will
|
||||
cause broken functionality, missing features, or inconsistent behavior.
|
||||
|
||||
---
|
||||
|
||||
## 1. Core Adapter (`gateway/platforms/<platform>.py`)
|
||||
|
||||
The adapter is a subclass of `BasePlatformAdapter` from `gateway/platforms/base.py`.
|
||||
|
||||
### Required methods
|
||||
|
||||
| Method | Purpose |
|
||||
|--------|---------|
|
||||
| `__init__(self, config)` | Parse config, init state. Call `super().__init__(config, Platform.YOUR_PLATFORM)` |
|
||||
| `connect() -> bool` | Connect to the platform, start listeners. Return True on success |
|
||||
| `disconnect()` | Stop listeners, close connections, cancel tasks |
|
||||
| `send(chat_id, text, ...) -> SendResult` | Send a text message |
|
||||
| `send_typing(chat_id)` | Send typing indicator |
|
||||
| `send_image(chat_id, image_url, caption) -> SendResult` | Send an image |
|
||||
| `get_chat_info(chat_id) -> dict` | Return `{name, type, chat_id}` for a chat |
|
||||
|
||||
### Optional methods (have default stubs in base)
|
||||
|
||||
| Method | Purpose |
|
||||
|--------|---------|
|
||||
| `send_document(chat_id, path, caption)` | Send a file attachment |
|
||||
| `send_voice(chat_id, path)` | Send a voice message |
|
||||
| `send_video(chat_id, path, caption)` | Send a video |
|
||||
| `send_animation(chat_id, path, caption)` | Send a GIF/animation |
|
||||
| `send_image_file(chat_id, path, caption)` | Send image from local file |
|
||||
|
||||
### Required function
|
||||
|
||||
```python
|
||||
def check_<platform>_requirements() -> bool:
|
||||
"""Check if this platform's dependencies are available."""
|
||||
```
|
||||
|
||||
### Key patterns to follow
|
||||
|
||||
- Use `self.build_source(...)` to construct `SessionSource` objects
|
||||
- Call `self.handle_message(event)` to dispatch inbound messages to the gateway
|
||||
- Use `MessageEvent`, `MessageType`, `SendResult` from base
|
||||
- Use `cache_image_from_bytes`, `cache_audio_from_bytes`, `cache_document_from_bytes` for attachments
|
||||
- Filter self-messages (prevent reply loops)
|
||||
- Filter sync/echo messages if the platform has them
|
||||
- Redact sensitive identifiers (phone numbers, tokens) in all log output
|
||||
- Implement reconnection with exponential backoff + jitter for streaming connections
|
||||
- Set `MAX_MESSAGE_LENGTH` if the platform has message size limits
|
||||
|
||||
---
|
||||
|
||||
## 2. Platform Enum (`gateway/config.py`)
|
||||
|
||||
Add the platform to the `Platform` enum:
|
||||
|
||||
```python
|
||||
class Platform(Enum):
|
||||
...
|
||||
YOUR_PLATFORM = "your_platform"
|
||||
```
|
||||
|
||||
Add env var loading in `_apply_env_overrides()`:
|
||||
|
||||
```python
|
||||
# Your Platform
|
||||
your_token = os.getenv("YOUR_PLATFORM_TOKEN")
|
||||
if your_token:
|
||||
if Platform.YOUR_PLATFORM not in config.platforms:
|
||||
config.platforms[Platform.YOUR_PLATFORM] = PlatformConfig()
|
||||
config.platforms[Platform.YOUR_PLATFORM].enabled = True
|
||||
config.platforms[Platform.YOUR_PLATFORM].token = your_token
|
||||
```
|
||||
|
||||
Update `get_connected_platforms()` if your platform doesn't use token/api_key
|
||||
(e.g., WhatsApp uses `enabled` flag, Signal uses `extra` dict).
|
||||
|
||||
---
|
||||
|
||||
## 3. Adapter Factory (`gateway/run.py`)
|
||||
|
||||
Add to `_create_adapter()`:
|
||||
|
||||
```python
|
||||
elif platform == Platform.YOUR_PLATFORM:
|
||||
from gateway.platforms.your_platform import YourAdapter, check_your_requirements
|
||||
if not check_your_requirements():
|
||||
logger.warning("Your Platform: dependencies not met")
|
||||
return None
|
||||
return YourAdapter(config)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Authorization Maps (`gateway/run.py`)
|
||||
|
||||
Add to BOTH dicts in `_is_user_authorized()`:
|
||||
|
||||
```python
|
||||
platform_env_map = {
|
||||
...
|
||||
Platform.YOUR_PLATFORM: "YOUR_PLATFORM_ALLOWED_USERS",
|
||||
}
|
||||
platform_allow_all_map = {
|
||||
...
|
||||
Platform.YOUR_PLATFORM: "YOUR_PLATFORM_ALLOW_ALL_USERS",
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Session Source (`gateway/session.py`)
|
||||
|
||||
If your platform needs extra identity fields (e.g., Signal's UUID alongside
|
||||
phone number), add them to the `SessionSource` dataclass with `Optional` defaults,
|
||||
and update `to_dict()`, `from_dict()`, and `build_source()` in base.py.
|
||||
|
||||
---
|
||||
|
||||
## 6. System Prompt Hints (`agent/prompt_builder.py`)
|
||||
|
||||
Add a `PLATFORM_HINTS` entry so the agent knows what platform it's on:
|
||||
|
||||
```python
|
||||
PLATFORM_HINTS = {
|
||||
...
|
||||
"your_platform": (
|
||||
"You are on Your Platform. "
|
||||
"Describe formatting capabilities, media support, etc."
|
||||
),
|
||||
}
|
||||
```
|
||||
|
||||
Without this, the agent won't know it's on your platform and may use
|
||||
inappropriate formatting (e.g., markdown on platforms that don't render it).
|
||||
|
||||
---
|
||||
|
||||
## 7. Toolset (`toolsets.py`)
|
||||
|
||||
Add a named toolset for your platform:
|
||||
|
||||
```python
|
||||
"hermes-your-platform": {
|
||||
"description": "Your Platform bot toolset",
|
||||
"tools": _HERMES_CORE_TOOLS,
|
||||
"includes": []
|
||||
},
|
||||
```
|
||||
|
||||
And add it to the `hermes-gateway` composite:
|
||||
|
||||
```python
|
||||
"hermes-gateway": {
|
||||
"includes": [..., "hermes-your-platform"]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Cron Delivery (`cron/scheduler.py`)
|
||||
|
||||
Add to `platform_map` in `_deliver_result()`:
|
||||
|
||||
```python
|
||||
platform_map = {
|
||||
...
|
||||
"your_platform": Platform.YOUR_PLATFORM,
|
||||
}
|
||||
```
|
||||
|
||||
Without this, `schedule_cronjob(deliver="your_platform")` silently fails.
|
||||
|
||||
---
|
||||
|
||||
## 9. Send Message Tool (`tools/send_message_tool.py`)
|
||||
|
||||
Add to `platform_map` in `send_message_tool()`:
|
||||
|
||||
```python
|
||||
platform_map = {
|
||||
...
|
||||
"your_platform": Platform.YOUR_PLATFORM,
|
||||
}
|
||||
```
|
||||
|
||||
Add routing in `_send_to_platform()`:
|
||||
|
||||
```python
|
||||
elif platform == Platform.YOUR_PLATFORM:
|
||||
return await _send_your_platform(pconfig, chat_id, message)
|
||||
```
|
||||
|
||||
Implement `_send_your_platform()` — a standalone async function that sends
|
||||
a single message without requiring the full adapter (for use by cron jobs
|
||||
and the send_message tool outside the gateway process).
|
||||
|
||||
Update the tool schema `target` description to include your platform example.
|
||||
|
||||
---
|
||||
|
||||
## 10. Cronjob Tool Schema (`tools/cronjob_tools.py`)
|
||||
|
||||
Update the `deliver` parameter description and docstring to mention your
|
||||
platform as a delivery option.
|
||||
|
||||
---
|
||||
|
||||
## 11. Channel Directory (`gateway/channel_directory.py`)
|
||||
|
||||
If your platform can't enumerate chats (most can't), add it to the
|
||||
session-based discovery list:
|
||||
|
||||
```python
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "your_platform"):
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 12. Status Display (`hermes_cli/status.py`)
|
||||
|
||||
Add to the `platforms` dict in the Messaging Platforms section:
|
||||
|
||||
```python
|
||||
platforms = {
|
||||
...
|
||||
"Your Platform": ("YOUR_PLATFORM_TOKEN", "YOUR_PLATFORM_HOME_CHANNEL"),
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 13. Gateway Setup Wizard (`hermes_cli/gateway.py`)
|
||||
|
||||
Add to the `_PLATFORMS` list:
|
||||
|
||||
```python
|
||||
{
|
||||
"key": "your_platform",
|
||||
"label": "Your Platform",
|
||||
"emoji": "📱",
|
||||
"token_var": "YOUR_PLATFORM_TOKEN",
|
||||
"setup_instructions": [...],
|
||||
"vars": [...],
|
||||
}
|
||||
```
|
||||
|
||||
If your platform needs custom setup logic (connectivity testing, QR codes,
|
||||
policy choices), add a `_setup_your_platform()` function and route to it
|
||||
in the platform selection switch.
|
||||
|
||||
Update `_platform_status()` if your platform's "configured" check differs
|
||||
from the standard `bool(get_env_value(token_var))`.
|
||||
|
||||
---
|
||||
|
||||
## 14. Phone/ID Redaction (`agent/redact.py`)
|
||||
|
||||
If your platform uses sensitive identifiers (phone numbers, etc.), add a
|
||||
regex pattern and redaction function to `agent/redact.py`. This ensures
|
||||
identifiers are masked in ALL log output, not just your adapter's logs.
|
||||
|
||||
---
|
||||
|
||||
## 15. Documentation
|
||||
|
||||
| File | What to update |
|
||||
|------|---------------|
|
||||
| `README.md` | Platform list in feature table + documentation table |
|
||||
| `AGENTS.md` | Gateway description + env var config section |
|
||||
| `website/docs/user-guide/messaging/<platform>.md` | **NEW** — Full setup guide (see existing platform docs for template) |
|
||||
| `website/docs/user-guide/messaging/index.md` | Architecture diagram, toolset table, security examples, Next Steps links |
|
||||
| `website/docs/reference/environment-variables.md` | All env vars for the platform |
|
||||
|
||||
---
|
||||
|
||||
## 16. Tests (`tests/gateway/test_<platform>.py`)
|
||||
|
||||
Recommended test coverage:
|
||||
|
||||
- Platform enum exists with correct value
|
||||
- Config loading from env vars via `_apply_env_overrides`
|
||||
- Adapter init (config parsing, allowlist handling, default values)
|
||||
- Helper functions (redaction, parsing, file type detection)
|
||||
- Session source round-trip (to_dict → from_dict)
|
||||
- Authorization integration (platform in allowlist maps)
|
||||
- Send message tool routing (platform in platform_map)
|
||||
|
||||
Optional but valuable:
|
||||
- Async tests for message handling flow (mock the platform API)
|
||||
- SSE/WebSocket reconnection logic
|
||||
- Attachment processing
|
||||
- Group message filtering
|
||||
|
||||
---
|
||||
|
||||
## Quick Verification
|
||||
|
||||
After implementing everything, verify with:
|
||||
|
||||
```bash
|
||||
# All checks pass (lint + test)
|
||||
make check
|
||||
|
||||
# Grep for your platform name to find any missed integration points
|
||||
grep -r "telegram\|discord\|whatsapp\|slack" gateway/ tools/ agent/ cron/ hermes_cli/ toolsets.py \
|
||||
--include="*.py" -l | sort -u
|
||||
# Check each file in the output — if it mentions other platforms but not yours, you missed it
|
||||
```
|
||||
17
gateway/platforms/__init__.py
Normal file
17
gateway/platforms/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Platform adapters for messaging integrations.
|
||||
|
||||
Each adapter handles:
|
||||
- Receiving messages from a platform
|
||||
- Sending messages/responses back
|
||||
- Platform-specific authentication
|
||||
- Message formatting and media handling
|
||||
"""
|
||||
|
||||
from .base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
|
||||
__all__ = [
|
||||
"BasePlatformAdapter",
|
||||
"MessageEvent",
|
||||
"SendResult",
|
||||
]
|
||||
984
gateway/platforms/base.py
Normal file
984
gateway/platforms/base.py
Normal file
@@ -0,0 +1,984 @@
|
||||
"""
|
||||
Base platform adapter interface.
|
||||
|
||||
All platform adapters (Telegram, Discord, WhatsApp) inherit from this
|
||||
and implement the required methods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from pathlib import Path as _Path
|
||||
from typing import Any
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.session import SessionSource
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image cache utilities
|
||||
#
|
||||
# When users send images on messaging platforms, we download them to a local
|
||||
# cache directory so they can be analyzed by the vision tool (which accepts
|
||||
# local file paths). This avoids issues with ephemeral platform URLs
|
||||
# (e.g. Telegram file URLs expire after ~1 hour).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Default location: ~/.hermes/image_cache/
|
||||
IMAGE_CACHE_DIR = Path(os.path.expanduser("~/.hermes/image_cache"))
|
||||
|
||||
|
||||
def get_image_cache_dir() -> Path:
|
||||
"""Return the image cache directory, creating it if it doesn't exist."""
|
||||
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return IMAGE_CACHE_DIR
|
||||
|
||||
|
||||
def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
|
||||
"""
|
||||
Save raw image bytes to the cache and return the absolute file path.
|
||||
|
||||
Args:
|
||||
data: Raw image bytes.
|
||||
ext: File extension including the dot (e.g. ".jpg", ".png").
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
"""
|
||||
cache_dir = get_image_cache_dir()
|
||||
filename = f"img_{uuid.uuid4().hex[:12]}{ext}"
|
||||
filepath = cache_dir / filename
|
||||
filepath.write_bytes(data)
|
||||
return str(filepath)
|
||||
|
||||
|
||||
async def cache_image_from_url(url: str, ext: str = ".jpg") -> str:
|
||||
"""
|
||||
Download an image from a URL and save it to the local cache.
|
||||
|
||||
Uses httpx for async download with a reasonable timeout.
|
||||
|
||||
Args:
|
||||
url: The HTTP/HTTPS URL to download from.
|
||||
ext: File extension including the dot (e.g. ".jpg", ".png").
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
|
||||
|
||||
def cleanup_image_cache(max_age_hours: int = 24) -> int:
|
||||
"""
|
||||
Delete cached images older than *max_age_hours*.
|
||||
|
||||
Returns the number of files removed.
|
||||
"""
|
||||
import time
|
||||
|
||||
cache_dir = get_image_cache_dir()
|
||||
cutoff = time.time() - (max_age_hours * 3600)
|
||||
removed = 0
|
||||
for f in cache_dir.iterdir():
|
||||
if f.is_file() and f.stat().st_mtime < cutoff:
|
||||
try:
|
||||
f.unlink()
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
return removed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio cache utilities
|
||||
#
|
||||
# Same pattern as image cache -- voice messages from platforms are downloaded
|
||||
# here so the STT tool (OpenAI Whisper) can transcribe them from local files.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AUDIO_CACHE_DIR = Path(os.path.expanduser("~/.hermes/audio_cache"))
|
||||
|
||||
|
||||
def get_audio_cache_dir() -> Path:
|
||||
"""Return the audio cache directory, creating it if it doesn't exist."""
|
||||
AUDIO_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return AUDIO_CACHE_DIR
|
||||
|
||||
|
||||
def cache_audio_from_bytes(data: bytes, ext: str = ".ogg") -> str:
|
||||
"""
|
||||
Save raw audio bytes to the cache and return the absolute file path.
|
||||
|
||||
Args:
|
||||
data: Raw audio bytes.
|
||||
ext: File extension including the dot (e.g. ".ogg", ".mp3").
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached audio file as a string.
|
||||
"""
|
||||
cache_dir = get_audio_cache_dir()
|
||||
filename = f"audio_{uuid.uuid4().hex[:12]}{ext}"
|
||||
filepath = cache_dir / filename
|
||||
filepath.write_bytes(data)
|
||||
return str(filepath)
|
||||
|
||||
|
||||
async def cache_audio_from_url(url: str, ext: str = ".ogg") -> str:
|
||||
"""
|
||||
Download an audio file from a URL and save it to the local cache.
|
||||
|
||||
Args:
|
||||
url: The HTTP/HTTPS URL to download from.
|
||||
ext: File extension including the dot (e.g. ".ogg", ".mp3").
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached audio file as a string.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "audio/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document cache utilities
|
||||
#
|
||||
# Same pattern as image/audio cache -- documents from platforms are downloaded
|
||||
# here so the agent can reference them by local file path.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOCUMENT_CACHE_DIR = Path(os.path.expanduser("~/.hermes/document_cache"))
|
||||
|
||||
SUPPORTED_DOCUMENT_TYPES = {
|
||||
".pdf": "application/pdf",
|
||||
".md": "text/markdown",
|
||||
".txt": "text/plain",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
}
|
||||
|
||||
|
||||
def get_document_cache_dir() -> Path:
|
||||
"""Return the document cache directory, creating it if it doesn't exist."""
|
||||
DOCUMENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return DOCUMENT_CACHE_DIR
|
||||
|
||||
|
||||
def cache_document_from_bytes(data: bytes, filename: str) -> str:
|
||||
"""
|
||||
Save raw document bytes to the cache and return the absolute file path.
|
||||
|
||||
The cached filename preserves the original human-readable name with a
|
||||
unique prefix: ``doc_{uuid12}_{original_filename}``.
|
||||
|
||||
Args:
|
||||
data: Raw document bytes.
|
||||
filename: Original filename (e.g. "report.pdf").
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached document file as a string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the sanitized path escapes the cache directory.
|
||||
"""
|
||||
cache_dir = get_document_cache_dir()
|
||||
# Sanitize: strip directory components, null bytes, and control characters
|
||||
safe_name = Path(filename).name if filename else "document"
|
||||
safe_name = safe_name.replace("\x00", "").strip()
|
||||
if not safe_name or safe_name in (".", ".."):
|
||||
safe_name = "document"
|
||||
cached_name = f"doc_{uuid.uuid4().hex[:12]}_{safe_name}"
|
||||
filepath = cache_dir / cached_name
|
||||
# Final safety check: ensure path stays inside cache dir
|
||||
if not filepath.resolve().is_relative_to(cache_dir.resolve()):
|
||||
raise ValueError(f"Path traversal rejected: {filename!r}")
|
||||
filepath.write_bytes(data)
|
||||
return str(filepath)
|
||||
|
||||
|
||||
def cleanup_document_cache(max_age_hours: int = 24) -> int:
|
||||
"""
|
||||
Delete cached documents older than *max_age_hours*.
|
||||
|
||||
Returns the number of files removed.
|
||||
"""
|
||||
import time
|
||||
|
||||
cache_dir = get_document_cache_dir()
|
||||
cutoff = time.time() - (max_age_hours * 3600)
|
||||
removed = 0
|
||||
for f in cache_dir.iterdir():
|
||||
if f.is_file() and f.stat().st_mtime < cutoff:
|
||||
try:
|
||||
f.unlink()
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
return removed
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Types of incoming messages."""
|
||||
|
||||
TEXT = "text"
|
||||
LOCATION = "location"
|
||||
PHOTO = "photo"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
VOICE = "voice"
|
||||
DOCUMENT = "document"
|
||||
STICKER = "sticker"
|
||||
COMMAND = "command" # /command style
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageEvent:
|
||||
"""
|
||||
Incoming message from a platform.
|
||||
|
||||
Normalized representation that all adapters produce.
|
||||
"""
|
||||
|
||||
# Message content
|
||||
text: str
|
||||
message_type: MessageType = MessageType.TEXT
|
||||
|
||||
# Source information
|
||||
source: SessionSource = None
|
||||
|
||||
# Original platform data
|
||||
raw_message: Any = None
|
||||
message_id: str | None = None
|
||||
|
||||
# Media attachments
|
||||
media_urls: list[str] = field(default_factory=list)
|
||||
media_types: list[str] = field(default_factory=list)
|
||||
|
||||
# Reply context
|
||||
reply_to_message_id: str | None = None
|
||||
|
||||
# Timestamps
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def is_command(self) -> bool:
|
||||
"""Check if this is a command message (e.g., /new, /reset)."""
|
||||
return self.text.startswith("/")
|
||||
|
||||
def get_command(self) -> str | None:
|
||||
"""Extract command name if this is a command message."""
|
||||
if not self.is_command():
|
||||
return None
|
||||
# Split on space and get first word, strip the /
|
||||
parts = self.text.split(maxsplit=1)
|
||||
return parts[0][1:].lower() if parts else None
|
||||
|
||||
def get_command_args(self) -> str:
|
||||
"""Get the arguments after a command."""
|
||||
if not self.is_command():
|
||||
return self.text
|
||||
parts = self.text.split(maxsplit=1)
|
||||
return parts[1] if len(parts) > 1 else ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendResult:
|
||||
"""Result of sending a message."""
|
||||
|
||||
success: bool
|
||||
message_id: str | None = None
|
||||
error: str | None = None
|
||||
raw_response: Any = None
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
MessageHandler = Callable[[MessageEvent], Awaitable[str | None]]
|
||||
|
||||
|
||||
class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
Base class for platform adapters.
|
||||
|
||||
Subclasses implement platform-specific logic for:
|
||||
- Connecting and authenticating
|
||||
- Receiving messages
|
||||
- Sending messages/responses
|
||||
- Handling media
|
||||
"""
|
||||
|
||||
def __init__(self, config: PlatformConfig, platform: Platform):
|
||||
self.config = config
|
||||
self.platform = platform
|
||||
self._message_handler: MessageHandler | None = None
|
||||
self._running = False
|
||||
|
||||
# Track active message handlers per session for interrupt support
|
||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||
self._active_sessions: dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: dict[str, MessageEvent] = {}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable name for this adapter."""
|
||||
return self.platform.value.title()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if adapter is currently connected."""
|
||||
return self._running
|
||||
|
||||
def set_message_handler(self, handler: MessageHandler) -> None:
|
||||
"""
|
||||
Set the handler for incoming messages.
|
||||
|
||||
The handler receives a MessageEvent and should return
|
||||
an optional response string.
|
||||
"""
|
||||
self._message_handler = handler
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Connect to the platform and start receiving messages.
|
||||
|
||||
Returns True if connection was successful.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the platform."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send(
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a message to a chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID to send to
|
||||
content: Message content (may be markdown)
|
||||
reply_to: Optional message ID to reply to
|
||||
metadata: Additional platform-specific options
|
||||
|
||||
Returns:
|
||||
SendResult with success status and message ID
|
||||
"""
|
||||
pass
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
content: str,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Edit a previously sent message. Optional — platforms that don't
|
||||
support editing return success=False and callers fall back to
|
||||
sending a new message.
|
||||
"""
|
||||
return SendResult(success=False, error="Not supported")
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""
|
||||
Send a typing indicator.
|
||||
|
||||
Override in subclasses if the platform supports it.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an image natively via the platform API.
|
||||
|
||||
Override in subclasses to send images as proper attachments
|
||||
instead of plain-text URLs. Default falls back to sending the
|
||||
URL as a text message.
|
||||
"""
|
||||
# Fallback: send URL as text (subclasses override for native images)
|
||||
text = f"{caption}\n{image_url}" if caption else image_url
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an animated GIF natively via the platform API.
|
||||
|
||||
Override in subclasses to send GIFs as proper animations
|
||||
(e.g., Telegram send_animation) so they auto-play inline.
|
||||
Default falls back to send_image.
|
||||
"""
|
||||
return await self.send_image(chat_id=chat_id, image_url=animation_url, caption=caption, reply_to=reply_to)
|
||||
|
||||
@staticmethod
|
||||
def _is_animation_url(url: str) -> bool:
|
||||
"""Check if a URL points to an animated GIF (vs a static image)."""
|
||||
lower = url.lower().split("?")[0] # Strip query params
|
||||
return lower.endswith(".gif")
|
||||
|
||||
@staticmethod
|
||||
def extract_images(content: str) -> tuple[list[tuple[str, str]], str]:
|
||||
"""
|
||||
Extract image URLs from markdown and HTML image tags in a response.
|
||||
|
||||
Finds patterns like:
|
||||
- 
|
||||
- <img src="https://example.com/image.png">
|
||||
- <img src="https://example.com/image.png"></img>
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (url, alt_text) pairs, cleaned content with image tags removed).
|
||||
"""
|
||||
images = []
|
||||
cleaned = content
|
||||
|
||||
# Match markdown images: 
|
||||
md_pattern = r"!\[([^\]]*)\]\((https?://[^\s\)]+)\)"
|
||||
for match in re.finditer(md_pattern, content):
|
||||
alt_text = match.group(1)
|
||||
url = match.group(2)
|
||||
# Only extract URLs that look like actual images
|
||||
if any(
|
||||
url.lower().endswith(ext) or ext in url.lower()
|
||||
for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp", "fal.media", "fal-cdn", "replicate.delivery"]
|
||||
):
|
||||
images.append((url, alt_text))
|
||||
|
||||
# Match HTML img tags: <img src="url"> or <img src="url"></img> or <img src="url"/>
|
||||
html_pattern = r'<img\s+src=["\']?(https?://[^\s"\'<>]+)["\']?\s*/?>\s*(?:</img>)?'
|
||||
for match in re.finditer(html_pattern, content):
|
||||
url = match.group(1)
|
||||
images.append((url, ""))
|
||||
|
||||
# Remove only the matched image tags from content (not all markdown images)
|
||||
if images:
|
||||
extracted_urls = {url for url, _ in images}
|
||||
|
||||
def _remove_if_extracted(match):
|
||||
url = match.group(2) if match.lastindex >= 2 else match.group(1)
|
||||
return "" if url in extracted_urls else match.group(0)
|
||||
|
||||
cleaned = re.sub(md_pattern, _remove_if_extracted, cleaned)
|
||||
cleaned = re.sub(html_pattern, _remove_if_extracted, cleaned)
|
||||
# Clean up leftover blank lines
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip()
|
||||
|
||||
return images, cleaned
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an audio file as a native voice message via the platform API.
|
||||
|
||||
Override in subclasses to send audio as voice bubbles (Telegram)
|
||||
or file attachments (Discord). Default falls back to sending the
|
||||
file path as text.
|
||||
"""
|
||||
text = f"🔊 Audio: {audio_path}"
|
||||
if caption:
|
||||
text = f"{caption}\n{text}"
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a video natively via the platform API.
|
||||
|
||||
Override in subclasses to send videos as inline playable media.
|
||||
Default falls back to sending the file path as text.
|
||||
"""
|
||||
text = f"🎬 Video: {video_path}"
|
||||
if caption:
|
||||
text = f"{caption}\n{text}"
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a document/file natively via the platform API.
|
||||
|
||||
Override in subclasses to send files as downloadable attachments.
|
||||
Default falls back to sending the file path as text.
|
||||
"""
|
||||
text = f"📎 File: {file_path}"
|
||||
if caption:
|
||||
text = f"{caption}\n{text}"
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a local image file natively via the platform API.
|
||||
|
||||
Unlike send_image() which takes a URL, this takes a local file path.
|
||||
Override in subclasses for native photo attachments.
|
||||
Default falls back to sending the file path as text.
|
||||
"""
|
||||
text = f"🖼️ Image: {image_path}"
|
||||
if caption:
|
||||
text = f"{caption}\n{text}"
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
@staticmethod
|
||||
def extract_media(content: str) -> tuple[list[tuple[str, bool]], str]:
|
||||
"""
|
||||
Extract MEDIA:<path> tags and [[audio_as_voice]] directives from response text.
|
||||
|
||||
The TTS tool returns responses like:
|
||||
[[audio_as_voice]]
|
||||
MEDIA:/path/to/audio.ogg
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (path, is_voice) pairs, cleaned content with tags removed).
|
||||
"""
|
||||
media = []
|
||||
cleaned = content
|
||||
|
||||
# Check for [[audio_as_voice]] directive
|
||||
has_voice_tag = "[[audio_as_voice]]" in content
|
||||
cleaned = cleaned.replace("[[audio_as_voice]]", "")
|
||||
|
||||
# Extract MEDIA:<path> tags (path may contain spaces)
|
||||
media_pattern = r"MEDIA:(\S+)"
|
||||
for match in re.finditer(media_pattern, content):
|
||||
path = match.group(1).strip()
|
||||
if path:
|
||||
media.append((path, has_voice_tag))
|
||||
|
||||
# Remove MEDIA tags from content
|
||||
if media:
|
||||
cleaned = re.sub(media_pattern, "", cleaned)
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip()
|
||||
|
||||
return media, cleaned
|
||||
|
||||
async def _keep_typing(self, chat_id: str, interval: float = 2.0) -> None:
|
||||
"""
|
||||
Continuously send typing indicator until cancelled.
|
||||
|
||||
Telegram/Discord typing status expires after ~5 seconds, so we refresh every 2
|
||||
to recover quickly after progress messages interrupt it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
await self.send_typing(chat_id)
|
||||
await asyncio.sleep(interval)
|
||||
except asyncio.CancelledError:
|
||||
pass # Normal cancellation when handler completes
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
|
||||
This method returns quickly by spawning background tasks.
|
||||
This allows new messages to be processed even while an agent is running,
|
||||
enabling interruption support.
|
||||
"""
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
session_key = event.source.chat_id
|
||||
|
||||
# Check if there's already an active handler for this session
|
||||
if session_key in self._active_sessions:
|
||||
# Store this as a pending message - it will interrupt the running agent
|
||||
print(f"[{self.name}] ⚡ New message while session {session_key} is active - triggering interrupt")
|
||||
self._pending_messages[session_key] = event
|
||||
# Signal the interrupt (the processing task checks this)
|
||||
self._active_sessions[session_key].set()
|
||||
return # Don't process now - will be handled after current task finishes
|
||||
|
||||
# Spawn background task to process this message
|
||||
asyncio.create_task(self._process_message_background(event, session_key))
|
||||
|
||||
@staticmethod
|
||||
def _get_human_delay() -> float:
|
||||
"""
|
||||
Return a random delay in seconds for human-like response pacing.
|
||||
|
||||
Reads from env vars:
|
||||
HERMES_HUMAN_DELAY_MODE: "off" (default) | "natural" | "custom"
|
||||
HERMES_HUMAN_DELAY_MIN_MS: minimum delay in ms (default 800, custom mode)
|
||||
HERMES_HUMAN_DELAY_MAX_MS: maximum delay in ms (default 2500, custom mode)
|
||||
"""
|
||||
import random
|
||||
|
||||
mode = os.getenv("HERMES_HUMAN_DELAY_MODE", "off").lower()
|
||||
if mode == "off":
|
||||
return 0.0
|
||||
min_ms = int(os.getenv("HERMES_HUMAN_DELAY_MIN_MS", "800"))
|
||||
max_ms = int(os.getenv("HERMES_HUMAN_DELAY_MAX_MS", "2500"))
|
||||
if mode == "natural":
|
||||
min_ms, max_ms = 800, 2500
|
||||
return random.uniform(min_ms / 1000.0, max_ms / 1000.0)
|
||||
|
||||
async def _process_message_background(self, event: MessageEvent, session_key: str) -> None:
|
||||
"""Background task that actually processes the message."""
|
||||
# Create interrupt event for this session
|
||||
interrupt_event = asyncio.Event()
|
||||
self._active_sessions[session_key] = interrupt_event
|
||||
|
||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id))
|
||||
|
||||
try:
|
||||
# Call the handler (this can take a while with tool calls)
|
||||
response = await self._message_handler(event)
|
||||
|
||||
# Send response if any
|
||||
if not response:
|
||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if response:
|
||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||
media_files, response = self.extract_media(response)
|
||||
|
||||
# Extract image URLs and send them as native platform attachments
|
||||
images, text_content = self.extract_images(response)
|
||||
if images:
|
||||
logger.info(
|
||||
"[%s] extract_images found %d image(s) in response (%d chars)",
|
||||
self.name,
|
||||
len(images),
|
||||
len(response),
|
||||
)
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
logger.info(
|
||||
"[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id
|
||||
)
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id, content=text_content, reply_to=event.message_id
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
# Send extracted images as native attachments
|
||||
if images:
|
||||
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
|
||||
for image_url, alt_text in images:
|
||||
if human_delay > 0:
|
||||
await asyncio.sleep(human_delay)
|
||||
try:
|
||||
logger.info(
|
||||
"[%s] Sending image: %s (alt=%s)",
|
||||
self.name,
|
||||
image_url[:80],
|
||||
alt_text[:30] if alt_text else "",
|
||||
)
|
||||
# Route animated GIFs through send_animation for proper playback
|
||||
if self._is_animation_url(image_url):
|
||||
img_result = await self.send_animation(
|
||||
chat_id=event.source.chat_id,
|
||||
animation_url=image_url,
|
||||
caption=alt_text if alt_text else None,
|
||||
)
|
||||
else:
|
||||
img_result = await self.send_image(
|
||||
chat_id=event.source.chat_id,
|
||||
image_url=image_url,
|
||||
caption=alt_text if alt_text else None,
|
||||
)
|
||||
if not img_result.success:
|
||||
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
|
||||
except Exception as img_err:
|
||||
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
|
||||
|
||||
# Send extracted media files — route by file type
|
||||
_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".3gp"}
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
|
||||
for media_path, is_voice in media_files:
|
||||
if human_delay > 0:
|
||||
await asyncio.sleep(human_delay)
|
||||
try:
|
||||
ext = Path(media_path).suffix.lower()
|
||||
if ext in _AUDIO_EXTS:
|
||||
media_result = await self.send_voice(
|
||||
chat_id=event.source.chat_id,
|
||||
audio_path=media_path,
|
||||
)
|
||||
elif ext in _VIDEO_EXTS:
|
||||
media_result = await self.send_video(
|
||||
chat_id=event.source.chat_id,
|
||||
video_path=media_path,
|
||||
)
|
||||
elif ext in _IMAGE_EXTS:
|
||||
media_result = await self.send_image_file(
|
||||
chat_id=event.source.chat_id,
|
||||
image_path=media_path,
|
||||
)
|
||||
else:
|
||||
media_result = await self.send_document(
|
||||
chat_id=event.source.chat_id,
|
||||
file_path=media_path,
|
||||
)
|
||||
|
||||
if not media_result.success:
|
||||
print(f"[{self.name}] Failed to send media ({ext}): {media_result.error}")
|
||||
except Exception as media_err:
|
||||
print(f"[{self.name}] Error sending media: {media_err}")
|
||||
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
pending_event = self._pending_messages.pop(session_key)
|
||||
print(f"[{self.name}] 📨 Processing queued message from interrupt")
|
||||
# Clean up current session before processing pending
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
typing_task.cancel()
|
||||
try:
|
||||
await typing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Process pending message in new background task
|
||||
await self._process_message_background(pending_event, session_key)
|
||||
return # Already cleaned up
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error handling message: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Stop typing indicator
|
||||
typing_task.cancel()
|
||||
try:
|
||||
await typing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Clean up session tracking
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
|
||||
def has_pending_interrupt(self, session_key: str) -> bool:
|
||||
"""Check if there's a pending interrupt for a session."""
|
||||
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
||||
|
||||
def get_pending_message(self, session_key: str) -> MessageEvent | None:
|
||||
"""Get and clear any pending message for a session."""
|
||||
return self._pending_messages.pop(session_key, None)
|
||||
|
||||
def build_source(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_name: str | None = None,
|
||||
chat_type: str = "dm",
|
||||
user_id: str | None = None,
|
||||
user_name: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
chat_topic: str | None = None,
|
||||
user_id_alt: str | None = None,
|
||||
chat_id_alt: str | None = None,
|
||||
) -> SessionSource:
|
||||
"""Helper to build a SessionSource for this platform."""
|
||||
# Normalize empty topic to None
|
||||
if chat_topic is not None and not chat_topic.strip():
|
||||
chat_topic = None
|
||||
return SessionSource(
|
||||
platform=self.platform,
|
||||
chat_id=str(chat_id),
|
||||
chat_name=chat_name,
|
||||
chat_type=chat_type,
|
||||
user_id=str(user_id) if user_id else None,
|
||||
user_name=user_name,
|
||||
thread_id=str(thread_id) if thread_id else None,
|
||||
chat_topic=chat_topic.strip() if chat_topic else None,
|
||||
user_id_alt=user_id_alt,
|
||||
chat_id_alt=chat_id_alt,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get information about a chat/channel.
|
||||
|
||||
Returns dict with at least:
|
||||
- name: Chat name
|
||||
- type: "dm", "group", "channel"
|
||||
"""
|
||||
pass
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format a message for this platform.
|
||||
|
||||
Override in subclasses to handle platform-specific formatting
|
||||
(e.g., Telegram MarkdownV2, Discord markdown).
|
||||
|
||||
Default implementation returns content as-is.
|
||||
"""
|
||||
return content
|
||||
|
||||
def truncate_message(self, content: str, max_length: int = 4096) -> list[str]:
|
||||
"""
|
||||
Split a long message into chunks, preserving code block boundaries.
|
||||
|
||||
When a split falls inside a triple-backtick code block, the fence is
|
||||
closed at the end of the current chunk and reopened (with the original
|
||||
language tag) at the start of the next chunk. Multi-chunk responses
|
||||
receive indicators like ``(1/3)``.
|
||||
|
||||
Args:
|
||||
content: The full message content
|
||||
max_length: Maximum length per chunk (platform-specific)
|
||||
|
||||
Returns:
|
||||
List of message chunks
|
||||
"""
|
||||
if len(content) <= max_length:
|
||||
return [content]
|
||||
|
||||
INDICATOR_RESERVE = 10 # room for " (XX/XX)"
|
||||
FENCE_CLOSE = "\n```"
|
||||
|
||||
chunks: list[str] = []
|
||||
remaining = content
|
||||
# When the previous chunk ended mid-code-block, this holds the
|
||||
# language tag (possibly "") so we can reopen the fence.
|
||||
carry_lang: str | None = None
|
||||
|
||||
while remaining:
|
||||
# If we're continuing a code block from the previous chunk,
|
||||
# prepend a new opening fence with the same language tag.
|
||||
prefix = f"```{carry_lang}\n" if carry_lang is not None else ""
|
||||
|
||||
# How much body text we can fit after accounting for the prefix,
|
||||
# a potential closing fence, and the chunk indicator.
|
||||
headroom = max_length - INDICATOR_RESERVE - len(prefix) - len(FENCE_CLOSE)
|
||||
if headroom < 1:
|
||||
headroom = max_length // 2
|
||||
|
||||
# Everything remaining fits in one final chunk
|
||||
if len(prefix) + len(remaining) <= max_length - INDICATOR_RESERVE:
|
||||
chunks.append(prefix + remaining)
|
||||
break
|
||||
|
||||
# Find a natural split point (prefer newlines, then spaces)
|
||||
region = remaining[:headroom]
|
||||
split_at = region.rfind("\n")
|
||||
if split_at < headroom // 2:
|
||||
split_at = region.rfind(" ")
|
||||
if split_at < 1:
|
||||
split_at = headroom
|
||||
|
||||
chunk_body = remaining[:split_at]
|
||||
remaining = remaining[split_at:].lstrip()
|
||||
|
||||
full_chunk = prefix + chunk_body
|
||||
|
||||
# Walk only the chunk_body (not the prefix we prepended) to
|
||||
# determine whether we end inside an open code block.
|
||||
in_code = carry_lang is not None
|
||||
lang = carry_lang or ""
|
||||
for line in chunk_body.split("\n"):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("```"):
|
||||
if in_code:
|
||||
in_code = False
|
||||
lang = ""
|
||||
else:
|
||||
in_code = True
|
||||
tag = stripped[3:].strip()
|
||||
lang = tag.split()[0] if tag else ""
|
||||
|
||||
if in_code:
|
||||
# Close the orphaned fence so the chunk is valid on its own
|
||||
full_chunk += FENCE_CLOSE
|
||||
carry_lang = lang
|
||||
else:
|
||||
carry_lang = None
|
||||
|
||||
chunks.append(full_chunk)
|
||||
|
||||
# Append chunk indicators when the response spans multiple messages
|
||||
if len(chunks) > 1:
|
||||
total = len(chunks)
|
||||
chunks = [f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)]
|
||||
|
||||
return chunks
|
||||
963
gateway/platforms/discord.py
Normal file
963
gateway/platforms/discord.py
Normal file
@@ -0,0 +1,963 @@
|
||||
"""
|
||||
Discord platform adapter.
|
||||
|
||||
Uses discord.py library for:
|
||||
- Receiving messages from servers and DMs
|
||||
- Sending responses back
|
||||
- Handling threads and channels
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import discord
|
||||
from discord import Intents
|
||||
from discord import Message as DiscordMessage
|
||||
from discord.ext import commands
|
||||
|
||||
DISCORD_AVAILABLE = True
|
||||
except ImportError:
|
||||
DISCORD_AVAILABLE = False
|
||||
discord = None
|
||||
DiscordMessage = Any
|
||||
Intents = Any
|
||||
commands = None
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_audio_from_url,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
|
||||
def check_discord_requirements() -> bool:
|
||||
"""Check if Discord dependencies are available."""
|
||||
return DISCORD_AVAILABLE
|
||||
|
||||
|
||||
class DiscordAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Discord bot adapter.
|
||||
|
||||
Handles:
|
||||
- Receiving messages from servers and DMs
|
||||
- Sending responses with Discord markdown
|
||||
- Thread support
|
||||
- Native slash commands (/ask, /reset, /status, /stop)
|
||||
- Button-based exec approvals
|
||||
- Auto-threading for long conversations
|
||||
- Reaction-based feedback
|
||||
"""
|
||||
|
||||
# Discord message limits
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.DISCORD)
|
||||
self._client: commands.Bot | None = None
|
||||
self._ready_event = asyncio.Event()
|
||||
self._allowed_user_ids: set = set() # For button approval authorization
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
print(f"[{self.name}] discord.py not installed. Run: pip install discord.py")
|
||||
return False
|
||||
|
||||
if not self.config.token:
|
||||
print(f"[{self.name}] No bot token configured")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Set up intents -- members intent needed for username-to-ID resolution
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
intents.dm_messages = True
|
||||
intents.guild_messages = True
|
||||
intents.members = True
|
||||
|
||||
# Create bot
|
||||
self._client = commands.Bot(
|
||||
command_prefix="!", # Not really used, we handle raw messages
|
||||
intents=intents,
|
||||
)
|
||||
|
||||
# Parse allowed user entries (may contain usernames or IDs)
|
||||
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed_env:
|
||||
self._allowed_user_ids = {uid.strip() for uid in allowed_env.split(",") if uid.strip()}
|
||||
|
||||
adapter_self = self # capture for closure
|
||||
|
||||
# Register event handlers
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
print(f"[{adapter_self.name}] Connected as {adapter_self._client.user}")
|
||||
|
||||
# Resolve any usernames in the allowed list to numeric IDs
|
||||
await adapter_self._resolve_allowed_usernames()
|
||||
|
||||
# Sync slash commands with Discord
|
||||
try:
|
||||
synced = await adapter_self._client.tree.sync()
|
||||
print(f"[{adapter_self.name}] Synced {len(synced)} slash command(s)")
|
||||
except Exception as e:
|
||||
print(f"[{adapter_self.name}] Slash command sync failed: {e}")
|
||||
adapter_self._ready_event.set()
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Ignore bot's own messages
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
await self._handle_message(message)
|
||||
|
||||
# Register slash commands
|
||||
self._register_slash_commands()
|
||||
|
||||
# Start the bot in background
|
||||
asyncio.create_task(self._client.start(self.config.token))
|
||||
|
||||
# Wait for ready
|
||||
await asyncio.wait_for(self._ready_event.wait(), timeout=30)
|
||||
|
||||
self._running = True
|
||||
return True
|
||||
|
||||
except TimeoutError:
|
||||
print(f"[{self.name}] Timeout waiting for connection")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Discord."""
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during disconnect: {e}")
|
||||
|
||||
self._running = False
|
||||
self._client = None
|
||||
self._ready_event.clear()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
async def send(
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Discord channel."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Get the channel
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
message_ids = []
|
||||
reference = None
|
||||
|
||||
if reply_to:
|
||||
try:
|
||||
ref_msg = await channel.fetch_message(int(reply_to))
|
||||
reference = ref_msg
|
||||
except Exception as e:
|
||||
logger.debug("Could not fetch reply-to message: %s", e)
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
reference=reference if i == 0 else None,
|
||||
)
|
||||
message_ids.append(str(msg.id))
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
content: str,
|
||||
) -> SendResult:
|
||||
"""Edit a previously sent Discord message."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
msg = await channel.fetch_message(int(message_id))
|
||||
formatted = self.format_message(content)
|
||||
if len(formatted) > self.MAX_MESSAGE_LENGTH:
|
||||
formatted = formatted[: self.MAX_MESSAGE_LENGTH - 3] + "..."
|
||||
await msg.edit(content=formatted)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
# Determine filename from path
|
||||
filename = os.path.basename(audio_path)
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
filename = os.path.basename(image_path)
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Download the image and send as a Discord file attachment
|
||||
# (Discord renders attachments inline, unlike plain URLs)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to download image: HTTP {resp.status}")
|
||||
|
||||
image_data = await resp.read()
|
||||
|
||||
# Determine filename from URL or content type
|
||||
content_type = resp.headers.get("content-type", "image/png")
|
||||
ext = "png"
|
||||
if "jpeg" in content_type or "jpg" in content_type:
|
||||
ext = "jpg"
|
||||
elif "gif" in content_type:
|
||||
ext = "gif"
|
||||
elif "webp" in content_type:
|
||||
ext = "webp"
|
||||
|
||||
import io
|
||||
|
||||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, falling back to URL. Run: pip install aiohttp")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send image attachment, falling back to URL: {e}")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._client:
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if channel:
|
||||
await channel.typing()
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a Discord channel."""
|
||||
if not self._client:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
if not channel:
|
||||
return {"name": str(chat_id), "type": "dm"}
|
||||
|
||||
# Determine channel type
|
||||
if isinstance(channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
name = channel.recipient.name if channel.recipient else str(chat_id)
|
||||
elif isinstance(channel, discord.Thread):
|
||||
chat_type = "thread"
|
||||
name = channel.name
|
||||
elif isinstance(channel, discord.TextChannel):
|
||||
chat_type = "channel"
|
||||
name = f"#{channel.name}"
|
||||
if channel.guild:
|
||||
name = f"{channel.guild.name} / {name}"
|
||||
else:
|
||||
chat_type = "channel"
|
||||
name = getattr(channel, "name", str(chat_id))
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"type": chat_type,
|
||||
"guild_id": str(channel.guild.id) if hasattr(channel, "guild") and channel.guild else None,
|
||||
"guild_name": channel.guild.name if hasattr(channel, "guild") and channel.guild else None,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
async def _resolve_allowed_usernames(self) -> None:
|
||||
"""
|
||||
Resolve non-numeric entries in DISCORD_ALLOWED_USERS to Discord user IDs.
|
||||
|
||||
Users can specify usernames (e.g. "teknium") or display names instead of
|
||||
raw numeric IDs. After resolution, the env var and internal set are updated
|
||||
so authorization checks work with IDs only.
|
||||
"""
|
||||
if not self._allowed_user_ids or not self._client:
|
||||
return
|
||||
|
||||
numeric_ids = set()
|
||||
to_resolve = set()
|
||||
|
||||
for entry in self._allowed_user_ids:
|
||||
if entry.isdigit():
|
||||
numeric_ids.add(entry)
|
||||
else:
|
||||
to_resolve.add(entry.lower())
|
||||
|
||||
if not to_resolve:
|
||||
return
|
||||
|
||||
print(f"[{self.name}] Resolving {len(to_resolve)} username(s): {', '.join(to_resolve)}")
|
||||
resolved_count = 0
|
||||
|
||||
for guild in self._client.guilds:
|
||||
# Fetch full member list (requires members intent)
|
||||
try:
|
||||
members = guild.members
|
||||
if len(members) < guild.member_count:
|
||||
members = [m async for m in guild.fetch_members(limit=None)]
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch members for guild %s: %s", guild.name, e)
|
||||
continue
|
||||
|
||||
for member in members:
|
||||
name_lower = member.name.lower()
|
||||
display_lower = member.display_name.lower()
|
||||
global_lower = (member.global_name or "").lower()
|
||||
|
||||
matched = name_lower in to_resolve or display_lower in to_resolve or global_lower in to_resolve
|
||||
if matched:
|
||||
uid = str(member.id)
|
||||
numeric_ids.add(uid)
|
||||
resolved_count += 1
|
||||
matched_name = (
|
||||
name_lower
|
||||
if name_lower in to_resolve
|
||||
else (display_lower if display_lower in to_resolve else global_lower)
|
||||
)
|
||||
to_resolve.discard(matched_name)
|
||||
print(f"[{self.name}] Resolved '{matched_name}' -> {uid} ({member.name}#{member.discriminator})")
|
||||
|
||||
if not to_resolve:
|
||||
break
|
||||
|
||||
if to_resolve:
|
||||
print(f"[{self.name}] Could not resolve usernames: {', '.join(to_resolve)}")
|
||||
|
||||
# Update internal set and env var so gateway auth checks use IDs
|
||||
self._allowed_user_ids = numeric_ids
|
||||
os.environ["DISCORD_ALLOWED_USERS"] = ",".join(sorted(numeric_ids))
|
||||
if resolved_count:
|
||||
print(f"[{self.name}] Updated DISCORD_ALLOWED_USERS with {resolved_count} resolved ID(s)")
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format message for Discord.
|
||||
|
||||
Discord uses its own markdown variant.
|
||||
"""
|
||||
# Discord markdown is fairly standard, no special escaping needed
|
||||
return content
|
||||
|
||||
def _register_slash_commands(self) -> None:
|
||||
"""Register Discord slash commands on the command tree."""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
tree = self._client.tree
|
||||
|
||||
@tree.command(name="ask", description="Ask Hermes a question")
|
||||
@discord.app_commands.describe(question="Your question for Hermes")
|
||||
async def slash_ask(interaction: discord.Interaction, question: str):
|
||||
await interaction.response.defer()
|
||||
event = self._build_slash_event(interaction, question)
|
||||
await self.handle_message(event)
|
||||
# The response is sent via the normal send() flow
|
||||
# Send a followup to close the interaction if needed
|
||||
try:
|
||||
await interaction.followup.send("Processing complete~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="new", description="Start a new conversation")
|
||||
async def slash_new(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/reset")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("New conversation started~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="reset", description="Reset your Hermes session")
|
||||
async def slash_reset(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/reset")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Session reset~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="model", description="Show or change the model")
|
||||
@discord.app_commands.describe(name="Model name (e.g. anthropic/claude-sonnet-4). Leave empty to see current.")
|
||||
async def slash_model(interaction: discord.Interaction, name: str = ""):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/model {name}".strip())
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="personality", description="Set a personality")
|
||||
@discord.app_commands.describe(name="Personality name. Leave empty to list available.")
|
||||
async def slash_personality(interaction: discord.Interaction, name: str = ""):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/personality {name}".strip())
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="retry", description="Retry your last message")
|
||||
async def slash_retry(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/retry")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Retrying~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="undo", description="Remove the last exchange")
|
||||
async def slash_undo(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/undo")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="status", description="Show Hermes session status")
|
||||
async def slash_status(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/status")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Status sent~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="sethome", description="Set this chat as the home channel")
|
||||
async def slash_sethome(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/sethome")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="stop", description="Stop the running Hermes agent")
|
||||
async def slash_stop(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/stop")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Stop requested~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="compress", description="Compress conversation context")
|
||||
async def slash_compress(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/compress")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="title", description="Set or show the session title")
|
||||
@discord.app_commands.describe(name="Session title. Leave empty to show current.")
|
||||
async def slash_title(interaction: discord.Interaction, name: str = ""):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/title {name}".strip())
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="resume", description="Resume a previously-named session")
|
||||
@discord.app_commands.describe(name="Session name to resume. Leave empty to list sessions.")
|
||||
async def slash_resume(interaction: discord.Interaction, name: str = ""):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/resume {name}".strip())
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="usage", description="Show token usage for this session")
|
||||
async def slash_usage(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/usage")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="provider", description="Show available providers")
|
||||
async def slash_provider(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/provider")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="help", description="Show available commands")
|
||||
async def slash_help(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/help")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="insights", description="Show usage insights and analytics")
|
||||
@discord.app_commands.describe(days="Number of days to analyze (default: 7)")
|
||||
async def slash_insights(interaction: discord.Interaction, days: int = 7):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/insights {days}")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="reload-mcp", description="Reload MCP servers from config")
|
||||
async def slash_reload_mcp(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/reload-mcp")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="update", description="Update Hermes Agent to the latest version")
|
||||
async def slash_update(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/update")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Update initiated~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Discord slash command interaction."""
|
||||
is_dm = isinstance(interaction.channel, discord.DMChannel)
|
||||
chat_type = "dm" if is_dm else "group"
|
||||
chat_name = ""
|
||||
if not is_dm and hasattr(interaction.channel, "name"):
|
||||
chat_name = interaction.channel.name
|
||||
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
|
||||
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
|
||||
|
||||
# Get channel topic (if available)
|
||||
chat_topic = getattr(interaction.channel, "topic", None)
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=str(interaction.channel_id),
|
||||
chat_name=chat_name,
|
||||
chat_type=chat_type,
|
||||
user_id=str(interaction.user.id),
|
||||
user_name=interaction.user.display_name,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
msg_type = MessageType.COMMAND if text.startswith("/") else MessageType.TEXT
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=interaction,
|
||||
)
|
||||
|
||||
async def send_exec_approval(self, chat_id: str, command: str, approval_id: str) -> SendResult:
|
||||
"""
|
||||
Send a button-based exec approval prompt for a dangerous command.
|
||||
|
||||
Returns SendResult. The approval is resolved when a user clicks a button.
|
||||
"""
|
||||
if not self._client or not DISCORD_AVAILABLE:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Command Approval Required",
|
||||
description=f"```\n{command[:500]}\n```",
|
||||
color=discord.Color.orange(),
|
||||
)
|
||||
embed.set_footer(text=f"Approval ID: {approval_id}")
|
||||
|
||||
view = ExecApprovalView(
|
||||
approval_id=approval_id,
|
||||
allowed_user_ids=self._allowed_user_ids,
|
||||
)
|
||||
|
||||
msg = await channel.send(embed=embed, view=view)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
# UNLESS the channel is in the free-response list.
|
||||
#
|
||||
# Config:
|
||||
# DISCORD_FREE_RESPONSE_CHANNELS: Comma-separated channel IDs where the
|
||||
# bot responds to every message without needing a mention.
|
||||
# DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement
|
||||
# globally (all channels become free-response). Default: "true".
|
||||
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
# Check if this channel is in the free-response list
|
||||
free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
|
||||
channel_id = str(message.channel.id)
|
||||
|
||||
# Global override: if DISCORD_REQUIRE_MENTION=false, all channels are free
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
|
||||
is_free_channel = channel_id in free_channels
|
||||
|
||||
if require_mention and not is_free_channel:
|
||||
# Must be @mentioned to respond
|
||||
if self._client.user not in message.mentions:
|
||||
return # Silently ignore messages that don't mention the bot
|
||||
|
||||
# Strip the bot mention from the message text so the agent sees clean input
|
||||
if self._client.user and self._client.user in message.mentions:
|
||||
message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip()
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if message.content.startswith("/"):
|
||||
msg_type = MessageType.COMMAND
|
||||
elif message.attachments:
|
||||
# Check attachment types
|
||||
for att in message.attachments:
|
||||
if att.content_type:
|
||||
if att.content_type.startswith("image/"):
|
||||
msg_type = MessageType.PHOTO
|
||||
elif att.content_type.startswith("video/"):
|
||||
msg_type = MessageType.VIDEO
|
||||
elif att.content_type.startswith("audio/"):
|
||||
msg_type = MessageType.AUDIO
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
# Determine chat type
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
chat_name = message.author.name
|
||||
elif isinstance(message.channel, discord.Thread):
|
||||
chat_type = "thread"
|
||||
chat_name = message.channel.name
|
||||
else:
|
||||
chat_type = "group" # Treat server channels as groups
|
||||
chat_name = getattr(message.channel, "name", str(message.channel.id))
|
||||
if hasattr(message.channel, "guild") and message.channel.guild:
|
||||
chat_name = f"{message.channel.guild.name} / #{chat_name}"
|
||||
|
||||
# Get thread ID if in a thread
|
||||
thread_id = None
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
thread_id = str(message.channel.id)
|
||||
|
||||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't)
|
||||
chat_topic = getattr(message.channel, "topic", None)
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(message.channel.id),
|
||||
chat_name=chat_name,
|
||||
chat_type=chat_type,
|
||||
user_id=str(message.author.id),
|
||||
user_name=message.author.display_name,
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
# Build media URLs -- download image attachments to local cache so the
|
||||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||||
media_urls = []
|
||||
media_types = []
|
||||
for att in message.attachments:
|
||||
content_type = att.content_type or "unknown"
|
||||
if content_type.startswith("image/"):
|
||||
try:
|
||||
# Determine extension from content type (image/png -> .png)
|
||||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||||
if ext not in (".jpg", ".jpeg", ".png", ".gif", ".webp"):
|
||||
ext = ".jpg"
|
||||
cached_path = await cache_image_from_url(att.url, ext=ext)
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(content_type)
|
||||
print(f"[Discord] Cached user image: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Discord] Failed to cache image attachment: {e}", flush=True)
|
||||
# Fall back to the CDN URL if caching fails
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
elif content_type.startswith("audio/"):
|
||||
try:
|
||||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||||
if ext not in (".ogg", ".mp3", ".wav", ".webm", ".m4a"):
|
||||
ext = ".ogg"
|
||||
cached_path = await cache_audio_from_url(att.url, ext=ext)
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(content_type)
|
||||
print(f"[Discord] Cached user audio: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Discord] Failed to cache audio attachment: {e}", flush=True)
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
else:
|
||||
# Other attachments: keep the original URL
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
|
||||
event = MessageEvent(
|
||||
text=message.content,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=message,
|
||||
message_id=str(message.id),
|
||||
media_urls=media_urls,
|
||||
media_types=media_types,
|
||||
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
|
||||
timestamp=message.created_at,
|
||||
)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord UI Components (outside the adapter class)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
|
||||
class ExecApprovalView(discord.ui.View):
|
||||
"""
|
||||
Interactive button view for exec approval of dangerous commands.
|
||||
|
||||
Shows three buttons: Allow Once (green), Always Allow (blue), Deny (red).
|
||||
Only users in the allowed list can click. The view times out after 5 minutes.
|
||||
"""
|
||||
|
||||
def __init__(self, approval_id: str, allowed_user_ids: set):
|
||||
super().__init__(timeout=300) # 5-minute timeout
|
||||
self.approval_id = approval_id
|
||||
self.allowed_user_ids = allowed_user_ids
|
||||
self.resolved = False
|
||||
|
||||
def _check_auth(self, interaction: discord.Interaction) -> bool:
|
||||
"""Verify the user clicking is authorized."""
|
||||
if not self.allowed_user_ids:
|
||||
return True # No allowlist = anyone can approve
|
||||
return str(interaction.user.id) in self.allowed_user_ids
|
||||
|
||||
async def _resolve(self, interaction: discord.Interaction, action: str, color: discord.Color):
|
||||
"""Resolve the approval and update the message."""
|
||||
if self.resolved:
|
||||
await interaction.response.send_message("This approval has already been resolved~", ephemeral=True)
|
||||
return
|
||||
|
||||
if not self._check_auth(interaction):
|
||||
await interaction.response.send_message("You're not authorized to approve commands~", ephemeral=True)
|
||||
return
|
||||
|
||||
self.resolved = True
|
||||
|
||||
# Update the embed with the decision
|
||||
embed = interaction.message.embeds[0] if interaction.message.embeds else None
|
||||
if embed:
|
||||
embed.color = color
|
||||
embed.set_footer(text=f"{action} by {interaction.user.display_name}")
|
||||
|
||||
# Disable all buttons
|
||||
for child in self.children:
|
||||
child.disabled = True
|
||||
|
||||
await interaction.response.edit_message(embed=embed, view=self)
|
||||
|
||||
# Store the approval decision
|
||||
try:
|
||||
from tools.approval import approve_permanent
|
||||
|
||||
if action == "allow_once":
|
||||
pass # One-time approval handled by gateway
|
||||
elif action == "allow_always":
|
||||
approve_permanent(self.approval_id)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@discord.ui.button(label="Allow Once", style=discord.ButtonStyle.green)
|
||||
async def allow_once(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
await self._resolve(interaction, "allow_once", discord.Color.green())
|
||||
|
||||
@discord.ui.button(label="Always Allow", style=discord.ButtonStyle.blurple)
|
||||
async def allow_always(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
await self._resolve(interaction, "allow_always", discord.Color.blue())
|
||||
|
||||
@discord.ui.button(label="Deny", style=discord.ButtonStyle.red)
|
||||
async def deny(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
await self._resolve(interaction, "deny", discord.Color.red())
|
||||
|
||||
async def on_timeout(self):
|
||||
"""Handle view timeout -- disable buttons and mark as expired."""
|
||||
self.resolved = True
|
||||
for child in self.children:
|
||||
child.disabled = True
|
||||
427
gateway/platforms/homeassistant.py
Normal file
427
gateway/platforms/homeassistant.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""
|
||||
Home Assistant platform adapter.
|
||||
|
||||
Connects to the HA WebSocket API for real-time event monitoring.
|
||||
State-change events are converted to MessageEvent objects and forwarded
|
||||
to the agent for processing. Outbound messages are delivered as HA
|
||||
persistent notifications.
|
||||
|
||||
Requires:
|
||||
- aiohttp (already in messaging extras)
|
||||
- HASS_TOKEN env var (Long-Lived Access Token)
|
||||
- HASS_URL env var (default: http://homeassistant.local:8123)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
aiohttp = None # type: ignore[assignment]
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_ha_requirements() -> bool:
|
||||
"""Check if Home Assistant dependencies are available and configured."""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
return False
|
||||
if not os.getenv("HASS_TOKEN"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Home Assistant WebSocket adapter.
|
||||
|
||||
Subscribes to ``state_changed`` events and forwards them as
|
||||
MessageEvent objects. Supports domain/entity filtering and
|
||||
per-entity cooldowns to avoid event floods.
|
||||
"""
|
||||
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
# Reconnection backoff schedule (seconds)
|
||||
_BACKOFF_STEPS = [5, 10, 30, 60]
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.HOMEASSISTANT)
|
||||
|
||||
# Connection state
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._rest_session: aiohttp.ClientSession | None = None
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._msg_id: int = 0
|
||||
|
||||
# Configuration from extra
|
||||
extra = config.extra or {}
|
||||
token = config.token or os.getenv("HASS_TOKEN", "")
|
||||
url = extra.get("url") or os.getenv("HASS_URL", "http://homeassistant.local:8123")
|
||||
self._hass_url: str = url.rstrip("/")
|
||||
self._hass_token: str = token
|
||||
|
||||
# Event filtering
|
||||
self._watch_domains: set[str] = set(extra.get("watch_domains", []))
|
||||
self._watch_entities: set[str] = set(extra.get("watch_entities", []))
|
||||
self._ignore_entities: set[str] = set(extra.get("ignore_entities", []))
|
||||
self._cooldown_seconds: int = int(extra.get("cooldown_seconds", 30))
|
||||
|
||||
# Cooldown tracking: entity_id -> last_event_timestamp
|
||||
self._last_event_time: dict[str, float] = {}
|
||||
|
||||
def _next_id(self) -> int:
|
||||
"""Return the next WebSocket message ID."""
|
||||
self._msg_id += 1
|
||||
return self._msg_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to HA WebSocket API and subscribe to events."""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
logger.warning("[%s] aiohttp not installed. Run: pip install aiohttp", self.name)
|
||||
return False
|
||||
|
||||
if not self._hass_token:
|
||||
logger.warning("[%s] No HASS_TOKEN configured", self.name)
|
||||
return False
|
||||
|
||||
try:
|
||||
success = await self._ws_connect()
|
||||
if not success:
|
||||
return False
|
||||
|
||||
# Dedicated REST session for send() calls
|
||||
self._rest_session = aiohttp.ClientSession()
|
||||
|
||||
# Start background listener
|
||||
self._listen_task = asyncio.create_task(self._listen_loop())
|
||||
self._running = True
|
||||
logger.info("[%s] Connected to %s", self.name, self._hass_url)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[%s] Failed to connect: %s", self.name, e)
|
||||
return False
|
||||
|
||||
async def _ws_connect(self) -> bool:
|
||||
"""Establish WebSocket connection and authenticate."""
|
||||
ws_url = self._hass_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_url}/api/websocket"
|
||||
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._ws = await self._session.ws_connect(ws_url, heartbeat=30)
|
||||
|
||||
# Step 1: Receive auth_required
|
||||
msg = await self._ws.receive_json()
|
||||
if msg.get("type") != "auth_required":
|
||||
logger.error("Expected auth_required, got: %s", msg.get("type"))
|
||||
await self._cleanup_ws()
|
||||
return False
|
||||
|
||||
# Step 2: Send auth
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"type": "auth",
|
||||
"access_token": self._hass_token,
|
||||
}
|
||||
)
|
||||
|
||||
# Step 3: Wait for auth_ok
|
||||
msg = await self._ws.receive_json()
|
||||
if msg.get("type") != "auth_ok":
|
||||
logger.error("Auth failed: %s", msg)
|
||||
await self._cleanup_ws()
|
||||
return False
|
||||
|
||||
# Step 4: Subscribe to state_changed events
|
||||
sub_id = self._next_id()
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"id": sub_id,
|
||||
"type": "subscribe_events",
|
||||
"event_type": "state_changed",
|
||||
}
|
||||
)
|
||||
|
||||
# Verify subscription acknowledgement
|
||||
msg = await self._ws.receive_json()
|
||||
if not msg.get("success"):
|
||||
logger.error("Failed to subscribe to events: %s", msg)
|
||||
await self._cleanup_ws()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _cleanup_ws(self) -> None:
|
||||
"""Close WebSocket and session."""
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Home Assistant."""
|
||||
self._running = False
|
||||
if self._listen_task:
|
||||
self._listen_task.cancel()
|
||||
try:
|
||||
await self._listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._listen_task = None
|
||||
|
||||
await self._cleanup_ws()
|
||||
if self._rest_session and not self._rest_session.closed:
|
||||
await self._rest_session.close()
|
||||
self._rest_session = None
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event listener
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _listen_loop(self) -> None:
|
||||
"""Main event loop with automatic reconnection."""
|
||||
backoff_idx = 0
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._read_events()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning("[%s] WebSocket error: %s", self.name, e)
|
||||
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
# Reconnect with backoff
|
||||
delay = self._BACKOFF_STEPS[min(backoff_idx, len(self._BACKOFF_STEPS) - 1)]
|
||||
logger.info("[%s] Reconnecting in %ds...", self.name, delay)
|
||||
await asyncio.sleep(delay)
|
||||
backoff_idx += 1
|
||||
|
||||
try:
|
||||
await self._cleanup_ws()
|
||||
success = await self._ws_connect()
|
||||
if success:
|
||||
backoff_idx = 0 # Reset on successful reconnect
|
||||
logger.info("[%s] Reconnected", self.name)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Reconnection failed: %s", self.name, e)
|
||||
|
||||
async def _read_events(self) -> None:
|
||||
"""Read events from WebSocket until disconnected."""
|
||||
if self._ws is None or self._ws.closed:
|
||||
return
|
||||
async for ws_msg in self._ws:
|
||||
if ws_msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(ws_msg.data)
|
||||
if data.get("type") == "event":
|
||||
await self._handle_ha_event(data.get("event", {}))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Invalid JSON from HA WS: %s", ws_msg.data[:200])
|
||||
elif ws_msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
|
||||
break
|
||||
|
||||
async def _handle_ha_event(self, event: dict[str, Any]) -> None:
|
||||
"""Process a state_changed event from Home Assistant."""
|
||||
event_data = event.get("data", {})
|
||||
entity_id: str = event_data.get("entity_id", "")
|
||||
|
||||
if not entity_id:
|
||||
return
|
||||
|
||||
# Apply ignore filter
|
||||
if entity_id in self._ignore_entities:
|
||||
return
|
||||
|
||||
# Apply domain/entity watch filters
|
||||
domain = entity_id.split(".")[0] if "." in entity_id else ""
|
||||
if self._watch_domains or self._watch_entities:
|
||||
domain_match = domain in self._watch_domains if self._watch_domains else False
|
||||
entity_match = entity_id in self._watch_entities if self._watch_entities else False
|
||||
if not domain_match and not entity_match:
|
||||
return
|
||||
|
||||
# Apply cooldown
|
||||
now = time.time()
|
||||
last = self._last_event_time.get(entity_id, 0)
|
||||
if (now - last) < self._cooldown_seconds:
|
||||
return
|
||||
self._last_event_time[entity_id] = now
|
||||
|
||||
# Build human-readable message
|
||||
old_state = event_data.get("old_state", {})
|
||||
new_state = event_data.get("new_state", {})
|
||||
message = self._format_state_change(entity_id, old_state, new_state)
|
||||
|
||||
if not message:
|
||||
return
|
||||
|
||||
# Build MessageEvent and forward to handler
|
||||
source = self.build_source(
|
||||
chat_id="ha_events",
|
||||
chat_name="Home Assistant Events",
|
||||
chat_type="channel",
|
||||
user_id="homeassistant",
|
||||
user_name="Home Assistant",
|
||||
)
|
||||
|
||||
msg_event = MessageEvent(
|
||||
text=message,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id=f"ha_{entity_id}_{int(now)}",
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
@staticmethod
|
||||
def _format_state_change(
|
||||
entity_id: str,
|
||||
old_state: dict[str, Any],
|
||||
new_state: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""Convert a state_changed event into a human-readable description."""
|
||||
if not new_state:
|
||||
return None
|
||||
|
||||
old_val = old_state.get("state", "unknown") if old_state else "unknown"
|
||||
new_val = new_state.get("state", "unknown")
|
||||
|
||||
# Skip if state didn't actually change
|
||||
if old_val == new_val:
|
||||
return None
|
||||
|
||||
friendly_name = new_state.get("attributes", {}).get("friendly_name", entity_id)
|
||||
domain = entity_id.split(".")[0] if "." in entity_id else ""
|
||||
|
||||
# Domain-specific formatting
|
||||
if domain == "climate":
|
||||
attrs = new_state.get("attributes", {})
|
||||
temp = attrs.get("current_temperature", "?")
|
||||
target = attrs.get("temperature", "?")
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: HVAC mode changed from "
|
||||
f"'{old_val}' to '{new_val}' (current: {temp}, target: {target})"
|
||||
)
|
||||
|
||||
if domain == "sensor":
|
||||
unit = new_state.get("attributes", {}).get("unit_of_measurement", "")
|
||||
return f"[Home Assistant] {friendly_name}: changed from {old_val}{unit} to {new_val}{unit}"
|
||||
|
||||
if domain == "binary_sensor":
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: "
|
||||
f"{'triggered' if new_val == 'on' else 'cleared'} "
|
||||
f"(was {'triggered' if old_val == 'on' else 'cleared'})"
|
||||
)
|
||||
|
||||
if domain in ("light", "switch", "fan"):
|
||||
return f"[Home Assistant] {friendly_name}: turned {'on' if new_val == 'on' else 'off'}"
|
||||
|
||||
if domain == "alarm_control_panel":
|
||||
return f"[Home Assistant] {friendly_name}: alarm state changed from '{old_val}' to '{new_val}'"
|
||||
|
||||
# Generic fallback
|
||||
return f"[Home Assistant] {friendly_name} ({entity_id}): changed from '{old_val}' to '{new_val}'"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound messaging
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a notification via HA REST API (persistent_notification.create).
|
||||
|
||||
Uses the REST API instead of WebSocket to avoid a race condition
|
||||
with the event listener loop that reads from the same WS connection.
|
||||
"""
|
||||
url = f"{self._hass_url}/api/services/persistent_notification/create"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._hass_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"title": "Hermes Agent",
|
||||
"message": content[: self.MAX_MESSAGE_LENGTH],
|
||||
}
|
||||
|
||||
try:
|
||||
if self._rest_session:
|
||||
async with self._rest_session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
if resp.status < 300:
|
||||
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
|
||||
else:
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
else:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp,
|
||||
):
|
||||
if resp.status < 300:
|
||||
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
|
||||
else:
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
|
||||
except TimeoutError:
|
||||
return SendResult(success=False, error="Timeout sending notification to HA")
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""No typing indicator for Home Assistant."""
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Return basic info about the HA event channel."""
|
||||
return {
|
||||
"name": "Home Assistant Events",
|
||||
"type": "channel",
|
||||
"url": self._hass_url,
|
||||
}
|
||||
718
gateway/platforms/signal.py
Normal file
718
gateway/platforms/signal.py
Normal file
@@ -0,0 +1,718 @@
|
||||
"""Signal messenger platform adapter.
|
||||
|
||||
Connects to a signal-cli daemon running in HTTP mode.
|
||||
Inbound messages arrive via SSE (Server-Sent Events) streaming.
|
||||
Outbound messages and actions use JSON-RPC 2.0 over HTTP.
|
||||
|
||||
Based on PR #268 by ibhagwan, rebuilt with bug fixes.
|
||||
|
||||
Requires:
|
||||
- signal-cli installed and running: signal-cli daemon --http 127.0.0.1:8080
|
||||
- SIGNAL_HTTP_URL and SIGNAL_ACCOUNT environment variables set
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_bytes,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
SIGNAL_MAX_ATTACHMENT_SIZE = 100 * 1024 * 1024 # 100 MB
|
||||
MAX_MESSAGE_LENGTH = 8000 # Signal message size limit
|
||||
TYPING_INTERVAL = 8.0 # seconds between typing indicator refreshes
|
||||
SSE_RETRY_DELAY_INITIAL = 2.0
|
||||
SSE_RETRY_DELAY_MAX = 60.0
|
||||
HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks
|
||||
HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern
|
||||
|
||||
# E.164 phone number pattern for redaction
|
||||
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
|
||||
def _parse_comma_list(value: str) -> list[str]:
|
||||
"""Split a comma-separated string into a list, stripping whitespace."""
|
||||
return [v.strip() for v in value.split(",") if v.strip()]
|
||||
|
||||
|
||||
def _guess_extension(data: bytes) -> str:
|
||||
"""Guess file extension from magic bytes."""
|
||||
if data[:4] == b"\x89PNG":
|
||||
return ".png"
|
||||
if data[:2] == b"\xff\xd8":
|
||||
return ".jpg"
|
||||
if data[:4] == b"GIF8":
|
||||
return ".gif"
|
||||
if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"WEBP":
|
||||
return ".webp"
|
||||
if data[:4] == b"%PDF":
|
||||
return ".pdf"
|
||||
if len(data) >= 8 and data[4:8] == b"ftyp":
|
||||
return ".mp4"
|
||||
if data[:4] == b"OggS":
|
||||
return ".ogg"
|
||||
if len(data) >= 2 and data[0] == 0xFF and (data[1] & 0xE0) == 0xE0:
|
||||
return ".mp3"
|
||||
if data[:2] == b"PK":
|
||||
return ".zip"
|
||||
return ".bin"
|
||||
|
||||
|
||||
def _is_image_ext(ext: str) -> bool:
|
||||
return ext.lower() in (".jpg", ".jpeg", ".png", ".gif", ".webp")
|
||||
|
||||
|
||||
def _is_audio_ext(ext: str) -> bool:
|
||||
return ext.lower() in (".mp3", ".wav", ".ogg", ".m4a", ".aac")
|
||||
|
||||
|
||||
def _render_mentions(text: str, mentions: list) -> str:
|
||||
"""Replace Signal mention placeholders (\\uFFFC) with readable @identifiers.
|
||||
|
||||
Signal encodes @mentions as the Unicode object replacement character
|
||||
with out-of-band metadata containing the mentioned user's UUID/number.
|
||||
"""
|
||||
if not mentions or "\ufffc" not in text:
|
||||
return text
|
||||
# Sort mentions by start position (reverse) to replace from end to start
|
||||
# so indices don't shift as we replace
|
||||
sorted_mentions = sorted(mentions, key=lambda m: m.get("start", 0), reverse=True)
|
||||
for mention in sorted_mentions:
|
||||
start = mention.get("start", 0)
|
||||
length = mention.get("length", 1)
|
||||
# Use the mention's number or UUID as the replacement
|
||||
identifier = mention.get("number") or mention.get("uuid") or "user"
|
||||
replacement = f"@{identifier}"
|
||||
text = text[:start] + replacement + text[start + length :]
|
||||
return text
|
||||
|
||||
|
||||
def check_signal_requirements() -> bool:
|
||||
"""Check if Signal is configured (has URL and account)."""
|
||||
return bool(os.getenv("SIGNAL_HTTP_URL") and os.getenv("SIGNAL_ACCOUNT"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Signal Adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SignalAdapter(BasePlatformAdapter):
|
||||
"""Signal messenger adapter using signal-cli HTTP daemon."""
|
||||
|
||||
platform = Platform.SIGNAL
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.SIGNAL)
|
||||
|
||||
extra = config.extra or {}
|
||||
self.http_url = extra.get("http_url", "http://127.0.0.1:8080").rstrip("/")
|
||||
self.account = extra.get("account", "")
|
||||
self.ignore_stories = extra.get("ignore_stories", True)
|
||||
|
||||
# Parse allowlists — group policy is derived from presence of group allowlist
|
||||
group_allowed_str = os.getenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
self.group_allow_from = set(_parse_comma_list(group_allowed_str))
|
||||
|
||||
# HTTP client
|
||||
self.client: httpx.AsyncClient | None = None
|
||||
|
||||
# Background tasks
|
||||
self._sse_task: asyncio.Task | None = None
|
||||
self._health_monitor_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._running = False
|
||||
self._last_sse_activity = 0.0
|
||||
self._sse_response: httpx.Response | None = None
|
||||
|
||||
# Normalize account for self-message filtering
|
||||
self._account_normalized = self.account.strip()
|
||||
|
||||
logger.info(
|
||||
"Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url,
|
||||
_redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to signal-cli daemon and start SSE listener."""
|
||||
if not self.http_url or not self.account:
|
||||
logger.error("Signal: SIGNAL_HTTP_URL and SIGNAL_ACCOUNT are required")
|
||||
return False
|
||||
|
||||
self.client = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
# Health check — verify signal-cli daemon is reachable
|
||||
try:
|
||||
resp = await self.client.get(f"{self.http_url}/api/v1/check", timeout=10.0)
|
||||
if resp.status_code != 200:
|
||||
logger.error("Signal: health check failed (status %d)", resp.status_code)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("Signal: cannot reach signal-cli at %s: %s", self.http_url, e)
|
||||
return False
|
||||
|
||||
self._running = True
|
||||
self._last_sse_activity = time.time()
|
||||
self._sse_task = asyncio.create_task(self._sse_listener())
|
||||
self._health_monitor_task = asyncio.create_task(self._health_monitor())
|
||||
|
||||
logger.info("Signal: connected to %s", self.http_url)
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop SSE listener and clean up."""
|
||||
self._running = False
|
||||
|
||||
if self._sse_task:
|
||||
self._sse_task.cancel()
|
||||
try:
|
||||
await self._sse_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._health_monitor_task:
|
||||
self._health_monitor_task.cancel()
|
||||
try:
|
||||
await self._health_monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Cancel all typing tasks
|
||||
for task in self._typing_tasks.values():
|
||||
task.cancel()
|
||||
self._typing_tasks.clear()
|
||||
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
self.client = None
|
||||
|
||||
logger.info("Signal: disconnected")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Streaming (inbound messages)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _sse_listener(self) -> None:
|
||||
"""Listen for SSE events from signal-cli daemon."""
|
||||
url = f"{self.http_url}/api/v1/events?account={self.account}"
|
||||
backoff = SSE_RETRY_DELAY_INITIAL
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
logger.debug("Signal SSE: connecting to %s", url)
|
||||
async with self.client.stream(
|
||||
"GET",
|
||||
url,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
timeout=None,
|
||||
) as response:
|
||||
self._sse_response = response
|
||||
backoff = SSE_RETRY_DELAY_INITIAL # Reset on successful connection
|
||||
self._last_sse_activity = time.time()
|
||||
logger.info("Signal SSE: connected")
|
||||
|
||||
buffer = ""
|
||||
async for chunk in response.aiter_text():
|
||||
if not self._running:
|
||||
break
|
||||
buffer += chunk
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Parse SSE data lines
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
if not data_str:
|
||||
continue
|
||||
self._last_sse_activity = time.time()
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
await self._handle_envelope(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Signal SSE: invalid JSON: %s", data_str[:100])
|
||||
except Exception:
|
||||
logger.exception("Signal SSE: error handling event")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except httpx.HTTPError as e:
|
||||
if self._running:
|
||||
logger.warning("Signal SSE: HTTP error: %s (reconnecting in %.0fs)", e, backoff)
|
||||
except Exception as e:
|
||||
if self._running:
|
||||
logger.warning("Signal SSE: error: %s (reconnecting in %.0fs)", e, backoff)
|
||||
|
||||
if self._running:
|
||||
# Add 20% jitter to prevent thundering herd on reconnection
|
||||
jitter = backoff * 0.2 * random.random()
|
||||
await asyncio.sleep(backoff + jitter)
|
||||
backoff = min(backoff * 2, SSE_RETRY_DELAY_MAX)
|
||||
|
||||
self._sse_response = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health Monitor
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _health_monitor(self) -> None:
|
||||
"""Monitor SSE connection health and force reconnect if stale."""
|
||||
while self._running:
|
||||
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
elapsed = time.time() - self._last_sse_activity
|
||||
if elapsed > HEALTH_CHECK_STALE_THRESHOLD:
|
||||
logger.warning("Signal: SSE idle for %.0fs, checking daemon health", elapsed)
|
||||
try:
|
||||
resp = await self.client.get(f"{self.http_url}/api/v1/check", timeout=10.0)
|
||||
if resp.status_code == 200:
|
||||
# Daemon is alive but SSE is idle — update activity to
|
||||
# avoid repeated warnings (connection may just be quiet)
|
||||
self._last_sse_activity = time.time()
|
||||
logger.debug("Signal: daemon healthy, SSE idle")
|
||||
else:
|
||||
logger.warning("Signal: health check failed (%d), forcing reconnect", resp.status_code)
|
||||
self._force_reconnect()
|
||||
except Exception as e:
|
||||
logger.warning("Signal: health check error: %s, forcing reconnect", e)
|
||||
self._force_reconnect()
|
||||
|
||||
def _force_reconnect(self) -> None:
|
||||
"""Force SSE reconnection by closing the current response."""
|
||||
if self._sse_response and not self._sse_response.is_stream_consumed:
|
||||
try:
|
||||
asyncio.create_task(self._sse_response.aclose())
|
||||
except Exception:
|
||||
pass
|
||||
self._sse_response = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message Handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_envelope(self, envelope: dict) -> None:
|
||||
"""Process an incoming signal-cli envelope."""
|
||||
# Unwrap nested envelope if present
|
||||
envelope_data = envelope.get("envelope", envelope)
|
||||
|
||||
# Filter syncMessage envelopes (sent transcripts, read receipts, etc.)
|
||||
# signal-cli may set syncMessage to null vs omitting it, so check key existence
|
||||
if "syncMessage" in envelope_data:
|
||||
return
|
||||
|
||||
# Extract sender info
|
||||
sender = envelope_data.get("sourceNumber") or envelope_data.get("sourceUuid") or envelope_data.get("source")
|
||||
sender_name = envelope_data.get("sourceName", "")
|
||||
sender_uuid = envelope_data.get("sourceUuid", "")
|
||||
|
||||
if not sender:
|
||||
logger.debug("Signal: ignoring envelope with no sender")
|
||||
return
|
||||
|
||||
# Self-message filtering — prevent reply loops
|
||||
if self._account_normalized and sender == self._account_normalized:
|
||||
return
|
||||
|
||||
# Filter stories
|
||||
if self.ignore_stories and envelope_data.get("storyMessage"):
|
||||
return
|
||||
|
||||
# Get data message — also check editMessage (edited messages contain
|
||||
# their updated dataMessage inside editMessage.dataMessage)
|
||||
data_message = envelope_data.get("dataMessage") or (envelope_data.get("editMessage") or {}).get("dataMessage")
|
||||
if not data_message:
|
||||
return
|
||||
|
||||
# Check for group message
|
||||
group_info = data_message.get("groupInfo")
|
||||
group_id = group_info.get("groupId") if group_info else None
|
||||
is_group = bool(group_id)
|
||||
|
||||
# Group message filtering — derived from SIGNAL_GROUP_ALLOWED_USERS:
|
||||
# - No env var set → groups disabled (default safe behavior)
|
||||
# - Env var set with group IDs → only those groups allowed
|
||||
# - Env var set with "*" → all groups allowed
|
||||
# DM auth is fully handled by run.py (_is_user_authorized)
|
||||
if is_group:
|
||||
if not self.group_allow_from:
|
||||
logger.debug("Signal: ignoring group message (no SIGNAL_GROUP_ALLOWED_USERS)")
|
||||
return
|
||||
if "*" not in self.group_allow_from and group_id not in self.group_allow_from:
|
||||
logger.debug("Signal: group %s not in allowlist", group_id[:8] if group_id else "?")
|
||||
return
|
||||
|
||||
# Build chat info
|
||||
chat_id = sender if not is_group else f"group:{group_id}"
|
||||
chat_type = "group" if is_group else "dm"
|
||||
|
||||
# Extract text and render mentions
|
||||
text = data_message.get("message", "")
|
||||
mentions = data_message.get("mentions", [])
|
||||
if text and mentions:
|
||||
text = _render_mentions(text, mentions)
|
||||
|
||||
# Process attachments
|
||||
attachments_data = data_message.get("attachments", [])
|
||||
image_paths = []
|
||||
audio_path = None
|
||||
document_paths = []
|
||||
|
||||
if attachments_data and not getattr(self, "ignore_attachments", False):
|
||||
for att in attachments_data:
|
||||
att_id = att.get("id")
|
||||
att_size = att.get("size", 0)
|
||||
if not att_id:
|
||||
continue
|
||||
if att_size > SIGNAL_MAX_ATTACHMENT_SIZE:
|
||||
logger.warning("Signal: attachment too large (%d bytes), skipping", att_size)
|
||||
continue
|
||||
try:
|
||||
cached_path, ext = await self._fetch_attachment(att_id)
|
||||
if cached_path:
|
||||
if _is_image_ext(ext):
|
||||
image_paths.append(cached_path)
|
||||
elif _is_audio_ext(ext):
|
||||
audio_path = cached_path
|
||||
else:
|
||||
document_paths.append(cached_path)
|
||||
except Exception:
|
||||
logger.exception("Signal: failed to fetch attachment %s", att_id)
|
||||
|
||||
# Build session source
|
||||
source = self.build_source(
|
||||
chat_id=chat_id,
|
||||
chat_name=group_info.get("groupName") if group_info else sender_name,
|
||||
chat_type=chat_type,
|
||||
user_id=sender,
|
||||
user_name=sender_name or sender,
|
||||
user_id_alt=sender_uuid if sender_uuid else None,
|
||||
chat_id_alt=group_id if is_group else None,
|
||||
)
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if audio_path:
|
||||
msg_type = MessageType.VOICE
|
||||
elif image_paths:
|
||||
msg_type = MessageType.IMAGE
|
||||
|
||||
# Parse timestamp from envelope data (milliseconds since epoch)
|
||||
ts_ms = envelope_data.get("timestamp", 0)
|
||||
if ts_ms:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=UTC)
|
||||
except (ValueError, OSError):
|
||||
timestamp = datetime.now(tz=UTC)
|
||||
else:
|
||||
timestamp = datetime.now(tz=UTC)
|
||||
|
||||
# Build and dispatch event
|
||||
event = MessageEvent(
|
||||
source=source,
|
||||
text=text or "",
|
||||
message_type=msg_type,
|
||||
image_paths=image_paths,
|
||||
audio_path=audio_path,
|
||||
document_paths=document_paths,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
logger.debug("Signal: message from %s in %s: %s", _redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Attachment Handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _fetch_attachment(self, attachment_id: str) -> tuple:
|
||||
"""Fetch an attachment via JSON-RPC and cache it. Returns (path, ext)."""
|
||||
result = await self._rpc(
|
||||
"getAttachment",
|
||||
{
|
||||
"account": self.account,
|
||||
"attachmentId": attachment_id,
|
||||
},
|
||||
)
|
||||
|
||||
if not result:
|
||||
return None, ""
|
||||
|
||||
# Result is base64-encoded file content
|
||||
raw_data = base64.b64decode(result)
|
||||
ext = _guess_extension(raw_data)
|
||||
|
||||
if _is_image_ext(ext):
|
||||
path = cache_image_from_bytes(raw_data, ext)
|
||||
elif _is_audio_ext(ext):
|
||||
path = cache_audio_from_bytes(raw_data, ext)
|
||||
else:
|
||||
path = cache_document_from_bytes(raw_data, ext)
|
||||
|
||||
return path, ext
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# JSON-RPC Communication
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _rpc(self, method: str, params: dict, rpc_id: str = None) -> Any:
|
||||
"""Send a JSON-RPC 2.0 request to signal-cli daemon."""
|
||||
if not self.client:
|
||||
logger.warning("Signal: RPC called but client not connected")
|
||||
return None
|
||||
|
||||
if rpc_id is None:
|
||||
rpc_id = f"{method}_{int(time.time() * 1000)}"
|
||||
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
"id": rpc_id,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = await self.client.post(
|
||||
f"{self.http_url}/api/v1/rpc",
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
logger.warning("Signal RPC error (%s): %s", method, data["error"])
|
||||
return None
|
||||
|
||||
return data.get("result")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Signal RPC %s failed: %s", method, e)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sending
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to_message_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a text message."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": text,
|
||||
}
|
||||
|
||||
if chat_id.startswith("group:"):
|
||||
params["groupId"] = chat_id[6:]
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
result = await self._rpc("send", params)
|
||||
|
||||
if result is not None:
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send failed")
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send a typing indicator."""
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
}
|
||||
|
||||
if chat_id.startswith("group:"):
|
||||
params["groupId"] = chat_id[6:]
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
await self._rpc("sendTyping", params, rpc_id="typing")
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send an image. Supports http(s):// and file:// URLs."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
# Resolve image to local path
|
||||
if image_url.startswith("file://"):
|
||||
file_path = unquote(image_url[7:])
|
||||
else:
|
||||
# Download remote image to cache
|
||||
try:
|
||||
file_path = await cache_image_from_url(image_url)
|
||||
except Exception as e:
|
||||
logger.warning("Signal: failed to download image: %s", e)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
if not file_path or not Path(file_path).exists():
|
||||
return SendResult(success=False, error="Image file not found")
|
||||
|
||||
# Validate size
|
||||
file_size = Path(file_path).stat().st_size
|
||||
if file_size > SIGNAL_MAX_ATTACHMENT_SIZE:
|
||||
return SendResult(success=False, error=f"Image too large ({file_size} bytes)")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": caption or "",
|
||||
"attachments": [file_path],
|
||||
}
|
||||
|
||||
if chat_id.startswith("group:"):
|
||||
params["groupId"] = chat_id[6:]
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
result = await self._rpc("send", params)
|
||||
if result is not None:
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send with attachment failed")
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: str | None = None,
|
||||
filename: str | None = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
if not Path(file_path).exists():
|
||||
return SendResult(success=False, error="File not found")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": caption or "",
|
||||
"attachments": [file_path],
|
||||
}
|
||||
|
||||
if chat_id.startswith("group:"):
|
||||
params["groupId"] = chat_id[6:]
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
result = await self._rpc("send", params)
|
||||
if result is not None:
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send document failed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Typing Indicators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _start_typing_indicator(self, chat_id: str) -> None:
|
||||
"""Start a typing indicator loop for a chat."""
|
||||
if chat_id in self._typing_tasks:
|
||||
return # Already running
|
||||
|
||||
async def _typing_loop():
|
||||
try:
|
||||
while True:
|
||||
await self.send_typing(chat_id)
|
||||
await asyncio.sleep(TYPING_INTERVAL)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(_typing_loop())
|
||||
|
||||
async def _stop_typing_indicator(self, chat_id: str) -> None:
|
||||
"""Stop a typing indicator loop for a chat."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chat Info
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a chat/contact."""
|
||||
if chat_id.startswith("group:"):
|
||||
return {
|
||||
"name": chat_id,
|
||||
"type": "group",
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
|
||||
# Try to resolve contact name
|
||||
result = await self._rpc(
|
||||
"getContact",
|
||||
{
|
||||
"account": self.account,
|
||||
"contactAddress": chat_id,
|
||||
},
|
||||
)
|
||||
|
||||
name = chat_id
|
||||
if result and isinstance(result, dict):
|
||||
name = result.get("name") or result.get("profileName") or chat_id
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"type": "dm",
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
568
gateway/platforms/slack.py
Normal file
568
gateway/platforms/slack.py
Normal file
@@ -0,0 +1,568 @@
|
||||
"""
|
||||
Slack platform adapter.
|
||||
|
||||
Uses slack-bolt (Python) with Socket Mode for:
|
||||
- Receiving messages from channels and DMs
|
||||
- Sending responses back
|
||||
- Handling slash commands
|
||||
- Thread support
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
SLACK_AVAILABLE = True
|
||||
except ImportError:
|
||||
SLACK_AVAILABLE = False
|
||||
AsyncApp = Any
|
||||
AsyncSocketModeHandler = Any
|
||||
AsyncWebClient = Any
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_document_from_bytes,
|
||||
)
|
||||
|
||||
|
||||
def check_slack_requirements() -> bool:
|
||||
"""Check if Slack dependencies are available."""
|
||||
return SLACK_AVAILABLE
|
||||
|
||||
|
||||
class SlackAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Slack bot adapter using Socket Mode.
|
||||
|
||||
Requires two tokens:
|
||||
- SLACK_BOT_TOKEN (xoxb-...) for API calls
|
||||
- SLACK_APP_TOKEN (xapp-...) for Socket Mode connection
|
||||
|
||||
Features:
|
||||
- DMs and channel messages (mention-gated in channels)
|
||||
- Thread support
|
||||
- File/image/audio attachments
|
||||
- Slash commands (/hermes)
|
||||
- Typing indicators (not natively supported by Slack bots)
|
||||
"""
|
||||
|
||||
MAX_MESSAGE_LENGTH = 4000 # Slack's limit is higher but mrkdwn can inflate
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.SLACK)
|
||||
self._app: AsyncApp | None = None
|
||||
self._handler: AsyncSocketModeHandler | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Slack via Socket Mode."""
|
||||
if not SLACK_AVAILABLE:
|
||||
print("[Slack] slack-bolt not installed. Run: pip install slack-bolt")
|
||||
return False
|
||||
|
||||
bot_token = self.config.token
|
||||
app_token = os.getenv("SLACK_APP_TOKEN")
|
||||
|
||||
if not bot_token:
|
||||
print("[Slack] SLACK_BOT_TOKEN not set")
|
||||
return False
|
||||
if not app_token:
|
||||
print("[Slack] SLACK_APP_TOKEN not set")
|
||||
return False
|
||||
|
||||
try:
|
||||
self._app = AsyncApp(token=bot_token)
|
||||
|
||||
# Get our own bot user ID for mention detection
|
||||
auth_response = await self._app.client.auth_test()
|
||||
self._bot_user_id = auth_response.get("user_id")
|
||||
bot_name = auth_response.get("user", "unknown")
|
||||
|
||||
# Register message event handler
|
||||
@self._app.event("message")
|
||||
async def handle_message_event(event, say):
|
||||
await self._handle_slack_message(event)
|
||||
|
||||
# Acknowledge app_mention events to prevent Bolt 404 errors.
|
||||
# The "message" handler above already processes @mentions in
|
||||
# channels, so this is intentionally a no-op to avoid duplicates.
|
||||
@self._app.event("app_mention")
|
||||
async def handle_app_mention(event, say):
|
||||
pass
|
||||
|
||||
# Register slash command handler
|
||||
@self._app.command("/hermes")
|
||||
async def handle_hermes_command(ack, command):
|
||||
await ack()
|
||||
await self._handle_slash_command(command)
|
||||
|
||||
# Start Socket Mode handler in background
|
||||
self._handler = AsyncSocketModeHandler(self._app, app_token)
|
||||
asyncio.create_task(self._handler.start_async())
|
||||
|
||||
self._running = True
|
||||
print(f"[Slack] Connected as @{bot_name} (Socket Mode)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Slack] Connection failed: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Slack."""
|
||||
if self._handler:
|
||||
await self._handler.close_async()
|
||||
self._running = False
|
||||
print("[Slack] Disconnected")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a message to a Slack channel or DM."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
kwargs = {
|
||||
"channel": chat_id,
|
||||
"text": content,
|
||||
}
|
||||
|
||||
# Reply in thread if thread_ts is available
|
||||
if reply_to:
|
||||
kwargs["thread_ts"] = reply_to
|
||||
elif metadata and metadata.get("thread_ts"):
|
||||
kwargs["thread_ts"] = metadata["thread_ts"]
|
||||
|
||||
result = await self._app.client.chat_postMessage(**kwargs)
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=result.get("ts"),
|
||||
raw_response=result,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Slack] Send error: {e}")
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
content: str,
|
||||
) -> SendResult:
|
||||
"""Edit a previously sent Slack message."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
await self._app.client.chat_update(
|
||||
channel=chat_id,
|
||||
ts=message_id,
|
||||
text=content,
|
||||
)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Slack doesn't have a direct typing indicator API for bots."""
|
||||
pass
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file to Slack by uploading it."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=image_path,
|
||||
filename=os.path.basename(image_path),
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an image to Slack by uploading the URL as a file."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
# Download the image first
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(image_url)
|
||||
response.raise_for_status()
|
||||
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
content=response.content,
|
||||
filename="image.png",
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception:
|
||||
# Fall back to sending the URL as text
|
||||
text = f"{caption}\n{image_url}" if caption else image_url
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an audio file to Slack."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=audio_path,
|
||||
filename=os.path.basename(audio_path),
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a video file to Slack."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not os.path.exists(video_path):
|
||||
return SendResult(success=False, error=f"Video file not found: {video_path}")
|
||||
|
||||
try:
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=video_path,
|
||||
filename=os.path.basename(video_path),
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send video: {e}")
|
||||
return await super().send_video(chat_id, video_path, caption, reply_to)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment to Slack."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
display_name = file_name or os.path.basename(file_path)
|
||||
|
||||
try:
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=file_path,
|
||||
filename=display_name,
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send document: {e}")
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a Slack channel."""
|
||||
if not self._app:
|
||||
return {"name": chat_id, "type": "unknown"}
|
||||
|
||||
try:
|
||||
result = await self._app.client.conversations_info(channel=chat_id)
|
||||
channel = result.get("channel", {})
|
||||
is_dm = channel.get("is_im", False)
|
||||
return {
|
||||
"name": channel.get("name", chat_id),
|
||||
"type": "dm" if is_dm else "group",
|
||||
}
|
||||
except Exception:
|
||||
return {"name": chat_id, "type": "unknown"}
|
||||
|
||||
# ----- Internal handlers -----
|
||||
|
||||
async def _handle_slack_message(self, event: dict) -> None:
|
||||
"""Handle an incoming Slack message event."""
|
||||
# Ignore bot messages (including our own)
|
||||
if event.get("bot_id") or event.get("subtype") == "bot_message":
|
||||
return
|
||||
|
||||
# Ignore message edits and deletions
|
||||
subtype = event.get("subtype")
|
||||
if subtype in ("message_changed", "message_deleted"):
|
||||
return
|
||||
|
||||
text = event.get("text", "")
|
||||
user_id = event.get("user", "")
|
||||
channel_id = event.get("channel", "")
|
||||
thread_ts = event.get("thread_ts") or event.get("ts")
|
||||
ts = event.get("ts", "")
|
||||
|
||||
# Determine if this is a DM or channel message
|
||||
channel_type = event.get("channel_type", "")
|
||||
is_dm = channel_type == "im"
|
||||
|
||||
# In channels, only respond if bot is mentioned
|
||||
if not is_dm and self._bot_user_id:
|
||||
if f"<@{self._bot_user_id}>" not in text:
|
||||
return
|
||||
# Strip the bot mention from the text
|
||||
text = text.replace(f"<@{self._bot_user_id}>", "").strip()
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if text.startswith("/"):
|
||||
msg_type = MessageType.COMMAND
|
||||
|
||||
# Handle file attachments
|
||||
media_urls = []
|
||||
media_types = []
|
||||
files = event.get("files", [])
|
||||
for f in files:
|
||||
mimetype = f.get("mimetype", "unknown")
|
||||
url = f.get("url_private_download") or f.get("url_private", "")
|
||||
if mimetype.startswith("image/") and url:
|
||||
try:
|
||||
ext = "." + mimetype.split("/")[-1].split(";")[0]
|
||||
if ext not in (".jpg", ".jpeg", ".png", ".gif", ".webp"):
|
||||
ext = ".jpg"
|
||||
# Slack private URLs require the bot token as auth header
|
||||
cached = await self._download_slack_file(url, ext)
|
||||
media_urls.append(cached)
|
||||
media_types.append(mimetype)
|
||||
msg_type = MessageType.PHOTO
|
||||
except Exception as e:
|
||||
print(f"[Slack] Failed to cache image: {e}", flush=True)
|
||||
elif mimetype.startswith("audio/") and url:
|
||||
try:
|
||||
ext = "." + mimetype.split("/")[-1].split(";")[0]
|
||||
if ext not in (".ogg", ".mp3", ".wav", ".webm", ".m4a"):
|
||||
ext = ".ogg"
|
||||
cached = await self._download_slack_file(url, ext, audio=True)
|
||||
media_urls.append(cached)
|
||||
media_types.append(mimetype)
|
||||
msg_type = MessageType.VOICE
|
||||
except Exception as e:
|
||||
print(f"[Slack] Failed to cache audio: {e}", flush=True)
|
||||
elif url:
|
||||
# Try to handle as a document attachment
|
||||
try:
|
||||
original_filename = f.get("name", "")
|
||||
ext = ""
|
||||
if original_filename:
|
||||
_, ext = os.path.splitext(original_filename)
|
||||
ext = ext.lower()
|
||||
|
||||
# Fallback: reverse-lookup from MIME type
|
||||
if not ext and mimetype:
|
||||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||||
ext = mime_to_ext.get(mimetype, "")
|
||||
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
continue # Skip unsupported file types silently
|
||||
|
||||
# Check file size (Slack limit: 20 MB for bots)
|
||||
file_size = f.get("size", 0)
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if not file_size or file_size > MAX_DOC_BYTES:
|
||||
print(f"[Slack] Document too large or unknown size: {file_size}", flush=True)
|
||||
continue
|
||||
|
||||
# Download and cache
|
||||
raw_bytes = await self._download_slack_file_bytes(url)
|
||||
cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}")
|
||||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
msg_type = MessageType.DOCUMENT
|
||||
print(f"[Slack] Cached user document: {cached_path}", flush=True)
|
||||
|
||||
# Inject text content for .txt/.md files (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
display_name = re.sub(r"[^\w.\- ]", "_", display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if text:
|
||||
text = f"{injection}\n\n{text}"
|
||||
else:
|
||||
text = injection
|
||||
except UnicodeDecodeError:
|
||||
pass # Binary content, skip injection
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Slack] Failed to cache document: {e}", flush=True)
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=channel_id,
|
||||
chat_name=channel_id, # Will be resolved later if needed
|
||||
chat_type="dm" if is_dm else "group",
|
||||
user_id=user_id,
|
||||
thread_id=thread_ts,
|
||||
)
|
||||
|
||||
msg_event = MessageEvent(
|
||||
text=text,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=event,
|
||||
message_id=ts,
|
||||
media_urls=media_urls,
|
||||
media_types=media_types,
|
||||
reply_to_message_id=thread_ts if thread_ts != ts else None,
|
||||
)
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
async def _handle_slash_command(self, command: dict) -> None:
|
||||
"""Handle /hermes slash command."""
|
||||
text = command.get("text", "").strip()
|
||||
user_id = command.get("user_id", "")
|
||||
channel_id = command.get("channel_id", "")
|
||||
|
||||
# Map subcommands to gateway commands
|
||||
subcommand_map = {
|
||||
"new": "/reset",
|
||||
"reset": "/reset",
|
||||
"status": "/status",
|
||||
"stop": "/stop",
|
||||
"help": "/help",
|
||||
"model": "/model",
|
||||
"personality": "/personality",
|
||||
"retry": "/retry",
|
||||
"undo": "/undo",
|
||||
}
|
||||
first_word = text.split()[0] if text else ""
|
||||
if first_word in subcommand_map:
|
||||
# Preserve arguments after the subcommand
|
||||
rest = text[len(first_word) :].strip()
|
||||
text = f"{subcommand_map[first_word]} {rest}".strip() if rest else subcommand_map[first_word]
|
||||
elif text:
|
||||
pass # Treat as a regular question
|
||||
else:
|
||||
text = "/help"
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=channel_id,
|
||||
chat_type="dm", # Slash commands are always in DM-like context
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
event = MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.COMMAND if text.startswith("/") else MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message=command,
|
||||
)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _download_slack_file(self, url: str, ext: str, audio: bool = False) -> str:
|
||||
"""Download a Slack file using the bot token for auth."""
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str) -> bytes:
|
||||
"""Download a Slack file and return raw bytes."""
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
835
gateway/platforms/telegram.py
Normal file
835
gateway/platforms/telegram.py
Normal file
@@ -0,0 +1,835 @@
|
||||
"""
|
||||
Telegram platform adapter.
|
||||
|
||||
Uses python-telegram-bot library for:
|
||||
- Receiving messages from users/groups
|
||||
- Sending responses back
|
||||
- Handling media and commands
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from telegram import Bot, Message, Update
|
||||
from telegram.constants import ChatType, ParseMode
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
CommandHandler,
|
||||
ContextTypes,
|
||||
filters,
|
||||
)
|
||||
from telegram.ext import (
|
||||
MessageHandler as TelegramMessageHandler,
|
||||
)
|
||||
|
||||
TELEGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TELEGRAM_AVAILABLE = False
|
||||
Update = Any
|
||||
Bot = Any
|
||||
Message = Any
|
||||
Application = Any
|
||||
CommandHandler = Any
|
||||
TelegramMessageHandler = Any
|
||||
filters = None
|
||||
ParseMode = None
|
||||
ChatType = None
|
||||
|
||||
# Mock ContextTypes so type annotations using ContextTypes.DEFAULT_TYPE
|
||||
# don't crash during class definition when the library isn't installed.
|
||||
class _MockContextTypes:
|
||||
DEFAULT_TYPE = Any
|
||||
|
||||
ContextTypes = _MockContextTypes
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_bytes,
|
||||
)
|
||||
|
||||
|
||||
def check_telegram_requirements() -> bool:
|
||||
"""Check if Telegram dependencies are available."""
|
||||
return TELEGRAM_AVAILABLE
|
||||
|
||||
|
||||
# Matches every character that MarkdownV2 requires to be backslash-escaped
|
||||
# when it appears outside a code span or fenced code block.
|
||||
_MDV2_ESCAPE_RE = re.compile(r"([_*\[\]()~`>#\+\-=|{}.!\\])")
|
||||
|
||||
|
||||
def _escape_mdv2(text: str) -> str:
|
||||
"""Escape Telegram MarkdownV2 special characters with a preceding backslash."""
|
||||
return _MDV2_ESCAPE_RE.sub(r"\\\1", text)
|
||||
|
||||
|
||||
def _strip_mdv2(text: str) -> str:
|
||||
"""Strip MarkdownV2 escape backslashes to produce clean plain text.
|
||||
|
||||
Also removes MarkdownV2 bold markers (*text* -> text) so the fallback
|
||||
doesn't show stray asterisks from header/bold conversion.
|
||||
"""
|
||||
# Remove escape backslashes before special characters
|
||||
cleaned = re.sub(r"\\([_*\[\]()~`>#\+\-=|{}.!\\])", r"\1", text)
|
||||
# Remove MarkdownV2 bold markers that format_message converted from **bold**
|
||||
cleaned = re.sub(r"\*([^*]+)\*", r"\1", cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
|
||||
Handles:
|
||||
- Receiving messages from users and groups
|
||||
- Sending responses with Telegram markdown
|
||||
- Forum topics (thread_id support)
|
||||
- Media messages
|
||||
"""
|
||||
|
||||
# Telegram message limits
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.TELEGRAM)
|
||||
self._app: Application | None = None
|
||||
self._bot: Bot | None = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Telegram and start polling for updates."""
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
print(f"[{self.name}] python-telegram-bot not installed. Run: pip install python-telegram-bot")
|
||||
return False
|
||||
|
||||
if not self.config.token:
|
||||
print(f"[{self.name}] No bot token configured")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Build the application
|
||||
self._app = Application.builder().token(self.config.token).build()
|
||||
self._bot = self._app.bot
|
||||
|
||||
# Register handlers
|
||||
self._app.add_handler(TelegramMessageHandler(filters.TEXT & ~filters.COMMAND, self._handle_text_message))
|
||||
self._app.add_handler(TelegramMessageHandler(filters.COMMAND, self._handle_command))
|
||||
self._app.add_handler(
|
||||
TelegramMessageHandler(
|
||||
filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION), self._handle_location_message
|
||||
)
|
||||
)
|
||||
self._app.add_handler(
|
||||
TelegramMessageHandler(
|
||||
filters.PHOTO
|
||||
| filters.VIDEO
|
||||
| filters.AUDIO
|
||||
| filters.VOICE
|
||||
| filters.Document.ALL
|
||||
| filters.Sticker.ALL,
|
||||
self._handle_media_message,
|
||||
)
|
||||
)
|
||||
|
||||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
|
||||
|
||||
# Register bot commands so Telegram shows a hint menu when users type /
|
||||
try:
|
||||
from telegram import BotCommand
|
||||
|
||||
await self._bot.set_my_commands(
|
||||
[
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("reset", "Reset conversation history"),
|
||||
BotCommand("model", "Show or change the model"),
|
||||
BotCommand("personality", "Set a personality"),
|
||||
BotCommand("retry", "Retry your last message"),
|
||||
BotCommand("undo", "Remove the last exchange"),
|
||||
BotCommand("status", "Show session info"),
|
||||
BotCommand("stop", "Stop the running agent"),
|
||||
BotCommand("sethome", "Set this chat as the home channel"),
|
||||
BotCommand("compress", "Compress conversation context"),
|
||||
BotCommand("title", "Set or show the session title"),
|
||||
BotCommand("resume", "Resume a previously-named session"),
|
||||
BotCommand("usage", "Show token usage for this session"),
|
||||
BotCommand("provider", "Show available providers"),
|
||||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Could not register command menu: {e}")
|
||||
|
||||
self._running = True
|
||||
print(f"[{self.name}] Connected and polling for updates")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop polling and disconnect."""
|
||||
if self._app:
|
||||
try:
|
||||
await self._app.updater.stop()
|
||||
await self._app.stop()
|
||||
await self._app.shutdown()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during disconnect: {e}")
|
||||
|
||||
self._running = False
|
||||
self._app = None
|
||||
self._bot = None
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
async def send(
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Telegram chat."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
message_ids = []
|
||||
thread_id = metadata.get("thread_id") if metadata else None
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Try Markdown first, fall back to plain text if it fails
|
||||
try:
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning(
|
||||
"[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error
|
||||
)
|
||||
# Strip MDV2 escape backslashes so the user doesn't
|
||||
# see raw backslashes littered through the message.
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=plain_chunk,
|
||||
parse_mode=None, # Plain text
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
else:
|
||||
raise # Re-raise if not a parse error
|
||||
message_ids.append(str(msg.message_id))
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
content: str,
|
||||
) -> SendResult:
|
||||
"""Edit a previously sent Telegram message."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
formatted = self.format_message(content)
|
||||
try:
|
||||
await self._bot.edit_message_text(
|
||||
chat_id=int(chat_id),
|
||||
message_id=int(message_id),
|
||||
text=formatted,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
)
|
||||
except Exception:
|
||||
# Fallback: retry without markdown formatting
|
||||
await self._bot.edit_message_text(
|
||||
chat_id=int(chat_id),
|
||||
message_id=int(message_id),
|
||||
text=content,
|
||||
)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a native Telegram voice message or audio file."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
# .ogg files -> send as voice (round playable bubble)
|
||||
if audio_path.endswith(".ogg") or audio_path.endswith(".opus"):
|
||||
msg = await self._bot.send_voice(
|
||||
chat_id=int(chat_id),
|
||||
voice=audio_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
else:
|
||||
# .mp3 and others -> send as audio file
|
||||
msg = await self._bot.send_audio(
|
||||
chat_id=int(chat_id),
|
||||
audio=audio_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send voice/audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Telegram photo.
|
||||
|
||||
Tries URL-based send first (fast, works for <5MB images).
|
||||
Falls back to downloading and uploading as file (supports up to 10MB).
|
||||
"""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs (up to ~5MB)
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_url,
|
||||
caption=caption[:1024] if caption else None, # Telegram caption limit
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
logger.warning("[%s] URL-based send_photo failed (%s), trying file upload", self.name, e)
|
||||
# Fallback: download and upload as file (supports up to 10MB)
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(image_url)
|
||||
resp.raise_for_status()
|
||||
image_data = resp.content
|
||||
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_data,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e2:
|
||||
logger.error("[%s] File upload send_photo also failed: %s", self.name, e2)
|
||||
# Final fallback: send URL as text
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an animated GIF natively as a Telegram animation (auto-plays inline)."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
msg = await self._bot.send_animation(
|
||||
chat_id=int(chat_id),
|
||||
animation=animation_url,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send animation, falling back to photo: {e}")
|
||||
# Fallback: try as a regular photo
|
||||
return await self.send_image(chat_id, animation_url, caption, reply_to)
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._bot:
|
||||
try:
|
||||
await self._bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a Telegram chat."""
|
||||
if not self._bot:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
try:
|
||||
chat = await self._bot.get_chat(int(chat_id))
|
||||
|
||||
chat_type = "dm"
|
||||
if chat.type == ChatType.GROUP:
|
||||
chat_type = "group"
|
||||
elif chat.type == ChatType.SUPERGROUP:
|
||||
chat_type = "group"
|
||||
if chat.is_forum:
|
||||
chat_type = "forum"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
return {
|
||||
"name": chat.title or chat.full_name or str(chat_id),
|
||||
"type": chat_type,
|
||||
"username": chat.username,
|
||||
"is_forum": getattr(chat, "is_forum", False),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Convert standard markdown to Telegram MarkdownV2 format.
|
||||
|
||||
Protected regions (code blocks, inline code) are extracted first so
|
||||
their contents are never modified. Standard markdown constructs
|
||||
(headers, bold, italic, links) are translated to MarkdownV2 syntax,
|
||||
and all remaining special characters are escaped.
|
||||
"""
|
||||
if not content:
|
||||
return content
|
||||
|
||||
placeholders: dict = {}
|
||||
counter = [0]
|
||||
|
||||
def _ph(value: str) -> str:
|
||||
"""Stash *value* behind a placeholder token that survives escaping."""
|
||||
key = f"\x00PH{counter[0]}\x00"
|
||||
counter[0] += 1
|
||||
placeholders[key] = value
|
||||
return key
|
||||
|
||||
text = content
|
||||
|
||||
# 1) Protect fenced code blocks (``` ... ```)
|
||||
text = re.sub(
|
||||
r"(```(?:[^\n]*\n)?[\s\S]*?```)",
|
||||
lambda m: _ph(m.group(0)),
|
||||
text,
|
||||
)
|
||||
|
||||
# 2) Protect inline code (`...`)
|
||||
text = re.sub(r"(`[^`]+`)", lambda m: _ph(m.group(0)), text)
|
||||
|
||||
# 3) Convert markdown links – escape the display text; inside the URL
|
||||
# only ')' and '\' need escaping per the MarkdownV2 spec.
|
||||
def _convert_link(m):
|
||||
display = _escape_mdv2(m.group(1))
|
||||
url = m.group(2).replace("\\", "\\\\").replace(")", "\\)")
|
||||
return _ph(f"[{display}]({url})")
|
||||
|
||||
text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _convert_link, text)
|
||||
|
||||
# 4) Convert markdown headers (## Title) → bold *Title*
|
||||
def _convert_header(m):
|
||||
inner = m.group(1).strip()
|
||||
# Strip redundant bold markers that may appear inside a header
|
||||
inner = re.sub(r"\*\*(.+?)\*\*", r"\1", inner)
|
||||
return _ph(f"*{_escape_mdv2(inner)}*")
|
||||
|
||||
text = re.sub(r"^#{1,6}\s+(.+)$", _convert_header, text, flags=re.MULTILINE)
|
||||
|
||||
# 5) Convert bold: **text** → *text* (MarkdownV2 bold)
|
||||
text = re.sub(
|
||||
r"\*\*(.+?)\*\*",
|
||||
lambda m: _ph(f"*{_escape_mdv2(m.group(1))}*"),
|
||||
text,
|
||||
)
|
||||
|
||||
# 6) Convert italic: *text* (single asterisk) → _text_ (MarkdownV2 italic)
|
||||
# [^*\n]+ prevents matching across newlines (which would corrupt
|
||||
# bullet lists using * markers and multi-line content).
|
||||
text = re.sub(
|
||||
r"\*([^*\n]+)\*",
|
||||
lambda m: _ph(f"_{_escape_mdv2(m.group(1))}_"),
|
||||
text,
|
||||
)
|
||||
|
||||
# 7) Escape remaining special characters in plain text
|
||||
text = _escape_mdv2(text)
|
||||
|
||||
# 8) Restore placeholders in reverse insertion order so that
|
||||
# nested references (a placeholder inside another) resolve correctly.
|
||||
for key in reversed(list(placeholders.keys())):
|
||||
text = text.replace(key, placeholders[key])
|
||||
|
||||
return text
|
||||
|
||||
async def _handle_text_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming text messages."""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.TEXT)
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming command messages."""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.COMMAND)
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming location/venue pin messages."""
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
msg = update.message
|
||||
venue = getattr(msg, "venue", None)
|
||||
location = getattr(venue, "location", None) if venue else getattr(msg, "location", None)
|
||||
|
||||
if not location:
|
||||
return
|
||||
|
||||
lat = getattr(location, "latitude", None)
|
||||
lon = getattr(location, "longitude", None)
|
||||
if lat is None or lon is None:
|
||||
return
|
||||
|
||||
# Build a text message with coordinates and context
|
||||
parts = ["[The user shared a location pin.]"]
|
||||
if venue:
|
||||
title = getattr(venue, "title", None)
|
||||
address = getattr(venue, "address", None)
|
||||
if title:
|
||||
parts.append(f"Venue: {title}")
|
||||
if address:
|
||||
parts.append(f"Address: {address}")
|
||||
parts.append(f"latitude: {lat}")
|
||||
parts.append(f"longitude: {lon}")
|
||||
parts.append(f"Map: https://www.google.com/maps/search/?api=1&query={lat},{lon}")
|
||||
parts.append("Ask what they'd like to find nearby (restaurants, cafes, etc.) and any preferences.")
|
||||
|
||||
event = self._build_message_event(msg, MessageType.LOCATION)
|
||||
event.text = "\n".join(parts)
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_media_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming media messages, downloading images to local cache."""
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
msg = update.message
|
||||
|
||||
# Determine media type
|
||||
if msg.sticker:
|
||||
msg_type = MessageType.STICKER
|
||||
elif msg.photo:
|
||||
msg_type = MessageType.PHOTO
|
||||
elif msg.video:
|
||||
msg_type = MessageType.VIDEO
|
||||
elif msg.audio:
|
||||
msg_type = MessageType.AUDIO
|
||||
elif msg.voice:
|
||||
msg_type = MessageType.VOICE
|
||||
elif msg.document:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
event = self._build_message_event(msg, msg_type)
|
||||
|
||||
# Add caption as text
|
||||
if msg.caption:
|
||||
event.text = msg.caption
|
||||
|
||||
# Handle stickers: describe via vision tool with caching
|
||||
if msg.sticker:
|
||||
await self._handle_sticker(msg, event)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
||||
# Download photo to local image cache so the vision tool can access it
|
||||
# even after Telegram's ephemeral file URLs expire (~1 hour).
|
||||
if msg.photo:
|
||||
try:
|
||||
# msg.photo is a list of PhotoSize sorted by size; take the largest
|
||||
photo = msg.photo[-1]
|
||||
file_obj = await photo.get_file()
|
||||
# Download the image bytes directly into memory
|
||||
image_bytes = await file_obj.download_as_bytearray()
|
||||
# Determine extension from the file path if available
|
||||
ext = ".jpg"
|
||||
if file_obj.file_path:
|
||||
for candidate in [".png", ".webp", ".gif", ".jpeg", ".jpg"]:
|
||||
if file_obj.file_path.lower().endswith(candidate):
|
||||
ext = candidate
|
||||
break
|
||||
# Save to cache and populate media_urls with the local path
|
||||
cached_path = cache_image_from_bytes(bytes(image_bytes), ext=ext)
|
||||
event.media_urls = [cached_path]
|
||||
event.media_types = [f"image/{ext.lstrip('.')}"]
|
||||
print(f"[Telegram] Cached user photo: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache photo: {e}", flush=True)
|
||||
|
||||
# Download voice/audio messages to cache for STT transcription
|
||||
if msg.voice:
|
||||
try:
|
||||
file_obj = await msg.voice.get_file()
|
||||
audio_bytes = await file_obj.download_as_bytearray()
|
||||
cached_path = cache_audio_from_bytes(bytes(audio_bytes), ext=".ogg")
|
||||
event.media_urls = [cached_path]
|
||||
event.media_types = ["audio/ogg"]
|
||||
print(f"[Telegram] Cached user voice: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache voice: {e}", flush=True)
|
||||
elif msg.audio:
|
||||
try:
|
||||
file_obj = await msg.audio.get_file()
|
||||
audio_bytes = await file_obj.download_as_bytearray()
|
||||
cached_path = cache_audio_from_bytes(bytes(audio_bytes), ext=".mp3")
|
||||
event.media_urls = [cached_path]
|
||||
event.media_types = ["audio/mp3"]
|
||||
print(f"[Telegram] Cached user audio: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache audio: {e}", flush=True)
|
||||
|
||||
# Download document files to cache for agent processing
|
||||
elif msg.document:
|
||||
doc = msg.document
|
||||
try:
|
||||
# Determine file extension
|
||||
ext = ""
|
||||
original_filename = doc.file_name or ""
|
||||
if original_filename:
|
||||
_, ext = os.path.splitext(original_filename)
|
||||
ext = ext.lower()
|
||||
|
||||
# If no extension from filename, reverse-lookup from MIME type
|
||||
if not ext and doc.mime_type:
|
||||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||||
ext = mime_to_ext.get(doc.mime_type, "")
|
||||
|
||||
# Check if supported
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
supported_list = ", ".join(sorted(SUPPORTED_DOCUMENT_TYPES.keys()))
|
||||
event.text = f"Unsupported document type '{ext or 'unknown'}'. Supported types: {supported_list}"
|
||||
print(f"[Telegram] Unsupported document type: {ext or 'unknown'}", flush=True)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
||||
# Check file size (Telegram Bot API limit: 20 MB)
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if not doc.file_size or doc.file_size > MAX_DOC_BYTES:
|
||||
event.text = "The document is too large or its size could not be verified. Maximum: 20 MB."
|
||||
print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
||||
# Download and cache
|
||||
file_obj = await doc.get_file()
|
||||
doc_bytes = await file_obj.download_as_bytearray()
|
||||
raw_bytes = bytes(doc_bytes)
|
||||
cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}")
|
||||
mime_type = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
event.media_urls = [cached_path]
|
||||
event.media_types = [mime_type]
|
||||
print(f"[Telegram] Cached user document: {cached_path}", flush=True)
|
||||
|
||||
# For text files, inject content into event.text (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
display_name = re.sub(r"[^\w.\- ]", "_", display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if event.text:
|
||||
event.text = f"{injection}\n\n{event.text}"
|
||||
else:
|
||||
event.text = injection
|
||||
except UnicodeDecodeError:
|
||||
print("[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache document: {e}", flush=True)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None:
|
||||
"""
|
||||
Describe a Telegram sticker via vision analysis, with caching.
|
||||
|
||||
For static stickers (WEBP), we download, analyze with vision, and cache
|
||||
the description by file_unique_id. For animated/video stickers, we inject
|
||||
a placeholder noting the emoji.
|
||||
"""
|
||||
from gateway.sticker_cache import (
|
||||
STICKER_VISION_PROMPT,
|
||||
build_animated_sticker_injection,
|
||||
build_sticker_injection,
|
||||
cache_sticker_description,
|
||||
get_cached_description,
|
||||
)
|
||||
|
||||
sticker = msg.sticker
|
||||
emoji = sticker.emoji or ""
|
||||
set_name = sticker.set_name or ""
|
||||
|
||||
# Animated and video stickers can't be analyzed as static images
|
||||
if sticker.is_animated or sticker.is_video:
|
||||
event.text = build_animated_sticker_injection(emoji)
|
||||
return
|
||||
|
||||
# Check the cache first
|
||||
cached = get_cached_description(sticker.file_unique_id)
|
||||
if cached:
|
||||
event.text = build_sticker_injection(
|
||||
cached["description"], cached.get("emoji", emoji), cached.get("set_name", set_name)
|
||||
)
|
||||
print(f"[Telegram] Sticker cache hit: {sticker.file_unique_id}", flush=True)
|
||||
return
|
||||
|
||||
# Cache miss -- download and analyze
|
||||
try:
|
||||
file_obj = await sticker.get_file()
|
||||
image_bytes = await file_obj.download_as_bytearray()
|
||||
cached_path = cache_image_from_bytes(bytes(image_bytes), ext=".webp")
|
||||
print(f"[Telegram] Analyzing sticker: {cached_path}", flush=True)
|
||||
|
||||
import json as _json
|
||||
|
||||
from tools.vision_tools import vision_analyze_tool
|
||||
|
||||
result_json = await vision_analyze_tool(
|
||||
image_url=cached_path,
|
||||
user_prompt=STICKER_VISION_PROMPT,
|
||||
)
|
||||
result = _json.loads(result_json)
|
||||
|
||||
if result.get("success"):
|
||||
description = result.get("analysis", "a sticker")
|
||||
cache_sticker_description(sticker.file_unique_id, description, emoji, set_name)
|
||||
event.text = build_sticker_injection(description, emoji, set_name)
|
||||
else:
|
||||
# Vision failed -- use emoji as fallback
|
||||
event.text = build_sticker_injection(
|
||||
f"a sticker with emoji {emoji}" if emoji else "a sticker",
|
||||
emoji,
|
||||
set_name,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Sticker analysis error: {e}", flush=True)
|
||||
event.text = build_sticker_injection(
|
||||
f"a sticker with emoji {emoji}" if emoji else "a sticker",
|
||||
emoji,
|
||||
set_name,
|
||||
)
|
||||
|
||||
def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Telegram message."""
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
|
||||
# Determine chat type
|
||||
chat_type = "dm"
|
||||
if chat.type in (ChatType.GROUP, ChatType.SUPERGROUP):
|
||||
chat_type = "group"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(chat.id),
|
||||
chat_name=chat.title or (chat.full_name if hasattr(chat, "full_name") else None),
|
||||
chat_type=chat_type,
|
||||
user_id=str(user.id) if user else None,
|
||||
user_name=user.full_name if user else None,
|
||||
thread_id=str(message.message_thread_id) if message.message_thread_id else None,
|
||||
)
|
||||
|
||||
return MessageEvent(
|
||||
text=message.text or "",
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=message,
|
||||
message_id=str(message.message_id),
|
||||
timestamp=message.date,
|
||||
)
|
||||
642
gateway/platforms/whatsapp.py
Normal file
642
gateway/platforms/whatsapp.py
Normal file
@@ -0,0 +1,642 @@
|
||||
"""
|
||||
WhatsApp platform adapter.
|
||||
|
||||
WhatsApp integration is more complex than Telegram/Discord because:
|
||||
- No official bot API for personal accounts
|
||||
- Business API requires Meta Business verification
|
||||
- Most solutions use web-based automation
|
||||
|
||||
This adapter supports multiple backends:
|
||||
1. WhatsApp Business API (requires Meta verification)
|
||||
2. whatsapp-web.js (via Node.js subprocess) - for personal accounts
|
||||
3. Baileys (via Node.js subprocess) - alternative for personal accounts
|
||||
|
||||
For simplicity, we'll implement a generic interface that can work
|
||||
with different backends via a bridge pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _kill_port_process(port: int) -> None:
|
||||
"""Kill any process listening on the given TCP port."""
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
# Use netstat to find the PID bound to this port, then taskkill
|
||||
result = subprocess.run(
|
||||
["netstat", "-ano", "-p", "TCP"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
for line in result.stdout.splitlines():
|
||||
parts = line.split()
|
||||
if len(parts) >= 5 and parts[3] == "LISTENING":
|
||||
local_addr = parts[1]
|
||||
if local_addr.endswith(f":{port}"):
|
||||
try:
|
||||
subprocess.run(
|
||||
["taskkill", "/PID", parts[4], "/F"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
except subprocess.SubprocessError:
|
||||
pass
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["fuser", f"{port}/tcp"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
subprocess.run(
|
||||
["fuser", "-k", f"{port}/tcp"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_audio_from_url,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
|
||||
def check_whatsapp_requirements() -> bool:
|
||||
"""
|
||||
Check if WhatsApp dependencies are available.
|
||||
|
||||
WhatsApp requires a Node.js bridge for most implementations.
|
||||
"""
|
||||
# Check for Node.js
|
||||
try:
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=5)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
WhatsApp adapter.
|
||||
|
||||
This implementation uses a simple HTTP bridge pattern where:
|
||||
1. A Node.js process runs the WhatsApp Web client
|
||||
2. Messages are forwarded via HTTP/IPC to this Python adapter
|
||||
3. Responses are sent back through the bridge
|
||||
|
||||
The actual Node.js bridge implementation can vary:
|
||||
- whatsapp-web.js based
|
||||
- Baileys based
|
||||
- Business API based
|
||||
|
||||
Configuration:
|
||||
- bridge_script: Path to the Node.js bridge script
|
||||
- bridge_port: Port for HTTP communication (default: 3000)
|
||||
- session_path: Path to store WhatsApp session data
|
||||
"""
|
||||
|
||||
# WhatsApp message limits
|
||||
MAX_MESSAGE_LENGTH = 65536 # WhatsApp allows longer messages
|
||||
|
||||
# Default bridge location relative to the hermes-agent install
|
||||
_DEFAULT_BRIDGE_DIR = Path(__file__).resolve().parents[2] / "scripts" / "whatsapp-bridge"
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.WHATSAPP)
|
||||
self._bridge_process: subprocess.Popen | None = None
|
||||
self._bridge_port: int = config.extra.get("bridge_port", 3000)
|
||||
self._bridge_script: str | None = config.extra.get(
|
||||
"bridge_script",
|
||||
str(self._DEFAULT_BRIDGE_DIR / "bridge.js"),
|
||||
)
|
||||
self._session_path: Path = Path(
|
||||
config.extra.get("session_path", Path.home() / ".hermes" / "whatsapp" / "session")
|
||||
)
|
||||
self._message_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._bridge_log_fh = None
|
||||
self._bridge_log: Path | None = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Start the WhatsApp bridge.
|
||||
|
||||
This launches the Node.js bridge process and waits for it to be ready.
|
||||
"""
|
||||
if not check_whatsapp_requirements():
|
||||
logger.warning("[%s] Node.js not found. WhatsApp requires Node.js.", self.name)
|
||||
return False
|
||||
|
||||
bridge_path = Path(self._bridge_script)
|
||||
if not bridge_path.exists():
|
||||
logger.warning("[%s] Bridge script not found: %s", self.name, bridge_path)
|
||||
return False
|
||||
|
||||
logger.info("[%s] Bridge found at %s", self.name, bridge_path)
|
||||
|
||||
# Auto-install npm dependencies if node_modules doesn't exist
|
||||
bridge_dir = bridge_path.parent
|
||||
if not (bridge_dir / "node_modules").exists():
|
||||
print(f"[{self.name}] Installing WhatsApp bridge dependencies...")
|
||||
try:
|
||||
install_result = subprocess.run(
|
||||
["npm", "install", "--silent"],
|
||||
cwd=str(bridge_dir),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
if install_result.returncode != 0:
|
||||
print(f"[{self.name}] npm install failed: {install_result.stderr}")
|
||||
return False
|
||||
print(f"[{self.name}] Dependencies installed")
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to install dependencies: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Ensure session directory exists
|
||||
self._session_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Kill any orphaned bridge from a previous gateway run
|
||||
_kill_port_process(self._bridge_port)
|
||||
import time
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
# Start the bridge process in its own process group.
|
||||
# Route output to a log file so QR codes, errors, and reconnection
|
||||
# messages are preserved for troubleshooting.
|
||||
whatsapp_mode = os.getenv("WHATSAPP_MODE", "self-chat")
|
||||
self._bridge_log = self._session_path.parent / "bridge.log"
|
||||
bridge_log_fh = open(self._bridge_log, "a")
|
||||
self._bridge_log_fh = bridge_log_fh
|
||||
self._bridge_process = subprocess.Popen(
|
||||
[
|
||||
"node",
|
||||
str(bridge_path),
|
||||
"--port",
|
||||
str(self._bridge_port),
|
||||
"--session",
|
||||
str(self._session_path),
|
||||
"--mode",
|
||||
whatsapp_mode,
|
||||
],
|
||||
stdout=bridge_log_fh,
|
||||
stderr=bridge_log_fh,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
# Wait for the bridge to connect to WhatsApp.
|
||||
# Phase 1: wait for the HTTP server to come up (up to 15s).
|
||||
# Phase 2: wait for WhatsApp status: connected (up to 15s more).
|
||||
import aiohttp
|
||||
|
||||
http_ready = False
|
||||
data = {}
|
||||
for attempt in range(15):
|
||||
await asyncio.sleep(1)
|
||||
if self._bridge_process.poll() is not None:
|
||||
print(f"[{self.name}] Bridge process died (exit code {self._bridge_process.returncode})")
|
||||
print(f"[{self.name}] Check log: {self._bridge_log}")
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
http_ready = True
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not http_ready:
|
||||
print(f"[{self.name}] Bridge HTTP server did not start in 15s")
|
||||
print(f"[{self.name}] Check log: {self._bridge_log}")
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
|
||||
# Phase 2: HTTP is up but WhatsApp may still be connecting.
|
||||
# Give it more time to authenticate with saved credentials.
|
||||
if data.get("status") != "connected":
|
||||
print(f"[{self.name}] Bridge HTTP ready, waiting for WhatsApp connection...")
|
||||
for attempt in range(15):
|
||||
await asyncio.sleep(1)
|
||||
if self._bridge_process.poll() is not None:
|
||||
print(f"[{self.name}] Bridge process died during connection")
|
||||
print(f"[{self.name}] Check log: {self._bridge_log}")
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
# Still not connected — warn but proceed (bridge may
|
||||
# auto-reconnect later, e.g. after a code 515 restart).
|
||||
print(f"[{self.name}] ⚠ WhatsApp not connected after 30s")
|
||||
print(f"[{self.name}] Bridge log: {self._bridge_log}")
|
||||
print(f"[{self.name}] If session expired, re-pair: hermes whatsapp")
|
||||
|
||||
# Start message polling task
|
||||
asyncio.create_task(self._poll_messages())
|
||||
|
||||
self._running = True
|
||||
print(f"[{self.name}] Bridge started on port {self._bridge_port}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
|
||||
def _close_bridge_log(self) -> None:
|
||||
"""Close the bridge log file handle if open."""
|
||||
if self._bridge_log_fh:
|
||||
try:
|
||||
self._bridge_log_fh.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._bridge_log_fh = None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop the WhatsApp bridge and clean up any orphaned processes."""
|
||||
if self._bridge_process:
|
||||
try:
|
||||
# Kill the entire process group so child node processes die too
|
||||
import signal
|
||||
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
self._bridge_process.terminate()
|
||||
else:
|
||||
os.killpg(os.getpgid(self._bridge_process.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
self._bridge_process.terminate()
|
||||
await asyncio.sleep(1)
|
||||
if self._bridge_process.poll() is None:
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
self._bridge_process.kill()
|
||||
else:
|
||||
os.killpg(os.getpgid(self._bridge_process.pid), signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
self._bridge_process.kill()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error stopping bridge: {e}")
|
||||
|
||||
# Also kill any orphaned bridge processes on our port
|
||||
_kill_port_process(self._bridge_port)
|
||||
|
||||
self._running = False
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
async def send(
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""Send a message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"chatId": chat_id,
|
||||
"message": content,
|
||||
}
|
||||
if reply_to:
|
||||
payload["replyTo"] = reply_to
|
||||
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/send", json=payload, timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(success=True, message_id=data.get("messageId"), raw_response=data)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
except ImportError:
|
||||
return SendResult(success=False, error="aiohttp not installed. Run: pip install aiohttp")
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
content: str,
|
||||
) -> SendResult:
|
||||
"""Edit a previously sent message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
f"http://localhost:{self._bridge_port}/edit",
|
||||
json={
|
||||
"chatId": chat_id,
|
||||
"messageId": message_id,
|
||||
"message": content,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def _send_media_to_bridge(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
media_type: str,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send any media file via bridge /send-media endpoint."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"chatId": chat_id,
|
||||
"filePath": file_path,
|
||||
"mediaType": media_type,
|
||||
}
|
||||
if caption:
|
||||
payload["caption"] = caption
|
||||
if file_name:
|
||||
payload["fileName"] = file_name
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
f"http://localhost:{self._bridge_port}/send-media",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data,
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Download image URL to cache, send natively via bridge."""
|
||||
try:
|
||||
local_path = await cache_image_from_url(image_url)
|
||||
return await self._send_media_to_bridge(chat_id, local_path, "image", caption)
|
||||
except Exception:
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively via bridge."""
|
||||
return await self._send_media_to_bridge(chat_id, image_path, "image", caption)
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a video natively via bridge — plays inline in WhatsApp."""
|
||||
return await self._send_media_to_bridge(chat_id, video_path, "video", caption)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a document/file as a downloadable attachment via bridge."""
|
||||
return await self._send_media_to_bridge(
|
||||
chat_id,
|
||||
file_path,
|
||||
"document",
|
||||
caption,
|
||||
file_name or os.path.basename(file_path),
|
||||
)
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator via bridge."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(
|
||||
f"http://localhost:{self._bridge_port}/typing",
|
||||
json={"chatId": chat_id},
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a WhatsApp chat."""
|
||||
if not self._running:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/chat/{chat_id}", timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"name": data.get("name", chat_id),
|
||||
"type": "group" if data.get("isGroup") else "dm",
|
||||
"participants": data.get("participants", []),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Could not get WhatsApp chat info for %s: %s", chat_id, e)
|
||||
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
async def _poll_messages(self) -> None:
|
||||
"""Poll the bridge for incoming messages."""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, message polling disabled")
|
||||
return
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/messages", timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = await self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Poll error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
await asyncio.sleep(1) # Poll interval
|
||||
|
||||
async def _build_message_event(self, data: dict[str, Any]) -> MessageEvent | None:
|
||||
"""Build a MessageEvent from bridge message data, downloading images to cache."""
|
||||
try:
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if data.get("hasMedia"):
|
||||
media_type = data.get("mediaType", "")
|
||||
if "image" in media_type:
|
||||
msg_type = MessageType.PHOTO
|
||||
elif "video" in media_type:
|
||||
msg_type = MessageType.VIDEO
|
||||
elif "audio" in media_type or "ptt" in media_type: # ptt = voice note
|
||||
msg_type = MessageType.VOICE
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
# Determine chat type
|
||||
is_group = data.get("isGroup", False)
|
||||
chat_type = "group" if is_group else "dm"
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=data.get("chatId", ""),
|
||||
chat_name=data.get("chatName"),
|
||||
chat_type=chat_type,
|
||||
user_id=data.get("senderId"),
|
||||
user_name=data.get("senderName"),
|
||||
)
|
||||
|
||||
# Download image media URLs to the local cache so the vision tool
|
||||
# can access them reliably regardless of URL expiration.
|
||||
raw_urls = data.get("mediaUrls", [])
|
||||
cached_urls = []
|
||||
media_types = []
|
||||
for url in raw_urls:
|
||||
if msg_type == MessageType.PHOTO and url.startswith(("http://", "https://")):
|
||||
try:
|
||||
cached_path = await cache_image_from_url(url, ext=".jpg")
|
||||
cached_urls.append(cached_path)
|
||||
media_types.append("image/jpeg")
|
||||
print(f"[{self.name}] Cached user image: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to cache image: {e}", flush=True)
|
||||
cached_urls.append(url)
|
||||
media_types.append("image/jpeg")
|
||||
elif msg_type == MessageType.VOICE and url.startswith(("http://", "https://")):
|
||||
try:
|
||||
cached_path = await cache_audio_from_url(url, ext=".ogg")
|
||||
cached_urls.append(cached_path)
|
||||
media_types.append("audio/ogg")
|
||||
print(f"[{self.name}] Cached user voice: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to cache voice: {e}", flush=True)
|
||||
cached_urls.append(url)
|
||||
media_types.append("audio/ogg")
|
||||
else:
|
||||
cached_urls.append(url)
|
||||
media_types.append("unknown")
|
||||
|
||||
return MessageEvent(
|
||||
text=data.get("body", ""),
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=data,
|
||||
message_id=data.get("messageId"),
|
||||
media_urls=cached_urls,
|
||||
media_types=media_types,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error building event: {e}")
|
||||
return None
|
||||
3198
gateway/run.py
Normal file
3198
gateway/run.py
Normal file
File diff suppressed because it is too large
Load Diff
751
gateway/session.py
Normal file
751
gateway/session.py
Normal file
@@ -0,0 +1,751 @@
|
||||
"""
|
||||
Session management for the gateway.
|
||||
|
||||
Handles:
|
||||
- Session context tracking (where messages come from)
|
||||
- Session storage (conversations persisted to disk)
|
||||
- Reset policy evaluation (when to start fresh)
|
||||
- Dynamic system prompt injection (agent knows its context)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .config import (
|
||||
GatewayConfig,
|
||||
HomeChannel,
|
||||
Platform,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionSource:
|
||||
"""
|
||||
Describes where a message originated from.
|
||||
|
||||
This information is used to:
|
||||
1. Route responses back to the right place
|
||||
2. Inject context into the system prompt
|
||||
3. Track origin for cron job delivery
|
||||
"""
|
||||
|
||||
platform: Platform
|
||||
chat_id: str
|
||||
chat_name: str | None = None
|
||||
chat_type: str = "dm" # "dm", "group", "channel", "thread"
|
||||
user_id: str | None = None
|
||||
user_name: str | None = None
|
||||
thread_id: str | None = None # For forum topics, Discord threads, etc.
|
||||
chat_topic: str | None = None # Channel topic/description (Discord, Slack)
|
||||
user_id_alt: str | None = None # Signal UUID (alternative to phone number)
|
||||
chat_id_alt: str | None = None # Signal group internal ID
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Human-readable description of the source."""
|
||||
if self.platform == Platform.LOCAL:
|
||||
return "CLI terminal"
|
||||
|
||||
parts = []
|
||||
if self.chat_type == "dm":
|
||||
parts.append(f"DM with {self.user_name or self.user_id or 'user'}")
|
||||
elif self.chat_type == "group":
|
||||
parts.append(f"group: {self.chat_name or self.chat_id}")
|
||||
elif self.chat_type == "channel":
|
||||
parts.append(f"channel: {self.chat_name or self.chat_id}")
|
||||
else:
|
||||
parts.append(self.chat_name or self.chat_id)
|
||||
|
||||
if self.thread_id:
|
||||
parts.append(f"thread: {self.thread_id}")
|
||||
|
||||
return ", ".join(parts)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
d = {
|
||||
"platform": self.platform.value,
|
||||
"chat_id": self.chat_id,
|
||||
"chat_name": self.chat_name,
|
||||
"chat_type": self.chat_type,
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"thread_id": self.thread_id,
|
||||
"chat_topic": self.chat_topic,
|
||||
}
|
||||
if self.user_id_alt:
|
||||
d["user_id_alt"] = self.user_id_alt
|
||||
if self.chat_id_alt:
|
||||
d["chat_id_alt"] = self.chat_id_alt
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionSource":
|
||||
return cls(
|
||||
platform=Platform(data["platform"]),
|
||||
chat_id=str(data["chat_id"]),
|
||||
chat_name=data.get("chat_name"),
|
||||
chat_type=data.get("chat_type", "dm"),
|
||||
user_id=data.get("user_id"),
|
||||
user_name=data.get("user_name"),
|
||||
thread_id=data.get("thread_id"),
|
||||
chat_topic=data.get("chat_topic"),
|
||||
user_id_alt=data.get("user_id_alt"),
|
||||
chat_id_alt=data.get("chat_id_alt"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def local_cli(cls) -> "SessionSource":
|
||||
"""Create a source representing the local CLI."""
|
||||
return cls(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI terminal",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""
|
||||
Full context for a session, used for dynamic system prompt injection.
|
||||
|
||||
The agent receives this information to understand:
|
||||
- Where messages are coming from
|
||||
- What platforms are available
|
||||
- Where it can deliver scheduled task outputs
|
||||
"""
|
||||
|
||||
source: SessionSource
|
||||
connected_platforms: list[Platform]
|
||||
home_channels: dict[Platform, HomeChannel]
|
||||
|
||||
# Session metadata
|
||||
session_key: str = ""
|
||||
session_id: str = ""
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"source": self.source.to_dict(),
|
||||
"connected_platforms": [p.value for p in self.connected_platforms],
|
||||
"home_channels": {p.value: hc.to_dict() for p, hc in self.home_channels.items()},
|
||||
"session_key": self.session_key,
|
||||
"session_id": self.session_id,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def build_session_context_prompt(context: SessionContext) -> str:
|
||||
"""
|
||||
Build the dynamic system prompt section that tells the agent about its context.
|
||||
|
||||
This is injected into the system prompt so the agent knows:
|
||||
- Where messages are coming from
|
||||
- What platforms are connected
|
||||
- Where it can deliver scheduled task outputs
|
||||
"""
|
||||
lines = [
|
||||
"## Current Session Context",
|
||||
"",
|
||||
]
|
||||
|
||||
# Source info
|
||||
platform_name = context.source.platform.value.title()
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append(f"**Source:** {platform_name} (the machine running this agent)")
|
||||
else:
|
||||
lines.append(f"**Source:** {platform_name} ({context.source.description})")
|
||||
|
||||
# Channel topic (if available - provides context about the channel's purpose)
|
||||
if context.source.chat_topic:
|
||||
lines.append(f"**Channel Topic:** {context.source.chat_topic}")
|
||||
|
||||
# User identity (especially useful for WhatsApp where multiple people DM)
|
||||
if context.source.user_name:
|
||||
lines.append(f"**User:** {context.source.user_name}")
|
||||
elif context.source.user_id:
|
||||
lines.append(f"**User ID:** {context.source.user_id}")
|
||||
|
||||
# Connected platforms
|
||||
platforms_list = ["local (files on this machine)"]
|
||||
for p in context.connected_platforms:
|
||||
if p != Platform.LOCAL:
|
||||
platforms_list.append(f"{p.value}: Connected ✓")
|
||||
|
||||
lines.append(f"**Connected Platforms:** {', '.join(platforms_list)}")
|
||||
|
||||
# Home channels
|
||||
if context.home_channels:
|
||||
lines.append("")
|
||||
lines.append("**Home Channels (default destinations):**")
|
||||
for platform, home in context.home_channels.items():
|
||||
lines.append(f" - {platform.value}: {home.name} (ID: {home.chat_id})")
|
||||
|
||||
# Delivery options for scheduled tasks
|
||||
lines.append("")
|
||||
lines.append("**Delivery options for scheduled tasks:**")
|
||||
|
||||
# Origin delivery
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append('- `"origin"` → Local output (saved to files)')
|
||||
else:
|
||||
lines.append(f'- `"origin"` → Back to this chat ({context.source.chat_name or context.source.chat_id})')
|
||||
|
||||
# Local always available
|
||||
lines.append('- `"local"` → Save to local files only (~/.hermes/cron/output/)')
|
||||
|
||||
# Platform home channels
|
||||
for platform, home in context.home_channels.items():
|
||||
lines.append(f'- `"{platform.value}"` → Home channel ({home.name})')
|
||||
|
||||
# Note about explicit targeting
|
||||
lines.append("")
|
||||
lines.append('*For explicit targeting, use `"platform:chat_id"` format if the user provides a specific chat ID.*')
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionEntry:
|
||||
"""
|
||||
Entry in the session store.
|
||||
|
||||
Maps a session key to its current session ID and metadata.
|
||||
"""
|
||||
|
||||
session_key: str
|
||||
session_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Origin metadata for delivery routing
|
||||
origin: SessionSource | None = None
|
||||
|
||||
# Display metadata
|
||||
display_name: str | None = None
|
||||
platform: Platform | None = None
|
||||
chat_type: str = "dm"
|
||||
|
||||
# Token tracking
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
# Set when a session was created because the previous one expired;
|
||||
# consumed once by the message handler to inject a notice into context
|
||||
was_auto_reset: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {
|
||||
"session_key": self.session_key,
|
||||
"session_id": self.session_id,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"display_name": self.display_name,
|
||||
"platform": self.platform.value if self.platform else None,
|
||||
"chat_type": self.chat_type,
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
}
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionEntry":
|
||||
origin = None
|
||||
if "origin" in data and data["origin"]:
|
||||
origin = SessionSource.from_dict(data["origin"])
|
||||
|
||||
platform = None
|
||||
if data.get("platform"):
|
||||
try:
|
||||
platform = Platform(data["platform"])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return cls(
|
||||
session_key=data["session_key"],
|
||||
session_id=data["session_id"],
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]),
|
||||
origin=origin,
|
||||
display_name=data.get("display_name"),
|
||||
platform=platform,
|
||||
chat_type=data.get("chat_type", "dm"),
|
||||
input_tokens=data.get("input_tokens", 0),
|
||||
output_tokens=data.get("output_tokens", 0),
|
||||
total_tokens=data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
|
||||
def build_session_key(source: SessionSource) -> str:
|
||||
"""Build a deterministic session key from a message source.
|
||||
|
||||
This is the single source of truth for session key construction.
|
||||
WhatsApp DMs include chat_id (multi-user), other DMs do not (single owner).
|
||||
"""
|
||||
platform = source.platform.value
|
||||
if source.chat_type == "dm":
|
||||
if platform == "whatsapp" and source.chat_id:
|
||||
return f"agent:main:{platform}:dm:{source.chat_id}"
|
||||
return f"agent:main:{platform}:dm"
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}"
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""
|
||||
Manages session storage and retrieval.
|
||||
|
||||
Uses SQLite (via SessionDB) for session metadata and message transcripts.
|
||||
Falls back to legacy JSONL files if SQLite is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, sessions_dir: Path, config: GatewayConfig, has_active_processes_fn=None, on_auto_reset=None):
|
||||
self.sessions_dir = sessions_dir
|
||||
self.config = config
|
||||
self._entries: dict[str, SessionEntry] = {}
|
||||
self._loaded = False
|
||||
self._has_active_processes_fn = has_active_processes_fn
|
||||
# on_auto_reset is deprecated — memory flush now runs proactively
|
||||
# via the background session expiry watcher in GatewayRunner.
|
||||
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
|
||||
|
||||
# Initialize SQLite session database
|
||||
self._db = None
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
|
||||
self._db = SessionDB()
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: SQLite session store unavailable, falling back to JSONL: {e}")
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
"""Load sessions index from disk if not already loaded."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
if sessions_file.exists():
|
||||
try:
|
||||
with open(sessions_file, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for key, entry_data in data.items():
|
||||
self._entries[key] = SessionEntry.from_dict(entry_data)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load sessions: {e}")
|
||||
|
||||
self._loaded = True
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Save sessions index to disk (kept for session key -> ID mapping)."""
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
data = {key: entry.to_dict() for key, entry in self._entries.items()}
|
||||
with open(sessions_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def _generate_session_key(self, source: SessionSource) -> str:
|
||||
"""Generate a session key from a source."""
|
||||
return build_session_key(source)
|
||||
|
||||
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
||||
"""Check if a session has expired based on its reset policy.
|
||||
|
||||
Works from the entry alone — no SessionSource needed.
|
||||
Used by the background expiry watcher to proactively flush memories.
|
||||
Sessions with active background processes are never considered expired.
|
||||
"""
|
||||
if self._has_active_processes_fn:
|
||||
if self._has_active_processes_fn(entry.session_key):
|
||||
return False
|
||||
|
||||
policy = self.config.get_reset_policy(
|
||||
platform=entry.platform,
|
||||
session_type=entry.chat_type,
|
||||
)
|
||||
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(
|
||||
hour=policy.at_hour,
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
if now.hour < policy.at_hour:
|
||||
today_reset -= timedelta(days=1)
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||
"""
|
||||
Check if a session should be reset based on policy.
|
||||
|
||||
Sessions with active background processes are never reset.
|
||||
"""
|
||||
if self._has_active_processes_fn:
|
||||
session_key = self._generate_session_key(source)
|
||||
if self._has_active_processes_fn(session_key):
|
||||
return False
|
||||
|
||||
policy = self.config.get_reset_policy(platform=source.platform, session_type=source.chat_type)
|
||||
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(hour=policy.at_hour, minute=0, second=0, microsecond=0)
|
||||
if now.hour < policy.at_hour:
|
||||
today_reset -= timedelta(days=1)
|
||||
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def has_any_sessions(self) -> bool:
|
||||
"""Check if any sessions have ever been created (across all platforms).
|
||||
|
||||
Uses the SQLite database as the source of truth because it preserves
|
||||
historical session records (ended sessions still count). The in-memory
|
||||
``_entries`` dict replaces entries on reset, so ``len(_entries)`` would
|
||||
stay at 1 for single-platform users — which is the bug this fixes.
|
||||
|
||||
The current session is already in the DB by the time this is called
|
||||
(get_or_create_session runs first), so we check ``> 1``.
|
||||
"""
|
||||
if self._db:
|
||||
try:
|
||||
return self._db.session_count() > 1
|
||||
except Exception:
|
||||
pass # fall through to heuristic
|
||||
# Fallback: check if sessions.json was loaded with existing data.
|
||||
# This covers the rare case where the DB is unavailable.
|
||||
self._ensure_loaded()
|
||||
return len(self._entries) > 1
|
||||
|
||||
def get_or_create_session(self, source: SessionSource, force_new: bool = False) -> SessionEntry:
|
||||
"""
|
||||
Get an existing session or create a new one.
|
||||
|
||||
Evaluates reset policy to determine if the existing session is stale.
|
||||
Creates a session record in SQLite when a new session starts.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
session_key = self._generate_session_key(source)
|
||||
now = datetime.now()
|
||||
|
||||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
if not self._should_reset(entry, source):
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
return entry
|
||||
else:
|
||||
# Session is being auto-reset. The background expiry watcher
|
||||
# should have already flushed memories proactively; discard
|
||||
# the marker so it doesn't accumulate.
|
||||
was_auto_reset = True
|
||||
self._pre_flushed_sessions.discard(entry.session_id)
|
||||
if self._db:
|
||||
try:
|
||||
self._db.end_session(entry.session_id, "session_reset")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
else:
|
||||
was_auto_reset = False
|
||||
|
||||
# Create new session
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=source,
|
||||
display_name=source.chat_name,
|
||||
platform=source.platform,
|
||||
chat_type=source.chat_type,
|
||||
was_auto_reset=was_auto_reset,
|
||||
)
|
||||
|
||||
self._entries[session_key] = entry
|
||||
self._save()
|
||||
|
||||
# Create session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
self._db.create_session(
|
||||
session_id=session_id,
|
||||
source=source.platform.value,
|
||||
user_id=source.user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to create SQLite session: {e}")
|
||||
|
||||
return entry
|
||||
|
||||
def update_session(self, session_key: str, input_tokens: int = 0, output_tokens: int = 0) -> None:
|
||||
"""Update a session's metadata after an interaction."""
|
||||
self._ensure_loaded()
|
||||
|
||||
if session_key in self._entries:
|
||||
entry = self._entries[session_key]
|
||||
entry.updated_at = datetime.now()
|
||||
entry.input_tokens += input_tokens
|
||||
entry.output_tokens += output_tokens
|
||||
entry.total_tokens = entry.input_tokens + entry.output_tokens
|
||||
self._save()
|
||||
|
||||
if self._db:
|
||||
try:
|
||||
self._db.update_token_counts(entry.session_id, input_tokens, output_tokens)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
def reset_session(self, session_key: str) -> SessionEntry | None:
|
||||
"""Force reset a session, creating a new session ID."""
|
||||
self._ensure_loaded()
|
||||
|
||||
if session_key not in self._entries:
|
||||
return None
|
||||
|
||||
old_entry = self._entries[session_key]
|
||||
|
||||
# End old session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
self._db.end_session(old_entry.session_id, "session_reset")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
now = datetime.now()
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
new_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=old_entry.origin,
|
||||
display_name=old_entry.display_name,
|
||||
platform=old_entry.platform,
|
||||
chat_type=old_entry.chat_type,
|
||||
)
|
||||
|
||||
self._entries[session_key] = new_entry
|
||||
self._save()
|
||||
|
||||
# Create new session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
self._db.create_session(
|
||||
session_id=session_id,
|
||||
source=old_entry.platform.value if old_entry.platform else "unknown",
|
||||
user_id=old_entry.origin.user_id if old_entry.origin else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
return new_entry
|
||||
|
||||
def switch_session(self, session_key: str, target_session_id: str) -> SessionEntry | None:
|
||||
"""Switch a session key to point at an existing session ID.
|
||||
|
||||
Used by ``/resume`` to restore a previously-named session.
|
||||
Ends the current session in SQLite (like reset), but instead of
|
||||
generating a fresh session ID, re-uses ``target_session_id`` so the
|
||||
old transcript is loaded on the next message.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
if session_key not in self._entries:
|
||||
return None
|
||||
|
||||
old_entry = self._entries[session_key]
|
||||
|
||||
# Don't switch if already on that session
|
||||
if old_entry.session_id == target_session_id:
|
||||
return old_entry
|
||||
|
||||
# End the current session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
self._db.end_session(old_entry.session_id, "session_switch")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB end_session failed: %s", e)
|
||||
|
||||
now = datetime.now()
|
||||
new_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=target_session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=old_entry.origin,
|
||||
display_name=old_entry.display_name,
|
||||
platform=old_entry.platform,
|
||||
chat_type=old_entry.chat_type,
|
||||
)
|
||||
|
||||
self._entries[session_key] = new_entry
|
||||
self._save()
|
||||
return new_entry
|
||||
|
||||
def list_sessions(self, active_minutes: int | None = None) -> list[SessionEntry]:
|
||||
"""List all sessions, optionally filtered by activity."""
|
||||
self._ensure_loaded()
|
||||
|
||||
entries = list(self._entries.values())
|
||||
|
||||
if active_minutes is not None:
|
||||
cutoff = datetime.now() - timedelta(minutes=active_minutes)
|
||||
entries = [e for e in entries if e.updated_at >= cutoff]
|
||||
|
||||
entries.sort(key=lambda e: e.updated_at, reverse=True)
|
||||
|
||||
return entries
|
||||
|
||||
def get_transcript_path(self, session_id: str) -> Path:
|
||||
"""Get the path to a session's legacy transcript file."""
|
||||
return self.sessions_dir / f"{session_id}.jsonl"
|
||||
|
||||
def append_to_transcript(self, session_id: str, message: dict[str, Any]) -> None:
|
||||
"""Append a message to a session's transcript (SQLite + legacy JSONL)."""
|
||||
# Write to SQLite
|
||||
if self._db:
|
||||
try:
|
||||
self._db.append_message(
|
||||
session_id=session_id,
|
||||
role=message.get("role", "unknown"),
|
||||
content=message.get("content"),
|
||||
tool_name=message.get("tool_name"),
|
||||
tool_calls=message.get("tool_calls"),
|
||||
tool_call_id=message.get("tool_call_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
# Also write legacy JSONL (keeps existing tooling working during transition)
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
|
||||
def rewrite_transcript(self, session_id: str, messages: list[dict[str, Any]]) -> None:
|
||||
"""Replace the entire transcript for a session with new messages.
|
||||
|
||||
Used by /retry, /undo, and /compress to persist modified conversation history.
|
||||
Rewrites both SQLite and legacy JSONL storage.
|
||||
"""
|
||||
# SQLite: clear old messages and re-insert
|
||||
if self._db:
|
||||
try:
|
||||
self._db.clear_messages(session_id)
|
||||
for msg in messages:
|
||||
self._db.append_message(
|
||||
session_id=session_id,
|
||||
role=msg.get("role", "unknown"),
|
||||
content=msg.get("content"),
|
||||
tool_name=msg.get("tool_name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
||||
|
||||
# JSONL: overwrite the file
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "w", encoding="utf-8") as f:
|
||||
for msg in messages:
|
||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||
|
||||
def load_transcript(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""Load all messages from a session's transcript."""
|
||||
# Try SQLite first
|
||||
if self._db:
|
||||
try:
|
||||
messages = self._db.get_messages_as_conversation(session_id)
|
||||
if messages:
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.debug("Could not load messages from DB: %s", e)
|
||||
|
||||
# Fall back to legacy JSONL
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
|
||||
if not transcript_path.exists():
|
||||
return []
|
||||
|
||||
messages = []
|
||||
with open(transcript_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
messages.append(json.loads(line))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_session_context(
|
||||
source: SessionSource, config: GatewayConfig, session_entry: SessionEntry | None = None
|
||||
) -> SessionContext:
|
||||
"""
|
||||
Build a full session context from a source and config.
|
||||
|
||||
This is used to inject context into the agent's system prompt.
|
||||
"""
|
||||
connected = config.get_connected_platforms()
|
||||
|
||||
home_channels = {}
|
||||
for platform in connected:
|
||||
home = config.get_home_channel(platform)
|
||||
if home:
|
||||
home_channels[platform] = home
|
||||
|
||||
context = SessionContext(
|
||||
source=source,
|
||||
connected_platforms=connected,
|
||||
home_channels=home_channels,
|
||||
)
|
||||
|
||||
if session_entry:
|
||||
context.session_key = session_entry.session_key
|
||||
context.session_id = session_entry.session_id
|
||||
context.created_at = session_entry.created_at
|
||||
context.updated_at = session_entry.updated_at
|
||||
|
||||
return context
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user