Przeglądaj źródła

feat: add volcengine maas model provider (#4142)

sino 11 miesięcy temu
rodzic
commit
4aa21242b6
25 zmienionych plików z 1818 dodań i 1 usunięć
  1. 1 0
      api/core/model_runtime/model_providers/_position.yaml
  2. 0 0
      api/core/model_runtime/model_providers/volcengine_maas/__init__.py
  3. 23 0
      api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg
  4. 23 0
      api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg
  5. 8 0
      api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg
  6. 108 0
      api/core/model_runtime/model_providers/volcengine_maas/client.py
  7. 156 0
      api/core/model_runtime/model_providers/volcengine_maas/errors.py
  8. 0 0
      api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py
  9. 284 0
      api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py
  10. 12 0
      api/core/model_runtime/model_providers/volcengine_maas/llm/models.py
  11. 0 0
      api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py
  12. 132 0
      api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py
  13. 4 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py
  14. 1 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py
  15. 144 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py
  16. 207 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py
  17. 43 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py
  18. 79 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py
  19. 213 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py
  20. 10 0
      api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py
  21. 151 0
      api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml
  22. 7 1
      api/tests/integration_tests/.env.example
  23. 0 0
      api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py
  24. 81 0
      api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py
  25. 131 0
      api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py

+ 1 - 0
api/core/model_runtime/model_providers/_position.yaml

@@ -26,5 +26,6 @@
 - yi
 - openllm
 - localai
+- volcengine_maas
 - openai_api_compatible
 - deepseek

+ 0 - 0
api/core/model_runtime/model_providers/volcengine_maas/__init__.py


+ 23 - 0
api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg

@@ -0,0 +1,23 @@
+<?xml version="1.0" encoding="utf-8"?>
+<svg viewBox="540.1546 563.7749 150.5263 25" xmlns="http://www.w3.org/2000/svg">
+  <g clip-path="url(#a)" transform="matrix(0.17650800943374634, 0, 0, -0.17650800943374634, 445.969970703125, 671.58935546875)">
+    <path fill="#00e5e5" d="M618.577 552.875l-13.795-59.901c-.218-.948.502-1.853 1.475-1.853h27.575c.973 0 1.693.905 1.474 1.853l-13.795 59.901c-.358 1.557-2.575 1.557-2.934 0"/>
+    <path fill="#00e5e5" d="M542.203 530.153l-8.564-37.188c-.218-.944.499-1.844 1.467-1.844h17.135c.968 0 1.684.9 1.467 1.844l-8.565 37.188c-.359 1.561-2.581 1.561-2.94 0"/>
+    <path fill="#006eff" d="M561.028 575.58l-19.024-82.607c-.219-.948.502-1.853 1.474-1.853h38.034c.973 0 1.693.905 1.474 1.853l-19.024 82.607c-.358 1.557-2.575 1.557-2.934 0"/>
+    <path fill="#006eff" d="M592.95 609.651l-26.871-116.678c-.218-.947.502-1.852 1.475-1.852h53.726c.973 0 1.693.905 1.475 1.852l-26.871 116.678c-.358 1.557-2.575 1.557-2.934 0"/>
+    <path fill="#00e5e5" d="M575.175 564.245l-16.414-71.271c-.218-.948.502-1.853 1.475-1.853h32.812c.973 0 1.693.905 1.475 1.853l-16.414 71.271c-.358 1.557-2.576 1.557-2.934 0"/>
+    <path fill="#4b4b4b" d="M677.358 492.868l-24.407 72.089h13.49l20.804-64.365 21.318 64.365h13.285l-24.614-72.089z"/>
+    <path fill="#4b4b4b" d="M744.98 502.857c3.09 0 5.922.788 8.496 2.368 2.575 1.579 4.617 3.691 6.128 6.333 1.509 2.643 2.266 5.544 2.266 8.703 0 3.09-.757 5.973-2.266 8.651-1.511 2.677-3.553 4.789-6.128 6.334-2.574 1.545-5.406 2.317-8.496 2.317-3.159 0-6.025-.79-8.6-2.369-2.574-1.58-4.618-3.674-6.127-6.282-1.511-2.611-2.266-5.493-2.266-8.651 0-3.159.755-6.06 2.266-8.703 1.509-2.642 3.553-4.754 6.127-6.333 2.575-1.58 5.441-2.368 8.6-2.368m0-10.813c-5.149 0-9.853 1.269-14.109 3.81-4.258 2.539-7.657 5.974-10.196 10.298-2.541 4.326-3.81 9.027-3.81 14.109 0 5.08 1.269 9.766 3.81 14.058 2.539 4.29 5.938 7.706 10.196 10.247 4.256 2.539 8.96 3.811 14.109 3.811 5.08 0 9.766-1.272 14.057-3.811 4.291-2.541 7.689-5.974 10.196-10.299 2.506-4.325 3.759-8.995 3.759-14.006 0-5.082-1.253-9.783-3.759-14.109-2.507-4.324-5.905-7.759-10.196-10.298-4.291-2.541-8.977-3.81-14.057-3.81"/>
+    <path fill="#4b4b4b" d="M782.261 564.957h11.844v-72.089h-11.844z"/>
+    <path fill="#4b4b4b" d="M831.693 492.044c-5.149 0-9.887 1.27-14.212 3.811-4.325 2.539-7.708 5.973-10.144 10.297-2.438 4.327-3.656 8.995-3.656 14.006 0 5.15 1.252 9.887 3.759 14.212 2.505 4.326 5.904 7.74 10.195 10.247 4.29 2.506 8.976 3.76 14.058 3.76 5.835 0 11.071-1.632 15.705-4.893 4.635-3.26 7.982-7.535 10.042-12.821l-10.917-2.986c-1.374 3.02-3.365 5.44-5.973 7.261-2.61 1.817-5.561 2.728-8.857 2.728-3.159 0-6.025-.79-8.599-2.368-2.575-1.58-4.619-3.708-6.128-6.386-1.511-2.678-2.266-5.597-2.266-8.754 0-3.158.755-6.076 2.266-8.753 1.509-2.677 3.569-4.789 6.179-6.333 2.609-1.546 5.458-2.317 8.548-2.317 3.363 0 6.333.908 8.908 2.728 2.575 1.819 4.548 4.205 5.922 7.158l10.917-2.987c-2.129-5.217-5.494-9.458-10.093-12.717-4.601-3.263-9.819-4.893-15.654-4.893"/>
+    <path fill="#4b4b4b" d="M892.043 502.755c3.09 0 5.955.788 8.599 2.368 2.642 1.578 4.754 3.708 6.334 6.386 1.578 2.677 2.369 5.559 2.369 8.649 0 3.09-.791 5.974-2.369 8.652-1.58 2.677-3.692 4.822-6.334 6.436-2.644 1.613-5.509 2.42-8.599 2.42-3.159 0-6.06-.79-8.703-2.368-2.643-1.58-4.721-3.708-6.23-6.386-1.511-2.678-2.266-5.596-2.266-8.754 0-3.158.755-6.059 2.266-8.701 1.509-2.644 3.569-4.756 6.179-6.334 2.608-1.58 5.526-2.368 8.754-2.368m-1.03-10.711c-4.943 0-9.492 1.27-13.646 3.811-4.155 2.539-7.45 5.973-9.886 10.297-2.438 4.327-3.656 8.995-3.656 14.006 0 5.15 1.218 9.887 3.656 14.212 2.436 4.326 5.731 7.74 9.886 10.247 4.154 2.506 8.703 3.76 13.646 3.76 3.638 0 6.968-.739 9.99-2.214 3.02-1.477 5.526-3.486 7.518-6.025v7.312h11.843v-54.582h-11.843v7.313c-1.992-2.541-4.498-4.533-7.518-5.974-3.022-1.442-6.352-2.163-9.99-2.163"/>
+    <path fill="#4b4b4b" d="M931.692 492.868v54.582h11.74v-7.209c1.854 2.402 4.239 4.359 7.158 5.87 2.917 1.51 6.23 2.266 9.938 2.266 3.981 0 7.724-.894 11.226-2.678 3.501-1.786 6.315-4.361 8.444-7.725 2.128-3.364 3.193-7.244 3.193-11.636v-33.47h-11.843v30.277c0 2.883-.602 5.389-1.803 7.518-1.202 2.128-2.816 3.774-4.84 4.943-2.026 1.167-4.274 1.751-6.746 1.751-2.816 0-5.355-.721-7.621-2.163-2.265-1.441-4.016-3.398-5.252-5.87-1.236-2.472-1.854-5.288-1.854-8.444v-28.012z"/>
+    <path fill="#4b4b4b" d="M1019.437 502.857c3.09 0 5.922.788 8.496 2.368 2.575 1.579 4.617 3.691 6.128 6.333 1.509 2.643 2.266 5.544 2.266 8.703 0 3.09-.757 5.973-2.266 8.651-1.511 2.677-3.553 4.789-6.128 6.334-2.574 1.545-5.406 2.317-8.496 2.317-3.159 0-6.025-.79-8.599-2.369-2.575-1.58-4.619-3.674-6.128-6.282-1.511-2.611-2.266-5.493-2.266-8.651 0-3.159.755-6.06 2.266-8.703 1.509-2.642 3.553-4.754 6.128-6.333 2.574-1.58 5.44-2.368 8.599-2.368m0-10.813c-5.149 0-9.853 1.269-14.109 3.81-4.258 2.539-7.657 5.974-10.196 10.298-2.541 4.326-3.81 9.027-3.81 14.109 0 5.08 1.269 9.766 3.81 14.058 2.539 4.29 5.938 7.706 10.196 10.247 4.256 2.539 8.96 3.811 14.109 3.811 5.08 0 9.766-1.272 14.058-3.811 4.29-2.541 7.688-5.974 10.195-10.299 2.506-4.325 3.759-8.995 3.759-14.006 0-5.082-1.253-9.783-3.759-14.109-2.507-4.324-5.905-7.759-10.195-10.298-4.292-2.541-8.978-3.81-14.058-3.81"/>
+    <path fill="#4b4b4b" d="M1057.026 492.868v72.089h51.287v-11.328h-39.135v-19.156h35.016v-11.122h-35.016v-19.155h39.135v-11.328z"/>
+    <path fill="#4b4b4b" d="M1118.92 492.868v54.582h11.74v-7.209c1.854 2.402 4.239 4.359 7.158 5.87 2.917 1.51 6.231 2.266 9.938 2.266 3.981 0 7.724-.894 11.226-2.678 3.501-1.786 6.316-4.361 8.444-7.725 2.128-3.364 3.193-7.244 3.193-11.636v-33.47h-11.843v30.277c0 2.883-.602 5.389-1.803 7.518-1.202 2.128-2.815 3.774-4.84 4.943-2.026 1.167-4.274 1.751-6.745 1.751-2.816 0-5.356-.721-7.621-2.163-2.266-1.441-4.017-3.398-5.253-5.87-1.236-2.472-1.854-5.288-1.854-8.444v-28.012z"/>
+    <path fill="#4b4b4b" d="M1207.077 504.094c3.09 0 5.955.754 8.599 2.266 2.642 1.508 4.754 3.551 6.334 6.127 1.578 2.574 2.369 5.406 2.369 8.496 0 3.089-.773 5.922-2.317 8.496-1.545 2.574-3.657 4.599-6.334 6.076-2.678 1.476-5.561 2.214-8.651 2.214-3.159 0-6.042-.738-8.651-2.214-2.61-1.477-4.686-3.502-6.231-6.076s-2.317-5.407-2.317-8.496c0-3.09.772-5.94 2.317-8.547 1.545-2.611 3.639-4.654 6.283-6.128 2.642-1.477 5.509-2.214 8.599-2.214m1.339-34.912c-4.052 0-7.948.756-11.689 2.266-3.743 1.51-6.969 3.809-9.681 6.899-2.713 3.09-4.618 6.899-5.716 11.431h11.844c.824-2.266 2.024-4.12 3.604-5.56 1.579-1.442 3.363-2.473 5.356-3.091 1.99-.617 4.05-.927 6.179-.927 2.883 0 5.474.584 7.775 1.752 2.3 1.168 4.12 2.986 5.458 5.458 1.339 2.471 2.009 5.595 2.009 9.372v4.427c-1.992-2.334-4.48-4.188-7.467-5.561-2.986-1.373-6.333-2.059-10.041-2.059-5.149 0-9.801 1.2-13.955 3.605-4.154 2.402-7.398 5.697-9.732 9.885-2.335 4.188-3.501 8.822-3.501 13.904 0 5.216 1.184 9.921 3.553 14.109 2.368 4.187 5.628 7.466 9.783 9.835 4.154 2.369 8.77 3.518 13.852 3.45 3.638-.069 6.968-.807 9.99-2.214 3.02-1.408 5.526-3.382 7.518-5.922v7.209h11.843v-50.463c0-5.767-1.22-10.762-3.656-14.985-2.438-4.221-5.699-7.414-9.784-9.577-4.085-2.163-8.599-3.243-13.542-3.243"/>
+    <path fill="#4b4b4b" d="M1247.035 547.45h11.844v-54.582h-11.844zm6.076 6.385c-2.129 0-3.949.703-5.458 2.111-1.511 1.406-2.266 3.175-2.266 5.304 0 2.059.755 3.81 2.266 5.252 1.509 1.442 3.329 2.163 5.458 2.163 2.128 0 3.93-.721 5.407-2.163 1.476-1.442 2.214-3.193 2.214-5.252 0-2.129-.738-3.898-2.214-5.304-1.477-1.408-3.279-2.111-5.407-2.111"/>
+    <path fill="#4b4b4b" d="M1270.515 492.868v54.582h11.74v-7.209c1.854 2.402 4.239 4.359 7.158 5.87 2.917 1.51 6.23 2.266 9.938 2.266 3.981 0 7.724-.894 11.226-2.678 3.501-1.786 6.315-4.361 8.444-7.725 2.128-3.364 3.193-7.244 3.193-11.636v-33.47h-11.843v30.277c0 2.883-.602 5.389-1.803 7.518-1.202 2.128-2.816 3.774-4.84 4.943-2.026 1.167-4.274 1.751-6.746 1.751-2.816 0-5.355-.721-7.621-2.163-2.265-1.441-4.016-3.398-5.252-5.87-1.236-2.472-1.854-5.288-1.854-8.444v-28.012z"/>
+    <path fill="#4b4b4b" d="M1374.428 524.895c-.481 2.334-1.459 4.496-2.935 6.488-1.477 1.991-3.331 3.57-5.562 4.738-2.232 1.167-4.754 1.751-7.569 1.751-2.816 0-5.373-.584-7.672-1.751-2.302-1.168-4.189-2.747-5.665-4.738-1.477-1.992-2.455-4.154-2.935-6.488zm-15.963-32.852c-5.149 0-9.87 1.27-14.16 3.811-4.292 2.539-7.673 5.974-10.145 10.298-2.471 4.326-3.707 8.994-3.707 14.006 0 5.149 1.252 9.886 3.759 14.212 2.506 4.326 5.904 7.74 10.196 10.247 4.29 2.506 8.976 3.759 14.057 3.759 4.531 0 8.754-1.03 12.668-3.089 3.913-2.06 7.157-4.859 9.732-8.394 2.574-3.537 4.238-7.502 4.995-11.894.685-3.09.72-6.146.102-9.167h-43.872c.48-2.746 1.494-5.115 3.038-7.105 1.545-1.992 3.467-3.519 5.768-4.583 2.299-1.065 4.822-1.597 7.569-1.597 2.884 0 5.681.602 8.394 1.803 2.711 1.2 4.959 2.832 6.745 4.892l10.402-2.678c-2.129-4.394-5.597-7.913-10.402-10.556-4.806-2.643-9.853-3.965-15.139-3.965"/>
+  </g>
+</svg>

Plik diff jest za duży
+ 23 - 0
api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg


+ 8 - 0
api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="utf-8"?>
+<svg viewBox="-0.006 0 24.6978 24.9156" xmlns="http://www.w3.org/2000/svg">
+  <path d="M20.511 15.3019L17.2442 28.1928C17.2362 28.2282 17.2364 28.2649 17.2447 28.3001C17.2531 28.3354 17.2694 28.3683 17.2923 28.3964C17.3153 28.4244 17.3444 28.4468 17.3773 28.4619C17.4103 28.477 17.4462 28.4844 17.4825 28.4835H24.0137C24.0499 28.4844 24.0859 28.477 24.1188 28.4619C24.1518 28.4468 24.1809 28.4244 24.2038 28.3964C24.2268 28.3683 24.2431 28.3354 24.2514 28.3001C24.2598 28.2649 24.26 28.2282 24.252 28.1928L20.9685 15.3019C20.9541 15.2524 20.924 15.209 20.8827 15.178C20.8415 15.1471 20.7913 15.1304 20.7397 15.1304C20.6882 15.1304 20.638 15.1471 20.5968 15.178C20.5555 15.209 20.5254 15.2524 20.511 15.3019V15.3019Z" fill="#00E5E5" transform="matrix(1.0178890228271484, 0, 0, 1.0178890228271484, -1.952212187461555e-7, -4.077521402283104)"/>
+  <path d="M2.53051 18.2228L-5.28338e-06 28.1924C-0.00799016 28.2277 -0.00780431 28.2644 0.000538111 28.2997C0.00888053 28.335 0.0251596 28.3679 0.0481365 28.3959C0.0711133 28.4239 0.100182 28.4464 0.133131 28.4615C0.166079 28.4766 0.202039 28.484 0.238273 28.4831H5.28025C5.31649 28.484 5.35245 28.4766 5.38539 28.4615C5.41834 28.4464 5.44741 28.4239 5.47039 28.3959C5.49336 28.3679 5.50964 28.335 5.51799 28.2997C5.52633 28.2644 5.52651 28.2277 5.51853 28.1924L2.98563 18.2228C2.97054 18.1742 2.94032 18.1318 2.89938 18.1016C2.85844 18.0714 2.80892 18.0552 2.75807 18.0552C2.70722 18.0552 2.6577 18.0714 2.61676 18.1016C2.57582 18.1318 2.5456 18.1742 2.53051 18.2228V18.2228Z" fill="#00E5E5" transform="matrix(1.0178890228271484, 0, 0, 1.0178890228271484, -1.952212187461555e-7, -4.077521402283104)"/>
+  <path d="M6.99344 9.96839L2.38275 28.1919C2.37498 28.2263 2.37494 28.262 2.38262 28.2964C2.3903 28.3308 2.40552 28.363 2.42717 28.3908C2.44882 28.4186 2.47637 28.4413 2.50783 28.4572C2.53929 28.473 2.57388 28.4817 2.60911 28.4826H11.8329C11.8691 28.4835 11.9051 28.4761 11.938 28.461C11.971 28.4459 12 28.4235 12.023 28.3955C12.046 28.3675 12.0623 28.3345 12.0706 28.2993C12.079 28.264 12.0791 28.2273 12.0712 28.1919L7.44855 9.96839C7.43347 9.91982 7.40325 9.87736 7.36231 9.8472C7.32136 9.81705 7.27185 9.80078 7.221 9.80078C7.17015 9.80078 7.12063 9.81705 7.07969 9.8472C7.03874 9.87736 7.00852 9.91982 6.99344 9.96839Z" fill="#006EFF" transform="matrix(1.0178890228271484, 0, 0, 1.0178890228271484, -1.952212187461555e-7, -4.077521402283104)"/>
+  <path d="M14.9472 4.17346C14.9321 4.1249 14.9019 4.08244 14.861 4.05228C14.82 4.02213 14.7705 4.00586 14.7197 4.00586C14.6688 4.00586 14.6193 4.02213 14.5784 4.05228C14.5374 4.08244 14.5072 4.1249 14.4921 4.17346L8.18963 28.192C8.18165 28.2273 8.18183 28.264 8.19017 28.2993C8.19852 28.3346 8.2148 28.3675 8.23777 28.3955C8.26075 28.4235 8.28982 28.446 8.32277 28.4611C8.35572 28.4762 8.39168 28.4835 8.42791 28.4827H21.0233C21.0596 28.4835 21.0955 28.4762 21.1285 28.4611C21.1614 28.446 21.1905 28.4235 21.2135 28.3955C21.2364 28.3675 21.2527 28.3346 21.2611 28.2993C21.2694 28.264 21.2696 28.2273 21.2616 28.192L14.9472 4.17346Z" fill="#006EFF" transform="matrix(1.0178890228271484, 0, 0, 1.0178890228271484, -1.952212187461555e-7, -4.077521402283104)"/>
+  <path d="M10.3175 12.6188L6.31915 28.1903C6.31074 28.2258 6.31061 28.2628 6.31875 28.2984C6.3269 28.3339 6.34311 28.3672 6.36614 28.3955C6.38916 28.4238 6.41839 28.4465 6.45155 28.4617C6.48472 28.4769 6.52094 28.4844 6.55743 28.4834H14.535C14.5715 28.4844 14.6077 28.4769 14.6409 28.4617C14.674 28.4465 14.7033 28.4238 14.7263 28.3955C14.7493 28.3672 14.7655 28.3339 14.7737 28.2984C14.7818 28.2628 14.7817 28.2258 14.7733 28.1903L10.7726 12.6188C10.7575 12.5702 10.7273 12.5278 10.6863 12.4976C10.6454 12.4674 10.5959 12.4512 10.545 12.4512C10.4942 12.4512 10.4447 12.4674 10.4037 12.4976C10.3628 12.5278 10.3326 12.5702 10.3175 12.6188Z" fill="#00E5E5" transform="matrix(1.0178890228271484, 0, 0, 1.0178890228271484, -1.952212187461555e-7, -4.077521402283104)"/>
+</svg>

+ 108 - 0
api/core/model_runtime/model_providers/volcengine_maas/client.py

@@ -0,0 +1,108 @@
+import re
+from collections.abc import Callable, Generator
+from typing import cast
+
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    ImagePromptMessageContent,
+    PromptMessage,
+    PromptMessageContentType,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService
+
+
+class MaaSClient(MaasService):
+    def __init__(self, host: str, region: str):
+        self.endpoint_id = None
+        super().__init__(host, region)
+
+    def set_endpoint_id(self, endpoint_id: str):
+        self.endpoint_id = endpoint_id
+
+    @classmethod
+    def from_credential(cls, credentials: dict) -> 'MaaSClient':
+        host = credentials['api_endpoint_host']
+        region = credentials['volc_region']
+        ak = credentials['volc_access_key_id']
+        sk = credentials['volc_secret_access_key']
+        endpoint_id = credentials['endpoint_id']
+
+        client = cls(host, region)
+        client.set_endpoint_id(endpoint_id)
+        client.set_ak(ak)
+        client.set_sk(sk)
+        return client
+
+    def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict:
+        req = {
+            'parameters': params,
+            'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages]
+        }
+        if not stream:
+            return super().chat(
+                self.endpoint_id,
+                req,
+            )
+        return super().stream_chat(
+            self.endpoint_id,
+            req,
+        )
+
+    def embeddings(self, texts: list[str]) -> dict:
+        req = {
+            'input': texts
+        }
+        return super().embeddings(self.endpoint_id, req)
+
+    @staticmethod
+    def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
+        if isinstance(message, UserPromptMessage):
+            message = cast(UserPromptMessage, message)
+            if isinstance(message.content, str):
+                message_dict = {"role": ChatRole.USER,
+                                "content": message.content}
+            else:
+                content = []
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        raise ValueError(
+                            'Content object type only support image_url')
+                    elif message_content.type == PromptMessageContentType.IMAGE:
+                        message_content = cast(
+                            ImagePromptMessageContent, message_content)
+                        image_data = re.sub(
+                            r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+                        content.append({
+                            'type': 'image_url',
+                            'image_url': {
+                                'url': '',
+                                'image_bytes': image_data,
+                                'detail': message_content.detail,
+                            }
+                        })
+
+                message_dict = {'role': ChatRole.USER, 'content': content}
+        elif isinstance(message, AssistantPromptMessage):
+            message = cast(AssistantPromptMessage, message)
+            message_dict = {'role': ChatRole.ASSISTANT,
+                            'content': message.content}
+        elif isinstance(message, SystemPromptMessage):
+            message = cast(SystemPromptMessage, message)
+            message_dict = {'role': ChatRole.SYSTEM,
+                            'content': message.content}
+        else:
+            raise ValueError(f"Got unknown PromptMessage type {message}")
+
+        return message_dict
+
+    @staticmethod
+    def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
+        try:
+            resp = fn()
+        except MaasException as e:
+            raise wrap_error(e)
+
+        return resp

+ 156 - 0
api/core/model_runtime/model_providers/volcengine_maas/errors.py

@@ -0,0 +1,156 @@
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+
+class ClientSDKRequestError(MaasException):
+    pass
+
+
+class SignatureDoesNotMatch(MaasException):
+    pass
+
+
+class RequestTimeout(MaasException):
+    pass
+
+
+class ServiceConnectionTimeout(MaasException):
+    pass
+
+
+class MissingAuthenticationHeader(MaasException):
+    pass
+
+
+class AuthenticationHeaderIsInvalid(MaasException):
+    pass
+
+
+class InternalServiceError(MaasException):
+    pass
+
+
+class MissingParameter(MaasException):
+    pass
+
+
+class InvalidParameter(MaasException):
+    pass
+
+
+class AuthenticationExpire(MaasException):
+    pass
+
+
+class EndpointIsInvalid(MaasException):
+    pass
+
+
+class EndpointIsNotEnable(MaasException):
+    pass
+
+
+class ModelNotSupportStreamMode(MaasException):
+    pass
+
+
+class ReqTextExistRisk(MaasException):
+    pass
+
+
+class RespTextExistRisk(MaasException):
+    pass
+
+
+class EndpointRateLimitExceeded(MaasException):
+    pass
+
+
+class ServiceConnectionRefused(MaasException):
+    pass
+
+
+class ServiceConnectionClosed(MaasException):
+    pass
+
+
+class UnauthorizedUserForEndpoint(MaasException):
+    pass
+
+
+class InvalidEndpointWithNoURL(MaasException):
+    pass
+
+
+class EndpointAccountRpmRateLimitExceeded(MaasException):
+    pass
+
+
+class EndpointAccountTpmRateLimitExceeded(MaasException):
+    pass
+
+
+class ServiceResourceWaitQueueFull(MaasException):
+    pass
+
+
+class EndpointIsPending(MaasException):
+    pass
+
+
+class ServiceNotOpen(MaasException):
+    pass
+
+
+AuthErrors = {
+    'SignatureDoesNotMatch': SignatureDoesNotMatch,
+    'MissingAuthenticationHeader': MissingAuthenticationHeader,
+    'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid,
+    'AuthenticationExpire': AuthenticationExpire,
+    'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint,
+}
+
+BadRequestErrors = {
+    'MissingParameter': MissingParameter,
+    'InvalidParameter': InvalidParameter,
+    'EndpointIsInvalid': EndpointIsInvalid,
+    'EndpointIsNotEnable': EndpointIsNotEnable,
+    'ModelNotSupportStreamMode': ModelNotSupportStreamMode,
+    'ReqTextExistRisk': ReqTextExistRisk,
+    'RespTextExistRisk': RespTextExistRisk,
+    'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL,
+    'ServiceNotOpen': ServiceNotOpen,
+}
+
+RateLimitErrors = {
+    'EndpointRateLimitExceeded': EndpointRateLimitExceeded,
+    'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded,
+    'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded,
+}
+
+ServerUnavailableErrors = {
+    'InternalServiceError': InternalServiceError,
+    'EndpointIsPending': EndpointIsPending,
+    'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull,
+}
+
+ConnectionErrors = {
+    'ClientSDKRequestError': ClientSDKRequestError,
+    'RequestTimeout': RequestTimeout,
+    'ServiceConnectionTimeout': ServiceConnectionTimeout,
+    'ServiceConnectionRefused': ServiceConnectionRefused,
+    'ServiceConnectionClosed': ServiceConnectionClosed,
+}
+
+ErrorCodeMap = {
+    **AuthErrors,
+    **BadRequestErrors,
+    **RateLimitErrors,
+    **ServerUnavailableErrors,
+    **ConnectionErrors,
+}
+
+
+def wrap_error(e: MaasException) -> Exception:
+    if ErrorCodeMap.get(e.code):
+        return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
+    return e

+ 0 - 0
api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py


+ 284 - 0
api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py

@@ -0,0 +1,284 @@
+import logging
+from collections.abc import Generator
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    PromptMessageTool,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    FetchFrom,
+    ModelPropertyKey,
+    ModelType,
+    ParameterRule,
+    ParameterType,
+)
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
+from core.model_runtime.model_providers.volcengine_maas.errors import (
+    AuthErrors,
+    BadRequestErrors,
+    ConnectionErrors,
+    RateLimitErrors,
+    ServerUnavailableErrors,
+)
+from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+logger = logging.getLogger(__name__)
+
+
+class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
+    def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+                stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+            -> LLMResult | Generator:
+        return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate credentials
+        """
+        # ping
+        client = MaaSClient.from_credential(credentials)
+        try:
+            client.chat(
+                {
+                    'max_new_tokens': 16,
+                    'temperature': 0.7,
+                    'top_p': 0.9,
+                    'top_k': 15,
+                },
+                [UserPromptMessage(content='ping\nAnswer: ')],
+            )
+        except MaasException as e:
+            raise CredentialsValidateFailedError(e.message)
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: list[PromptMessageTool] | None = None) -> int:
+        if len(prompt_messages) == 0:
+            return 0
+        return self._num_tokens_from_messages(prompt_messages)
+
+    def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
+        """
+        Calculate num tokens.
+
+        :param messages: messages
+        """
+        num_tokens = 0
+        messages_dict = [
+            MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
+        for message in messages_dict:
+            for key, value in message.items():
+                num_tokens += self._get_num_tokens_by_gpt2(str(key))
+                num_tokens += self._get_num_tokens_by_gpt2(str(value))
+
+        return num_tokens
+
+    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                  model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+                  stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+            -> LLMResult | Generator:
+
+        client = MaaSClient.from_credential(credentials)
+
+        req_params = ModelConfigs.get(
+            credentials['base_model_name'], {}).get('req_params', {}).copy()
+        if credentials.get('context_size'):
+            req_params['max_prompt_tokens'] = credentials.get('context_size')
+        if credentials.get('max_tokens'):
+            req_params['max_new_tokens'] = credentials.get('max_tokens')
+        if model_parameters.get('max_tokens'):
+            req_params['max_new_tokens'] = model_parameters.get('max_tokens')
+        if model_parameters.get('temperature'):
+            req_params['temperature'] = model_parameters.get('temperature')
+        if model_parameters.get('top_p'):
+            req_params['top_p'] = model_parameters.get('top_p')
+        if model_parameters.get('top_k'):
+            req_params['top_k'] = model_parameters.get('top_k')
+        if model_parameters.get('presence_penalty'):
+            req_params['presence_penalty'] = model_parameters.get(
+                'presence_penalty')
+        if model_parameters.get('frequency_penalty'):
+            req_params['frequency_penalty'] = model_parameters.get(
+                'frequency_penalty')
+        if stop:
+            req_params['stop'] = stop
+
+        resp = MaaSClient.wrap_exception(
+            lambda: client.chat(req_params, prompt_messages, stream))
+        if not stream:
+            return self._handle_chat_response(model, credentials, prompt_messages, resp)
+        return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
+
+    def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
+        for index, r in enumerate(resp):
+            choices = r['choices']
+            if not choices:
+                continue
+            choice = choices[0]
+            message = choice['message']
+            usage = None
+            if r.get('usage'):
+                usage = self._calc_usage(model, credentials, r['usage'])
+            yield LLMResultChunk(
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=LLMResultChunkDelta(
+                    index=index,
+                    message=AssistantPromptMessage(
+                        content=message['content'] if message['content'] else '',
+                        tool_calls=[]
+                    ),
+                    usage=usage,
+                    finish_reason=choice.get('finish_reason'),
+                ),
+            )
+
+    def _handle_chat_response(self,  model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
+        choices = resp['choices']
+        if not choices:
+            return
+        choice = choices[0]
+        message = choice['message']
+
+        return LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=AssistantPromptMessage(
+                content=message['content'] if message['content'] else '',
+                tool_calls=[],
+            ),
+            usage=self._calc_usage(model, credentials, resp['usage']),
+        )
+
+    def _calc_usage(self,  model: str, credentials: dict, usage: dict) -> LLMUsage:
+        return self._calc_response_usage(model=model, credentials=credentials,
+                                         prompt_tokens=usage['prompt_tokens'],
+                                         completion_tokens=usage['completion_tokens']
+                                         )
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+            used to define customizable model schema
+        """
+        max_tokens = ModelConfigs.get(
+            credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
+        if credentials.get('max_tokens'):
+            max_tokens = int(credentials.get('max_tokens'))
+        rules = [
+            ParameterRule(
+                name='temperature',
+                type=ParameterType.FLOAT,
+                use_template='temperature',
+                label=I18nObject(
+                    zh_Hans='温度',
+                    en_US='Temperature'
+                )
+            ),
+            ParameterRule(
+                name='top_p',
+                type=ParameterType.FLOAT,
+                use_template='top_p',
+                label=I18nObject(
+                    zh_Hans='Top P',
+                    en_US='Top P'
+                )
+            ),
+            ParameterRule(
+                name='top_k',
+                type=ParameterType.INT,
+                min=1,
+                default=1,
+                label=I18nObject(
+                    zh_Hans='Top K',
+                    en_US='Top K'
+                )
+            ),
+            ParameterRule(
+                name='presence_penalty',
+                type=ParameterType.FLOAT,
+                use_template='presence_penalty',
+                label={
+                    'en_US': 'Presence Penalty',
+                    'zh_Hans': '存在惩罚',
+                },
+                min=-2.0,
+                max=2.0,
+            ),
+            ParameterRule(
+                name='frequency_penalty',
+                type=ParameterType.FLOAT,
+                use_template='frequency_penalty',
+                label={
+                    'en_US': 'Frequency Penalty',
+                    'zh_Hans': '频率惩罚',
+                },
+                min=-2.0,
+                max=2.0,
+            ),
+            ParameterRule(
+                name='max_tokens',
+                type=ParameterType.INT,
+                use_template='max_tokens',
+                min=1,
+                max=max_tokens,
+                default=512,
+                label=I18nObject(
+                    zh_Hans='最大生成长度',
+                    en_US='Max Tokens'
+                )
+            ),
+        ]
+
+        model_properties = ModelConfigs.get(
+            credentials['base_model_name'], {}).get('model_properties', {}).copy()
+        if credentials.get('mode'):
+            model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
+        if credentials.get('context_size'):
+            model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
+                credentials.get('context_size', 4096))
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=model
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.LLM,
+            model_properties=model_properties,
+            parameter_rules=rules
+        )
+
+        return entity
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: ConnectionErrors.values(),
+            InvokeServerUnavailableError: ServerUnavailableErrors.values(),
+            InvokeRateLimitError: RateLimitErrors.values(),
+            InvokeAuthorizationError: AuthErrors.values(),
+            InvokeBadRequestError: BadRequestErrors.values(),
+        }

