From 9a6004c55d36bac64b393fe45ee832cc60b75910 Mon Sep 17 00:00:00 2001
From: Webifi
Date: Sat, 22 Jul 2023 13:24:18 -0500
Subject: [PATCH] More changes for Petals integration
---
src/lib/ApiUtil.svelte | 4 +-
src/lib/ChatCompletionResponse.svelte | 4 +
src/lib/ChatRequest.svelte | 204 ++------------------------
src/lib/ChatRequestOpenAi.svelte | 100 +++++++++++++
src/lib/ChatRequestPetals.svelte | 126 ++++++++++++++++
src/lib/Home.svelte | 6 +-
src/lib/Models.svelte | 23 +--
src/lib/Stats.svelte | 19 +--
src/lib/Types.svelte | 3 +-
9 files changed, 271 insertions(+), 218 deletions(-)
create mode 100644 src/lib/ChatRequestOpenAi.svelte
create mode 100644 src/lib/ChatRequestPetals.svelte
diff --git a/src/lib/ApiUtil.svelte b/src/lib/ApiUtil.svelte
index ceded8b..afd2f7f 100644
--- a/src/lib/ApiUtil.svelte
+++ b/src/lib/ApiUtil.svelte
@@ -5,12 +5,12 @@
const endpointGenerations = import.meta.env.VITE_ENDPOINT_GENERATIONS || '/v1/images/generations'
const endpointModels = import.meta.env.VITE_ENDPOINT_MODELS || '/v1/models'
const endpointEmbeddings = import.meta.env.VITE_ENDPOINT_EMBEDDINGS || '/v1/embeddings'
- const endpointPetalsV2Websocket = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
+ const endpointPetals = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
export const getApiBase = ():string => apiBase
export const getEndpointCompletions = ():string => endpointCompletions
export const getEndpointGenerations = ():string => endpointGenerations
export const getEndpointModels = ():string => endpointModels
export const getEndpointEmbeddings = ():string => endpointEmbeddings
- export const getPetalsV2Websocket = ():string => endpointPetalsV2Websocket
+ export const getPetals = ():string => endpointPetals
\ No newline at end of file
diff --git a/src/lib/ChatCompletionResponse.svelte b/src/lib/ChatCompletionResponse.svelte
index a6743f6..ab5fcff 100644
--- a/src/lib/ChatCompletionResponse.svelte
+++ b/src/lib/ChatCompletionResponse.svelte
@@ -65,6 +65,10 @@ export class ChatCompletionResponse {
this.promptTokenCount = tokens
}
+ getPromptTokenCount (): number {
+ return this.promptTokenCount
+ }
+
async updateImageFromSyncResponse (response: ResponseImage, prompt: string, model: Model) {
this.setModel(model)
for (let i = 0; i < response.data.length; i++) {
diff --git a/src/lib/ChatRequest.svelte b/src/lib/ChatRequest.svelte
index 20b5626..40c966e 100644
--- a/src/lib/ChatRequest.svelte
+++ b/src/lib/ChatRequest.svelte
@@ -6,10 +6,11 @@
import { deleteMessage, getChatSettingValueNullDefault, insertMessages, getApiKey, addError, currentChatMessages, getMessages, updateMessages, deleteSummaryMessage } from './Storage.svelte'
import { scrollToBottom, scrollToMessage } from './Util.svelte'
import { getRequestSettingList, defaultModel } from './Settings.svelte'
- import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
import { v4 as uuidv4 } from 'uuid'
import { get } from 'svelte/store'
- import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
+ import { getEndpoint, getModelDetail } from './Models.svelte'
+ import { runOpenAiCompletionRequest } from './ChatRequestOpenAi.svelte'
+ import { runPetalsCompletionRequest } from './ChatRequestPetals.svelte'
export class ChatRequest {
constructor () {
@@ -27,6 +28,14 @@ export class ChatRequest {
this.chat = chat
}
+ getChat (): Chat {
+ return this.chat
+ }
+
+ getChatSettings (): ChatSettings {
+ return this.chat.settings
+ }
+
// Common error handler
async handleError (response) {
let errorResponse
@@ -258,193 +267,10 @@ export class ChatRequest {
_this.controller = new AbortController()
const signal = _this.controller.signal
- if (modelDetail.type === 'PetalsV2Websocket') {
- // Petals
- const ws = new WebSocket(getEndpoint(model))
- const abortListener = (e:Event) => {
- _this.updating = false
- _this.updatingMessage = ''
- chatResponse.updateFromError('User aborted request.')
- signal.removeEventListener('abort', abortListener)
- ws.close()
- }
- signal.addEventListener('abort', abortListener)
- const stopSequences = modelDetail.stop || ['###']
- const stopSequencesC = stopSequences.slice()
- const stopSequence = stopSequencesC.shift()
- chatResponse.onFinish(() => {
- _this.updating = false
- _this.updatingMessage = ''
- })
- ws.onopen = () => {
- ws.send(JSON.stringify({
- type: 'open_inference_session',
- model,
- max_length: maxTokens || opts.maxTokens
- }))
- ws.onmessage = event => {
- const response = JSON.parse(event.data)
- if (!response.ok) {
- const err = new Error('Error opening socket: ' + response.traceback)
- console.error(err)
- throw err
- }
- const rMessages = request.messages || [] as Message[]
- const inputArray = (rMessages).reduce((a, m) => {
- const c = getRoleTag(m.role, model, chatSettings) + m.content
- a.push(c)
- return a
- }, [] as string[])
- const lastMessage = rMessages[rMessages.length - 1]
- if (lastMessage && lastMessage.role !== 'assistant') {
- inputArray.push(getRoleTag('assistant', model, chatSettings))
- }
- const petalsRequest = {
- type: 'generate',
- inputs: (request.messages || [] as Message[]).reduce((a, m) => {
- const c = getRoleTag(m.role, model, chatSettings) + m.content
- a.push(c)
- return a
- }, [] as string[]).join(stopSequence),
- max_new_tokens: 3, // wait for up to 3 tokens before displaying
- stop_sequence: stopSequence,
- doSample: 1,
- temperature: request.temperature || 0,
- top_p: request.top_p || 0,
- extra_stop_sequences: stopSequencesC
- }
- ws.send(JSON.stringify(petalsRequest))
- ws.onmessage = event => {
- // Remove updating indicator
- _this.updating = 1 // hide indicator, but still signal we're updating
- _this.updatingMessage = ''
- const response = JSON.parse(event.data)
- if (!response.ok) {
- const err = new Error('Error in response: ' + response.traceback)
- console.error(err)
- throw err
- }
- window.setTimeout(() => {
- chatResponse.updateFromAsyncResponse(
- {
- model,
- choices: [{
- delta: {
- content: response.outputs,
- role: 'assistant'
- },
- finish_reason: (response.stop ? 'stop' : null)
- }]
- } as any
- )
- if (response.stop) {
- const message = chatResponse.getMessages()[0]
- if (message) {
- for (let i = 0, l = stopSequences.length; i < l; i++) {
- if (message.content.endsWith(stopSequences[i])) {
- message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
- updateMessages(chatId)
- }
- }
- }
- }
- }, 1)
- }
- }
- ws.onclose = () => {
- _this.updating = false
- _this.updatingMessage = ''
- chatResponse.updateFromClose()
- }
- ws.onerror = err => {
- console.error(err)
- throw err
- }
- }
+ if (modelDetail.type === 'Petals') {
+ await runPetalsCompletionRequest(request, _this as any, chatResponse as any, signal, opts)
} else {
- // OpenAI
- const abortListener = (e:Event) => {
- _this.updating = false
- _this.updatingMessage = ''
- chatResponse.updateFromError('User aborted request.')
- signal.removeEventListener('abort', abortListener)
- }
- signal.addEventListener('abort', abortListener)
- const fetchOptions = {
- method: 'POST',
- headers: {
- Authorization: `Bearer ${getApiKey()}`,
- 'Content-Type': 'application/json'
- },
- body: JSON.stringify(request),
- signal
- }
-
- if (opts.streaming) {
- /**
- * Streaming request/response
- * We'll get the response a token at a time, as soon as they are ready
- */
- chatResponse.onFinish(() => {
- _this.updating = false
- _this.updatingMessage = ''
- })
- fetchEventSource(getEndpoint(model), {
- ...fetchOptions,
- openWhenHidden: true,
- onmessage (ev) {
- // Remove updating indicator
- _this.updating = 1 // hide indicator, but still signal we're updating
- _this.updatingMessage = ''
- // console.log('ev.data', ev.data)
- if (!chatResponse.hasFinished()) {
- if (ev.data === '[DONE]') {
- // ?? anything to do when "[DONE]"?
- } else {
- const data = JSON.parse(ev.data)
- // console.log('data', data)
- window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
- }
- }
- },
- onclose () {
- _this.updating = false
- _this.updatingMessage = ''
- chatResponse.updateFromClose()
- },
- onerror (err) {
- console.error(err)
- throw err
- },
- async onopen (response) {
- if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
- // everything's good
- } else {
- // client-side errors are usually non-retriable:
- await _this.handleError(response)
- }
- }
- }).catch(err => {
- _this.updating = false
- _this.updatingMessage = ''
- chatResponse.updateFromError(err.message)
- })
- } else {
- /**
- * Non-streaming request/response
- * We'll get the response all at once, after a long delay
- */
- const response = await fetch(getEndpoint(model), fetchOptions)
- if (!response.ok) {
- await _this.handleError(response)
- } else {
- const json = await response.json()
- // Remove updating indicator
- _this.updating = false
- _this.updatingMessage = ''
- chatResponse.updateFromSyncResponse(json)
- }
- }
+ await runOpenAiCompletionRequest(request, _this as any, chatResponse as any, signal, opts)
}
} catch (e) {
// console.error(e)
@@ -456,7 +282,7 @@ export class ChatRequest {
return chatResponse
}
- private getModel (): Model {
+ getModel (): Model {
return this.chat.settings.model || defaultModel
}
diff --git a/src/lib/ChatRequestOpenAi.svelte b/src/lib/ChatRequestOpenAi.svelte
new file mode 100644
index 0000000..37495ef
--- /dev/null
+++ b/src/lib/ChatRequestOpenAi.svelte
@@ -0,0 +1,100 @@
+
\ No newline at end of file
diff --git a/src/lib/ChatRequestPetals.svelte b/src/lib/ChatRequestPetals.svelte
new file mode 100644
index 0000000..b0c1bac
--- /dev/null
+++ b/src/lib/ChatRequestPetals.svelte
@@ -0,0 +1,126 @@
+
\ No newline at end of file
diff --git a/src/lib/Home.svelte b/src/lib/Home.svelte
index c86a17a..a69b1c2 100644
--- a/src/lib/Home.svelte
+++ b/src/lib/Home.svelte
@@ -3,7 +3,7 @@
import Footer from './Footer.svelte'
import { replace } from 'svelte-spa-router'
import { onMount } from 'svelte'
- import { getPetalsV2Websocket } from './ApiUtil.svelte'
+ import { getPetals } from './ApiUtil.svelte'
$: apiKey = $apiKeyStorage
@@ -112,7 +112,7 @@ const setPetalsEnabled = (event: Event) => {
aria-label="PetalsAPI Endpoint"
type="text"
class="input"
- placeholder={getPetalsV2Websocket()}
+ placeholder={getPetals()}
value={$globalStorage.pedalsEndpoint || ''}
/>
@@ -123,7 +123,7 @@ const setPetalsEnabled = (event: Event) => {
- Only use {getPetalsV2Websocket()} for testing. You must set up your own Petals server for actual use.
+ Only use {getPetals()} for testing. You must set up your own Petals server for actual use.
Do not send sensitive information when using Petals.
diff --git a/src/lib/Models.svelte b/src/lib/Models.svelte
index 1289939..8f03e24 100644
--- a/src/lib/Models.svelte
+++ b/src/lib/Models.svelte
@@ -1,5 +1,5 @@