Kaynağa Gözat

feat: claude api support (#572)

John Wang 1 yıl önce
ebeveyn
işleme
7599f79a17
52 değiştirilmiş dosya ile 637 ekleme ve 349 silme
  1. 33 1
      api/commands.py
  2. 6 0
      api/config.py
  3. 2 2
      api/controllers/console/app/audio.py
  4. 6 6
      api/controllers/console/app/completion.py
  5. 1 1
      api/controllers/console/app/error.py
  6. 4 4
      api/controllers/console/app/generator.py
  7. 6 6
      api/controllers/console/app/message.py
  8. 4 4
      api/controllers/console/datasets/datasets_document.py
  9. 2 2
      api/controllers/console/datasets/hit_testing.py
  10. 2 2
      api/controllers/console/explore/audio.py
  11. 6 6
      api/controllers/console/explore/completion.py
  12. 6 6
      api/controllers/console/explore/message.py
  13. 21 6
      api/controllers/console/workspace/providers.py
  14. 2 2
      api/controllers/service_api/app/audio.py
  15. 6 6
      api/controllers/service_api/app/completion.py
  16. 2 2
      api/controllers/service_api/dataset/document.py
  17. 2 2
      api/controllers/web/audio.py
  18. 6 6
      api/controllers/web/completion.py
  19. 6 6
      api/controllers/web/message.py
  20. 8 0
      api/core/__init__.py
  21. 1 1
      api/core/callback_handler/llm_callback_handler.py
  22. 36 20
      api/core/completion.py
  23. 19 2
      api/core/constant/llm_constant.py
  24. 3 2
      api/core/conversation_message_task.py
  25. 2 0
      api/core/embedding/cached_embedding.py
  26. 15 1
      api/core/generator/llm_generator.py
  27. 1 1
      api/core/index/index.py
  28. 3 0
      api/core/llm/error.py
  29. 44 28
      api/core/llm/llm_builder.py
  30. 127 12
      api/core/llm/provider/anthropic_provider.py
  31. 2 3
      api/core/llm/provider/azure_provider.py
  32. 25 29
      api/core/llm/provider/base.py
  33. 4 4
      api/core/llm/provider/llm_provider_service.py
  34. 11 0
      api/core/llm/provider/openai_provider.py
  35. 19 36
      api/core/llm/streamable_azure_chat_open_ai.py
  36. 5 11
      api/core/llm/streamable_azure_open_ai.py
  37. 39 0
      api/core/llm/streamable_chat_anthropic.py
  38. 17 34
      api/core/llm/streamable_chat_open_ai.py
  39. 5 11
      api/core/llm/streamable_open_ai.py
  40. 3 2
      api/core/llm/whisper.py
  41. 27 0
      api/core/llm/wrappers/anthropic_wrapper.py
  42. 1 25
      api/core/llm/wrappers/openai_wrapper.py
  43. 5 5
      api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
  44. 2 2
      api/core/tool/dataset_index_tool.py
  45. 16 1
      api/events/event_handlers/create_provider_when_tenant_created.py
  46. 16 1
      api/events/event_handlers/create_provider_when_tenant_updated.py
  47. 2 1
      api/requirements.txt
  48. 28 4
      api/services/app_model_config_service.py
  49. 1 6
      api/services/audio_service.py
  50. 1 1
      api/services/hit_testing_service.py
  51. 24 34
      api/services/provider_service.py
  52. 2 2
      api/services/workspace_service.py

+ 33 - 1
api/commands.py

@@ -18,7 +18,8 @@ from models.model import Account
 import secrets
 import base64
 
-from models.provider import Provider
+from models.provider import Provider, ProviderName
+from services.provider_service import ProviderService
 
 
 @click.command('reset-password', help='Reset the account password.')
@@ -193,9 +194,40 @@ def recreate_all_dataset_indexes():
     click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
 
 
+@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
+def sync_anthropic_hosted_providers():
+    click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
+    count = 0
+
+    page = 1
+    while True:
+        try:
+            tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
+        except NotFound:
+            break
+
+        page += 1
+        for tenant in tenants:
+            try:
+                click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
+                ProviderService.create_system_provider(
+                    tenant,
+                    ProviderName.ANTHROPIC.value,
+                    current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
+                    True
+                )
+                count += 1
+            except Exception as e:
+                click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
+                continue
+
+    click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
+
+
 def register_commands(app):
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_email)
     app.cli.add_command(generate_invitation_codes)
     app.cli.add_command(reset_encrypt_key_pair)
     app.cli.add_command(recreate_all_dataset_indexes)
+    app.cli.add_command(sync_anthropic_hosted_providers)

+ 6 - 0
api/config.py

@@ -51,6 +51,8 @@ DEFAULTS = {
     'LOG_LEVEL': 'INFO',
     'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
     'DEFAULT_LLM_PROVIDER': 'openai',
+    'OPENAI_HOSTED_QUOTA_LIMIT': 200,
+    'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
     'TENANT_DOCUMENT_COUNT': 100
 }
 
@@ -192,6 +194,10 @@ class Config:
 
         # hosted provider credentials
         self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
+        self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
+
+        self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
+        self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
 
         # By default it is False
         # You could disable it for compatibility with certain OpenAPI providers

+ 2 - 2
api/controllers/console/app/audio.py

@@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource):
             raise UnsupportedAudioTypeError()
         except ProviderNotSupportSpeechToTextServiceError:
             raise ProviderNotSupportSpeechToTextError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:

+ 6 - 6
api/controllers/console/app/completion.py

@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
         except services.errors.app_model_config.AppModelConfigBrokenError:
             logging.exception("App model config broken.")
             raise AppUnavailableError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
         except services.errors.app_model_config.AppModelConfigBrokenError:
             logging.exception("App model config broken.")
             raise AppUnavailableError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
             except services.errors.app_model_config.AppModelConfigBrokenError:
                 logging.exception("App model config broken.")
                 yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
-            except ProviderTokenNotInitError:
-                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+            except ProviderTokenNotInitError as ex:
+                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
             except QuotaExceededError:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:

+ 1 - 1
api/controllers/console/app/error.py

@@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
 
 class ProviderQuotaExceededError(BaseHTTPException):
     error_code = 'provider_quota_exceeded'
-    description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
+    description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
                   "Please go to Settings -> Model Provider to complete your own provider credentials."
     code = 400
 

+ 4 - 4
api/controllers/console/app/generator.py

@@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
                 account.current_tenant_id,
                 args['prompt_template']
             )
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
                 args['audiences'],
                 args['hoping_to_solve']
             )
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:

+ 6 - 6
api/controllers/console/app/message.py

@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
             raise NotFound("Message Not Exists.")
         except MoreLikeThisDisabledError:
             raise AppMoreLikeThisDisabledError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
             except MoreLikeThisDisabledError:
                 yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
-            except ProviderTokenNotInitError:
-                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+            except ProviderTokenNotInitError as ex:
+                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
             except QuotaExceededError:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
             raise NotFound("Message not found")
         except ConversationNotExistsError:
             raise NotFound("Conversation not found")
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:

+ 4 - 4
api/controllers/console/datasets/datasets_document.py

@@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource):
 
         try:
             documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -324,8 +324,8 @@ class DatasetInitApi(Resource):
                 document_data=args,
                 account=current_user
             )
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:

