diff options
| author | lian-manonog <lian.manonog@gmail.com> | 2024-08-28 11:34:10 +0100 |
|---|---|---|
| committer | lian-manonog <lian.manonog@gmail.com> | 2024-08-28 11:34:10 +0100 |
| commit | 459702f2bdd3070923187ec0d4c76c85dbe0d235 (patch) | |
| tree | 42c734652bcaa140e771261bf9ef2080de5cb271 | |
| parent | ad357ff34202827720dc216562dfbb0fbd65c297 (diff) | |
| parent | 04bbcc289f31c219d811850de86937bd6ba7796c (diff) | |
| download | de-project-bentley-459702f2bdd3070923187ec0d4c76c85dbe0d235.tar.gz de-project-bentley-459702f2bdd3070923187ec0d4c76c85dbe0d235.zip | |
Merge branch 'development' of https://github.com/ajschofield/de-project-bentley into test/transform-lambda
| -rw-r--r-- | requirements.txt | 3 | ||||
| -rw-r--r-- | src/load_lambda.py | 213 | ||||
| -rw-r--r-- | src/transform_lambda/dataframes.py (renamed from src/dataframes.py) | 150 | ||||
| -rw-r--r-- | src/transform_lambda/transform_lambda.py (renamed from src/transform_lambda.py) | 63 | ||||
| -rw-r--r-- | terraform/lambda.tf | 6 | ||||
| -rw-r--r-- | tests/test_dataframes.py (renamed from tests/test_fact_sales_order.py) | 73 | ||||
| -rw-r--r-- | tests/test_load_lambda.py | 196 |
7 files changed, 593 insertions, 111 deletions
diff --git a/requirements.txt b/requirements.txt index 0c81216..763b95a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,5 +30,6 @@ Werkzeug==3.0.3 xmltodict==0.13.0 s3fs pandas +pyarrow +SQLAlchemy bs4 -pyarrow
\ No newline at end of file diff --git a/src/load_lambda.py b/src/load_lambda.py index c6a8e60..7339ab9 100644 --- a/src/load_lambda.py +++ b/src/load_lambda.py @@ -1,2 +1,211 @@ -def lambda_handler(): - pass +import boto3 +from botocore.exceptions import ClientError +import pandas as pd +import pyarrow.parquet as pq +from io import BytesIO +import logging +import json +import traceback +from sqlalchemy import create_engine + + +logger = logging.getLogger(__name__) + +logging.basicConfig( + format="{asctime} - {levelname} - {message}", + style="{", + datefmt="%Y-%m-%d %H:%M", + level=logging.DEBUG, +) + +logging.getLogger("botocore").setLevel(logging.INFO) + + +def lambda_handler(event, context): + try: + uploaded_tables = upload_dfs_to_database() + if uploaded_tables["not_uploaded"]: + return { + "statusCode": 200, + "body": json.dumps("No dataframes were uploaded."), + } + elif uploaded_tables["uploaded"]: + return { + "statusCode": 200, + "body": json.dumps( + f"""The following dataframes were uploaded successfully: + {uploaded_tables["uploaded"]} .""" + ), + } + else: + logger.error(f"error") + return {"error"} + except Exception as e: + logger.error({e}) + return {"statusCode": 500, "body": {e}} + + +def retrieve_secrets(client=None, secret_name=None): + session = boto3.session.Session() + region_name = "eu-west-2" + + if secret_name == None: + secret_name = "bentley-RDS-credentials" + if client == None: + client = session.client(service_name="secretsmanager", region_name=region_name) + + try: + get_secret_value_response = client.get_secret_value(SecretId=secret_name) + print(get_secret_value_response) + except ClientError as e: + logger.error(f"Failed to retrieve secret {secret_name}: {str(e)}") + raise e + except KeyError: + logger.error(f"Secret {secret_name} does not contain a SecretString") + raise ValueError(f"Secret {secret_name} does not contain a SecretString") + + return get_secret_value_response["SecretString"] + + +# connect to database, slightly different way of doing it, to allow manipulation through pandas + + +def connect_to_db_and_return_engine(sm_secret=None): + if sm_secret is None: + sm_secret = json.loads(retrieve_secrets()) + + try: + secrets = sm_secret + host = secrets["host"] + port = secrets["port"] + user = secrets["user"] + password = secrets["password"] + database = secrets["database"] + conn_str = f"postgresql+pg8000://{user}:{password}@{host}:{port}/{database}" + # interface between python (pandas) and SQL + engine = create_engine(conn_str) + return engine + except Exception as e: + logger.error(f"Interface error: {e}") + raise RuntimeError("Failed to create database engine") + + +# get transform bucket +def get_transform_bucket(client=None): + if client is None: + client = boto3.client("s3") + try: + response = client.list_buckets() + except ClientError as e: + logger.error(f"Error listing S3 buckets: {e}") + raise RuntimeError("Error listing S3 buckets") + + transform_bucket_filter = [ + bucket["Name"] + for bucket in response["Buckets"] + if "transform" in bucket["Name"] + ] + + if not transform_bucket_filter: + logger.error("No transform bucket found") + raise ValueError("No transform bucket found") + + return transform_bucket_filter[0] + + +# list and then retrieve parquet files from S3 bucket +# convert parquet files into dataframes +# return a dictionary of dataframes with name as key, and dataframe object as value + + +def convert_parquet_files_to_dfs(bucket_name=None, client=None): + try: + if client is None: + client = boto3.client("s3") + if bucket_name is None: + bucket_name = get_transform_bucket() + files = client.list_objects_v2(Bucket=bucket_name) + + dfs = {} + if "Contents" in files: + for file in files["Contents"]: + file_key = file["Key"] + try: + file_obj = client.get_object(Bucket=bucket_name, Key=file_key) + parquet_file = pq.ParquetFile(BytesIO(file_obj["Body"].read())) + df = parquet_file.read().to_pandas() + dfs[file_key] = df + except ClientError as e: + logger.error(f"Unable to retrieve S3 object {file_key}: {e}") + except Exception as e: + logger.error(f"Unable to process file {file_key}: {e}") + else: + logger.error(f"No files found in {bucket_name}.") + return {} + except ValueError as value_error: + logger.error(f"Unable to list objects: {value_error}") + raise + except ClientError as client_error: + logger.error(f"Unable to list objects: {client_error}") + raise + + return dfs + + +def upload_dfs_to_database(): + upload_status = {"uploaded": [], "not_uploaded": []} + dict_of_dfs = convert_parquet_files_to_dfs() + db_engine = connect_to_db_and_return_engine() + immutable_df_dict = [ + "dim_counterparty.parquet", + "dim_date.parquet", # this needs to be mutable + "dim_location.parquet", + "dim_staff.parquet", + "dim_design.parquet", + ] + mutable_df_dict = [ + "fact_sales_order", + "fact_purchase_order", + "fact_payment", + "dim_currency", + ] + + for file_name, df in dict_of_dfs.items(): + print(df) + if file_name in immutable_df_dict: + table_name = file_name.split(".")[0] + print(table_name, "<<<<<") + try: + df.to_sql( + table_name, + con=db_engine, + if_exists="append", + index=False, + ) + upload_status["uploaded"].append(table_name) + except Exception as e: + logger.error(f"Error uploading dataframe {file_name} to database: {e}") + raise + elif file_name.rsplit("_", 1)[0] in mutable_df_dict: + table_name = file_name.rsplit("_", 1)[0] + try: + df.to_sql( + table_name, + con=db_engine, + schema="project_team_2", + if_exists="append", + index=False, + ) + upload_status["uploaded"].append(table_name) + except Exception as e: + logger.error(f"Error uploading dataframe {file_name} to database: {e}") + raise + else: + upload_status["not_uploaded"].append(file_name) + logger.error(f"{file_name} does not correspond with table in database") + db_engine.dispose() + return upload_status + + +if __name__ == "__main__": + lambda_handler(None, None) diff --git a/src/dataframes.py b/src/transform_lambda/dataframes.py index ab53063..2a46bd6 100644 --- a/src/dataframes.py +++ b/src/transform_lambda/dataframes.py @@ -16,85 +16,103 @@ import requests # dim_counterparty +# no test, same as fact_payment def create_fact_sales_order(dict_of_df): df_sales = dict_of_df["sales_order"] df_sales.index.name = "sales_record_id" - df_sales["created_date"] = pd.to_datetime(df_sales["created_at"]).dt.date - df_sales["created_time"] = pd.to_datetime(df_sales["created_at"]).dt.time - df_sales["last_updated_date"] = pd.to_datetime(df_sales["last_updated"]).dt.date - df_sales["last_updated_time"] = pd.to_datetime(df_sales["last_updated"]).dt.time - fact_sales_order = df_sales.loc[ - :, - [ - "sales_record_id", - "sales_order_id", - "created_date", - "created_time", - "last_updated_date", - "last_updated_time", - "sales_staff_id", - "counterparty_id", - "units_sold", - "unit_price", - "currency_id", - "design_id", - "agreed_payment_date", - "agreed_delivery_date", - "agreed_delivery_location_id", - ], - ] - return fact_sales_order + df_sales["created_date"] = df_sales["created_at"].astype("datetime64[ns]").dt.date + df_sales["created_time"] = ( + df_sales["created_at"].astype("datetime64[ns]").dt.floor("s").dt.time + ) + df_sales["last_updated_date"] = ( + df_sales["last_updated"].astype("datetime64[ns]").dt.date + ) + df_sales["last_updated_time"] = ( + df_sales["last_updated"].astype("datetime64[ns]").dt.floor("s").dt.time + ) + df_sales["agreed_delivery_date"] = pd.to_datetime( + df_sales["agreed_delivery_date"], format="%Y-%m-%d" + ) + df_sales["agreed_payment_date"] = pd.to_datetime( + df_sales["agreed_payment_date"], format="%Y-%m-%d" + ) + df_sales = df_sales.drop(labels=["created_at", "last_updated"], axis=1) + + df_sales.reset_index(inplace=True) + return df_sales -# fact_purchase_order from purchase_order + df_sales["created_date"] = df_sales["created_at"].astype("datetime64[ns]").dt.date + df_sales["created_time"] = ( + df_sales["created_at"].astype("datetime64[ns]").dt.floor("s").dt.time + ) + df_sales["last_updated_date"] = ( + df_sales["last_updated"].astype("datetime64[ns]").dt.date + ) + df_sales["last_updated_time"] = ( + df_sales["last_updated"].astype("datetime64[ns]").dt.floor("s").dt.time + ) + df_sales["agreed_delivery_date"] = pd.to_datetime( + df_sales["agreed_delivery_date"], format="%Y-%m-%d" + ) + df_sales["agreed_payment_date"] = pd.to_datetime( + df_sales["agreed_payment_date"], format="%Y-%m-%d" + ) + df_sales = df_sales.drop(labels=["created_at", "last_updated"], axis=1) + df_sales.reset_index(inplace=True) + return df_sales + + +# no test, same as fact_payment def create_fact_purchase_orders(dict_of_df): df_po = dict_of_df["purchase_order"] df_po.index.name = "purchase_record_id" - df_po["created_date"] = df_po["created_at"].date() - df_po["created_time"] = df_po["created_at"].dt.time - df_po["last_updated_date"] = df_po["last_updated_at"].date() - df_po["last_updated_time"] = df_po["last_updated_at"].dt.time + df_po["created_date"] = df_po["created_at"].astype("datetime64[ns]").dt.date + df_po["created_time"] = ( + df_po["created_at"].astype("datetime64[ns]").dt.floor("s").dt.time + ) + df_po["last_updated_date"] = df_po["last_updated"].astype("datetime64[ns]").dt.date + df_po["last_updated_time"] = ( + df_po["last_updated"].astype("datetime64[ns]").dt.floor("s").dt.time + ) df_po["agreed_delivery_date"] = pd.to_datetime( df_po["agreed_delivery_date"], format="%Y-%m-%d" ) df_po["agreed_payment_date"] = pd.to_datetime( df_po["agreed_payment_date"], format="%Y-%m-%d" ) - df_po.drop(labels=["created_at", "last_updated_at"], axis=1, inplace=True) + df_po = df_po.drop(labels=["created_at", "last_updated"], axis=1) + df_po.reset_index(inplace=True) return df_po +# test passed + + def create_fact_payment(dict_of_df): df_payment = dict_of_df["payment"] df_payment.index.name = "payment_record_id" - df_payment["created_date"] = df_payment["created_at"].date() - df_payment["created_time"] = df_payment["created_at"].time - df_payment["last_updated_date"] = df_payment["last_updated"].date() - df_payment["last_updated_time"] = df_payment["last_updated"].time + df_payment["created_date"] = ( + df_payment["created_at"].astype("datetime64[ns]").dt.date + ) + df_payment["created_time"] = ( + df_payment["created_at"].astype("datetime64[ns]").dt.floor("s").dt.time + ) + df_payment["last_updated_date"] = ( + df_payment["last_updated"].astype("datetime64[ns]").dt.date + ) + df_payment["last_updated_time"] = ( + df_payment["last_updated"].astype("datetime64[ns]").dt.floor("s").dt.time + ) df_payment["payment_date"] = pd.to_datetime( df_payment["payment_date"], format="%Y-%m-%d" ) - fact_payment = df_payment.loc[ - :, - [ - "payment_record_id", - "payment_id", - "created_date", - "created_time", - "last_updated_date", - "last_updated_time", - "transaction_id", - "counterparty_id", - "payment_amount", - "currency_id", - "payment_type_id", - "paid", - "payment_date", - ], - ] - return fact_payment + df_payment = df_payment.drop(labels=["created_at", "last_updated"], axis=1) + + df_payment.reset_index(inplace=True) + return df_payment # test passed @@ -108,6 +126,8 @@ def create_dim_transaction(dict_of_df): # test passed + + def create_dim_location(dict_of_df): df_loc = ( dict_of_df["address"] @@ -118,18 +138,24 @@ def create_dim_location(dict_of_df): def create_dim_counterparty(dict_of_df): - df_prefixed_address = dict_of_df["address"].add_prefix( - "counterparty_legal_", axis=1 + df_prefixed_address = ( + dict_of_df["address"] + .drop(labels=["created_at", "last_updated"], axis=1) + .add_prefix("counterparty_legal_", axis=1) ) df_cp = pd.merge( dict_of_df["counterparty"], df_prefixed_address, left_on="legal_address_id", right_on="counterparty_legal_address_id", - how="outer", + how="inner", ) df_cp.drop( - columns=["legal_address_id", "counterparty_legal_address_id"], inplace=True + columns=[ + "legal_address_id", + "counterparty_legal_address_id", + ], + inplace=True, ) return df_cp @@ -143,11 +169,11 @@ def create_dim_date(dict_of_df): create_fact_purchase_orders(dict_of_df), create_fact_sales_order(dict_of_df), ] - date_col_names = [ - col_name for col_name in list(fact_dfs[0].columns) if "date" in col_name - ] list_of_date_columns = [] for df in fact_dfs: + date_col_names = [ + col_name for col_name in list(df.columns) if "_date" in col_name + ] for col in date_col_names: list_of_date_columns.append(df[col]) sr_date = pd.array(pd.concat(list_of_date_columns), dtype="datetime64[ns]") @@ -164,6 +190,8 @@ def create_dim_date(dict_of_df): # tests passed + + def scrape_currency_names(): response = requests.get("https://www.xe.com/currency/").content soup = BeautifulSoup(response, "html.parser") diff --git a/src/transform_lambda.py b/src/transform_lambda/transform_lambda.py index 9830e0f..93b2284 100644 --- a/src/transform_lambda.py +++ b/src/transform_lambda/transform_lambda.py @@ -1,4 +1,3 @@ -from src.dataframes import * import json import boto3 import re @@ -6,10 +5,11 @@ import logging import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +from dataframes import * from botocore.exceptions import ClientError from pg8000.native import Connection, InterfaceError from datetime import datetime -import io + class DBConnectionException(Exception): """Wraps pg8000.native Error or DatabaseError.""" @@ -59,8 +59,6 @@ def lambda_handler(event, context): TABLES, bucket_name("extract"), client=boto3.client("s3") ) - print(dict_of_df) - immutable_df_dict = { "dim_counterparty": create_dim_counterparty(dict_of_df), "dim_date": create_dim_date(dict_of_df), @@ -108,7 +106,7 @@ def process_to_parquet_and_upload_to_s3( immutable_df_dict, mutable_df_dict, bucket, - client=boto3.client("s3") + client=boto3.client("s3"), ): status = {"uploaded": [], "not_uploaded": []} @@ -116,25 +114,22 @@ def process_to_parquet_and_upload_to_s3( if table_name in existing_s3_files: status["not_uploaded"].append(table_name) else: - parquet_buffer = io.BytesIO() - - df.to_parquet(parquet_buffer, engine="pyarrow") # or engine="fastparquet" - - parquet_buffer.seek(0) - - client.upload_fileobj(parquet_buffer, bucket, f"{table_name}.parquet") - + parquet_file = df.to_parquet( + f"{table_name}.parquet", engine="pyarrow" + ) # or fastparquet + # changed parquet_file variable to the file name + client.upload_file(f"{table_name}.parquet", bucket, f"{table_name}.parquet") status["uploaded"].append(table_name) - # for table_name, df in mutable_df_dict.items(): - # s3_key = datetime.strftime( - # datetime.today(), f"{table_name}/%Y/%m/%d/{table_name}_%H:%M:%S.parquet" - # ) - # parquet_file = df.to_parquet( - # f"{table_name}.parquet", engine="pyarrow" - # ) # or fastparquet - # client.upload_file(parquet_file, bucket, s3_key) - # status["uploaded"].append(table_name) + for table_name, df in mutable_df_dict.items(): + s3_key = datetime.strftime( + datetime.today(), f"{table_name}/%Y/%m/%d/{table_name}_%H:%M:%S.parquet" + ) + parquet_file = df.to_parquet( + f"{table_name}.parquet", engine="pyarrow" + ) # or fastparquet + client.upload_file(f"{table_name}.parquet", bucket, s3_key) + status["uploaded"].append(table_name) return status @@ -188,23 +183,15 @@ def read_from_s3_subfolder_to_df(tables, bucket, client=boto3.client("s3")): return table_dfs - - def bucket_name(bucket_prefix, client=boto3.client("s3")): + response = client.list_buckets() + bucket_filter = [ + bucket["Name"] + for bucket in response["Buckets"] + if bucket_prefix in bucket["Name"] + ] - response = client.list_buckets() - bucket_filter = [ - bucket["Name"] - for bucket in response["Buckets"] - if bucket_prefix in bucket["Name"] - ] - if not bucket_filter: - raise ValueError(f"No bucket found with prefix: {bucket_prefix}") - - return bucket_filter[0] - - - + return bucket_filter[0] def list_existing_s3_files(bucket_name, client=boto3.client("s3")): @@ -217,7 +204,7 @@ def list_existing_s3_files(bucket_name, client=boto3.client("s3")): existing_files = [obj["Key"] for obj in response["Contents"]] else: logger.error("The bucket is empty") - return None + return [] # changed from None to [] so it is an iterable except ClientError as e: logger.error(f"Error listing S3 objects: {e}") diff --git a/terraform/lambda.tf b/terraform/lambda.tf index d33a6c9..5f4a58e 100644 --- a/terraform/lambda.tf +++ b/terraform/lambda.tf @@ -83,11 +83,13 @@ resource "aws_lambda_function" "extract_lambda" { # Transform Lambda Function # ############################# + data "archive_file" "transform_lambda_zip" { type = "zip" - source_file = "${path.module}/../src/transform_lambda.py" - output_path = "${path.module}/../transform_function.zip" + source_dir = "${path.module}../src/transform_lambda" + output_path = "${path.module}../transform_lambda.zip" } + resource "aws_s3_object" "transform_lambda_code" { bucket = aws_s3_bucket.lambda_code_bucket.bucket key = "${var.transform_lambda_name}/transform_function.zip" diff --git a/tests/test_fact_sales_order.py b/tests/test_dataframes.py index a245379..ea7bad1 100644 --- a/tests/test_fact_sales_order.py +++ b/tests/test_dataframes.py @@ -54,7 +54,8 @@ class TestCreateDimStaff: "email_address": ["Hello", "Bye"], "department_id": ["Hello", "Bye"], } - test_df = {"staff": pd.DataFrame(data=d), "department": pd.DataFrame(data=d2)} + test_df = {"staff": pd.DataFrame( + data=d), "department": pd.DataFrame(data=d2)} result = create_dim_staff(test_df) assert isinstance(result, pd.DataFrame) @@ -71,7 +72,10 @@ class TestCreateDimStaff: "email_address": ["Hello", "Bye"], "department_id": ["Hello", "Bye"], } - test_df = {"staff": pd.DataFrame(data=d), "department": pd.DataFrame(data=d2)} + + test_df = {"staff": pd.DataFrame( + data=d), "department": pd.DataFrame(data=d2)} + result = create_dim_staff(test_df) expected_d = { "staff_id": ["Hello", "Bye"], @@ -88,7 +92,9 @@ class TestCreateDimStaff: class TestCreatePaymentType: def test_create_dim_payment_type_returns_correct_columns_and_values(self): - d = {"payment_type_id": ["Hello", "Bye"], "payment_type_name": ["Hello", "Bye"]} + d = {"payment_type_id": ["Hello", "Bye"], + "payment_type_name": ["Hello", "Bye"]} + test_df = {"payment_type": pd.DataFrame(data=d)} result = create_dim_payment_type(test_df) expected_columns = ["payment_type_id", "payment_type_name"] @@ -180,11 +186,14 @@ class TestCreateDimDate: index=[0], ) df_two = pd.DataFrame( - data={"updated_date": dt(2020, 5, 17), "created_date": dt(2021, 9, 13)}, + data={"updated_date": dt(2020, 5, 17), + "created_date": dt(2021, 9, 13)}, index=[0], ) df_three = pd.DataFrame( - data={"updated_date": dt(2022, 5, 17), "created_date": dt(2023, 5, 13)}, + data={"updated_date": dt(2022, 5, 17), + "created_date": dt(2023, 5, 13)}, + index=[0], ) expected_df = pd.DataFrame( @@ -214,7 +223,8 @@ class TestCreateDimDate: mock_fso.return_value = df_three result = create_dim_date({"dum": 0}) result.reset_index(inplace=True, drop=True) - assert result.eq(expected_df, axis="columns").all(axis=None) + assert result.eq( + expected_df, axis="columns").all(axis=None) class TestCreateDimLocation: @@ -222,7 +232,9 @@ class TestCreateDimLocation: dict_df = { "address": pd.DataFrame( data=[["some_time", "some_other_time", 1, "SE18 9QO"]], - columns=["created_at", "last_updated", "address_id", "postal_code"], + columns=["created_at", "last_updated", + "address_id", "postal_code"], + ) } result = create_dim_location(dict_df) @@ -244,3 +256,50 @@ class TestCreateDimTransaction: } result = create_dim_transaction(dict_df) assert list(result.columns) == ["transaction_id", "some_other_id"] + + +class TestCreateFactPayment: + def test_returns_correct_columns_payment(self): + dict_df = { + "payment": pd.DataFrame( + data=[ + [ + dt.strptime( + "2022-11-03 14:20:49.962846", "%Y-%m-%d %H:%M:%S.%f" + ), + dt.strptime( + "2022-12-14 16:20:49.962194", "%Y-%m-%d %H:%M:%S.%f" + ), + 1, + "SE18 9QO", + "2020-07-16", + ] + ], + columns=[ + "created_at", + "last_updated", + "payment_id", + "some_other_id", + "payment_date", + ], + ) + } + expected_cols = [ + "payment_record_id", + "created_date", + "created_time", + "last_updated_date", + "last_updated_time", + "payment_date", + "payment_id", + "some_other_id", + ] + result = create_fact_payment(dict_df) + assert isinstance(result, pd.DataFrame) + for col in list(result.columns): + assert col in expected_cols + for col in expected_cols: + + +if "_date" or "_time" in col: + assert result[col].dtype == "O" diff --git a/tests/test_load_lambda.py b/tests/test_load_lambda.py new file mode 100644 index 0000000..65106f7 --- /dev/null +++ b/tests/test_load_lambda.py @@ -0,0 +1,196 @@ +import pandas as pd +import pyarrow.parquet as pq +from io import BytesIO +from moto import mock_aws +import boto3 +import botocore.exceptions +import os +import pytest +from src.load_lambda import * +import tempfile + + +@pytest.fixture(scope="class") +def aws_credentials(): + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_DEFAULT_REGION"] = "eu-west-2" + + +@pytest.fixture(scope="class") +def mock_s3_client(aws_credentials): + with mock_aws(): + yield boto3.client("s3") + + +@pytest.fixture(scope="class") +def mock_sm_client(aws_credentials): + with mock_aws(): + yield boto3.client("secretsmanager") + + +class TestLambdaHandler: + def test_lambda_handler_returns_200_and_table_name_if_uploaded(self, mocker): + mocker.patch( + "src.load_lambda.upload_dfs_to_database", + return_value={"uploaded": ["table_one", "table_two"], "not_uploaded": []}, + ) + result = lambda_handler(None, None) + assert result["statusCode"] == 200 + assert "table_one" in result["body"] + assert "table_two" in result["body"] + + def test_lambda_handler_returns_200_and_table_name_if_not_uploaded(self, mocker): + mocker.patch( + "src.load_lambda.upload_dfs_to_database", + return_value={"uploaded": [], "not_uploaded": ["table_one"]}, + ) + result = lambda_handler(None, None) + assert result["statusCode"] == 200 + assert "No dataframes were uploaded" in result["body"] + + def test_lambda_handler_returns_error_if_both_lists_empty(self, mocker): + mocker.patch( + "src.load_lambda.upload_dfs_to_database", + return_value={"uploaded": [], "not_uploaded": []}, + ) + + result = lambda_handler(None, None) + + assert result == {"error"} + + +class TestRetrieveSecrets: + def test_retrieve_secrets_returns_dictionary(self, mock_sm_client): + secret = { + "cohort_id": "test_cohort_id", + "user": "test_user_id", + "password": "test_password", + "host": "test_host", + "database": "test_database", + "port": "test_port", + } + + secret_name = "test_secret" + + mock_sm_client.create_secret(Name=secret_name, SecretString=json.dumps(secret)) + + result = json.loads(retrieve_secrets(mock_sm_client, secret_name)) + + assert isinstance(result, dict) + + def test_retrieve_secrets_returns_correct_keys_and_values(self, mock_sm_client): + secret_name = "test_secret" + + result = json.loads(retrieve_secrets(mock_sm_client, secret_name)) + + assert result["user"] == "test_user_id" + assert result["password"] == "test_password" + + def test_retrieve_secrets_returns_client_error_if_no_secret(self, mock_sm_client): + secret_name = "another_test_secret" + + with pytest.raises(botocore.exceptions.ClientError) as error: + retrieve_secrets(mock_sm_client, secret_name) + + +class TestConnectToDBAndReturnEngine: + def test_returns_unsuccessful_connection_when_wrong_credentials(self): + sm_secret = { + "host": "host", + "port": "port", + "user": "user", + "password": "password", + "database": "database", + } + + with pytest.raises(Exception): + connect_to_db_and_return_engine(json.dumps(sm_secret)) + + +class TestGetTransformBucket: + def test_raises_value_error_if_no_buckets(self, mock_s3_client): + with pytest.raises(ValueError, match="No transform bucket found"): + get_transform_bucket(mock_s3_client) + + def test_raises_value_error_if_no_transform_bucket(self, mock_s3_client): + mock_s3_client.create_bucket( + Bucket="extract_bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + with pytest.raises(ValueError, match="No transform bucket found"): + get_transform_bucket(mock_s3_client) + + def test_returns_transform_bucket_if_one_bucket(self, mock_s3_client): + mock_s3_client.create_bucket( + Bucket="transform_bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + result = get_transform_bucket(mock_s3_client) + assert result == "transform_bucket" + + def test_only_returns_transform_bucket_if_several_buckets(self, mock_s3_client): + mock_s3_client.create_bucket( + Bucket="another_test_bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + result = get_transform_bucket(mock_s3_client) + assert result == "transform_bucket" + + +class TestConvertParquetToDfs: + def test_function_returns_empty_dictionary_if_no_files(self, mock_s3_client): + mock_s3_client.create_bucket( + Bucket="transform_bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + result = convert_parquet_files_to_dfs( + bucket_name="transform_bucket", client=mock_s3_client + ) + assert result == {} + + # def test_function_returns_dictionary_with_table_with_file_key(): + # # need to mock parquet file and upload to mock bucket + # result = convert_parquet_files_to_dfs(bucket_name="transform_bucket", client=mock_s3_client) + # assert "dim_staff" in result + + def test_function_returns_dictionary_with_file_key_and_dataframe( + self, mock_s3_client + ): + with tempfile.TemporaryDirectory() as tmp: + d = { + "test": ["Hello", "Bye"], + "design_id": ["Hello", "Bye"], + "design_name": ["Hello", "Bye"], + "file_name": ["Hello", "Bye"], + "file_location": ["Hello", "Bye"], + "Hello": ["Hello", "Bye"], + } + + test_df = pd.DataFrame(data=d) + + path = os.path.join(tmp, "test_parquet.parquet") + + test_df.to_parquet(path, engine="pyarrow") + + with open(path, "rb") as p: + mock_s3_client.put_object( + Bucket="transform_bucket", Key="test_parquet.parquet", Body=p.read() + ) + + result = convert_parquet_files_to_dfs( + bucket_name="transform_bucket", client=mock_s3_client + ) + + assert "test_parquet.parquet" in result + + pd.testing.assert_frame_equal(result["test_parquet.parquet"], test_df) + + +class TestUploadDfsToDatabase: + # Full success test + # Partial success test + # Failure test + pass |
