More changes for Petals integration
This commit is contained in:
parent
df222e7028
commit
9a6004c55d
|
@ -5,12 +5,12 @@
|
||||||
const endpointGenerations = import.meta.env.VITE_ENDPOINT_GENERATIONS || '/v1/images/generations'
|
const endpointGenerations = import.meta.env.VITE_ENDPOINT_GENERATIONS || '/v1/images/generations'
|
||||||
const endpointModels = import.meta.env.VITE_ENDPOINT_MODELS || '/v1/models'
|
const endpointModels = import.meta.env.VITE_ENDPOINT_MODELS || '/v1/models'
|
||||||
const endpointEmbeddings = import.meta.env.VITE_ENDPOINT_EMBEDDINGS || '/v1/embeddings'
|
const endpointEmbeddings = import.meta.env.VITE_ENDPOINT_EMBEDDINGS || '/v1/embeddings'
|
||||||
const endpointPetalsV2Websocket = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
|
const endpointPetals = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
|
||||||
|
|
||||||
export const getApiBase = ():string => apiBase
|
export const getApiBase = ():string => apiBase
|
||||||
export const getEndpointCompletions = ():string => endpointCompletions
|
export const getEndpointCompletions = ():string => endpointCompletions
|
||||||
export const getEndpointGenerations = ():string => endpointGenerations
|
export const getEndpointGenerations = ():string => endpointGenerations
|
||||||
export const getEndpointModels = ():string => endpointModels
|
export const getEndpointModels = ():string => endpointModels
|
||||||
export const getEndpointEmbeddings = ():string => endpointEmbeddings
|
export const getEndpointEmbeddings = ():string => endpointEmbeddings
|
||||||
export const getPetalsV2Websocket = ():string => endpointPetalsV2Websocket
|
export const getPetals = ():string => endpointPetals
|
||||||
</script>
|
</script>
|
|
@ -65,6 +65,10 @@ export class ChatCompletionResponse {
|
||||||
this.promptTokenCount = tokens
|
this.promptTokenCount = tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getPromptTokenCount (): number {
|
||||||
|
return this.promptTokenCount
|
||||||
|
}
|
||||||
|
|
||||||
async updateImageFromSyncResponse (response: ResponseImage, prompt: string, model: Model) {
|
async updateImageFromSyncResponse (response: ResponseImage, prompt: string, model: Model) {
|
||||||
this.setModel(model)
|
this.setModel(model)
|
||||||
for (let i = 0; i < response.data.length; i++) {
|
for (let i = 0; i < response.data.length; i++) {
|
||||||
|
|
|
@ -6,10 +6,11 @@
|
||||||
import { deleteMessage, getChatSettingValueNullDefault, insertMessages, getApiKey, addError, currentChatMessages, getMessages, updateMessages, deleteSummaryMessage } from './Storage.svelte'
|
import { deleteMessage, getChatSettingValueNullDefault, insertMessages, getApiKey, addError, currentChatMessages, getMessages, updateMessages, deleteSummaryMessage } from './Storage.svelte'
|
||||||
import { scrollToBottom, scrollToMessage } from './Util.svelte'
|
import { scrollToBottom, scrollToMessage } from './Util.svelte'
|
||||||
import { getRequestSettingList, defaultModel } from './Settings.svelte'
|
import { getRequestSettingList, defaultModel } from './Settings.svelte'
|
||||||
import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
|
|
||||||
import { v4 as uuidv4 } from 'uuid'
|
import { v4 as uuidv4 } from 'uuid'
|
||||||
import { get } from 'svelte/store'
|
import { get } from 'svelte/store'
|
||||||
import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
|
import { getEndpoint, getModelDetail } from './Models.svelte'
|
||||||
|
import { runOpenAiCompletionRequest } from './ChatRequestOpenAi.svelte'
|
||||||
|
import { runPetalsCompletionRequest } from './ChatRequestPetals.svelte'
|
||||||
|
|
||||||
export class ChatRequest {
|
export class ChatRequest {
|
||||||
constructor () {
|
constructor () {
|
||||||
|
@ -27,6 +28,14 @@ export class ChatRequest {
|
||||||
this.chat = chat
|
this.chat = chat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getChat (): Chat {
|
||||||
|
return this.chat
|
||||||
|
}
|
||||||
|
|
||||||
|
getChatSettings (): ChatSettings {
|
||||||
|
return this.chat.settings
|
||||||
|
}
|
||||||
|
|
||||||
// Common error handler
|
// Common error handler
|
||||||
async handleError (response) {
|
async handleError (response) {
|
||||||
let errorResponse
|
let errorResponse
|
||||||
|
@ -258,193 +267,10 @@ export class ChatRequest {
|
||||||
_this.controller = new AbortController()
|
_this.controller = new AbortController()
|
||||||
const signal = _this.controller.signal
|
const signal = _this.controller.signal
|
||||||
|
|
||||||
if (modelDetail.type === 'PetalsV2Websocket') {
|
if (modelDetail.type === 'Petals') {
|
||||||
// Petals
|
await runPetalsCompletionRequest(request, _this as any, chatResponse as any, signal, opts)
|
||||||
const ws = new WebSocket(getEndpoint(model))
|
|
||||||
const abortListener = (e:Event) => {
|
|
||||||
_this.updating = false
|
|
||||||
_this.updatingMessage = ''
|
|
||||||
chatResponse.updateFromError('User aborted request.')
|
|
||||||
signal.removeEventListener('abort', abortListener)
|
|
||||||
ws.close()
|
|
||||||
}
|
|
||||||
signal.addEventListener('abort', abortListener)
|
|
||||||
const stopSequences = modelDetail.stop || ['###']
|
|
||||||
const stopSequencesC = stopSequences.slice()
|
|
||||||
const stopSequence = stopSequencesC.shift()
|
|
||||||
chatResponse.onFinish(() => {
|
|
||||||
_this.updating = false
|
|
||||||
_this.updatingMessage = ''
|
|
||||||
})
|
|
||||||
ws.onopen = () => {
|
|
||||||
ws.send(JSON.stringify({
|
|
||||||
type: 'open_inference_session',
|
|
||||||
model,
|
|
||||||
max_length: maxTokens || opts.maxTokens
|
|
||||||
}))
|
|
||||||
ws.onmessage = event => {
|
|
||||||
const response = JSON.parse(event.data)
|
|
||||||
if (!response.ok) {
|
|
||||||
const err = new Error('Error opening socket: ' + response.traceback)
|
|
||||||
console.error(err)
|
|
||||||
throw err
|
|
||||||
}
|
|
||||||
const rMessages = request.messages || [] as Message[]
|
|
||||||
const inputArray = (rMessages).reduce((a, m) => {
|
|
||||||
const c = getRoleTag(m.role, model, chatSettings) + m.content
|
|
||||||
a.push(c)
|
|
||||||
return a
|
|
||||||
}, [] as string[])
|
|
||||||
const lastMessage = rMessages[rMessages.length - 1]
|
|
||||||
if (lastMessage && lastMessage.role !== 'assistant') {
|
|
||||||
inputArray.push(getRoleTag('assistant', model, chatSettings))
|
|
||||||
}
|
|
||||||
const petalsRequest = {
|
|
||||||
type: 'generate',
|
|
||||||
inputs: (request.messages || [] as Message[]).reduce((a, m) => {
|
|
||||||
const c = getRoleTag(m.role, model, chatSettings) + m.content
|
|
||||||
a.push(c)
|
|
||||||
return a
|
|
||||||
}, [] as string[]).join(stopSequence),
|
|
||||||
max_new_tokens: 3, // wait for up to 3 tokens before displaying
|
|
||||||
stop_sequence: stopSequence,
|
|
||||||
doSample: 1,
|
|
||||||
temperature: request.temperature || 0,
|
|
||||||
top_p: request.top_p || 0,
|
|
||||||
extra_stop_sequences: stopSequencesC
|
|
||||||
}
|
|
||||||
ws.send(JSON.stringify(petalsRequest))
|
|
||||||
ws.onmessage = event => {
|
|
||||||
// Remove updating indicator
|
|
||||||
_this.updating = 1 // hide indicator, but still signal we're updating
|
|
||||||
_this.updatingMessage = ''
|
|
||||||
const response = JSON.parse(event.data)
|
|
||||||
if (!response.ok) {
|
|
||||||
const err = new Error('Error in response: ' + response.traceback)
|
|
||||||
console.error(err)
|
|
||||||
throw err
|
|
||||||
}
|
|
||||||
window.setTimeout(() => {
|
|
||||||
chatResponse.updateFromAsyncResponse(
|
|
||||||
{
|
|
||||||
model,
|
|
||||||
choices: [{
|
|
||||||
delta: {
|
|
||||||
content: response.outputs,
|
|
||||||
role: 'assistant'
|
|
||||||
},
|
|
||||||
finish_reason: (response.stop ? 'stop' : null)
|
|
||||||
}]
|
|
||||||
} as any
|
|
||||||
)
|
|
||||||
if (response.stop) {
|
|
||||||
const message = chatResponse.getMessages()[0]
|
|
||||||
if (message) {
|
|
||||||
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
|
||||||
if (message.content.endsWith(stopSequences[i])) {
|
|
||||||
message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
|
|
||||||
updateMessages(chatId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ws.onclose = () => {
|
|
||||||
_this.updating = false
|
|
||||||
_this.updatingMessage = ''
|
|
||||||
chatResponse.updateFromClose()
|
|
||||||
}
|
|
||||||
ws.onerror = err => {
|
|
||||||
console.error(err)
|
|
||||||
throw err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// OpenAI
|
await runOpenAiCompletionRequest(request, _this as any, chatResponse as any, signal, opts)
|
||||||
const abortListener = (e:Event) => {
|
|
||||||
_this.updating = false
|
|
||||||
_this.updatingMessage = ''
|
|
||||||
chatResponse.updateFromError('User aborted request.')
|
|
||||||
signal.removeEventListener('abort', abortListener)
|
|
||||||
}
|
|
||||||
signal.addEventListener('abort', abortListener)
|
|
||||||
const fetchOptions = {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
Authorization: `Bearer ${getApiKey()}`,
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
},
|
|
||||||
body: JSON.stringify(request),
|
|
||||||
signal
|
|
||||||
}
|
|
||||||
|
|
||||||
if (opts.streaming) {
|
|
||||||
/**
|
|
||||||
* Streaming request/response
|
|
||||||
* We'll get the response a token at a time, as soon as they are ready
|
|
||||||
*/
|
|
||||||
chatResponse.onFinish(() => {
|
|
||||||
_this.updating = false
|
|
||||||
_this.updatingMessage = ''
|
|
||||||
})
|
|
||||||
fetchEventSource(getEndpoint(model), {
|
|
||||||
...fetchOptions,
|
|
||||||
openWhenHidden: true,
|
|
||||||
onmessage (ev) {
|
|
||||||
// Remove updating indicator
|
|
||||||
_this.updating = 1 // hide indicator, but still signal we're updating
|
|
||||||
_this.updatingMessage = ''
|
|
||||||
// console.log('ev.data', ev.data)
|
|
||||||
if (!chatResponse.hasFinished()) {
|
|
||||||
if (ev.data === '[DONE]') {
|
|
||||||
// ?? anything to do when "[DONE]"?
|
|
||||||
} else {
|
|
||||||
const data = JSON.parse(ev.data)
|
|
||||||
// console.log('data', data)
|
|
||||||
window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onclose () {
|
|
||||||
_this.updating = false
|
|
||||||
_this.updatingMessage = ''
|
|
||||||
chatResponse.updateFromClose()
|
|
||||||
},
|
|
||||||
onerror (err) {
|
|
||||||
console.error(err)
|
|
||||||
throw err
|
|
||||||
},
|
|
||||||
async onopen (response) {
|
|
||||||
if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
|
|
||||||
// everything's good
|
|
||||||
} else {
|
|
||||||
// client-side errors are usually non-retriable:
|
|
||||||
await _this.handleError(response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}).catch(err => {
|
|
||||||
_this.updating = false
|
|
||||||
_this.updatingMessage = ''
|
|
||||||
chatResponse.updateFromError(err.message)
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
/**
|
|
||||||
* Non-streaming request/response
|
|
||||||
* We'll get the response all at once, after a long delay
|
|
||||||
*/
|
|
||||||
const response = await fetch(getEndpoint(model), fetchOptions)
|
|
||||||
if (!response.ok) {
|
|
||||||
await _this.handleError(response)
|
|
||||||
} else {
|
|
||||||
const json = await response.json()
|
|
||||||
// Remove updating indicator
|
|
||||||
_this.updating = false
|
|
||||||
_this.updatingMessage = ''
|
|
||||||
chatResponse.updateFromSyncResponse(json)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
// console.error(e)
|
// console.error(e)
|
||||||
|
@ -456,7 +282,7 @@ export class ChatRequest {
|
||||||
return chatResponse
|
return chatResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
private getModel (): Model {
|
getModel (): Model {
|
||||||
return this.chat.settings.model || defaultModel
|
return this.chat.settings.model || defaultModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
<script context="module" lang="ts">
|
||||||
|
import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
|
||||||
|
import ChatCompletionResponse from './ChatCompletionResponse.svelte'
|
||||||
|
import ChatRequest from './ChatRequest.svelte'
|
||||||
|
import { getEndpoint } from './Models.svelte'
|
||||||
|
import { getApiKey } from './Storage.svelte'
|
||||||
|
import type { ChatCompletionOpts, Request } from './Types.svelte'
|
||||||
|
|
||||||
|
export const runOpenAiCompletionRequest = async (
|
||||||
|
request: Request,
|
||||||
|
chatRequest: ChatRequest,
|
||||||
|
chatResponse: ChatCompletionResponse,
|
||||||
|
signal: AbortSignal,
|
||||||
|
opts: ChatCompletionOpts) => {
|
||||||
|
// OpenAI Request
|
||||||
|
const model = chatRequest.getModel()
|
||||||
|
const abortListener = (e:Event) => {
|
||||||
|
chatRequest.updating = false
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
chatResponse.updateFromError('User aborted request.')
|
||||||
|
chatRequest.removeEventListener('abort', abortListener)
|
||||||
|
}
|
||||||
|
signal.addEventListener('abort', abortListener)
|
||||||
|
const fetchOptions = {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${getApiKey()}`,
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: JSON.stringify(request),
|
||||||
|
signal
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.streaming) {
|
||||||
|
/**
|
||||||
|
* Streaming request/response
|
||||||
|
* We'll get the response a token at a time, as soon as they are ready
|
||||||
|
*/
|
||||||
|
chatResponse.onFinish(() => {
|
||||||
|
chatRequest.updating = false
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
})
|
||||||
|
fetchEventSource(getEndpoint(model), {
|
||||||
|
...fetchOptions,
|
||||||
|
openWhenHidden: true,
|
||||||
|
onmessage (ev) {
|
||||||
|
// Remove updating indicator
|
||||||
|
chatRequest.updating = 1 // hide indicator, but still signal we're updating
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
// console.log('ev.data', ev.data)
|
||||||
|
if (!chatResponse.hasFinished()) {
|
||||||
|
if (ev.data === '[DONE]') {
|
||||||
|
// ?? anything to do when "[DONE]"?
|
||||||
|
} else {
|
||||||
|
const data = JSON.parse(ev.data)
|
||||||
|
// console.log('data', data)
|
||||||
|
window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onclose () {
|
||||||
|
chatRequest.updating = false
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
chatResponse.updateFromClose()
|
||||||
|
},
|
||||||
|
onerror (err) {
|
||||||
|
console.error(err)
|
||||||
|
throw err
|
||||||
|
},
|
||||||
|
async onopen (response) {
|
||||||
|
if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
|
||||||
|
// everything's good
|
||||||
|
} else {
|
||||||
|
// client-side errors are usually non-retriable:
|
||||||
|
await chatRequest.handleError(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}).catch(err => {
|
||||||
|
chatRequest.updating = false
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
chatResponse.updateFromError(err.message)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
/**
|
||||||
|
* Non-streaming request/response
|
||||||
|
* We'll get the response all at once, after a long delay
|
||||||
|
*/
|
||||||
|
const response = await fetch(getEndpoint(model), fetchOptions)
|
||||||
|
if (!response.ok) {
|
||||||
|
await chatRequest.handleError(response)
|
||||||
|
} else {
|
||||||
|
const json = await response.json()
|
||||||
|
// Remove updating indicator
|
||||||
|
chatRequest.updating = false
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
chatResponse.updateFromSyncResponse(json)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
|
@ -0,0 +1,126 @@
|
||||||
|
<script context="module" lang="ts">
|
||||||
|
import ChatCompletionResponse from './ChatCompletionResponse.svelte'
|
||||||
|
import ChatRequest from './ChatRequest.svelte'
|
||||||
|
import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
|
||||||
|
import type { ChatCompletionOpts, Message, Request } from './Types.svelte'
|
||||||
|
import { getModelMaxTokens } from './Stats.svelte'
|
||||||
|
import { updateMessages } from './Storage.svelte'
|
||||||
|
|
||||||
|
export const runPetalsCompletionRequest = async (
|
||||||
|
request: Request,
|
||||||
|
chatRequest: ChatRequest,
|
||||||
|
chatResponse: ChatCompletionResponse,
|
||||||
|
signal: AbortSignal,
|
||||||
|
opts: ChatCompletionOpts) => {
|
||||||
|
// Petals
|
||||||
|
const model = chatRequest.getModel()
|
||||||
|
const modelDetail = getModelDetail(model)
|
||||||
|
const ws = new WebSocket(getEndpoint(model))
|
||||||
|
const abortListener = (e:Event) => {
|
||||||
|
chatRequest.updating = false
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
chatResponse.updateFromError('User aborted request.')
|
||||||
|
signal.removeEventListener('abort', abortListener)
|
||||||
|
ws.close()
|
||||||
|
}
|
||||||
|
signal.addEventListener('abort', abortListener)
|
||||||
|
const startSequences = modelDetail.start || []
|
||||||
|
const startSequence = startSequences[0] || ''
|
||||||
|
const stopSequences = modelDetail.stop || ['###']
|
||||||
|
const stopSequencesC = stopSequences.slice()
|
||||||
|
const stopSequence = stopSequencesC.shift()
|
||||||
|
const maxTokens = getModelMaxTokens(model)
|
||||||
|
let maxLen = Math.min(opts.maxTokens || chatRequest.chat.max_tokens || maxTokens, maxTokens)
|
||||||
|
const promptTokenCount = chatResponse.getPromptTokenCount()
|
||||||
|
if (promptTokenCount > maxLen) {
|
||||||
|
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
|
||||||
|
}
|
||||||
|
chatResponse.onFinish(() => {
|
||||||
|
chatRequest.updating = false
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
})
|
||||||
|
ws.onopen = () => {
|
||||||
|
ws.send(JSON.stringify({
|
||||||
|
type: 'open_inference_session',
|
||||||
|
model,
|
||||||
|
max_length: maxLen
|
||||||
|
}))
|
||||||
|
ws.onmessage = event => {
|
||||||
|
const response = JSON.parse(event.data)
|
||||||
|
if (!response.ok) {
|
||||||
|
const err = new Error('Error opening socket: ' + response.traceback)
|
||||||
|
console.error(err)
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
const rMessages = request.messages || [] as Message[]
|
||||||
|
const inputArray = (rMessages).reduce((a, m) => {
|
||||||
|
const c = getRoleTag(m.role, model, chatRequest.chat) + m.content
|
||||||
|
a.push(c)
|
||||||
|
return a
|
||||||
|
}, [] as string[])
|
||||||
|
const lastMessage = rMessages[rMessages.length - 1]
|
||||||
|
if (lastMessage && lastMessage.role !== 'assistant') {
|
||||||
|
inputArray.push(getRoleTag('assistant', model, chatRequest.chat))
|
||||||
|
}
|
||||||
|
const petalsRequest = {
|
||||||
|
type: 'generate',
|
||||||
|
inputs: inputArray.join(stopSequence),
|
||||||
|
max_new_tokens: 3, // wait for up to 3 tokens before displaying
|
||||||
|
stop_sequence: stopSequence,
|
||||||
|
doSample: 1,
|
||||||
|
temperature: request.temperature || 0,
|
||||||
|
top_p: request.top_p || 0,
|
||||||
|
extra_stop_sequences: stopSequencesC
|
||||||
|
}
|
||||||
|
ws.send(JSON.stringify(petalsRequest))
|
||||||
|
ws.onmessage = event => {
|
||||||
|
// Remove updating indicator
|
||||||
|
chatRequest.updating = 1 // hide indicator, but still signal we're updating
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
const response = JSON.parse(event.data)
|
||||||
|
if (!response.ok) {
|
||||||
|
const err = new Error('Error in response: ' + response.traceback)
|
||||||
|
console.error(err)
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
window.setTimeout(() => {
|
||||||
|
chatResponse.updateFromAsyncResponse(
|
||||||
|
{
|
||||||
|
model,
|
||||||
|
choices: [{
|
||||||
|
delta: {
|
||||||
|
content: response.outputs,
|
||||||
|
role: 'assistant'
|
||||||
|
},
|
||||||
|
finish_reason: (response.stop ? 'stop' : null)
|
||||||
|
}]
|
||||||
|
} as any
|
||||||
|
)
|
||||||
|
if (response.stop) {
|
||||||
|
const message = chatResponse.getMessages()[0]
|
||||||
|
if (message) {
|
||||||
|
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
||||||
|
if (message.content.endsWith(stopSequences[i])) {
|
||||||
|
message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
|
||||||
|
const startS = startSequence[i] || ''
|
||||||
|
if (message.content.startsWith(startS)) message.content = message.content.slice(startS.length)
|
||||||
|
updateMessages(chatRequest.getChat().id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ws.onclose = () => {
|
||||||
|
chatRequest.updating = false
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
chatResponse.updateFromClose()
|
||||||
|
}
|
||||||
|
ws.onerror = err => {
|
||||||
|
console.error(err)
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
|
@ -3,7 +3,7 @@
|
||||||
import Footer from './Footer.svelte'
|
import Footer from './Footer.svelte'
|
||||||
import { replace } from 'svelte-spa-router'
|
import { replace } from 'svelte-spa-router'
|
||||||
import { onMount } from 'svelte'
|
import { onMount } from 'svelte'
|
||||||
import { getPetalsV2Websocket } from './ApiUtil.svelte'
|
import { getPetals } from './ApiUtil.svelte'
|
||||||
|
|
||||||
$: apiKey = $apiKeyStorage
|
$: apiKey = $apiKeyStorage
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ const setPetalsEnabled = (event: Event) => {
|
||||||
aria-label="PetalsAPI Endpoint"
|
aria-label="PetalsAPI Endpoint"
|
||||||
type="text"
|
type="text"
|
||||||
class="input"
|
class="input"
|
||||||
placeholder={getPetalsV2Websocket()}
|
placeholder={getPetals()}
|
||||||
value={$globalStorage.pedalsEndpoint || ''}
|
value={$globalStorage.pedalsEndpoint || ''}
|
||||||
/>
|
/>
|
||||||
</p>
|
</p>
|
||||||
|
@ -123,7 +123,7 @@ const setPetalsEnabled = (event: Event) => {
|
||||||
|
|
||||||
</form>
|
</form>
|
||||||
<p>
|
<p>
|
||||||
Only use <u>{getPetalsV2Websocket()}</u> for testing. You must set up your own Petals server for actual use.
|
Only use <u>{getPetals()}</u> for testing. You must set up your own Petals server for actual use.
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
<b>Do not send sensitive information when using Petals.</b>
|
<b>Do not send sensitive information when using Petals.</b>
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
<script context="module" lang="ts">
|
<script context="module" lang="ts">
|
||||||
import { getApiBase, getEndpointCompletions, getEndpointGenerations, getEndpointModels, getPetalsV2Websocket } from './ApiUtil.svelte'
|
import { getApiBase, getEndpointCompletions, getEndpointGenerations, getEndpointModels, getPetals } from './ApiUtil.svelte'
|
||||||
import { apiKeyStorage, globalStorage } from './Storage.svelte'
|
import { apiKeyStorage, globalStorage } from './Storage.svelte'
|
||||||
import { get } from 'svelte/store'
|
import { get } from 'svelte/store'
|
||||||
import type { ModelDetail, Model, ResponseModels, SelectOption, ChatSettings } from './Types.svelte'
|
import type { ModelDetail, Model, ResponseModels, SelectOption, ChatSettings } from './Types.svelte'
|
||||||
|
@ -34,9 +34,10 @@ const modelDetails : Record<string, ModelDetail> = {
|
||||||
max: 16384 // 16k max token buffer
|
max: 16384 // 16k max token buffer
|
||||||
},
|
},
|
||||||
'meta-llama/Llama-2-70b-chat-hf': {
|
'meta-llama/Llama-2-70b-chat-hf': {
|
||||||
type: 'PetalsV2Websocket',
|
type: 'Petals',
|
||||||
label: 'Petals - Llama-2-70b-chat',
|
label: 'Petals - Llama-2-70b-chat',
|
||||||
stop: ['###', '</s>'],
|
start: [''],
|
||||||
|
stop: ['</s>'],
|
||||||
prompt: 0.000000, // $0.000 per 1000 tokens prompt
|
prompt: 0.000000, // $0.000 per 1000 tokens prompt
|
||||||
completion: 0.000000, // $0.000 per 1000 tokens completion
|
completion: 0.000000, // $0.000 per 1000 tokens completion
|
||||||
max: 4096 // 4k max token buffer
|
max: 4096 // 4k max token buffer
|
||||||
|
@ -119,8 +120,8 @@ export const getEndpoint = (model: Model): string => {
|
||||||
const modelDetails = getModelDetail(model)
|
const modelDetails = getModelDetail(model)
|
||||||
const gSettings = get(globalStorage)
|
const gSettings = get(globalStorage)
|
||||||
switch (modelDetails.type) {
|
switch (modelDetails.type) {
|
||||||
case 'PetalsV2Websocket':
|
case 'Petals':
|
||||||
return gSettings.pedalsEndpoint || getPetalsV2Websocket()
|
return gSettings.pedalsEndpoint || getPetals()
|
||||||
case 'OpenAIDall-e':
|
case 'OpenAIDall-e':
|
||||||
return getApiBase() + getEndpointGenerations()
|
return getApiBase() + getEndpointGenerations()
|
||||||
case 'OpenAIChat':
|
case 'OpenAIChat':
|
||||||
|
@ -132,12 +133,12 @@ export const getEndpoint = (model: Model): string => {
|
||||||
export const getRoleTag = (role: string, model: Model, settings: ChatSettings): string => {
|
export const getRoleTag = (role: string, model: Model, settings: ChatSettings): string => {
|
||||||
const modelDetails = getModelDetail(model)
|
const modelDetails = getModelDetail(model)
|
||||||
switch (modelDetails.type) {
|
switch (modelDetails.type) {
|
||||||
case 'PetalsV2Websocket':
|
case 'Petals':
|
||||||
if (role === 'assistant') {
|
if (role === 'assistant') {
|
||||||
return ('Assistant') +
|
if (settings.useSystemPrompt && settings.characterName) return '[' + settings.characterName + '] '
|
||||||
': '
|
return '[Assistant] '
|
||||||
}
|
}
|
||||||
if (role === 'user') return 'Human: '
|
if (role === 'user') return '[user] '
|
||||||
return ''
|
return ''
|
||||||
case 'OpenAIDall-e':
|
case 'OpenAIDall-e':
|
||||||
return role
|
return role
|
||||||
|
@ -150,7 +151,7 @@ export const getRoleTag = (role: string, model: Model, settings: ChatSettings):
|
||||||
export const getTokens = (model: Model, value: string): number[] => {
|
export const getTokens = (model: Model, value: string): number[] => {
|
||||||
const modelDetails = getModelDetail(model)
|
const modelDetails = getModelDetail(model)
|
||||||
switch (modelDetails.type) {
|
switch (modelDetails.type) {
|
||||||
case 'PetalsV2Websocket':
|
case 'Petals':
|
||||||
return llamaTokenizer.encode(value)
|
return llamaTokenizer.encode(value)
|
||||||
case 'OpenAIDall-e':
|
case 'OpenAIDall-e':
|
||||||
return [0]
|
return [0]
|
||||||
|
@ -184,7 +185,7 @@ export async function getModelOptions (): Promise<SelectOption[]> {
|
||||||
}
|
}
|
||||||
const filteredModels = supportedModelKeys.filter((model) => {
|
const filteredModels = supportedModelKeys.filter((model) => {
|
||||||
switch (getModelDetail(model).type) {
|
switch (getModelDetail(model).type) {
|
||||||
case 'PetalsV2Websocket':
|
case 'Petals':
|
||||||
return gSettings.enablePetals
|
return gSettings.enablePetals
|
||||||
case 'OpenAIChat':
|
case 'OpenAIChat':
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -10,19 +10,12 @@
|
||||||
export const countPromptTokens = (prompts:Message[], model:Model, settings: ChatSettings):number => {
|
export const countPromptTokens = (prompts:Message[], model:Model, settings: ChatSettings):number => {
|
||||||
const detail = getModelDetail(model)
|
const detail = getModelDetail(model)
|
||||||
const count = prompts.reduce((a, m) => {
|
const count = prompts.reduce((a, m) => {
|
||||||
switch (detail.type) {
|
|
||||||
case 'PetalsV2Websocket':
|
|
||||||
a += countMessageTokens(m, model, settings)
|
a += countMessageTokens(m, model, settings)
|
||||||
break
|
|
||||||
case 'OpenAIChat':
|
|
||||||
default:
|
|
||||||
a += countMessageTokens(m, model, settings)
|
|
||||||
}
|
|
||||||
return a
|
return a
|
||||||
}, 0)
|
}, 0)
|
||||||
switch (detail.type) {
|
switch (detail.type) {
|
||||||
case 'PetalsV2Websocket':
|
case 'Petals':
|
||||||
return count + (Math.max(prompts.length - 1, 0) * countTokens(model, (detail.stop && detail.stop[0]) || '###')) // todo, make stop per model?
|
return count
|
||||||
case 'OpenAIChat':
|
case 'OpenAIChat':
|
||||||
default:
|
default:
|
||||||
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
|
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
|
||||||
|
@ -34,9 +27,11 @@
|
||||||
|
|
||||||
export const countMessageTokens = (message:Message, model:Model, settings: ChatSettings):number => {
|
export const countMessageTokens = (message:Message, model:Model, settings: ChatSettings):number => {
|
||||||
const detail = getModelDetail(model)
|
const detail = getModelDetail(model)
|
||||||
|
const start = detail.start && detail.start[0]
|
||||||
|
const stop = detail.stop && detail.stop[0]
|
||||||
switch (detail.type) {
|
switch (detail.type) {
|
||||||
case 'PetalsV2Websocket':
|
case 'Petals':
|
||||||
return countTokens(model, getRoleTag(message.role, model, settings) + ': ' + message.content)
|
return countTokens(model, (start || '') + getRoleTag(message.role, model, settings) + ': ' + message.content + (stop || '###'))
|
||||||
case 'OpenAIChat':
|
case 'OpenAIChat':
|
||||||
default:
|
default:
|
||||||
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
|
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
|
||||||
|
|
|
@ -7,12 +7,13 @@ export type Model = typeof supportedModelKeys[number];
|
||||||
|
|
||||||
export type ImageGenerationSizes = typeof imageGenerationSizeTypes[number];
|
export type ImageGenerationSizes = typeof imageGenerationSizeTypes[number];
|
||||||
|
|
||||||
export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'PetalsV2Websocket'
|
export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'Petals'
|
||||||
|
|
||||||
export type ModelDetail = {
|
export type ModelDetail = {
|
||||||
type: RequestType;
|
type: RequestType;
|
||||||
label?: string;
|
label?: string;
|
||||||
stop?: string[];
|
stop?: string[];
|
||||||
|
start?: string[];
|
||||||
prompt: number;
|
prompt: number;
|
||||||
completion: number;
|
completion: number;
|
||||||
max: number;
|
max: number;
|
||||||
|
|
Loading…
Reference in New Issue