|
@@ -31,7 +31,7 @@ def test_invoke_model(setup_google_mock):
|
|
|
model = GoogleLargeLanguageModel()
|
|
|
|
|
|
response = model.invoke(
|
|
|
- model="gemini-pro",
|
|
|
+ model="gemini-1.5-pro",
|
|
|
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
|
|
prompt_messages=[
|
|
|
SystemPromptMessage(
|
|
@@ -48,7 +48,7 @@ def test_invoke_model(setup_google_mock):
|
|
|
]
|
|
|
),
|
|
|
],
|
|
|
- model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048},
|
|
|
+ model_parameters={"temperature": 0.5, "top_p": 1.0, "max_output_tokens": 2048},
|
|
|
stop=["How"],
|
|
|
stream=False,
|
|
|
user="abc-123",
|
|
@@ -63,7 +63,7 @@ def test_invoke_stream_model(setup_google_mock):
|
|
|
model = GoogleLargeLanguageModel()
|
|
|
|
|
|
response = model.invoke(
|
|
|
- model="gemini-pro",
|
|
|
+ model="gemini-1.5-pro",
|
|
|
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
|
|
prompt_messages=[
|
|
|
SystemPromptMessage(
|
|
@@ -80,7 +80,7 @@ def test_invoke_stream_model(setup_google_mock):
|
|
|
]
|
|
|
),
|
|
|
],
|
|
|
- model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048},
|
|
|
+ model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens": 2048},
|
|
|
stream=True,
|
|
|
user="abc-123",
|
|
|
)
|
|
@@ -99,7 +99,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
|
|
|
model = GoogleLargeLanguageModel()
|
|
|
|
|
|
result = model.invoke(
|
|
|
- model="gemini-pro-vision",
|
|
|
+ model="gemini-1.5-pro",
|
|
|
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
|
|
prompt_messages=[
|
|
|
SystemPromptMessage(
|
|
@@ -128,7 +128,7 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
|
|
|
model = GoogleLargeLanguageModel()
|
|
|
|
|
|
result = model.invoke(
|
|
|
- model="gemini-pro-vision",
|
|
|
+ model="gemini-1.5-pro",
|
|
|
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
|
|
prompt_messages=[
|
|
|
SystemPromptMessage(content="You are a helpful AI assistant."),
|
|
@@ -164,7 +164,7 @@ def test_get_num_tokens():
|
|
|
model = GoogleLargeLanguageModel()
|
|
|
|
|
|
num_tokens = model.get_num_tokens(
|
|
|
- model="gemini-pro",
|
|
|
+ model="gemini-1.5-pro",
|
|
|
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
|
|
prompt_messages=[
|
|
|
SystemPromptMessage(
|