+ 12 - 0
api/core/model_runtime/model_providers/volcengine_maas/llm/models.py

@@ -0,0 +1,12 @@
+ModelConfigs = {
+    'Skylark2-pro-4k': {
+        'req_params': {
+            'max_prompt_tokens': 4096,
+            'max_new_tokens': 4000,
+        },
+        'model_properties': {
+            'context_size': 4096,
+            'mode': 'chat',
+        }
+    }
+}

+ 0 - 0
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py


+ 132 - 0
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py

@@ -0,0 +1,132 @@
+import time
+from typing import Optional
+
+from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
+from core.model_runtime.model_providers.volcengine_maas.errors import (
+    AuthErrors,
+    BadRequestErrors,
+    ConnectionErrors,
+    RateLimitErrors,
+    ServerUnavailableErrors,
+)
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+
+class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
+    """
+    Model class for VolcengineMaaS text embedding model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+        client = MaaSClient.from_credential(credentials)
+        resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
+
+        usage = self._calc_response_usage(
+            model=model, credentials=credentials, tokens=resp['total_tokens'])
+
+        result = TextEmbeddingResult(
+            model=model,
+            embeddings=[v['embedding'] for v in resp['data']],
+            usage=usage
+        )
+
+        return result
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        num_tokens = 0
+        for text in texts:
+            # use GPT2Tokenizer to get num tokens
+            num_tokens += self._get_num_tokens_by_gpt2(text)
+        return num_tokens
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            self._invoke(model=model, credentials=credentials, texts=['ping'])
+        except MaasException as e:
+            raise CredentialsValidateFailedError(e.message)
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: ConnectionErrors.values(),
+            InvokeServerUnavailableError: ServerUnavailableErrors.values(),
+            InvokeRateLimitError: RateLimitErrors.values(),
+            InvokeAuthorizationError: AuthErrors.values(),
+            InvokeBadRequestError: BadRequestErrors.values(),
+        }
+
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage

+ 4 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py

@@ -0,0 +1,4 @@
+from .common import ChatRole
+from .maas import MaasException, MaasService
+
+__all__ = ['MaasService', 'ChatRole', 'MaasException']

+ 1 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py

@@ -0,0 +1 @@
+

+ 144 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py

@@ -0,0 +1,144 @@
+# coding : utf-8
+import datetime
+
+import pytz
+
+from .util import Util
+
+
+class MetaData:
+    def __init__(self):
+        self.algorithm = ''
+        self.credential_scope = ''
+        self.signed_headers = ''
+        self.date = ''
+        self.region = ''
+        self.service = ''
+
+    def set_date(self, date):
+        self.date = date
+
+    def set_service(self, service):
+        self.service = service
+
+    def set_region(self, region):
+        self.region = region
+
+    def set_algorithm(self, algorithm):
+        self.algorithm = algorithm
+
+    def set_credential_scope(self, credential_scope):
+        self.credential_scope = credential_scope
+
+    def set_signed_headers(self, signed_headers):
+        self.signed_headers = signed_headers
+
+
+class SignResult:
+    def __init__(self):
+        self.xdate = ''
+        self.xCredential = ''
+        self.xAlgorithm = ''
+        self.xSignedHeaders = ''
+        self.xSignedQueries = ''
+        self.xSignature = ''
+        self.xContextSha256 = ''
+        self.xSecurityToken = ''
+
+        self.authorization = ''
+
+    def __str__(self):
+        return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()])
+
+
+class Credentials:
+    def __init__(self, ak, sk, service, region, session_token=''):
+        self.ak = ak
+        self.sk = sk
+        self.service = service
+        self.region = region
+        self.session_token = session_token
+
+    def set_ak(self, ak):
+        self.ak = ak
+
+    def set_sk(self, sk):
+        self.sk = sk
+
+    def set_session_token(self, session_token):
+        self.session_token = session_token
+
+
+class Signer:
+    @staticmethod
+    def sign(request, credentials):
+        if request.path == '':
+            request.path = '/'
+        if request.method != 'GET' and not ('Content-Type' in request.headers):
+            request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
+
+        format_date = Signer.get_current_format_date()
+        request.headers['X-Date'] = format_date
+        if credentials.session_token != '':
+            request.headers['X-Security-Token'] = credentials.session_token
+
+        md = MetaData()
+        md.set_algorithm('HMAC-SHA256')
+        md.set_service(credentials.service)
+        md.set_region(credentials.region)
+        md.set_date(format_date[:8])
+
+        hashed_canon_req = Signer.hashed_canonical_request_v4(request, md)
+        md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request']))
+
+        signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req])
+        signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
+        sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
+        request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials)
+        return
+
+    @staticmethod
+    def hashed_canonical_request_v4(request, meta):
+        body_hash = Util.sha256(request.body)
+        request.headers['X-Content-Sha256'] = body_hash
+
+        signed_headers = dict()
+        for key in request.headers:
+            if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'):
+                signed_headers[key.lower()] = request.headers[key]
+
+        if 'host' in signed_headers:
+            v = signed_headers['host']
+            if v.find(':') != -1:
+                split = v.split(':')
+                port = split[1]
+                if str(port) == '80' or str(port) == '443':
+                    signed_headers['host'] = split[0]
+
+        signed_str = ''
+        for key in sorted(signed_headers.keys()):
+            signed_str += key + ':' + signed_headers[key] + '\n'
+
+        meta.set_signed_headers(';'.join(sorted(signed_headers.keys())))
+
+        canonical_request = '\n'.join(
+            [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str,
+             meta.signed_headers, body_hash])
+
+        return Util.sha256(canonical_request)
+
+    @staticmethod
+    def get_signing_secret_key_v4(sk, date, region, service):
+        date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date)
+        region = Util.hmac_sha256(date, region)
+        service = Util.hmac_sha256(region, service)
+        return Util.hmac_sha256(service, 'request')
+
+    @staticmethod
+    def build_auth_header_v4(signature, meta, credentials):
+        credential = credentials.ak + '/' + meta.credential_scope
+        return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature
+
+    @staticmethod
+    def get_current_format_date():
+        return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ")

+ 207 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py

@@ -0,0 +1,207 @@
+import json
+from collections import OrderedDict
+from urllib.parse import urlencode
+
+import requests
+
+from .auth import Signer
+
+VERSION = 'v1.0.137'
+
+
+class Service:
+    def __init__(self, service_info, api_info):
+        self.service_info = service_info
+        self.api_info = api_info
+        self.session = requests.session()
+
+    def set_ak(self, ak):
+        self.service_info.credentials.set_ak(ak)
+
+    def set_sk(self, sk):
+        self.service_info.credentials.set_sk(sk)
+
+    def set_session_token(self, session_token):
+        self.service_info.credentials.set_session_token(session_token)
+
+    def set_host(self, host):
+        self.service_info.host = host
+
+    def set_scheme(self, scheme):
+        self.service_info.scheme = scheme
+
+    def get(self, api, params, doseq=0):
+        if not (api in self.api_info):
+            raise Exception("no such api")
+        api_info = self.api_info[api]
+
+        r = self.prepare_request(api_info, params, doseq)
+
+        Signer.sign(r, self.service_info.credentials)
+
+        url = r.build(doseq)
+        resp = self.session.get(url, headers=r.headers,
+                                timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+        if resp.status_code == 200:
+            return resp.text
+        else:
+            raise Exception(resp.text)
+
+    def post(self, api, params, form):
+        if not (api in self.api_info):
+            raise Exception("no such api")
+        api_info = self.api_info[api]
+        r = self.prepare_request(api_info, params)
+        r.headers['Content-Type'] = 'application/x-www-form-urlencoded'
+        r.form = self.merge(api_info.form, form)
+        r.body = urlencode(r.form, True)
+        Signer.sign(r, self.service_info.credentials)
+
+        url = r.build()
+
+        resp = self.session.post(url, headers=r.headers, data=r.form,
+                                 timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+        if resp.status_code == 200:
+            return resp.text
+        else:
+            raise Exception(resp.text)
+
+    def json(self, api, params, body):
+        if not (api in self.api_info):
+            raise Exception("no such api")
+        api_info = self.api_info[api]
+        r = self.prepare_request(api_info, params)
+        r.headers['Content-Type'] = 'application/json'
+        r.body = body
+
+        Signer.sign(r, self.service_info.credentials)
+
+        url = r.build()
+        resp = self.session.post(url, headers=r.headers, data=r.body,
+                                 timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+        if resp.status_code == 200:
+            return json.dumps(resp.json())
+        else:
+            raise Exception(resp.text.encode("utf-8"))
+
+    def put(self, url, file_path, headers):
+        with open(file_path, 'rb') as f:
+            resp = self.session.put(url, headers=headers, data=f)
+            if resp.status_code == 200:
+                return True, resp.text.encode("utf-8")
+            else:
+                return False, resp.text.encode("utf-8")
+
+    def put_data(self, url, data, headers):
+        resp = self.session.put(url, headers=headers, data=data)
+        if resp.status_code == 200:
+            return True, resp.text.encode("utf-8")
+        else:
+            return False, resp.text.encode("utf-8")
+
+    def prepare_request(self, api_info, params, doseq=0):
+        for key in params:
+            if type(params[key]) == int or type(params[key]) == float or type(params[key]) == bool:
+                params[key] = str(params[key])
+            elif type(params[key]) == list:
+                if not doseq:
+                    params[key] = ','.join(params[key])
+
+        connection_timeout = self.service_info.connection_timeout
+        socket_timeout = self.service_info.socket_timeout
+
+        r = Request()
+        r.set_schema(self.service_info.scheme)
+        r.set_method(api_info.method)
+        r.set_connection_timeout(connection_timeout)
+        r.set_socket_timeout(socket_timeout)
+
+        headers = self.merge(api_info.header, self.service_info.header)
+        headers['Host'] = self.service_info.host
+        headers['User-Agent'] = 'volc-sdk-python/' + VERSION
+        r.set_headers(headers)
+
+        query = self.merge(api_info.query, params)
+        r.set_query(query)
+
+        r.set_host(self.service_info.host)
+        r.set_path(api_info.path)
+
+        return r
+
+    @staticmethod
+    def merge(param1, param2):
+        od = OrderedDict()
+        for key in param1:
+            od[key] = param1[key]
+
+        for key in param2:
+            od[key] = param2[key]
+
+        return od
+
+
+class Request:
+    def __init__(self):
+        self.schema = ''
+        self.method = ''
+        self.host = ''
+        self.path = ''
+        self.headers = OrderedDict()
+        self.query = OrderedDict()
+        self.body = ''
+        self.form = dict()
+        self.connection_timeout = 0
+        self.socket_timeout = 0
+
+    def set_schema(self, schema):
+        self.schema = schema
+
+    def set_method(self, method):
+        self.method = method
+
+    def set_host(self, host):
+        self.host = host
+
+    def set_path(self, path):
+        self.path = path
+
+    def set_headers(self, headers):
+        self.headers = headers
+
+    def set_query(self, query):
+        self.query = query
+
+    def set_body(self, body):
+        self.body = body
+
+    def set_connection_timeout(self, connection_timeout):
+        self.connection_timeout = connection_timeout
+
+    def set_socket_timeout(self, socket_timeout):
+        self.socket_timeout = socket_timeout
+
+    def build(self, doseq=0):
+        return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq)
+
+
+class ServiceInfo:
+    def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'):
+        self.host = host
+        self.header = header
+        self.credentials = credentials
+        self.connection_timeout = connection_timeout
+        self.socket_timeout = socket_timeout
+        self.scheme = scheme
+
+
+class ApiInfo:
+    def __init__(self, method, path, query, form, header):
+        self.method = method
+        self.path = path
+        self.query = query
+        self.form = form
+        self.header = header
+
+    def __str__(self):
+        return 'method: ' + self.method + ', path: ' + self.path

+ 43 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py

@@ -0,0 +1,43 @@
+import hashlib
+import hmac
+from functools import reduce
+from urllib.parse import quote
+
+
+class Util:
+    @staticmethod
+    def norm_uri(path):
+        return quote(path).replace('%2F', '/').replace('+', '%20')
+
+    @staticmethod
+    def norm_query(params):
+        query = ''
+        for key in sorted(params.keys()):
+            if type(params[key]) == list:
+                for k in params[key]:
+                    query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&'
+            else:
+                query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&'
+        query = query[:-1]
+        return query.replace('+', '%20')
+
+    @staticmethod
+    def hmac_sha256(key, content):
+        return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest()
+
+    @staticmethod
+    def sha256(content):
+        if isinstance(content, str) is True:
+            return hashlib.sha256(content.encode('utf-8')).hexdigest()
+        else:
+            return hashlib.sha256(content).hexdigest()
+
+    @staticmethod
+    def to_hex(content):
+        lst = []
+        for ch in content:
+            hv = hex(ch).replace('0x', '')
+            if len(hv) == 1:
+                hv = '0' + hv
+            lst.append(hv)
+        return reduce(lambda x, y: x + y, lst)

+ 79 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py

@@ -0,0 +1,79 @@
+import json
+import random
+from datetime import datetime
+
+
+class ChatRole:
+    USER = "user"
+    ASSISTANT = "assistant"
+    SYSTEM = "system"
+    FUNCTION = "function"
+
+
+class _Dict(dict):
+    __setattr__ = dict.__setitem__
+    __getattr__ = dict.__getitem__
+
+    def __missing__(self, key):
+        return None
+
+
+def dict_to_object(dict_obj):
+    # 支持嵌套类型
+    if isinstance(dict_obj, list):
+        insts = []
+        for i in dict_obj:
+            insts.append(dict_to_object(i))
+        return insts
+
+    if isinstance(dict_obj, dict):
+        inst = _Dict()
+        for k, v in dict_obj.items():
+            inst[k] = dict_to_object(v)
+        return inst
+
+    return dict_obj
+
+
+def json_to_object(json_str, req_id=None):
+    obj = dict_to_object(json.loads(json_str))
+    if obj and isinstance(obj, dict) and req_id:
+        obj["req_id"] = req_id
+    return obj
+
+
+def gen_req_id():
+    return datetime.now().strftime("%Y%m%d%H%M%S") + format(
+        random.randint(0, 2 ** 64 - 1), "020X"
+    )
+
+
+class SSEDecoder:
+    def __init__(self, source):
+        self.source = source
+
+    def _read(self):
+        data = b''
+        for chunk in self.source:
+            for line in chunk.splitlines(True):
+                data += line
+                if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')):
+                    yield data
+                    data = b''
+        if data:
+            yield data
+
+    def next(self):
+        for chunk in self._read():
+            for line in chunk.splitlines():
+                # skip comment
+                if line.startswith(b':'):
+                    continue
+
+                if b':' in line:
+                    field, value = line.split(b':', 1)
+                else:
+                    field, value = line, b''
+
+                if field == b'data' and len(value) > 0:
+                    yield value

+ 213 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py

@@ -0,0 +1,213 @@
+import copy
+import json
+from collections.abc import Iterator
+
+from .base.auth import Credentials, Signer
+from .base.service import ApiInfo, Service, ServiceInfo
+from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object
+
+
+class MaasService(Service):
+    def __init__(self, host, region, connection_timeout=60, socket_timeout=60):
+        service_info = self.get_service_info(
+            host, region, connection_timeout, socket_timeout
+        )
+        self._apikey = None
+        api_info = self.get_api_info()
+        super().__init__(service_info, api_info)
+
+    def set_apikey(self, apikey):
+        self._apikey = apikey
+
+    @staticmethod
+    def get_service_info(host, region, connection_timeout, socket_timeout):
+        service_info = ServiceInfo(
+            host,
+            {"Accept": "application/json"},
+            Credentials("", "", "ml_maas", region),
+            connection_timeout,
+            socket_timeout,
+            "https",
+        )
+        return service_info
+
+    @staticmethod
+    def get_api_info():
+        api_info = {
+            "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}),
+            "embeddings": ApiInfo(
+                "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {}
+            ),
+        }
+        return api_info
+
+    def chat(self, endpoint_id, req):
+        req["stream"] = False
+        return self._request(endpoint_id, "chat", req)
+
+    def stream_chat(self, endpoint_id, req):
+        req_id = gen_req_id()
+        self._validate("chat", req_id)
+        apikey = self._apikey
+
+        try:
+            req["stream"] = True
+            res = self._call(
+                endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True
+            )
+
+            decoder = SSEDecoder(res)
+
+            def iter_fn():
+                for data in decoder.next():
+                    if data == b"[DONE]":
+                        return
+
+                    try:
+                        res = json_to_object(
+                            str(data, encoding="utf-8"), req_id=req_id)
+                    except Exception:
+                        raise
+
+                    if res.error is not None and res.error.code_n != 0:
+                        raise MaasException(
+                            res.error.code_n,
+                            res.error.code,
+                            res.error.message,
+                            req_id,
+                        )
+                    yield res
+
+            return iter_fn()
+        except MaasException:
+            raise
+        except Exception as e:
+            raise new_client_sdk_request_error(str(e))
+
+    def embeddings(self, endpoint_id, req):
+        return self._request(endpoint_id, "embeddings", req)
+
+    def _request(self, endpoint_id, api, req, params={}):
+        req_id = gen_req_id()
+
+        self._validate(api, req_id)
+
+        apikey = self._apikey
+
+        try:
+            res = self._call(endpoint_id, api, req_id, params,
+                             json.dumps(req).encode("utf-8"), apikey)
+            resp = dict_to_object(res.json())
+            if resp and isinstance(resp, dict):
+                resp["req_id"] = req_id
+            return resp
+
+        except MaasException as e:
+            raise e
+        except Exception as e:
+            raise new_client_sdk_request_error(str(e), req_id)
+
+    def _validate(self, api, req_id):
+        credentials_exist = (
+            self.service_info.credentials is not None and
+            self.service_info.credentials.sk is not None and
+            self.service_info.credentials.ak is not None
+        )
+
+        if not self._apikey and not credentials_exist:
+            raise new_client_sdk_request_error("no valid credential", req_id)
+
+        if not (api in self.api_info):
+            raise new_client_sdk_request_error("no such api", req_id)
+
+    def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False):
+        api_info = copy.deepcopy(self.api_info[api])
+        api_info.path = api_info.path.format(endpoint_id=endpoint_id)
+
+        r = self.prepare_request(api_info, params)
+        r.headers["x-tt-logid"] = req_id
+        r.headers["Content-Type"] = "application/json"
+        r.body = body
+
+        if apikey is None:
+            Signer.sign(r, self.service_info.credentials)
+        elif apikey is not None:
+            r.headers["Authorization"] = "Bearer " + apikey
+
+        url = r.build()
+        res = self.session.post(
+            url,
+            headers=r.headers,
+            data=r.body,
+            timeout=(
+                self.service_info.connection_timeout,
+                self.service_info.socket_timeout,
+            ),
+            stream=stream,
+        )
+
+        if res.status_code != 200:
+            raw = res.text.encode()
+            res.close()
+            try:
+                resp = json_to_object(
+                    str(raw, encoding="utf-8"), req_id=req_id)
+            except Exception:
+                raise new_client_sdk_request_error(raw, req_id)
+
+            if resp.error:
+                raise MaasException(
+                    resp.error.code_n, resp.error.code, resp.error.message, req_id
+                )
+            else:
+                raise new_client_sdk_request_error(resp, req_id)
+
+        return res
+
+
+class MaasException(Exception):
+    def __init__(self, code_n, code, message, req_id):
+        self.code_n = code_n
+        self.code = code
+        self.message = message
+        self.req_id = req_id
+
+    def __str__(self):
+        return ("Detailed exception information is listed below.\n" +
+                "req_id: {}\n" +
+                "code_n: {}\n" +
+                "code: {}\n" +
+                "message: {}").format(self.req_id, self.code_n, self.code, self.message)
+
+
+def new_client_sdk_request_error(raw, req_id=""):
+    return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
+
+
+class BinaryResponseContent:
+    def __init__(self, response, request_id) -> None:
+        self.response = response
+        self.request_id = request_id
+
+    def stream_to_file(
+            self,
+            file: str
+    ) -> None:
+        is_first = True
+        error_bytes = b''
+        with open(file, mode="wb") as f:
+            for data in self.response:
+                if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)):
+                    error_bytes += data
+                else:
+                    f.write(data)
+
+        if len(error_bytes) > 0:
+            resp = json_to_object(
+                str(error_bytes, encoding="utf-8"), req_id=self.request_id)
+            raise MaasException(
+                resp.error.code_n, resp.error.code, resp.error.message, self.request_id
+            )
+
+    def iter_bytes(self) -> Iterator[bytes]:
+        yield from self.response

+ 10 - 0
api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py

@@ -0,0 +1,10 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class VolcengineMaaSProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        pass

+ 151 - 0
api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml

@@ -0,0 +1,151 @@
+provider: volcengine_maas
+label:
+  en_US: Volcengine
+description:
+  en_US: Volcengine MaaS models.
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.svg
+  zh_Hans: icon_l_zh.svg
+background: "#F9FAFB"
+help:
+  title:
+    en_US: Get your Access Key and Secret Access Key from Volcengine Console
+  url:
+    en_US: https://console.volcengine.com/iam/keymanage/
+supported_model_types:
+  - llm
+  - text-embedding
+configurate_methods:
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your Model Name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: volc_access_key_id
+      required: true
+      label:
+        en_US: Access Key
+        zh_Hans: Access Key
+      type: secret-input
+      placeholder:
+        en_US: Enter your Access Key
+        zh_Hans: 输入您的 Access Key
+    - variable: volc_secret_access_key
+      required: true
+      label:
+        en_US: Secret Access Key
+        zh_Hans: Secret Access Key
+      type: secret-input
+      placeholder:
+        en_US: Enter your Secret Access Key
+        zh_Hans: 输入您的 Secret Access Key
+    - variable: volc_region
+      required: true
+      label:
+        en_US: Volcengine Region
+        zh_Hans: 火山引擎地区
+      type: text-input
+      default: cn-beijing
+      placeholder:
+        en_US: Enter Volcengine Region
+        zh_Hans: 输入火山引擎地域
+    - variable: api_endpoint_host
+      required: true
+      label:
+        en_US: API Endpoint Host
+        zh_Hans: API Endpoint Host
+      type: text-input
+      default: maas-api.ml-platform-cn-beijing.volces.com
+      placeholder:
+        en_US: Enter your API Endpoint Host
+        zh_Hans: 输入 API Endpoint Host
+    - variable: endpoint_id
+      required: true
+      label:
+        en_US: Endpoint ID
+        zh_Hans: Endpoint ID
+      type: text-input
+      placeholder:
+        en_US: Enter your Endpoint ID
+        zh_Hans: 输入您的 Endpoint ID
+    - variable: base_model_name
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Base Model
+        zh_Hans: 基础模型
+      type: select
+      required: true
+      options:
+        - label:
+            en_US: Skylark2-pro-4k
+          value: Skylark2-pro-4k
+          show_on:
+            - variable: __model_type
+              value: llm
+        - label:
+            en_US: Custom
+            zh_Hans: 自定义
+          value: Custom
+    - variable: mode
+      required: true
+      show_on:
+        - variable: __model_type
+          value: llm
+        - variable: base_model_name
+          value: Custom
+      label:
+        zh_Hans: 模型类型
+        en_US: Completion Mode
+      type: select
+      default: chat
+      placeholder:
+        zh_Hans: 选择对话类型
+        en_US: Select Completion Mode
+      options:
+        - value: completion
+          label:
+            en_US: Completion
+            zh_Hans: 补全
+        - value: chat
+          label:
+            en_US: Chat
+            zh_Hans: 对话
+    - variable: context_size
+      required: true
+      show_on:
+        - variable: __model_type
+          value: llm
+        - variable: base_model_name
+          value: Custom
+      label:
+        zh_Hans: 模型上下文长度
+        en_US: Model Context Size
+      type: text-input
+      default: '4096'
+      placeholder:
+        zh_Hans: 输入您的模型上下文长度
+        en_US: Enter your Model Context Size
+    - variable: max_tokens
+      required: true
+      show_on:
+        - variable: __model_type
+          value: llm
+        - variable: base_model_name
+          value: Custom
+      label:
+        zh_Hans: 最大 token 上限
+        en_US: Upper Bound for Max Tokens
+      default: '4096'
+      type: text-input
+      placeholder:
+        zh_Hans: 输入您的模型最大 token 上限
+        en_US: Enter your model Upper Bound for Max Tokens

+ 7 - 1
api/tests/integration_tests/.env.example

@@ -73,4 +73,10 @@ MOCK_SWITCH=false
 
 # CODE EXECUTION CONFIGURATION
 CODE_EXECUTION_ENDPOINT=
-CODE_EXECUTION_API_KEY=
+CODE_EXECUTION_API_KEY=
+
+# Volcengine MaaS Credentials
+VOLC_API_KEY=
+VOLC_SECRET_KEY=
+VOLC_MODEL_ENDPOINT_ID=
+VOLC_EMBEDDING_ENDPOINT_ID=

+ 0 - 0
api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py


+ 81 - 0
api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py

@@ -0,0 +1,81 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.volcengine_maas.text_embedding.text_embedding import (
+    VolcengineMaaSTextEmbeddingModel,
+)
+
+
+def test_validate_credentials():
+    model = VolcengineMaaSTextEmbeddingModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='NOT IMPORTANT',
+            credentials={
+                'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+                'volc_region': 'cn-beijing',
+                'volc_access_key_id': 'INVALID',
+                'volc_secret_access_key': 'INVALID',
+                'endpoint_id': 'INVALID',
+            }
+        )
+
+    model.validate_credentials(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+        },
+    )
+
+
+def test_invoke_model():
+    model = VolcengineMaaSTextEmbeddingModel()
+
+    result = model.invoke(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+        },
+        texts=[
+            "hello",
+            "world"
+        ],
+        user="abc-123"
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 2
+    assert result.usage.total_tokens > 0
+
+
+def test_get_num_tokens():
+    model = VolcengineMaaSTextEmbeddingModel()
+
+    num_tokens = model.get_num_tokens(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+        },
+        texts=[
+            "hello",
+            "world"
+        ]
+    )
+
+    assert num_tokens == 2

+ 131 - 0
api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py

@@ -0,0 +1,131 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.volcengine_maas.llm.llm import VolcengineMaaSLargeLanguageModel
+
+
+def test_validate_credentials_for_chat_model():
+    model = VolcengineMaaSLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='NOT IMPORTANT',
+            credentials={
+                'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+                'volc_region': 'cn-beijing',
+                'volc_access_key_id': 'INVALID',
+                'volc_secret_access_key': 'INVALID',
+                'endpoint_id': 'INVALID',
+            }
+        )
+
+    model.validate_credentials(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+        }
+    )
+
+
+def test_invoke_model():
+    model = VolcengineMaaSLargeLanguageModel()
+
+    response = model.invoke(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+            'base_model_name': 'Skylark2-pro-4k',
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.7,
+            'top_p': 1.0,
+            'top_k': 1,
+        },
+        stop=['you'],
+        user="abc-123",
+        stream=False
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+    assert response.usage.total_tokens > 0
+
+
+def test_invoke_stream_model():
+    model = VolcengineMaaSLargeLanguageModel()
+
+    response = model.invoke(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+            'base_model_name': 'Skylark2-pro-4k',
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.7,
+            'top_p': 1.0,
+            'top_k': 1,
+        },
+        stop=['you'],
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(response, Generator)
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(
+            chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_num_tokens():
+    model = VolcengineMaaSLargeLanguageModel()
+
+    response = model.get_num_tokens(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+            'base_model_name': 'Skylark2-pro-4k',
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        tools=[]
+    )
+
+    assert isinstance(response, int)
+    assert response == 6

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików