Browse Source

Fix/remove tsne position test (#5858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Jyong 9 tháng trước cách đây
mục cha
commit
0944ca9d91

+ 2 - 29
api/services/hit_testing_service.py

@@ -4,10 +4,6 @@ import time
 import numpy as np
 from sklearn.manifold import TSNE
 
-from core.embedding.cached_embedding import CacheEmbedding
-from core.model_manager import ModelManager
-from core.model_runtime.entities.model_entities import ModelType
-from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.models.document import Document
 from core.rag.retrieval.retrival_methods import RetrievalMethod
@@ -45,17 +41,6 @@ class HitTestingService:
         if not retrieval_model:
             retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
 
-        # get embedding model
-        model_manager = ModelManager()
-        embedding_model = model_manager.get_model_instance(
-            tenant_id=dataset.tenant_id,
-            model_type=ModelType.TEXT_EMBEDDING,
-            provider=dataset.embedding_model_provider,
-            model=dataset.embedding_model
-        )
-
-        embeddings = CacheEmbedding(embedding_model)
-
         all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
                                                   dataset_id=dataset.id,
                                                   query=query,
@@ -80,20 +65,10 @@ class HitTestingService:
         db.session.add(dataset_query)
         db.session.commit()
 
-        return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
+        return cls.compact_retrieve_response(dataset, query, all_documents)
 
     @classmethod
-    def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: list[Document]):
-        text_embeddings = [
-            embeddings.embed_query(query)
-        ]
-
-        text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
-
-        tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
-
-        query_position = tsne_position_data.pop(0)
-
+    def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
         i = 0
         records = []
         for document in documents:
@@ -113,7 +88,6 @@ class HitTestingService:
             record = {
                 "segment": segment,
                 "score": document.metadata.get('score', None),
-                "tsne_position": tsne_position_data[i]
             }
 
             records.append(record)
@@ -123,7 +97,6 @@ class HitTestingService:
         return {
             "query": {
                 "content": query,
-                "tsne_position": query_position,
             },
             "records": records
         }

+ 3 - 52
web/app/components/datasets/hit-testing/hit-detail.tsx

@@ -2,51 +2,16 @@ import type { FC } from 'react'
 import React from 'react'
 import cn from 'classnames'
 import { useTranslation } from 'react-i18next'
-import ReactECharts from 'echarts-for-react'
 import { SegmentIndexTag } from '../documents/detail/completed'
 import s from '../documents/detail/completed/style.module.css'
 import type { SegmentDetailModel } from '@/models/datasets'
 import Divider from '@/app/components/base/divider'
 
-type IScatterChartProps = {
-  data: Array<number[]>
-  curr: Array<number[]>
-}
-
-const ScatterChart: FC<IScatterChartProps> = ({ data, curr }) => {
-  const option = {
-    xAxis: {},
-    yAxis: {},
-    tooltip: {
-      trigger: 'item',
-      axisPointer: {
-        type: 'cross',
-      },
-    },
-    series: [
-      {
-        type: 'effectScatter',
-        symbolSize: 5,
-        data: curr,
-      },
-      {
-        type: 'scatter',
-        symbolSize: 5,
-        data,
-      },
-    ],
-  }
-  return (
-    <ReactECharts option={option} style={{ height: 380, width: 430 }} />
-  )
-}
-
 type IHitDetailProps = {
   segInfo?: Partial<SegmentDetailModel> & { id: string }
-  vectorInfo?: { curr: Array<number[]>; points: Array<number[]> }
 }
 
-const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => {
+const HitDetail: FC<IHitDetailProps> = ({ segInfo }) => {
   const { t } = useTranslation()
 
   const renderContent = () => {
@@ -65,8 +30,8 @@ const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => {
   }
 
   return (
-    <div className='flex flex-row overflow-x-auto'>
-      <div className="flex-1 bg-gray-25 p-6 min-w-[300px]">
+    <div className='overflow-x-auto'>
+      <div className="bg-gray-25 p-6">
         <div className="flex items-center">
           <SegmentIndexTag
             positionId={segInfo?.position || ''}
@@ -94,20 +59,6 @@ const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => {
             })}
         </div>
       </div>
-      <div className="flex-1 bg-white p-6">
-        <div className="flex items-center">
-          <div className={cn(s.commonIcon, s.bezierCurveIcon)} />
-          <span className={s.numberInfo}>
-            {t('datasetDocuments.segment.vectorHash')}
-          </span>
-        </div>
-        <div
-          className={cn(s.numberInfo, 'w-[400px] truncate text-gray-700 mt-1')}
-        >
-          {segInfo?.index_node_hash}
-        </div>
-        <ScatterChart data={vectorInfo?.points || []} curr={vectorInfo?.curr || []} />
-      </div>
     </div>
   )
 }

+ 2 - 8
web/app/components/datasets/hit-testing/index.tsx

@@ -1,6 +1,6 @@
 'use client'
 import type { FC } from 'react'
-import React, { useEffect, useMemo, useState } from 'react'
+import React, { useEffect, useState } from 'react'
 import { useTranslation } from 'react-i18next'
 import useSWR from 'swr'
 import { omit } from 'lodash-es'
@@ -62,8 +62,6 @@ const HitTesting: FC<Props> = ({ datasetId }: Props) => {
 
   const total = recordsRes?.total || 0
 
-  const points = useMemo(() => (hitResult?.records.map(v => [v.tsne_position.x, v.tsne_position.y]) || []), [hitResult?.records])
-
   const onClickCard = (detail: HitTestingType) => {
     setCurrParagraph({ paraInfo: detail, showModal: true })
   }
@@ -194,17 +192,13 @@ const HitTesting: FC<Props> = ({ datasetId }: Props) => {
         </div>
       </FloatRightContainer>
       <Modal
-        className='!max-w-[960px] !p-0'
+        className='w-[520px] p-0'
         closable
         onClose={() => setCurrParagraph({ showModal: false })}
         isShow={currParagraph.showModal}
       >
         {currParagraph.showModal && <HitDetail
           segInfo={currParagraph.paraInfo?.segment}
-          vectorInfo={{
-            curr: [[currParagraph.paraInfo?.tsne_position?.x || 0, currParagraph.paraInfo?.tsne_position.y || 0]],
-            points,
-          }}
         />}
       </Modal>
       <Drawer isOpen={isShowModifyRetrievalModal} onClose={() => setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'>