Skip to content

Commit 950e509

Browse files
authored
Enable optimized Jamba (#3406)
1 parent 1be1dc8 commit 950e509

File tree

28 files changed

+3023
-16
lines changed

28 files changed

+3023
-16
lines changed

csrc/cpu/aten/Conv.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
namespace torch_ipex {
99
namespace cpu {
1010

11+
IPEX_DEFINE_DISPATCH(causal_conv1d_update_kernel_stub);
1112
std::vector<int64_t> calc_conv_output_size(
1213
at::IntArrayRef input_size,
1314
at::IntArrayRef kernel_size,
@@ -505,6 +506,32 @@ at::Tensor convolution_forward(
505506
weight_channels_last);
506507
}
507508

509+
/**
510+
* Official Python implementation: causal_conv1d_update_ref:
511+
* https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py#L206
512+
* @param hidden_states (batch, dim) or (batch, dim, seqlen)
513+
* @param conv_states (batch, dim, state_len), where state_len >= width - 1
514+
* @param conv_weights (dim, width)
515+
* @param conv_bias (dim,)
516+
* @param silu_activation If true, apply the SiLU activation function.
517+
* @return (hidden_states, conv_states)
518+
*/
519+
std::tuple<at::Tensor, at::Tensor> causal_conv1d_update(
520+
const at::Tensor& hidden_states,
521+
const at::Tensor& conv_states,
522+
const at::Tensor& conv_weights,
523+
const c10::optional<at::Tensor>& conv_bias,
524+
bool silu_activation) {
525+
RECORD_FUNCTION("causal_conv1d_update", c10::ArrayRef<c10::IValue>({}));
526+
return causal_conv1d_update_kernel_stub(
527+
kCPU,
528+
hidden_states,
529+
conv_states,
530+
conv_weights,
531+
conv_bias,
532+
silu_activation);
533+
}
534+
508535
} // namespace cpu
509536
} // namespace torch_ipex
510537

@@ -561,6 +588,12 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
561588
"convolution_forward",
562589
c10::DispatchKey::CPU,
563590
torch_ipex::cpu::convolution_forward_impl);
591+
m.def(
592+
"causal_conv1d_update(Tensor hidden_states, Tensor conv_states, Tensor conv_weights, Tensor? conv_bias, bool silu_activation) -> (Tensor, Tensor)");
593+
m.impl(
594+
"causal_conv1d_update",
595+
c10::DispatchKey::CPU,
596+
torch_ipex::cpu::causal_conv1d_update);
564597
// bw
565598
m.def(
566599
"convolution_backward(Tensor input, Tensor weight, Tensor? bias, Tensor grad_output, bool[3] out_mask, "

csrc/cpu/aten/Conv.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <ATen/Tensor.h>
4+
#include <dyndisp/DispatchStub.h>
45
#include <torch/csrc/autograd/custom_function.h>
56

67
#include <ideep.hpp>
@@ -51,6 +52,13 @@ std::vector<int64_t> calc_conv_output_size(
5152
at::IntArrayRef stride,
5253
at::IntArrayRef dilation);
5354

55+
std::tuple<at::Tensor, at::Tensor> causal_conv1d_update(
56+
const at::Tensor& hidden_states,
57+
const at::Tensor& conv_states,
58+
const at::Tensor& conv_weights,
59+
const c10::optional<at::Tensor>& conv_bias,
60+
bool silu_activation);
61+
5462
// IPEX customized convolution OP with n-D packed weight
5563
class IPEXConvolutionOp : public torch::autograd::Function<IPEXConvolutionOp> {
5664
public:
@@ -95,5 +103,14 @@ at::Tensor convolution_forward(
95103
c10::optional<at::IntArrayRef> dilation,
96104
c10::optional<bool> weight_channels_last);
97105

106+
using causal_conv1d_update_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)(
107+
const at::Tensor& hidden_states,
108+
const at::Tensor& conv_states,
109+
const at::Tensor& conv_weights,
110+
const c10::optional<at::Tensor>& conv_bias,
111+
bool silu_activation);
112+
IPEX_DECLARE_DISPATCH(
113+
causal_conv1d_update_kernel_fn,
114+
causal_conv1d_update_kernel_stub);
98115
} // namespace cpu
99116
} // namespace torch_ipex

