aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/load_lambda.py75
-rw-r--r--tests/test_load_lambda.py44
2 files changed, 80 insertions, 39 deletions
diff --git a/src/load_lambda.py b/src/load_lambda.py
index 8eaea32..6e6bc80 100644
--- a/src/load_lambda.py
+++ b/src/load_lambda.py
@@ -40,6 +40,7 @@ def lambda_handler(event, context):
logger.error(f"Error: {e}", exc_info=True)
return {"statusCode": 500, "body": json.dumps("Internal server error.")}
+
def retrieve_secrets():
secret_name = "bentley-RDS-credentials"
region_name = "eu-west-2"
@@ -59,7 +60,10 @@ def retrieve_secrets():
return get_secret_value_response["SecretString"]
+
# connect to database, slightly different way of doing it, to allow manipulation through pandas
+
+
def connect_to_db_and_return_engine():
try:
secrets = json.loads(retrieve_secrets())
@@ -68,13 +72,14 @@ def connect_to_db_and_return_engine():
user = secrets["user"]
password = secrets["password"]
database = secrets["database"]
- conn_str = f'postgresql+pg8000://{user}:{password}@{host}:{port}/{database}'
- engine = create_engine(conn_str) #interface between python (pandas) and SQL
+ conn_str = f"postgresql+pg8000://{user}:{password}@{host}:{port}/{database}"
+ # interface between python (pandas) and SQL
+ engine = create_engine(conn_str)
return engine
except Exception as e:
logger.error(f"Interface error: {e}")
raise RuntimeError("Failed to create database engine")
-
+
# get transform bucket
def get_transform_bucket(client=None):
@@ -85,9 +90,11 @@ def get_transform_bucket(client=None):
except ClientError as e:
logger.error(f"Error listing S3 buckets: {e}")
raise RuntimeError("Error listing S3 buckets")
-
+
transform_bucket_filter = [
- bucket["Name"] for bucket in response["Buckets"] if "transform" in bucket["Name"]
+ bucket["Name"]
+ for bucket in response["Buckets"]
+ if "transform" in bucket["Name"]
]
if not transform_bucket_filter:
@@ -96,9 +103,12 @@ def get_transform_bucket(client=None):
return transform_bucket_filter[0]
+
# list and then retrieve parquet files from S3 bucket
# convert parquet files into dataframes
-# return a dictionary of dataframes with name as key, and dataframe object as value
+# return a dictionary of dataframes with name as key, and dataframe object as value
+
+
def convert_parquet_files_to_dfs(bucket_name=None, client=None):
try:
if client is None:
@@ -110,10 +120,10 @@ def convert_parquet_files_to_dfs(bucket_name=None, client=None):
dfs = {}
if "Contents" in files:
for file in files["Contents"]:
- file_key = file['Key']
+ file_key = file["Key"]
try:
file_obj = client.get_object(Bucket=bucket_name, Key=file_key)
- parquet_file = pq.ParquetFile(BytesIO(file_obj['Body'].read()))
+ parquet_file = pq.ParquetFile(BytesIO(file_obj["Body"].read()))
df = parquet_file.read().to_pandas()
dfs[file_key] = df
except ClientError as e:
@@ -132,34 +142,51 @@ def convert_parquet_files_to_dfs(bucket_name=None, client=None):
return dfs
+
def upload_dfs_to_database():
upload_status = {"uploaded": [], "not_uploaded": []}
dict_of_dfs = convert_parquet_files_to_dfs()
db_engine = connect_to_db_and_return_engine()
- immutable_df_dict = ["dim_counterparty.parquet",
- "dim_date.parquet", #this needs to be mutable
- "dim_location.parquet",
- "dim_staff.parquet",
- "dim_design.parquet"]
- mutable_df_dict = ["fact_sales_order",
- "fact_purchase_order",
- "fact_payment",
- "dim_currency"]
-
+ immutable_df_dict = [
+ "dim_counterparty.parquet",
+ "dim_date.parquet", # this needs to be mutable
+ "dim_location.parquet",
+ "dim_staff.parquet",
+ "dim_design.parquet",
+ ]
+ mutable_df_dict = [
+ "fact_sales_order",
+ "fact_purchase_order",
+ "fact_payment",
+ "dim_currency",
+ ]
+
for file_name, df in dict_of_dfs.items():
if file_name in immutable_df_dict:
table_name = file_name.split(".")[0]
try:
- df.to_sql(table_name, con=db_engine, schema="project_team_2", if_exists="overwrite", index=False)
+ df.to_sql(
+ table_name,
+ con=db_engine,
+ schema="project_team_2",
+ if_exists="overwrite",
+ index=False,
+ )
upload_status["uploaded"].append(table_name)
except Exception as e:
logger.error(f"Error uploading dataframe {file_name} to database: {e}")
raise
- elif file_name.rsplit('_', 1)[0] in mutable_df_dict:
- table_name = file_name.rsplit('_', 1)[0]
+ elif file_name.rsplit("_", 1)[0] in mutable_df_dict:
+ table_name = file_name.rsplit("_", 1)[0]
try:
- df.to_sql(table_name, con=db_engine, schema="project_team_2", if_exists="overwrite", index=False)
- upload_status["uploaded"].append(table_name)
+ df.to_sql(
+ table_name,
+ con=db_engine,
+ schema="project_team_2",
+ if_exists="overwrite",
+ index=False,
+ )
+ upload_status["uploaded"].append(table_name)
except Exception as e:
logger.error(f"Error uploading dataframe {file_name} to database: {e}")
raise
@@ -167,4 +194,4 @@ def upload_dfs_to_database():
upload_status["not_uploaded"].append(file_name)
logger.error(f"{file_name} does not correspond with table in database")
db_engine.dispose()
- return upload_status \ No newline at end of file
+ return upload_status
diff --git a/tests/test_load_lambda.py b/tests/test_load_lambda.py
index e04ccec..88c71e4 100644
--- a/tests/test_load_lambda.py
+++ b/tests/test_load_lambda.py
@@ -5,7 +5,14 @@ from moto import mock_aws
import boto3
import os
import pytest
-from src.load_lambda import lambda_handler, connect_to_db_and_return_engine, get_transform_bucket, convert_parquet_files_to_dfs, upload_dfs_to_database
+from src.load_lambda import (
+ lambda_handler,
+ connect_to_db_and_return_engine,
+ get_transform_bucket,
+ convert_parquet_files_to_dfs,
+ upload_dfs_to_database,
+)
+
@pytest.fixture(scope="class")
def aws_credentials():
@@ -25,12 +32,15 @@ def mock_s3_client(aws_credentials):
class TestLambdaHandler:
pass
+
class TestRetrieveSecrets:
pass
+
class TestConnectToDBAndReturnEngine:
pass
+
class TestGetTransformBucket:
def test_raises_value_error_if_no_buckets(self, mock_s3_client):
with pytest.raises(ValueError, match="No transform bucket found"):
@@ -38,35 +48,38 @@ class TestGetTransformBucket:
def test_raises_value_error_if_no_transform_bucket(self, mock_s3_client):
mock_s3_client.create_bucket(
- Bucket="extract_bucket",
- CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
- )
+ Bucket="extract_bucket",
+ CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
+ )
with pytest.raises(ValueError, match="No transform bucket found"):
get_transform_bucket(mock_s3_client)
def test_returns_transform_bucket_if_one_bucket(self, mock_s3_client):
mock_s3_client.create_bucket(
- Bucket="transform_bucket",
- CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
- )
+ Bucket="transform_bucket",
+ CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
+ )
result = get_transform_bucket(mock_s3_client)
assert result == "transform_bucket"
def test_only_returns_transform_bucket_if_several_buckets(self, mock_s3_client):
mock_s3_client.create_bucket(
- Bucket="another_test_bucket",
- CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
- )
+ Bucket="another_test_bucket",
+ CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
+ )
result = get_transform_bucket(mock_s3_client)
assert result == "transform_bucket"
+
class TestConvertParquetToDfs:
def test_function_returns_empty_dictionary_if_no_files(self, mock_s3_client):
mock_s3_client.create_bucket(
- Bucket="transform_bucket",
- CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
- )
- result = convert_parquet_files_to_dfs(bucket_name="transform_bucket", client=mock_s3_client)
+ Bucket="transform_bucket",
+ CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
+ )
+ result = convert_parquet_files_to_dfs(
+ bucket_name="transform_bucket", client=mock_s3_client
+ )
assert result == {}
# def test_function_returns_dictionary_with_table_with_file_key():
@@ -74,5 +87,6 @@ class TestConvertParquetToDfs:
# result = convert_parquet_files_to_dfs(bucket_name="transform_bucket", client=mock_s3_client)
# assert "dim_staff" in result
+
class TestUploadDfsToDatabase:
- pass \ No newline at end of file
+ pass
git.ajschof.me — hosted by ajschofield — powered by cgit