+ 2 - 2
api/controllers/console/datasets/hit_testing.py

@@ -95,8 +95,8 @@ class HitTestingApi(Resource):
             return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
         except services.errors.index.IndexNotInitializedError:
             raise DatasetNotInitializedError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:

+ 2 - 2
api/controllers/console/explore/audio.py

@@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource):
             raise UnsupportedAudioTypeError()
         except ProviderNotSupportSpeechToTextServiceError:
             raise ProviderNotSupportSpeechToTextError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:

+ 6 - 6
api/controllers/console/explore/completion.py

@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
         except services.errors.app_model_config.AppModelConfigBrokenError:
             logging.exception("App model config broken.")
             raise AppUnavailableError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
         except services.errors.app_model_config.AppModelConfigBrokenError:
             logging.exception("App model config broken.")
             raise AppUnavailableError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
             except services.errors.app_model_config.AppModelConfigBrokenError:
                 logging.exception("App model config broken.")
                 yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
-            except ProviderTokenNotInitError:
-                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+            except ProviderTokenNotInitError as ex:
+                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
             except QuotaExceededError:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:

+ 6 - 6
api/controllers/console/explore/message.py

@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
             raise NotFound("Message Not Exists.")
         except MoreLikeThisDisabledError:
             raise AppMoreLikeThisDisabledError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
             except MoreLikeThisDisabledError:
                 yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
-            except ProviderTokenNotInitError:
-                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+            except ProviderTokenNotInitError as ex:
+                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
             except QuotaExceededError:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
             raise NotFound("Conversation not found")
         except SuggestedQuestionsAfterAnswerDisabledError:
             raise AppSuggestedQuestionsAfterAnswerDisabledError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:

+ 21 - 6
api/controllers/console/workspace/providers.py

@@ -3,6 +3,7 @@ import base64
 import json
 import logging
 
+from flask import current_app
 from flask_login import login_required, current_user
 from flask_restful import Resource, reqparse, abort
 from werkzeug.exceptions import Forbidden
@@ -34,7 +35,7 @@ class ProviderListApi(Resource):
         plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
         """
 
-        ProviderService.init_supported_provider(current_user.current_tenant, "cloud")
+        ProviderService.init_supported_provider(current_user.current_tenant)
         providers = Provider.query.filter_by(tenant_id=tenant_id).all()
 
         provider_list = [
@@ -50,7 +51,8 @@ class ProviderListApi(Resource):
                        'quota_used': p.quota_used
                    } if p.provider_type == ProviderType.SYSTEM.value else {}),
                 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
-                                                                ProviderName(p.provider_name))
+                                                                ProviderName(p.provider_name), only_custom=True)
+                if p.provider_type == ProviderType.CUSTOM.value else None
             }
             for p in providers
         ]
@@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
                                       is_valid=token_is_valid)
             db.session.add(provider_model)
 
-        if provider_model.is_valid:
+        if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
             other_providers = db.session.query(Provider).filter(
                 Provider.tenant_id == tenant.id,
+                Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
                 Provider.provider_name != provider,
                 Provider.provider_type == ProviderType.CUSTOM.value
             ).all()
@@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
 
         db.session.commit()
 
-        if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
+        if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
                         ProviderName.HUGGINGFACEHUB.value]:
             return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
 
@@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
         args = parser.parse_args()
 
         # todo: remove this when the provider is supported
-        if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
+        if provider in [ProviderName.COHERE.value,
                         ProviderName.HUGGINGFACEHUB.value]:
             return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
 
@@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
             provider_model.is_valid = args['is_enabled']
             db.session.commit()
         elif not provider_model:
-            ProviderService.create_system_provider(tenant, provider, args['is_enabled'])
+            if provider == ProviderName.OPENAI.value:
+                quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
+            elif provider == ProviderName.ANTHROPIC.value:
+                quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
+            else:
+                quota_limit = 0
+
+            ProviderService.create_system_provider(
+                tenant,
+                provider,
+                quota_limit,
+                args['is_enabled']
+            )
         else:
             abort(403)
 

+ 2 - 2
api/controllers/service_api/app/audio.py

@@ -43,8 +43,8 @@ class AudioApi(AppApiResource):
             raise UnsupportedAudioTypeError()
         except ProviderNotSupportSpeechToTextServiceError:
             raise ProviderNotSupportSpeechToTextError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:

+ 6 - 6
api/controllers/service_api/app/completion.py

@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
         except services.errors.app_model_config.AppModelConfigBrokenError:
             logging.exception("App model config broken.")
             raise AppUnavailableError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
         except services.errors.app_model_config.AppModelConfigBrokenError:
             logging.exception("App model config broken.")
             raise AppUnavailableError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
             except services.errors.app_model_config.AppModelConfigBrokenError:
                 logging.exception("App model config broken.")
                 yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
-            except ProviderTokenNotInitError:
-                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+            except ProviderTokenNotInitError as ex:
+                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
             except QuotaExceededError:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:

+ 2 - 2
api/controllers/service_api/dataset/document.py

@@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
                 dataset_process_rule=dataset.latest_process_rule,
                 created_from='api'
             )
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         document = documents[0]
         if doc_type and doc_metadata:
             metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]

+ 2 - 2
api/controllers/web/audio.py

@@ -45,8 +45,8 @@ class AudioApi(WebApiResource):
             raise UnsupportedAudioTypeError()
         except ProviderNotSupportSpeechToTextServiceError:
             raise ProviderNotSupportSpeechToTextError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:

+ 6 - 6
api/controllers/web/completion.py

@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
         except services.errors.app_model_config.AppModelConfigBrokenError:
             logging.exception("App model config broken.")
             raise AppUnavailableError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
         except services.errors.app_model_config.AppModelConfigBrokenError:
             logging.exception("App model config broken.")
             raise AppUnavailableError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
             except services.errors.app_model_config.AppModelConfigBrokenError:
                 logging.exception("App model config broken.")
                 yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
-            except ProviderTokenNotInitError:
-                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+            except ProviderTokenNotInitError as ex:
+                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
             except QuotaExceededError:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:

+ 6 - 6
api/controllers/web/message.py

@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
             raise NotFound("Message Not Exists.")
         except MoreLikeThisDisabledError:
             raise AppMoreLikeThisDisabledError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:
@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
                 yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
             except MoreLikeThisDisabledError:
                 yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
-            except ProviderTokenNotInitError:
-                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+            except ProviderTokenNotInitError as ex:
+                yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
             except QuotaExceededError:
                 yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
             except ModelCurrentlyNotSupportError:
@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
             raise NotFound("Conversation not found")
         except SuggestedQuestionsAfterAnswerDisabledError:
             raise AppSuggestedQuestionsAfterAnswerDisabledError()
-        except ProviderTokenNotInitError:
-            raise ProviderNotInitializeError()
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
             raise ProviderQuotaExceededError()
         except ModelCurrentlyNotSupportError:

+ 8 - 0
api/core/__init__.py

@@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
     api_key: str
 
 
+class HostedAnthropicCredential(BaseModel):
+    api_key: str
+
+
 class HostedLLMCredentials(BaseModel):
     openai: Optional[HostedOpenAICredential] = None
+    anthropic: Optional[HostedAnthropicCredential] = None
 
 
 hosted_llm_credentials = HostedLLMCredentials()
@@ -26,3 +31,6 @@ def init_app(app: Flask):
 
     if app.config.get("OPENAI_API_KEY"):
         hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
+
+    if app.config.get("ANTHROPIC_API_KEY"):
+        hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))

+ 1 - 1
api/core/callback_handler/llm_callback_handler.py

@@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
             })
 
         self.llm_message.prompt = real_prompts
-        self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
+        self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
 
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any

+ 36 - 20
api/core/completion.py

@@ -118,6 +118,7 @@ class Completion:
         prompt, stop_words = cls.get_main_llm_prompt(
             mode=mode,
             llm=final_llm,
+            model=app_model_config.model_dict,
             pre_prompt=app_model_config.pre_prompt,
             query=query,
             inputs=inputs,
@@ -129,6 +130,7 @@ class Completion:
 
         cls.recale_llm_max_tokens(
             final_llm=final_llm,
+            model=app_model_config.model_dict,
             prompt=prompt,
             mode=mode
         )
@@ -138,7 +140,8 @@ class Completion:
         return response
 
     @classmethod
-    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
+    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
+                            pre_prompt: str, query: str, inputs: dict,
                             chain_output: Optional[str],
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
             Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
@@ -151,10 +154,11 @@ class Completion:
 
         if mode == 'completion':
             prompt_template = JinjaPromptTemplate.from_template(
-                template=("""Use the following CONTEXT as your learned knowledge:
-[CONTEXT]
+                template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
+
+<context>
 {{context}}
