Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 54 additions & 43 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2857,20 +2857,44 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
}

#define SIN_COS_N_COUNT WHISPER_N_FFT
static float sin_vals[SIN_COS_N_COUNT];
static float cos_vals[SIN_COS_N_COUNT];
namespace {
struct whisper_global_cache {
// In FFT, we frequently use sine and cosine operations with the same values.
// We can use precalculated values to speed up the process.
float sin_vals[SIN_COS_N_COUNT];
float cos_vals[SIN_COS_N_COUNT];

// Hann window (Use cosf to eliminate difference)
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
float hann_window[WHISPER_N_FFT];
float hann_window2x[WHISPER_N_FFT * 2];

whisper_global_cache() {
fill_sin_cos_table();
#define FILL_HANN_WINDOW(arr) fill_hann_window(sizeof(arr) / sizeof(arr[0]), true, arr)
FILL_HANN_WINDOW(hann_window);
FILL_HANN_WINDOW(hann_window2x);
}

void fill_sin_cos_table() {
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
}

// In FFT, we frequently use sine and cosine operations with the same values.
// We can use precalculated values to speed up the process.
static void fill_sin_cos_table() {
static bool is_filled = false;
if (is_filled) return;
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
void fill_hann_window(int length, bool periodic, float* output) {
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
}
is_filled = true;
} global_cache;
}

// naive Discrete Fourier Transform
Expand All @@ -2888,8 +2912,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {

for (int n = 0; n < N; n++) {
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
re += in[n]*cos_vals[idx]; // cos(t)
im -= in[n]*sin_vals[idx]; // sin(t)
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
}

out[k*2 + 0] = re;
Expand Down Expand Up @@ -2940,8 +2964,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N/2; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = cos_vals[idx]; // cos(t)
float im = -sin_vals[idx]; // sin(t)
float re = global_cache.cos_vals[idx]; // cos(t)
float im = -global_cache.sin_vals[idx]; // sin(t)

float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1];
Expand All @@ -2954,22 +2978,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
}
}

static bool hann_window(int length, bool periodic, std::vector<float> & output) {
if (output.size() < static_cast<size_t>(length)) {
output.resize(length);
}
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
}

return true;
}

static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads,
const whisper_filters & filters, whisper_mel & mel) {
std::vector<float> fft_in(frame_size, 0.0);
Expand All @@ -2984,7 +2993,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step;

// apply Hanning window (~10% faster)
// apply Hann window (~10% faster)
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
fft_in[j] = hann[j] * samples[offset + j];
}
Expand Down Expand Up @@ -3051,12 +3060,16 @@ static bool log_mel_spectrogram(
whisper_mel & mel) {
const int64_t t_start_us = ggml_time_us();

// Hanning window (Use cosf to eliminate difference)
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
std::vector<float> hann;
hann_window(frame_size, true, hann);

// Hann window
const float * hann = nullptr;
if (frame_size == WHISPER_N_FFT) {
hann = global_cache.hann_window;
} else if (frame_size == 2 * WHISPER_N_FFT) {
hann = global_cache.hann_window2x;
} else {
WHISPER_ASSERT(false && "Unsupported frame_size");
return false;
}

// Calculate the length of padding
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
Expand Down Expand Up @@ -3086,7 +3099,7 @@ static bool log_mel_spectrogram(
std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
std::cref(filters), std::ref(mel));
}
Expand Down Expand Up @@ -3246,8 +3259,6 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
#endif

struct whisper_state * whisper_init_state(whisper_context * ctx) {
fill_sin_cos_table();

whisper_state * state = new whisper_state;

state->backend = whisper_backend_init(ctx->params);
Expand Down Expand Up @@ -7235,7 +7246,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
// operation (after median filter)
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
w = ggml_norm(gctx, w, 1e-9);
w = ggml_norm(gctx, w, 1e-9f);
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);

// Pass median filter - this is done over AUDIO_TOKENS dimension.
Expand Down