csrc/cpu/aten/SelectiveScan.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#include <ATen/ATen.h>
2+
3+
#include <ATen/NativeFunctions.h>
4+
#include <ATen/Parallel.h>
5+
#include <ATen/native/ReduceOpsUtils.h>
6+
#include <ATen/native/cpu/utils.h>
7+
#include <ATen/record_function.h>
8+
#include <c10/util/irange.h>
9+
10+
#include "SelectiveScan.h"
11+
#include "utils/library.h"
12+
13+
namespace torch_ipex {
14+
namespace cpu {
15+
16+
IPEX_DEFINE_DISPATCH(selective_scan_kernel_stub);
17+
IPEX_DEFINE_DISPATCH(selective_state_update_kernel_stub);
18+
19+
/**
20+
* Does selective scan algorithm in Mamba Paper.
21+
* Paper: https://arxiv.org/abs/2312.00752
22+
* Official Python Implementation:
23+
* selective_scan_ref:
24+
* https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L113
25+
* @param u: (batch, dim, len) or (batch, len, dim)
26+
* @param delta: same shape as u
27+
* @param A: (dim, dstate) or (dstate, dim)
28+
* @param B: (batch, dstate, len) or (batch, dstate, 2len) or (battch, ngroups,
29+
* dstate, len)
30+
* @param C: (batch, dstate, len) or (batch, dstate, 2len) or (battch, ngroups,
31+
* dstate, len)
32+
* @param D: (dim,) or None
33+
* @param z: (batch, dim, len) or None
34+
* @param delta_bias: (dim,) or None
35+
* @param delta_softplus: bool
36+
* @param return_last_state: bool
37+
* @return: out: (batch, dim, len), last_state: (batch, dim, dstate)
38+
*/
39+
std::tuple<at::Tensor, at::Tensor> selective_scan(
40+
const at::Tensor& u,
41+
const at::Tensor& delta,
42+
const at::Tensor& A,
43+
const at::Tensor& B,
44+
const at::Tensor& C,
45+
const c10::optional<at::Tensor>& D,
46+
const c10::optional<at::Tensor>& z,
47+
const c10::optional<at::Tensor>& delta_bias,
48+
bool delta_softplus,
49+
bool return_last_state) {
50+
RECORD_FUNCTION("selective_scan_fn", c10::ArrayRef<c10::IValue>({}));
51+
return selective_scan_kernel_stub(
52+
kCPU,
53+
u,
54+
delta,
55+
A,
56+
B,
57+
C,
58+
D,
59+
z,
60+
delta_bias,
61+
delta_softplus,
62+
return_last_state);
63+
}
64+
65+
/**
66+
* Official Python Implementation:
67+
* selective_state_update_ref:
68+
* https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py#L219
69+
* @param state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
70+
* @param x: (batch, dim) or (batch, nheads, dim)
71+
* @param dt: (batch, dim) or (batch, nheads, dim)
72+
* @param A: (dim, dstate) or (nheads, dim, dstate) or (dstate, dim) or (nheads,
73+
* dstate, dim)
74+
* @param B: (batch, dstate) or (batch, ngroups, dstate)
75+
* @param C: (batch, dstate) or (batch, ngroups, dstate)
76+
* @param D: (dim,) or (nheads, dim) or None
77+
* @param z: (batch, dim) or (batch, nheads, dim) or None
78+
* @param dt_bias: (dim,) or (nheads, dim) or None
79+
* @param dt_softplus: bool
80+
* @return: out: (batch, dim) or (batch, nheads, dim)
81+
*/
82+
at::Tensor selective_state_update(
83+
const at::Tensor& state,
84+
const at::Tensor& x,
85+
const at::Tensor& dt,
86+
const at::Tensor& A,
87+
const at::Tensor& B,
88+
const at::Tensor& C,
89+
const c10::optional<at::Tensor>& D,
90+
const c10::optional<at::Tensor>& z,
91+
const c10::optional<at::Tensor>& dt_bias,
92+
bool dt_softplus) {
93+
RECORD_FUNCTION("selective_state_update", c10::ArrayRef<c10::IValue>({}));
94+
return selective_state_update_kernel_stub(
95+
kCPU, state, x, dt, A, B, C, D, z, dt_bias, dt_softplus);
96+
}
97+
98+
} // namespace cpu
99+
} // namespace torch_ipex
100+
101+
namespace {
102+
103+
IPEX_TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
104+
m.def(
105+
"selective_scan_fn(Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, bool return_last_state) -> (Tensor, Tensor)");
106+
m.impl(
107+
"selective_scan_fn",
108+
c10::DispatchKey::CPU,
109+
torch_ipex::cpu::selective_scan);
110+
m.def(
111+
"selective_state_update(Tensor state, Tensor x, Tensor dt, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? dt_bias, bool dt_softplus) -> (Tensor)");
112+
m.impl(
113+
"selective_state_update",
114+
c10::DispatchKey::CPU,
115+
torch_ipex::cpu::selective_state_update);
116+
}
117+
118+
} // namespace