-[END CONTEXT]
+</context>
 
 When answer to user:
 - If you don't know, just say that you don't know.
@@ -204,10 +208,11 @@ And answer according to the language of the user's question.
 
             if chain_output:
                 human_inputs['context'] = chain_output
-                human_message_prompt += """Use the following CONTEXT as your learned knowledge.
-[CONTEXT]
+                human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
+
+<context>
 {{context}}
-[END CONTEXT]
+</context>
 
 When answer to user:
 - If you don't know, just say that you don't know.
@@ -219,7 +224,7 @@ And answer according to the language of the user's question.
             if pre_prompt:
                 human_message_prompt += pre_prompt
 
-            query_prompt = "\nHuman: {{query}}\nAI: "
+            query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
 
             if memory:
                 # append chat histories
@@ -228,9 +233,11 @@ And answer according to the language of the user's question.
                     inputs=human_inputs
                 )
 
-                curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
-                rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
-                              - memory.llm.max_tokens - curr_message_tokens
+                curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
+                model_name = model['name']
+                max_tokens = model.get("completion_params").get('max_tokens')
+                rest_tokens = llm_constant.max_context_token_length[model_name] \
+                              - max_tokens - curr_message_tokens
                 rest_tokens = max(rest_tokens, 0)
                 histories = cls.get_history_messages_from_memory(memory, rest_tokens)
 
@@ -241,7 +248,10 @@ And answer according to the language of the user's question.
                 #         if histories_param not in human_inputs:
                 #             human_inputs[histories_param] = '{{' + histories_param + '}}'
 
-                human_message_prompt += "\n\n" + histories
+                human_message_prompt += "\n\n" if human_message_prompt else ""
+                human_message_prompt += "Here is the chat histories between human and assistant, " \
+                                        "inside <histories></histories> XML tags.\n\n<histories>"
+                human_message_prompt += histories + "</histories>"
 
             human_message_prompt += query_prompt
 
@@ -307,13 +317,15 @@ And answer according to the language of the user's question.
             model=app_model_config.model_dict
         )
 
-        model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
-        max_tokens = llm.max_tokens
+        model_name = app_model_config.model_dict.get("name")
+        model_limited_tokens = llm_constant.max_context_token_length[model_name]
+        max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
 
         # get prompt without memory and context
         prompt, _ = cls.get_main_llm_prompt(
             mode=mode,
             llm=llm,
+            model=app_model_config.model_dict,
             pre_prompt=app_model_config.pre_prompt,
             query=query,
             inputs=inputs,
@@ -332,16 +344,17 @@ And answer according to the language of the user's question.
         return rest_tokens
 
     @classmethod
-    def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
+    def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
                               prompt: Union[str, List[BaseMessage]], mode: str):
         # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit
-        model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
-        max_tokens = final_llm.max_tokens
+        model_name = model.get("name")
+        model_limited_tokens = llm_constant.max_context_token_length[model_name]
+        max_tokens = model.get("completion_params").get('max_tokens')
 
         if mode == 'completion' and isinstance(final_llm, BaseLLM):
             prompt_tokens = final_llm.get_num_tokens(prompt)
         else:
-            prompt_tokens = final_llm.get_messages_tokens(prompt)
+            prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
 
         if prompt_tokens + max_tokens > model_limited_tokens:
             max_tokens = max(model_limited_tokens - prompt_tokens, 16)
@@ -350,9 +363,10 @@ And answer according to the language of the user's question.
     @classmethod
     def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
                                 app_model_config: AppModelConfig, user: Account, streaming: bool):
-        llm: StreamableOpenAI = LLMBuilder.to_llm(
+
+        llm = LLMBuilder.to_llm_from_model(
             tenant_id=app.tenant_id,
-            model_name='gpt-3.5-turbo',
+            model=app_model_config.model_dict,
             streaming=streaming
         )
 
@@ -360,6 +374,7 @@ And answer according to the language of the user's question.
         original_prompt, _ = cls.get_main_llm_prompt(
             mode="completion",
             llm=llm,
+            model=app_model_config.model_dict,
             pre_prompt=pre_prompt,
             query=message.query,
             inputs=message.inputs,
@@ -390,6 +405,7 @@ And answer according to the language of the user's question.
 
         cls.recale_llm_max_tokens(
             final_llm=llm,
+            model=app_model_config.model_dict,
             prompt=prompt,
             mode='completion'
         )

+ 19 - 2
api/core/constant/llm_constant.py

@@ -1,6 +1,8 @@
 from _decimal import Decimal
 
 models = {
+    'claude-instant-1': 'anthropic',  # 100,000 tokens
+    'claude-2': 'anthropic',  # 100,000 tokens
     'gpt-4': 'openai',  # 8,192 tokens
     'gpt-4-32k': 'openai',  # 32,768 tokens
     'gpt-3.5-turbo': 'openai',  # 4,096 tokens
@@ -10,10 +12,13 @@ models = {
     'text-curie-001': 'openai',  # 2,049 tokens
     'text-babbage-001': 'openai',  # 2,049 tokens
     'text-ada-001': 'openai',  # 2,049 tokens
-    'text-embedding-ada-002': 'openai'  # 8191 tokens, 1536 dimensions
+    'text-embedding-ada-002': 'openai',  # 8191 tokens, 1536 dimensions
+    'whisper-1': 'openai'
 }
 
 max_context_token_length = {
+    'claude-instant-1': 100000,
+    'claude-2': 100000,
     'gpt-4': 8192,
     'gpt-4-32k': 32768,
     'gpt-3.5-turbo': 4096,
@@ -23,17 +28,21 @@ max_context_token_length = {
     'text-curie-001': 2049,
     'text-babbage-001': 2049,
     'text-ada-001': 2049,
-    'text-embedding-ada-002': 8191
+    'text-embedding-ada-002': 8191,
 }
 
 models_by_mode = {
     'chat': [
+        'claude-instant-1',  # 100,000 tokens
+        'claude-2',  # 100,000 tokens
         'gpt-4',  # 8,192 tokens
         'gpt-4-32k',  # 32,768 tokens
         'gpt-3.5-turbo',  # 4,096 tokens
         'gpt-3.5-turbo-16k',  # 16,384 tokens
     ],
     'completion': [
+        'claude-instant-1',  # 100,000 tokens
+        'claude-2',  # 100,000 tokens
         'gpt-4',  # 8,192 tokens
         'gpt-4-32k',  # 32,768 tokens
         'gpt-3.5-turbo',  # 4,096 tokens
@@ -52,6 +61,14 @@ models_by_mode = {
 model_currency = 'USD'
 
 model_prices = {
+    'claude-instant-1': {
+        'prompt': Decimal('0.00163'),
+        'completion': Decimal('0.00551'),
+    },
+    'claude-2': {
+        'prompt': Decimal('0.01102'),
+        'completion': Decimal('0.03268'),
+    },
     'gpt-4': {
         'prompt': Decimal('0.03'),
         'completion': Decimal('0.06'),

+ 3 - 2
api/core/conversation_message_task.py

@@ -56,7 +56,7 @@ class ConversationMessageTask:
         )
 
     def init(self):
-        provider_name = LLMBuilder.get_default_provider(self.app.tenant_id)
+        provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
         self.model_dict['provider'] = provider_name
 
         override_model_configs = None
@@ -89,7 +89,7 @@ class ConversationMessageTask:
                 system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
                 system_instruction = system_message.content
                 llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
-                system_instruction_tokens = llm.get_messages_tokens([system_message])
+                system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
 
         if not self.conversation:
             self.is_new_conversation = True
@@ -185,6 +185,7 @@ class ConversationMessageTask:
         if provider and provider.provider_type == ProviderType.SYSTEM.value:
             db.session.query(Provider).filter(
                 Provider.tenant_id == self.app.tenant_id,
+                Provider.provider_name == provider.provider_name,
                 Provider.quota_limit > Provider.quota_used
             ).update({'quota_used': Provider.quota_used + 1})
 

+ 2 - 0
api/core/embedding/cached_embedding.py

@@ -4,6 +4,7 @@ from typing import List
 from langchain.embeddings.base import Embeddings
 from sqlalchemy.exc import IntegrityError
 
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
 from extensions.ext_database import db
 from libs import helper
 from models.dataset import Embedding
@@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings):
         text_embeddings.extend(embedding_results)
         return text_embeddings
 
+    @handle_openai_exceptions
     def embed_query(self, text: str) -> List[float]:
         """Embed query text."""
         # use doc embedding cache or store if not exists

+ 15 - 1
api/core/generator/llm_generator.py

@@ -23,6 +23,10 @@ class LLMGenerator:
     @classmethod
     def generate_conversation_name(cls, tenant_id: str, query, answer):
         prompt = CONVERSATION_TITLE_PROMPT
+
+        if len(query) > 2000:
+            query = query[:300] + "...[TRUNCATED]..." + query[-300:]
+
         prompt = prompt.format(query=query)
         llm: StreamableOpenAI = LLMBuilder.to_llm(
             tenant_id=tenant_id,
@@ -52,7 +56,17 @@ class LLMGenerator:
             if not message.answer:
                 continue
 
-            message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
+            if len(message.query) > 2000:
+                query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
+            else:
+                query = message.query
+
+            if len(message.answer) > 2000:
+                answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
+            else:
+                answer = message.answer
+
+            message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
             if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
                 context += message_qa_text
 

+ 1 - 1
api/core/index/index.py

@@ -17,7 +17,7 @@ class IndexBuilder:
 
             model_credentials = LLMBuilder.get_model_credentials(
                 tenant_id=dataset.tenant_id,
-                model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
+                model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
                 model_name='text-embedding-ada-002'
             )
 

+ 3 - 0
api/core/llm/error.py

@@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
     """
     description = "Provider Token Not Init"
 
