aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/extract_lambda.py131
-rw-r--r--tests/test_extract_lambda.py21
2 files changed, 89 insertions, 63 deletions
diff --git a/src/extract_lambda.py b/src/extract_lambda.py
index 15fe785..7efaac0 100644
--- a/src/extract_lambda.py
+++ b/src/extract_lambda.py
@@ -1,14 +1,15 @@
-from pg8000.native import Connection, InterfaceError, identifier
-import boto3
import csv
-from botocore.exceptions import ClientError
-import logging
import json
-from datetime import datetime
+import logging
import re
+from datetime import datetime
+from io import StringIO
+import boto3
+from botocore.exceptions import ClientError
+from pg8000.native import Connection, InterfaceError, identifier
-logger = logging.getLogger()
+logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# DB Exception class
@@ -28,6 +29,7 @@ def lambda_handler(event, context):
and converts all tables to CSV and if any of those tables do not exist in, or are different to the ones in s3, it uploads them
it uses 3 helper functions to achieve these 3 functionalities
"""
+ db = None
try:
db = connect_to_database()
existing_files = list_existing_s3_files()
@@ -39,14 +41,13 @@ def lambda_handler(event, context):
"statusCode": 200,
"body": json.dumps("No changes detected, no CSV files were uploaded."),
}
- else:
- return {
- "statusCode": 200,
- "body": json.dumps(
- f"""CSV files processed for {', '.join(any_changes['updated'])} and uploaded successfully.{
- 'The following tables were not updated: '+', '.join(any_changes['no change']) if any_changes['no change'] else ''}"""
- ),
- }
+ return {
+ "statusCode": 200,
+ "body": json.dumps(
+ f"""CSV files processed for {', '.join(any_changes['updated'])} and uploaded successfully.{
+ 'The following tables were not updated: '+', '.join(any_changes['no change']) if any_changes['no change'] else ''}"""
+ ),
+ }
except Exception as e:
logger.error(f"Error: {e}")
return {"statusCode": 500, "body": json.dumps("Internal server error.")}
@@ -55,17 +56,24 @@ def lambda_handler(event, context):
db.close()
-def retrieve_secrets(
- sm_client=boto3.client("secretsmanager"), secret_name="bentley-secrets"
-):
+def retrieve_secrets():
+ secret_name = "bentley-secrets"
+ region_name = "eu-west-2"
+
+ # Create a Secrets Manager client
+ session = boto3.session.Session()
+ client = session.client(service_name="secretsmanager", region_name=region_name)
+
try:
- response = sm_client.get_secret_value(SecretId=secret_name)
- if "SecretString" in response:
- secret = json.loads(response["SecretString"])
- return secret
+ get_secret_value_response = client.get_secret_value(SecretId=secret_name)
except ClientError as e:
- logger.error(f"Could not retrieve secrets: {e}")
+ logger.error(f"Failed to retrieve secret {secret_name}: {str(e)}")
raise e
+ except KeyError:
+ logger.error(f"Secret {secret_name} does not contain a SecretString")
+ raise ValueError(f"Secret {secret_name} does not contain a SecretString")
+
+ return get_secret_value_response["SecretString"]
def connect_to_database() -> Connection:
@@ -123,6 +131,30 @@ def list_existing_s3_files(bucket_name=extract_bucket(), client=boto3.client("s3
return existing_files
+def get_latest_timestamp(existing_files):
+ all_datetimes = []
+ for file_name in existing_files.keys():
+ match = re.search(r"\/(.+/).+_(.+)\.csv", file_name)
+ if match:
+ datetime_str = "".join(match.group(1, 2))
+ all_datetimes.append(datetime.strptime(datetime_str, "%Y/%m/%d/%H:%M:%S"))
+ return max(all_datetimes) if all_datetimes else datetime.min
+
+
+def stream_to_s3(table_name, rows, column_names, s3_client, bucket_name, s3_key):
+ csv_buffer = StringIO()
+ csv_writer = csv.writer(csv_buffer)
+
+ csv_writer.writerow(column_names)
+
+ for row in rows:
+ csv_writer.writerow(row)
+
+ csv_buffer.seek(0)
+
+ s3_client.upload_fileobj(csv_buffer, bucket_name, s3_key)
+
+
def process_and_upload_tables(db, existing_files, client=boto3.client("s3")):
"""Creates a list of the tables from a database query and
then selects everything from each table in individual queries
@@ -131,53 +163,42 @@ def process_and_upload_tables(db, existing_files, client=boto3.client("s3")):
to files, or new tables/files it uploads them to the s3 bucket
"""
load_status = {"updated": [], "no change": []}
- # Retrieving the latest file timestamp from S3 extract bucket
- all_datetimes = []
- for file_names in existing_files.keys():
- datetime_str_on_s3 = "".join(
- re.search(r"\/(.+/).+_(.+)\.csv", file_names).group(1, 2)
- )
- all_datetimes.append(datetime.strptime(datetime_str_on_s3, "%Y/%m/%d/%H:%M:%S"))
- latest_timestamp = max(all_datetimes)
+ latest_timestamp = get_latest_timestamp(existing_files)
- # Iterating through tables on the database and retrieving only latest changes vs previous file load
tables = db.run(
"""
- SELECT table_name
- FROM information_schema.tables
- WHERE table_schema='public' AND table_type='BASE TABLE';"""
+ SELECT table_name
+ FROM information_schema.tables
+ WHERE table_schema='public'
+ AND table_type='BASE TABLE';
+ """
)
+
for table in tables:
- print(tables)
table_name = table[0]
rows = db.run(
f"SELECT * FROM {identifier(table_name)} WHERE last_updated >= :latest;",
latest={datetime.strftime(latest_timestamp, "%Y-%m-%d %H:%M:%S")},
)
- print("rows", rows)
- # Creating a temporary file path and writing the column name to it followed by each row of data
if rows:
- csv_file_path = f"/tmp/{table_name}.csv"
- with open(csv_file_path, "w", newline="") as file:
- writer = csv.writer(file)
- # column_names = [desc["name"] for desc in db.columns(f"SELECT * FROM {table_name};")]
- column_names = [
- col_name[0]
- for col_name in db.run(
- """SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS
- WHERE table_name = :table ;""",
- table=table_name,
- )
- ]
- writer.writerow(column_names)
- writer.writerows(rows)
- s3_key = datetime.strftime(
- datetime.today(), f"{table_name}/%Y/%m/%d/{table_name}_%H:%M:%S.csv"
+ column_names = [
+ col_name[0]
+ for col_name in db.run(
+ """SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS
+ WHERE table_name = :table ;""",
+ table=table_name,
+ )
+ ]
+
+ s3_key = (
+ f"{table_name}/{datetime.now().strftime('%Y/%m/%d')}/"
+ f"{table_name}_{datetime.now().strftime('%H:%M:%S')}.csv"
)
- # Writing the new file to S3 extract bucket:
try:
- client.upload_file(csv_file_path, extract_bucket(), s3_key)
+ stream_to_s3(
+ table_name, rows, column_names, client, extract_bucket(), s3_key
+ )
load_status["updated"].append(table_name)
logger.info(f"Uploaded {s3_key} to S3.")
except ClientError as e:
diff --git a/tests/test_extract_lambda.py b/tests/test_extract_lambda.py
index 3931cfc..3d15927 100644
--- a/tests/test_extract_lambda.py
+++ b/tests/test_extract_lambda.py
@@ -58,7 +58,7 @@ def s3_mock_bucket(s3_client):
class TestLambdaHandler:
- def test_lambda_handler_files_processed_and_uploaded_successfully(self, mocker):
+ def test_files_processed_and_uploaded_successfully(self, mocker):
mock_db = MagicMock()
mock_db.run.side_effect = [
[["Fruits"]],
@@ -72,7 +72,11 @@ class TestLambdaHandler:
]
with patch("src.extract_lambda.connect_to_database", return_value=mock_db):
mock_process_and_upload_tables = mocker.patch(
- "src.extract_lambda.process_and_upload_tables", return_value=mock_db
+ "src.extract_lambda.process_and_upload_tables",
+ return_value={
+ "updated": ["Fruits"],
+ "no change": ["Vegetable", "Berry"],
+ },
)
mock_list_existing_s3_files = mocker.patch(
"src.extract_lambda.list_existing_s3_files", return_value={}
@@ -81,15 +85,15 @@ class TestLambdaHandler:
context = {}
response = lambda_handler(event, context)
assert response["statusCode"] == 200
- assert (
- json.loads(response["body"])
- == "CSV files processed and uploaded successfully."
+ assert json.loads(response["body"]) == (
+ "CSV files processed for Fruits and uploaded successfully."
+ "The following tables were not updated: Vegetable, Berry"
)
mock_list_existing_s3_files.assert_called_once()
mock_process_and_upload_tables.assert_called_once_with(mock_db, {})
mock_db.close.assert_called_once()
- def test_lambda_handler_no_changes_detected_no_files_uploaded(self, mocker):
+ def test_no_changes_detected_no_files_uploaded(self, mocker):
mock_db = MagicMock()
mock_db.run.side_effect = [
[["Fruits"]],
@@ -104,7 +108,8 @@ class TestLambdaHandler:
with patch("src.extract_lambda.connect_to_database", return_value=mock_db):
mock_process_and_upload_tables = mocker.patch(
- "src.extract_lambda.process_and_upload_tables", return_value=False
+ "src.extract_lambda.process_and_upload_tables",
+ return_value={"updated": [], "no change": ["Fruits"]},
)
mock_list_existing_s3_files = mocker.patch(
"src.extract_lambda.list_existing_s3_files", return_value={}
@@ -121,7 +126,7 @@ class TestLambdaHandler:
mock_process_and_upload_tables.assert_called_once_with(mock_db, {})
mock_db.close.assert_called_once()
- def test_lambda_handler_exception_error(self, mocker):
+ def test_exception_error(self, mocker):
with patch(
"src.extract_lambda.connect_to_database",
side_effect=Exception("Database connection error"),
git.ajschof.me — hosted by ajschofield — powered by cgit