|
7 | 7 | "log"
|
8 | 8 | "time"
|
9 | 9 |
|
10 |
| - "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" |
11 | 10 | "github.com/pulumi/pulumi-go-provider/infer"
|
12 | 11 |
|
13 | 12 | "github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/services"
|
@@ -218,110 +217,125 @@ func (v VertexModelDeployment) Read(
|
218 | 217 | state := req.State
|
219 | 218 |
|
220 | 219 | if req.State.ModelName != "" {
|
221 |
| - // Lookup the model |
| 220 | + // Read the model from the registry |
| 221 | + |
222 | 222 | modelClientFactory := v.getModelClientFactory()
|
223 | 223 | modelClient, err := modelClientFactory(ctx, req.State.Region)
|
224 | 224 | if err != nil {
|
225 |
| - return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, err |
| 225 | + return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to create model client: %w", err) |
226 | 226 | }
|
227 | 227 | defer func() {
|
228 | 228 | if closeErr := modelClient.Close(); closeErr != nil {
|
229 | 229 | log.Printf("failed to close model client: %v", closeErr)
|
230 | 230 | }
|
231 | 231 | }()
|
232 | 232 |
|
233 |
| - modelGetter := services.NewVertexModelGet(ctx, modelClient, req.State.ModelName) |
234 |
| - model, err := modelGetter.Get(ctx, req.State.ModelName) |
| 233 | + err = readRegistryModel(ctx, modelClient, req, &state) |
235 | 234 | if err != nil {
|
236 |
| - return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, err |
237 |
| - } |
238 |
| - |
239 |
| - // Update state with current model values |
240 |
| - state.ModelName = model.Name |
241 |
| - state.ModelArtifactsBucketURI = model.ArtifactUri |
242 |
| - state.Labels = model.Labels |
243 |
| - |
244 |
| - // Safely access ContainerSpec fields |
245 |
| - if model.ContainerSpec != nil { |
246 |
| - state.ModelImageURL = model.ContainerSpec.ImageUri |
247 |
| - state.PredictRoute = model.ContainerSpec.PredictRoute |
248 |
| - state.HealthRoute = model.ContainerSpec.HealthRoute |
249 |
| - } |
250 |
| - |
251 |
| - // Safely access PredictSchemata fields |
252 |
| - if model.PredictSchemata != nil { |
253 |
| - state.ModelPredictionInputSchemaURI = model.PredictSchemata.InstanceSchemaUri |
254 |
| - state.ModelPredictionOutputSchemaURI = model.PredictSchemata.PredictionSchemaUri |
| 235 | + return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to read model from registry: %w", err) |
255 | 236 | }
|
256 | 237 | }
|
257 | 238 |
|
258 | 239 | if req.State.DeployedModelID != "" && req.State.EndpointName != "" {
|
259 |
| - // Lookup the endpoint if model is deployed to an endpoint |
| 240 | + // Read the endpoint if model is deployed to an endpoint |
260 | 241 |
|
261 | 242 | endpointClientFactory := v.getEndpointClientFactory()
|
262 | 243 | endpointClient, err := endpointClientFactory(ctx, req.State.Region)
|
263 | 244 | if err != nil {
|
264 |
| - return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, err |
| 245 | + return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to create endpoint client: %w", err) |
265 | 246 | }
|
266 | 247 | defer func() {
|
267 | 248 | if closeErr := endpointClient.Close(); closeErr != nil {
|
268 | 249 | log.Printf("failed to close endpoint client: %v", closeErr)
|
269 | 250 | }
|
270 | 251 | }()
|
271 | 252 |
|
272 |
| - getReq := &aiplatformpb.GetEndpointRequest{ |
273 |
| - Name: fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", |
274 |
| - req.State.ProjectID, req.State.Region, req.State.EndpointName), |
275 |
| - } |
276 |
| - |
277 |
| - endpoint, err := endpointClient.GetEndpoint(ctx, getReq) |
| 253 | + err = readEndpointModel(ctx, endpointClient, req, &state) |
278 | 254 | if err != nil {
|
279 |
| - return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, err |
| 255 | + return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to read model endpoint: %w", err) |
280 | 256 | }
|
| 257 | + } |
281 | 258 |
|
282 |
| - // Verify the deployed model still exists and update its properties |
283 |
| - var foundDeployedModel *aiplatformpb.DeployedModel |
284 |
| - for _, deployedModel := range endpoint.DeployedModels { |
285 |
| - if deployedModel.Id == req.State.DeployedModelID { |
286 |
| - foundDeployedModel = deployedModel |
| 259 | + return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{ |
| 260 | + Inputs: req.Inputs, |
| 261 | + State: state, |
| 262 | + }, nil |
| 263 | +} |
287 | 264 |
|
288 |
| - break |
289 |
| - } |
| 265 | +func readEndpointModel(ctx context.Context, |
| 266 | + endpointClient services.VertexEndpointClient, |
| 267 | + req infer.ReadRequest[VertexModelDeploymentArgs, VertexModelDeploymentState], |
| 268 | + state *VertexModelDeploymentState) error { |
| 269 | + |
| 270 | + endpointGetter := services.NewVertexEndpointModelGetter(endpointClient, req.State.ProjectID, req.State.Region) |
| 271 | + endpoint, foundDeployedModel, err := endpointGetter.Get(ctx, req.State.EndpointName, req.State.DeployedModelID) |
| 272 | + if err != nil { |
| 273 | + return err |
| 274 | + } |
| 275 | + |
| 276 | + if foundDeployedModel == nil { |
| 277 | + // Model is no longer deployed - return empty response to indicate resource doesn't exist |
| 278 | + return nil |
| 279 | + } |
| 280 | + |
| 281 | + // Update state with current endpoint and deployed model information |
| 282 | + state.EndpointName = endpoint.Name |
| 283 | + state.DeployedModelID = foundDeployedModel.Id |
| 284 | + |
| 285 | + // Update endpoint deployment configuration with current values if available |
| 286 | + if state.EndpointModelDeployment == nil { |
| 287 | + return nil |
| 288 | + } |
| 289 | + |
| 290 | + // Extract current deployment configuration from the deployed model |
| 291 | + if dedicatedResources := foundDeployedModel.GetDedicatedResources(); dedicatedResources != nil { |
| 292 | + if machineSpec := dedicatedResources.MachineSpec; machineSpec != nil { |
| 293 | + state.EndpointModelDeployment.MachineType = machineSpec.MachineType |
290 | 294 | }
|
| 295 | + state.EndpointModelDeployment.MinReplicas = int(dedicatedResources.MinReplicaCount) |
| 296 | + state.EndpointModelDeployment.MaxReplicas = int(dedicatedResources.MaxReplicaCount) |
| 297 | + } |
291 | 298 |
|
292 |
| - if foundDeployedModel == nil { |
293 |
| - // Model is no longer deployed - return empty response to indicate resource doesn't exist |
294 |
| - return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, nil |
| 299 | + // Update traffic percentage from endpoint's traffic split if available |
| 300 | + if endpoint.TrafficSplit != nil { |
| 301 | + if trafficPercent, exists := endpoint.TrafficSplit[foundDeployedModel.Id]; exists { |
| 302 | + state.EndpointModelDeployment.TrafficPercent = int(trafficPercent) |
295 | 303 | }
|
| 304 | + } |
296 | 305 |
|
297 |
| - // Update state with current endpoint and deployed model information |
298 |
| - state.EndpointName = endpoint.Name |
299 |
| - state.DeployedModelID = foundDeployedModel.Id |
300 |
| - |
301 |
| - // Update endpoint deployment configuration with current values if available |
302 |
| - if state.EndpointModelDeployment != nil && foundDeployedModel.PredictionResources != nil { |
303 |
| - // Extract current deployment configuration from the deployed model |
304 |
| - if dedicatedResources := foundDeployedModel.GetDedicatedResources(); dedicatedResources != nil { |
305 |
| - if machineSpec := dedicatedResources.MachineSpec; machineSpec != nil { |
306 |
| - state.EndpointModelDeployment.MachineType = machineSpec.MachineType |
307 |
| - } |
308 |
| - state.EndpointModelDeployment.MinReplicas = int(dedicatedResources.MinReplicaCount) |
309 |
| - state.EndpointModelDeployment.MaxReplicas = int(dedicatedResources.MaxReplicaCount) |
310 |
| - } |
| 306 | + return nil |
| 307 | +} |
311 | 308 |
|
312 |
| - // Update traffic percentage from endpoint's traffic split if available |
313 |
| - if endpoint.TrafficSplit != nil { |
314 |
| - if trafficPercent, exists := endpoint.TrafficSplit[foundDeployedModel.Id]; exists { |
315 |
| - state.EndpointModelDeployment.TrafficPercent = int(trafficPercent) |
316 |
| - } |
317 |
| - } |
318 |
| - } |
| 309 | +func readRegistryModel(ctx context.Context, |
| 310 | + modelClient services.VertexModelClient, |
| 311 | + req infer.ReadRequest[VertexModelDeploymentArgs, |
| 312 | + VertexModelDeploymentState], state *VertexModelDeploymentState) error { |
| 313 | + |
| 314 | + modelGetter := services.NewVertexModelGet(ctx, modelClient, req.State.ModelName) |
| 315 | + model, err := modelGetter.Get(ctx, req.State.ModelName) |
| 316 | + if err != nil { |
| 317 | + return fmt.Errorf("failed to get model: %w", err) |
319 | 318 | }
|
320 | 319 |
|
321 |
| - return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{ |
322 |
| - Inputs: req.Inputs, |
323 |
| - State: state, |
324 |
| - }, nil |
| 320 | + // Update state with current model values |
| 321 | + state.ModelName = model.Name |
| 322 | + state.ModelArtifactsBucketURI = model.ArtifactUri |
| 323 | + state.Labels = model.Labels |
| 324 | + |
| 325 | + // Safely access ContainerSpec fields |
| 326 | + if model.ContainerSpec != nil { |
| 327 | + state.ModelImageURL = model.ContainerSpec.ImageUri |
| 328 | + state.PredictRoute = model.ContainerSpec.PredictRoute |
| 329 | + state.HealthRoute = model.ContainerSpec.HealthRoute |
| 330 | + } |
| 331 | + |
| 332 | + // Safely access PredictSchemata fields |
| 333 | + if model.PredictSchemata != nil { |
| 334 | + state.ModelPredictionInputSchemaURI = model.PredictSchemata.InstanceSchemaUri |
| 335 | + state.ModelPredictionOutputSchemaURI = model.PredictSchemata.PredictionSchemaUri |
| 336 | + } |
| 337 | + |
| 338 | + return nil |
325 | 339 | }
|
326 | 340 |
|
327 | 341 | // testFactoryRegistry holds test factories for dependency injection during testing
|
|
0 commit comments