+    def __init__(self, *args, **kwargs):
+        self.description = args[0] if args else self.description
+
 
 class QuotaExceededError(Exception):
     """

+ 44 - 28
api/core/llm/llm_builder.py

@@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
 from core.llm.provider.llm_provider_service import LLMProviderService
 from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
 from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
+from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
 from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
 from core.llm.streamable_open_ai import StreamableOpenAI
-from models.provider import ProviderType
+from models.provider import ProviderType, ProviderName
 
 
 class LLMBuilder:
@@ -32,43 +33,43 @@ class LLMBuilder:
 
     @classmethod
     def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
-        provider = cls.get_default_provider(tenant_id)
+        provider = cls.get_default_provider(tenant_id, model_name)
 
         model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
 
+        llm_cls = None
         mode = cls.get_mode_by_model(model_name)
         if mode == 'chat':
-            if provider == 'openai':
+            if provider == ProviderName.OPENAI.value:
                 llm_cls = StreamableChatOpenAI
-            else:
+            elif provider == ProviderName.AZURE_OPENAI.value:
                 llm_cls = StreamableAzureChatOpenAI
+            elif provider == ProviderName.ANTHROPIC.value:
+                llm_cls = StreamableChatAnthropic
         elif mode == 'completion':
-            if provider == 'openai':
+            if provider == ProviderName.OPENAI.value:
                 llm_cls = StreamableOpenAI
-            else:
+            elif provider == ProviderName.AZURE_OPENAI.value:
                 llm_cls = StreamableAzureOpenAI
-        else:
-            raise ValueError(f"model name {model_name} is not supported.")
 
+        if not llm_cls:
+            raise ValueError(f"model name {model_name} is not supported.")
 
         model_kwargs = {
+            'model_name': model_name,
+            'temperature': kwargs.get('temperature', 0),
+            'max_tokens': kwargs.get('max_tokens', 256),
             'top_p': kwargs.get('top_p', 1),
             'frequency_penalty': kwargs.get('frequency_penalty', 0),
             'presence_penalty': kwargs.get('presence_penalty', 0),
+            'callbacks': kwargs.get('callbacks', None),
+            'streaming': kwargs.get('streaming', False),
         }
 
-        model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
+        model_kwargs.update(model_credentials)
+        model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
 
-        return llm_cls(
-            model_name=model_name,
-            temperature=kwargs.get('temperature', 0),
-            max_tokens=kwargs.get('max_tokens', 256),
-            **model_extras_kwargs,
-            callbacks=kwargs.get('callbacks', None),
-            streaming=kwargs.get('streaming', False),
-            # request_timeout=None
-            **model_credentials
-        )
+        return llm_cls(**model_kwargs)
 
     @classmethod
     def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
@@ -118,14 +119,29 @@ class LLMBuilder:
         return provider_service.get_credentials(model_name)
 
     @classmethod
-    def get_default_provider(cls, tenant_id: str) -> str:
-        provider = BaseProvider.get_valid_provider(tenant_id)
-        if not provider:
-            raise ProviderTokenNotInitError()
-
-        if provider.provider_type == ProviderType.SYSTEM.value:
-            provider_name = 'openai'
-        else:
-            provider_name = provider.provider_name
+    def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
+        provider_name = llm_constant.models[model_name]
+
+        if provider_name == 'openai':
+            # get the default provider (openai / azure_openai) for the tenant
+            openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
+            azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
+
+            provider = None
+            if openai_provider:
+                provider = openai_provider
+            elif azure_openai_provider:
+                provider = azure_openai_provider
+
+            if not provider:
+                raise ProviderTokenNotInitError(
+                    f"No valid {provider_name} model provider credentials found. "
+                    f"Please go to Settings -> Model Provider to complete your provider credentials."
+                )
+
+            if provider.provider_type == ProviderType.SYSTEM.value:
+                provider_name = 'openai'
+            else:
+                provider_name = provider.provider_name
 
         return provider_name

+ 127 - 12
api/core/llm/provider/anthropic_provider.py

@@ -1,23 +1,138 @@
-from typing import Optional
+import json
+import logging
+from typing import Optional, Union
 
+import anthropic
+from langchain.chat_models import ChatAnthropic
+from langchain.schema import HumanMessage
+
+from core import hosted_llm_credentials
+from core.llm.error import ProviderTokenNotInitError
 from core.llm.provider.base import BaseProvider
-from models.provider import ProviderName
+from core.llm.provider.errors import ValidateFailedError
+from models.provider import ProviderName, ProviderType
 
 
 class AnthropicProvider(BaseProvider):
     def get_models(self, model_id: Optional[str] = None) -> list[dict]:
-        credentials = self.get_credentials(model_id)
-        # todo
-        return []
+        return [
+            {
+                'id': 'claude-instant-1',
+                'name': 'claude-instant-1',
+            },
+            {
+                'id': 'claude-2',
+                'name': 'claude-2',
+            },
+        ]
 
     def get_credentials(self, model_id: Optional[str] = None) -> dict:
+        return self.get_provider_api_key(model_id=model_id)
+
+    def get_provider_name(self):
+        return ProviderName.ANTHROPIC
+
+    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
         """
