@@ -94,6 +94,22 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
94
94
modelMessage2
95
95
]
96
96
} ] ;
97
+ const chatHistoryOnlyCall : ChatHistoryItem [ ] = [ ...baseChatHistory , {
98
+ type : "model" ,
99
+ response : [
100
+ {
101
+ type : "functionCall" ,
102
+ name : func1name ,
103
+
104
+ // convert to number since this will go through JSON.stringify,
105
+ // and we want to avoid escaping characters in the rendered output
106
+ params : Number ( func1params ) ,
107
+ result : Number ( func1result ) ,
108
+ startsNewChunk : true
109
+ } ,
110
+ modelMessage2
111
+ ]
112
+ } ] ;
97
113
const chatHistory2Calls : ChatHistoryItem [ ] = [ ...baseChatHistory , {
98
114
type : "model" ,
99
115
response : [
@@ -257,6 +273,17 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
257
273
stringifyFunctionResults : stringifyResult ,
258
274
combineModelMessageAndToolCalls
259
275
} ) ;
276
+ const renderedOnlyCall = getFirstValidResult ( [
277
+ ( ) => renderTemplate ( {
278
+ chatHistory : chatHistoryOnlyCall ,
279
+ functions : functions1 ,
280
+ additionalParams,
281
+ stringifyFunctionParams : stringifyParams ,
282
+ stringifyFunctionResults : stringifyResult ,
283
+ combineModelMessageAndToolCalls
284
+ } ) ,
285
+ ( ) => undefined
286
+ ] ) ;
260
287
const rendered2Calls = getFirstValidResult ( [
261
288
( ) => renderTemplate ( {
262
289
chatHistory : chatHistory2Calls ,
@@ -411,14 +438,46 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
411
438
parallelismResultPrefix
412
439
} = resolveParallelismBetweenSectionsParts ( func2ParamsToFunc1Result . text . slice ( callSuffixLength , - resultPrefixLength ) ) ;
413
440
441
+ let revivedCallPrefix = reviveSeparatorText ( callPrefixText , idToStaticContent , contentIds ) ;
442
+ const revivedParallelismCallSectionPrefix = removeCommonRevivedPrefix (
443
+ reviveSeparatorText ( parallelismCallPrefix , idToStaticContent , contentIds ) ,
444
+ ! combineModelMessageAndToolCalls
445
+ ? textBetween2TextualModelResponses
446
+ : LlamaText ( )
447
+ ) ;
448
+ let revivedParallelismCallBetweenCalls = reviveSeparatorText ( parallelismBetweenCallsText , idToStaticContent , contentIds ) ;
449
+
450
+ if ( revivedParallelismCallSectionPrefix . values . length === 0 && renderedOnlyCall != null ) {
451
+ const userMessage1ToModelMessage1Start = getTextBetweenIds ( rendered1Call , userMessage1 , modelMessage1 ) ;
452
+ const onlyCallUserMessage1ToFunc1Name = getTextBetweenIds ( renderedOnlyCall , userMessage1 , func1name ) ;
453
+
454
+ if ( userMessage1ToModelMessage1Start . text != null && onlyCallUserMessage1ToFunc1Name . text != null ) {
455
+ const onlyCallModelMessagePrefixLength = findCommandStartLength (
456
+ userMessage1ToModelMessage1Start . text ,
457
+ onlyCallUserMessage1ToFunc1Name . text
458
+ ) ;
459
+ const onlyCallCallPrefixText = onlyCallUserMessage1ToFunc1Name . text . slice ( onlyCallModelMessagePrefixLength ) ;
460
+ const revivedOnlyCallCallPrefixText = reviveSeparatorText ( onlyCallCallPrefixText , idToStaticContent , contentIds ) ;
461
+
462
+ const optionalCallPrefix = removeCommonRevivedSuffix ( revivedCallPrefix , revivedOnlyCallCallPrefixText ) ;
463
+ if ( optionalCallPrefix . values . length > 0 ) {
464
+ revivedCallPrefix = removeCommonRevivedPrefix ( revivedCallPrefix , optionalCallPrefix ) ;
465
+ revivedParallelismCallBetweenCalls = LlamaText ( [
466
+ optionalCallPrefix ,
467
+ revivedParallelismCallBetweenCalls
468
+ ] ) ;
469
+ }
470
+ }
471
+ }
472
+
414
473
return {
415
474
stringifyParams,
416
475
stringifyResult,
417
476
combineModelMessageAndToolCalls,
418
477
settings : {
419
478
call : {
420
479
optionalPrefixSpace : true ,
421
- prefix : reviveSeparatorText ( callPrefixText , idToStaticContent , contentIds ) ,
480
+ prefix : revivedCallPrefix ,
422
481
paramsPrefix : reviveSeparatorText ( callParamsPrefixText , idToStaticContent , contentIds ) ,
423
482
suffix : reviveSeparatorText ( callSuffixText , idToStaticContent , contentIds ) ,
424
483
emptyCallParamsPlaceholder : { }
@@ -445,13 +504,8 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
445
504
} ,
446
505
parallelism : {
447
506
call : {
448
- sectionPrefix : removeCommonRevivedPrefix (
449
- reviveSeparatorText ( parallelismCallPrefix , idToStaticContent , contentIds ) ,
450
- ! combineModelMessageAndToolCalls
451
- ? textBetween2TextualModelResponses
452
- : LlamaText ( )
453
- ) ,
454
- betweenCalls : reviveSeparatorText ( parallelismBetweenCallsText , idToStaticContent , contentIds ) ,
507
+ sectionPrefix : revivedParallelismCallSectionPrefix ,
508
+ betweenCalls : revivedParallelismCallBetweenCalls ,
455
509
sectionSuffix : reviveSeparatorText ( parallelismCallSuffixText , idToStaticContent , contentIds )
456
510
} ,
457
511
result : {
@@ -524,14 +578,48 @@ function removeCommonRevivedPrefix(target: LlamaText, matchStart: LlamaText) {
524
578
} else if ( targetValue instanceof SpecialToken && matchStartValue instanceof SpecialToken ) {
525
579
if ( targetValue . value === matchStartValue . value )
526
580
continue ;
527
- }
581
+ } else if ( LlamaText ( targetValue ?? "" ) . compare ( LlamaText ( matchStartValue ?? "" ) ) )
582
+ continue ;
528
583
529
584
return LlamaText ( target . values . slice ( commonStartLength ) ) ;
530
585
}
531
586
532
587
return LlamaText ( target . values . slice ( matchStart . values . length ) ) ;
533
588
}
534
589
590
+ function removeCommonRevivedSuffix ( target : LlamaText , matchEnd : LlamaText ) {
591
+ for (
592
+ let commonEndLength = 0 ;
593
+ commonEndLength < target . values . length && commonEndLength < matchEnd . values . length ;
594
+ commonEndLength ++
595
+ ) {
596
+ const targetValue = target . values [ target . values . length - commonEndLength - 1 ] ;
597
+ const matchEndValue = matchEnd . values [ matchEnd . values . length - commonEndLength - 1 ] ;
598
+
599
+ if ( typeof targetValue === "string" && typeof matchEndValue === "string" ) {
600
+ if ( targetValue === matchEndValue )
601
+ continue ;
602
+ } else if ( targetValue instanceof SpecialTokensText && matchEndValue instanceof SpecialTokensText ) {
603
+ const commonLength = findCommonEndLength ( targetValue . value , matchEndValue . value ) ;
604
+ if ( commonLength === targetValue . value . length && commonLength === matchEndValue . value . length )
605
+ continue ;
606
+
607
+ return LlamaText ( [
608
+ ...target . values . slice ( 0 , target . values . length - commonEndLength - 1 ) ,
609
+ new SpecialTokensText ( targetValue . value . slice ( 0 , targetValue . value . length - commonLength ) )
610
+ ] ) ;
611
+ } else if ( targetValue instanceof SpecialToken && matchEndValue instanceof SpecialToken ) {
612
+ if ( targetValue . value === matchEndValue . value )
613
+ continue ;
614
+ } else if ( LlamaText ( targetValue ?? "" ) . compare ( LlamaText ( matchEndValue ?? "" ) ) )
615
+ continue ;
616
+
617
+ return LlamaText ( target . values . slice ( 0 , target . values . length - commonEndLength - 1 ) ) ;
618
+ }
619
+
620
+ return LlamaText ( target . values . slice ( 0 , target . values . length - matchEnd . values . length ) ) ;
621
+ }
622
+
535
623
function findCommandStartLength ( text1 : string , text2 : string ) {
536
624
let commonStartLength = 0 ;
537
625
while ( commonStartLength < text1 . length && commonStartLength < text2 . length ) {
0 commit comments