|
@@ -5,6 +5,7 @@ from typing import List, Optional, Any
|
|
|
import openai
|
|
|
from langchain.callbacks.manager import Callbacks
|
|
|
from langchain.schema import LLMResult
|
|
|
+from openai import api_requestor
|
|
|
|
|
|
from core.model_providers.providers.base import BaseModelProvider
|
|
|
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
|
@@ -105,7 +106,13 @@ class OpenAIModel(BaseLLM):
|
|
|
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
|
|
|
|
|
prompts = self._get_prompt_from_messages(messages)
|
|
|
- return self._client.generate([prompts], stop, callbacks)
|
|
|
+
|
|
|
+ try:
|
|
|
+ return self._client.generate([prompts], stop, callbacks)
|
|
|
+ finally:
|
|
|
+ thread_context = api_requestor._thread_context
|
|
|
+ if hasattr(thread_context, "session") and thread_context.session:
|
|
|
+ thread_context.session.close()
|
|
|
|
|
|
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
|
|
"""
|