-        Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
-        The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
+        Returns the provider configs.
         """
-        return {
-            'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
-        }
+        try:
+            config = self.get_provider_api_key(only_custom=only_custom)
+        except:
+            config = {
+                'anthropic_api_key': ''
+            }
 
-    def get_provider_name(self):
-        return ProviderName.ANTHROPIC
+        if obfuscated:
+            if not config.get('anthropic_api_key'):
+                config = {
+                    'anthropic_api_key': ''
+                }
+
+            config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
+            return config
+
+        return config
+
+    def get_encrypted_token(self, config: Union[dict | str]):
+        """
+        Returns the encrypted token.
+        """
+        return json.dumps({
+            'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
+        })
+
+    def get_decrypted_token(self, token: str):
+        """
+        Returns the decrypted token.
+        """
+        config = json.loads(token)
+        config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
+        return config
+
+    def get_token_type(self):
+        return dict
+
+    def config_validate(self, config: Union[dict | str]):
+        """
+        Validates the given config.
+        """
+        # check OpenAI / Azure OpenAI credential is valid
+        openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
+        azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
+
+        provider = None
+        if openai_provider:
+            provider = openai_provider
+        elif azure_openai_provider:
+            provider = azure_openai_provider
+
+        if not provider:
+            raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
+
+        if provider.provider_type == ProviderType.SYSTEM.value:
+            quota_used = provider.quota_used if provider.quota_used is not None else 0
+            quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
+            if quota_used >= quota_limit:
+                raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
+                                          f"please configure OpenAI or Azure OpenAI provider first.")
+
+        try:
+            if not isinstance(config, dict):
+                raise ValueError('Config must be a object.')
+
+            if 'anthropic_api_key' not in config:
+                raise ValueError('anthropic_api_key must be provided.')
+
+            chat_llm = ChatAnthropic(
+                model='claude-instant-1',
+                anthropic_api_key=config['anthropic_api_key'],
+                max_tokens_to_sample=10,
+                temperature=0,
+                default_request_timeout=60
+            )
+
+            messages = [
+                HumanMessage(
+                    content="ping"
+                )
+            ]
+
+            chat_llm(messages)
+        except anthropic.APIConnectionError as ex:
+            raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
+        except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
+            raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
+                                      f"{ex.body['error']['type']}: {ex.body['error']['message']}")
+        except Exception as ex:
+            logging.exception('Anthropic config validation failed')
+            raise ex
+
+    def get_hosted_credentials(self) -> Union[str | dict]:
+        if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
+            raise ProviderTokenNotInitError(
+                f"No valid {self.get_provider_name().value} model provider credentials found. "
+                f"Please go to Settings -> Model Provider to complete your provider credentials."
+            )
+
+        return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}

+ 2 - 3
api/core/llm/provider/azure_provider.py

@@ -52,12 +52,12 @@ class AzureProvider(BaseProvider):
     def get_provider_name(self):
         return ProviderName.AZURE_OPENAI
 
-    def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
+    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
         """
         Returns the provider configs.
         """
         try:
-            config = self.get_provider_api_key()
+            config = self.get_provider_api_key(only_custom=only_custom)
         except:
             config = {
                 'openai_api_type': 'azure',
@@ -81,7 +81,6 @@ class AzureProvider(BaseProvider):
         return config
 
     def get_token_type(self):
-        # TODO: change to dict when implemented
         return dict
 
     def config_validate(self, config: Union[dict | str]):

+ 25 - 29
api/core/llm/provider/base.py

@@ -2,7 +2,7 @@ import base64
 from abc import ABC, abstractmethod
 from typing import Optional, Union
 
-from core import hosted_llm_credentials
+from core.constant import llm_constant
 from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
 from extensions.ext_database import db
 from libs import rsa
@@ -14,15 +14,18 @@ class BaseProvider(ABC):
     def __init__(self, tenant_id: str):
         self.tenant_id = tenant_id
 
-    def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
+    def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
         """
         Returns the decrypted API key for the given tenant_id and provider_name.
         If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
         If the provider is not found or not valid, raises a ProviderTokenNotInitError.
         """
-        provider = self.get_provider(prefer_custom)
+        provider = self.get_provider(only_custom)
         if not provider:
-            raise ProviderTokenNotInitError()
+            raise ProviderTokenNotInitError(
+                f"No valid {llm_constant.models[model_id]} model provider credentials found. "
+                f"Please go to Settings -> Model Provider to complete your provider credentials."
+            )
 
         if provider.provider_type == ProviderType.SYSTEM.value:
             quota_used = provider.quota_used if provider.quota_used is not None else 0
@@ -38,18 +41,19 @@ class BaseProvider(ABC):
         else:
             return self.get_decrypted_token(provider.encrypted_config)
 
-    def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
+    def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
         """
         Returns the Provider instance for the given tenant_id and provider_name.
         If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
         """
-        return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
+        return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
 
     @classmethod
-    def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
+    def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
+        Provider]:
         """
         Returns the Provider instance for the given tenant_id and provider_name.
