Skip to main content

oxide_browser/
forge_config.rs

1//! Persistent Forge AI provider configuration (API keys, models).
2//!
3//! Stored at `{config_dir}/oxide/forge_config.json`. Environment variables
4//! (`ANTHROPIC_API_KEY`, `OPENAI_API_KEY`, `GEMINI_API_KEY`, `XAI_API_KEY`)
5//! are merged on load and do not overwrite saved keys.
6
7use std::collections::HashMap;
8use std::fmt;
9use std::path::PathBuf;
10
11use anyhow::{Context, Result};
12use serde::{Deserialize, Serialize};
13
14/// Supported LLM backends for Oxide Forge.
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum ForgeProvider {
18    Anthropic,
19    Openai,
20    Gemini,
21    Xai,
22}
23
24impl ForgeProvider {
25    pub const ALL: [ForgeProvider; 4] = [
26        ForgeProvider::Anthropic,
27        ForgeProvider::Openai,
28        ForgeProvider::Gemini,
29        ForgeProvider::Xai,
30    ];
31
32    pub fn id(self) -> &'static str {
33        match self {
34            ForgeProvider::Anthropic => "anthropic",
35            ForgeProvider::Openai => "openai",
36            ForgeProvider::Gemini => "gemini",
37            ForgeProvider::Xai => "xai",
38        }
39    }
40
41    pub fn label(self) -> &'static str {
42        match self {
43            ForgeProvider::Anthropic => "Anthropic",
44            ForgeProvider::Openai => "OpenAI",
45            ForgeProvider::Gemini => "Google Gemini",
46            ForgeProvider::Xai => "xAI",
47        }
48    }
49
50    pub fn default_model(self) -> &'static str {
51        match self {
52            ForgeProvider::Anthropic => "claude-opus-4-7",
53            ForgeProvider::Openai => "gpt-4o",
54            // Stable id per https://ai.google.dev/gemini-api/docs/models
55            ForgeProvider::Gemini => "gemini-2.5-flash",
56            ForgeProvider::Xai => "grok-2-latest",
57        }
58    }
59
60    pub fn env_var(self) -> &'static str {
61        match self {
62            ForgeProvider::Anthropic => "ANTHROPIC_API_KEY",
63            ForgeProvider::Openai => "OPENAI_API_KEY",
64            ForgeProvider::Gemini => "GEMINI_API_KEY",
65            ForgeProvider::Xai => "XAI_API_KEY",
66        }
67    }
68
69    pub fn from_id(s: &str) -> Option<Self> {
70        match s {
71            "anthropic" => Some(ForgeProvider::Anthropic),
72            "openai" => Some(ForgeProvider::Openai),
73            "gemini" => Some(ForgeProvider::Gemini),
74            "xai" => Some(ForgeProvider::Xai),
75            _ => None,
76        }
77    }
78
79    /// Map persisted model ids to a currently available API model name.
80    pub fn normalize_model(self, model: &str) -> String {
81        let model = model.trim();
82        if model.is_empty() {
83            return self.default_model().to_string();
84        }
85        match self {
86            ForgeProvider::Gemini => normalize_gemini_model(model),
87            _ => model.to_string(),
88        }
89    }
90}
91
92/// Gemini 2.0 and older ids return 404 — see [model deprecations](https://ai.google.dev/gemini-api/docs/models).
93pub fn normalize_gemini_model(model: &str) -> String {
94    match model.trim() {
95        "gemini-2.0-flash"
96        | "gemini-2.0-flash-lite"
97        | "gemini-2.0-flash-001"
98        | "gemini-2.0-flash-lite-001"
99        | "gemini-1.5-flash"
100        | "gemini-1.5-flash-8b"
101        | "gemini-1.5-pro"
102        | "gemini-pro"
103        | "gemini-3-pro-preview"
104        | "gemini-3-pro" => ForgeProvider::Gemini.default_model().to_string(),
105        other => other.to_string(),
106    }
107}
108
109impl fmt::Display for ForgeProvider {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        f.write_str(self.label())
112    }
113}
114
115/// Per-provider credentials and model choice.
116#[derive(Clone, Debug, Default, Serialize, Deserialize)]
117pub struct ForgeProviderSettings {
118    #[serde(default, skip_serializing_if = "String::is_empty")]
119    pub api_key: String,
120    #[serde(default, skip_serializing_if = "String::is_empty")]
121    pub model: String,
122}
123
124impl ForgeProviderSettings {
125    pub fn model_or_default(&self, provider: ForgeProvider) -> String {
126        provider.normalize_model(&self.model)
127    }
128
129    pub fn has_key(&self) -> bool {
130        !self.api_key.trim().is_empty()
131    }
132}
133
134/// User-facing Forge configuration persisted on disk.
135#[derive(Clone, Debug, Serialize, Deserialize)]
136pub struct ForgeUserConfig {
137    #[serde(default = "default_active_provider")]
138    pub active_provider: ForgeProvider,
139    #[serde(default)]
140    pub providers: HashMap<String, ForgeProviderSettings>,
141    #[serde(default)]
142    pub settings_open: bool,
143}
144
145fn default_active_provider() -> ForgeProvider {
146    ForgeProvider::Anthropic
147}
148
149impl Default for ForgeUserConfig {
150    fn default() -> Self {
151        Self {
152            active_provider: ForgeProvider::Anthropic,
153            providers: HashMap::new(),
154            settings_open: false,
155        }
156    }
157}
158
159impl ForgeUserConfig {
160    pub fn config_path() -> PathBuf {
161        dirs::config_dir()
162            .unwrap_or_else(|| PathBuf::from("."))
163            .join("oxide")
164            .join("forge_config.json")
165    }
166
167    pub fn load() -> Self {
168        let path = Self::config_path();
169        let mut cfg = if path.is_file() {
170            std::fs::read_to_string(&path)
171                .ok()
172                .and_then(|s| serde_json::from_str(&s).ok())
173                .unwrap_or_default()
174        } else {
175            ForgeUserConfig::default()
176        };
177        cfg.merge_env_keys();
178        cfg.migrate_deprecated_models();
179        cfg
180    }
181
182    /// Rewrite deprecated Gemini model strings in saved config.
183    fn migrate_deprecated_models(&mut self) {
184        let mut changed = false;
185        if let Some(entry) = self.providers.get_mut(ForgeProvider::Gemini.id()) {
186            let normalized = ForgeProvider::Gemini.normalize_model(&entry.model);
187            if entry.model != normalized {
188                entry.model = normalized;
189                changed = true;
190            }
191        }
192        if changed {
193            let _ = self.save();
194        }
195    }
196
197    pub fn save(&self) -> Result<()> {
198        let path = Self::config_path();
199        if let Some(parent) = path.parent() {
200            std::fs::create_dir_all(parent)
201                .with_context(|| format!("create {}", parent.display()))?;
202        }
203        let json = serde_json::to_string_pretty(self).context("serialise forge config")?;
204        std::fs::write(&path, json).with_context(|| format!("write {}", path.display()))?;
205        Ok(())
206    }
207
208    /// Fill missing keys from environment variables (does not overwrite saved keys).
209    pub fn merge_env_keys(&mut self) {
210        for provider in ForgeProvider::ALL {
211            if let Ok(key) = std::env::var(provider.env_var()) {
212                let trimmed = key.trim();
213                if !trimmed.is_empty() {
214                    let entry = self.provider_mut(provider);
215                    if !entry.has_key() {
216                        entry.api_key = trimmed.to_string();
217                    }
218                }
219            }
220        }
221        // Alternate env name for Gemini.
222        if let Ok(key) = std::env::var("GOOGLE_API_KEY") {
223            let trimmed = key.trim();
224            if !trimmed.is_empty() {
225                let entry = self.provider_mut(ForgeProvider::Gemini);
226                if !entry.has_key() {
227                    entry.api_key = trimmed.to_string();
228                }
229            }
230        }
231    }
232
233    pub fn provider(&self, p: ForgeProvider) -> ForgeProviderSettings {
234        self.providers.get(p.id()).cloned().unwrap_or_default()
235    }
236
237    pub fn provider_mut(&mut self, p: ForgeProvider) -> &mut ForgeProviderSettings {
238        self.providers.entry(p.id().to_string()).or_default()
239    }
240
241    pub fn active_settings(&self) -> ForgeProviderSettings {
242        self.provider(self.active_provider)
243    }
244
245    pub fn active_api_key(&self) -> Option<String> {
246        let key = self.active_settings().api_key.trim().to_string();
247        if key.is_empty() {
248            None
249        } else {
250            Some(key)
251        }
252    }
253
254    pub fn active_model(&self) -> String {
255        self.active_settings()
256            .model_or_default(self.active_provider)
257    }
258
259    pub fn any_provider_configured(&self) -> bool {
260        ForgeProvider::ALL
261            .iter()
262            .any(|p| self.provider(*p).has_key())
263    }
264
265    pub fn configured_providers(&self) -> Vec<ForgeProvider> {
266        ForgeProvider::ALL
267            .iter()
268            .filter(|p| self.provider(**p).has_key())
269            .copied()
270            .collect()
271    }
272
273    pub fn set_api_key(&mut self, provider: ForgeProvider, key: String) {
274        let key = key.trim().to_string();
275        if key.is_empty() {
276            return;
277        }
278        self.provider_mut(provider).api_key = key;
279    }
280
281    pub fn set_model(&mut self, provider: ForgeProvider, model: String) {
282        self.provider_mut(provider).model = model.trim().to_string();
283    }
284}
285
286/// Mask an API key for display (`sk-…abcd`).
287pub fn mask_api_key(key: &str) -> String {
288    let key = key.trim();
289    if key.is_empty() {
290        return String::new();
291    }
292    if key.len() <= 8 {
293        return "•".repeat(key.chars().count());
294    }
295    let prefix: String = key.chars().take(4).collect();
296    let suffix: String = key
297        .chars()
298        .rev()
299        .take(4)
300        .collect::<String>()
301        .chars()
302        .rev()
303        .collect();
304    format!("{prefix}…{suffix}")
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn masks_long_keys() {
313        let m = mask_api_key("sk-ant-api03-abcdefghijklmnop");
314        assert!(m.contains('…'));
315        assert!(m.starts_with("sk-a"));
316    }
317
318    #[test]
319    fn migrates_deprecated_gemini_models() {
320        assert_eq!(
321            normalize_gemini_model("gemini-2.0-flash"),
322            "gemini-2.5-flash"
323        );
324        assert_eq!(
325            normalize_gemini_model("gemini-3-pro-preview"),
326            "gemini-2.5-flash"
327        );
328        assert_eq!(
329            normalize_gemini_model("gemini-2.5-flash"),
330            "gemini-2.5-flash"
331        );
332    }
333}