ソースを参照

fix: huggingface and replicate. (#1888)

Garfield Dai 1 年間 前
コミット
91ee62d1ab

+ 25 - 14
api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py

@@ -154,20 +154,31 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
                 content=chunk.token.text
             )
 
-            prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
-            completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
-
-            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
-
-            yield LLMResultChunk(
-                model=model,
-                prompt_messages=prompt_messages,
-                delta=LLMResultChunkDelta(
-                    index=index,
-                    message=assistant_prompt_message,
-                    usage=usage,
-                ),
-            )
+            if chunk.details:
+                prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+                completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=assistant_prompt_message,
+                        usage=usage,
+                        finish_reason=chunk.details.finish_reason,
+                    ),
+                )
+            else:
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=assistant_prompt_message,
+                    ),
+                )
 
     def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
         if isinstance(response, str):

+ 33 - 15
api/core/model_runtime/model_providers/replicate/llm/llm.py

@@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
         )
 
         for key, value in input_properties:
-            if key not in ['system_prompt', 'prompt']:
+            if key not in ['system_prompt', 'prompt'] and 'stop' not in key:
                 value_type = value.get('type')
 
                 if not value_type:
@@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
         index = -1
         current_completion: str = ""
         stop_condition_reached = False
+
+        prediction_output_length = 10000
+        is_prediction_output_finished = False
+
         for output in prediction.output_iterator():
             current_completion += output
 
+            if not is_prediction_output_finished and prediction.status == 'succeeded':
+                prediction_output_length = len(prediction.output) - 1
+                is_prediction_output_finished = True
+
             if stop:
                 for s in stop:
                     if s in current_completion:
@@ -172,20 +180,30 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
                 content=output if output else ''
             )
 
-            prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
-            completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
-
-            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
-
-            yield LLMResultChunk(
-                model=model,
-                prompt_messages=prompt_messages,
-                delta=LLMResultChunkDelta(
-                    index=index,
-                    message=assistant_prompt_message,
-                    usage=usage,
-                ),
-            )
+            if index < prediction_output_length:
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=assistant_prompt_message
+                    )
+                )
+            else:
+                prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+                completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=assistant_prompt_message,
+                        usage=usage
+                    )
+                )
 
     def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str],
                                   prompt_messages: list[PromptMessage]) -> LLMResult: