Skip to content

Conversation

codeflash-ai[bot]
Copy link

@codeflash-ai codeflash-ai bot commented Oct 10, 2025

📄 456% (4.56x) speedup for _VertexRayClientContext._context_table_template in google/cloud/aiplatform/vertex_ray/client_builder.py

⏱️ Runtime : 8.74 milliseconds 1.57 milliseconds (best of 377 runs)

📝 Explanation and details

The optimization introduces template caching to eliminate repeated template instantiation overhead. The key changes are:

What was optimized:

  • Added class-level caching of VertexRayTemplate instances using hasattr() checks
  • Templates are now created once per class and reused across all method calls
  • Replaced direct template instantiation with cached template access via cls._shell_uri_template and cls._table_template

Why this leads to speedup:

  • The original code creates two new VertexRayTemplate objects on every call to _context_table_template()
  • Template instantiation involves file I/O, string parsing, and object initialization overhead
  • By caching templates at the class level, this expensive initialization only happens once
  • Subsequent calls simply reuse the pre-initialized template objects and call their render() methods

Performance impact from profiler data:

  • Template creation time dropped from ~25.8ms (97.9% of total time) to ~0.13ms (2.5% of total time)
  • Overall method execution time improved from 26.6ms to 5.4ms (456% speedup)
  • The render() calls now dominate execution time (83.4%), which is the actual useful work

Test case benefits:
This optimization is particularly effective for:

  • Repeated calls with the same context instances (524% faster for 100 repeated calls)
  • Multiple unique contexts where template reuse across instances provides benefits (509-526% faster)
  • Basic single calls still see significant improvement (305-561% faster) due to reduced template initialization overhead

The optimization maintains identical functionality while dramatically reducing computational overhead through intelligent caching.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 554 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest  # used for our unit tests
from aiplatform.vertex_ray.client_builder import _VertexRayClientContext

# function to test
# -*- coding: utf-8 -*-

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Minimal stub implementations for external dependencies
class VertexRayTemplate:
    def __init__(self, template_name):
        self.template_name = template_name

    def render(self, **kwargs):
        # Return a string that includes the template name and all key-value pairs
        # This makes it easy for tests to assert correct values
        items = [f"{k}={repr(v)}" for k, v in sorted(kwargs.items())]
        return f"Template:{self.template_name}|" + "|".join(items)

class DummyClientContext:
    def __init__(
        self,
        python_version="3.8.10",
        ray_version="2.47.1",
        ray_commit="abcdefg",
        protocol_version="1.0",
        _num_clients=1,
        _context_to_restore=None,
    ):
        self.python_version = python_version
        self.ray_version = ray_version
        self.ray_commit = ray_commit
        self.protocol_version = protocol_version
        self._num_clients = _num_clients
        self._context_to_restore = _context_to_restore

# Simulate the aiplatform version
class aiplatform:
    __version__ = "1.38.0"

# Simulate ray version
class ray:
    __version__ = "2.47.1"

VERTEX_SDK_VERSION = aiplatform.__version__
from aiplatform.vertex_ray.client_builder import _VertexRayClientContext

# ------------------- UNIT TESTS -------------------

# ----------- BASIC TEST CASES -----------

def test_basic_all_fields_present():
    """Test basic case where all fields are present, including shell uri."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard.example.com",
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": "http://shell.example.com"
    }
    client_ctx = DummyClientContext()
    ctx = _VertexRayClientContext("pr-123", ray_head_uris, client_ctx)
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 57.9μs -> 9.06μs (539% faster)

def test_basic_shell_uri_missing():
    """Test case where shell uri is missing."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard.example.com"
    }
    client_ctx = DummyClientContext()
    ctx = _VertexRayClientContext("pr-456", ray_head_uris, client_ctx)
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 36.2μs -> 8.04μs (350% faster)

def test_basic_different_python_ray_versions():
    """Test case where python and ray versions are different."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard.example.com",
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": "http://shell.example.com"
    }
    client_ctx = DummyClientContext(python_version="3.9.1", ray_version="2.42.0")
    ctx = _VertexRayClientContext("pr-789", ray_head_uris, client_ctx)
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 55.4μs -> 8.39μs (561% faster)

# ----------- EDGE TEST CASES -----------

def test_edge_missing_dashboard_uri():
    """Test case where dashboard URI is missing (should raise ValueError)."""
    ray_head_uris = {
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": "http://shell.example.com"
    }
    client_ctx = DummyClientContext()
    with pytest.raises(ValueError):
        _VertexRayClientContext("pr-000", ray_head_uris, client_ctx)

def test_edge_empty_strings():
    """Test case with empty strings for all fields."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "",
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": ""
    }
    client_ctx = DummyClientContext(python_version="", ray_version="", ray_commit="", protocol_version="")
    ctx = _VertexRayClientContext("", ray_head_uris, client_ctx)
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 63.6μs -> 11.1μs (475% faster)

def test_edge_none_shell_uri():
    """Test case where shell uri is explicitly set to None."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard.example.com",
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": None
    }
    client_ctx = DummyClientContext()
    ctx = _VertexRayClientContext("pr-abc", ray_head_uris, client_ctx)
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 37.3μs -> 9.21μs (305% faster)


def test_edge_supported_ray_version_2_9_3():
    """Test case for supported ray version 2.9.3 with protocol_version."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard.example.com"
    }
    client_ctx = DummyClientContext(ray_version="2.9.3", protocol_version="2.0")
    orig_version = ray.__version__
    ray.__version__ = "2.9.3"
    try:
        ctx = _VertexRayClientContext("pr-proto", ray_head_uris, client_ctx)
        codeflash_output = ctx._context_table_template(); result = codeflash_output
    finally:
        ray.__version__ = orig_version

def test_edge_long_strings():
    """Test case with very long strings for fields."""
    long_str = "x" * 500
    ray_head_uris = {
        "RAY_DASHBOARD_URI": long_str,
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": long_str
    }
    client_ctx = DummyClientContext(
        python_version=long_str,
        ray_version=long_str,
        ray_commit=long_str,
        protocol_version=long_str,
    )
    ctx = _VertexRayClientContext(long_str, ray_head_uris, client_ctx)
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 71.3μs -> 14.4μs (395% faster)

# ----------- LARGE SCALE TEST CASES -----------

def test_large_scale_many_unique_contexts():
    """Test large scale: create many unique contexts and ensure output is correct."""
    for i in range(100):  # 100 contexts, each with unique values
        ray_head_uris = {
            "RAY_DASHBOARD_URI": f"http://dashboard{i}.example.com",
            "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": f"http://shell{i}.example.com"
        }
        client_ctx = DummyClientContext(
            python_version=f"3.8.{i}",
            ray_version="2.47.1",
            ray_commit=f"commit{i}",
            protocol_version=f"proto{i}"
        )
        ctx = _VertexRayClientContext(f"pr-{i}", ray_head_uris, client_ctx)
        codeflash_output = ctx._context_table_template(); result = codeflash_output # 2.97ms -> 475μs (524% faster)

def test_large_scale_max_length_strings():
    """Test large scale: fields with maximum allowed length (999 chars)."""
    max_str = "y" * 999
    ray_head_uris = {
        "RAY_DASHBOARD_URI": max_str,
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": max_str
    }
    client_ctx = DummyClientContext(
        python_version=max_str,
        ray_version=max_str,
        ray_commit=max_str,
        protocol_version=max_str,
    )
    ctx = _VertexRayClientContext(max_str, ray_head_uris, client_ctx)
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 61.7μs -> 12.4μs (398% faster)

def test_large_scale_no_shell_uri_many_contexts():
    """Test large scale: many contexts, no shell uri."""
    for i in range(100):  # 100 contexts
        ray_head_uris = {
            "RAY_DASHBOARD_URI": f"http://dashboard{i}.example.com"
        }
        client_ctx = DummyClientContext(
            python_version=f"3.9.{i}",
            ray_version="2.42.0",
            ray_commit=f"commit{i}",
            protocol_version=f"proto{i}"
        )
        ctx = _VertexRayClientContext(f"pr-{i}", ray_head_uris, client_ctx)
        codeflash_output = ctx._context_table_template(); result = codeflash_output # 1.65ms -> 403μs (310% faster)

def test_large_scale_all_none_fields():
    """Test large scale: all fields None except dashboard URI."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard.example.com"
    }
    client_ctx = DummyClientContext(
        python_version=None,
        ray_version=None,
        ray_commit=None,
        protocol_version=None,
        _num_clients=None,
        _context_to_restore=None
    )
    ctx = _VertexRayClientContext(None, ray_head_uris, client_ctx)
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 40.5μs -> 8.62μs (370% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest  # used for our unit tests
from aiplatform.vertex_ray.client_builder import _VertexRayClientContext

# function to test
# -*- coding: utf-8 -*-

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# --- Minimal stub/mock implementations for testability ---

class VertexRayTemplate:
    """Stub class to simulate template rendering."""
    def __init__(self, template_name):
        self.template_name = template_name
        self.render_calls = []

    def render(self, **kwargs):
        # For testability, return a string that encodes the template name and the sorted kwargs.
        self.render_calls.append((self.template_name, kwargs))
        keys = sorted(kwargs)
        items = [f"{k}={repr(kwargs[k])}" for k in keys]
        return f"Rendered({self.template_name}):" + ",".join(items)

class DummyClientContext:
    """Stub for ray.client_builder.ClientContext"""
    def __init__(
        self,
        python_version="3.9.0",
        ray_version="2.47.1",
        ray_commit="abcdefg",
        protocol_version="1.2.3",
        _num_clients=1,
        _context_to_restore=None,
    ):
        self.python_version = python_version
        self.ray_version = ray_version
        self.ray_commit = ray_commit
        self.protocol_version = protocol_version
        self._num_clients = _num_clients
        self._context_to_restore = _context_to_restore

# Simulate aiplatform.__version__ and ray.__version__
class aiplatform:
    __version__ = "1.43.0"
from aiplatform.vertex_ray.client_builder import _VertexRayClientContext

# unit tests

# ---- Basic Test Cases ----

def test_context_table_template_basic_all_fields():
    """Basic: All fields present including shell_uri."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard:8265",
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": "http://shell:8888"
    }
    ray_client_context = DummyClientContext(
        python_version="3.8.10",
        ray_version="2.47.1",
        ray_commit="commit123"
    )
    ctx = _VertexRayClientContext(
        persistent_resource_id="resource-abc",
        ray_head_uris=ray_head_uris,
        ray_client_context=ray_client_context
    )
    # Patch template factory to track calls
    ctx._template_factory = VertexRayTemplate
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 58.2μs -> 9.30μs (526% faster)

def test_context_table_template_basic_no_shell_uri():
    """Basic: No shell_uri present."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard:8265"
    }
    ray_client_context = DummyClientContext(
        python_version="3.7.9",
        ray_version="2.47.1",
        ray_commit="commit456"
    )
    ctx = _VertexRayClientContext(
        persistent_resource_id="resource-def",
        ray_head_uris=ray_head_uris,
        ray_client_context=ray_client_context
    )
    ctx._template_factory = VertexRayTemplate
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 38.3μs -> 8.28μs (362% faster)

def test_context_table_template_basic_empty_strings():
    """Basic: All fields present but as empty strings."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "",
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": ""
    }
    ray_client_context = DummyClientContext(
        python_version="",
        ray_version="",
        ray_commit=""
    )
    ctx = _VertexRayClientContext(
        persistent_resource_id="",
        ray_head_uris=ray_head_uris,
        ray_client_context=ray_client_context
    )
    ctx._template_factory = VertexRayTemplate
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 58.6μs -> 10.8μs (445% faster)

# ---- Edge Test Cases ----


def test_context_table_template_shell_uri_is_none():
    """Edge: shell_uri explicitly set to None."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard:8265",
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": None
    }
    ray_client_context = DummyClientContext(
        python_version="3.9.1",
        ray_version="2.47.1",
        ray_commit="commit000"
    )
    ctx = _VertexRayClientContext(
        persistent_resource_id="resource-none",
        ray_head_uris=ray_head_uris,
        ray_client_context=ray_client_context
    )
    ctx._template_factory = VertexRayTemplate
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 58.7μs -> 12.2μs (380% faster)

def test_context_table_template_shell_uri_row_special_chars():
    """Edge: shell_uri contains special characters."""
    special_uri = "http://shell:8888/?token=abc!@#$%^&*()"
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard:8265",
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": special_uri
    }
    ray_client_context = DummyClientContext(
        python_version="3.9.2",
        ray_version="2.47.1",
        ray_commit="commit-special"
    )
    ctx = _VertexRayClientContext(
        persistent_resource_id="resource-special",
        ray_head_uris=ray_head_uris,
        ray_client_context=ray_client_context
    )
    ctx._template_factory = VertexRayTemplate
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 68.9μs -> 11.5μs (497% faster)

def test_context_table_template_long_resource_id():
    """Edge: Very long persistent_resource_id."""
    long_id = "r" * 500
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard:8265"
    }
    ray_client_context = DummyClientContext(
        python_version="3.9.2",
        ray_version="2.47.1",
        ray_commit="commit-long"
    )
    ctx = _VertexRayClientContext(
        persistent_resource_id=long_id,
        ray_head_uris=ray_head_uris,
        ray_client_context=ray_client_context
    )
    ctx._template_factory = VertexRayTemplate
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 41.9μs -> 10.0μs (317% faster)


def test_context_table_template_large_strings():
    """Large: Very large strings for fields."""
    big_str = "x" * 1000
    ray_head_uris = {
        "RAY_DASHBOARD_URI": big_str,
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": big_str
    }
    ray_client_context = DummyClientContext(
        python_version=big_str,
        ray_version=big_str,
        ray_commit=big_str
    )
    ctx = _VertexRayClientContext(
        persistent_resource_id=big_str,
        ray_head_uris=ray_head_uris,
        ray_client_context=ray_client_context
    )
    ctx._template_factory = VertexRayTemplate
    codeflash_output = ctx._context_table_template(); result = codeflash_output # 85.5μs -> 18.1μs (372% faster)

def test_context_table_template_many_unique_calls():
    """Large: Generate many unique invocations with different data."""
    for i in range(10):  # 10 is enough for coverage without excessive runtime
        ray_head_uris = {
            "RAY_DASHBOARD_URI": f"http://dashboard{i}:8265",
            "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": f"http://shell{i}:8888"
        }
        ray_client_context = DummyClientContext(
            python_version=f"3.9.{i}",
            ray_version="2.47.1",
            ray_commit=f"commit{i}"
        )
        ctx = _VertexRayClientContext(
            persistent_resource_id=f"resource-{i}",
            ray_head_uris=ray_head_uris,
            ray_client_context=ray_client_context
        )
        ctx._template_factory = VertexRayTemplate
        codeflash_output = ctx._context_table_template(); result = codeflash_output # 343μs -> 56.4μs (509% faster)

def test_context_table_template_performance_many_calls():
    """Large: Performance with many repeated calls (simulate scale)."""
    ray_head_uris = {
        "RAY_DASHBOARD_URI": "http://dashboard:8265",
        "RAY_HEAD_NODE_INTERACTIVE_SHELL_URI": "http://shell:8888"
    }
    ray_client_context = DummyClientContext(
        python_version="3.9.0",
        ray_version="2.47.1",
        ray_commit="commit-perf"
    )
    ctx = _VertexRayClientContext(
        persistent_resource_id="resource-perf",
        ray_head_uris=ray_head_uris,
        ray_client_context=ray_client_context
    )
    ctx._template_factory = VertexRayTemplate
    # Call the function 100 times in a loop to check for leaks/performance
    results = set()
    for _ in range(100):
        codeflash_output = ctx._context_table_template(); result = codeflash_output # 2.89ms -> 463μs (524% faster)
        results.add(result)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-_VertexRayClientContext._context_table_template-mgkjcdxk and push.

Codeflash

The optimization introduces **template caching** to eliminate repeated template instantiation overhead. The key changes are:

**What was optimized:**
- Added class-level caching of `VertexRayTemplate` instances using `hasattr()` checks
- Templates are now created once per class and reused across all method calls
- Replaced direct template instantiation with cached template access via `cls._shell_uri_template` and `cls._table_template`

**Why this leads to speedup:**
- The original code creates two new `VertexRayTemplate` objects on every call to `_context_table_template()`
- Template instantiation involves file I/O, string parsing, and object initialization overhead
- By caching templates at the class level, this expensive initialization only happens once
- Subsequent calls simply reuse the pre-initialized template objects and call their `render()` methods

**Performance impact from profiler data:**
- Template creation time dropped from ~25.8ms (97.9% of total time) to ~0.13ms (2.5% of total time) 
- Overall method execution time improved from 26.6ms to 5.4ms (456% speedup)
- The `render()` calls now dominate execution time (83.4%), which is the actual useful work

**Test case benefits:**
This optimization is particularly effective for:
- **Repeated calls** with the same context instances (524% faster for 100 repeated calls)
- **Multiple unique contexts** where template reuse across instances provides benefits (509-526% faster)
- **Basic single calls** still see significant improvement (305-561% faster) due to reduced template initialization overhead

The optimization maintains identical functionality while dramatically reducing computational overhead through intelligent caching.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 10, 2025 07:39
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Oct 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants