aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--terraform/lambda.tf4
-rw-r--r--tests/test_load_lambda.py93
2 files changed, 66 insertions, 31 deletions
diff --git a/terraform/lambda.tf b/terraform/lambda.tf
index 5f4a58e..b6f36fb 100644
--- a/terraform/lambda.tf
+++ b/terraform/lambda.tf
@@ -86,8 +86,8 @@ resource "aws_lambda_function" "extract_lambda" {
data "archive_file" "transform_lambda_zip" {
type = "zip"
- source_dir = "${path.module}../src/transform_lambda"
- output_path = "${path.module}../transform_lambda.zip"
+ source_dir = "${path.module}/../src/transform_lambda"
+ output_path = "${path.module}/../transform_lambda.zip"
}
resource "aws_s3_object" "transform_lambda_code" {
diff --git a/tests/test_load_lambda.py b/tests/test_load_lambda.py
index 65106f7..b284588 100644
--- a/tests/test_load_lambda.py
+++ b/tests/test_load_lambda.py
@@ -1,13 +1,20 @@
import pandas as pd
-import pyarrow.parquet as pq
-from io import BytesIO
from moto import mock_aws
import boto3
import botocore.exceptions
import os
import pytest
-from src.load_lambda import *
+from src.load_lambda import (
+ lambda_handler,
+ retrieve_secrets,
+ connect_to_db_and_return_engine,
+ convert_parquet_files_to_dfs,
+ get_transform_bucket,
+ upload_dfs_to_database,
+)
import tempfile
+import json
+from unittest.mock import MagicMock, patch
@pytest.fixture(scope="class")
@@ -20,19 +27,20 @@ def aws_credentials():
@pytest.fixture(scope="class")
-def mock_s3_client(aws_credentials):
+def mock_s3_client():
with mock_aws():
yield boto3.client("s3")
@pytest.fixture(scope="class")
-def mock_sm_client(aws_credentials):
+def mock_sm_client():
with mock_aws():
yield boto3.client("secretsmanager")
class TestLambdaHandler:
- def test_lambda_handler_returns_200_and_table_name_if_uploaded(self, mocker):
+ @staticmethod
+ def test_lambda_handler_returns_200_and_table_name_if_uploaded(mocker):
mocker.patch(
"src.load_lambda.upload_dfs_to_database",
return_value={"uploaded": ["table_one", "table_two"], "not_uploaded": []},
@@ -42,7 +50,8 @@ class TestLambdaHandler:
assert "table_one" in result["body"]
assert "table_two" in result["body"]
- def test_lambda_handler_returns_200_and_table_name_if_not_uploaded(self, mocker):
+ @staticmethod
+ def test_lambda_handler_returns_200_and_table_name_if_not_uploaded(mocker):
mocker.patch(
"src.load_lambda.upload_dfs_to_database",
return_value={"uploaded": [], "not_uploaded": ["table_one"]},
@@ -51,7 +60,8 @@ class TestLambdaHandler:
assert result["statusCode"] == 200
assert "No dataframes were uploaded" in result["body"]
- def test_lambda_handler_returns_error_if_both_lists_empty(self, mocker):
+ @staticmethod
+ def test_lambda_handler_returns_error_if_both_lists_empty(mocker):
mocker.patch(
"src.load_lambda.upload_dfs_to_database",
return_value={"uploaded": [], "not_uploaded": []},
@@ -63,7 +73,8 @@ class TestLambdaHandler:
class TestRetrieveSecrets:
- def test_retrieve_secrets_returns_dictionary(self, mock_sm_client):
+ @staticmethod
+ def test_retrieve_secrets_returns_dictionary(mock_sm_client):
secret = {
"cohort_id": "test_cohort_id",
"user": "test_user_id",
@@ -81,7 +92,8 @@ class TestRetrieveSecrets:
assert isinstance(result, dict)
- def test_retrieve_secrets_returns_correct_keys_and_values(self, mock_sm_client):
+ @staticmethod
+ def test_retrieve_secrets_returns_correct_keys_and_values(mock_sm_client):
secret_name = "test_secret"
result = json.loads(retrieve_secrets(mock_sm_client, secret_name))
@@ -89,7 +101,8 @@ class TestRetrieveSecrets:
assert result["user"] == "test_user_id"
assert result["password"] == "test_password"
- def test_retrieve_secrets_returns_client_error_if_no_secret(self, mock_sm_client):
+ @staticmethod
+ def test_retrieve_secrets_returns_client_error_if_no_secret(mock_sm_client):
secret_name = "another_test_secret"
with pytest.raises(botocore.exceptions.ClientError) as error:
@@ -97,7 +110,8 @@ class TestRetrieveSecrets:
class TestConnectToDBAndReturnEngine:
- def test_returns_unsuccessful_connection_when_wrong_credentials(self):
+ @staticmethod
+ def test_returns_unsuccessful_connection_when_wrong_credentials():
sm_secret = {
"host": "host",
"port": "port",
@@ -111,11 +125,13 @@ class TestConnectToDBAndReturnEngine:
class TestGetTransformBucket:
- def test_raises_value_error_if_no_buckets(self, mock_s3_client):
+ @staticmethod
+ def test_raises_value_error_if_no_buckets(mock_s3_client):
with pytest.raises(ValueError, match="No transform bucket found"):
get_transform_bucket(mock_s3_client)
- def test_raises_value_error_if_no_transform_bucket(self, mock_s3_client):
+ @staticmethod
+ def test_raises_value_error_if_no_transform_bucket(mock_s3_client):
mock_s3_client.create_bucket(
Bucket="extract_bucket",
CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
@@ -123,7 +139,8 @@ class TestGetTransformBucket:
with pytest.raises(ValueError, match="No transform bucket found"):
get_transform_bucket(mock_s3_client)
- def test_returns_transform_bucket_if_one_bucket(self, mock_s3_client):
+ @staticmethod
+ def test_returns_transform_bucket_if_one_bucket(mock_s3_client):
mock_s3_client.create_bucket(
Bucket="transform_bucket",
CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
@@ -131,7 +148,8 @@ class TestGetTransformBucket:
result = get_transform_bucket(mock_s3_client)
assert result == "transform_bucket"
- def test_only_returns_transform_bucket_if_several_buckets(self, mock_s3_client):
+ @staticmethod
+ def test_only_returns_transform_bucket_if_several_buckets(mock_s3_client):
mock_s3_client.create_bucket(
Bucket="another_test_bucket",
CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
@@ -141,7 +159,8 @@ class TestGetTransformBucket:
class TestConvertParquetToDfs:
- def test_function_returns_empty_dictionary_if_no_files(self, mock_s3_client):
+ @staticmethod
+ def test_function_returns_empty_dictionary_if_no_files(mock_s3_client):
mock_s3_client.create_bucket(
Bucket="transform_bucket",
CreateBucketConfiguration={"LocationConstraint": "eu-west-2"},
@@ -151,14 +170,8 @@ class TestConvertParquetToDfs:
)
assert result == {}
- # def test_function_returns_dictionary_with_table_with_file_key():
- # # need to mock parquet file and upload to mock bucket
- # result = convert_parquet_files_to_dfs(bucket_name="transform_bucket", client=mock_s3_client)
- # assert "dim_staff" in result
-
- def test_function_returns_dictionary_with_file_key_and_dataframe(
- self, mock_s3_client
- ):
+ @staticmethod
+ def test_function_returns_dictionary_with_file_key_and_dataframe(mock_s3_client):
with tempfile.TemporaryDirectory() as tmp:
d = {
"test": ["Hello", "Bye"],
@@ -190,7 +203,29 @@ class TestConvertParquetToDfs:
class TestUploadDfsToDatabase:
- # Full success test
- # Partial success test
- # Failure test
- pass
+ @pytest.fixture
+ def mock_engine(self):
+ engine = MagicMock()
+ engine.dispose = MagicMock()
+ return engine
+
+ @pytest.fixture
+ def mock_df(self):
+ df = MagicMock(spec=pd.DataFrame)
+ df.to_sql = MagicMock()
+ return df
+
+ @staticmethod
+ def test_function_returns_dictionary_with_uploaded_and_not_uploaded_keys(
+ mock_engine, mock_df
+ ):
+ with patch(
+ "src.load_lambda.convert_parquet_files_to_dfs",
+ return_value={"dim_counterparty.parquet": mock_df},
+ ), patch(
+ "src.load_lambda.connect_to_db_and_return_engine", return_value=mock_engine
+ ):
+ result = upload_dfs_to_database()
+
+ assert "uploaded" in result
+ assert "not_uploaded" in result
git.ajschof.me — hosted by ajschofield — powered by cgit