Skip to content

Commit c0a4104

Browse files
thaisacsShiboXing
authored andcommitted
Fix g.costs (apache#17972)
1 parent 03e032c commit c0a4104

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

python/tvm/meta_schedule/cost_model/xgb_model.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -537,13 +537,17 @@ def _mean_cost(x: RunnerResult) -> float:
537537
self.last_train_size = self.data_size
538538

539539
# Step 5. Re-train the model
540-
self._train(
541-
xs=list(itertools_chain.from_iterable([g.features for g in self.data.values()])),
542-
ys=np.concatenate(
543-
[g.min_cost / g.costs for g in self.data.values()],
544-
axis=0,
545-
),
546-
)
540+
with np.errstate(divide="ignore", invalid="ignore"):
541+
feature_list = list(
542+
itertools_chain.from_iterable([g.features for g in self.data.values()])
543+
)
544+
cost_ratio_list = [
545+
np.divide(g.min_cost, g.costs, out=np.zeros_like(g.costs), where=g.costs != 0)
546+
for g in self.data.values()
547+
]
548+
cost_ratios = np.concatenate(cost_ratio_list, axis=0)
549+
550+
self._train(xs=feature_list, ys=cost_ratios)
547551

548552
def predict(
549553
self,

0 commit comments

Comments
 (0)