diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index 55132c1..4223c20 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -96,6 +96,7 @@ pub struct ToolSpec { #[derive(Debug, Clone, PartialEq)] pub struct GlobalToolRegistry { plugin_tools: Vec, + enforcer: Option, } impl GlobalToolRegistry { @@ -103,6 +104,7 @@ impl GlobalToolRegistry { pub fn builtin() -> Self { Self { plugin_tools: Vec::new(), + enforcer: None, } } @@ -125,7 +127,7 @@ impl GlobalToolRegistry { } } - Ok(Self { plugin_tools }) + Ok(Self { plugin_tools, enforcer: None }) } pub fn normalize_allowed_tools( @@ -229,7 +231,14 @@ impl GlobalToolRegistry { Ok(builtin.chain(plugin).collect()) } + pub fn set_enforcer(&mut self, enforcer: PermissionEnforcer) { + self.enforcer = Some(enforcer); + } + pub fn execute(&self, name: &str, input: &Value) -> Result { + if let Some(enforcer) = &self.enforcer { + enforce_permission_check(enforcer, name, input)?; + } if mvp_tool_specs().iter().any(|spec| spec.name == name) { return execute_tool(name, input); } @@ -2776,11 +2785,12 @@ impl ApiClient for ProviderRuntimeClient { struct SubagentToolExecutor { allowed_tools: BTreeSet, + enforcer: Option, } impl SubagentToolExecutor { fn new(allowed_tools: BTreeSet) -> Self { - Self { allowed_tools } + Self { allowed_tools, enforcer: None } } } @@ -2793,6 +2803,10 @@ impl ToolExecutor for SubagentToolExecutor { } let value = serde_json::from_str(input) .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + if let Some(enforcer) = &self.enforcer { + enforce_permission_check(enforcer, tool_name, &value) + .map_err(ToolError::new)?; + } execute_tool(tool_name, &value).map_err(ToolError::new) } } @@ -4890,7 +4904,7 @@ mod tests { AssistantEvent::MessageStop, ]) } - _ => panic!("unexpected mock stream call"), + _ => unreachable!("extra mock stream call"), } } }