More changes for Petals integration
This commit is contained in:
		
							parent
							
								
									df222e7028
								
							
						
					
					
						commit
						9a6004c55d
					
				| 
						 | 
				
			
			@ -5,12 +5,12 @@
 | 
			
		|||
  const endpointGenerations = import.meta.env.VITE_ENDPOINT_GENERATIONS || '/v1/images/generations'
 | 
			
		||||
  const endpointModels = import.meta.env.VITE_ENDPOINT_MODELS || '/v1/models'
 | 
			
		||||
  const endpointEmbeddings = import.meta.env.VITE_ENDPOINT_EMBEDDINGS || '/v1/embeddings'
 | 
			
		||||
  const endpointPetalsV2Websocket = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
 | 
			
		||||
  const endpointPetals = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
 | 
			
		||||
 | 
			
		||||
  export const getApiBase = ():string => apiBase
 | 
			
		||||
  export const getEndpointCompletions = ():string => endpointCompletions
 | 
			
		||||
  export const getEndpointGenerations = ():string => endpointGenerations
 | 
			
		||||
  export const getEndpointModels = ():string => endpointModels
 | 
			
		||||
  export const getEndpointEmbeddings = ():string => endpointEmbeddings
 | 
			
		||||
  export const getPetalsV2Websocket = ():string => endpointPetalsV2Websocket
 | 
			
		||||
  export const getPetals = ():string => endpointPetals
 | 
			
		||||
</script>
 | 
			
		||||
| 
						 | 
				
			
			@ -65,6 +65,10 @@ export class ChatCompletionResponse {
 | 
			
		|||
    this.promptTokenCount = tokens
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  getPromptTokenCount (): number {
 | 
			
		||||
    return this.promptTokenCount
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  async updateImageFromSyncResponse (response: ResponseImage, prompt: string, model: Model) {
 | 
			
		||||
    this.setModel(model)
 | 
			
		||||
    for (let i = 0; i < response.data.length; i++) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,10 +6,11 @@
 | 
			
		|||
    import { deleteMessage, getChatSettingValueNullDefault, insertMessages, getApiKey, addError, currentChatMessages, getMessages, updateMessages, deleteSummaryMessage } from './Storage.svelte'
 | 
			
		||||
    import { scrollToBottom, scrollToMessage } from './Util.svelte'
 | 
			
		||||
    import { getRequestSettingList, defaultModel } from './Settings.svelte'
 | 
			
		||||
    import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
 | 
			
		||||
    import { v4 as uuidv4 } from 'uuid'
 | 
			
		||||
    import { get } from 'svelte/store'
 | 
			
		||||
    import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
 | 
			
		||||
    import { getEndpoint, getModelDetail } from './Models.svelte'
 | 
			
		||||
    import { runOpenAiCompletionRequest } from './ChatRequestOpenAi.svelte'
 | 
			
		||||
    import { runPetalsCompletionRequest } from './ChatRequestPetals.svelte'
 | 
			
		||||
 | 
			
		||||
export class ChatRequest {
 | 
			
		||||
      constructor () {
 | 
			
		||||
| 
						 | 
				
			
			@ -27,6 +28,14 @@ export class ChatRequest {
 | 
			
		|||
        this.chat = chat
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      getChat (): Chat {
 | 
			
		||||
        return this.chat
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      getChatSettings (): ChatSettings {
 | 
			
		||||
        return this.chat.settings
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Common error handler
 | 
			
		||||
      async handleError (response) {
 | 
			
		||||
        let errorResponse
 | 
			
		||||
| 
						 | 
				
			
			@ -258,193 +267,10 @@ export class ChatRequest {
 | 
			
		|||
          _this.controller = new AbortController()
 | 
			
		||||
          const signal = _this.controller.signal
 | 
			
		||||
 | 
			
		||||
          if (modelDetail.type === 'PetalsV2Websocket') {
 | 
			
		||||
            // Petals
 | 
			
		||||
            const ws = new WebSocket(getEndpoint(model))
 | 
			
		||||
            const abortListener = (e:Event) => {
 | 
			
		||||
              _this.updating = false
 | 
			
		||||
              _this.updatingMessage = ''
 | 
			
		||||
              chatResponse.updateFromError('User aborted request.')
 | 
			
		||||
              signal.removeEventListener('abort', abortListener)
 | 
			
		||||
              ws.close()
 | 
			
		||||
            }
 | 
			
		||||
            signal.addEventListener('abort', abortListener)
 | 
			
		||||
            const stopSequences = modelDetail.stop || ['###']
 | 
			
		||||
            const stopSequencesC = stopSequences.slice()
 | 
			
		||||
            const stopSequence = stopSequencesC.shift()
 | 
			
		||||
            chatResponse.onFinish(() => {
 | 
			
		||||
              _this.updating = false
 | 
			
		||||
              _this.updatingMessage = ''
 | 
			
		||||
            })
 | 
			
		||||
            ws.onopen = () => {
 | 
			
		||||
              ws.send(JSON.stringify({
 | 
			
		||||
                type: 'open_inference_session',
 | 
			
		||||
                model,
 | 
			
		||||
                max_length: maxTokens || opts.maxTokens
 | 
			
		||||
              }))
 | 
			
		||||
              ws.onmessage = event => {
 | 
			
		||||
                const response = JSON.parse(event.data)
 | 
			
		||||
                if (!response.ok) {
 | 
			
		||||
                  const err = new Error('Error opening socket: ' + response.traceback)
 | 
			
		||||
                  console.error(err)
 | 
			
		||||
                  throw err
 | 
			
		||||
                }
 | 
			
		||||
                const rMessages = request.messages || [] as Message[]
 | 
			
		||||
                const inputArray = (rMessages).reduce((a, m) => {
 | 
			
		||||
                  const c = getRoleTag(m.role, model, chatSettings) + m.content
 | 
			
		||||
                  a.push(c)
 | 
			
		||||
                  return a
 | 
			
		||||
                }, [] as string[])
 | 
			
		||||
                const lastMessage = rMessages[rMessages.length - 1]
 | 
			
		||||
                if (lastMessage && lastMessage.role !== 'assistant') {
 | 
			
		||||
                  inputArray.push(getRoleTag('assistant', model, chatSettings))
 | 
			
		||||
                }
 | 
			
		||||
                const petalsRequest = {
 | 
			
		||||
                  type: 'generate',
 | 
			
		||||
                  inputs: (request.messages || [] as Message[]).reduce((a, m) => {
 | 
			
		||||
                    const c = getRoleTag(m.role, model, chatSettings) + m.content
 | 
			
		||||
                    a.push(c)
 | 
			
		||||
                    return a
 | 
			
		||||
                  }, [] as string[]).join(stopSequence),
 | 
			
		||||
                  max_new_tokens: 3, // wait for up to 3 tokens before displaying
 | 
			
		||||
                  stop_sequence: stopSequence,
 | 
			
		||||
                  doSample: 1,
 | 
			
		||||
                  temperature: request.temperature || 0,
 | 
			
		||||
                  top_p: request.top_p || 0,
 | 
			
		||||
                  extra_stop_sequences: stopSequencesC
 | 
			
		||||
                }
 | 
			
		||||
                ws.send(JSON.stringify(petalsRequest))
 | 
			
		||||
                ws.onmessage = event => {
 | 
			
		||||
                  // Remove updating indicator
 | 
			
		||||
                  _this.updating = 1 // hide indicator, but still signal we're updating
 | 
			
		||||
                  _this.updatingMessage = ''
 | 
			
		||||
                  const response = JSON.parse(event.data)
 | 
			
		||||
                  if (!response.ok) {
 | 
			
		||||
                    const err = new Error('Error in response: ' + response.traceback)
 | 
			
		||||
                    console.error(err)
 | 
			
		||||
                    throw err
 | 
			
		||||
                  }
 | 
			
		||||
                  window.setTimeout(() => {
 | 
			
		||||
                    chatResponse.updateFromAsyncResponse(
 | 
			
		||||
                      {
 | 
			
		||||
                        model,
 | 
			
		||||
                        choices: [{
 | 
			
		||||
                          delta: {
 | 
			
		||||
                            content: response.outputs,
 | 
			
		||||
                            role: 'assistant'
 | 
			
		||||
                          },
 | 
			
		||||
                          finish_reason: (response.stop ? 'stop' : null)
 | 
			
		||||
                        }]
 | 
			
		||||
                      } as any
 | 
			
		||||
                    )
 | 
			
		||||
                    if (response.stop) {
 | 
			
		||||
                      const message = chatResponse.getMessages()[0]
 | 
			
		||||
                      if (message) {
 | 
			
		||||
                        for (let i = 0, l = stopSequences.length; i < l; i++) {
 | 
			
		||||
                          if (message.content.endsWith(stopSequences[i])) {
 | 
			
		||||
                            message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
 | 
			
		||||
                            updateMessages(chatId)
 | 
			
		||||
                          }
 | 
			
		||||
                        }
 | 
			
		||||
                      }
 | 
			
		||||
                    }
 | 
			
		||||
                  }, 1)
 | 
			
		||||
                }
 | 
			
		||||
              }
 | 
			
		||||
              ws.onclose = () => {
 | 
			
		||||
                _this.updating = false
 | 
			
		||||
                _this.updatingMessage = ''
 | 
			
		||||
                chatResponse.updateFromClose()
 | 
			
		||||
              }
 | 
			
		||||
              ws.onerror = err => {
 | 
			
		||||
                console.error(err)
 | 
			
		||||
                throw err
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
          if (modelDetail.type === 'Petals') {
 | 
			
		||||
            await runPetalsCompletionRequest(request, _this as any, chatResponse as any, signal, opts)
 | 
			
		||||
          } else {
 | 
			
		||||
            // OpenAI
 | 
			
		||||
            const abortListener = (e:Event) => {
 | 
			
		||||
              _this.updating = false
 | 
			
		||||
              _this.updatingMessage = ''
 | 
			
		||||
              chatResponse.updateFromError('User aborted request.')
 | 
			
		||||
              signal.removeEventListener('abort', abortListener)
 | 
			
		||||
            }
 | 
			
		||||
            signal.addEventListener('abort', abortListener)
 | 
			
		||||
            const fetchOptions = {
 | 
			
		||||
              method: 'POST',
 | 
			
		||||
              headers: {
 | 
			
		||||
                Authorization: `Bearer ${getApiKey()}`,
 | 
			
		||||
                'Content-Type': 'application/json'
 | 
			
		||||
              },
 | 
			
		||||
              body: JSON.stringify(request),
 | 
			
		||||
              signal
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (opts.streaming) {
 | 
			
		||||
            /**
 | 
			
		||||
             * Streaming request/response
 | 
			
		||||
             * We'll get the response a token at a time, as soon as they are ready
 | 
			
		||||
            */
 | 
			
		||||
              chatResponse.onFinish(() => {
 | 
			
		||||
                _this.updating = false
 | 
			
		||||
                _this.updatingMessage = ''
 | 
			
		||||
              })
 | 
			
		||||
              fetchEventSource(getEndpoint(model), {
 | 
			
		||||
                ...fetchOptions,
 | 
			
		||||
                openWhenHidden: true,
 | 
			
		||||
                onmessage (ev) {
 | 
			
		||||
                  // Remove updating indicator
 | 
			
		||||
                  _this.updating = 1 // hide indicator, but still signal we're updating
 | 
			
		||||
                  _this.updatingMessage = ''
 | 
			
		||||
                  // console.log('ev.data', ev.data)
 | 
			
		||||
                  if (!chatResponse.hasFinished()) {
 | 
			
		||||
                    if (ev.data === '[DONE]') {
 | 
			
		||||
                      // ?? anything to do when "[DONE]"?
 | 
			
		||||
                    } else {
 | 
			
		||||
                      const data = JSON.parse(ev.data)
 | 
			
		||||
                      // console.log('data', data)
 | 
			
		||||
                      window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
 | 
			
		||||
                    }
 | 
			
		||||
                  }
 | 
			
		||||
                },
 | 
			
		||||
                onclose () {
 | 
			
		||||
                  _this.updating = false
 | 
			
		||||
                  _this.updatingMessage = ''
 | 
			
		||||
                  chatResponse.updateFromClose()
 | 
			
		||||
                },
 | 
			
		||||
                onerror (err) {
 | 
			
		||||
                  console.error(err)
 | 
			
		||||
                  throw err
 | 
			
		||||
                },
 | 
			
		||||
                async onopen (response) {
 | 
			
		||||
                  if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
 | 
			
		||||
                    // everything's good
 | 
			
		||||
                  } else {
 | 
			
		||||
                    // client-side errors are usually non-retriable:
 | 
			
		||||
                    await _this.handleError(response)
 | 
			
		||||
                  }
 | 
			
		||||
                }
 | 
			
		||||
              }).catch(err => {
 | 
			
		||||
                _this.updating = false
 | 
			
		||||
                _this.updatingMessage = ''
 | 
			
		||||
                chatResponse.updateFromError(err.message)
 | 
			
		||||
              })
 | 
			
		||||
            } else {
 | 
			
		||||
            /**
 | 
			
		||||
             * Non-streaming request/response
 | 
			
		||||
             * We'll get the response all at once, after a long delay
 | 
			
		||||
             */
 | 
			
		||||
              const response = await fetch(getEndpoint(model), fetchOptions)
 | 
			
		||||
              if (!response.ok) {
 | 
			
		||||
                await _this.handleError(response)
 | 
			
		||||
              } else {
 | 
			
		||||
                const json = await response.json()
 | 
			
		||||
                // Remove updating indicator
 | 
			
		||||
                _this.updating = false
 | 
			
		||||
                _this.updatingMessage = ''
 | 
			
		||||
                chatResponse.updateFromSyncResponse(json)
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
            await runOpenAiCompletionRequest(request, _this as any, chatResponse as any, signal, opts)
 | 
			
		||||
          }
 | 
			
		||||
        } catch (e) {
 | 
			
		||||
        // console.error(e)
 | 
			
		||||
| 
						 | 
				
			
			@ -456,7 +282,7 @@ export class ChatRequest {
 | 
			
		|||
        return chatResponse
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      private getModel (): Model {
 | 
			
		||||
      getModel (): Model {
 | 
			
		||||
        return this.chat.settings.model || defaultModel
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,100 @@
 | 
			
		|||
<script context="module" lang="ts">
 | 
			
		||||
    import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
 | 
			
		||||
    import ChatCompletionResponse from './ChatCompletionResponse.svelte'
 | 
			
		||||
    import ChatRequest from './ChatRequest.svelte'
 | 
			
		||||
    import { getEndpoint } from './Models.svelte'
 | 
			
		||||
    import { getApiKey } from './Storage.svelte'
 | 
			
		||||
    import type { ChatCompletionOpts, Request } from './Types.svelte'
 | 
			
		||||
 | 
			
		||||
export const runOpenAiCompletionRequest = async (
 | 
			
		||||
  request: Request,
 | 
			
		||||
  chatRequest: ChatRequest,
 | 
			
		||||
  chatResponse: ChatCompletionResponse,
 | 
			
		||||
  signal: AbortSignal,
 | 
			
		||||
  opts: ChatCompletionOpts) => {
 | 
			
		||||
    // OpenAI Request
 | 
			
		||||
      const model = chatRequest.getModel()
 | 
			
		||||
      const abortListener = (e:Event) => {
 | 
			
		||||
        chatRequest.updating = false
 | 
			
		||||
        chatRequest.updatingMessage = ''
 | 
			
		||||
        chatResponse.updateFromError('User aborted request.')
 | 
			
		||||
        chatRequest.removeEventListener('abort', abortListener)
 | 
			
		||||
      }
 | 
			
		||||
      signal.addEventListener('abort', abortListener)
 | 
			
		||||
      const fetchOptions = {
 | 
			
		||||
        method: 'POST',
 | 
			
		||||
        headers: {
 | 
			
		||||
          Authorization: `Bearer ${getApiKey()}`,
 | 
			
		||||
          'Content-Type': 'application/json'
 | 
			
		||||
        },
 | 
			
		||||
        body: JSON.stringify(request),
 | 
			
		||||
        signal
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (opts.streaming) {
 | 
			
		||||
      /**
 | 
			
		||||
             * Streaming request/response
 | 
			
		||||
             * We'll get the response a token at a time, as soon as they are ready
 | 
			
		||||
            */
 | 
			
		||||
        chatResponse.onFinish(() => {
 | 
			
		||||
          chatRequest.updating = false
 | 
			
		||||
          chatRequest.updatingMessage = ''
 | 
			
		||||
        })
 | 
			
		||||
        fetchEventSource(getEndpoint(model), {
 | 
			
		||||
          ...fetchOptions,
 | 
			
		||||
          openWhenHidden: true,
 | 
			
		||||
          onmessage (ev) {
 | 
			
		||||
          // Remove updating indicator
 | 
			
		||||
            chatRequest.updating = 1 // hide indicator, but still signal we're updating
 | 
			
		||||
            chatRequest.updatingMessage = ''
 | 
			
		||||
            // console.log('ev.data', ev.data)
 | 
			
		||||
            if (!chatResponse.hasFinished()) {
 | 
			
		||||
              if (ev.data === '[DONE]') {
 | 
			
		||||
              // ?? anything to do when "[DONE]"?
 | 
			
		||||
              } else {
 | 
			
		||||
                const data = JSON.parse(ev.data)
 | 
			
		||||
                // console.log('data', data)
 | 
			
		||||
                window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
          },
 | 
			
		||||
          onclose () {
 | 
			
		||||
            chatRequest.updating = false
 | 
			
		||||
            chatRequest.updatingMessage = ''
 | 
			
		||||
            chatResponse.updateFromClose()
 | 
			
		||||
          },
 | 
			
		||||
          onerror (err) {
 | 
			
		||||
            console.error(err)
 | 
			
		||||
            throw err
 | 
			
		||||
          },
 | 
			
		||||
          async onopen (response) {
 | 
			
		||||
            if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
 | 
			
		||||
            // everything's good
 | 
			
		||||
            } else {
 | 
			
		||||
            // client-side errors are usually non-retriable:
 | 
			
		||||
              await chatRequest.handleError(response)
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }).catch(err => {
 | 
			
		||||
          chatRequest.updating = false
 | 
			
		||||
          chatRequest.updatingMessage = ''
 | 
			
		||||
          chatResponse.updateFromError(err.message)
 | 
			
		||||
        })
 | 
			
		||||
      } else {
 | 
			
		||||
      /**
 | 
			
		||||
             * Non-streaming request/response
 | 
			
		||||
             * We'll get the response all at once, after a long delay
 | 
			
		||||
             */
 | 
			
		||||
        const response = await fetch(getEndpoint(model), fetchOptions)
 | 
			
		||||
        if (!response.ok) {
 | 
			
		||||
          await chatRequest.handleError(response)
 | 
			
		||||
        } else {
 | 
			
		||||
          const json = await response.json()
 | 
			
		||||
          // Remove updating indicator
 | 
			
		||||
          chatRequest.updating = false
 | 
			
		||||
          chatRequest.updatingMessage = ''
 | 
			
		||||
          chatResponse.updateFromSyncResponse(json)
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
}
 | 
			
		||||
</script>
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,126 @@
 | 
			
		|||
<script context="module" lang="ts">
 | 
			
		||||
    import ChatCompletionResponse from './ChatCompletionResponse.svelte'
 | 
			
		||||
    import ChatRequest from './ChatRequest.svelte'
 | 
			
		||||
    import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
 | 
			
		||||
    import type { ChatCompletionOpts, Message, Request } from './Types.svelte'
 | 
			
		||||
    import { getModelMaxTokens } from './Stats.svelte'
 | 
			
		||||
    import { updateMessages } from './Storage.svelte'
 | 
			
		||||
 | 
			
		||||
export const runPetalsCompletionRequest = async (
 | 
			
		||||
  request: Request,
 | 
			
		||||
  chatRequest: ChatRequest,
 | 
			
		||||
  chatResponse: ChatCompletionResponse,
 | 
			
		||||
  signal: AbortSignal,
 | 
			
		||||
  opts: ChatCompletionOpts) => {
 | 
			
		||||
      // Petals
 | 
			
		||||
      const model = chatRequest.getModel()
 | 
			
		||||
      const modelDetail = getModelDetail(model)
 | 
			
		||||
      const ws = new WebSocket(getEndpoint(model))
 | 
			
		||||
      const abortListener = (e:Event) => {
 | 
			
		||||
        chatRequest.updating = false
 | 
			
		||||
        chatRequest.updatingMessage = ''
 | 
			
		||||
        chatResponse.updateFromError('User aborted request.')
 | 
			
		||||
        signal.removeEventListener('abort', abortListener)
 | 
			
		||||
        ws.close()
 | 
			
		||||
      }
 | 
			
		||||
      signal.addEventListener('abort', abortListener)
 | 
			
		||||
      const startSequences = modelDetail.start || []
 | 
			
		||||
      const startSequence = startSequences[0] || ''
 | 
			
		||||
      const stopSequences = modelDetail.stop || ['###']
 | 
			
		||||
      const stopSequencesC = stopSequences.slice()
 | 
			
		||||
      const stopSequence = stopSequencesC.shift()
 | 
			
		||||
      const maxTokens = getModelMaxTokens(model)
 | 
			
		||||
      let maxLen = Math.min(opts.maxTokens || chatRequest.chat.max_tokens || maxTokens, maxTokens)
 | 
			
		||||
      const promptTokenCount = chatResponse.getPromptTokenCount()
 | 
			
		||||
      if (promptTokenCount > maxLen) {
 | 
			
		||||
        maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
 | 
			
		||||
      }
 | 
			
		||||
      chatResponse.onFinish(() => {
 | 
			
		||||
        chatRequest.updating = false
 | 
			
		||||
        chatRequest.updatingMessage = ''
 | 
			
		||||
      })
 | 
			
		||||
      ws.onopen = () => {
 | 
			
		||||
        ws.send(JSON.stringify({
 | 
			
		||||
          type: 'open_inference_session',
 | 
			
		||||
          model,
 | 
			
		||||
          max_length: maxLen
 | 
			
		||||
        }))
 | 
			
		||||
        ws.onmessage = event => {
 | 
			
		||||
          const response = JSON.parse(event.data)
 | 
			
		||||
          if (!response.ok) {
 | 
			
		||||
            const err = new Error('Error opening socket: ' + response.traceback)
 | 
			
		||||
            console.error(err)
 | 
			
		||||
            throw err
 | 
			
		||||
          }
 | 
			
		||||
          const rMessages = request.messages || [] as Message[]
 | 
			
		||||
          const inputArray = (rMessages).reduce((a, m) => {
 | 
			
		||||
            const c = getRoleTag(m.role, model, chatRequest.chat) + m.content
 | 
			
		||||
            a.push(c)
 | 
			
		||||
            return a
 | 
			
		||||
          }, [] as string[])
 | 
			
		||||
          const lastMessage = rMessages[rMessages.length - 1]
 | 
			
		||||
          if (lastMessage && lastMessage.role !== 'assistant') {
 | 
			
		||||
            inputArray.push(getRoleTag('assistant', model, chatRequest.chat))
 | 
			
		||||
          }
 | 
			
		||||
          const petalsRequest = {
 | 
			
		||||
            type: 'generate',
 | 
			
		||||
            inputs: inputArray.join(stopSequence),
 | 
			
		||||
            max_new_tokens: 3, // wait for up to 3 tokens before displaying
 | 
			
		||||
            stop_sequence: stopSequence,
 | 
			
		||||
            doSample: 1,
 | 
			
		||||
            temperature: request.temperature || 0,
 | 
			
		||||
            top_p: request.top_p || 0,
 | 
			
		||||
            extra_stop_sequences: stopSequencesC
 | 
			
		||||
          }
 | 
			
		||||
          ws.send(JSON.stringify(petalsRequest))
 | 
			
		||||
          ws.onmessage = event => {
 | 
			
		||||
            // Remove updating indicator
 | 
			
		||||
            chatRequest.updating = 1 // hide indicator, but still signal we're updating
 | 
			
		||||
            chatRequest.updatingMessage = ''
 | 
			
		||||
            const response = JSON.parse(event.data)
 | 
			
		||||
            if (!response.ok) {
 | 
			
		||||
              const err = new Error('Error in response: ' + response.traceback)
 | 
			
		||||
              console.error(err)
 | 
			
		||||
              throw err
 | 
			
		||||
            }
 | 
			
		||||
            window.setTimeout(() => {
 | 
			
		||||
              chatResponse.updateFromAsyncResponse(
 | 
			
		||||
                      {
 | 
			
		||||
                        model,
 | 
			
		||||
                        choices: [{
 | 
			
		||||
                          delta: {
 | 
			
		||||
                            content: response.outputs,
 | 
			
		||||
                            role: 'assistant'
 | 
			
		||||
                          },
 | 
			
		||||
                          finish_reason: (response.stop ? 'stop' : null)
 | 
			
		||||
                        }]
 | 
			
		||||
                      } as any
 | 
			
		||||
              )
 | 
			
		||||
              if (response.stop) {
 | 
			
		||||
                const message = chatResponse.getMessages()[0]
 | 
			
		||||
                if (message) {
 | 
			
		||||
                  for (let i = 0, l = stopSequences.length; i < l; i++) {
 | 
			
		||||
                    if (message.content.endsWith(stopSequences[i])) {
 | 
			
		||||
                      message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
 | 
			
		||||
                      const startS = startSequence[i] || ''
 | 
			
		||||
                      if (message.content.startsWith(startS)) message.content = message.content.slice(startS.length)
 | 
			
		||||
                      updateMessages(chatRequest.getChat().id)
 | 
			
		||||
                    }
 | 
			
		||||
                  }
 | 
			
		||||
                }
 | 
			
		||||
              }
 | 
			
		||||
            }, 1)
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        ws.onclose = () => {
 | 
			
		||||
          chatRequest.updating = false
 | 
			
		||||
          chatRequest.updatingMessage = ''
 | 
			
		||||
          chatResponse.updateFromClose()
 | 
			
		||||
        }
 | 
			
		||||
        ws.onerror = err => {
 | 
			
		||||
          console.error(err)
 | 
			
		||||
          throw err
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
}
 | 
			
		||||
</script>
 | 
			
		||||
| 
						 | 
				
			
			@ -3,7 +3,7 @@
 | 
			
		|||
  import Footer from './Footer.svelte'
 | 
			
		||||
  import { replace } from 'svelte-spa-router'
 | 
			
		||||
  import { onMount } from 'svelte'
 | 
			
		||||
  import { getPetalsV2Websocket } from './ApiUtil.svelte'
 | 
			
		||||
  import { getPetals } from './ApiUtil.svelte'
 | 
			
		||||
 | 
			
		||||
$: apiKey = $apiKeyStorage
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -112,7 +112,7 @@ const setPetalsEnabled = (event: Event) => {
 | 
			
		|||
              aria-label="PetalsAPI Endpoint"
 | 
			
		||||
              type="text"
 | 
			
		||||
              class="input"
 | 
			
		||||
              placeholder={getPetalsV2Websocket()}
 | 
			
		||||
              placeholder={getPetals()}
 | 
			
		||||
              value={$globalStorage.pedalsEndpoint || ''}
 | 
			
		||||
            />
 | 
			
		||||
          </p>
 | 
			
		||||
| 
						 | 
				
			
			@ -123,7 +123,7 @@ const setPetalsEnabled = (event: Event) => {
 | 
			
		|||
          
 | 
			
		||||
        </form>
 | 
			
		||||
        <p>
 | 
			
		||||
          Only use <u>{getPetalsV2Websocket()}</u> for testing.  You must set up your own Petals server for actual use. 
 | 
			
		||||
          Only use <u>{getPetals()}</u> for testing.  You must set up your own Petals server for actual use. 
 | 
			
		||||
        </p>
 | 
			
		||||
        <p>
 | 
			
		||||
          <b>Do not send sensitive information when using Petals.</b>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
<script context="module" lang="ts">
 | 
			
		||||
    import { getApiBase, getEndpointCompletions, getEndpointGenerations, getEndpointModels, getPetalsV2Websocket } from './ApiUtil.svelte'
 | 
			
		||||
    import { getApiBase, getEndpointCompletions, getEndpointGenerations, getEndpointModels, getPetals } from './ApiUtil.svelte'
 | 
			
		||||
    import { apiKeyStorage, globalStorage } from './Storage.svelte'
 | 
			
		||||
    import { get } from 'svelte/store'
 | 
			
		||||
    import type { ModelDetail, Model, ResponseModels, SelectOption, ChatSettings } from './Types.svelte'
 | 
			
		||||
| 
						 | 
				
			
			@ -34,9 +34,10 @@ const modelDetails : Record<string, ModelDetail> = {
 | 
			
		|||
        max: 16384 // 16k max token buffer
 | 
			
		||||
      },
 | 
			
		||||
      'meta-llama/Llama-2-70b-chat-hf': {
 | 
			
		||||
        type: 'PetalsV2Websocket',
 | 
			
		||||
        type: 'Petals',
 | 
			
		||||
        label: 'Petals - Llama-2-70b-chat',
 | 
			
		||||
        stop: ['###', '</s>'],
 | 
			
		||||
        start: [''],
 | 
			
		||||
        stop: ['</s>'],
 | 
			
		||||
        prompt: 0.000000, // $0.000 per 1000 tokens prompt
 | 
			
		||||
        completion: 0.000000, // $0.000 per 1000 tokens completion
 | 
			
		||||
        max: 4096 // 4k max token buffer
 | 
			
		||||
| 
						 | 
				
			
			@ -119,8 +120,8 @@ export const getEndpoint = (model: Model): string => {
 | 
			
		|||
  const modelDetails = getModelDetail(model)
 | 
			
		||||
  const gSettings = get(globalStorage)
 | 
			
		||||
  switch (modelDetails.type) {
 | 
			
		||||
        case 'PetalsV2Websocket':
 | 
			
		||||
          return gSettings.pedalsEndpoint || getPetalsV2Websocket()
 | 
			
		||||
        case 'Petals':
 | 
			
		||||
          return gSettings.pedalsEndpoint || getPetals()
 | 
			
		||||
        case 'OpenAIDall-e':
 | 
			
		||||
          return getApiBase() + getEndpointGenerations()
 | 
			
		||||
        case 'OpenAIChat':
 | 
			
		||||
| 
						 | 
				
			
			@ -132,12 +133,12 @@ export const getEndpoint = (model: Model): string => {
 | 
			
		|||
export const getRoleTag = (role: string, model: Model, settings: ChatSettings): string => {
 | 
			
		||||
  const modelDetails = getModelDetail(model)
 | 
			
		||||
  switch (modelDetails.type) {
 | 
			
		||||
        case 'PetalsV2Websocket':
 | 
			
		||||
        case 'Petals':
 | 
			
		||||
          if (role === 'assistant') {
 | 
			
		||||
            return ('Assistant') +
 | 
			
		||||
              ': '
 | 
			
		||||
            if (settings.useSystemPrompt && settings.characterName) return '[' + settings.characterName + '] '
 | 
			
		||||
            return '[Assistant] '
 | 
			
		||||
          }
 | 
			
		||||
          if (role === 'user') return 'Human: '
 | 
			
		||||
          if (role === 'user') return '[user] '
 | 
			
		||||
          return ''
 | 
			
		||||
        case 'OpenAIDall-e':
 | 
			
		||||
          return role
 | 
			
		||||
| 
						 | 
				
			
			@ -150,7 +151,7 @@ export const getRoleTag = (role: string, model: Model, settings: ChatSettings):
 | 
			
		|||
export const getTokens = (model: Model, value: string): number[] => {
 | 
			
		||||
  const modelDetails = getModelDetail(model)
 | 
			
		||||
  switch (modelDetails.type) {
 | 
			
		||||
        case 'PetalsV2Websocket':
 | 
			
		||||
        case 'Petals':
 | 
			
		||||
          return llamaTokenizer.encode(value)
 | 
			
		||||
        case 'OpenAIDall-e':
 | 
			
		||||
          return [0]
 | 
			
		||||
| 
						 | 
				
			
			@ -184,7 +185,7 @@ export async function getModelOptions (): Promise<SelectOption[]> {
 | 
			
		|||
  }
 | 
			
		||||
  const filteredModels = supportedModelKeys.filter((model) => {
 | 
			
		||||
        switch (getModelDetail(model).type) {
 | 
			
		||||
          case 'PetalsV2Websocket':
 | 
			
		||||
          case 'Petals':
 | 
			
		||||
            return gSettings.enablePetals
 | 
			
		||||
          case 'OpenAIChat':
 | 
			
		||||
          default:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,19 +10,12 @@
 | 
			
		|||
  export const countPromptTokens = (prompts:Message[], model:Model, settings: ChatSettings):number => {
 | 
			
		||||
    const detail = getModelDetail(model)
 | 
			
		||||
    const count = prompts.reduce((a, m) => {
 | 
			
		||||
      switch (detail.type) {
 | 
			
		||||
        case 'PetalsV2Websocket':
 | 
			
		||||
      a += countMessageTokens(m, model, settings)
 | 
			
		||||
          break
 | 
			
		||||
        case 'OpenAIChat':
 | 
			
		||||
        default:
 | 
			
		||||
          a += countMessageTokens(m, model, settings)
 | 
			
		||||
      }
 | 
			
		||||
      return a
 | 
			
		||||
    }, 0)
 | 
			
		||||
    switch (detail.type) {
 | 
			
		||||
      case 'PetalsV2Websocket':
 | 
			
		||||
        return count + (Math.max(prompts.length - 1, 0) * countTokens(model, (detail.stop && detail.stop[0]) || '###')) // todo, make stop per model?
 | 
			
		||||
      case 'Petals':
 | 
			
		||||
        return count
 | 
			
		||||
      case 'OpenAIChat':
 | 
			
		||||
      default:
 | 
			
		||||
        // Not sure how OpenAI formats it, but this seems to get close to the right counts.
 | 
			
		||||
| 
						 | 
				
			
			@ -34,9 +27,11 @@
 | 
			
		|||
 | 
			
		||||
  export const countMessageTokens = (message:Message, model:Model, settings: ChatSettings):number => {
 | 
			
		||||
    const detail = getModelDetail(model)
 | 
			
		||||
    const start = detail.start && detail.start[0]
 | 
			
		||||
    const stop = detail.stop && detail.stop[0]
 | 
			
		||||
    switch (detail.type) {
 | 
			
		||||
      case 'PetalsV2Websocket':
 | 
			
		||||
        return countTokens(model, getRoleTag(message.role, model, settings) + ': ' + message.content)
 | 
			
		||||
      case 'Petals':
 | 
			
		||||
        return countTokens(model, (start || '') + getRoleTag(message.role, model, settings) + ': ' + message.content + (stop || '###'))
 | 
			
		||||
      case 'OpenAIChat':
 | 
			
		||||
      default:
 | 
			
		||||
        // Not sure how OpenAI formats it, but this seems to get close to the right counts.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,12 +7,13 @@ export type Model = typeof supportedModelKeys[number];
 | 
			
		|||
 | 
			
		||||
export type ImageGenerationSizes = typeof imageGenerationSizeTypes[number];
 | 
			
		||||
 | 
			
		||||
export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'PetalsV2Websocket'
 | 
			
		||||
export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'Petals'
 | 
			
		||||
 | 
			
		||||
export type ModelDetail = {
 | 
			
		||||
    type: RequestType;
 | 
			
		||||
    label?: string;
 | 
			
		||||
    stop?: string[];
 | 
			
		||||
    start?: string[];
 | 
			
		||||
    prompt: number;
 | 
			
		||||
    completion: number;
 | 
			
		||||
    max: number;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue