Skip to content

Commit 2138d3e

Browse files
authored
Merge pull request #9 from paywithextend/receipt-uploads-and-sort-params
Receipt uploads and sort params
2 parents 015acae + 02af3fd commit 2138d3e

File tree

7 files changed

+283
-3
lines changed

7 files changed

+283
-3
lines changed

extend/client.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,5 +148,41 @@ async def patch(self, url: str, data: Dict) -> Any:
148148
response.raise_for_status()
149149
return response.json()
150150

151+
async def post_multipart(
152+
self,
153+
url: str,
154+
data: Optional[Dict[str, Any]] = None,
155+
files: Optional[Dict[str, Any]] = None,
156+
) -> Any:
157+
"""Make a POST request with multipart/form-data payload.
158+
159+
This method is designed to support file uploads along with optional form data.
160+
161+
Args:
162+
url (str): The API endpoint path (e.g., "/receiptattachments")
163+
data (Optional[Dict[str, Any]]): Optional form fields to include in the request.
164+
files (Optional[Dict[str, Any]]): Files to be uploaded. For example,
165+
{"file": file_obj} where file_obj is an open file in binary mode.
166+
167+
Returns:
168+
The JSON response from the API.
169+
170+
Raises:
171+
httpx.HTTPError: If the request fails.
172+
ValueError: If the response is not valid JSON.
173+
"""
174+
# When sending multipart data, we pass `data` (for non-file fields)
175+
# and `files` (for file uploads) separately.
176+
async with httpx.AsyncClient() as client:
177+
response = await client.post(
178+
self.build_full_url(url),
179+
headers=self.headers,
180+
data=data,
181+
files=files,
182+
timeout=httpx.Timeout(30)
183+
)
184+
response.raise_for_status()
185+
return response.json()
186+
151187
def build_full_url(self, url: Optional[str]):
152188
return f"https://{API_HOST}{url or ''}"

extend/extend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .client import APIClient
33
from .resources.credit_cards import CreditCards
44
from .resources.expense_data import ExpenseData
5+
from .resources.receipt_attachments import ReceiptAttachments
56
from .resources.transactions import Transactions
67

78

@@ -31,3 +32,4 @@ def __init__(self, api_key: str, api_secret: str):
3132
self.virtual_cards = VirtualCards(self._api_client)
3233
self.transactions = Transactions(self._api_client)
3334
self.expense_data = ExpenseData(self._api_client)
35+
self.receipt_attachments = ReceiptAttachments(self._api_client)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Dict, IO
2+
3+
from extend.client import APIClient
4+
from .resource import Resource
5+
6+
7+
class ReceiptAttachments(Resource):
8+
@property
9+
def _base_url(self) -> str:
10+
return "/receiptattachments"
11+
12+
def __init__(self, api_client: APIClient):
13+
super().__init__(api_client)
14+
15+
async def create_receipt_attachment(
16+
self,
17+
transaction_id: str,
18+
file: IO,
19+
) -> Dict:
20+
"""Create a receipt attachment for a transaction by uploading a file using multipart form data.
21+
22+
Args:
23+
transaction_id (str): The unique identifier of the transaction to attach the receipt to
24+
file (IO): A file-like object opened in binary mode that contains the data
25+
to be uploaded
26+
27+
Returns:
28+
Dict: A dictionary representing the receipt attachment, including:
29+
- id: Unique identifier of the receipt attachment.
30+
- transactionId: The associated transaction ID.
31+
- contentType: The MIME type of the uploaded file.
32+
- urls: A dictionary with URLs for the original image, main image, and thumbnail.
33+
- createdAt: Timestamp when the receipt attachment was created.
34+
- uploadType: A string describing the type of upload (e.g., "TRANSACTION", "VIRTUAL_CARD").
35+
36+
Raises:
37+
httpx.HTTPError: If the request fails
38+
"""
39+
40+
return await self._request(
41+
method="post_multipart",
42+
data={"transaction_id": transaction_id},
43+
files={"file": file})

extend/resources/resource.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ async def _request(
2020
self,
2121
method: str,
2222
path: str = None,
23-
params: Optional[Dict] = None
23+
params: Optional[Dict] = None,
24+
data: Optional[Dict[str, Any]] = None,
25+
files: Optional[Dict[str, Any]] = None,
2426
) -> Any:
2527
if params is not None:
2628
params = {k: v for k, v in params.items() if v is not None}
@@ -33,6 +35,12 @@ async def _request(
3335
return await self._api_client.put(self.build_full_path(path), params)
3436
case "patch":
3537
return await self._api_client.patch(self.build_full_path(path), params)
38+
case "post_multipart":
39+
return await self._api_client.post_multipart(
40+
self.build_full_path(path),
41+
data=data,
42+
files=files
43+
)
3644
case _:
3745
raise ValueError(f"Unsupported HTTP method: {method}")
3846

extend/resources/transactions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ async def get_transactions(
2222
min_amount_cents: Optional[int] = None,
2323
max_amount_cents: Optional[int] = None,
2424
search_term: Optional[str] = None,
25+
sort_field: Optional[str] = None,
2526
) -> Dict:
2627
"""Get a list of transactions with optional filtering and pagination.
2728
@@ -34,6 +35,9 @@ async def get_transactions(
3435
min_amount_cents (int): Minimum clearing amount in cents
3536
max_amount_cents (int): Maximum clearing amount in cents
3637
search_term (Optional[str]): Filter cards by search term (e.g., "Marketing")
38+
sort_field (Optional[str]): Field to sort by, with optional direction
39+
Use "recipientName", "merchantName", "amount", "date" for ASC
40+
Use "-recipientName", "-merchantName", "-amount", "-date" for DESC
3741
3842
Returns:
3943
Dict: A dictionary containing:
@@ -57,8 +61,8 @@ async def get_transactions(
5761
"minClearingBillingCents": min_amount_cents,
5862
"maxClearingBillingCents": max_amount_cents,
5963
"search": search_term,
64+
"sort": sort_field,
6065
}
61-
params = {k: v for k, v in params.items() if v is not None}
6266

6367
return await self._request(method="get", params=params)
6468

extend/resources/virtual_cards.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ async def get_virtual_cards(
2121
status: Optional[str] = None,
2222
recipient: Optional[str] = None,
2323
search_term: Optional[str] = None,
24+
sort_field: Optional[str] = None,
25+
sort_direction: Optional[str] = None,
2426
) -> Dict:
2527
"""Get a list of virtual cards with optional filtering and pagination.
2628
@@ -30,6 +32,8 @@ async def get_virtual_cards(
3032
status (Optional[str]): Filter cards by status (e.g., "ACTIVE", "CANCELLED")
3133
recipient (Optional[str]): Filter cards by recipient id (e.g., "u_1234")
3234
search_term (Optional[str]): Filter cards by search term (e.g., "Marketing")
35+
sort_field (Optional[str]): Field to sort by "createdAt", "updatedAt", "balanceCents", "displayName", "type", or "status"
36+
sort_direction (Optional[str]): Direction to sort (ASC or DESC)
3337
3438
Returns:
3539
Dict: A dictionary containing:
@@ -49,8 +53,9 @@ async def get_virtual_cards(
4953
"statuses": status,
5054
"recipient": recipient,
5155
"search": search_term,
56+
"sortField": sort_field,
57+
"sortDirection": sort_direction,
5258
}
53-
params = {k: v for k, v in params.items() if v is not None}
5459

5560
return await self._request(method="get", params=params)
5661

tests/test_integration.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import uuid
33
from datetime import datetime, timedelta
4+
from io import BytesIO
45

56
import pytest
67
from dotenv import load_dotenv
@@ -129,6 +130,47 @@ async def test_list_virtual_cards(self, extend):
129130
for card in response["virtualCards"]:
130131
assert card["status"] == "CLOSED"
131132

133+
@pytest.mark.asyncio
134+
async def test_list_virtual_cards_with_sorting(self, extend):
135+
"""Test listing virtual cards with various sorting options"""
136+
137+
# Test sorting by display name ascending
138+
asc_response = await extend.virtual_cards.get_virtual_cards(
139+
sort_field="displayName",
140+
sort_direction="ASC",
141+
per_page=50 # Ensure we get enough cards to compare
142+
)
143+
144+
# Test sorting by display name descending
145+
desc_response = await extend.virtual_cards.get_virtual_cards(
146+
sort_field="displayName",
147+
sort_direction="DESC",
148+
per_page=50 # Ensure we get enough cards to compare
149+
)
150+
151+
# Verify responses contain cards
152+
assert "virtualCards" in asc_response
153+
assert "virtualCards" in desc_response
154+
155+
# If sufficient cards exist, just verify the orders are different
156+
# rather than trying to implement our own sorting logic
157+
if len(asc_response["virtualCards"]) > 1 and len(desc_response["virtualCards"]) > 1:
158+
asc_ids = [card["id"] for card in asc_response["virtualCards"]]
159+
desc_ids = [card["id"] for card in desc_response["virtualCards"]]
160+
161+
# Verify the orders are different for different sort directions
162+
assert asc_ids != desc_ids, "ASC and DESC sorting should produce different results"
163+
164+
# Test other sort fields
165+
for field in ["createdAt", "updatedAt", "balanceCents", "status", "type"]:
166+
# Test both directions for each field
167+
for direction in ["ASC", "DESC"]:
168+
response = await extend.virtual_cards.get_virtual_cards(
169+
sort_field=field,
170+
sort_direction=direction
171+
)
172+
assert "virtualCards" in response, f"Sorting by {field} {direction} should return virtual cards"
173+
132174

133175
@pytest.mark.integration
134176
class TestTransactions:
@@ -152,6 +194,54 @@ async def test_list_transactions(self, extend):
152194
for field in required_fields:
153195
assert field in transaction, f"Transaction should contain '{field}' field"
154196

197+
@pytest.mark.asyncio
198+
async def test_list_transactions_with_sorting(self, extend):
199+
"""Test listing transactions with various sorting options"""
200+
201+
# Define sort fields - positive for ASC, negative (prefixed with -) for DESC
202+
sort_fields = [
203+
"recipientName", "-recipientName",
204+
"merchantName", "-merchantName",
205+
"amount", "-amount",
206+
"date", "-date"
207+
]
208+
209+
# Test each sort field
210+
for sort_field in sort_fields:
211+
# Get transactions with this sort
212+
response = await extend.transactions.get_transactions(
213+
sort_field=sort_field,
214+
per_page=10
215+
)
216+
217+
# Verify response contains transactions and basic structure
218+
assert isinstance(response, dict), f"Response for sort {sort_field} should be a dictionary"
219+
assert "transactions" in response, f"Response for sort {sort_field} should contain 'transactions' key"
220+
221+
# If we have enough data, test opposite sort direction for comparison
222+
if len(response["transactions"]) > 1:
223+
# Determine the field name and opposite sort field
224+
is_desc = sort_field.startswith("-")
225+
field_name = sort_field[1:] if is_desc else sort_field
226+
opposite_sort = field_name if is_desc else f"-{field_name}"
227+
228+
# Get transactions with opposite sort
229+
opposite_response = await extend.transactions.get_transactions(
230+
sort_field=opposite_sort,
231+
per_page=10
232+
)
233+
234+
# Get IDs in both sort orders for comparison
235+
sorted_ids = [tx["id"] for tx in response["transactions"]]
236+
opposite_sorted_ids = [tx["id"] for tx in opposite_response["transactions"]]
237+
238+
# If we have the same set of transactions in both responses,
239+
# verify that different sort directions produce different orders
240+
if set(sorted_ids) == set(opposite_sorted_ids) and len(sorted_ids) > 1:
241+
assert sorted_ids != opposite_sorted_ids, (
242+
f"Different sort directions for {field_name} should produce different results"
243+
)
244+
155245

156246
@pytest.mark.integration
157247
class TestRecurringCards:
@@ -303,6 +393,98 @@ async def test_get_expense_categories_and_labels(self, extend):
303393
assert "expenseLabels" in labels
304394

305395

396+
@pytest.mark.integration
397+
class TestTransactionExpenseData:
398+
"""Integration tests for updating transaction expense data using a specific expense category and label"""
399+
400+
@pytest.mark.asyncio
401+
async def test_update_transaction_expense_data_with_specific_category_and_label(self, extend):
402+
"""Test updating the expense data for a transaction using a specific expense category and label."""
403+
# Retrieve available expense categories (active ones)
404+
categories_response = await extend.expense_data.get_expense_categories(active=True)
405+
assert "expenseCategories" in categories_response, "Response should include 'expenseCategories'"
406+
expense_categories = categories_response["expenseCategories"]
407+
assert expense_categories, "No expense categories available for testing"
408+
409+
# For this test, pick the first expense category
410+
category = expense_categories[0]
411+
category_id = category["id"]
412+
413+
# Retrieve the labels for the chosen expense category
414+
labels_response = await extend.expense_data.get_expense_category_labels(
415+
category_id=category_id,
416+
page=0,
417+
per_page=10
418+
)
419+
assert "expenseLabels" in labels_response, "Response should include 'expenseLabels'"
420+
expense_labels = labels_response["expenseLabels"]
421+
assert expense_labels, "No expense labels available for the selected category"
422+
423+
# Pick the first label from the list
424+
label = expense_labels[0]
425+
label_id = label["id"]
426+
427+
# Retrieve at least one transaction to update expense data
428+
transactions_response = await extend.transactions.get_transactions(per_page=1)
429+
assert transactions_response.get("transactions"), "No transactions available for testing expense data update"
430+
transaction = transactions_response["transactions"][0]
431+
transaction_id = transaction["id"]
432+
433+
# Prepare the expense data payload with the specific category and label
434+
update_payload = {
435+
"expenseDetails": [
436+
{
437+
"categoryId": category_id,
438+
"labelId": label_id
439+
}
440+
]
441+
}
442+
443+
# Call the update_transaction_expense_data method
444+
response = await extend.transactions.update_transaction_expense_data(transaction_id, update_payload)
445+
446+
# Verify the response contains the transaction id and expected expense details
447+
assert "id" in response, "Response should include the transaction id"
448+
if "expenseDetails" in response:
449+
# Depending on the API response, the structure might vary; adjust assertions accordingly
450+
assert response["expenseDetails"] == update_payload["expenseDetails"], (
451+
"Expense details in the response should match the update payload"
452+
)
453+
454+
455+
@pytest.mark.integration
456+
class TestReceiptAttachments:
457+
"""Integration tests for receipt attachment operations"""
458+
459+
@pytest.mark.asyncio
460+
async def test_create_receipt_attachment(self, extend):
461+
"""Test creating a receipt attachment via multipart upload."""
462+
# Create a dummy PNG file in memory
463+
# This is a minimal PNG header plus extra bytes to simulate file content.
464+
png_header = b'\x89PNG\r\n\x1a\n'
465+
dummy_content = png_header + b'\x00' * 100
466+
file_obj = BytesIO(dummy_content)
467+
# Optionally set a name attribute for file identification in the upload
468+
file_obj.name = f"test_receipt_{uuid.uuid4()}.png"
469+
470+
# Retrieve a valid transaction id from existing transactions
471+
transactions_response = await extend.transactions.get_transactions(page=0, per_page=1)
472+
assert transactions_response.get("transactions"), "No transactions available for testing receipt attachment"
473+
transaction_id = transactions_response["transactions"][0]["id"]
474+
475+
# Call the receipt attachment upload method
476+
response = await extend.receipt_attachments.create_receipt_attachment(
477+
transaction_id=transaction_id,
478+
file=file_obj
479+
)
480+
481+
# Assert that the response contains expected keys
482+
assert "id" in response, "Receipt attachment should have an id"
483+
assert "urls" in response, "Receipt attachment should include urls"
484+
assert "contentType" in response, "Receipt attachment should include a content type"
485+
assert response["contentType"] == "image/png", "Content type should be 'image/png'"
486+
487+
306488
def test_environment_variables():
307489
"""Test that required environment variables are set"""
308490
assert os.getenv("EXTEND_API_KEY"), "EXTEND_API_KEY environment variable is required"

0 commit comments

Comments
 (0)