瀏覽代碼

feat: ollama support (#2003)

takatost 1 年之前
父節點
當前提交
cca9edc97a
共有 21 個文件被更改,包括 1293 次插入13 次删除
  1. 26 3
      api/core/app_runner/generate_task_pipeline.py
  2. 1 0
      api/core/model_runtime/model_providers/_position.yaml
  3. 2 2
      api/core/model_runtime/model_providers/localai/localai.yaml
  4. 0 0
      api/core/model_runtime/model_providers/ollama/__init__.py
  5. 12 0
      api/core/model_runtime/model_providers/ollama/_assets/icon_l_en.svg
  6. 12 0
      api/core/model_runtime/model_providers/ollama/_assets/icon_s_en.svg
  7. 0 0
      api/core/model_runtime/model_providers/ollama/llm/__init__.py
  8. 615 0
      api/core/model_runtime/model_providers/ollama/llm/llm.py
  9. 17 0
      api/core/model_runtime/model_providers/ollama/ollama.py
  10. 98 0
      api/core/model_runtime/model_providers/ollama/ollama.yaml
  11. 0 0
      api/core/model_runtime/model_providers/ollama/text_embedding/__init__.py
  12. 221 0
      api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
  13. 1 0
      api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
  14. 2 2
      api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml
  15. 2 2
      api/core/model_runtime/model_providers/openllm/openllm.yaml
  16. 2 2
      api/core/model_runtime/model_providers/xinference/xinference.yaml
  17. 18 2
      api/core/prompt/prompt_transform.py
  18. 3 0
      api/tests/integration_tests/.env.example
  19. 0 0
      api/tests/integration_tests/model_runtime/ollama/__init__.py
  20. 190 0
      api/tests/integration_tests/model_runtime/ollama/test_llm.py
  21. 71 0
      api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py

+ 26 - 3
api/core/app_runner/generate_task_pipeline.py

@@ -459,10 +459,33 @@ class GenerateTaskPipeline:
                     "files": files
                 })
         else:
-            prompts.append({
+            prompt_message = prompt_messages[0]
+            text = ''
+            files = []
+            if isinstance(prompt_message.content, list):
+                for content in prompt_message.content:
+                    if content.type == PromptMessageContentType.TEXT:
+                        content = cast(TextPromptMessageContent, content)
+                        text += content.data
+                    else:
+                        content = cast(ImagePromptMessageContent, content)
+                        files.append({
+                            "type": 'image',
+                            "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
+                            "detail": content.detail.value
+                        })
+            else:
+                text = prompt_message.content
+
+            params = {
                 "role": 'user',
-                "text": prompt_messages[0].content
-            })
+                "text": text,
+            }
+
+            if files:
+                params['files'] = files
+
+            prompts.append(params)
 
         return prompts
 

+ 1 - 0
api/core/model_runtime/model_providers/_position.yaml

@@ -6,6 +6,7 @@
 - huggingface_hub
 - cohere
 - togetherai
+- ollama
 - zhipuai
 - baichuan
 - spark

+ 2 - 2
api/core/model_runtime/model_providers/localai/localai.yaml

@@ -54,5 +54,5 @@ model_credential_schema:
       type: text-input
       required: true
       placeholder:
-        zh_Hans: 在此输入LocalAI的服务器地址,如 https://example.com/xxx
-        en_US: Enter the url of your LocalAI, for example https://example.com/xxx
+        zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
+        en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080

+ 0 - 0
api/core/model_runtime/model_providers/ollama/__init__.py


文件差異過大導致無法顯示
+ 12 - 0
api/core/model_runtime/model_providers/ollama/_assets/icon_l_en.svg


文件差異過大導致無法顯示
+ 12 - 0
api/core/model_runtime/model_providers/ollama/_assets/icon_s_en.svg


+ 0 - 0
api/core/model_runtime/model_providers/ollama/llm/__init__.py


+ 615 - 0
api/core/model_runtime/model_providers/ollama/llm/llm.py

@@ -0,0 +1,615 @@
+import json
+import logging
+import re
+from decimal import Decimal
+from typing import Optional, Generator, Union, List, cast
+from urllib.parse import urljoin
+
+import requests
+
+from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \
+    UserPromptMessage, PromptMessageContentType, ImagePromptMessageContent, \
+    TextPromptMessageContent, SystemPromptMessage
+from core.model_runtime.entities.model_entities import I18nObject, ModelType, \
+    PriceConfig, AIModelEntity, FetchFrom, ModelPropertyKey, ParameterRule, ParameterType, DefaultParameterName, \
+    ModelFeature
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, \
+    LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
+    InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+
+class OllamaLargeLanguageModel(LargeLanguageModel):
+    """
+    Model class for Ollama large language model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        return self._generate(
+            model=model,
+            credentials=credentials,
+            prompt_messages=prompt_messages,
+            model_parameters=model_parameters,
+            stop=stop,
+            stream=stream,
+            user=user
+        )
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return:
+        """
+        # get model mode
+        model_mode = self.get_model_mode(model, credentials)
+
+        if model_mode == LLMMode.CHAT:
+            # chat model
+            return self._num_tokens_from_messages(prompt_messages)
+        else:
+            first_prompt_message = prompt_messages[0]
+            if isinstance(first_prompt_message.content, str):
+                text = first_prompt_message.content
+            else:
+                text = ''
+                for message_content in first_prompt_message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        message_content = cast(TextPromptMessageContent, message_content)
+                        text = message_content.data
+                        break
+            return self._get_num_tokens_by_gpt2(text)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            self._generate(
+                model=model,
+                credentials=credentials,
+                prompt_messages=[UserPromptMessage(content="ping")],
+                model_parameters={
+                    'num_predict': 5
+                },
+                stream=False
+            )
+        except InvokeError as ex:
+            raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
+        except Exception as ex:
+            raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
+
+    def _generate(self, model: str, credentials: dict,
+                  prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
+                  stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+        """
+        Invoke llm completion model
+
+        :param model: model name
+        :param credentials: credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        headers = {
+            'Content-Type': 'application/json'
+        }
+
+        endpoint_url = credentials['base_url']
+        if not endpoint_url.endswith('/'):
+            endpoint_url += '/'
+
+        # prepare the payload for a simple ping to the model
+        data = {
+            'model': model,
+            'stream': stream
+        }
+
+        if 'format' in model_parameters:
+            data['format'] = model_parameters['format']
+            del model_parameters['format']
+
+        data['options'] = model_parameters or {}
+
+        if stop:
+            data['stop'] = "\n".join(stop)
+
+        completion_type = LLMMode.value_of(credentials['mode'])
+
+        if completion_type is LLMMode.CHAT:
+            endpoint_url = urljoin(endpoint_url, 'api/chat')
+            data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
+        else:
+            endpoint_url = urljoin(endpoint_url, 'api/generate')
+            first_prompt_message = prompt_messages[0]
+            if isinstance(first_prompt_message, UserPromptMessage):
+                first_prompt_message = cast(UserPromptMessage, first_prompt_message)
+                if isinstance(first_prompt_message.content, str):
+                    data['prompt'] = first_prompt_message.content
+                else:
+                    text = ''
+                    images = []
+                    for message_content in first_prompt_message.content:
+                        if message_content.type == PromptMessageContentType.TEXT:
+                            message_content = cast(TextPromptMessageContent, message_content)
+                            text = message_content.data
+                        elif message_content.type == PromptMessageContentType.IMAGE:
+                            message_content = cast(ImagePromptMessageContent, message_content)
+                            image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+                            images.append(image_data)
+
+                    data['prompt'] = text
+                    data['images'] = images
+
+        # send a post request to validate the credentials
+        response = requests.post(
+            endpoint_url,
+            headers=headers,
+            json=data,
+            timeout=(10, 60),
+            stream=stream
+        )
+
+        response.encoding = "utf-8"
+        if response.status_code != 200:
+            raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
+
+        if stream:
+            return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
+
+        return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
+
+    def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode,
+                                  response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult:
+        """
+        Handle llm completion response
+
+        :param model: model name
+        :param credentials: model credentials
+        :param completion_type: completion type
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: llm result
+        """
+        response_json = response.json()
+
+        if completion_type is LLMMode.CHAT:
+            message = response_json.get('message', {})
+            response_content = message.get('content', '')
+        else:
+            response_content = response_json['response']
+
+        assistant_message = AssistantPromptMessage(content=response_content)
+
+        if 'prompt_eval_count' in response_json and 'eval_count' in response_json:
+            # transform usage
+            prompt_tokens = response_json["prompt_eval_count"]
+            completion_tokens = response_json["eval_count"]
+        else:
+            # calculate num tokens
+            prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
+            completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        # transform response
+        result = LLMResult(
+            model=response_json["model"],
+            prompt_messages=prompt_messages,
+            message=assistant_message,
+            usage=usage,
+        )
+
+        return result
+
+    def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode,
+                                         response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator:
+        """
+        Handle llm completion stream response
+
+        :param model: model name
+        :param credentials: model credentials
+        :param completion_type: completion type
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: llm response chunk generator result
+        """
+        full_text = ''
+        chunk_index = 0
+
+        def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
+                -> LLMResultChunk:
+            # calculate num tokens
+            prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
+            completion_tokens = self._get_num_tokens_by_gpt2(full_text)
+
+            # transform usage
+            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+            return LLMResultChunk(
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=LLMResultChunkDelta(
+                    index=index,
+                    message=message,
+                    finish_reason=finish_reason,
+                    usage=usage
+                )
+            )
+
+        for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'):
+            if not chunk:
+                continue
+
+            try:
+                chunk_json = json.loads(chunk)
+                # stream ended
+            except json.JSONDecodeError as e:
+                yield create_final_llm_result_chunk(
+                    index=chunk_index,
+                    message=AssistantPromptMessage(content=""),
+                    finish_reason="Non-JSON encountered."
+                )
+
+                chunk_index += 1
+                break
+
+            if completion_type is LLMMode.CHAT:
+                if not chunk_json:
+                    continue
+
+                if 'message' not in chunk_json:
+                    text = ''
+                else:
+                    text = chunk_json.get('message').get('content', '')
+            else:
+                if not chunk_json:
+                    continue
+
+                # transform assistant message to prompt message
+                text = chunk_json['response']
+
+            assistant_prompt_message = AssistantPromptMessage(
+                content=text
+            )
+
+            full_text += text
+
+            if chunk_json['done']:
+                # calculate num tokens
+                if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json:
+                    # transform usage
+                    prompt_tokens = chunk_json["prompt_eval_count"]
+                    completion_tokens = chunk_json["eval_count"]
+                else:
+                    # calculate num tokens
+                    prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
+                    completion_tokens = self._get_num_tokens_by_gpt2(full_text)
+
+                # transform usage
+                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+                yield LLMResultChunk(
+                    model=chunk_json['model'],
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=chunk_index,
+                        message=assistant_prompt_message,
+                        finish_reason='stop',
+                        usage=usage
+                    )
+                )
+            else:
+                yield LLMResultChunk(
+                    model=chunk_json['model'],
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=chunk_index,
+                        message=assistant_prompt_message,
+                    )
+                )
+
+            chunk_index += 1
+
+    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+        """
+        Convert PromptMessage to dict for Ollama API
+        """
+        if isinstance(message, UserPromptMessage):
+            message = cast(UserPromptMessage, message)
+            if isinstance(message.content, str):
+                message_dict = {"role": "user", "content": message.content}
+            else:
+                text = ''
+                images = []
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        message_content = cast(TextPromptMessageContent, message_content)
+                        text = message_content.data
+                    elif message_content.type == PromptMessageContentType.IMAGE:
+                        message_content = cast(ImagePromptMessageContent, message_content)
+                        image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+                        images.append(image_data)
+
+                message_dict = {"role": "user", "content": text, "images": images}
+        elif isinstance(message, AssistantPromptMessage):
+            message = cast(AssistantPromptMessage, message)
+            message_dict = {"role": "assistant", "content": message.content}
+        elif isinstance(message, SystemPromptMessage):
+            message = cast(SystemPromptMessage, message)
+            message_dict = {"role": "system", "content": message.content}
+        else:
+            raise ValueError(f"Got unknown type {message}")
+
+        return message_dict
+
+    def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int:
+        """
+        Calculate num tokens.
+
+        :param messages: messages
+        """
+        num_tokens = 0
+        messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
+        for message in messages_dict:
+            for key, value in message.items():
+                num_tokens += self._get_num_tokens_by_gpt2(str(key))
+                num_tokens += self._get_num_tokens_by_gpt2(str(value))
+
+        return num_tokens
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        """
+        Get customizable model schema.
+
+        :param model: model name
+        :param credentials: credentials
+
+        :return: model schema
+        """
+        extras = {}
+
+        if 'vision_support' in credentials and credentials['vision_support'] == 'true':
+            extras['features'] = [ModelFeature.VISION]
+
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                zh_Hans=model,
+                en_US=model
+            ),
+            model_type=ModelType.LLM,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.MODE: credentials.get('mode'),
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
+            },
+            parameter_rules=[
+                ParameterRule(
+                    name=DefaultParameterName.TEMPERATURE.value,
+                    use_template=DefaultParameterName.TEMPERATURE.value,
+                    label=I18nObject(en_US="Temperature"),
+                    type=ParameterType.FLOAT,
+                    help=I18nObject(en_US="The temperature of the model. "
+                                          "Increasing the temperature will make the model answer "
+                                          "more creatively. (Default: 0.8)"),
+                    default=0.8,
+                    min=0,
+                    max=2
+                ),
+                ParameterRule(
+                    name=DefaultParameterName.TOP_P.value,
+                    use_template=DefaultParameterName.TOP_P.value,
+                    label=I18nObject(en_US="Top P"),
+                    type=ParameterType.FLOAT,
+                    help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
+                                          "more diverse text, while a lower value (e.g., 0.5) will generate more "
+                                          "focused and conservative text. (Default: 0.9)"),
+                    default=0.9,
+                    min=0,
+                    max=1
+                ),
+                ParameterRule(
+                    name="top_k",
+                    label=I18nObject(en_US="Top K"),
+                    type=ParameterType.INT,
+                    help=I18nObject(en_US="Reduces the probability of generating nonsense. "
+                                          "A higher value (e.g. 100) will give more diverse answers, "
+                                          "while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
+                    default=40,
+                    min=1,
+                    max=100
+                ),
+                ParameterRule(
+                    name='repeat_penalty',
+                    label=I18nObject(en_US="Repeat Penalty"),
+                    type=ParameterType.FLOAT,
+                    help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
+                                          "A higher value (e.g., 1.5) will penalize repetitions more strongly, "
+                                          "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
+                    default=1.1,
+                    min=-2,
+                    max=2
+                ),
+                ParameterRule(
+                    name='num_predict',
+                    use_template='max_tokens',
+                    label=I18nObject(en_US="Num Predict"),
+                    type=ParameterType.INT,
+                    help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
+                                          "(Default: 128, -1 = infinite generation, -2 = fill context)"),
+                    default=128,
+                    min=-2,
+                    max=int(credentials.get('max_tokens', 4096)),
+                ),
+                ParameterRule(
+                    name='mirostat',
+                    label=I18nObject(en_US="Mirostat sampling"),
+                    type=ParameterType.INT,
+                    help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
+                                          "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
+                    default=0,
+                    min=0,
+                    max=2
+                ),
+                ParameterRule(
+                    name='mirostat_eta',
+                    label=I18nObject(en_US="Mirostat Eta"),
+                    type=ParameterType.FLOAT,
+                    help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from "
+                                          "the generated text. A lower learning rate will result in slower adjustments, "
+                                          "while a higher learning rate will make the algorithm more responsive. "
+                                          "(Default: 0.1)"),
+                    default=0.1,
+                    precision=1
+                ),
+                ParameterRule(
+                    name='mirostat_tau',
+                    label=I18nObject(en_US="Mirostat Tau"),
+                    type=ParameterType.FLOAT,
+                    help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
+                                          "A lower value will result in more focused and coherent text. (Default: 5.0)"),
+                    default=5.0,
+                    precision=1
+                ),
+                ParameterRule(
+                    name='num_ctx',
+                    label=I18nObject(en_US="Size of context window"),
+                    type=ParameterType.INT,
+                    help=I18nObject(en_US="Sets the size of the context window used to generate the next token. "
+                                          "(Default: 2048)"),
+                    default=2048,
+                    min=1
+                ),
+                ParameterRule(
+                    name='num_gpu',
+                    label=I18nObject(en_US="Num GPU"),
+                    type=ParameterType.INT,
+                    help=I18nObject(en_US="The number of layers to send to the GPU(s). "
+                                          "On macOS it defaults to 1 to enable metal support, 0 to disable."),
+                    default=1,
+                    min=0,
+                    max=1
+                ),
+                ParameterRule(
+                    name='num_thread',
+                    label=I18nObject(en_US="Num Thread"),
+                    type=ParameterType.INT,
+                    help=I18nObject(en_US="Sets the number of threads to use during computation. "
+                                          "By default, Ollama will detect this for optimal performance. "
+                                          "It is recommended to set this value to the number of physical CPU cores "
+                                          "your system has (as opposed to the logical number of cores)."),
+                    min=1,
+                ),
+                ParameterRule(
+                    name='repeat_last_n',
+                    label=I18nObject(en_US="Repeat last N"),
+                    type=ParameterType.INT,
+                    help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
+                                          "(Default: 64, 0 = disabled, -1 = num_ctx)"),
+                    default=64,
+                    min=-1
+                ),
+                ParameterRule(
+                    name='tfs_z',
+                    label=I18nObject(en_US="TFS Z"),
+                    type=ParameterType.FLOAT,
+                    help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
+                                          "from the output. A higher value (e.g., 2.0) will reduce the impact more, "
+                                          "while a value of 1.0 disables this setting. (default: 1)"),
+                    default=1,
+                    precision=1
+                ),
+                ParameterRule(
+                    name='seed',
+                    label=I18nObject(en_US="Seed"),
+                    type=ParameterType.INT,
+                    help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
+                                          "a specific number will make the model generate the same text for "
+                                          "the same prompt. (Default: 0)"),
+                    default=0
+                ),
+                ParameterRule(
+                    name='format',
+                    label=I18nObject(en_US="Format"),
+                    type=ParameterType.STRING,
+                    help=I18nObject(en_US="the format to return a response in."
+                                          " Currently the only accepted value is json."),
+                    options=['json'],
+                )
+            ],
+            pricing=PriceConfig(
+                input=Decimal(credentials.get('input_price', 0)),
+                output=Decimal(credentials.get('output_price', 0)),
+                unit=Decimal(credentials.get('unit', 0)),
+                currency=credentials.get('currency', "USD")
+            ),
+            **extras
+        )
+
+        return entity
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeAuthorizationError: [
+                requests.exceptions.InvalidHeader,  # Missing or Invalid API Key
+            ],
+            InvokeBadRequestError: [
+                requests.exceptions.HTTPError,  # Invalid Endpoint URL or model name
+                requests.exceptions.InvalidURL,  # Misconfigured request or other API error
+            ],
+            InvokeRateLimitError: [
+                requests.exceptions.RetryError  # Too many requests sent in a short period of time
+            ],
+            InvokeServerUnavailableError: [
+                requests.exceptions.ConnectionError,  # Engine Overloaded
+                requests.exceptions.HTTPError  # Server Error
+            ],
+            InvokeConnectionError: [
+                requests.exceptions.ConnectTimeout,  # Timeout
+                requests.exceptions.ReadTimeout  # Timeout
+            ]
+        }

+ 17 - 0
api/core/model_runtime/model_providers/ollama/ollama.py

@@ -0,0 +1,17 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIProvider(ModelProvider):
+
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        pass

+ 98 - 0
api/core/model_runtime/model_providers/ollama/ollama.yaml

@@ -0,0 +1,98 @@
+provider: ollama
+label:
+  en_US: Ollama
+icon_large:
+  en_US: icon_l_en.svg
+icon_small:
+  en_US: icon_s_en.svg
+background: "#F9FAFB"
+help:
+  title:
+    en_US: How to integrate with Ollama
+    zh_Hans: 如何集成 Ollama
+  url:
+    en_US: https://docs.dify.ai/advanced/model-configuration/ollama
+supported_model_types:
+  - llm
+  - text-embedding
+configurate_methods:
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your model name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: base_url
+      label:
+        zh_Hans: 基础 URL
+        en_US: Base URL
+      type: text-input
+      required: true
+      placeholder:
+        zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434
+        en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
+    - variable: mode
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        zh_Hans: 模型类型
+        en_US: Completion mode
+      type: select
+      required: true
+      default: chat
+      placeholder:
+        zh_Hans: 选择对话类型
+        en_US: Select completion mode
+      options:
+        - value: completion
+          label:
+            en_US: Completion
+            zh_Hans: 补全
+        - value: chat
+          label:
+            en_US: Chat
+            zh_Hans: 对话
+    - variable: context_size
+      label:
+        zh_Hans: 模型上下文长度
+        en_US: Model context size
+      required: true
+      type: text-input
+      default: '4096'
+      placeholder:
+        zh_Hans: 在此输入您的模型上下文长度
+        en_US: Enter your Model context size
+    - variable: max_tokens
+      label:
+        zh_Hans: 最大 token 上限
+        en_US: Upper bound for max tokens
+      show_on:
+        - variable: __model_type
+          value: llm
+      default: '4096'
+      type: text-input
+      required: true
+    - variable: vision_support
+      label:
+        zh_Hans: 是否支持 Vision
+        en_US: Vision support
+      show_on:
+        - variable: __model_type
+          value: llm
+      default: 'false'
+      type: radio
+      required: false
+      options:
+        - value: 'true'
+          label:
+            en_US: Yes
+            zh_Hans: 是
+        - value: 'false'
+          label:
+            en_US: No
+            zh_Hans: 否

+ 0 - 0
api/core/model_runtime/model_providers/ollama/text_embedding/__init__.py


+ 221 - 0
api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py

@@ -0,0 +1,221 @@
+import logging
+import time
+from decimal import Decimal
+from typing import Optional
+from urllib.parse import urljoin
+import requests
+import json
+
+import numpy as np
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import PriceType, ModelPropertyKey, ModelType, AIModelEntity, FetchFrom, \
+    PriceConfig
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
+from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
+    InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+
+logger = logging.getLogger(__name__)
+
+
+class OllamaEmbeddingModel(TextEmbeddingModel):
+    """
+    Model class for an Ollama text embedding model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+
+        # Prepare headers and payload for the request
+        headers = {
+            'Content-Type': 'application/json'
+        }
+
+        endpoint_url = credentials.get('base_url')
+        if not endpoint_url.endswith('/'):
+            endpoint_url += '/'
+
+        endpoint_url = urljoin(endpoint_url, 'api/embeddings')
+
+        # get model properties
+        context_size = self._get_context_size(model, credentials)
+
+        inputs = []
+        used_tokens = 0
+
+        for i, text in enumerate(texts):
+            # Here token count is only an approximation based on the GPT2 tokenizer
+            num_tokens = self._get_num_tokens_by_gpt2(text)
+
+            if num_tokens >= context_size:
+                cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
+                # if num tokens is larger than context length, only use the start
+                inputs.append(text[0: cutoff])
+            else:
+                inputs.append(text)
+
+        batched_embeddings = []
+
+        for text in inputs:
+            # Prepare the payload for the request
+            payload = {
+                'prompt': text,
+                'model': model,
+            }
+
+            # Make the request to the OpenAI API
+            response = requests.post(
+                endpoint_url,
+                headers=headers,
+                data=json.dumps(payload),
+                timeout=(10, 300)
+            )
+
+            response.raise_for_status()  # Raise an exception for HTTP errors
+            response_data = response.json()
+
+            # Extract embeddings and used tokens from the response
+            embeddings = response_data['embedding']
+            embedding_used_tokens = self.get_num_tokens(model, credentials, [text])
+
+            used_tokens += embedding_used_tokens
+            batched_embeddings.append(embeddings)
+
+        # calc usage
+        usage = self._calc_response_usage(
+            model=model,
+            credentials=credentials,
+            tokens=used_tokens
+        )
+
+        return TextEmbeddingResult(
+            embeddings=batched_embeddings,
+            usage=usage,
+            model=model
+        )
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Approximate number of tokens for given messages using GPT2 tokenizer
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            self._invoke(
+                model=model,
+                credentials=credentials,
+                texts=['ping']
+            )
+        except InvokeError as ex:
+            raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
+        except Exception as ex:
+            raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        """
+            generate custom model entities from credentials
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            model_type=ModelType.TEXT_EMBEDDING,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
+                ModelPropertyKey.MAX_CHUNKS: 1,
+            },
+            parameter_rules=[],
+            pricing=PriceConfig(
+                input=Decimal(credentials.get('input_price', 0)),
+                unit=Decimal(credentials.get('unit', 0)),
+                currency=credentials.get('currency', "USD")
+            )
+        )
+
+        return entity
+
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeAuthorizationError: [
+                requests.exceptions.InvalidHeader,  # Missing or Invalid API Key
+            ],
+            InvokeBadRequestError: [
+                requests.exceptions.HTTPError,  # Invalid Endpoint URL or model name
+                requests.exceptions.InvalidURL,  # Misconfigured request or other API error
+            ],
+            InvokeRateLimitError: [
+                requests.exceptions.RetryError  # Too many requests sent in a short period of time
+            ],
+            InvokeServerUnavailableError: [
+                requests.exceptions.ConnectionError,  # Engine Overloaded
+                requests.exceptions.HTTPError  # Server Error
+            ],
+            InvokeConnectionError: [
+                requests.exceptions.ConnectTimeout,  # Timeout
+                requests.exceptions.ReadTimeout  # Timeout
+            ]
+        }

+ 1 - 0
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -360,6 +360,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                         message=AssistantPromptMessage(content=""),
                         finish_reason="Non-JSON encountered."
                     )
+                    break
 
                 if not chunk_json or len(chunk_json['choices']) == 0:
                     continue

+ 2 - 2
api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml

@@ -33,8 +33,8 @@ model_credential_schema:
       type: text-input
       required: true
       placeholder:
-        zh_Hans: Base URL, eg. https://api.openai.com/v1
-        en_US: Base URL, eg. https://api.openai.com/v1
+        zh_Hans: Base URL, e.g. https://api.openai.com/v1
+        en_US: Base URL, e.g. https://api.openai.com/v1
     - variable: mode
       show_on:
         - variable: __model_type

+ 2 - 2
api/core/model_runtime/model_providers/openllm/openllm.yaml

@@ -33,5 +33,5 @@ model_credential_schema:
       type: text-input
       required: true
       placeholder:
-        zh_Hans: 在此输入OpenLLM的服务器地址,如 https://example.com/xxx
-        en_US: Enter the url of your OpenLLM, for example https://example.com/xxx
+        zh_Hans: 在此输入OpenLLM的服务器地址,如 http://192.168.1.100:3000
+        en_US: Enter the url of your OpenLLM, e.g. http://192.168.1.100:3000

+ 2 - 2
api/core/model_runtime/model_providers/xinference/xinference.yaml

@@ -34,8 +34,8 @@ model_credential_schema:
       type: secret-input
       required: true
       placeholder:
-        zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
-        en_US: Enter the url of your Xinference, for example https://example.com/xxx
+        zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
+        en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997
     - variable: model_uid
       label:
         zh_Hans: 模型UID

+ 18 - 2
api/core/prompt/prompt_transform.py

@@ -121,6 +121,7 @@ class PromptTransform:
                     prompt_template_entity=prompt_template_entity,
                     inputs=inputs,
                     query=query,
+                    files=files,
                     context=context,
                     memory=memory,
                     model_config=model_config
@@ -343,7 +344,14 @@ class PromptTransform:
 
             prompt_message = UserPromptMessage(content=prompt_message_contents)
         else:
-            prompt_message = UserPromptMessage(content=prompt)
+            if files:
+                prompt_message_contents = [TextPromptMessageContent(data=prompt)]
+                for file in files:
+                    prompt_message_contents.append(file.prompt_message_content)
+
+                prompt_message = UserPromptMessage(content=prompt_message_contents)
+            else:
+                prompt_message = UserPromptMessage(content=prompt)
 
         return [prompt_message]
 
@@ -434,6 +442,7 @@ class PromptTransform:
                                                        prompt_template_entity: PromptTemplateEntity,
                                                        inputs: dict,
                                                        query: str,
+                                                       files: List[FileObj],
                                                        context: Optional[str],
                                                        memory: Optional[TokenBufferMemory],
                                                        model_config: ModelConfigEntity) -> List[PromptMessage]:
@@ -461,7 +470,14 @@ class PromptTransform:
 
         prompt = self._format_prompt(prompt_template, prompt_inputs)
 
-        prompt_messages.append(UserPromptMessage(content=prompt))
+        if files:
+            prompt_message_contents = [TextPromptMessageContent(data=prompt)]
+            for file in files:
+                prompt_message_contents.append(file.prompt_message_content)
+
+            prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
+        else:
+            prompt_messages.append(UserPromptMessage(content=prompt))
 
         return prompt_messages
 

+ 3 - 0
api/tests/integration_tests/.env.example

@@ -62,5 +62,8 @@ COHERE_API_KEY=
 # Jina Credentials
 JINA_API_KEY=
 
+# Ollama Credentials
+OLLAMA_BASE_URL=
+
 # Mock Switch
 MOCK_SWITCH=false

+ 0 - 0
api/tests/integration_tests/model_runtime/ollama/__init__.py


文件差異過大導致無法顯示
+ 190 - 0
api/tests/integration_tests/model_runtime/ollama/test_llm.py


+ 71 - 0
api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py

@@ -0,0 +1,71 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel
+
+
+def test_validate_credentials():
+    model = OllamaEmbeddingModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='mistral:text',
+            credentials={
+                'base_url': 'http://localhost:21434',
+                'mode': 'chat',
+                'context_size': 4096,
+            }
+        )
+
+    model.validate_credentials(
+        model='mistral:text',
+        credentials={
+            'base_url': os.environ.get('OLLAMA_BASE_URL'),
+            'mode': 'chat',
+            'context_size': 4096,
+        }
+    )
+
+
+def test_invoke_model():
+    model = OllamaEmbeddingModel()
+
+    result = model.invoke(
+        model='mistral:text',
+        credentials={
+            'base_url': os.environ.get('OLLAMA_BASE_URL'),
+            'mode': 'chat',
+            'context_size': 4096,
+        },
+        texts=[
+            "hello",
+            "world"
+        ],
+        user="abc-123"
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 2
+    assert result.usage.total_tokens == 2
+
+
+def test_get_num_tokens():
+    model = OllamaEmbeddingModel()
+
+    num_tokens = model.get_num_tokens(
+        model='mistral:text',
+        credentials={
+            'base_url': os.environ.get('OLLAMA_BASE_URL'),
+            'mode': 'chat',
+            'context_size': 4096,
+        },
+        texts=[
+            "hello",
+            "world"
+        ]
+    )
+
+    assert num_tokens == 2

部分文件因文件數量過多而無法顯示