Source code
Revision control
Copy as Markdown
Other Tools
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=8 sts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
#ifndef LlamaRuntimeLinker_h__
#define LlamaRuntimeLinker_h__
#include "mozilla/Attributes.h"
#include "mozilla/Types.h"
#include "llama/llama.h"
#include "ggml.h"
struct PRLibrary;
namespace mozilla::llama {
// Format: X(return_type, name, params)
#define MOZINFERENCE_FUNCTION_LIST(X)                                       \
  X(void, llama_log_set,                                                    \
    (void (*callback)(enum ggml_log_level, const char*, void*),             \
     void* user_data))                                                      \
  X(struct llama_model_params, llama_model_default_params, (void))          \
  X(struct llama_model*, llama_model_load_from_file_handle,                 \
    (FILE * file, struct llama_model_params params))                        \
  X(int32_t, llama_model_meta_val_str,                                      \
    (const struct llama_model* model, const char* key, char* buf,           \
     size_t buf_size))                                                      \
  X(struct llama_context_params, llama_context_default_params, (void))      \
  X(struct llama_context*, llama_init_from_model,                           \
    (struct llama_model * model, struct llama_context_params params))       \
  X(void, llama_attach_threadpool,                                          \
    (struct llama_context * ctx, ggml_threadpool_t threadpool,              \
     ggml_threadpool_t threadpool_batch))                                   \
  X(const char*, llama_model_chat_template,                                 \
    (const struct llama_model* model, const char* name))                    \
  X(int32_t, llama_chat_apply_template,                                     \
    (const char* tmpl, const struct llama_chat_message* chat, size_t n_msg, \
     bool add_ass, char* buf, int32_t length))                              \
  X(struct llama_sampler_chain_params, llama_sampler_chain_default_params,  \
    (void))                                                                 \
  X(struct llama_sampler*, llama_sampler_chain_init,                        \
    (struct llama_sampler_chain_params params))                             \
  X(void, llama_sampler_chain_add,                                          \
    (struct llama_sampler * chain, struct llama_sampler * smpl))            \
  X(struct llama_sampler*, llama_sampler_init_greedy, (void))               \
  X(struct llama_sampler*, llama_sampler_init_temp, (float t))              \
  X(struct llama_sampler*, llama_sampler_init_dist, (uint32_t seed))        \
  X(struct llama_sampler*, llama_sampler_init_top_k, (int32_t k))           \
  X(struct llama_sampler*, llama_sampler_init_top_p,                        \
    (float p, size_t min_keep))                                             \
  X(struct llama_sampler*, llama_sampler_init_logit_bias,                   \
    (int32_t n_vocab, int32_t n_logit_bias,                                 \
     const llama_logit_bias* logit_bias))                                   \
  X(void, llama_memory_clear, (llama_memory_t mem, bool data))              \
  X(llama_memory_t, llama_get_memory, (const struct llama_context* ctx))    \
  X(const struct llama_vocab*, llama_model_get_vocab,                       \
    (const struct llama_model* model))                                      \
  X(int32_t, llama_vocab_n_tokens, (const struct llama_vocab* vocab))       \
  X(int32_t, llama_tokenize,                                                \
    (const struct llama_vocab* vocab, const char* text, int32_t text_len,   \
     llama_token* tokens, int32_t n_tokens_max, bool add_special,           \
     bool parse_special))                                                   \
  X(uint32_t, llama_n_ctx, (const struct llama_context* ctx))               \
  X(struct llama_batch, llama_batch_get_one,                                \
    (llama_token * tokens, int32_t n_tokens))                               \
  X(int32_t, llama_memory_seq_pos_max,                                      \
    (llama_memory_t mem, llama_seq_id seq_id))                              \
  X(int32_t, llama_decode,                                                  \
    (struct llama_context * ctx, struct llama_batch batch))                 \
  X(llama_token, llama_sampler_sample,                                      \
    (struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx)) \
  X(bool, llama_vocab_is_eog,                                               \
    (const struct llama_vocab* vocab, llama_token token))                   \
  X(int32_t, llama_token_to_piece,                                          \
    (const struct llama_vocab* vocab, llama_token token, char* buf,         \
     int32_t length, int32_t lstrip, bool special))                         \
  X(void, llama_model_free, (struct llama_model * model))                   \
  X(void, llama_free, (struct llama_context * ctx))                         \
  X(void, llama_sampler_free, (struct llama_sampler * smpl))                \
  X(void, ggml_threadpool_params_init,                                      \
    (struct ggml_threadpool_params * p, int n_threads))                     \
  X(bool, ggml_threadpool_params_match,                                     \
    (const struct ggml_threadpool_params* p0,                               \
     const struct ggml_threadpool_params* p1))                              \
  X(ggml_threadpool_t, ggml_threadpool_new,                                 \
    (struct ggml_threadpool_params * params))                               \
  X(void, ggml_threadpool_free, (ggml_threadpool_t threadpool))
struct LlamaLibWrapper {
  LlamaLibWrapper() = default;
  ~LlamaLibWrapper() = default;
  enum class LinkResult {
    Success,
    NoProvidedLib,
    MissingFunction,
  };
  LinkResult Link();
  void Unlink();
  // Library handle
  PRLibrary* mLlamaLib;
#define DECLARE_FUNCTION_PTR(ret, name, params) ret(*name) params;
  MOZINFERENCE_FUNCTION_LIST(DECLARE_FUNCTION_PTR)
#undef DECLARE_FUNCTION_PTR
};
class LlamaRuntimeLinker {
 public:
  enum LinkStatus {
    LinkStatus_INIT = 0,
    LinkStatus_FAILED,
    LinkStatus_SUCCEEDED,
  };
  // Initialize the dynamic linker, returns true on success
  static bool Init();
  // Get the llama library wrapper
  static LlamaLibWrapper* Get() {
    if (!Init()) {
      return nullptr;
    }
    return &sLlamaLib;
  }
  // Check if the library has been successfully linked
  static bool IsAvailable() { return sLinkStatus == LinkStatus_SUCCEEDED; }
 private:
  static LlamaLibWrapper sLlamaLib;
  static LinkStatus sLinkStatus;
};
}  // namespace mozilla::llama
#endif  // LlamaRuntimeLinker_h__