Source code for architxt.llm
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
from langchain_core.rate_limiters import BaseRateLimiter # vu
def _get_local_chat_model(
model_provider: str,
model_name: str,
*,
max_tokens: int,
temperature: float,
rate_limiter: BaseRateLimiter | None = None,
openvino: bool = False,
) -> BaseChatModel:
if model_provider != 'huggingface':
msg = f'Unsupported model provider for local mode: {model_provider}. Should be huggingface'
raise ValueError(msg)
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
pipeline = HuggingFacePipeline.from_model_id(
model_id=model_name,
task='text-generation',
device_map=None if openvino else 'auto',
backend='openvino' if openvino else 'pt',
model_kwargs={'export': True} if openvino else {'torch_dtype': 'auto'},
pipeline_kwargs={
'use_cache': True,
'do_sample': True,
'return_full_text': False,
'repetition_penalty': 1.1,
'num_return_sequences': 1,
'pad_token_id': 0,
},
)
return ChatHuggingFace(
llm=pipeline,
rate_limiter=rate_limiter,
max_tokens=max_tokens,
temperature=temperature,
)
[docs]
def get_chat_model(
model_provider: str,
model_name: str,
*,
max_tokens: int,
temperature: float,
rate_limiter: BaseRateLimiter | None = None,
local: bool = False,
openvino: bool = False,
) -> BaseChatModel:
if local:
return _get_local_chat_model(
model_provider,
model_name,
max_tokens=max_tokens,
temperature=temperature,
rate_limiter=rate_limiter,
openvino=openvino,
)
from langchain.chat_models import init_chat_model
return init_chat_model(
model_provider=model_provider,
model=model_name,
temperature=temperature,
max_tokens=max_tokens,
rate_limiter=rate_limiter,
)