Add holdWebsocket option for faster petals chat
This commit is contained in:
		
							parent
							
								
									e991dfd9b7
								
							
						
					
					
						commit
						2e4181bf7e
					
				| 
						 | 
					@ -21,6 +21,7 @@ export class ChatRequest {
 | 
				
			||||||
      updating: boolean|number = false
 | 
					      updating: boolean|number = false
 | 
				
			||||||
      updatingMessage: string = ''
 | 
					      updatingMessage: string = ''
 | 
				
			||||||
      controller:AbortController
 | 
					      controller:AbortController
 | 
				
			||||||
 | 
					      providerData: Record<string, any> = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      setChat (chat: Chat) {
 | 
					      setChat (chat: Chat) {
 | 
				
			||||||
        this.chat = chat
 | 
					        this.chat = chat
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -122,6 +122,7 @@ const defaults:ChatSettings = {
 | 
				
			||||||
  systemMessageEnd: '',
 | 
					  systemMessageEnd: '',
 | 
				
			||||||
  leadPrompt: '',
 | 
					  leadPrompt: '',
 | 
				
			||||||
  repetitionPenalty: 1.1,
 | 
					  repetitionPenalty: 1.1,
 | 
				
			||||||
 | 
					  holdSocket: true,
 | 
				
			||||||
  // useResponseAlteration: false,
 | 
					  // useResponseAlteration: false,
 | 
				
			||||||
  // responseAlterations: [],
 | 
					  // responseAlterations: [],
 | 
				
			||||||
  isDirty: false
 | 
					  isDirty: false
 | 
				
			||||||
| 
						 | 
					@ -451,6 +452,13 @@ const chatSettingsList: ChatSetting[] = [
 | 
				
			||||||
        type: 'boolean',
 | 
					        type: 'boolean',
 | 
				
			||||||
        hide: hideModelSetting
 | 
					        hide: hideModelSetting
 | 
				
			||||||
      },
 | 
					      },
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        key: 'holdSocket',
 | 
				
			||||||
 | 
					        name: 'Continue WebSocket',
 | 
				
			||||||
 | 
					        title: 'Hold WebSocket connection open and try to re-use for each new chat message. Faster, but message delimitation could get mangled.',
 | 
				
			||||||
 | 
					        type: 'boolean',
 | 
				
			||||||
 | 
					        hide: hideModelSetting
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
      {
 | 
					      {
 | 
				
			||||||
        key: 'temperature',
 | 
					        key: 'temperature',
 | 
				
			||||||
        name: 'Sampling Temperature',
 | 
					        name: 'Sampling Temperature',
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -96,6 +96,7 @@ export type ChatSettings = {
 | 
				
			||||||
    systemMessageStart: string;
 | 
					    systemMessageStart: string;
 | 
				
			||||||
    systemMessageEnd: string;
 | 
					    systemMessageEnd: string;
 | 
				
			||||||
    repetitionPenalty: number;
 | 
					    repetitionPenalty: number;
 | 
				
			||||||
 | 
					    holdSocket: boolean;
 | 
				
			||||||
    isDirty?: boolean;
 | 
					    isDirty?: boolean;
 | 
				
			||||||
  } & Request;
 | 
					  } & Request;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -154,4 +154,8 @@
 | 
				
			||||||
    return value
 | 
					    return value
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  export const escapeRegex = (string: string): string => {
 | 
				
			||||||
 | 
					    return string.replace(/[/\-\\^$*+?.()|[\]{}]/g, '\\$&')
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
</script> 
 | 
					</script> 
 | 
				
			||||||
| 
						 | 
					@ -20,7 +20,8 @@ const hiddenSettings = {
 | 
				
			||||||
      assistantMessageEnd: true,
 | 
					      assistantMessageEnd: true,
 | 
				
			||||||
      systemMessageStart: true,
 | 
					      systemMessageStart: true,
 | 
				
			||||||
      systemMessageEnd: true,
 | 
					      systemMessageEnd: true,
 | 
				
			||||||
      repetitionPenalty: true
 | 
					      repetitionPenalty: true,
 | 
				
			||||||
 | 
					      holdSocket: true
 | 
				
			||||||
      // leadPrompt: true
 | 
					      // leadPrompt: true
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,6 +5,29 @@
 | 
				
			||||||
    import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
 | 
					    import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
 | 
				
			||||||
    import { getModelMaxTokens } from '../../Stats.svelte'
 | 
					    import { getModelMaxTokens } from '../../Stats.svelte'
 | 
				
			||||||
    import { updateMessages } from '../../Storage.svelte'
 | 
					    import { updateMessages } from '../../Storage.svelte'
 | 
				
			||||||
 | 
					    import { escapeRegex } from '../../Util.svelte'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const levenshteinDistance = (str1 = '', str2 = '') => {
 | 
				
			||||||
 | 
					  const track = Array(str2.length + 1).fill(null).map(() =>
 | 
				
			||||||
 | 
					        Array(str1.length + 1).fill(null))
 | 
				
			||||||
 | 
					  for (let i = 0; i <= str1.length; i += 1) {
 | 
				
			||||||
 | 
					        track[0][i] = i
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  for (let j = 0; j <= str2.length; j += 1) {
 | 
				
			||||||
 | 
					        track[j][0] = j
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  for (let j = 1; j <= str2.length; j += 1) {
 | 
				
			||||||
 | 
					        for (let i = 1; i <= str1.length; i += 1) {
 | 
				
			||||||
 | 
					          const indicator = str1[i - 1] === str2[j - 1] ? 0 : 1
 | 
				
			||||||
 | 
					          track[j][i] = Math.min(
 | 
				
			||||||
 | 
					            track[j][i - 1] + 1, // deletion
 | 
				
			||||||
 | 
					            track[j - 1][i] + 1, // insertion
 | 
				
			||||||
 | 
					            track[j - 1][i - 1] + indicator // substitution
 | 
				
			||||||
 | 
					          )
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return track[str2.length][str1.length]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export const chatRequest = async (
 | 
					export const chatRequest = async (
 | 
				
			||||||
  request: Request,
 | 
					  request: Request,
 | 
				
			||||||
| 
						 | 
					@ -16,8 +39,10 @@ export const chatRequest = async (
 | 
				
			||||||
      const chatSettings = chat.settings
 | 
					      const chatSettings = chat.settings
 | 
				
			||||||
      const model = chatRequest.getModel()
 | 
					      const model = chatRequest.getModel()
 | 
				
			||||||
      const modelDetail = getModelDetail(model)
 | 
					      const modelDetail = getModelDetail(model)
 | 
				
			||||||
      const ws = new WebSocket(getEndpoint(model))
 | 
					 | 
				
			||||||
      const signal = chatRequest.controller.signal
 | 
					      const signal = chatRequest.controller.signal
 | 
				
			||||||
 | 
					      const providerData = chatRequest.providerData.petals || {}
 | 
				
			||||||
 | 
					      chatRequest.providerData.petals = providerData
 | 
				
			||||||
 | 
					      let ws: WebSocket = providerData.ws
 | 
				
			||||||
      const abortListener = (e:Event) => {
 | 
					      const abortListener = (e:Event) => {
 | 
				
			||||||
        chatRequest.updating = false
 | 
					        chatRequest.updating = false
 | 
				
			||||||
        chatRequest.updatingMessage = ''
 | 
					        chatRequest.updatingMessage = ''
 | 
				
			||||||
| 
						 | 
					@ -26,9 +51,17 @@ export const chatRequest = async (
 | 
				
			||||||
        ws.close()
 | 
					        ws.close()
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      signal.addEventListener('abort', abortListener)
 | 
					      signal.addEventListener('abort', abortListener)
 | 
				
			||||||
 | 
					      const startSequence = getStartSequence(chat)
 | 
				
			||||||
      let stopSequences = [...new Set(getStopSequence(chat).split(',').filter(s => s.trim()).concat((modelDetail.stop || ['###', '</s>']).slice()))]
 | 
					      let stopSequences = [...new Set(getStopSequence(chat).split(',').filter(s => s.trim()).concat((modelDetail.stop || ['###', '</s>']).slice()))]
 | 
				
			||||||
      const stopSequence = '</s>'
 | 
					      let stopSequence = stopSequences[0] || '###'
 | 
				
			||||||
 | 
					      if (startSequence.length) {
 | 
				
			||||||
 | 
					        const sld = stopSequences.slice()
 | 
				
			||||||
 | 
					          .filter(s => s === '###' || '</s>' || countTokens(model, s) === 1)
 | 
				
			||||||
 | 
					          .sort((a, b) => levenshteinDistance(a, startSequence) - levenshteinDistance(b, startSequence))
 | 
				
			||||||
 | 
					        stopSequence = sld[0] || stopSequence
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
      stopSequences.push(stopSequence)
 | 
					      stopSequences.push(stopSequence)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
      const delimiter = getDelimiter(chat)
 | 
					      const delimiter = getDelimiter(chat)
 | 
				
			||||||
      const leadPromptSequence = getLeadPrompt(chat)
 | 
					      const leadPromptSequence = getLeadPrompt(chat)
 | 
				
			||||||
      if (delimiter) stopSequences.unshift(delimiter.trim())
 | 
					      if (delimiter) stopSequences.unshift(delimiter.trim())
 | 
				
			||||||
| 
						 | 
					@ -62,56 +95,55 @@ export const chatRequest = async (
 | 
				
			||||||
      const buildMessage = (m: Message): string => {
 | 
					      const buildMessage = (m: Message): string => {
 | 
				
			||||||
        return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
 | 
					        return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					      const buildInputArray = (a) => {
 | 
				
			||||||
 | 
					        return a.reduce((a, m, i) => {
 | 
				
			||||||
 | 
					          let c = buildMessage(m)
 | 
				
			||||||
 | 
					          let replace = false
 | 
				
			||||||
 | 
					          const lm = a[a.length - 1]
 | 
				
			||||||
 | 
					          // Merge content if needed
 | 
				
			||||||
 | 
					          if (lm) {
 | 
				
			||||||
 | 
					            if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
 | 
				
			||||||
 | 
					              c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
 | 
				
			||||||
 | 
					              replace = true
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					              c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
 | 
				
			||||||
 | 
					              c = c.replaceAll('[[USER_PROMPT]]', lm.content)
 | 
				
			||||||
 | 
					              replace = true
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					              c = c.replaceAll('[[USER_PROMPT]]', '')
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          // Clean up merge fields on last
 | 
				
			||||||
 | 
					          if (!rMessages[i + 1]) {
 | 
				
			||||||
 | 
					            c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          const result = {
 | 
				
			||||||
 | 
					            role: m.role,
 | 
				
			||||||
 | 
					            content: c.trim()
 | 
				
			||||||
 | 
					          } as Message
 | 
				
			||||||
 | 
					          if (replace) {
 | 
				
			||||||
 | 
					            a[a.length - 1] = result
 | 
				
			||||||
 | 
					          } else {
 | 
				
			||||||
 | 
					            a.push(result)
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          return a
 | 
				
			||||||
 | 
					        }, [] as Message[])
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
      const lastMessage = rMessages[rMessages.length - 1]
 | 
					      const lastMessage = rMessages[rMessages.length - 1]
 | 
				
			||||||
      let doLead = true
 | 
					      let doLead = true
 | 
				
			||||||
      if (lastMessage && lastMessage.role === 'assistant') {
 | 
					      if (lastMessage && lastMessage.role === 'assistant') {
 | 
				
			||||||
        lastMessage.content = leadPromptSequence + lastMessage.content
 | 
					        lastMessage.content = leadPromptSequence + lastMessage.content
 | 
				
			||||||
        doLead = false
 | 
					        doLead = false
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      const inputArray = rMessages.reduce((a, m, i) => {
 | 
					      // const inputArray = buildInputArray(rMessages).map(m => m.content)
 | 
				
			||||||
        let c = buildMessage(m)
 | 
					      const lInputArray = buildInputArray(rMessages.slice(0, -1)).map(m => m.content)
 | 
				
			||||||
        let replace = false
 | 
					      const nInputArray = buildInputArray(rMessages.slice(-1)).map(m => m.content)
 | 
				
			||||||
        const lm = a[a.length - 1]
 | 
					 | 
				
			||||||
        // Merge content if needed
 | 
					 | 
				
			||||||
        if (lm) {
 | 
					 | 
				
			||||||
          if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
 | 
					 | 
				
			||||||
            c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
 | 
					 | 
				
			||||||
            replace = true
 | 
					 | 
				
			||||||
          } else {
 | 
					 | 
				
			||||||
            c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
          if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
 | 
					 | 
				
			||||||
            c = c.replaceAll('[[USER_PROMPT]]', lm.content)
 | 
					 | 
				
			||||||
            replace = true
 | 
					 | 
				
			||||||
          } else {
 | 
					 | 
				
			||||||
            c = c.replaceAll('[[USER_PROMPT]]', '')
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        // Clean up merge fields on last
 | 
					 | 
				
			||||||
        if (!rMessages[i + 1]) {
 | 
					 | 
				
			||||||
          c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        const result = {
 | 
					 | 
				
			||||||
          role: m.role,
 | 
					 | 
				
			||||||
          content: c.trim()
 | 
					 | 
				
			||||||
        } as Message
 | 
					 | 
				
			||||||
        if (replace) {
 | 
					 | 
				
			||||||
          a[a.length - 1] = result
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
          a.push(result)
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        return a
 | 
					 | 
				
			||||||
      }, [] as Message[])
 | 
					 | 
				
			||||||
      const leadPrompt = (leadPromptSequence && doLead) ? delimiter + leadPromptSequence : ''
 | 
					      const leadPrompt = (leadPromptSequence && doLead) ? delimiter + leadPromptSequence : ''
 | 
				
			||||||
      const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(delimiter) + leadPrompt
 | 
					      const lastPrompt = startSequence + lInputArray.join(delimiter)
 | 
				
			||||||
 | 
					      const nextPrompt = nInputArray.slice(-1).join('') + leadPrompt
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
      let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
 | 
					 | 
				
			||||||
      const promptTokenCount = countTokens(model, fullPromptInput)
 | 
					 | 
				
			||||||
      if (promptTokenCount > maxLen) {
 | 
					 | 
				
			||||||
        maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      // update with real count
 | 
					 | 
				
			||||||
      chatResponse.setPromptTokenCount(promptTokenCount)
 | 
					 | 
				
			||||||
      // set up the request
 | 
					      // set up the request
 | 
				
			||||||
      chatResponse.onFinish(() => {
 | 
					      chatResponse.onFinish(() => {
 | 
				
			||||||
        const message = chatResponse.getMessages()[0]
 | 
					        const message = chatResponse.getMessages()[0]
 | 
				
			||||||
| 
						 | 
					@ -124,51 +156,119 @@ export const chatRequest = async (
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        ws.close()
 | 
					        !chatSettings.holdSocket && ws.close()
 | 
				
			||||||
      })
 | 
					      })
 | 
				
			||||||
      ws.onopen = () => {
 | 
					
 | 
				
			||||||
        ws.send(JSON.stringify({
 | 
					      let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
 | 
				
			||||||
          type: 'open_inference_session',
 | 
					
 | 
				
			||||||
          model,
 | 
					      let inputPrompt = startSequence
 | 
				
			||||||
          max_length: maxLen
 | 
					
 | 
				
			||||||
        }))
 | 
					      const getNewWs = ():Promise<WebSocket> => new Promise<WebSocket>((resolve, reject) => {
 | 
				
			||||||
        ws.onmessage = event => {
 | 
					        // console.warn('requesting new ws')
 | 
				
			||||||
 | 
					        const nws = new WebSocket(getEndpoint(model))
 | 
				
			||||||
 | 
					        let opened = false
 | 
				
			||||||
 | 
					        let done = false
 | 
				
			||||||
 | 
					        nws.onmessage = event => {
 | 
				
			||||||
 | 
					          if (done) return
 | 
				
			||||||
 | 
					          done = true
 | 
				
			||||||
          const response = JSON.parse(event.data)
 | 
					          const response = JSON.parse(event.data)
 | 
				
			||||||
          if (!response.ok) {
 | 
					          if (!response.ok) {
 | 
				
			||||||
            const err = new Error('Error opening socket: ' + response.traceback)
 | 
					            const err = new Error('Error opening socket: ' + response.traceback)
 | 
				
			||||||
            chatResponse.updateFromError(err.message)
 | 
					            chatResponse.updateFromError(err.message)
 | 
				
			||||||
 | 
					            console.error(err)
 | 
				
			||||||
 | 
					            reject(err)
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          nws.onerror = err => {
 | 
				
			||||||
            console.error(err)
 | 
					            console.error(err)
 | 
				
			||||||
            throw err
 | 
					            throw err
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
          const petalsRequest = {
 | 
					          // console.warn('got new ws')
 | 
				
			||||||
            type: 'generate',
 | 
					          inputPrompt = lastPrompt
 | 
				
			||||||
            inputs: fullPromptInput,
 | 
					          providerData.knownBuffer = ''
 | 
				
			||||||
            max_new_tokens: 1, // wait for up to 1 tokens before displaying
 | 
					          providerData.ws = nws
 | 
				
			||||||
            stop_sequence: stopSequence,
 | 
					          resolve(nws)
 | 
				
			||||||
            do_sample: 1, // enable top p and the like
 | 
					        }
 | 
				
			||||||
            temperature,
 | 
					        nws.onclose = () => {
 | 
				
			||||||
            top_p: topP,
 | 
					          chatResponse.updateFromClose()
 | 
				
			||||||
            repetition_penalty: chatSettings.repetitionPenalty
 | 
					        }
 | 
				
			||||||
          } as any
 | 
					        nws.onerror = err => {
 | 
				
			||||||
          if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
 | 
					          if (done) return
 | 
				
			||||||
          // Update token count
 | 
					          done = true
 | 
				
			||||||
 | 
					          console.error(err)
 | 
				
			||||||
 | 
					          reject(err)
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        nws.onopen = () => {
 | 
				
			||||||
 | 
					          if (opened) return
 | 
				
			||||||
 | 
					          opened = true
 | 
				
			||||||
 | 
					          const promptTokenCount = countTokens(model, lastPrompt + delimiter + nextPrompt)
 | 
				
			||||||
 | 
					          if (promptTokenCount > maxLen) {
 | 
				
			||||||
 | 
					            maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          // update with real count
 | 
				
			||||||
          chatResponse.setPromptTokenCount(promptTokenCount)
 | 
					          chatResponse.setPromptTokenCount(promptTokenCount)
 | 
				
			||||||
          ws.send(JSON.stringify(petalsRequest))
 | 
					          nws.send(JSON.stringify({
 | 
				
			||||||
          ws.onmessage = event => {
 | 
					            type: 'open_inference_session',
 | 
				
			||||||
            // Remove updating indicator
 | 
					            model,
 | 
				
			||||||
            chatRequest.updating = 1 // hide indicator, but still signal we're updating
 | 
					            max_length: chatSettings.holdSocket ? maxTokens : maxLen
 | 
				
			||||||
            chatRequest.updatingMessage = ''
 | 
					          }))
 | 
				
			||||||
            const response = JSON.parse(event.data)
 | 
					        }
 | 
				
			||||||
            if (!response.ok) {
 | 
					      })
 | 
				
			||||||
              if (response.traceback.includes('Maximum length exceeded')) {
 | 
					
 | 
				
			||||||
                return chatResponse.finish('length')
 | 
					      const wsOpen = (ws && ws.readyState !== WebSocket.CLOSED)
 | 
				
			||||||
              }
 | 
					
 | 
				
			||||||
              const err = new Error('Error in response: ' + response.traceback)
 | 
					      if (!chatSettings.holdSocket || wsOpen) {
 | 
				
			||||||
              console.error(err)
 | 
					        const rgxp = new RegExp('(<s>|</s>|\\s|' + escapeRegex(stopSequence) + ')', 'g')
 | 
				
			||||||
              chatResponse.updateFromError(err.message)
 | 
					        const kb = providerData.knownBuffer.replace(rgxp, '')
 | 
				
			||||||
              throw err
 | 
					        const lp = lastPrompt.replace(rgxp, '')
 | 
				
			||||||
            }
 | 
					        const lm = kb === lp
 | 
				
			||||||
            chatResponse.updateFromAsyncResponse(
 | 
					        if (!lm || countTokens(model, providerData.knownBuffer + inputPrompt) >= maxTokens) {
 | 
				
			||||||
 | 
					          wsOpen && ws.close()
 | 
				
			||||||
 | 
					          ws = await getNewWs()
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      if (!ws || ws.readyState === WebSocket.CLOSED) {
 | 
				
			||||||
 | 
					        ws = await getNewWs()
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      inputPrompt += delimiter + nextPrompt
 | 
				
			||||||
 | 
					      providerData.knownBuffer += inputPrompt
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					      // console.log(
 | 
				
			||||||
 | 
					      //   '\n\n*** inputPrompt: ***\n\n',
 | 
				
			||||||
 | 
					      //   inputPrompt
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					      // )
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					      const petalsRequest = {
 | 
				
			||||||
 | 
					        type: 'generate',
 | 
				
			||||||
 | 
					        inputs: inputPrompt,
 | 
				
			||||||
 | 
					        max_new_tokens: 1, // wait for up to 1 tokens before displaying
 | 
				
			||||||
 | 
					        stop_sequence: stopSequence,
 | 
				
			||||||
 | 
					        do_sample: 1, // enable top p and the like
 | 
				
			||||||
 | 
					        temperature,
 | 
				
			||||||
 | 
					        top_p: topP,
 | 
				
			||||||
 | 
					        repetition_penalty: chatSettings.repetitionPenalty
 | 
				
			||||||
 | 
					      } as any
 | 
				
			||||||
 | 
					      if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
 | 
				
			||||||
 | 
					      // Update token count
 | 
				
			||||||
 | 
					      chatResponse.setPromptTokenCount(countTokens(model, providerData.knownBuffer))
 | 
				
			||||||
 | 
					      ws.onmessage = event => {
 | 
				
			||||||
 | 
					        // Remove updating indicator
 | 
				
			||||||
 | 
					        chatRequest.updating = 1 // hide indicator, but still signal we're updating
 | 
				
			||||||
 | 
					        chatRequest.updatingMessage = ''
 | 
				
			||||||
 | 
					        const response = JSON.parse(event.data)
 | 
				
			||||||
 | 
					        if (!response.ok) {
 | 
				
			||||||
 | 
					          if (response.traceback.includes('Maximum length exceeded')) {
 | 
				
			||||||
 | 
					            return chatResponse.finish('length')
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          const err = new Error('Error in response: ' + response.traceback)
 | 
				
			||||||
 | 
					          console.error(err)
 | 
				
			||||||
 | 
					          chatResponse.updateFromError(err.message)
 | 
				
			||||||
 | 
					          throw err
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        providerData.knownBuffer += response.outputs
 | 
				
			||||||
 | 
					        chatResponse.updateFromAsyncResponse(
 | 
				
			||||||
                {
 | 
					                {
 | 
				
			||||||
                  model,
 | 
					                  model,
 | 
				
			||||||
                  choices: [{
 | 
					                  choices: [{
 | 
				
			||||||
| 
						 | 
					@ -179,37 +279,32 @@ export const chatRequest = async (
 | 
				
			||||||
                    finish_reason: (response.stop ? 'stop' : null)
 | 
					                    finish_reason: (response.stop ? 'stop' : null)
 | 
				
			||||||
                  }]
 | 
					                  }]
 | 
				
			||||||
                } as any
 | 
					                } as any
 | 
				
			||||||
            )
 | 
					        )
 | 
				
			||||||
            if (chatSettings.aggressiveStop && !response.stop) {
 | 
					        if (chatSettings.aggressiveStop && !response.stop) {
 | 
				
			||||||
              // check if we should've stopped
 | 
					          // check if we should've stopped
 | 
				
			||||||
              const message = chatResponse.getMessages()[0]
 | 
					          const message = chatResponse.getMessages()[0]
 | 
				
			||||||
              const pad = 10 // look back 10 characters + stop sequence
 | 
					          const pad = 10 // look back 10 characters + stop sequence
 | 
				
			||||||
              if (message) {
 | 
					          if (message) {
 | 
				
			||||||
                const mc = (message.content).trim()
 | 
					            const mc = (message.content).trim()
 | 
				
			||||||
                for (let i = 0, l = stopSequences.length; i < l; i++) {
 | 
					            for (let i = 0, l = stopSequences.length; i < l; i++) {
 | 
				
			||||||
                  const ss = stopSequences[i].trim()
 | 
					              const ss = stopSequences[i].trim()
 | 
				
			||||||
                  const ind = mc.slice(0 - (ss.length + pad)).indexOf(ss)
 | 
					              const ind = mc.slice(0 - (ss.length + pad)).indexOf(ss)
 | 
				
			||||||
                  if (ind > -1) {
 | 
					              if (ind > -1) {
 | 
				
			||||||
                    const offset = (ss.length + pad) - ind
 | 
					                const offset = (ss.length + pad) - ind
 | 
				
			||||||
                    message.content = mc.slice(0, mc.length - offset)
 | 
					                message.content = mc.slice(0, mc.length - offset)
 | 
				
			||||||
                    response.stop = true
 | 
					                response.stop = true
 | 
				
			||||||
                    updateMessages(chat.id)
 | 
					                updateMessages(chat.id)
 | 
				
			||||||
                    chatResponse.finish()
 | 
					                chatResponse.finish()
 | 
				
			||||||
                    ws.close()
 | 
					                if (ss !== stopSequence) {
 | 
				
			||||||
                  }
 | 
					                  providerData.knownBuffer += stopSequence
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					                ws.close()
 | 
				
			||||||
              }
 | 
					              }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        ws.onclose = () => {
 | 
					 | 
				
			||||||
          chatResponse.updateFromClose()
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        ws.onerror = err => {
 | 
					 | 
				
			||||||
          console.error(err)
 | 
					 | 
				
			||||||
          throw err
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					      ws.send(JSON.stringify(petalsRequest))
 | 
				
			||||||
      return chatResponse
 | 
					      return chatResponse
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
</script>
 | 
					</script>
 | 
				
			||||||
		Loading…
	
		Reference in New Issue