diff --git a/src/FSharpy.TaskSeq.Test/FSharpy.TaskSeq.Test.fsproj b/src/FSharpy.TaskSeq.Test/FSharpy.TaskSeq.Test.fsproj
index 0b1ca03c..d99de137 100644
--- a/src/FSharpy.TaskSeq.Test/FSharpy.TaskSeq.Test.fsproj
+++ b/src/FSharpy.TaskSeq.Test/FSharpy.TaskSeq.Test.fsproj
@@ -14,6 +14,7 @@
+
diff --git a/src/FSharpy.TaskSeq.Test/TaskSeq.Fold.Tests.fs b/src/FSharpy.TaskSeq.Test/TaskSeq.Fold.Tests.fs
new file mode 100644
index 00000000..b21a8075
--- /dev/null
+++ b/src/FSharpy.TaskSeq.Test/TaskSeq.Fold.Tests.fs
@@ -0,0 +1,49 @@
+module FSharpy.TaskSeq.Tests.Fold
+
+open System.Text
+open Xunit
+open FsUnit.Xunit
+open FsToolkit.ErrorHandling
+
+open FSharpy
+
+
+[]
+let ``TaskSeq-fold folds with every item`` () = task {
+ let! alphabet =
+ createDummyTaskSeqWith 50L<µs> 1000L<µs> 26
+ |> TaskSeq.fold (fun (state: StringBuilder) item -> state.Append(char item + '@')) (StringBuilder())
+
+ alphabet.ToString()
+ |> should equal "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+}
+
+[]
+let ``TaskSeq-foldAsync folds with every item`` () = task {
+ let! alphabet =
+ createDummyTaskSeqWith 50L<µs> 1000L<µs> 26
+ |> TaskSeq.foldAsync
+ (fun (state: StringBuilder) item -> task { return state.Append(char item + '@') })
+ (StringBuilder())
+
+ alphabet.ToString()
+ |> should equal "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+}
+
+[]
+let ``TaskSeq-fold takes state on empty IAsyncEnumberable`` () = task {
+ let! empty =
+ TaskSeq.empty
+ |> TaskSeq.fold (fun _ item -> char (item + 64)) '_'
+
+ empty |> should equal '_'
+}
+
+[]
+let ``TaskSeq-foldAsync takes state on empty IAsyncEnumerable`` () = task {
+ let! alphabet =
+ TaskSeq.empty
+ |> TaskSeq.foldAsync (fun _ item -> task { return char (item + 64) }) '_'
+
+ alphabet |> should equal '_'
+}
diff --git a/src/FSharpy.TaskSeq.Test/TestUtils.fs b/src/FSharpy.TaskSeq.Test/TestUtils.fs
index 47cc08d4..4eef93d7 100644
--- a/src/FSharpy.TaskSeq.Test/TestUtils.fs
+++ b/src/FSharpy.TaskSeq.Test/TestUtils.fs
@@ -162,4 +162,4 @@ module TestUtils =
}
/// Create a bunch of dummy tasks, each lasting between 10-30ms with spin-wait delays.
- let createDummyTaskSeq = createDummyTaskSeqWith 10_0000L<µs> 30_0000L<µs>
+ let createDummyTaskSeq = createDummyTaskSeqWith 10_000L<µs> 30_000L<µs>
diff --git a/src/FSharpy.TaskSeq/TaskSeq.fs b/src/FSharpy.TaskSeq/TaskSeq.fs
index a9448d91..ea814d81 100644
--- a/src/FSharpy.TaskSeq/TaskSeq.fs
+++ b/src/FSharpy.TaskSeq/TaskSeq.fs
@@ -197,3 +197,9 @@ module TaskSeq =
/// Zips two task sequences, returning a taskSeq of the tuples of each sequence, in order. May raise ArgumentException
/// if the sequences are or unequal length.
let zip taskSeq1 taskSeq2 = Internal.zip taskSeq1 taskSeq2
+
+ /// Applies a function to each element of the task sequence, threading an accumulator argument through the computation.
+ let fold folder state taskSeq = Internal.fold (FolderAction folder) state taskSeq
+
+ /// Applies an async function to each element of the task sequence, threading an accumulator argument through the computation.
+ let foldAsync folder state taskSeq = Internal.fold (AsyncFolderAction folder) state taskSeq
diff --git a/src/FSharpy.TaskSeq/TaskSeqInternal.fs b/src/FSharpy.TaskSeq/TaskSeqInternal.fs
index 0c98527f..462b0a19 100644
--- a/src/FSharpy.TaskSeq/TaskSeqInternal.fs
+++ b/src/FSharpy.TaskSeq/TaskSeqInternal.fs
@@ -12,10 +12,15 @@ module ExtraTaskSeqOperators =
[]
type Action<'T, 'U, 'TaskU when 'TaskU :> Task<'U>> =
- | CountableAction of c_action: (int -> 'T -> 'U)
- | SimpleAction of s_action: ('T -> 'U)
- | AsyncCountableAction of ac_action: (int -> 'T -> 'TaskU)
- | AsyncSimpleAction of as_action: ('T -> 'TaskU)
+ | CountableAction of countable_action: (int -> 'T -> 'U)
+ | SimpleAction of simple_action: ('T -> 'U)
+ | AsyncCountableAction of async_countable_action: (int -> 'T -> 'TaskU)
+ | AsyncSimpleAction of async_simple_action: ('T -> 'TaskU)
+
+[]
+type FolderAction<'T, 'State, 'TaskState when 'TaskState :> Task<'State>> =
+ | FolderAction of state_action: ('State -> 'T -> 'State)
+ | AsyncFolderAction of async_state_action: ('State -> 'T -> 'TaskState)
module internal TaskSeqInternal =
let iter action (taskSeq: taskSeq<_>) = task {
@@ -59,17 +64,26 @@ module internal TaskSeqInternal =
go <- step
}
- let fold (action: 'State -> 'T -> 'State) initial (taskSeq: taskSeq<_>) = task {
+ let fold folder initial (taskSeq: taskSeq<_>) = task {
let e = taskSeq.GetAsyncEnumerator(CancellationToken())
let mutable go = true
let mutable result = initial
let! step = e.MoveNextAsync()
go <- step
- while go do
- result <- action result e.Current
- let! step = e.MoveNextAsync()
- go <- step
+ match folder with
+ | FolderAction folder ->
+ while go do
+ result <- folder result e.Current
+ let! step = e.MoveNextAsync()
+ go <- step
+
+ | AsyncFolderAction folder ->
+ while go do
+ let! tempResult = folder result e.Current
+ result <- tempResult
+ let! step = e.MoveNextAsync()
+ go <- step
return result
}