Skip to content

Commit aa4702c

Browse files
separate model from endpoint lookup
1 parent a25aced commit aa4702c

File tree

8 files changed

+157
-81
lines changed

8 files changed

+157
-81
lines changed

pkg/resources/provider.go

Lines changed: 82 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"log"
88
"time"
99

10-
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
1110
"github.com/pulumi/pulumi-go-provider/infer"
1211

1312
"github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/services"
@@ -218,110 +217,125 @@ func (v VertexModelDeployment) Read(
218217
state := req.State
219218

220219
if req.State.ModelName != "" {
221-
// Lookup the model
220+
// Read the model from the registry
221+
222222
modelClientFactory := v.getModelClientFactory()
223223
modelClient, err := modelClientFactory(ctx, req.State.Region)
224224
if err != nil {
225-
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, err
225+
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to create model client: %w", err)
226226
}
227227
defer func() {
228228
if closeErr := modelClient.Close(); closeErr != nil {
229229
log.Printf("failed to close model client: %v", closeErr)
230230
}
231231
}()
232232

233-
modelGetter := services.NewVertexModelGet(ctx, modelClient, req.State.ModelName)
234-
model, err := modelGetter.Get(ctx, req.State.ModelName)
233+
err = readRegistryModel(ctx, modelClient, req, &state)
235234
if err != nil {
236-
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, err
237-
}
238-
239-
// Update state with current model values
240-
state.ModelName = model.Name
241-
state.ModelArtifactsBucketURI = model.ArtifactUri
242-
state.Labels = model.Labels
243-
244-
// Safely access ContainerSpec fields
245-
if model.ContainerSpec != nil {
246-
state.ModelImageURL = model.ContainerSpec.ImageUri
247-
state.PredictRoute = model.ContainerSpec.PredictRoute
248-
state.HealthRoute = model.ContainerSpec.HealthRoute
249-
}
250-
251-
// Safely access PredictSchemata fields
252-
if model.PredictSchemata != nil {
253-
state.ModelPredictionInputSchemaURI = model.PredictSchemata.InstanceSchemaUri
254-
state.ModelPredictionOutputSchemaURI = model.PredictSchemata.PredictionSchemaUri
235+
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to read model from registry: %w", err)
255236
}
256237
}
257238

258239
if req.State.DeployedModelID != "" && req.State.EndpointName != "" {
259-
// Lookup the endpoint if model is deployed to an endpoint
240+
// Read the endpoint if model is deployed to an endpoint
260241

261242
endpointClientFactory := v.getEndpointClientFactory()
262243
endpointClient, err := endpointClientFactory(ctx, req.State.Region)
263244
if err != nil {
264-
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, err
245+
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to create endpoint client: %w", err)
265246
}
266247
defer func() {
267248
if closeErr := endpointClient.Close(); closeErr != nil {
268249
log.Printf("failed to close endpoint client: %v", closeErr)
269250
}
270251
}()
271252

272-
getReq := &aiplatformpb.GetEndpointRequest{
273-
Name: fmt.Sprintf("projects/%s/locations/%s/endpoints/%s",
274-
req.State.ProjectID, req.State.Region, req.State.EndpointName),
275-
}
276-
277-
endpoint, err := endpointClient.GetEndpoint(ctx, getReq)
253+
err = readEndpointModel(ctx, endpointClient, req, &state)
278254
if err != nil {
279-
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, err
255+
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to read model endpoint: %w", err)
280256
}
257+
}
281258

282-
// Verify the deployed model still exists and update its properties
283-
var foundDeployedModel *aiplatformpb.DeployedModel
284-
for _, deployedModel := range endpoint.DeployedModels {
285-
if deployedModel.Id == req.State.DeployedModelID {
286-
foundDeployedModel = deployedModel
259+
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{
260+
Inputs: req.Inputs,
261+
State: state,
262+
}, nil
263+
}
287264

288-
break
289-
}
265+
func readEndpointModel(ctx context.Context,
266+
endpointClient services.VertexEndpointClient,
267+
req infer.ReadRequest[VertexModelDeploymentArgs, VertexModelDeploymentState],
268+
state *VertexModelDeploymentState) error {
269+
270+
endpointGetter := services.NewVertexEndpointModelGetter(endpointClient, req.State.ProjectID, req.State.Region)
271+
endpoint, foundDeployedModel, err := endpointGetter.Get(ctx, req.State.EndpointName, req.State.DeployedModelID)
272+
if err != nil {
273+
return err
274+
}
275+
276+
if foundDeployedModel == nil {
277+
// Model is no longer deployed - return empty response to indicate resource doesn't exist
278+
return nil
279+
}
280+
281+
// Update state with current endpoint and deployed model information
282+
state.EndpointName = endpoint.Name
283+
state.DeployedModelID = foundDeployedModel.Id
284+
285+
// Update endpoint deployment configuration with current values if available
286+
if state.EndpointModelDeployment == nil {
287+
return nil
288+
}
289+
290+
// Extract current deployment configuration from the deployed model
291+
if dedicatedResources := foundDeployedModel.GetDedicatedResources(); dedicatedResources != nil {
292+
if machineSpec := dedicatedResources.MachineSpec; machineSpec != nil {
293+
state.EndpointModelDeployment.MachineType = machineSpec.MachineType
290294
}
295+
state.EndpointModelDeployment.MinReplicas = int(dedicatedResources.MinReplicaCount)
296+
state.EndpointModelDeployment.MaxReplicas = int(dedicatedResources.MaxReplicaCount)
297+
}
291298

292-
if foundDeployedModel == nil {
293-
// Model is no longer deployed - return empty response to indicate resource doesn't exist
294-
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, nil
299+
// Update traffic percentage from endpoint's traffic split if available
300+
if endpoint.TrafficSplit != nil {
301+
if trafficPercent, exists := endpoint.TrafficSplit[foundDeployedModel.Id]; exists {
302+
state.EndpointModelDeployment.TrafficPercent = int(trafficPercent)
295303
}
304+
}
296305

297-
// Update state with current endpoint and deployed model information
298-
state.EndpointName = endpoint.Name
299-
state.DeployedModelID = foundDeployedModel.Id
300-
301-
// Update endpoint deployment configuration with current values if available
302-
if state.EndpointModelDeployment != nil && foundDeployedModel.PredictionResources != nil {
303-
// Extract current deployment configuration from the deployed model
304-
if dedicatedResources := foundDeployedModel.GetDedicatedResources(); dedicatedResources != nil {
305-
if machineSpec := dedicatedResources.MachineSpec; machineSpec != nil {
306-
state.EndpointModelDeployment.MachineType = machineSpec.MachineType
307-
}
308-
state.EndpointModelDeployment.MinReplicas = int(dedicatedResources.MinReplicaCount)
309-
state.EndpointModelDeployment.MaxReplicas = int(dedicatedResources.MaxReplicaCount)
310-
}
306+
return nil
307+
}
311308

312-
// Update traffic percentage from endpoint's traffic split if available
313-
if endpoint.TrafficSplit != nil {
314-
if trafficPercent, exists := endpoint.TrafficSplit[foundDeployedModel.Id]; exists {
315-
state.EndpointModelDeployment.TrafficPercent = int(trafficPercent)
316-
}
317-
}
318-
}
309+
func readRegistryModel(ctx context.Context,
310+
modelClient services.VertexModelClient,
311+
req infer.ReadRequest[VertexModelDeploymentArgs,
312+
VertexModelDeploymentState], state *VertexModelDeploymentState) error {
313+
314+
modelGetter := services.NewVertexModelGet(ctx, modelClient, req.State.ModelName)
315+
model, err := modelGetter.Get(ctx, req.State.ModelName)
316+
if err != nil {
317+
return fmt.Errorf("failed to get model: %w", err)
319318
}
320319

321-
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{
322-
Inputs: req.Inputs,
323-
State: state,
324-
}, nil
320+
// Update state with current model values
321+
state.ModelName = model.Name
322+
state.ModelArtifactsBucketURI = model.ArtifactUri
323+
state.Labels = model.Labels
324+
325+
// Safely access ContainerSpec fields
326+
if model.ContainerSpec != nil {
327+
state.ModelImageURL = model.ContainerSpec.ImageUri
328+
state.PredictRoute = model.ContainerSpec.PredictRoute
329+
state.HealthRoute = model.ContainerSpec.HealthRoute
330+
}
331+
332+
// Safely access PredictSchemata fields
333+
if model.PredictSchemata != nil {
334+
state.ModelPredictionInputSchemaURI = model.PredictSchemata.InstanceSchemaUri
335+
state.ModelPredictionOutputSchemaURI = model.PredictSchemata.PredictionSchemaUri
336+
}
337+
338+
return nil
325339
}
326340

327341
// testFactoryRegistry holds test factories for dependency injection during testing

pkg/resources/provider_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ import (
1414
)
1515

1616
const (
17-
testProjectID = "test-project"
18-
testRegion = "us-central1"
19-
testEndpointID = "test-endpoint"
20-
testModelImageURL = "gcr.io/test-project/custom-model:latest"
21-
testModelArtifactsBucketURI = "gs://test-bucket/model-artifacts/"
22-
testModelPredictionInputSchemaURI = "gs://test-bucket/schemas/input_schema.json"
17+
testProjectID = "test-project"
18+
testRegion = "us-central1"
19+
testEndpointID = "test-endpoint"
20+
testModelImageURL = "gcr.io/test-project/custom-model:latest"
21+
testModelArtifactsBucketURI = "gs://test-bucket/model-artifacts/"
22+
testModelPredictionInputSchemaURI = "gs://test-bucket/schemas/input_schema.json"
2323
testModelPredictionOutputSchemaURI = "gs://test-bucket/schemas/output_schema.json"
24-
testEndpointPath = "projects/test-project/locations/us-central1/endpoints/test-endpoint"
25-
testModelName = "projects/test-project/locations/us-central1/models/1234567890"
26-
testCreateTime = "2023-10-15T10:30:00Z"
24+
testEndpointPath = "projects/test-project/locations/us-central1/endpoints/test-endpoint"
25+
testModelName = "projects/test-project/locations/us-central1/models/1234567890"
26+
testCreateTime = "2023-10-15T10:30:00Z"
2727
)
2828

2929
//nolint:paralleltest,tparallel // Cannot run in parallel due to shared testFactoryRegistry
File renamed without changes.

pkg/services/modelendpointget.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package services
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
8+
)
9+
10+
// EndpointModelGetter allows getting endpoints and their deployed models from the registry.
11+
type EndpointModelGetter interface {
12+
Get(ctx context.Context, endpointName, deployedModelID string) (*aiplatformpb.Endpoint, *aiplatformpb.DeployedModel, error)
13+
Close() error
14+
}
15+
16+
// VertexEndpointModelGetter implements the EndpointModelGetter interface for Vertex AI.
17+
type VertexEndpointModelGetter struct {
18+
endpointClient VertexEndpointClient
19+
projectID string
20+
region string
21+
}
22+
23+
// NewVertexEndpointModelGetter creates a new VertexEndpointModelGetter with the provided endpoint client.
24+
func NewVertexEndpointModelGetter(endpointClient VertexEndpointClient, projectID, region string) *VertexEndpointModelGetter {
25+
return &VertexEndpointModelGetter{
26+
endpointClient: endpointClient,
27+
projectID: projectID,
28+
region: region,
29+
}
30+
}
31+
32+
// Get retrieves an endpoint and finds the specified deployed model within it.
33+
// Returns the endpoint, the deployed model (if found), and any error.
34+
func (g *VertexEndpointModelGetter) Get(ctx context.Context, endpointName, deployedModelID string) (*aiplatformpb.Endpoint, *aiplatformpb.DeployedModel, error) {
35+
getReq := &aiplatformpb.GetEndpointRequest{
36+
Name: fmt.Sprintf("projects/%s/locations/%s/endpoints/%s",
37+
g.projectID, g.region, endpointName),
38+
}
39+
40+
endpoint, err := g.endpointClient.GetEndpoint(ctx, getReq)
41+
if err != nil {
42+
return nil, nil, fmt.Errorf("failed to get endpoint: %w", err)
43+
}
44+
45+
// Verify the deployed model still exists and update its properties
46+
var foundDeployedModel *aiplatformpb.DeployedModel
47+
for _, deployedModel := range endpoint.DeployedModels {
48+
if deployedModel.Id == deployedModelID {
49+
foundDeployedModel = deployedModel
50+
51+
break
52+
}
53+
}
54+
55+
return endpoint, foundDeployedModel, nil
56+
}
57+
58+
// Close closes the endpoint client.
59+
func (g *VertexEndpointModelGetter) Close() error {
60+
return g.endpointClient.Close()
61+
}
File renamed without changes.

pkg/services/modelget.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package services
22

33
import (
44
"context"
5+
"fmt"
56

67
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
78
)
@@ -34,7 +35,7 @@ func (g *VertexModelGet) Get(ctx context.Context, modelName string) (*aiplatform
3435

3536
model, err := g.modelClient.GetModel(ctx, getReq)
3637
if err != nil {
37-
return nil, err
38+
return nil, fmt.Errorf("failed to get model: %w", err)
3839
}
3940

4041
return model, nil

sdk/go/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ go 1.24.5
44

55
require (
66
github.com/blang/semver v3.5.1+incompatible
7-
github.com/pulumi/pulumi/sdk/v3 v3.196.0
7+
github.com/pulumi/pulumi/sdk/v3 v3.197.0
88
)
99

1010
require (

sdk/go/go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ github.com/pulumi/appdash v0.0.0-20231130102222-75f619a67231 h1:vkHw5I/plNdTr435
155155
github.com/pulumi/appdash v0.0.0-20231130102222-75f619a67231/go.mod h1:murToZ2N9hNJzewjHBgfFdXhZKjY3z5cYC1VXk+lbFE=
156156
github.com/pulumi/esc v0.17.0 h1:oaVOIyFTENlYDuqc3pW75lQT9jb2cd6ie/4/Twxn66w=
157157
github.com/pulumi/esc v0.17.0/go.mod h1:XnSxlt5NkmuAj304l/gK4pRErFbtqq6XpfX1tYT9Jbc=
158-
github.com/pulumi/pulumi/sdk/v3 v3.196.0 h1:OwD+S4udFwxrdfw9n4dHv6gToF+SQNtggQJIfacSBYQ=
159-
github.com/pulumi/pulumi/sdk/v3 v3.196.0/go.mod h1:aV0+c5xpSYccWKmOjTZS9liYCqh7+peu3cQgSXu7CJw=
158+
github.com/pulumi/pulumi/sdk/v3 v3.197.0 h1:ZNKda7CQpfVbRS2r/7U5F+s4iejfL9HK39bXl5CCTpY=
159+
github.com/pulumi/pulumi/sdk/v3 v3.197.0/go.mod h1:aV0+c5xpSYccWKmOjTZS9liYCqh7+peu3cQgSXu7CJw=
160160
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
161161
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
162162
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=

0 commit comments

Comments
 (0)