Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4c957ffe35 | |||
| 41a7a660c8 | |||
| 44c8c63899 |
@@ -40,25 +40,46 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|||||||
|
|
||||||
async def get_embedding(self, text: str) -> list[float]:
|
async def get_embedding(self, text: str) -> list[float]:
|
||||||
"""获取文本的嵌入"""
|
"""获取文本的嵌入"""
|
||||||
|
kwargs = self._embedding_kwargs()
|
||||||
embedding = await self.client.embeddings.create(
|
embedding = await self.client.embeddings.create(
|
||||||
input=text,
|
input=text,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
dimensions=self.get_dim(),
|
**kwargs,
|
||||||
)
|
)
|
||||||
return embedding.data[0].embedding
|
return embedding.data[0].embedding
|
||||||
|
|
||||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||||
"""批量获取文本的嵌入"""
|
"""批量获取文本的嵌入"""
|
||||||
|
kwargs = self._embedding_kwargs()
|
||||||
embeddings = await self.client.embeddings.create(
|
embeddings = await self.client.embeddings.create(
|
||||||
input=text,
|
input=text,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
dimensions=self.get_dim(),
|
**kwargs,
|
||||||
)
|
)
|
||||||
return [item.embedding for item in embeddings.data]
|
return [item.embedding for item in embeddings.data]
|
||||||
|
|
||||||
|
def _embedding_kwargs(self) -> dict:
|
||||||
|
"""构建嵌入请求的可选参数"""
|
||||||
|
kwargs = {}
|
||||||
|
if "embedding_dimensions" in self.provider_config:
|
||||||
|
try:
|
||||||
|
kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
|
||||||
|
)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
def get_dim(self) -> int:
|
def get_dim(self) -> int:
|
||||||
"""获取向量的维度"""
|
"""获取向量的维度"""
|
||||||
return int(self.provider_config.get("embedding_dimensions", 1024))
|
if "embedding_dimensions" in self.provider_config:
|
||||||
|
try:
|
||||||
|
return int(self.provider_config["embedding_dimensions"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
async def terminate(self):
|
async def terminate(self):
|
||||||
if self.client:
|
if self.client:
|
||||||
|
|||||||
Reference in New Issue
Block a user