From 9a6004c55d36bac64b393fe45ee832cc60b75910 Mon Sep 17 00:00:00 2001 From: Webifi Date: Sat, 22 Jul 2023 13:24:18 -0500 Subject: [PATCH] More changes for Petals integration --- src/lib/ApiUtil.svelte | 4 +- src/lib/ChatCompletionResponse.svelte | 4 + src/lib/ChatRequest.svelte | 204 ++------------------------ src/lib/ChatRequestOpenAi.svelte | 100 +++++++++++++ src/lib/ChatRequestPetals.svelte | 126 ++++++++++++++++ src/lib/Home.svelte | 6 +- src/lib/Models.svelte | 23 +-- src/lib/Stats.svelte | 19 +-- src/lib/Types.svelte | 3 +- 9 files changed, 271 insertions(+), 218 deletions(-) create mode 100644 src/lib/ChatRequestOpenAi.svelte create mode 100644 src/lib/ChatRequestPetals.svelte diff --git a/src/lib/ApiUtil.svelte b/src/lib/ApiUtil.svelte index ceded8b..afd2f7f 100644 --- a/src/lib/ApiUtil.svelte +++ b/src/lib/ApiUtil.svelte @@ -5,12 +5,12 @@ const endpointGenerations = import.meta.env.VITE_ENDPOINT_GENERATIONS || '/v1/images/generations' const endpointModels = import.meta.env.VITE_ENDPOINT_MODELS || '/v1/models' const endpointEmbeddings = import.meta.env.VITE_ENDPOINT_EMBEDDINGS || '/v1/embeddings' - const endpointPetalsV2Websocket = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate' + const endpointPetals = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate' export const getApiBase = ():string => apiBase export const getEndpointCompletions = ():string => endpointCompletions export const getEndpointGenerations = ():string => endpointGenerations export const getEndpointModels = ():string => endpointModels export const getEndpointEmbeddings = ():string => endpointEmbeddings - export const getPetalsV2Websocket = ():string => endpointPetalsV2Websocket + export const getPetals = ():string => endpointPetals \ No newline at end of file diff --git a/src/lib/ChatCompletionResponse.svelte b/src/lib/ChatCompletionResponse.svelte index a6743f6..ab5fcff 100644 --- a/src/lib/ChatCompletionResponse.svelte +++ b/src/lib/ChatCompletionResponse.svelte @@ -65,6 +65,10 @@ export class ChatCompletionResponse { this.promptTokenCount = tokens } + getPromptTokenCount (): number { + return this.promptTokenCount + } + async updateImageFromSyncResponse (response: ResponseImage, prompt: string, model: Model) { this.setModel(model) for (let i = 0; i < response.data.length; i++) { diff --git a/src/lib/ChatRequest.svelte b/src/lib/ChatRequest.svelte index 20b5626..40c966e 100644 --- a/src/lib/ChatRequest.svelte +++ b/src/lib/ChatRequest.svelte @@ -6,10 +6,11 @@ import { deleteMessage, getChatSettingValueNullDefault, insertMessages, getApiKey, addError, currentChatMessages, getMessages, updateMessages, deleteSummaryMessage } from './Storage.svelte' import { scrollToBottom, scrollToMessage } from './Util.svelte' import { getRequestSettingList, defaultModel } from './Settings.svelte' - import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source' import { v4 as uuidv4 } from 'uuid' import { get } from 'svelte/store' - import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte' + import { getEndpoint, getModelDetail } from './Models.svelte' + import { runOpenAiCompletionRequest } from './ChatRequestOpenAi.svelte' + import { runPetalsCompletionRequest } from './ChatRequestPetals.svelte' export class ChatRequest { constructor () { @@ -27,6 +28,14 @@ export class ChatRequest { this.chat = chat } + getChat (): Chat { + return this.chat + } + + getChatSettings (): ChatSettings { + return this.chat.settings + } + // Common error handler async handleError (response) { let errorResponse @@ -258,193 +267,10 @@ export class ChatRequest { _this.controller = new AbortController() const signal = _this.controller.signal - if (modelDetail.type === 'PetalsV2Websocket') { - // Petals - const ws = new WebSocket(getEndpoint(model)) - const abortListener = (e:Event) => { - _this.updating = false - _this.updatingMessage = '' - chatResponse.updateFromError('User aborted request.') - signal.removeEventListener('abort', abortListener) - ws.close() - } - signal.addEventListener('abort', abortListener) - const stopSequences = modelDetail.stop || ['###'] - const stopSequencesC = stopSequences.slice() - const stopSequence = stopSequencesC.shift() - chatResponse.onFinish(() => { - _this.updating = false - _this.updatingMessage = '' - }) - ws.onopen = () => { - ws.send(JSON.stringify({ - type: 'open_inference_session', - model, - max_length: maxTokens || opts.maxTokens - })) - ws.onmessage = event => { - const response = JSON.parse(event.data) - if (!response.ok) { - const err = new Error('Error opening socket: ' + response.traceback) - console.error(err) - throw err - } - const rMessages = request.messages || [] as Message[] - const inputArray = (rMessages).reduce((a, m) => { - const c = getRoleTag(m.role, model, chatSettings) + m.content - a.push(c) - return a - }, [] as string[]) - const lastMessage = rMessages[rMessages.length - 1] - if (lastMessage && lastMessage.role !== 'assistant') { - inputArray.push(getRoleTag('assistant', model, chatSettings)) - } - const petalsRequest = { - type: 'generate', - inputs: (request.messages || [] as Message[]).reduce((a, m) => { - const c = getRoleTag(m.role, model, chatSettings) + m.content - a.push(c) - return a - }, [] as string[]).join(stopSequence), - max_new_tokens: 3, // wait for up to 3 tokens before displaying - stop_sequence: stopSequence, - doSample: 1, - temperature: request.temperature || 0, - top_p: request.top_p || 0, - extra_stop_sequences: stopSequencesC - } - ws.send(JSON.stringify(petalsRequest)) - ws.onmessage = event => { - // Remove updating indicator - _this.updating = 1 // hide indicator, but still signal we're updating - _this.updatingMessage = '' - const response = JSON.parse(event.data) - if (!response.ok) { - const err = new Error('Error in response: ' + response.traceback) - console.error(err) - throw err - } - window.setTimeout(() => { - chatResponse.updateFromAsyncResponse( - { - model, - choices: [{ - delta: { - content: response.outputs, - role: 'assistant' - }, - finish_reason: (response.stop ? 'stop' : null) - }] - } as any - ) - if (response.stop) { - const message = chatResponse.getMessages()[0] - if (message) { - for (let i = 0, l = stopSequences.length; i < l; i++) { - if (message.content.endsWith(stopSequences[i])) { - message.content = message.content.slice(0, message.content.length - stopSequences[i].length) - updateMessages(chatId) - } - } - } - } - }, 1) - } - } - ws.onclose = () => { - _this.updating = false - _this.updatingMessage = '' - chatResponse.updateFromClose() - } - ws.onerror = err => { - console.error(err) - throw err - } - } + if (modelDetail.type === 'Petals') { + await runPetalsCompletionRequest(request, _this as any, chatResponse as any, signal, opts) } else { - // OpenAI - const abortListener = (e:Event) => { - _this.updating = false - _this.updatingMessage = '' - chatResponse.updateFromError('User aborted request.') - signal.removeEventListener('abort', abortListener) - } - signal.addEventListener('abort', abortListener) - const fetchOptions = { - method: 'POST', - headers: { - Authorization: `Bearer ${getApiKey()}`, - 'Content-Type': 'application/json' - }, - body: JSON.stringify(request), - signal - } - - if (opts.streaming) { - /** - * Streaming request/response - * We'll get the response a token at a time, as soon as they are ready - */ - chatResponse.onFinish(() => { - _this.updating = false - _this.updatingMessage = '' - }) - fetchEventSource(getEndpoint(model), { - ...fetchOptions, - openWhenHidden: true, - onmessage (ev) { - // Remove updating indicator - _this.updating = 1 // hide indicator, but still signal we're updating - _this.updatingMessage = '' - // console.log('ev.data', ev.data) - if (!chatResponse.hasFinished()) { - if (ev.data === '[DONE]') { - // ?? anything to do when "[DONE]"? - } else { - const data = JSON.parse(ev.data) - // console.log('data', data) - window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1) - } - } - }, - onclose () { - _this.updating = false - _this.updatingMessage = '' - chatResponse.updateFromClose() - }, - onerror (err) { - console.error(err) - throw err - }, - async onopen (response) { - if (response.ok && response.headers.get('content-type') === EventStreamContentType) { - // everything's good - } else { - // client-side errors are usually non-retriable: - await _this.handleError(response) - } - } - }).catch(err => { - _this.updating = false - _this.updatingMessage = '' - chatResponse.updateFromError(err.message) - }) - } else { - /** - * Non-streaming request/response - * We'll get the response all at once, after a long delay - */ - const response = await fetch(getEndpoint(model), fetchOptions) - if (!response.ok) { - await _this.handleError(response) - } else { - const json = await response.json() - // Remove updating indicator - _this.updating = false - _this.updatingMessage = '' - chatResponse.updateFromSyncResponse(json) - } - } + await runOpenAiCompletionRequest(request, _this as any, chatResponse as any, signal, opts) } } catch (e) { // console.error(e) @@ -456,7 +282,7 @@ export class ChatRequest { return chatResponse } - private getModel (): Model { + getModel (): Model { return this.chat.settings.model || defaultModel } diff --git a/src/lib/ChatRequestOpenAi.svelte b/src/lib/ChatRequestOpenAi.svelte new file mode 100644 index 0000000..37495ef --- /dev/null +++ b/src/lib/ChatRequestOpenAi.svelte @@ -0,0 +1,100 @@ + \ No newline at end of file diff --git a/src/lib/ChatRequestPetals.svelte b/src/lib/ChatRequestPetals.svelte new file mode 100644 index 0000000..b0c1bac --- /dev/null +++ b/src/lib/ChatRequestPetals.svelte @@ -0,0 +1,126 @@ + \ No newline at end of file diff --git a/src/lib/Home.svelte b/src/lib/Home.svelte index c86a17a..a69b1c2 100644 --- a/src/lib/Home.svelte +++ b/src/lib/Home.svelte @@ -3,7 +3,7 @@ import Footer from './Footer.svelte' import { replace } from 'svelte-spa-router' import { onMount } from 'svelte' - import { getPetalsV2Websocket } from './ApiUtil.svelte' + import { getPetals } from './ApiUtil.svelte' $: apiKey = $apiKeyStorage @@ -112,7 +112,7 @@ const setPetalsEnabled = (event: Event) => { aria-label="PetalsAPI Endpoint" type="text" class="input" - placeholder={getPetalsV2Websocket()} + placeholder={getPetals()} value={$globalStorage.pedalsEndpoint || ''} />

@@ -123,7 +123,7 @@ const setPetalsEnabled = (event: Event) => {

- Only use {getPetalsV2Websocket()} for testing. You must set up your own Petals server for actual use. + Only use {getPetals()} for testing. You must set up your own Petals server for actual use.

Do not send sensitive information when using Petals. diff --git a/src/lib/Models.svelte b/src/lib/Models.svelte index 1289939..8f03e24 100644 --- a/src/lib/Models.svelte +++ b/src/lib/Models.svelte @@ -1,5 +1,5 @@