Skip to content

Commit 1e48c10

Browse files
committed
nf2go: convert nftables rules to golang code
One of the biggest barriers to adopt the netlink format for nftables is the complexity of writing bytecode. This commits adds a tool that allows to take an nftables dump and generate the corresponding golang code and validating that the generated code produces the exact same output. Change-Id: I491b35e0d8062de33c67091dd4126d843b231838 Signed-off-by: Antonio Ojea <[email protected]>
1 parent 69f487d commit 1e48c10

File tree

3 files changed

+866
-0
lines changed

3 files changed

+866
-0
lines changed

internal/nf2go/main.go

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"fmt"
7+
"io"
8+
"io/ioutil"
9+
"log"
10+
"os"
11+
"os/exec"
12+
"path/filepath"
13+
"regexp"
14+
"runtime"
15+
"strings"
16+
"time"
17+
18+
"github.com/google/go-cmp/cmp"
19+
"github.com/google/nftables"
20+
"github.com/vishvananda/netns"
21+
)
22+
23+
func main() {
24+
args := os.Args[1:]
25+
if len(args) != 1 {
26+
log.Fatalf("need to specify the file to read the \"nft list ruleset\" dump")
27+
}
28+
29+
filename := args[0]
30+
31+
runtime.LockOSThread()
32+
defer runtime.UnlockOSThread()
33+
34+
// Create a new network namespace
35+
ns, err := netns.New()
36+
if err != nil {
37+
log.Fatalf("netns.New() failed: %v", err)
38+
}
39+
n, err := nftables.New(nftables.WithNetNSFd(int(ns)))
40+
if err != nil {
41+
log.Fatalf("nftables.New() failed: %v", err)
42+
}
43+
44+
scriptOutput, err := applyNFTRuleset(filename)
45+
if err != nil {
46+
log.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput)
47+
}
48+
49+
var buf bytes.Buffer
50+
// Helper function to print to the file
51+
pf := func(format string, a ...interface{}) {
52+
_, err := fmt.Fprintf(&buf, format, a...)
53+
if err != nil {
54+
log.Fatal(err)
55+
}
56+
}
57+
58+
pf("// Code generated by nft2go. DO NOT EDIT.\n")
59+
pf("package main\n\n")
60+
pf("import (\n")
61+
pf("\t\"fmt\"\n")
62+
pf("\t\"log\"\n")
63+
pf("\t\"github.com/google/nftables\"\n")
64+
pf("\t\"github.com/google/nftables/expr\"\n")
65+
pf(")\n\n")
66+
pf("func main() {\n")
67+
pf("\tn, err:= nftables.New()\n")
68+
pf("\tif err!= nil {\n")
69+
pf("\t\tlog.Fatal(err)\n")
70+
pf("\t}\n\n")
71+
pf("\n")
72+
pf("\tvar expressions []expr.Any\n")
73+
pf("\tvar chain *nftables.Chain\n")
74+
pf("\tvar table *nftables.Table\n")
75+
76+
tables, err := n.ListTables()
77+
if err != nil {
78+
log.Fatalf("ListTables failed: %v", err)
79+
}
80+
81+
chains, err := n.ListChains()
82+
if err != nil {
83+
log.Fatal(err)
84+
}
85+
86+
for _, table := range tables {
87+
log.Printf("processing table: %s", table.Name)
88+
89+
pf("\ttable = n.AddTable(&nftables.Table{Family: %s,Name: \"%s\"})\n", TableFamilyString(table.Family), table.Name)
90+
for _, chain := range chains {
91+
if chain.Table.Name != table.Name {
92+
continue
93+
}
94+
95+
sets, err := n.GetSets(table)
96+
if err != nil {
97+
log.Fatal(err)
98+
}
99+
for _, set := range sets {
100+
// TODO datatype and the other options
101+
pf("\tn.AddSet(&nftables.Set{\n")
102+
pf("\t\tTable: table,\n")
103+
pf("\t\tName: \"%s\",\n", set.Name)
104+
pf("\t}, nil)\n")
105+
}
106+
107+
pf("\tchain = n.AddChain(&nftables.Chain{Name: \"%s\", Table: table, Type: %s, Hooknum: %s, Priority: %s})\n",
108+
chain.Name, ChainTypeString(chain.Type), ChainHookRef(chain.Hooknum), ChainPrioRef(chain.Priority))
109+
110+
rules, err := n.GetRules(table, chain)
111+
if err != nil {
112+
log.Fatal(err)
113+
}
114+
115+
for _, rule := range rules {
116+
pf("\texpressions = []expr.Any{\n")
117+
for _, exp := range rule.Exprs {
118+
pf("\t\t%#v,\n", exp)
119+
}
120+
pf("\t\t}\n")
121+
pf("\tn.AddRule(&nftables.Rule{\n")
122+
pf("\t\tTable: table,\n")
123+
pf("\t\tChain: chain,\n")
124+
pf("\t\tExprs: expressions,\n")
125+
pf("\t})\n")
126+
}
127+
}
128+
}
129+
130+
pf("\n\tif err:= n.Flush(); err!= nil {\n")
131+
pf("\t\tlog.Fatal(err)\n")
132+
pf("\t}\n\n")
133+
pf("\tfmt.Println(\"nft ruleset applied.\")\n")
134+
pf("}\n")
135+
136+
// Program nftables using your Go code
137+
if err := flushNFTRuleset(); err != nil {
138+
log.Fatalf("Failed to flush nftables ruleset: %v", err)
139+
}
140+
141+
// Create the output file
142+
// Create a temporary directory
143+
tempDir, err := ioutil.TempDir("", "nftables_gen")
144+
if err != nil {
145+
log.Fatal(err)
146+
}
147+
defer os.RemoveAll(tempDir) // Clean up the temporary directory
148+
149+
// Create the temporary Go file
150+
tempGoFile := filepath.Join(tempDir, "nftables_recreate.go")
151+
f, err := os.Create(tempGoFile)
152+
if err != nil {
153+
log.Fatal(err)
154+
}
155+
defer f.Close()
156+
157+
mw := io.MultiWriter(f, os.Stdout)
158+
buf.WriteTo(mw)
159+
160+
// Format the generated code
161+
log.Printf("formating file: %s", tempGoFile)
162+
cmd := exec.Command("gofmt", "-w", "-s", tempGoFile)
163+
output, err := cmd.CombinedOutput()
164+
if err != nil {
165+
log.Fatalf("gofmt error: %v\nOutput: %s", err, output)
166+
}
167+
168+
// Run the generated code
169+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
170+
defer cancel()
171+
172+
log.Printf("executing file: %s", tempGoFile)
173+
cmd = exec.CommandContext(ctx, "go", "run", tempGoFile)
174+
output, err = cmd.CombinedOutput()
175+
if err != nil {
176+
log.Fatalf("Execution error: %v\nOutput: %s", err, output)
177+
}
178+
179+
// Retrieve nftables state using nft
180+
log.Printf("obtain current ruleset: %s", tempGoFile)
181+
actualOutput, err := listNFTRuleset()
182+
if err != nil {
183+
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
184+
}
185+
186+
expectedOutput, err := os.ReadFile(filename)
187+
if err != nil {
188+
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
189+
}
190+
191+
if !compareMultilineStringsIgnoreIndentation(string(expectedOutput), actualOutput) {
192+
log.Printf("Expected output:\n%s", string(expectedOutput))
193+
log.Printf("Actual output:\n%s", actualOutput)
194+
195+
log.Fatalf("nftables ruleset mismatch:\n%s", cmp.Diff(string(expectedOutput), actualOutput))
196+
}
197+
198+
if err := flushNFTRuleset(); err != nil {
199+
log.Fatalf("Failed to flush nftables ruleset: %v", err)
200+
}
201+
}
202+
203+
func applyNFTRuleset(scriptPath string) (string, error) {
204+
cmd := exec.Command("nft", "--debug=netlink", "-f", scriptPath)
205+
out, err := cmd.CombinedOutput()
206+
if err != nil {
207+
return string(out), err
208+
}
209+
return strings.TrimSpace(string(out)), nil
210+
}
211+
212+
func listNFTRuleset() (string, error) {
213+
cmd := exec.Command("nft", "list", "ruleset")
214+
out, err := cmd.CombinedOutput()
215+
if err != nil {
216+
return string(out), err
217+
}
218+
return strings.TrimSpace(string(out)), nil
219+
}
220+
221+
func flushNFTRuleset() error {
222+
cmd := exec.Command("nft", "flush", "ruleset")
223+
return cmd.Run()
224+
}
225+
226+
func ChainHookRef(hookNum *nftables.ChainHook) string {
227+
i := uint32(0)
228+
if hookNum != nil {
229+
i = uint32(*hookNum)
230+
}
231+
switch i {
232+
case 0:
233+
return "nftables.ChainHookPrerouting"
234+
case 1:
235+
return "nftables.ChainHookInput"
236+
case 2:
237+
return "nftables.ChainHookForward"
238+
case 3:
239+
return "nftables.ChainHookOutput"
240+
case 4:
241+
return "nftables.ChainHookPostrouting"
242+
case 5:
243+
return "nftables.ChainHookIngress"
244+
case 6:
245+
return "nftables.ChainHookEgress"
246+
}
247+
return ""
248+
}
249+
250+
func ChainPrioRef(priority *nftables.ChainPriority) string {
251+
i := int32(0)
252+
if priority != nil {
253+
i = int32(*priority)
254+
}
255+
return fmt.Sprintf("nftables.ChainPriorityRef(%d)", i)
256+
}
257+
258+
func ChainTypeString(chaintype nftables.ChainType) string {
259+
switch chaintype {
260+
case nftables.ChainTypeFilter:
261+
return "nftables.ChainTypeFilter"
262+
case nftables.ChainTypeRoute:
263+
return "nftables.ChainTypeRoute"
264+
case nftables.ChainTypeNAT:
265+
return "nftables.ChainTypeNAT"
266+
default:
267+
return "nftables.ChainTypeFilter"
268+
}
269+
}
270+
271+
func TableFamilyString(family nftables.TableFamily) string {
272+
switch family {
273+
case nftables.TableFamilyUnspecified:
274+
return "nftables.TableFamilyUnspecified"
275+
case nftables.TableFamilyINet:
276+
return "nftables.TableFamilyINet"
277+
case nftables.TableFamilyIPv4:
278+
return "nftables.TableFamilyIPv4"
279+
case nftables.TableFamilyIPv6:
280+
return "nftables.TableFamilyIPv6"
281+
case nftables.TableFamilyARP:
282+
return "nftables.TableFamilyARP"
283+
case nftables.TableFamilyNetdev:
284+
return "nftables.TableFamilyNetdev"
285+
case nftables.TableFamilyBridge:
286+
return "nftables.TableFamilyBridge"
287+
default:
288+
return "nftables.TableFamilyIPv4"
289+
}
290+
}
291+
292+
func compareMultilineStringsIgnoreIndentation(str1, str2 string) bool {
293+
// Remove all indentation from both strings
294+
re := regexp.MustCompile(`(?m)^\s+`)
295+
str1 = re.ReplaceAllString(str1, "")
296+
str2 = re.ReplaceAllString(str2, "")
297+
298+
return str1 == str2
299+
}

0 commit comments

Comments
 (0)