-        If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
+        If both CUSTOM and System providers exist.
         """
         query = db.session.query(Provider).filter(
             Provider.tenant_id == tenant_id
@@ -58,39 +62,31 @@ class BaseProvider(ABC):
         if provider_name:
             query = query.filter(Provider.provider_name == provider_name)
 
-        providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
+        if only_custom:
+            query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
 
-        custom_provider = None
-        system_provider = None
+        providers = query.order_by(Provider.provider_type.asc()).all()
 
         for provider in providers:
             if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
-                custom_provider = provider
+                return provider
             elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
-                system_provider = provider
-
-        if custom_provider:
-            return custom_provider
-        elif system_provider:
-            return system_provider
-        else:
-            return None
+                return provider
 
-    def get_hosted_credentials(self) -> str:
-        if self.get_provider_name() != ProviderName.OPENAI:
-            raise ProviderTokenNotInitError()
+        return None
 
-        if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
-            raise ProviderTokenNotInitError()
-
-        return hosted_llm_credentials.openai.api_key
+    def get_hosted_credentials(self) -> Union[str | dict]:
+        raise ProviderTokenNotInitError(
+            f"No valid {self.get_provider_name().value} model provider credentials found. "
+            f"Please go to Settings -> Model Provider to complete your provider credentials."
+        )
 
-    def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
+    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
         """
         Returns the provider configs.
         """
         try:
-            config = self.get_provider_api_key()
+            config = self.get_provider_api_key(only_custom=only_custom)
         except:
             config = ''
 

+ 4 - 4
api/core/llm/provider/llm_provider_service.py

@@ -31,11 +31,11 @@ class LLMProviderService:
     def get_credentials(self, model_id: Optional[str] = None) -> dict:
         return self.provider.get_credentials(model_id)
 
-    def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
-        return self.provider.get_provider_configs(obfuscated)
+    def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
+        return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
 
-    def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
-        return self.provider.get_provider(prefer_custom)
+    def get_provider_db_record(self) -> Optional[Provider]:
+        return self.provider.get_provider()
 
     def config_validate(self, config: Union[dict | str]):
         """

+ 11 - 0
api/core/llm/provider/openai_provider.py

@@ -4,6 +4,8 @@ from typing import Optional, Union
 import openai
 from openai.error import AuthenticationError, OpenAIError
 
+from core import hosted_llm_credentials
+from core.llm.error import ProviderTokenNotInitError
 from core.llm.moderation import Moderation
 from core.llm.provider.base import BaseProvider
 from core.llm.provider.errors import ValidateFailedError
@@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
         except Exception as ex:
             logging.exception('OpenAI config validation failed')
             raise ex
+
+    def get_hosted_credentials(self) -> Union[str | dict]:
+        if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
+            raise ProviderTokenNotInitError(
+                f"No valid {self.get_provider_name().value} model provider credentials found. "
+                f"Please go to Settings -> Model Provider to complete your provider credentials."
+            )
+
+        return hosted_llm_credentials.openai.api_key

+ 19 - 36
api/core/llm/streamable_azure_chat_open_ai.py

@@ -1,11 +1,11 @@
-from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
-from langchain.schema import BaseMessage, ChatResult, LLMResult
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import BaseMessage, LLMResult
 from langchain.chat_models import AzureChatOpenAI
 from typing import Optional, List, Dict, Any
 
 from pydantic import root_validator
 
-from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
 
 
 class StreamableAzureChatOpenAI(AzureChatOpenAI):
@@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
             "organization": self.openai_organization if self.openai_organization else None,
         }
 
-    def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
-        """Get the number of tokens in a list of messages.
-
-        Args:
-            messages: The messages to count the tokens of.
-
-        Returns:
-            The number of tokens in the messages.
-        """
-        tokens_per_message = 5
-        tokens_per_request = 3
-
-        message_tokens = tokens_per_request
-        message_strs = ''
-        for message in messages:
-            message_strs += message.content
-            message_tokens += tokens_per_message
-
-        # calc once
-        message_tokens += self.get_num_tokens(message_strs)
-
-        return message_tokens
-
-    @handle_llm_exceptions
+    @handle_openai_exceptions
     def generate(
             self,
             messages: List[List[BaseMessage]],
@@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
     ) -> LLMResult:
         return super().generate(messages, stop, callbacks, **kwargs)
 
-    @handle_llm_exceptions_async
-    async def agenerate(
-            self,
-            messages: List[List[BaseMessage]],
-            stop: Optional[List[str]] = None,
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> LLMResult:
-        return await super().agenerate(messages, stop, callbacks, **kwargs)
+    @classmethod
+    def get_kwargs_from_model_params(cls, params: dict):
+        model_kwargs = {
+            'top_p': params.get('top_p', 1),
+            'frequency_penalty': params.get('frequency_penalty', 0),
+            'presence_penalty': params.get('presence_penalty', 0),
+        }
+
+        del params['top_p']
+        del params['frequency_penalty']
+        del params['presence_penalty']
+
+        params['model_kwargs'] = model_kwargs
+
+        return params

+ 5 - 11
api/core/llm/streamable_azure_open_ai.py

@@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any
 
 from pydantic import root_validator
 
-from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
 
 
 class StreamableAzureOpenAI(AzureOpenAI):
@@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
             "organization": self.openai_organization if self.openai_organization else None,
         }}
 
-    @handle_llm_exceptions
+    @handle_openai_exceptions
     def generate(
             self,
             prompts: List[str],
@@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
     ) -> LLMResult:
         return super().generate(prompts, stop, callbacks, **kwargs)
 
-    @handle_llm_exceptions_async
-    async def agenerate(
-            self,
-            prompts: List[str],
-            stop: Optional[List[str]] = None,
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> LLMResult:
-        return await super().agenerate(prompts, stop, callbacks, **kwargs)
+    @classmethod
+    def get_kwargs_from_model_params(cls, params: dict):
+        return params

+ 39 - 0
api/core/llm/streamable_chat_anthropic.py

@@ -0,0 +1,39 @@
+from typing import List, Optional, Any, Dict
+
+from langchain.callbacks.manager import Callbacks
+from langchain.chat_models import ChatAnthropic
+from langchain.schema import BaseMessage, LLMResult
+
+from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
+
+
+class StreamableChatAnthropic(ChatAnthropic):
+    """
+    Wrapper around Anthropic's large language model.
+    """
+
+    @handle_anthropic_exceptions
+    def generate(
+            self,
+            messages: List[List[BaseMessage]],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            *,
+            tags: Optional[List[str]] = None,
+            metadata: Optional[Dict[str, Any]] = None,
+            **kwargs: Any,
+    ) -> LLMResult:
+        return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
+
+    @classmethod
+    def get_kwargs_from_model_params(cls, params: dict):
+        params['model'] = params.get('model_name')
+        del params['model_name']
+
+        params['max_tokens_to_sample'] = params.get('max_tokens')
+        del params['max_tokens']
+
+        del params['frequency_penalty']
+        del params['presence_penalty']
+
+        return params

+ 17 - 34
api/core/llm/streamable_chat_open_ai.py

@@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any
 
 from pydantic import root_validator
 
-from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
 
 
 class StreamableChatOpenAI(ChatOpenAI):
@@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI):
             "organization": self.openai_organization if self.openai_organization else None,
         }
 
-    def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
-        """Get the number of tokens in a list of messages.
-
-        Args:
-            messages: The messages to count the tokens of.
-
-        Returns:
-            The number of tokens in the messages.
-        """
-        tokens_per_message = 5
-        tokens_per_request = 3
-
-        message_tokens = tokens_per_request
-        message_strs = ''
-        for message in messages:
-            message_strs += message.content
-            message_tokens += tokens_per_message
-
-        # calc once
-        message_tokens += self.get_num_tokens(message_strs)
-
-        return message_tokens
-
-    @handle_llm_exceptions
+    @handle_openai_exceptions
     def generate(
             self,
             messages: List[List[BaseMessage]],
@@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI):
     ) -> LLMResult:
         return super().generate(messages, stop, callbacks, **kwargs)
 
-    @handle_llm_exceptions_async
-    async def agenerate(
-            self,
-            messages: List[List[BaseMessage]],
-            stop: Optional[List[str]] = None,
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> LLMResult:
-        return await super().agenerate(messages, stop, callbacks, **kwargs)
+    @classmethod
+    def get_kwargs_from_model_params(cls, params: dict):
+        model_kwargs = {
+            'top_p': params.get('top_p', 1),
+            'frequency_penalty': params.get('frequency_penalty', 0),
+            'presence_penalty': params.get('presence_penalty', 0),
+        }
+
+        del params['top_p']
+        del params['frequency_penalty']
+        del params['presence_penalty']
+
+        params['model_kwargs'] = model_kwargs
+
+        return params

+ 5 - 11
api/core/llm/streamable_open_ai.py

@@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping
 from langchain import OpenAI
 from pydantic import root_validator
 
-from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
 
 
 class StreamableOpenAI(OpenAI):
@@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI):
             "organization": self.openai_organization if self.openai_organization else None,
         }}
 
-    @handle_llm_exceptions
+    @handle_openai_exceptions
     def generate(
             self,
             prompts: List[str],
@@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI):
     ) -> LLMResult:
         return super().generate(prompts, stop, callbacks, **kwargs)
 
-    @handle_llm_exceptions_async
-    async def agenerate(
-            self,
-            prompts: List[str],
-            stop: Optional[List[str]] = None,
-            callbacks: Callbacks = None,
-            **kwargs: Any,
-    ) -> LLMResult:
-        return await super().agenerate(prompts, stop, callbacks, **kwargs)
+    @classmethod
+    def get_kwargs_from_model_params(cls, params: dict):
+        return params

+ 3 - 2
api/core/llm/whisper.py

@@ -1,6 +1,7 @@
 import openai
+
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
 from models.provider import ProviderName
-from core.llm.error_handle_wraps import handle_llm_exceptions
 from core.llm.provider.base import BaseProvider
 
 
@@ -13,7 +14,7 @@ class Whisper:
             self.client = openai.Audio
             self.credentials = provider.get_credentials()
 
