浏览代码

Feature/use jwt in web (#533)

Co-authored-by: crazywoola <li.zheng@dentsplysirona.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
crazywoola 1 年之前
父节点
当前提交
d49ac1e4ac

+ 1 - 0
.gitignore

@@ -109,6 +109,7 @@ venv/
 ENV/
 env.bak/
 venv.bak/
+.conda/
 
 # Spyder project settings
 .spyderproject

+ 1 - 1
api/app.py

@@ -155,7 +155,7 @@ def register_blueprints(app):
          resources={
              r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
          supports_credentials=True,
-         allow_headers=['Content-Type', 'Authorization'],
+         allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
          methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
          expose_headers=['X-Version', 'X-Env']
          )

+ 1 - 1
api/controllers/web/__init__.py

@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
 api = ExternalApi(bp)
 
 
-from . import completion, app, conversation, message, site, saved_message, audio
+from . import completion, app, conversation, message, site, saved_message, audio, passport

+ 64 - 0
api/controllers/web/passport.py

@@ -0,0 +1,64 @@
+# -*- coding:utf-8 -*-
+import uuid
+from controllers.web import api
+from flask_restful import Resource
+from flask import request
+from werkzeug.exceptions import Unauthorized, NotFound
+from models.model import Site, EndUser, App
+from extensions.ext_database import db
+from libs.passport import PassportService
+
+class PassportResource(Resource):
+    """Base resource for passport."""
+    def get(self):
+        app_id = request.headers.get('X-App-Code')
+        if app_id is None:
+            raise Unauthorized('X-App-Code header is missing.')
+
+        # get site from db and check if it is normal
+        site = db.session.query(Site).filter(
+            Site.code == app_id,
+            Site.status == 'normal'
+        ).first()
+        if not site:
+            raise NotFound()
+        # get app from db and check if it is normal and enable_site
+        app_model = db.session.query(App).filter(App.id == site.app_id).first()
+        if not app_model or app_model.status != 'normal' or not app_model.enable_site:
+            raise NotFound()
+        
+        end_user = EndUser(
+            tenant_id=app_model.tenant_id,
+            app_id=app_model.id,
+            type='browser',
+            is_anonymous=True,
+            session_id=generate_session_id(),
+        )
+        db.session.add(end_user)
+        db.session.commit()
+
+        payload = {
+            "iss": site.app_id,
+            'sub': 'Web API Passport',
+            'app_id': site.app_id,
+            'end_user_id': end_user.id,
+        }
+
+        tk = PassportService().issue(payload)
+
+        return {
+            'access_token': tk,
+        }
+
+api.add_resource(PassportResource, '/passport')
+
+def generate_session_id():
+    """
+    Generate a unique session ID.
+    """
+    while True:
+        session_id = str(uuid.uuid4())
+        existing_count = db.session.query(EndUser) \
+            .filter(EndUser.session_id == session_id).count()
+        if existing_count == 0:
+            return session_id

+ 16 - 78
api/controllers/web/wraps.py

@@ -1,110 +1,48 @@
 # -*- coding:utf-8 -*-
-import uuid
 from functools import wraps
 
-from flask import request, session
+from flask import request
 from flask_restful import Resource
 from werkzeug.exceptions import NotFound, Unauthorized
 
 from extensions.ext_database import db
-from models.model import App, Site, EndUser
+from models.model import App, EndUser
+from libs.passport import PassportService
 
-
-def validate_token(view=None):
+def validate_jwt_token(view=None):
     def decorator(view):
         @wraps(view)
         def decorated(*args, **kwargs):
-            site = validate_and_get_site()
-
-            app_model = db.session.query(App).filter(App.id == site.app_id).first()
-            if not app_model:
-                raise NotFound()
-
-            if app_model.status != 'normal':
-                raise NotFound()
-
-            if not app_model.enable_site:
-                raise NotFound()
-
-            end_user = create_or_update_end_user_for_session(app_model)
+            app_model, end_user = decode_jwt_token()
 
             return view(app_model, end_user, *args, **kwargs)
         return decorated
-
     if view:
         return decorator(view)
     return decorator
 
-
-def validate_and_get_site():
-    """
-    Validate and get API token.
-    """
+def decode_jwt_token():
     auth_header = request.headers.get('Authorization')
     if auth_header is None:
         raise Unauthorized('Authorization header is missing.')
 
     if ' ' not in auth_header:
         raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-
-    auth_scheme, auth_token = auth_header.split(None, 1)
+    
+    auth_scheme, tk = auth_header.split(None, 1)
     auth_scheme = auth_scheme.lower()
 
     if auth_scheme != 'bearer':
         raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-
-    site = db.session.query(Site).filter(
-        Site.code == auth_token,
-        Site.status == 'normal'
-    ).first()
-
-    if not site:
+    decoded = PassportService().verify(tk)
+    app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
+    if not app_model:
+        raise NotFound()
+    end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
+    if not end_user:
         raise NotFound()
 
-    return site
-
-
-def create_or_update_end_user_for_session(app_model):
-    """
-    Create or update session terminal based on session ID.
-    """
-    if 'session_id' not in session:
-        session['session_id'] = generate_session_id()
-
-    session_id = session.get('session_id')
-    end_user = db.session.query(EndUser) \
-        .filter(
-        EndUser.session_id == session_id,
-        EndUser.type == 'browser'
-    ).first()
-
-    if end_user is None:
-        end_user = EndUser(
-            tenant_id=app_model.tenant_id,
-            app_id=app_model.id,
-            type='browser',
-            is_anonymous=True,
-            session_id=session_id
-        )
-        db.session.add(end_user)
-        db.session.commit()
-
-    return end_user
-
-
-def generate_session_id():
-    """
-    Generate a unique session ID.
-    """
-    count = 1
-    session_id = ''
-    while count != 0:
-        session_id = str(uuid.uuid4())
-        count = db.session.query(EndUser) \
-            .filter(EndUser.session_id == session_id).count()
-
-    return session_id
-
+    return app_model, end_user
 
 class WebApiResource(Resource):
-    method_decorators = [validate_token]
+    method_decorators = [validate_jwt_token]

+ 20 - 0
api/libs/passport.py

@@ -0,0 +1,20 @@
+# -*- coding:utf-8 -*-
+import jwt
+from werkzeug.exceptions import Unauthorized
+from flask import current_app
+class PassportService:
+    def __init__(self):
+        self.sk = current_app.config.get('SECRET_KEY')
+    
+    def issue(self, payload):
+        return jwt.encode(payload, self.sk, algorithm='HS256')
+    
+    def verify(self, token):
+        try:
+            return jwt.decode(token, self.sk, algorithms=['HS256'])
+        except jwt.exceptions.InvalidSignatureError:
+            raise Unauthorized('Invalid token signature.')
+        except jwt.exceptions.DecodeError:
+            raise Unauthorized('Invalid token.')
+        except jwt.exceptions.ExpiredSignatureError:
+            raise Unauthorized('Token has expired.')

+ 2 - 1
api/requirements.txt

@@ -32,4 +32,5 @@ redis~=4.5.4
 openpyxl==3.1.2
 chardet~=5.1.0
 docx2txt==0.8
-pypdfium2==4.16.0
+pypdfium2==4.16.0
+pyjwt~=2.6.0

+ 17 - 2
web/app/components/share/chat/index.tsx

@@ -8,13 +8,26 @@ import { useContext } from 'use-context-selector'
 import produce from 'immer'
 import { useBoolean, useGetState } from 'ahooks'
 import AppUnavailable from '../../base/app-unavailable'
+import { checkOrSetAccessToken } from '../utils'
 import useConversation from './hooks/use-conversation'
 import s from './style.module.css'
 import { ToastContext } from '@/app/components/base/toast'
 import Sidebar from '@/app/components/share/chat/sidebar'
 import ConfigSence from '@/app/components/share/chat/config-scence'
 import Header from '@/app/components/share/header'
-import { delConversation, fetchAppInfo, fetchAppParams, fetchChatList, fetchConversations, fetchSuggestedQuestions, pinConversation, sendChatMessage, stopChatMessageResponding, unpinConversation, updateFeedback } from '@/service/share'
+import {
+  delConversation,
+  fetchAppInfo,
+  fetchAppParams,
+  fetchChatList,
+  fetchConversations,
+  fetchSuggestedQuestions,
+  pinConversation,
+  sendChatMessage,
+  stopChatMessageResponding,
+  unpinConversation,
+  updateFeedback,
+} from '@/service/share'
 import type { ConversationItem, SiteInfo } from '@/models/share'
 import type { PromptConfig, SuggestedQuestionsAfterAnswerConfig } from '@/models/debug'
 import type { Feedbacktype, IChatItem } from '@/app/components/app/chat'
@@ -296,7 +309,9 @@ const Main: FC<IMainProps> = ({
     return fetchConversations(isInstalledApp, installedAppInfo?.id, undefined, undefined, 100)
   }
 
-  const fetchInitData = () => {
+  const fetchInitData = async () => {
+    await checkOrSetAccessToken()
+
     return Promise.all([isInstalledApp
       ? {
         app_id: installedAppInfo?.id,

+ 5 - 5
web/app/components/share/text-generation/index.tsx

@@ -7,6 +7,7 @@ import { useBoolean, useClickAway, useGetState } from 'ahooks'
 import { XMarkIcon } from '@heroicons/react/24/outline'
 import TabHeader from '../../base/tab-header'
 import Button from '../../base/button'
+import { checkOrSetAccessToken } from '../utils'
 import s from './style.module.css'
 import RunBatch from './run-batch'
 import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
@@ -76,9 +77,6 @@ const TextGeneration: FC<IMainProps> = ({
     const res: any = await doFetchSavedMessage(isInstalledApp, installedAppInfo?.id)
     setSavedMessages(res.data)
   }
-  useEffect(() => {
-    fetchSavedMessage()
-  }, [])
   const handleSaveMessage = async (messageId: string) => {
     await saveMessage(messageId, isInstalledApp, installedAppInfo?.id)
     notify({ type: 'success', message: t('common.api.saved') })
@@ -256,7 +254,9 @@ const TextGeneration: FC<IMainProps> = ({
     setAllTaskList(newAllTaskList)
   }
 
-  const fetchInitData = () => {
+  const fetchInitData = async () => {
+    await checkOrSetAccessToken()
+
     return Promise.all([isInstalledApp
       ? {
         app_id: installedAppInfo?.id,
@@ -267,7 +267,7 @@ const TextGeneration: FC<IMainProps> = ({
         },
         plan: 'basic',
       }
-      : fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id)])
+      : fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id), fetchSavedMessage()])
   }
 
   useEffect(() => {

+ 18 - 0
web/app/components/share/utils.ts

@@ -0,0 +1,18 @@
+import { fetchAccessToken } from '@/service/share'
+
+export const checkOrSetAccessToken = async () => {
+  const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
+  const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
+  let accessTokenJson = { [sharedToken]: '' }
+  try {
+    accessTokenJson = JSON.parse(accessToken)
+  }
+  catch (e) {
+
+  }
+  if (!accessTokenJson[sharedToken]) {
+    const res = await fetchAccessToken(sharedToken)
+    accessTokenJson[sharedToken] = res.access_token
+    localStorage.setItem('token', JSON.stringify(accessTokenJson))
+  }
+}

+ 10 - 2
web/service/base.ts

@@ -142,7 +142,15 @@ const baseFetch = (
   const options = Object.assign({}, baseOptions, fetchOptions)
   if (isPublicAPI) {
     const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
-    options.headers.set('Authorization', `bearer ${sharedToken}`)
+    const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
+    let accessTokenJson = { [sharedToken]: '' }
+    try {
+      accessTokenJson = JSON.parse(accessToken)
+    }
+    catch (e) {
+
+    }
+    options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
   }
 
   if (deleteContentType) {
@@ -194,7 +202,7 @@ const baseFetch = (
               case 401: {
                 if (isPublicAPI) {
                   Toast.notify({ type: 'error', message: 'Invalid token' })
-                  return
+                  return bodyJson.then((data: any) => Promise.reject(data))
                 }
                 const loginUrl = `${globalThis.location.origin}/signin`
                 if (IS_CE_EDITION) {

+ 6 - 0
web/service/share.ts

@@ -118,3 +118,9 @@ export const fetchSuggestedQuestions = (messageId: string, isInstalledApp: boole
 export const audioToText = (url: string, isPublicAPI: boolean, body: FormData) => {
   return (getAction('post', !isPublicAPI))(url, { body }, { bodyStringify: false, deleteContentType: true }) as Promise<{ text: string }>
 }
+
+export const fetchAccessToken = async (appCode: string) => {
+  const headers = new Headers()
+  headers.append('X-App-Code', appCode)
+  return get('/passport', { headers }) as Promise<{ access_token: string }>
+}