|
| 1 | +# ramalama/generate/compose.py |
| 2 | + |
| 3 | +import os |
| 4 | +import shlex |
| 5 | +from typing import Optional, Tuple |
| 6 | + |
| 7 | +from ramalama.common import RAG_DIR, get_accel_env_vars |
| 8 | +from ramalama.file import PlainFile |
| 9 | +from ramalama.version import version |
| 10 | + |
| 11 | + |
| 12 | +class Compose: |
| 13 | + def __init__( |
| 14 | + self, |
| 15 | + model_name: str, |
| 16 | + model_paths: Tuple[str, str], |
| 17 | + chat_template_paths: Optional[Tuple[str, str]], |
| 18 | + mmproj_paths: Optional[Tuple[str, str]], |
| 19 | + args, |
| 20 | + exec_args, |
| 21 | + ): |
| 22 | + self.src_model_path, self.dest_model_path = model_paths |
| 23 | + self.src_chat_template_path, self.dest_chat_template_path = ( |
| 24 | + chat_template_paths if chat_template_paths is not None else ("", "") |
| 25 | + ) |
| 26 | + self.src_mmproj_path, self.dest_mmproj_path = mmproj_paths if mmproj_paths is not None else ("", "") |
| 27 | + self.src_model_path = self.src_model_path.removeprefix("oci://") |
| 28 | + |
| 29 | + self.model_name = model_name |
| 30 | + custom_name = getattr(args, "name", None) |
| 31 | + self.name = custom_name if custom_name else f"ramalama-{model_name}" |
| 32 | + self.args = args |
| 33 | + self.exec_args = exec_args |
| 34 | + self.image = args.image |
| 35 | + |
| 36 | + def _gen_volumes(self) -> str: |
| 37 | + volumes = " volumes:" |
| 38 | + |
| 39 | + # Model Volume |
| 40 | + volumes += self._gen_model_volume() |
| 41 | + |
| 42 | + # RAG Volume |
| 43 | + if getattr(self.args, "rag", None): |
| 44 | + volumes += self._gen_rag_volume() |
| 45 | + |
| 46 | + # Chat Template Volume |
| 47 | + if self.src_chat_template_path and os.path.exists(self.src_chat_template_path): |
| 48 | + volumes += self._gen_chat_template_volume() |
| 49 | + |
| 50 | + # MMProj Volume |
| 51 | + if self.src_mmproj_path and os.path.exists(self.src_mmproj_path): |
| 52 | + volumes += self._gen_mmproj_volume() |
| 53 | + |
| 54 | + return volumes |
| 55 | + |
| 56 | + def _gen_model_volume(self) -> str: |
| 57 | + return f'\n - "{self.src_model_path}:{self.dest_model_path}:ro"' |
| 58 | + |
| 59 | + def _gen_rag_volume(self) -> str: |
| 60 | + rag_source = self.args.rag |
| 61 | + volume_str = "" |
| 62 | + |
| 63 | + if rag_source.startswith("oci:") or rag_source.startswith("oci://"): |
| 64 | + if rag_source.startswith("oci://"): |
| 65 | + oci_image = rag_source.removeprefix("oci://") |
| 66 | + else: |
| 67 | + oci_image = rag_source.removeprefix("oci:") |
| 68 | + # This is the standard long-form syntax for image volumes, now supported by Docker. |
| 69 | + volume_str = f""" |
| 70 | + - type: image |
| 71 | + source: {oci_image} |
| 72 | + target: {RAG_DIR} |
| 73 | + image: |
| 74 | + readonly: true""" |
| 75 | + |
| 76 | + elif os.path.exists(rag_source): |
| 77 | + # Standard host path mount |
| 78 | + volume_str = f'\n - "{rag_source}:{RAG_DIR}:ro"' |
| 79 | + |
| 80 | + return volume_str |
| 81 | + |
| 82 | + def _gen_chat_template_volume(self) -> str: |
| 83 | + return f'\n - "{self.src_chat_template_path}:{self.dest_chat_template_path}:ro"' |
| 84 | + |
| 85 | + def _gen_mmproj_volume(self) -> str: |
| 86 | + return f'\n - "{self.src_mmproj_path}:{self.dest_mmproj_path}:ro"' |
| 87 | + |
| 88 | + def _gen_devices(self) -> str: |
| 89 | + device_list = [] |
| 90 | + for dev_path in ["/dev/dri", "/dev/kfd", "/dev/accel"]: |
| 91 | + if os.path.exists(dev_path): |
| 92 | + device_list.append(dev_path) |
| 93 | + |
| 94 | + if not device_list: |
| 95 | + return "" |
| 96 | + |
| 97 | + devices_str = " devices:" |
| 98 | + for dev in device_list: |
| 99 | + devices_str += f'\n - "{dev}:{dev}"' |
| 100 | + return devices_str |
| 101 | + |
| 102 | + def _gen_ports(self) -> str: |
| 103 | + port_arg = getattr(self.args, "port", None) |
| 104 | + if not port_arg: |
| 105 | + # Default to 8080 if no port is specified |
| 106 | + return ' ports:\n - "8080:8080"' |
| 107 | + |
| 108 | + p = port_arg.split(":", 2) |
| 109 | + host_port = p[1] if len(p) > 1 else p[0] |
| 110 | + container_port = p[0] |
| 111 | + return f' ports:\n - "{host_port}:{container_port}"' |
| 112 | + |
| 113 | + def _gen_environment(self) -> str: |
| 114 | + env_vars = get_accel_env_vars() |
| 115 | + # Allow user to override with --env |
| 116 | + if getattr(self.args, "env", None): |
| 117 | + for e in self.args.env: |
| 118 | + key, val = e.split("=", 1) |
| 119 | + env_vars[key] = val |
| 120 | + |
| 121 | + if not env_vars: |
| 122 | + return "" |
| 123 | + |
| 124 | + env_spec = " environment:" |
| 125 | + for k, v in env_vars.items(): |
| 126 | + env_spec += f'\n - {k}={v}' |
| 127 | + return env_spec |
| 128 | + |
| 129 | + def _gen_gpu_deployment(self) -> str: |
| 130 | + gpu_keywords = ["cuda", "rocm", "gpu"] |
| 131 | + if not any(keyword in self.image.lower() for keyword in gpu_keywords): |
| 132 | + return "" |
| 133 | + |
| 134 | + return """\ |
| 135 | + deploy: |
| 136 | + resources: |
| 137 | + reservations: |
| 138 | + devices: |
| 139 | + - driver: nvidia |
| 140 | + count: all |
| 141 | + capabilities: [gpu]""" |
| 142 | + |
| 143 | + def _gen_command(self) -> str: |
| 144 | + if not self.exec_args: |
| 145 | + return "" |
| 146 | + # shlex.join is perfect for creating a command string from a list |
| 147 | + cmd = shlex.join(self.exec_args) |
| 148 | + return f" command: {cmd}" |
| 149 | + |
| 150 | + def generate(self) -> PlainFile: |
| 151 | + _version = version() |
| 152 | + |
| 153 | + # Generate all the dynamic sections of the YAML file |
| 154 | + volumes_string = self._gen_volumes() |
| 155 | + ports_string = self._gen_ports() |
| 156 | + environment_string = self._gen_environment() |
| 157 | + devices_string = self._gen_devices() |
| 158 | + gpu_deploy_string = self._gen_gpu_deployment() |
| 159 | + command_string = self._gen_command() |
| 160 | + |
| 161 | + # Assemble the final file content |
| 162 | + content = f"""\ |
| 163 | +# Save this output to a 'docker-compose.yaml' file and run 'docker compose up'. |
| 164 | +# |
| 165 | +# Created with ramalama-{_version} |
| 166 | +
|
| 167 | +services: |
| 168 | + {self.model_name}: |
| 169 | + container_name: {self.name} |
| 170 | + image: {self.image} |
| 171 | +{volumes_string} |
| 172 | +{ports_string} |
| 173 | +{environment_string} |
| 174 | +{devices_string} |
| 175 | +{gpu_deploy_string} |
| 176 | +{command_string} |
| 177 | + restart: unless-stopped |
| 178 | +""" |
| 179 | + # Clean up any empty lines that might result from empty sections |
| 180 | + content = "\n".join(line for line in content.splitlines() if line.strip()) |
| 181 | + |
| 182 | + return genfile(self.name, content) |
| 183 | + |
| 184 | + |
| 185 | +def genfile(name: str, content: str) -> PlainFile: |
| 186 | + file_name = "docker-compose.yaml" |
| 187 | + print(f"Generating Docker Compose file: {file_name}") |
| 188 | + |
| 189 | + file = PlainFile(file_name) |
| 190 | + file.content = content |
| 191 | + return file |
0 commit comments