Simple example that runs Mistral model using candle framework. Adopted from candle's mistral example, but without platform-dependent optimizations and with UI for the easiest start.

Candle is a framework developed by Hugging Face - leading AI platform in the world. They started to build it because Python, common choice for AI development, introduces significant performance and devops overhead, while rust solves these problems, enchances reliability and provides direct access to WASM and WebGPU ecosystems to easily run models on the client side. As of now it's not as easy to use as PyTorch and missing some important features, but the future is bright and it already supports a lot of modern models like the one used in this example.

As always let's start with the manifest:

Cargo.toml

[package]
name = "llm-mistral"
edition = "2021"

[[bin]]
name = "serve"
path = "./serve.rs"

[dependencies]
prest = "0.5"

hf-hub = "0.3"
tokenizers = "0.15"
candle-core = "0.3"
candle-transformers = "0.3"

It includes hf-hub that simplifies model loading, tokenizers - another hugging face utility for efficient text pre- and postprocessing for LLMs, and candle-* crates which run calculations of the models.

The core example's code is in:

llm.rs

use candle_core::{DType, Device, Tensor};
use candle_transformers::{
    generation::LogitsProcessor,
    models::quantized_mistral::{Config as QMistralCfg, Model as QMistral},
    quantized_var_builder::VarBuilder,
    utils::apply_repeat_penalty,
};
use hf_hub::{api::sync::Api, Repo};
use prest::*;
use tokenizers::Tokenizer;

pub fn init() -> Somehow<Mistral> {
    let cfg = MistralConfig::default();
    let start = std::time::Instant::now();
    info!("started initializing the model...");
    let repo = Repo::model("lmz/candle-mistral".to_owned());
    let repo_api = Api::new().unwrap().repo(repo);
    let tokenizer_filename = repo_api.get("tokenizer.json").unwrap();
    let tokenizer = Tokenizer::from_file(tokenizer_filename).unwrap();
    let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
    let logits_processor = LogitsProcessor::new(cfg.seed, cfg.temperature, cfg.top_p);
    let weights_filename = repo_api.get("model-q4k.gguf").unwrap();
    let mistral_cfg = QMistralCfg::config_7b_v0_1(true);
    let weights = VarBuilder::from_gguf(&weights_filename, &Device::Cpu)?;
    let model = QMistral::new(&mistral_cfg, weights)?;
    info!("initialized the model in {:?}", start.elapsed());
    Ok(Mistral {
        model,
        logits_processor,
        cfg,
        tokenizer,
        eos_token,
        history: String::new(),
        tokens: vec![],
        current_ctx: 0,
        processed: 0,
    })
}

pub struct MistralConfig {
    pub seed: u64,
    pub repeat_penalty: f32,
    pub repeat_last_n: usize,
    pub temperature: Option<f64>,
    pub top_p: Option<f64>,
}
impl Default for MistralConfig {
    fn default() -> Self {
        Self {
            seed: 123456789,
            repeat_penalty: 1.1,
            repeat_last_n: 64,
            temperature: None,
            top_p: None,
        }
    }
}

pub struct Mistral {
    model: QMistral,
    logits_processor: LogitsProcessor,
    tokenizer: Tokenizer,
    cfg: MistralConfig,
    pub history: String,
    tokens: Vec<u32>,
    eos_token: u32,
    pub current_ctx: usize,
    processed: usize,
}

impl Mistral {
    pub fn prompt(&mut self, text: &str) -> Result<(), Error> {
        self.history += text;
        self.tokens.append(&mut self.encode(text));
        self.processed = self.tokens.len();
        Ok(())
    }
    pub fn more(&mut self) -> bool {
        let next_token = self.predict().unwrap();
        self.current_ctx = self.tokens.len();
        self.tokens.push(next_token);
        self.try_decode();
        return next_token != self.eos_token;
    }
    fn predict(&mut self) -> Somehow<u32> {
        let Mistral {
            tokens,
            current_ctx,
            cfg,
            ..
        } = self;
        let input = Tensor::new(&tokens[*current_ctx..], &Device::Cpu)?.unsqueeze(0)?;
        let logits = self.model.forward(&input, *current_ctx)?;
        let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
        let penalty_pos = tokens.len().saturating_sub(cfg.repeat_last_n);
        let logits = apply_repeat_penalty(&logits, cfg.repeat_penalty, &tokens[penalty_pos..])?;
        let next_token = self.logits_processor.sample(&logits)?;
        Ok(next_token)
    }
    fn encode(&self, input: &str) -> Vec<u32> {
        self.tokenizer
            .encode(input, true)
            .unwrap()
            .get_ids()
            .to_vec()
    }
    fn try_decode(&mut self) {
        let Mistral {
            tokens,
            processed,
            current_ctx,
            ..
        } = self;
        let processed_text = self
            .tokenizer
            .decode(&tokens[*processed..*current_ctx], true)
            .unwrap();
        let updated_text = self.tokenizer.decode(&tokens[*processed..], true).unwrap();
        if updated_text.len() > processed_text.len()
            && updated_text.chars().last().unwrap().is_ascii()
        {
            self.processed = self.current_ctx;
            let new_text = updated_text.split_at(processed_text.len()).1.to_string();
            self.history += &new_text;
        }
    }
}

It defines how the model is initialized, encodes, performs inference and decodes. Prest-based service that works with this model is defined here:

serve.rs

use prest::*;

mod llm;

state!(LLM: Mutex<llm::Mistral> = { Mutex::new(llm::init()?) });

#[derive(Deserialize)]
struct Prompt {
    pub content: String,
}

#[init]
async fn main() -> Result {
    info!("Initializing LLM...");
    let _ = *LLM;

    route("/", get(page))
        .route(
            "/prompt",
            post(|Vals(prompt): Vals<Prompt>| async move {
                {
                    let mut llm = LLM.lock().await;
                    if llm.history.len() == 0 {
                        llm.prompt(&prompt.content).unwrap()
                    } else {
                        let prompt = "\n".to_owned() + &prompt.content;
                        llm.prompt(&prompt).unwrap()
                    }
                }
                history(true).await
            }),
        )
        .route(
            "/more",
            get(|| async {
                let in_progress = LLM.lock().await.more();
                history(in_progress).await
            }),
        )
        .route(
            "/reset",
            get(|| async {
                let mut llm = LLM.lock().await;
                *llm = llm::init().unwrap();
                Redirect::to("/")
            }),
        )
        .run()
        .await
}

async fn page() -> Markup {
    html!( html { (Head::with_title("With Mistral LLM"))
        body $"max-w-screen-sm mx-auto mt-8" {
            div {(history(false).await)}
            (Scripts::default())
        }
    })
}

async fn history(in_progress: bool) -> Markup {
    let content = LLM.lock().await.history.clone();
    let btn = if content.len() == 0 {
        "Start generating"
    } else {
        "Append and continue"
    };
    html!(
        (PreEscaped(content))
        @if in_progress {
            ins get="/more" target="div" trigger="load"{}
            span {"loading..."}
            br{}
            button get="/" target="body" {"Pause"}
        }
        @else {
            form post="/prompt" target="div"  {
                input type="text" name="content" placeholder="Prompt" required {}
                button type="submit" {(btn)}
            }
        }
        button get="/reset" target="body" {"Reset"}
    )
}

Beware that it's a simple and naive implementation designed to check it out locally. For real-world SaaS or other types of services model should be managed differently, but this example is enough to demonstrate core building blocks.

v0.5.1
made by Egor Dezhic