More changes for Petals integration
This commit is contained in:
parent
df222e7028
commit
9a6004c55d
|
@ -5,12 +5,12 @@
|
|||
const endpointGenerations = import.meta.env.VITE_ENDPOINT_GENERATIONS || '/v1/images/generations'
|
||||
const endpointModels = import.meta.env.VITE_ENDPOINT_MODELS || '/v1/models'
|
||||
const endpointEmbeddings = import.meta.env.VITE_ENDPOINT_EMBEDDINGS || '/v1/embeddings'
|
||||
const endpointPetalsV2Websocket = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
|
||||
const endpointPetals = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
|
||||
|
||||
export const getApiBase = ():string => apiBase
|
||||
export const getEndpointCompletions = ():string => endpointCompletions
|
||||
export const getEndpointGenerations = ():string => endpointGenerations
|
||||
export const getEndpointModels = ():string => endpointModels
|
||||
export const getEndpointEmbeddings = ():string => endpointEmbeddings
|
||||
export const getPetalsV2Websocket = ():string => endpointPetalsV2Websocket
|
||||
export const getPetals = ():string => endpointPetals
|
||||
</script>
|
|
@ -65,6 +65,10 @@ export class ChatCompletionResponse {
|
|||
this.promptTokenCount = tokens
|
||||
}
|
||||
|
||||
getPromptTokenCount (): number {
|
||||
return this.promptTokenCount
|
||||
}
|
||||
|
||||
async updateImageFromSyncResponse (response: ResponseImage, prompt: string, model: Model) {
|
||||
this.setModel(model)
|
||||
for (let i = 0; i < response.data.length; i++) {
|
||||
|
|
|
@ -6,10 +6,11 @@
|
|||
import { deleteMessage, getChatSettingValueNullDefault, insertMessages, getApiKey, addError, currentChatMessages, getMessages, updateMessages, deleteSummaryMessage } from './Storage.svelte'
|
||||
import { scrollToBottom, scrollToMessage } from './Util.svelte'
|
||||
import { getRequestSettingList, defaultModel } from './Settings.svelte'
|
||||
import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { get } from 'svelte/store'
|
||||
import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
|
||||
import { getEndpoint, getModelDetail } from './Models.svelte'
|
||||
import { runOpenAiCompletionRequest } from './ChatRequestOpenAi.svelte'
|
||||
import { runPetalsCompletionRequest } from './ChatRequestPetals.svelte'
|
||||
|
||||
export class ChatRequest {
|
||||
constructor () {
|
||||
|
@ -27,6 +28,14 @@ export class ChatRequest {
|
|||
this.chat = chat
|
||||
}
|
||||
|
||||
getChat (): Chat {
|
||||
return this.chat
|
||||
}
|
||||
|
||||
getChatSettings (): ChatSettings {
|
||||
return this.chat.settings
|
||||
}
|
||||
|
||||
// Common error handler
|
||||
async handleError (response) {
|
||||
let errorResponse
|
||||
|
@ -258,193 +267,10 @@ export class ChatRequest {
|
|||
_this.controller = new AbortController()
|
||||
const signal = _this.controller.signal
|
||||
|
||||
if (modelDetail.type === 'PetalsV2Websocket') {
|
||||
// Petals
|
||||
const ws = new WebSocket(getEndpoint(model))
|
||||
const abortListener = (e:Event) => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromError('User aborted request.')
|
||||
signal.removeEventListener('abort', abortListener)
|
||||
ws.close()
|
||||
}
|
||||
signal.addEventListener('abort', abortListener)
|
||||
const stopSequences = modelDetail.stop || ['###']
|
||||
const stopSequencesC = stopSequences.slice()
|
||||
const stopSequence = stopSequencesC.shift()
|
||||
chatResponse.onFinish(() => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
})
|
||||
ws.onopen = () => {
|
||||
ws.send(JSON.stringify({
|
||||
type: 'open_inference_session',
|
||||
model,
|
||||
max_length: maxTokens || opts.maxTokens
|
||||
}))
|
||||
ws.onmessage = event => {
|
||||
const response = JSON.parse(event.data)
|
||||
if (!response.ok) {
|
||||
const err = new Error('Error opening socket: ' + response.traceback)
|
||||
console.error(err)
|
||||
throw err
|
||||
}
|
||||
const rMessages = request.messages || [] as Message[]
|
||||
const inputArray = (rMessages).reduce((a, m) => {
|
||||
const c = getRoleTag(m.role, model, chatSettings) + m.content
|
||||
a.push(c)
|
||||
return a
|
||||
}, [] as string[])
|
||||
const lastMessage = rMessages[rMessages.length - 1]
|
||||
if (lastMessage && lastMessage.role !== 'assistant') {
|
||||
inputArray.push(getRoleTag('assistant', model, chatSettings))
|
||||
}
|
||||
const petalsRequest = {
|
||||
type: 'generate',
|
||||
inputs: (request.messages || [] as Message[]).reduce((a, m) => {
|
||||
const c = getRoleTag(m.role, model, chatSettings) + m.content
|
||||
a.push(c)
|
||||
return a
|
||||
}, [] as string[]).join(stopSequence),
|
||||
max_new_tokens: 3, // wait for up to 3 tokens before displaying
|
||||
stop_sequence: stopSequence,
|
||||
doSample: 1,
|
||||
temperature: request.temperature || 0,
|
||||
top_p: request.top_p || 0,
|
||||
extra_stop_sequences: stopSequencesC
|
||||
}
|
||||
ws.send(JSON.stringify(petalsRequest))
|
||||
ws.onmessage = event => {
|
||||
// Remove updating indicator
|
||||
_this.updating = 1 // hide indicator, but still signal we're updating
|
||||
_this.updatingMessage = ''
|
||||
const response = JSON.parse(event.data)
|
||||
if (!response.ok) {
|
||||
const err = new Error('Error in response: ' + response.traceback)
|
||||
console.error(err)
|
||||
throw err
|
||||
}
|
||||
window.setTimeout(() => {
|
||||
chatResponse.updateFromAsyncResponse(
|
||||
{
|
||||
model,
|
||||
choices: [{
|
||||
delta: {
|
||||
content: response.outputs,
|
||||
role: 'assistant'
|
||||
},
|
||||
finish_reason: (response.stop ? 'stop' : null)
|
||||
}]
|
||||
} as any
|
||||
)
|
||||
if (response.stop) {
|
||||
const message = chatResponse.getMessages()[0]
|
||||
if (message) {
|
||||
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
||||
if (message.content.endsWith(stopSequences[i])) {
|
||||
message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
|
||||
updateMessages(chatId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 1)
|
||||
}
|
||||
}
|
||||
ws.onclose = () => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromClose()
|
||||
}
|
||||
ws.onerror = err => {
|
||||
console.error(err)
|
||||
throw err
|
||||
}
|
||||
}
|
||||
if (modelDetail.type === 'Petals') {
|
||||
await runPetalsCompletionRequest(request, _this as any, chatResponse as any, signal, opts)
|
||||
} else {
|
||||
// OpenAI
|
||||
const abortListener = (e:Event) => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromError('User aborted request.')
|
||||
signal.removeEventListener('abort', abortListener)
|
||||
}
|
||||
signal.addEventListener('abort', abortListener)
|
||||
const fetchOptions = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${getApiKey()}`,
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
signal
|
||||
}
|
||||
|
||||
if (opts.streaming) {
|
||||
/**
|
||||
* Streaming request/response
|
||||
* We'll get the response a token at a time, as soon as they are ready
|
||||
*/
|
||||
chatResponse.onFinish(() => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
})
|
||||
fetchEventSource(getEndpoint(model), {
|
||||
...fetchOptions,
|
||||
openWhenHidden: true,
|
||||
onmessage (ev) {
|
||||
// Remove updating indicator
|
||||
_this.updating = 1 // hide indicator, but still signal we're updating
|
||||
_this.updatingMessage = ''
|
||||
// console.log('ev.data', ev.data)
|
||||
if (!chatResponse.hasFinished()) {
|
||||
if (ev.data === '[DONE]') {
|
||||
// ?? anything to do when "[DONE]"?
|
||||
} else {
|
||||
const data = JSON.parse(ev.data)
|
||||
// console.log('data', data)
|
||||
window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
|
||||
}
|
||||
}
|
||||
},
|
||||
onclose () {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromClose()
|
||||
},
|
||||
onerror (err) {
|
||||
console.error(err)
|
||||
throw err
|
||||
},
|
||||
async onopen (response) {
|
||||
if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
|
||||
// everything's good
|
||||
} else {
|
||||
// client-side errors are usually non-retriable:
|
||||
await _this.handleError(response)
|
||||
}
|
||||
}
|
||||
}).catch(err => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromError(err.message)
|
||||
})
|
||||
} else {
|
||||
/**
|
||||
* Non-streaming request/response
|
||||
* We'll get the response all at once, after a long delay
|
||||
*/
|
||||
const response = await fetch(getEndpoint(model), fetchOptions)
|
||||
if (!response.ok) {
|
||||
await _this.handleError(response)
|
||||
} else {
|
||||
const json = await response.json()
|
||||
// Remove updating indicator
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromSyncResponse(json)
|
||||
}
|
||||
}
|
||||
await runOpenAiCompletionRequest(request, _this as any, chatResponse as any, signal, opts)
|
||||
}
|
||||
} catch (e) {
|
||||
// console.error(e)
|
||||
|
@ -456,7 +282,7 @@ export class ChatRequest {
|
|||
return chatResponse
|
||||
}
|
||||
|
||||
private getModel (): Model {
|
||||
getModel (): Model {
|
||||
return this.chat.settings.model || defaultModel
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
<script context="module" lang="ts">
|
||||
import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
|
||||
import ChatCompletionResponse from './ChatCompletionResponse.svelte'
|
||||
import ChatRequest from './ChatRequest.svelte'
|
||||
import { getEndpoint } from './Models.svelte'
|
||||
import { getApiKey } from './Storage.svelte'
|
||||
import type { ChatCompletionOpts, Request } from './Types.svelte'
|
||||
|
||||
export const runOpenAiCompletionRequest = async (
|
||||
request: Request,
|
||||
chatRequest: ChatRequest,
|
||||
chatResponse: ChatCompletionResponse,
|
||||
signal: AbortSignal,
|
||||
opts: ChatCompletionOpts) => {
|
||||
// OpenAI Request
|
||||
const model = chatRequest.getModel()
|
||||
const abortListener = (e:Event) => {
|
||||
chatRequest.updating = false
|
||||
chatRequest.updatingMessage = ''
|
||||
chatResponse.updateFromError('User aborted request.')
|
||||
chatRequest.removeEventListener('abort', abortListener)
|
||||
}
|
||||
signal.addEventListener('abort', abortListener)
|
||||
const fetchOptions = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${getApiKey()}`,
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
signal
|
||||
}
|
||||
|
||||
if (opts.streaming) {
|
||||
/**
|
||||
* Streaming request/response
|
||||
* We'll get the response a token at a time, as soon as they are ready
|
||||
*/
|
||||
chatResponse.onFinish(() => {
|
||||
chatRequest.updating = false
|
||||
chatRequest.updatingMessage = ''
|
||||
})
|
||||
fetchEventSource(getEndpoint(model), {
|
||||
...fetchOptions,
|
||||
openWhenHidden: true,
|
||||
onmessage (ev) {
|
||||
// Remove updating indicator
|
||||
chatRequest.updating = 1 // hide indicator, but still signal we're updating
|
||||
chatRequest.updatingMessage = ''
|
||||
// console.log('ev.data', ev.data)
|
||||
if (!chatResponse.hasFinished()) {
|
||||
if (ev.data === '[DONE]') {
|
||||
// ?? anything to do when "[DONE]"?
|
||||
} else {
|
||||
const data = JSON.parse(ev.data)
|
||||
// console.log('data', data)
|
||||
window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
|
||||
}
|
||||
}
|
||||
},
|
||||
onclose () {
|
||||
chatRequest.updating = false
|
||||
chatRequest.updatingMessage = ''
|
||||
chatResponse.updateFromClose()
|
||||
},
|
||||
onerror (err) {
|
||||
console.error(err)
|
||||
throw err
|
||||
},
|
||||
async onopen (response) {
|
||||
if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
|
||||
// everything's good
|
||||
} else {
|
||||
// client-side errors are usually non-retriable:
|
||||
await chatRequest.handleError(response)
|
||||
}
|
||||
}
|
||||
}).catch(err => {
|
||||
chatRequest.updating = false
|
||||
chatRequest.updatingMessage = ''
|
||||
chatResponse.updateFromError(err.message)
|
||||
})
|
||||
} else {
|
||||
/**
|
||||
* Non-streaming request/response
|
||||
* We'll get the response all at once, after a long delay
|
||||
*/
|
||||
const response = await fetch(getEndpoint(model), fetchOptions)
|
||||
if (!response.ok) {
|
||||
await chatRequest.handleError(response)
|
||||
} else {
|
||||
const json = await response.json()
|
||||
// Remove updating indicator
|
||||
chatRequest.updating = false
|
||||
chatRequest.updatingMessage = ''
|
||||
chatResponse.updateFromSyncResponse(json)
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
|
@ -0,0 +1,126 @@
|
|||
<script context="module" lang="ts">
|
||||
import ChatCompletionResponse from './ChatCompletionResponse.svelte'
|
||||
import ChatRequest from './ChatRequest.svelte'
|
||||
import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
|
||||
import type { ChatCompletionOpts, Message, Request } from './Types.svelte'
|
||||
import { getModelMaxTokens } from './Stats.svelte'
|
||||
import { updateMessages } from './Storage.svelte'
|
||||
|
||||
export const runPetalsCompletionRequest = async (
|
||||
request: Request,
|
||||
chatRequest: ChatRequest,
|
||||
chatResponse: ChatCompletionResponse,
|
||||
signal: AbortSignal,
|
||||
opts: ChatCompletionOpts) => {
|
||||
// Petals
|
||||
const model = chatRequest.getModel()
|
||||
const modelDetail = getModelDetail(model)
|
||||
const ws = new WebSocket(getEndpoint(model))
|
||||
const abortListener = (e:Event) => {
|
||||
chatRequest.updating = false
|
||||
chatRequest.updatingMessage = ''
|
||||
chatResponse.updateFromError('User aborted request.')
|
||||
signal.removeEventListener('abort', abortListener)
|
||||
ws.close()
|
||||
}
|
||||
signal.addEventListener('abort', abortListener)
|
||||
const startSequences = modelDetail.start || []
|
||||
const startSequence = startSequences[0] || ''
|
||||
const stopSequences = modelDetail.stop || ['###']
|
||||
const stopSequencesC = stopSequences.slice()
|
||||
const stopSequence = stopSequencesC.shift()
|
||||
const maxTokens = getModelMaxTokens(model)
|
||||
let maxLen = Math.min(opts.maxTokens || chatRequest.chat.max_tokens || maxTokens, maxTokens)
|
||||
const promptTokenCount = chatResponse.getPromptTokenCount()
|
||||
if (promptTokenCount > maxLen) {
|
||||
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
|
||||
}
|
||||
chatResponse.onFinish(() => {
|
||||
chatRequest.updating = false
|
||||
chatRequest.updatingMessage = ''
|
||||
})
|
||||
ws.onopen = () => {
|
||||
ws.send(JSON.stringify({
|
||||
type: 'open_inference_session',
|
||||
model,
|
||||
max_length: maxLen
|
||||
}))
|
||||
ws.onmessage = event => {
|
||||
const response = JSON.parse(event.data)
|
||||
if (!response.ok) {
|
||||
const err = new Error('Error opening socket: ' + response.traceback)
|
||||
console.error(err)
|
||||
throw err
|
||||
}
|
||||
const rMessages = request.messages || [] as Message[]
|
||||
const inputArray = (rMessages).reduce((a, m) => {
|
||||
const c = getRoleTag(m.role, model, chatRequest.chat) + m.content
|
||||
a.push(c)
|
||||
return a
|
||||
}, [] as string[])
|
||||
const lastMessage = rMessages[rMessages.length - 1]
|
||||
if (lastMessage && lastMessage.role !== 'assistant') {
|
||||
inputArray.push(getRoleTag('assistant', model, chatRequest.chat))
|
||||
}
|
||||
const petalsRequest = {
|
||||
type: 'generate',
|
||||
inputs: inputArray.join(stopSequence),
|
||||
max_new_tokens: 3, // wait for up to 3 tokens before displaying
|
||||
stop_sequence: stopSequence,
|
||||
doSample: 1,
|
||||
temperature: request.temperature || 0,
|
||||
top_p: request.top_p || 0,
|
||||
extra_stop_sequences: stopSequencesC
|
||||
}
|
||||
ws.send(JSON.stringify(petalsRequest))
|
||||
ws.onmessage = event => {
|
||||
// Remove updating indicator
|
||||
chatRequest.updating = 1 // hide indicator, but still signal we're updating
|
||||
chatRequest.updatingMessage = ''
|
||||
const response = JSON.parse(event.data)
|
||||
if (!response.ok) {
|
||||
const err = new Error('Error in response: ' + response.traceback)
|
||||
console.error(err)
|
||||
throw err
|
||||
}
|
||||
window.setTimeout(() => {
|
||||
chatResponse.updateFromAsyncResponse(
|
||||
{
|
||||
model,
|
||||
choices: [{
|
||||
delta: {
|
||||
content: response.outputs,
|
||||
role: 'assistant'
|
||||
},
|
||||
finish_reason: (response.stop ? 'stop' : null)
|
||||
}]
|
||||
} as any
|
||||
)
|
||||
if (response.stop) {
|
||||
const message = chatResponse.getMessages()[0]
|
||||
if (message) {
|
||||
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
||||
if (message.content.endsWith(stopSequences[i])) {
|
||||
message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
|
||||
const startS = startSequence[i] || ''
|
||||
if (message.content.startsWith(startS)) message.content = message.content.slice(startS.length)
|
||||
updateMessages(chatRequest.getChat().id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 1)
|
||||
}
|
||||
}
|
||||
ws.onclose = () => {
|
||||
chatRequest.updating = false
|
||||
chatRequest.updatingMessage = ''
|
||||
chatResponse.updateFromClose()
|
||||
}
|
||||
ws.onerror = err => {
|
||||
console.error(err)
|
||||
throw err
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
|
@ -3,7 +3,7 @@
|
|||
import Footer from './Footer.svelte'
|
||||
import { replace } from 'svelte-spa-router'
|
||||
import { onMount } from 'svelte'
|
||||
import { getPetalsV2Websocket } from './ApiUtil.svelte'
|
||||
import { getPetals } from './ApiUtil.svelte'
|
||||
|
||||
$: apiKey = $apiKeyStorage
|
||||
|
||||
|
@ -112,7 +112,7 @@ const setPetalsEnabled = (event: Event) => {
|
|||
aria-label="PetalsAPI Endpoint"
|
||||
type="text"
|
||||
class="input"
|
||||
placeholder={getPetalsV2Websocket()}
|
||||
placeholder={getPetals()}
|
||||
value={$globalStorage.pedalsEndpoint || ''}
|
||||
/>
|
||||
</p>
|
||||
|
@ -123,7 +123,7 @@ const setPetalsEnabled = (event: Event) => {
|
|||
|
||||
</form>
|
||||
<p>
|
||||
Only use <u>{getPetalsV2Websocket()}</u> for testing. You must set up your own Petals server for actual use.
|
||||
Only use <u>{getPetals()}</u> for testing. You must set up your own Petals server for actual use.
|
||||
</p>
|
||||
<p>
|
||||
<b>Do not send sensitive information when using Petals.</b>
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
<script context="module" lang="ts">
|
||||
import { getApiBase, getEndpointCompletions, getEndpointGenerations, getEndpointModels, getPetalsV2Websocket } from './ApiUtil.svelte'
|
||||
import { getApiBase, getEndpointCompletions, getEndpointGenerations, getEndpointModels, getPetals } from './ApiUtil.svelte'
|
||||
import { apiKeyStorage, globalStorage } from './Storage.svelte'
|
||||
import { get } from 'svelte/store'
|
||||
import type { ModelDetail, Model, ResponseModels, SelectOption, ChatSettings } from './Types.svelte'
|
||||
|
@ -34,9 +34,10 @@ const modelDetails : Record<string, ModelDetail> = {
|
|||
max: 16384 // 16k max token buffer
|
||||
},
|
||||
'meta-llama/Llama-2-70b-chat-hf': {
|
||||
type: 'PetalsV2Websocket',
|
||||
type: 'Petals',
|
||||
label: 'Petals - Llama-2-70b-chat',
|
||||
stop: ['###', '</s>'],
|
||||
start: [''],
|
||||
stop: ['</s>'],
|
||||
prompt: 0.000000, // $0.000 per 1000 tokens prompt
|
||||
completion: 0.000000, // $0.000 per 1000 tokens completion
|
||||
max: 4096 // 4k max token buffer
|
||||
|
@ -119,8 +120,8 @@ export const getEndpoint = (model: Model): string => {
|
|||
const modelDetails = getModelDetail(model)
|
||||
const gSettings = get(globalStorage)
|
||||
switch (modelDetails.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
return gSettings.pedalsEndpoint || getPetalsV2Websocket()
|
||||
case 'Petals':
|
||||
return gSettings.pedalsEndpoint || getPetals()
|
||||
case 'OpenAIDall-e':
|
||||
return getApiBase() + getEndpointGenerations()
|
||||
case 'OpenAIChat':
|
||||
|
@ -132,12 +133,12 @@ export const getEndpoint = (model: Model): string => {
|
|||
export const getRoleTag = (role: string, model: Model, settings: ChatSettings): string => {
|
||||
const modelDetails = getModelDetail(model)
|
||||
switch (modelDetails.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
case 'Petals':
|
||||
if (role === 'assistant') {
|
||||
return ('Assistant') +
|
||||
': '
|
||||
if (settings.useSystemPrompt && settings.characterName) return '[' + settings.characterName + '] '
|
||||
return '[Assistant] '
|
||||
}
|
||||
if (role === 'user') return 'Human: '
|
||||
if (role === 'user') return '[user] '
|
||||
return ''
|
||||
case 'OpenAIDall-e':
|
||||
return role
|
||||
|
@ -150,7 +151,7 @@ export const getRoleTag = (role: string, model: Model, settings: ChatSettings):
|
|||
export const getTokens = (model: Model, value: string): number[] => {
|
||||
const modelDetails = getModelDetail(model)
|
||||
switch (modelDetails.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
case 'Petals':
|
||||
return llamaTokenizer.encode(value)
|
||||
case 'OpenAIDall-e':
|
||||
return [0]
|
||||
|
@ -184,7 +185,7 @@ export async function getModelOptions (): Promise<SelectOption[]> {
|
|||
}
|
||||
const filteredModels = supportedModelKeys.filter((model) => {
|
||||
switch (getModelDetail(model).type) {
|
||||
case 'PetalsV2Websocket':
|
||||
case 'Petals':
|
||||
return gSettings.enablePetals
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
|
|
|
@ -10,19 +10,12 @@
|
|||
export const countPromptTokens = (prompts:Message[], model:Model, settings: ChatSettings):number => {
|
||||
const detail = getModelDetail(model)
|
||||
const count = prompts.reduce((a, m) => {
|
||||
switch (detail.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
a += countMessageTokens(m, model, settings)
|
||||
break
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
a += countMessageTokens(m, model, settings)
|
||||
}
|
||||
a += countMessageTokens(m, model, settings)
|
||||
return a
|
||||
}, 0)
|
||||
switch (detail.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
return count + (Math.max(prompts.length - 1, 0) * countTokens(model, (detail.stop && detail.stop[0]) || '###')) // todo, make stop per model?
|
||||
case 'Petals':
|
||||
return count
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
|
||||
|
@ -34,9 +27,11 @@
|
|||
|
||||
export const countMessageTokens = (message:Message, model:Model, settings: ChatSettings):number => {
|
||||
const detail = getModelDetail(model)
|
||||
const start = detail.start && detail.start[0]
|
||||
const stop = detail.stop && detail.stop[0]
|
||||
switch (detail.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
return countTokens(model, getRoleTag(message.role, model, settings) + ': ' + message.content)
|
||||
case 'Petals':
|
||||
return countTokens(model, (start || '') + getRoleTag(message.role, model, settings) + ': ' + message.content + (stop || '###'))
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
|
||||
|
|
|
@ -7,12 +7,13 @@ export type Model = typeof supportedModelKeys[number];
|
|||
|
||||
export type ImageGenerationSizes = typeof imageGenerationSizeTypes[number];
|
||||
|
||||
export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'PetalsV2Websocket'
|
||||
export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'Petals'
|
||||
|
||||
export type ModelDetail = {
|
||||
type: RequestType;
|
||||
label?: string;
|
||||
stop?: string[];
|
||||
start?: string[];
|
||||
prompt: number;
|
||||
completion: number;
|
||||
max: number;
|
||||
|
|
Loading…
Reference in New Issue