csrc/cpu/aten/SelectiveScan.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
#include <dyndisp/DispatchStub.h>
5+
6+
namespace torch_ipex {
7+
namespace cpu {
8+
9+
std::tuple<at::Tensor, at::Tensor> selective_scan(
10+
const at::Tensor& u,
11+
const at::Tensor& delta,
12+
const at::Tensor& A,
13+
const at::Tensor& B,
14+
const at::Tensor& C,
15+
const c10::optional<at::Tensor>& D,
16+
const c10::optional<at::Tensor>& z,
17+
const c10::optional<at::Tensor>& delta_bias,
18+
bool delta_softplus,
19+
bool return_last_state);
20+
at::Tensor selective_state_update(
21+
const at::Tensor& state,
22+
const at::Tensor& x,
23+
const at::Tensor& dt,
24+
const at::Tensor& A,
25+
const at::Tensor& B,
26+
const at::Tensor& C,
27+
const c10::optional<at::Tensor>& D,
28+
const c10::optional<at::Tensor>& z,
29+
const c10::optional<at::Tensor>& dt_bias,
30+
bool dt_softplus);
31+
32+
using selective_scan_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)(
33+
const at::Tensor& u,
34+
const at::Tensor& delta,
35+
const at::Tensor& A,
36+
const at::Tensor& B,
37+
const at::Tensor& C,
38+
const c10::optional<at::Tensor>& D,
39+
const c10::optional<at::Tensor>& z,
40+
const c10::optional<at::Tensor>& delta_bias,
41+
bool delta_softplus,
42+
bool return_last_state);
43+
using selective_state_update_fn = at::Tensor (*)(
44+
const at::Tensor& state,
45+
const at::Tensor& x,
46+
const at::Tensor& dt,
47+
const at::Tensor& A,
48+
const at::Tensor& B,
49+
const at::Tensor& C,
50+
const c10::optional<at::Tensor>& D,
51+
const c10::optional<at::Tensor>& z,
52+
const c10::optional<at::Tensor>& dt_bias,
53+
bool dt_softplus);
54+
IPEX_DECLARE_DISPATCH(selective_scan_kernel_fn, selective_scan_kernel_stub);
55+
IPEX_DECLARE_DISPATCH(
56+
selective_state_update_fn,
57+
selective_state_update_kernel_stub);
58+
59+
} // namespace cpu
60+
} // namespace torch_ipex
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#include <aten/Conv.h>
2+
#include "mkl.h"
3+
#include "vec/vec.h"
4+
5+
namespace torch_ipex {
6+
namespace cpu {
7+
namespace {
8+
template <typename T>
9+
std::tuple<at::Tensor, at::Tensor> causal_conv1d_update_kernel_inner(
10+
const at::Tensor& hidden_states,
11+
const at::Tensor& conv_states,
12+
const at::Tensor& conv_weights,
13+
const c10::optional<at::Tensor>& conv_bias,
14+
bool silu_activation) {
15+
auto bs = conv_states.size(0);
16+
auto channels = conv_states.size(1);
17+
auto kernel_size = conv_states.size(2);
18+
auto has_bias = conv_bias.has_value();
19+
auto bias_ptr = has_bias ? conv_bias.value().data_ptr<T>() : nullptr;
20+
auto conv_states_ptr = conv_states.data_ptr<T>();
21+
auto conv_weights_ptr = conv_weights.data_ptr<T>();
22+
auto hidden_states_ptr = hidden_states.data_ptr<T>();
23+
auto hidden_states_strideB = hidden_states.stride(0);
24+
auto hidden_states_strideC = hidden_states.stride(1);
25+
auto conv_states_strideB = conv_states.stride(0);
26+
auto conv_states_strideC = conv_states.stride(1);
27+
auto conv_states_strideK = conv_states.stride(2);
28+
auto conv_weights_strideC = conv_weights.stride(0);
29+
#pragma omp parallel for collapse(2)
30+
for (auto bi = 0; bi < bs; bi++) {
31+
for (auto ci = 0; ci < channels; ci++) {
32+
auto conv_weights_start = ci * conv_weights_strideC;
33+
float out = 0.0f;
34+
auto conv_states_start =
35+
bi * conv_states_strideB + ci * conv_states_strideC;
36+
for (auto k = 1; k < kernel_size; k++) {
37+
auto conv_states_idx = conv_states_start + k * conv_states_strideK;
38+
out += conv_weights_ptr[conv_weights_start + k - 1] *
39+
conv_states_ptr[conv_states_idx];
40+
conv_states_ptr[conv_states_idx - conv_states_strideK] =
41+
conv_states_ptr[conv_states_idx];
42+
}
43+
auto hidden_states_idx =
44+
bi * hidden_states_strideB + ci * hidden_states_strideC;
45+
out += hidden_states_ptr[hidden_states_idx] *
46+
conv_weights_ptr[conv_weights_start + kernel_size - 1];
47+
conv_states_ptr
48+
[conv_states_start + (kernel_size - 1) * conv_states_strideK] =
49+
hidden_states_ptr[hidden_states_idx];
50+
if (has_bias) {
51+
out += bias_ptr[ci];
52+
}
53+
if (silu_activation) {
54+
out = out / (1 + expf(-out));
55+
}
56+
hidden_states_ptr[hidden_states_idx] = out;
57+
}
58+
}
59+
return std::make_tuple(std::move(hidden_states), std::move(conv_states));
60+
}
61+
62+
std::tuple<at::Tensor, at::Tensor> causal_conv1d_update_kernel_impl(
63+
const at::Tensor& hidden_states,
64+
const at::Tensor& conv_states,
65+
const at::Tensor& conv_weights,
66+
const c10::optional<at::Tensor>& conv_bias,
67+
bool silu_activation) {
68+
if (hidden_states.scalar_type() == at::ScalarType::Float) {
69+
return causal_conv1d_update_kernel_inner<float>(
70+
hidden_states, conv_states, conv_weights, conv_bias, silu_activation);
71+
} else if (hidden_states.scalar_type() == at::ScalarType::BFloat16) {
72+
return causal_conv1d_update_kernel_inner<at::BFloat16>(
73+
hidden_states, conv_states, conv_weights, conv_bias, silu_activation);
74+
} else if (hidden_states.scalar_type() == at::ScalarType::Half) {
75+
return causal_conv1d_update_kernel_inner<at::Half>(
76+
hidden_states, conv_states, conv_weights, conv_bias, silu_activation);
77+
} else {
78+
TORCH_CHECK(
79+
false,
80+
"Only support bfloat16, float16 and float for causal_conv1d_update");
81+
}
82+
}
83+
} // anonymous namespace
84+
IPEX_REGISTER_DISPATCH(
85+
causal_conv1d_update_kernel_stub,
86+
&causal_conv1d_update_kernel_impl);
87+
88+
} // namespace cpu
89+
} // namespace torch_ipex

0 commit comments

Comments
 (0)