aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/extract_lambda.py40
-rw-r--r--tests/test_secrets_manager.py73
2 files changed, 96 insertions, 17 deletions
diff --git a/src/extract_lambda.py b/src/extract_lambda.py
index fb2d7e8..f4c0c1d 100644
--- a/src/extract_lambda.py
+++ b/src/extract_lambda.py
@@ -1,5 +1,4 @@
-from pg8000.native import Connection, DatabaseError, InterfaceError
-from dotenv import dotenv_values
+from pg8000.native import Connection, InterfaceError
import boto3
import csv
from botocore.exceptions import ClientError
@@ -42,31 +41,35 @@ def lambda_handler(event, context):
'statusCode': 200,
'body': json.dumps('CSV files processed and uploaded successfully.')
}
-
except Exception as e:
logger.error(f'Error: {e}')
return {
'statusCode': 500,
'body': json.dumps('Internal server error.')
}
-
finally:
-
if db:
db.close()
-def get_config(path: str = ".env") -> dict:
- return dotenv_values(path)
+def retrieve_secrets(sm_client=boto3.client('secretsmanager'), secret_name='bentley-secrets'):
+ try:
+ response = sm_client.get_secret_value(SecretId=secret_name)
+ if 'SecretString' in response:
+ secret = json.loads(response['SecretString'])
+ return secret
+ except ClientError as e:
+ logger.error(f'Could not retrieve secrets: {e}')
+ raise e
def connect_to_database() -> Connection:
try:
- config = get_config()
- host = config["host"]
- port = config["port"]
- user = config["user"]
- password = config["password"]
- database = config["database"]
+ secrets = retrieve_secrets()
+ host = secrets["host"]
+ port = secrets["port"]
+ user = secrets["user"]
+ password = secrets["password"]
+ database = secrets["database"]
return Connection(
database=database,
@@ -79,9 +82,12 @@ def connect_to_database() -> Connection:
logger.error(f'Interface error: {i}')
raise DBConnectionException("Failed to connect to database")
+def extract_bucket(client=boto3.client('s3')):
+ response = client.list_buckets()
+ extract_bucket_filter = [bucket['Name'] for bucket in response['Buckets'] if 'extract' in bucket['Name']]
+ return extract_bucket_filter[0]
-
-def list_existing_s3_files(bucket_name='extract_bucket', client=boto3.client('s3')):
+def list_existing_s3_files(bucket_name=extract_bucket(), client=boto3.client('s3')):
"""Creates a dictionary and populates it with the
results of listing the contents of the s3 bucket, then
returns the populated dictionary
@@ -90,7 +96,7 @@ def list_existing_s3_files(bucket_name='extract_bucket', client=boto3.client('s3
existing_files = {}
try:
- response = client.list_objects_v2(Bucket='extract_bucket')
+ response = client.list_objects_v2(Bucket=bucket_name)
if 'Contents' in response:
for obj in response['Contents']:
@@ -147,7 +153,7 @@ def process_and_upload_tables(db, existing_files, client=boto3.client('s3')):
## END OF NEW CODE
if existing_files[latest_s3_object_key] != new_csv_content:
try:
- client.upload_file(csv_file_path, 'extract_bucket', s3_key)
+ client.upload_file(csv_file_path, extract_bucket(), s3_key)
logger.info(f"Uploaded {s3_key} to S3.")
except ClientError as e:
logger.error(f'Error uploading to S3: {e}')
diff --git a/tests/test_secrets_manager.py b/tests/test_secrets_manager.py
new file mode 100644
index 0000000..a30be86
--- /dev/null
+++ b/tests/test_secrets_manager.py
@@ -0,0 +1,73 @@
+from src.secrets_manager import sm_client, retrieve_secrets
+import boto3
+import botocore.exceptions
+from moto import mock_aws
+import json
+import pytest
+import os
+
+@pytest.fixture(scope='function')
+def aws_credentials():
+ """Mocked AWS Credentials for moto."""
+ os.environ["AWS_ACCESS_KEY_ID"] = "testing"
+ os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
+ os.environ["AWS_SECURITY_TOKEN"] = "testing"
+ os.environ["AWS_SESSION_TOKEN"] = "testing"
+ os.environ["AWS_DEFAULT_REGION"] = "eu-west-2"
+
+@pytest.fixture(scope='function')
+def mock_sm_client(aws_credentials):
+ with mock_aws():
+ yield boto3.client("secretsmanager")
+
+@pytest.fixture(scope='function')
+def mock_store_secret(mock_sm_client):
+ secret = {
+ "cohort_id": "test_cohort_id",
+ "user": "test_user_id",
+ "password": "test_password",
+ "host": "test_host",
+ "database": "test_database",
+ "port": "test_port"
+ }
+
+ secret_name = "test_secret"
+
+ response = mock_sm_client.create_secret(Name=secret_name, SecretString=json.dumps(secret))
+
+ return response
+
+def test_retrieves_secrets_returns_dictionary(mock_sm_client, mock_store_secret):
+ secret_name = "test_secret"
+
+ result = retrieve_secrets(mock_sm_client, secret_name)
+
+ assert isinstance(result, dict)
+
+def test_retrieves_secrets_returns_correct_keys_and_values(mock_sm_client, mock_store_secret):
+
+ secret_name = "test_secret"
+
+ result = retrieve_secrets(mock_sm_client, secret_name)
+
+ assert result["cohort_id"] == "test_cohort_id"
+ assert result["user"] == "test_user_id"
+ assert result["password"] == "test_password"
+ assert result["host"] == "test_host"
+ assert result["database"] == "test_database"
+ assert result["port"] == "test_port"
+
+def test_retrieves_secrets_raises_error_if_secret_name_incorrect_data_type(mock_sm_client):
+ secret_name = [1, 2, 3]
+
+
+ with pytest.raises(botocore.exceptions.ParamValidationError) as error:
+ retrieve_secrets(mock_sm_client, secret_name)
+
+
+def test_retrieves_secrets_raises_error_if_secret_name_does_not_exist(mock_sm_client, mock_store_secret):
+ secret_name = 'test_secret_2'
+
+
+ with pytest.raises(botocore.exceptions.ClientError) as error:
+ retrieve_secrets(mock_sm_client, secret_name) \ No newline at end of file
git.ajschof.me — hosted by ajschofield — powered by cgit