-    @handle_llm_exceptions
+    @handle_openai_exceptions
     def transcribe(self, file):
         return self.client.transcribe(
             model='whisper-1', 

+ 27 - 0
api/core/llm/wrappers/anthropic_wrapper.py

@@ -0,0 +1,27 @@
+import logging
+from functools import wraps
+
+import anthropic
+
+from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
+    LLMBadRequestError
+
+
+def handle_anthropic_exceptions(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        try:
+            return func(*args, **kwargs)
+        except anthropic.APIConnectionError as e:
+            logging.exception("Failed to connect to Anthropic API.")
+            raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
+        except anthropic.RateLimitError:
+            raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
+        except anthropic.AuthenticationError as e:
+            raise LLMAuthorizationError(f"Anthropic: {e.message}")
+        except anthropic.BadRequestError as e:
+            raise LLMBadRequestError(f"Anthropic: {e.message}")
+        except anthropic.APIStatusError as e:
+            raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
+
+    return wrapper

+ 1 - 25
api/core/llm/error_handle_wraps.py → api/core/llm/wrappers/openai_wrapper.py

@@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
     LLMBadRequestError
 
 
-def handle_llm_exceptions(func):
+def handle_openai_exceptions(func):
     @wraps(func)
     def wrapper(*args, **kwargs):
         try:
@@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
             raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
 
     return wrapper
-
-
-def handle_llm_exceptions_async(func):
-    @wraps(func)
-    async def wrapper(*args, **kwargs):
-        try:
-            return await func(*args, **kwargs)
-        except openai.error.InvalidRequestError as e:
-            logging.exception("Invalid request to OpenAI API.")
-            raise LLMBadRequestError(str(e))
-        except openai.error.APIConnectionError as e:
-            logging.exception("Failed to connect to OpenAI API.")
-            raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
-        except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
-            logging.exception("OpenAI service unavailable.")
-            raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
-        except openai.error.RateLimitError as e:
-            raise LLMRateLimitError(str(e))
-        except openai.error.AuthenticationError as e:
-            raise LLMAuthorizationError(str(e))
-        except openai.error.OpenAIError as e:
-            raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
-
-    return wrapper

+ 5 - 5
api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py

@@ -1,7 +1,7 @@
 from typing import Any, List, Dict, Union
 
 from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
+from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
 
 from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
 from core.llm.streamable_open_ai import StreamableOpenAI
@@ -12,8 +12,8 @@ from models.model import Conversation, Message
 class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
     conversation: Conversation
     human_prefix: str = "Human"
-    ai_prefix: str = "AI"
-    llm: Union[StreamableChatOpenAI | StreamableOpenAI]
+    ai_prefix: str = "Assistant"
+    llm: BaseLanguageModel
     memory_key: str = "chat_history"
     max_token_limit: int = 2000
     message_limit: int = 10
@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
             return chat_messages
 
         # prune the chat message if it exceeds the max token limit
-        curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
+        curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
         if curr_buffer_length > self.max_token_limit:
             pruned_memory = []
             while curr_buffer_length > self.max_token_limit and chat_messages:
                 pruned_memory.append(chat_messages.pop(0))
-                curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
+                curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
 
         return chat_messages
 

+ 2 - 2
api/core/tool/dataset_index_tool.py

@@ -30,7 +30,7 @@ class DatasetTool(BaseTool):
         else:
             model_credentials = LLMBuilder.get_model_credentials(
                 tenant_id=self.dataset.tenant_id,
-                model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
+                model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
                 model_name='text-embedding-ada-002'
             )
 
@@ -60,7 +60,7 @@ class DatasetTool(BaseTool):
     async def _arun(self, tool_input: str) -> str:
         model_credentials = LLMBuilder.get_model_credentials(
             tenant_id=self.dataset.tenant_id,
-            model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
+            model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
             model_name='text-embedding-ada-002'
         )
 

+ 16 - 1
api/events/event_handlers/create_provider_when_tenant_created.py

@@ -1,4 +1,7 @@
+from flask import current_app
+
 from events.tenant_event import tenant_was_updated
+from models.provider import ProviderName
 from services.provider_service import ProviderService
 
 
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
 def handle(sender, **kwargs):
     tenant = sender
     if tenant.status == 'normal':
-        ProviderService.create_system_provider(tenant)
+        ProviderService.create_system_provider(
+            tenant,
+            ProviderName.OPENAI.value,
+            current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
+            True
+        )
+
+        ProviderService.create_system_provider(
+            tenant,
+            ProviderName.ANTHROPIC.value,
+            current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
+            True
+        )

+ 16 - 1
api/events/event_handlers/create_provider_when_tenant_updated.py

@@ -1,4 +1,7 @@
+from flask import current_app
+
 from events.tenant_event import tenant_was_created
+from models.provider import ProviderName
 from services.provider_service import ProviderService
 
 
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
 def handle(sender, **kwargs):
     tenant = sender
     if tenant.status == 'normal':
-        ProviderService.create_system_provider(tenant)
+        ProviderService.create_system_provider(
+            tenant,
+            ProviderName.OPENAI.value,
+            current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
+            True
+        )
+
+        ProviderService.create_system_provider(
+            tenant,
+            ProviderName.ANTHROPIC.value,
+            current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
+            True
+        )

+ 2 - 1
api/requirements.txt

@@ -10,7 +10,7 @@ flask-session2==1.3.1
 flask-cors==3.0.10
 gunicorn~=20.1.0
 gevent~=22.10.2
-langchain==0.0.209
+langchain==0.0.230
 openai~=0.27.5
 psycopg2-binary~=2.9.6
 pycryptodome==3.17
@@ -35,3 +35,4 @@ docx2txt==0.8
 pypdfium2==4.16.0
 resend~=0.5.1
 pyjwt~=2.6.0
+anthropic~=0.3.4

+ 28 - 4
api/services/app_model_config_service.py

@@ -6,6 +6,30 @@ from models.account import Account
 from services.dataset_service import DatasetService
 from core.llm.llm_builder import LLMBuilder
 
+MODEL_PROVIDERS = [
+    'openai',
+    'anthropic',
+]
+
+MODELS_BY_APP_MODE = {
+    'chat': [
+        'claude-instant-1',
+        'claude-2',
+        'gpt-4',
+        'gpt-4-32k',
+        'gpt-3.5-turbo',
+        'gpt-3.5-turbo-16k',
+    ],
+    'completion': [
+        'claude-instant-1',
+        'claude-2',
+        'gpt-4',
+        'gpt-4-32k',
+        'gpt-3.5-turbo',
+        'gpt-3.5-turbo-16k',
+        'text-davinci-003',
+    ]
+}
 
 class AppModelConfigService:
     @staticmethod
@@ -125,7 +149,7 @@ class AppModelConfigService:
         if not isinstance(config["speech_to_text"]["enabled"], bool):
             raise ValueError("enabled in speech_to_text must be of boolean type")
         
-        provider_name = LLMBuilder.get_default_provider(account.current_tenant_id)
+        provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1')
 
         if config["speech_to_text"]["enabled"] and provider_name != 'openai':
             raise ValueError("provider not support speech to text")
@@ -153,14 +177,14 @@ class AppModelConfigService:
             raise ValueError("model must be of object type")
 
         # model.provider
-        if 'provider' not in config["model"] or config["model"]["provider"] != "openai":
-            raise ValueError("model.provider must be 'openai'")
+        if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
+            raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
 
         # model.name
         if 'name' not in config["model"]:
             raise ValueError("model.name is required")
 
-        if config["model"]["name"] not in llm_constant.models_by_mode[mode]:
+        if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
             raise ValueError("model.name must be in the specified model list")
 
         # model.completion_params

+ 1 - 6
api/services/audio_service.py

@@ -27,7 +27,7 @@ class AudioService:
             message = f"Audio size larger than {FILE_SIZE} mb"
             raise AudioTooLargeServiceError(message)
         
-        provider_name = LLMBuilder.get_default_provider(tenant_id)
+        provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1')
         if provider_name != ProviderName.OPENAI.value:
             raise ProviderNotSupportSpeechToTextServiceError()
 
@@ -37,8 +37,3 @@ class AudioService:
         buffer.name = 'temp.mp3'
 
         return Whisper(provider_service.provider).transcribe(buffer)
-
-
-
-        
-        

+ 1 - 1
api/services/hit_testing_service.py

@@ -31,7 +31,7 @@ class HitTestingService:
 
         model_credentials = LLMBuilder.get_model_credentials(
             tenant_id=dataset.tenant_id,
-            model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
+            model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
             model_name='text-embedding-ada-002'
         )
 

+ 24 - 34
api/services/provider_service.py

@@ -10,50 +10,40 @@ from models.provider import *
 class ProviderService:
 
     @staticmethod
-    def init_supported_provider(tenant, edition):
+    def init_supported_provider(tenant):
         """Initialize the model provider, check whether the supported provider has a record"""
 
-        providers = Provider.query.filter_by(tenant_id=tenant.id).all()
+        need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
 
-        openai_provider_exists = False
-        azure_openai_provider_exists = False
-
-        # TODO: The cloud version needs to construct the data of the SYSTEM type
+        providers = db.session.query(Provider).filter(
+            Provider.tenant_id == tenant.id,
+            Provider.provider_type == ProviderType.CUSTOM.value,
+            Provider.provider_name.in_(need_init_provider_names)
+        ).all()
 
+        exists_provider_names = []
         for provider in providers:
-            if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
-                openai_provider_exists = True
-            if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
-                azure_openai_provider_exists = True
+            exists_provider_names.append(provider.provider_name)
 
-        # Initialize the model provider, check whether the supported provider has a record
+        not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
 
-        # Create default providers if they don't exist
-        if not openai_provider_exists:
-            openai_provider = Provider(
-                tenant_id=tenant.id,
-                provider_name=ProviderName.OPENAI.value,
-                provider_type=ProviderType.CUSTOM.value,
-                is_valid=False
-            )
-            db.session.add(openai_provider)
-
-        if not azure_openai_provider_exists:
-            azure_openai_provider = Provider(
-                tenant_id=tenant.id,
-                provider_name=ProviderName.AZURE_OPENAI.value,
-                provider_type=ProviderType.CUSTOM.value,
-                is_valid=False
-            )
-            db.session.add(azure_openai_provider)
+        if not_exists_provider_names:
+            # Initialize the model provider, check whether the supported provider has a record
+            for provider_name in not_exists_provider_names:
+                provider = Provider(
+                    tenant_id=tenant.id,
+                    provider_name=provider_name,
+                    provider_type=ProviderType.CUSTOM.value,
+                    is_valid=False
+                )
+                db.session.add(provider)
 
-        if not openai_provider_exists or not azure_openai_provider_exists:
             db.session.commit()
 
     @staticmethod
-    def get_obfuscated_api_key(tenant, provider_name: ProviderName):
+    def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
         llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
-        return llm_provider_service.get_provider_configs(obfuscated=True)
+        return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
 
     @staticmethod
     def get_token_type(tenant, provider_name: ProviderName):
@@ -73,7 +63,7 @@ class ProviderService:
         return llm_provider_service.get_encrypted_token(configs)
 
     @staticmethod
-    def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value,
+    def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
                                is_valid: bool = True):
         if current_app.config['EDITION'] != 'CLOUD':
             return
@@ -90,7 +80,7 @@ class ProviderService:
                 provider_name=provider_name,
                 provider_type=ProviderType.SYSTEM.value,
                 quota_type=ProviderQuotaType.TRIAL.value,
-                quota_limit=200,
+                quota_limit=quota_limit,
                 encrypted_config='',
                 is_valid=is_valid,
             )

+ 2 - 2
api/services/workspace_service.py

@@ -1,6 +1,6 @@
 from extensions.ext_database import db
 from models.account import Tenant
-from models.provider import Provider, ProviderType
+from models.provider import Provider, ProviderType, ProviderName
 
 
 class WorkspaceService:
@@ -33,7 +33,7 @@ class WorkspaceService:
                 if provider.is_valid and provider.encrypted_config:
                     custom_provider = provider
             elif provider.provider_type == ProviderType.SYSTEM.value:
-                if provider.is_valid:
+                if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid:
                     system_provider = provider
 
         if system_provider and not custom_provider: