update token count
This commit is contained in:
		
							parent
							
								
									86f427f62f
								
							
						
					
					
						commit
						f4d9774423
					
				| 
						 | 
					@ -1,7 +1,7 @@
 | 
				
			||||||
<script context="module" lang="ts">
 | 
					<script context="module" lang="ts">
 | 
				
			||||||
    import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte'
 | 
					    import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte'
 | 
				
			||||||
    import { ChatRequest } from '../../ChatRequest.svelte'
 | 
					    import { ChatRequest } from '../../ChatRequest.svelte'
 | 
				
			||||||
    import { getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
 | 
					    import { countTokens, getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
 | 
				
			||||||
    import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
 | 
					    import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
 | 
				
			||||||
    import { getModelMaxTokens } from '../../Stats.svelte'
 | 
					    import { getModelMaxTokens } from '../../Stats.svelte'
 | 
				
			||||||
    import { updateMessages } from '../../Storage.svelte'
 | 
					    import { updateMessages } from '../../Storage.svelte'
 | 
				
			||||||
| 
						 | 
					@ -35,11 +35,78 @@ export const chatRequest = async (
 | 
				
			||||||
      stopSequences = stopSequences.sort((a, b) => b.length - a.length)
 | 
					      stopSequences = stopSequences.sort((a, b) => b.length - a.length)
 | 
				
			||||||
      const stopSequencesC = stopSequences.filter(s => s !== stopSequence)
 | 
					      const stopSequencesC = stopSequences.filter(s => s !== stopSequence)
 | 
				
			||||||
      const maxTokens = getModelMaxTokens(model)
 | 
					      const maxTokens = getModelMaxTokens(model)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					      // Enforce strict order of messages
 | 
				
			||||||
 | 
					      const fMessages = (request.messages || [] as Message[])
 | 
				
			||||||
 | 
					      const rMessages = fMessages.reduce((a, m, i) => {
 | 
				
			||||||
 | 
					        a.push(m)
 | 
				
			||||||
 | 
					        const nm = fMessages[i + 1]
 | 
				
			||||||
 | 
					        if (m.role === 'system' && (!nm || nm.role !== 'user')) {
 | 
				
			||||||
 | 
					          const nc = {
 | 
				
			||||||
 | 
					            role: 'user',
 | 
				
			||||||
 | 
					            content: ''
 | 
				
			||||||
 | 
					          } as Message
 | 
				
			||||||
 | 
					          a.push(nc)
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return a
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					          [] as Message[])
 | 
				
			||||||
 | 
					      // make sure top_p and temperature are set the way we need
 | 
				
			||||||
 | 
					      let temperature = request.temperature
 | 
				
			||||||
 | 
					      if (temperature === undefined || isNaN(temperature as any)) temperature = 1
 | 
				
			||||||
 | 
					      if (!temperature || temperature <= 0) temperature = 0.01
 | 
				
			||||||
 | 
					      let topP = request.top_p
 | 
				
			||||||
 | 
					      if (topP === undefined || isNaN(topP as any)) topP = 1
 | 
				
			||||||
 | 
					      if (!topP || topP <= 0) topP = 0.01
 | 
				
			||||||
 | 
					      // build the message array
 | 
				
			||||||
 | 
					      const buildMessage = (m: Message): string => {
 | 
				
			||||||
 | 
					        return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      const inputArray = rMessages.reduce((a, m, i) => {
 | 
				
			||||||
 | 
					        let c = buildMessage(m)
 | 
				
			||||||
 | 
					        let replace = false
 | 
				
			||||||
 | 
					        const lm = a[a.length - 1]
 | 
				
			||||||
 | 
					        // Merge content if needed
 | 
				
			||||||
 | 
					        if (lm) {
 | 
				
			||||||
 | 
					          if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
 | 
				
			||||||
 | 
					            c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
 | 
				
			||||||
 | 
					            replace = true
 | 
				
			||||||
 | 
					          } else {
 | 
				
			||||||
 | 
					            c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
 | 
				
			||||||
 | 
					            c = c.replaceAll('[[USER_PROMPT]]', lm.content)
 | 
				
			||||||
 | 
					            replace = true
 | 
				
			||||||
 | 
					          } else {
 | 
				
			||||||
 | 
					            c = c.replaceAll('[[USER_PROMPT]]', '')
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        // Clean up merge fields on last
 | 
				
			||||||
 | 
					        if (!rMessages[i + 1]) {
 | 
				
			||||||
 | 
					          c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        const result = {
 | 
				
			||||||
 | 
					          role: m.role,
 | 
				
			||||||
 | 
					          content: c.trim()
 | 
				
			||||||
 | 
					        } as Message
 | 
				
			||||||
 | 
					        if (replace) {
 | 
				
			||||||
 | 
					          a[a.length - 1] = result
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					          a.push(result)
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return a
 | 
				
			||||||
 | 
					      }, [] as Message[])
 | 
				
			||||||
 | 
					      const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
 | 
				
			||||||
 | 
					      const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
      let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
 | 
					      let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
 | 
				
			||||||
      const promptTokenCount = chatResponse.getPromptTokenCount()
 | 
					      const promptTokenCount = countTokens(model, fullPromptInput)
 | 
				
			||||||
      if (promptTokenCount > maxLen) {
 | 
					      if (promptTokenCount > maxLen) {
 | 
				
			||||||
        maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
 | 
					        maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					      // update with real count
 | 
				
			||||||
 | 
					      chatResponse.setPromptTokenCount(promptTokenCount)
 | 
				
			||||||
 | 
					      // set up the request
 | 
				
			||||||
      chatResponse.onFinish(() => {
 | 
					      chatResponse.onFinish(() => {
 | 
				
			||||||
        const message = chatResponse.getMessages()[0]
 | 
					        const message = chatResponse.getMessages()[0]
 | 
				
			||||||
        if (message) {
 | 
					        if (message) {
 | 
				
			||||||
| 
						 | 
					@ -69,70 +136,9 @@ export const chatRequest = async (
 | 
				
			||||||
            console.error(err)
 | 
					            console.error(err)
 | 
				
			||||||
            throw err
 | 
					            throw err
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
          // Enforce strict order of messages
 | 
					 | 
				
			||||||
          const fMessages = (request.messages || [] as Message[])
 | 
					 | 
				
			||||||
          const rMessages = fMessages.reduce((a, m, i) => {
 | 
					 | 
				
			||||||
            a.push(m)
 | 
					 | 
				
			||||||
            const nm = fMessages[i + 1]
 | 
					 | 
				
			||||||
            if (m.role === 'system' && (!nm || nm.role !== 'user')) {
 | 
					 | 
				
			||||||
              const nc = {
 | 
					 | 
				
			||||||
                role: 'user',
 | 
					 | 
				
			||||||
                content: ''
 | 
					 | 
				
			||||||
              } as Message
 | 
					 | 
				
			||||||
              a.push(nc)
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            return a
 | 
					 | 
				
			||||||
          },
 | 
					 | 
				
			||||||
          [] as Message[])
 | 
					 | 
				
			||||||
          // make sure top_p and temperature are set the way we need
 | 
					 | 
				
			||||||
          let temperature = request.temperature
 | 
					 | 
				
			||||||
          if (temperature === undefined || isNaN(temperature as any)) temperature = 1
 | 
					 | 
				
			||||||
          if (!temperature || temperature <= 0) temperature = 0.01
 | 
					 | 
				
			||||||
          let topP = request.top_p
 | 
					 | 
				
			||||||
          if (topP === undefined || isNaN(topP as any)) topP = 1
 | 
					 | 
				
			||||||
          if (!topP || topP <= 0) topP = 0.01
 | 
					 | 
				
			||||||
          // build the message array
 | 
					 | 
				
			||||||
          const buildMessage = (m: Message): string => {
 | 
					 | 
				
			||||||
            return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
          const inputArray = rMessages.reduce((a, m, i) => {
 | 
					 | 
				
			||||||
            let c = buildMessage(m)
 | 
					 | 
				
			||||||
            let replace = false
 | 
					 | 
				
			||||||
            const lm = a[a.length - 1]
 | 
					 | 
				
			||||||
            // Merge content if needed
 | 
					 | 
				
			||||||
            if (lm) {
 | 
					 | 
				
			||||||
              if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
 | 
					 | 
				
			||||||
                c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
 | 
					 | 
				
			||||||
                replace = true
 | 
					 | 
				
			||||||
              } else {
 | 
					 | 
				
			||||||
                c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
 | 
					 | 
				
			||||||
              }
 | 
					 | 
				
			||||||
              if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
 | 
					 | 
				
			||||||
                c = c.replaceAll('[[USER_PROMPT]]', lm.content)
 | 
					 | 
				
			||||||
                replace = true
 | 
					 | 
				
			||||||
              } else {
 | 
					 | 
				
			||||||
                c = c.replaceAll('[[USER_PROMPT]]', '')
 | 
					 | 
				
			||||||
              }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            // Clean up merge fields on last
 | 
					 | 
				
			||||||
            if (!rMessages[i + 1]) {
 | 
					 | 
				
			||||||
              c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            const result = {
 | 
					 | 
				
			||||||
              role: m.role,
 | 
					 | 
				
			||||||
              content: c.trim()
 | 
					 | 
				
			||||||
            } as Message
 | 
					 | 
				
			||||||
            if (replace) {
 | 
					 | 
				
			||||||
              a[a.length - 1] = result
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
              a.push(result)
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            return a
 | 
					 | 
				
			||||||
          }, [] as Message[])
 | 
					 | 
				
			||||||
          const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
 | 
					 | 
				
			||||||
          const petalsRequest = {
 | 
					          const petalsRequest = {
 | 
				
			||||||
            type: 'generate',
 | 
					            type: 'generate',
 | 
				
			||||||
            inputs: getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt,
 | 
					            inputs: fullPromptInput,
 | 
				
			||||||
            max_new_tokens: 1, // wait for up to 1 tokens before displaying
 | 
					            max_new_tokens: 1, // wait for up to 1 tokens before displaying
 | 
				
			||||||
            stop_sequence: stopSequence,
 | 
					            stop_sequence: stopSequence,
 | 
				
			||||||
            do_sample: 1, // enable top p and the like
 | 
					            do_sample: 1, // enable top p and the like
 | 
				
			||||||
| 
						 | 
					@ -141,6 +147,8 @@ export const chatRequest = async (
 | 
				
			||||||
            // repitition_penalty: chatSettings.repititionPenalty
 | 
					            // repitition_penalty: chatSettings.repititionPenalty
 | 
				
			||||||
          } as any
 | 
					          } as any
 | 
				
			||||||
          if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
 | 
					          if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
 | 
				
			||||||
 | 
					          // Update token count
 | 
				
			||||||
 | 
					          chatResponse.setPromptTokenCount(promptTokenCount)
 | 
				
			||||||
          ws.send(JSON.stringify(petalsRequest))
 | 
					          ws.send(JSON.stringify(petalsRequest))
 | 
				
			||||||
          ws.onmessage = event => {
 | 
					          ws.onmessage = event => {
 | 
				
			||||||
            // Remove updating indicator
 | 
					            // Remove updating indicator
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue