update token count
This commit is contained in:
		
							parent
							
								
									86f427f62f
								
							
						
					
					
						commit
						f4d9774423
					
				| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
<script context="module" lang="ts">
 | 
			
		||||
    import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte'
 | 
			
		||||
    import { ChatRequest } from '../../ChatRequest.svelte'
 | 
			
		||||
    import { getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
 | 
			
		||||
    import { countTokens, getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
 | 
			
		||||
    import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
 | 
			
		||||
    import { getModelMaxTokens } from '../../Stats.svelte'
 | 
			
		||||
    import { updateMessages } from '../../Storage.svelte'
 | 
			
		||||
| 
						 | 
				
			
			@ -35,40 +35,7 @@ export const chatRequest = async (
 | 
			
		|||
      stopSequences = stopSequences.sort((a, b) => b.length - a.length)
 | 
			
		||||
      const stopSequencesC = stopSequences.filter(s => s !== stopSequence)
 | 
			
		||||
      const maxTokens = getModelMaxTokens(model)
 | 
			
		||||
      let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
 | 
			
		||||
      const promptTokenCount = chatResponse.getPromptTokenCount()
 | 
			
		||||
      if (promptTokenCount > maxLen) {
 | 
			
		||||
        maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
 | 
			
		||||
      }
 | 
			
		||||
      chatResponse.onFinish(() => {
 | 
			
		||||
        const message = chatResponse.getMessages()[0]
 | 
			
		||||
        if (message) {
 | 
			
		||||
          for (let i = 0, l = stopSequences.length; i < l; i++) {
 | 
			
		||||
            const ss = stopSequences[i].trim()
 | 
			
		||||
            if (message.content.trim().endsWith(ss)) {
 | 
			
		||||
              message.content = message.content.trim().slice(0, message.content.trim().length - ss.length)
 | 
			
		||||
              updateMessages(chat.id)
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        chatRequest.updating = false
 | 
			
		||||
        chatRequest.updatingMessage = ''
 | 
			
		||||
        ws.close()
 | 
			
		||||
      })
 | 
			
		||||
      ws.onopen = () => {
 | 
			
		||||
        ws.send(JSON.stringify({
 | 
			
		||||
          type: 'open_inference_session',
 | 
			
		||||
          model,
 | 
			
		||||
          max_length: maxLen
 | 
			
		||||
        }))
 | 
			
		||||
        ws.onmessage = event => {
 | 
			
		||||
          const response = JSON.parse(event.data)
 | 
			
		||||
          if (!response.ok) {
 | 
			
		||||
            const err = new Error('Error opening socket: ' + response.traceback)
 | 
			
		||||
            chatResponse.updateFromError(err.message)
 | 
			
		||||
            console.error(err)
 | 
			
		||||
            throw err
 | 
			
		||||
          }
 | 
			
		||||
    
 | 
			
		||||
      // Enforce strict order of messages
 | 
			
		||||
      const fMessages = (request.messages || [] as Message[])
 | 
			
		||||
      const rMessages = fMessages.reduce((a, m, i) => {
 | 
			
		||||
| 
						 | 
				
			
			@ -130,9 +97,48 @@ export const chatRequest = async (
 | 
			
		|||
        return a
 | 
			
		||||
      }, [] as Message[])
 | 
			
		||||
      const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
 | 
			
		||||
      const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt
 | 
			
		||||
    
 | 
			
		||||
      let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
 | 
			
		||||
      const promptTokenCount = countTokens(model, fullPromptInput)
 | 
			
		||||
      if (promptTokenCount > maxLen) {
 | 
			
		||||
        maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
 | 
			
		||||
      }
 | 
			
		||||
      // update with real count
 | 
			
		||||
      chatResponse.setPromptTokenCount(promptTokenCount)
 | 
			
		||||
      // set up the request
 | 
			
		||||
      chatResponse.onFinish(() => {
 | 
			
		||||
        const message = chatResponse.getMessages()[0]
 | 
			
		||||
        if (message) {
 | 
			
		||||
          for (let i = 0, l = stopSequences.length; i < l; i++) {
 | 
			
		||||
            const ss = stopSequences[i].trim()
 | 
			
		||||
            if (message.content.trim().endsWith(ss)) {
 | 
			
		||||
              message.content = message.content.trim().slice(0, message.content.trim().length - ss.length)
 | 
			
		||||
              updateMessages(chat.id)
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        chatRequest.updating = false
 | 
			
		||||
        chatRequest.updatingMessage = ''
 | 
			
		||||
        ws.close()
 | 
			
		||||
      })
 | 
			
		||||
      ws.onopen = () => {
 | 
			
		||||
        ws.send(JSON.stringify({
 | 
			
		||||
          type: 'open_inference_session',
 | 
			
		||||
          model,
 | 
			
		||||
          max_length: maxLen
 | 
			
		||||
        }))
 | 
			
		||||
        ws.onmessage = event => {
 | 
			
		||||
          const response = JSON.parse(event.data)
 | 
			
		||||
          if (!response.ok) {
 | 
			
		||||
            const err = new Error('Error opening socket: ' + response.traceback)
 | 
			
		||||
            chatResponse.updateFromError(err.message)
 | 
			
		||||
            console.error(err)
 | 
			
		||||
            throw err
 | 
			
		||||
          }
 | 
			
		||||
          const petalsRequest = {
 | 
			
		||||
            type: 'generate',
 | 
			
		||||
            inputs: getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt,
 | 
			
		||||
            inputs: fullPromptInput,
 | 
			
		||||
            max_new_tokens: 1, // wait for up to 1 tokens before displaying
 | 
			
		||||
            stop_sequence: stopSequence,
 | 
			
		||||
            do_sample: 1, // enable top p and the like
 | 
			
		||||
| 
						 | 
				
			
			@ -141,6 +147,8 @@ export const chatRequest = async (
 | 
			
		|||
            // repitition_penalty: chatSettings.repititionPenalty
 | 
			
		||||
          } as any
 | 
			
		||||
          if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
 | 
			
		||||
          // Update token count
 | 
			
		||||
          chatResponse.setPromptTokenCount(promptTokenCount)
 | 
			
		||||
          ws.send(JSON.stringify(petalsRequest))
 | 
			
		||||
          ws.onmessage = event => {
 | 
			
		||||
            // Remove updating indicator
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue