Skip to content

Commit e070829

Browse files
authored
Support bfloat16 data type (#54)
1 parent 436e523 commit e070829

12 files changed

+455
-53
lines changed

cacheflow/master/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
213213
parser.add_argument('--use-np-cache', action='store_true',
214214
help='save a numpy copy of model weights for faster loading')
215215
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
216-
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
217-
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
216+
# NOTE(woosuk): FlashAttention does not support float32.
217+
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type')
218218
# Parallel arguments
219219
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
220220
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')

cacheflow/models/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
'float': torch.float,
1818
'float16': torch.float16,
1919
'float32': torch.float32,
20+
'bfloat16': torch.bfloat16,
2021
}
2122

2223

csrc/activation_kernels.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ void silu_and_mul(
3434
dim3 grid(num_tokens);
3535
dim3 block(std::min(d, 1024));
3636
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
37-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
37+
AT_DISPATCH_FLOATING_TYPES_AND2(
38+
at::ScalarType::Half,
39+
at::ScalarType::BFloat16,
3840
input.scalar_type(),
3941
"silu_and_mul_kernel",
4042
[&] {

csrc/attention/attention_dtypes.cuh renamed to csrc/attention/attention_dtypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@
33
#include "attention_generic.cuh"
44
#include "dtype_float16.cuh"
55
#include "dtype_float32.cuh"
6+
7+
#ifdef ENABLE_BF16
8+
#include "dtype_bfloat16.cuh"
9+
#endif // ENABLE_BF16

csrc/attention/attention_kernels.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <torch/extension.h>
22
#include <ATen/cuda/CUDAContext.h>
33

4-
#include "attention_dtypes.cuh"
4+
#include "attention_dtypes.h"
55
#include "attention_utils.cuh"
66

77
#include <algorithm>
@@ -438,9 +438,13 @@ void single_query_cached_kv_attention(
438438
torch::Tensor& context_lens, // [num_seqs]
439439
int block_size,
440440
int max_context_len) {
441-
// TODO(woosuk): Support FP32 and BF16.
441+
// TODO(woosuk): Support FP32.
442442
if (query.dtype() == at::ScalarType::Half) {
443443
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
444+
#ifdef ENABLE_BF16
445+
} else if (query.dtype() == at::ScalarType::BFloat16) {
446+
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
447+
#endif
444448
} else {
445449
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
446450
}

csrc/attention/attention_utils.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include "attention_dtypes.cuh"
3+
#include "attention_dtypes.h"
44

55
#include <float.h>
66
#include <type_traits>

0 commit comments

Comments
 (0)