@@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
1657
1657
}
1658
1658
}
1659
1659
1660
- size_t llama_context::state_seq_get_size (llama_seq_id seq_id) {
1660
+ size_t llama_context::state_seq_get_size (llama_seq_id seq_id, llama_state_seq_flags flags ) {
1661
1661
llama_io_write_dummy io;
1662
1662
try {
1663
- return state_seq_write_data (io, seq_id);
1663
+ return state_seq_write_data (io, seq_id, flags );
1664
1664
} catch (const std::exception & err) {
1665
1665
LLAMA_LOG_ERROR (" %s: error getting state size: %s\n " , __func__, err.what ());
1666
1666
return 0 ;
1667
1667
}
1668
1668
}
1669
1669
1670
- size_t llama_context::state_seq_get_data (llama_seq_id seq_id, uint8_t * dst, size_t size) {
1670
+ size_t llama_context::state_seq_get_data (llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags ) {
1671
1671
llama_io_write_buffer io (dst, size);
1672
1672
try {
1673
- return state_seq_write_data (io, seq_id);
1673
+ return state_seq_write_data (io, seq_id, flags );
1674
1674
} catch (const std::exception & err) {
1675
1675
LLAMA_LOG_ERROR (" %s: error saving state: %s\n " , __func__, err.what ());
1676
1676
return 0 ;
1677
1677
}
1678
1678
}
1679
1679
1680
- size_t llama_context::state_seq_set_data (llama_seq_id seq_id, const uint8_t * src, size_t size) {
1680
+ size_t llama_context::state_seq_set_data (llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags ) {
1681
1681
llama_io_read_buffer io (src, size);
1682
1682
try {
1683
- return state_seq_read_data (io, seq_id);
1683
+ return state_seq_read_data (io, seq_id, flags );
1684
1684
} catch (const std::exception & err) {
1685
1685
LLAMA_LOG_ERROR (" %s: error loading state: %s\n " , __func__, err.what ());
1686
1686
return 0 ;
@@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
1778
1778
{
1779
1779
const size_t state_size = file.size () - file.tell ();
1780
1780
llama_io_read_file io (&file);
1781
- const size_t nread = state_seq_read_data (io, seq_id);
1781
+ const size_t nread = state_seq_read_data (io, seq_id, 0 );
1782
1782
if (!nread) {
1783
1783
LLAMA_LOG_ERROR (" %s: failed to restore sequence state\n " , __func__);
1784
1784
return 0 ;
@@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
1802
1802
1803
1803
// save the context state using stream saving
1804
1804
llama_io_write_file io (&file);
1805
- state_seq_write_data (io, seq_id);
1805
+ state_seq_write_data (io, seq_id, 0 );
1806
1806
1807
1807
const size_t res = file.tell ();
1808
1808
GGML_ASSERT (res == sizeof (uint32_t ) * 3 + sizeof (llama_token) * n_token_count + io.n_bytes ());
@@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1971
1971
return io.n_bytes ();
1972
1972
}
1973
1973
1974
- size_t llama_context::state_seq_write_data (llama_io_write_i & io, llama_seq_id seq_id) {
1974
+ size_t llama_context::state_seq_write_data (llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags ) {
1975
1975
GGML_UNUSED (seq_id);
1976
1976
1977
1977
if (memory) {
1978
- memory->state_write (io, seq_id);
1978
+ memory->state_write (io, seq_id, flags );
1979
1979
}
1980
1980
1981
1981
return io.n_bytes ();
1982
1982
}
1983
1983
1984
- size_t llama_context::state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) {
1984
+ size_t llama_context::state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags ) {
1985
1985
GGML_UNUSED (seq_id);
1986
1986
1987
1987
if (memory) {
1988
- memory->state_read (io, seq_id);
1988
+ memory->state_read (io, seq_id, flags );
1989
1989
}
1990
1990
1991
1991
return io.n_bytes ();
@@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
2801
2801
}
2802
2802
2803
2803
size_t llama_state_seq_get_size (llama_context * ctx, llama_seq_id seq_id) {
2804
- return ctx-> state_seq_get_size ( seq_id);
2804
+ return llama_state_seq_get_size_ext (ctx, seq_id, 0 );
2805
2805
}
2806
2806
2807
2807
size_t llama_state_seq_get_data (llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
2808
+ return llama_state_seq_get_data_ext (ctx, dst, size, seq_id, 0 );
2809
+ }
2810
+
2811
+ size_t llama_state_seq_set_data (llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2812
+ return llama_state_seq_set_data_ext (ctx, src, size, seq_id, 0 );
2813
+ }
2814
+
2815
+ size_t llama_state_seq_get_size_ext (llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
2816
+ return ctx->state_seq_get_size (seq_id, flags);
2817
+ }
2818
+
2819
+ size_t llama_state_seq_get_data_ext (llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2808
2820
ctx->synchronize ();
2809
2821
2810
- return ctx->state_seq_get_data (seq_id, dst, size);
2822
+ return ctx->state_seq_get_data (seq_id, dst, size, flags );
2811
2823
}
2812
2824
2813
- size_t llama_state_seq_set_data (llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2825
+ size_t llama_state_seq_set_data_ext (llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags ) {
2814
2826
ctx->synchronize ();
2815
2827
2816
- return ctx->state_seq_set_data (seq_id, src, size);
2828
+ return ctx->state_seq_set_data (seq_id, src, size, flags );
2817
2829
}
2818
2830
2819
2831
size_t llama_state_seq_save_file (llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
0 commit comments