diff options
Diffstat (limited to 'tests/test_load_lambda.py')
| -rw-r--r-- | tests/test_load_lambda.py | 93 |
1 files changed, 64 insertions, 29 deletions
diff --git a/tests/test_load_lambda.py b/tests/test_load_lambda.py index 65106f7..b284588 100644 --- a/tests/test_load_lambda.py +++ b/tests/test_load_lambda.py @@ -1,13 +1,20 @@ import pandas as pd -import pyarrow.parquet as pq -from io import BytesIO from moto import mock_aws import boto3 import botocore.exceptions import os import pytest -from src.load_lambda import * +from src.load_lambda import ( + lambda_handler, + retrieve_secrets, + connect_to_db_and_return_engine, + convert_parquet_files_to_dfs, + get_transform_bucket, + upload_dfs_to_database, +) import tempfile +import json +from unittest.mock import MagicMock, patch @pytest.fixture(scope="class") @@ -20,19 +27,20 @@ def aws_credentials(): @pytest.fixture(scope="class") -def mock_s3_client(aws_credentials): +def mock_s3_client(): with mock_aws(): yield boto3.client("s3") @pytest.fixture(scope="class") -def mock_sm_client(aws_credentials): +def mock_sm_client(): with mock_aws(): yield boto3.client("secretsmanager") class TestLambdaHandler: - def test_lambda_handler_returns_200_and_table_name_if_uploaded(self, mocker): + @staticmethod + def test_lambda_handler_returns_200_and_table_name_if_uploaded(mocker): mocker.patch( "src.load_lambda.upload_dfs_to_database", return_value={"uploaded": ["table_one", "table_two"], "not_uploaded": []}, @@ -42,7 +50,8 @@ class TestLambdaHandler: assert "table_one" in result["body"] assert "table_two" in result["body"] - def test_lambda_handler_returns_200_and_table_name_if_not_uploaded(self, mocker): + @staticmethod + def test_lambda_handler_returns_200_and_table_name_if_not_uploaded(mocker): mocker.patch( "src.load_lambda.upload_dfs_to_database", return_value={"uploaded": [], "not_uploaded": ["table_one"]}, @@ -51,7 +60,8 @@ class TestLambdaHandler: assert result["statusCode"] == 200 assert "No dataframes were uploaded" in result["body"] - def test_lambda_handler_returns_error_if_both_lists_empty(self, mocker): + @staticmethod + def test_lambda_handler_returns_error_if_both_lists_empty(mocker): mocker.patch( "src.load_lambda.upload_dfs_to_database", return_value={"uploaded": [], "not_uploaded": []}, @@ -63,7 +73,8 @@ class TestLambdaHandler: class TestRetrieveSecrets: - def test_retrieve_secrets_returns_dictionary(self, mock_sm_client): + @staticmethod + def test_retrieve_secrets_returns_dictionary(mock_sm_client): secret = { "cohort_id": "test_cohort_id", "user": "test_user_id", @@ -81,7 +92,8 @@ class TestRetrieveSecrets: assert isinstance(result, dict) - def test_retrieve_secrets_returns_correct_keys_and_values(self, mock_sm_client): + @staticmethod + def test_retrieve_secrets_returns_correct_keys_and_values(mock_sm_client): secret_name = "test_secret" result = json.loads(retrieve_secrets(mock_sm_client, secret_name)) @@ -89,7 +101,8 @@ class TestRetrieveSecrets: assert result["user"] == "test_user_id" assert result["password"] == "test_password" - def test_retrieve_secrets_returns_client_error_if_no_secret(self, mock_sm_client): + @staticmethod + def test_retrieve_secrets_returns_client_error_if_no_secret(mock_sm_client): secret_name = "another_test_secret" with pytest.raises(botocore.exceptions.ClientError) as error: @@ -97,7 +110,8 @@ class TestRetrieveSecrets: class TestConnectToDBAndReturnEngine: - def test_returns_unsuccessful_connection_when_wrong_credentials(self): + @staticmethod + def test_returns_unsuccessful_connection_when_wrong_credentials(): sm_secret = { "host": "host", "port": "port", @@ -111,11 +125,13 @@ class TestConnectToDBAndReturnEngine: class TestGetTransformBucket: - def test_raises_value_error_if_no_buckets(self, mock_s3_client): + @staticmethod + def test_raises_value_error_if_no_buckets(mock_s3_client): with pytest.raises(ValueError, match="No transform bucket found"): get_transform_bucket(mock_s3_client) - def test_raises_value_error_if_no_transform_bucket(self, mock_s3_client): + @staticmethod + def test_raises_value_error_if_no_transform_bucket(mock_s3_client): mock_s3_client.create_bucket( Bucket="extract_bucket", CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, @@ -123,7 +139,8 @@ class TestGetTransformBucket: with pytest.raises(ValueError, match="No transform bucket found"): get_transform_bucket(mock_s3_client) - def test_returns_transform_bucket_if_one_bucket(self, mock_s3_client): + @staticmethod + def test_returns_transform_bucket_if_one_bucket(mock_s3_client): mock_s3_client.create_bucket( Bucket="transform_bucket", CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, @@ -131,7 +148,8 @@ class TestGetTransformBucket: result = get_transform_bucket(mock_s3_client) assert result == "transform_bucket" - def test_only_returns_transform_bucket_if_several_buckets(self, mock_s3_client): + @staticmethod + def test_only_returns_transform_bucket_if_several_buckets(mock_s3_client): mock_s3_client.create_bucket( Bucket="another_test_bucket", CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, @@ -141,7 +159,8 @@ class TestGetTransformBucket: class TestConvertParquetToDfs: - def test_function_returns_empty_dictionary_if_no_files(self, mock_s3_client): + @staticmethod + def test_function_returns_empty_dictionary_if_no_files(mock_s3_client): mock_s3_client.create_bucket( Bucket="transform_bucket", CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, @@ -151,14 +170,8 @@ class TestConvertParquetToDfs: ) assert result == {} - # def test_function_returns_dictionary_with_table_with_file_key(): - # # need to mock parquet file and upload to mock bucket - # result = convert_parquet_files_to_dfs(bucket_name="transform_bucket", client=mock_s3_client) - # assert "dim_staff" in result - - def test_function_returns_dictionary_with_file_key_and_dataframe( - self, mock_s3_client - ): + @staticmethod + def test_function_returns_dictionary_with_file_key_and_dataframe(mock_s3_client): with tempfile.TemporaryDirectory() as tmp: d = { "test": ["Hello", "Bye"], @@ -190,7 +203,29 @@ class TestConvertParquetToDfs: class TestUploadDfsToDatabase: - # Full success test - # Partial success test - # Failure test - pass + @pytest.fixture + def mock_engine(self): + engine = MagicMock() + engine.dispose = MagicMock() + return engine + + @pytest.fixture + def mock_df(self): + df = MagicMock(spec=pd.DataFrame) + df.to_sql = MagicMock() + return df + + @staticmethod + def test_function_returns_dictionary_with_uploaded_and_not_uploaded_keys( + mock_engine, mock_df + ): + with patch( + "src.load_lambda.convert_parquet_files_to_dfs", + return_value={"dim_counterparty.parquet": mock_df}, + ), patch( + "src.load_lambda.connect_to_db_and_return_engine", return_value=mock_engine + ): + result = upload_dfs_to_database() + + assert "uploaded" in result + assert "not_uploaded" in result |
