Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--use-np-cache', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
# NOTE(woosuk): FlashAttention does not support float32.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type')
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
Expand Down
1 change: 1 addition & 0 deletions cacheflow/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
'bfloat16': torch.bfloat16,
}


Expand Down
4 changes: 3 additions & 1 deletion csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ void silu_and_mul(
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"silu_and_mul_kernel",
[&] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"

#ifdef ENABLE_BF16
#include "dtype_bfloat16.cuh"
#endif // ENABLE_BF16
8 changes: 6 additions & 2 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "attention_dtypes.cuh"
#include "attention_dtypes.h"
#include "attention_utils.cuh"

#include <algorithm>
Expand Down Expand Up @@ -438,9 +438,13 @@ void single_query_cached_kv_attention(
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len) {
// TODO(woosuk): Support FP32 and BF16.
// TODO(woosuk): Support FP32.
if (query.dtype() == at::ScalarType::Half) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
#ifdef ENABLE_BF16
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
#endif
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/attention/attention_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "attention_dtypes.cuh"
#include "attention_dtypes.h"

#include <float.h>
#include <type_traits>
Expand Down
Loading