Skip to content

Commit 4223771

Browse files
committed
Pick up vector length from 'zvlXXXb' (RVV) mattr for riscv
1 parent 4ac03b3 commit 4223771

File tree

8 files changed

+223
-35
lines changed

8 files changed

+223
-35
lines changed

python/tvm/target/codegen.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,23 @@ def llvm_cpu_has_features(cpu_features, target=None):
211211
return has_feats
212212

213213

214+
def llvm_get_vector_width(target=None):
215+
"""Get vector width from LLVM target's `-mtriple` and `-mcpu` and considering `-mattr`.
216+
217+
Parameters
218+
----------
219+
target : Target
220+
The TVM target.
221+
222+
Returns
223+
-------
224+
vector_width : int
225+
Vector with of target in number of bits, -1 on error.
226+
"""
227+
assert isinstance(target, Target) or target is None
228+
return _ffi_api.llvm_get_vector_width(target)
229+
230+
214231
def llvm_version_major(allow_none=False):
215232
"""Get the major LLVM version.
216233

python/tvm/testing/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,41 @@ def _multi_gpu_exists():
848848
"llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm"
849849
)
850850

851+
852+
# Mark a test as requiring minimum llvm version
853+
def requires_llvm_minimum_version(major_version):
854+
"""Mark a test as requiring at least a specific version of LLVM.
855+
856+
Unit test marked with this decorator will run only if the
857+
installed version of LLVM is at least `major_version`.
858+
859+
This also marks the test as requiring LLVM backend support.
860+
861+
Parameters
862+
----------
863+
major_version: int
864+
865+
866+
"""
867+
868+
try:
869+
llvm_version = tvm.target.codegen.llvm_version_major()
870+
except RuntimeError:
871+
llvm_version = 0
872+
873+
requires = [
874+
pytest.mark.skipif(
875+
llvm_version < major_version, reason=f"Requires LLVM >= {major_version}"
876+
),
877+
*requires_llvm.marks(),
878+
]
879+
880+
def inner(func):
881+
return _compose([func], requires)
882+
883+
return inner
884+
885+
851886
# Mark a test as requiring a GPU to run.
852887
requires_gpu = Feature("gpu", run_time_check=_any_gpu_exists)
853888

src/target/llvm/codegen_llvm.cc

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -174,32 +174,8 @@ void CodeGenLLVM::InitTarget() {
174174
data_layout_.reset(new llvm::DataLayout(module_.get()));
175175
#endif
176176
if (native_vector_bits_ == 0) {
177-
const int vwidth = llvm_target_->GetVectorWidth();
178-
const auto& arch = tm->getTargetTriple().getArch();
179-
const std::string arch_name = std::string(tm->getTargetTriple().getArchName());
180-
if (vwidth > 0) {
181-
// override from target options
182-
// e.g. llvm -vector-width=xxx
183-
native_vector_bits_ = vwidth;
184-
} else if (arch == llvm::Triple::x86_64) {
185-
// for avx512
186-
native_vector_bits_ = 512;
187-
} else if (arch == llvm::Triple::x86) {
188-
native_vector_bits_ = 256;
189-
} else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
190-
native_vector_bits_ = 128;
191-
} else if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
192-
native_vector_bits_ = 256;
193-
LOG(WARNING) << "LLVM RVV VLEN inference failed, "
194-
<< "using 256 bits, set -vector-width=XXX to override";
195-
// fallback default
196-
} else {
197-
native_vector_bits_ = 128;
198-
LOG(WARNING) << "Set native vector bits to be 128 for `" << arch_name
199-
<< "`, use -vector-width=XXX to override.";
200-
}
177+
native_vector_bits_ = llvm_target_->GetVectorWidth();
201178
}
202-
203179
#if TVM_LLVM_VERSION >= 60
204180
bool use_float16_abi = false;
205181
#if TVM_LLVM_VERSION >= 150

src/target/llvm/llvm_instance.cc

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target)
288288
// TVM & LLVM vector width options
289289
if (const auto& w = Downcast<Optional<runtime::Int>>(target.Get("vector-width"))) {
290290
vector_width_ = w.value();
291-
if ((vector_width_ <= 0) || (vector_width_ > 65535)) {
291+
if ((vector_width_ <= 0) || (vector_width_ > 65536)) {
292292
LOG(FATAL) << "Invalid -vector-width value: " << vector_width_;
293293
}
294294
}
@@ -300,26 +300,32 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target)
300300
code_model_ = llvm::CodeModel::Medium;
301301
#if TVM_LLVM_VERSION >= 140
302302
// VLEN inference
303-
const auto* mci = GetOrCreateTargetMachine(false)->getMCSubtargetInfo();
304-
const auto cpu_name = mci->getCPU();
305-
const auto m_arch = llvm::RISCV::getMArchFromMcpu(cpu_name);
303+
const auto cpu_name = GetOrCreateTargetMachine(false)->getMCSubtargetInfo()->getCPU();
304+
const auto canon_arch = llvm::RISCV::getMArchFromMcpu(cpu_name);
306305
auto ISAInfo =
307-
llvm::RISCVISAInfo::parseArchString(m_arch, /*EnableExperimentalExtensions=*/true);
308-
// infer VLEN from LLVM or via options
306+
llvm::RISCVISAInfo::parseArchString(canon_arch, /*EnableExperimentalExtensions=*/true);
307+
// infer VLEN from LLVM RISCVInfo parser
309308
if (!llvm::errorToBool(ISAInfo.takeError()) && (vector_width_ == 0)) {
310309
vector_width_ = (*ISAInfo)->getMinVLen();
311310
}
311+
// infer VLEN from LLVM options (zvlXXXb override)
312+
for (const auto& attr : attrs_) {
313+
if (attr.find("zvl") != std::string::npos) {
314+
std::string vec;
315+
for (char c : attr) {
316+
if (std::isdigit(c)) vec += c;
317+
}
318+
vector_width_ = std::stoi(vec);
319+
}
320+
}
312321
#endif
313322
if (vector_width_ > 0) {
314323
// push cl-opt to LLVM
315324
llvm_options_.push_back(
316325
ParseOptionString("-riscv-v-vector-bits-min:int=" + std::to_string(vector_width_)));
317-
llvm_options_.push_back(
318-
ParseOptionString("-riscv-v-vector-bits-max:int=" + std::to_string(vector_width_)));
319326
} else {
320327
// fallback default (codegen will warn)
321328
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-min:int=256"));
322-
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-max:int=256"));
323329
}
324330
}
325331

@@ -924,6 +930,32 @@ const bool LLVMTargetInfo::TargetHasCPUFeature(const std::string& feature) const
924930
return has_feature;
925931
}
926932

933+
const int LLVMTargetInfo::GetVectorWidth() {
934+
auto* tm = GetOrCreateTargetMachine(false);
935+
const auto& arch = tm->getTargetTriple().getArch();
936+
const std::string arch_name = std::string(tm->getTargetTriple().getArchName());
937+
if (vector_width_ == 0) {
938+
if (arch == llvm::Triple::x86_64) {
939+
// for avx512
940+
vector_width_ = 512;
941+
} else if (arch == llvm::Triple::x86) {
942+
vector_width_ = 256;
943+
} else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
944+
vector_width_ = 128;
945+
} else if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
946+
vector_width_ = 256;
947+
LOG(WARNING) << "LLVM RVV VLEN inference failed, "
948+
<< "using 256 bits, set -vector-width=XXX to override";
949+
// fallback default
950+
} else {
951+
vector_width_ = 128;
952+
LOG(WARNING) << "Set native vector bits to be 128 for `" << arch_name
953+
<< "`, use -vector-width=XXX to override.";
954+
}
955+
}
956+
return vector_width_;
957+
}
958+
927959
// LLVMTarget
928960

929961
bool LLVMTarget::modified_llvm_state_ = false;

src/target/llvm/llvm_instance.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ class LLVMTargetInfo {
246246
* \brief Get the TVM & LLVM vector_width
247247
* \return number of bits for vector width
248248
*/
249-
const int GetVectorWidth() const { return vector_width_; }
249+
const int GetVectorWidth();
250250
/*!
251251
* \brief Get the LLVM optimization level
252252
* \return optimization level for this target

src/target/llvm/llvm_module.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,19 @@ TVM_REGISTER_GLOBAL("target.llvm_get_system_x86_vendor").set_body_typed([]() ->
690690
return "unimplemented";
691691
});
692692

693+
TVM_REGISTER_GLOBAL("target.llvm_get_vector_width").set_body_typed([](const Target& target) -> int {
694+
auto use_target = target.defined() ? target : Target::Current(false);
695+
// ignore non "llvm" target
696+
if (target.defined()) {
697+
if (target->kind->name != "llvm") {
698+
return -1;
699+
}
700+
}
701+
auto llvm_instance = std::make_unique<LLVMInstance>();
702+
LLVMTargetInfo llvm_backend(*llvm_instance, use_target);
703+
return llvm_backend.GetVectorWidth();
704+
});
705+
693706
TVM_REGISTER_GLOBAL("target.llvm_get_system_triple").set_body_typed([]() -> String {
694707
return llvm::sys::getDefaultTargetTriple();
695708
});
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
from tvm.script import tir as T
19+
from tvm.target.codegen import target_has_features
20+
21+
22+
@tvm.testing.requires_llvm_minimum_version(14)
23+
@tvm.testing.parametrize_targets(
24+
"llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m",
25+
"llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v",
26+
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m",
27+
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v",
28+
)
29+
def test_rvv(target):
30+
def check_rvv_presence(N, extent):
31+
@T.prim_func
32+
def load_vec(A: T.Buffer((N,), "int8")):
33+
for j in T.vectorized(0, extent):
34+
A[j] = 1
35+
36+
f = tvm.build(load_vec, target)
37+
# Check RVV `vsetvli` prensence
38+
assembly = f.get_source("asm")
39+
if target_has_features("v"):
40+
assert "vsetvli" in assembly
41+
else:
42+
assert "vsetvli" not in assembly
43+
44+
with tvm.target.Target(target):
45+
check_rvv_presence(16, 32)
46+
47+
48+
if __name__ == "__main__":
49+
test_rvv()
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import pytest
18+
19+
import tvm
20+
from tvm.target import _ffi_api, codegen, Target
21+
from tvm.target.codegen import target_has_features, llvm_get_vector_width
22+
23+
LLVM_VERSION = codegen.llvm_version_major()
24+
25+
# fmt: off
26+
min_llvm_version, tvm_target, vec_width = tvm.testing.parameters(
27+
# generic, no-vec -> (default 256)
28+
(-1, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+i,+m", 256),
29+
(-1, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+64bit,+a,+c,+d,+f,+m", 256),
30+
# generic, with-vec -> (default 256)
31+
(-1, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", 256),
32+
(-1, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", 256),
33+
# explicit -vector-width
34+
(-1, "llvm -device=riscv_cpu -vector-width=128 -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", 128),
35+
(-1, "llvm -device=riscv_cpu -vector-width=128 -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", 128),
36+
(-1, "llvm -device=riscv_cpu -vector-width=512 -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", 512),
37+
(-1, "llvm -device=riscv_cpu -vector-width=512 -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", 512),
38+
# explicit +zvlXXXb
39+
(14, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v,+zvl64b", 64),
40+
(14, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v,+zvl64b", 64),
41+
# vendor CPU
42+
(17, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=sifive-x280", 512),
43+
(18, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=sifive-p670", 128),
44+
(19, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=spacemit-x60", 256),
45+
) # fmt: on
46+
47+
48+
def test_riscv_rvv_features(min_llvm_version, tvm_target, vec_width):
49+
"""Test RVV features support for different targets.
50+
51+
Parameters
52+
----------
53+
min_llvm_version : int
54+
Minimal LLVM version.
55+
tvm_target : str
56+
TVM target.
57+
vec_width : bool
58+
Expected vector width.
59+
"""
60+
61+
# skip test on llvm_version
62+
if LLVM_VERSION < min_llvm_version:
63+
return
64+
65+
with Target(tvm_target):
66+
assert llvm_get_vector_width() == vec_width

0 commit comments

Comments
 (0)