aboutsummaryrefslogtreecommitdiffstats
path: root/src/load_lambda.py
blob: 11d1d7071e016ff4d33836586b248036a2c7ee9c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import boto3
from botocore.exceptions import ClientError
import pandas as pd
import pyarrow.parquet as pq
from io import BytesIO
import logging
import json
from sqlalchemy import create_engine


logger = logging.getLogger(__name__)

logging.basicConfig(
    format="{asctime} - {levelname} - {message}",
    style="{",
    datefmt="%Y-%m-%d %H:%M",
    level=logging.DEBUG,
)

logging.getLogger("botocore").setLevel(logging.INFO)


def lambda_handler(event, context):
    try:
        uploaded_tables = upload_dfs_to_database()
        if not uploaded_tables["uploaded"]:
            return {
                "statusCode": 200,
                "body": json.dumps("No dataframes were uploaded."),
            }
        return {
            "statusCode": 200,
            "body": json.dumps(
                f"""The following dataframes were uploaded successfully: 
                {uploaded_tables["uploaded"]} ."""
            ),
        }
    except Exception as e:
        logger.error(f"Error: {e}", exc_info=True)
        return {"statusCode": 500, "body": json.dumps("Internal server error.")}


def retrieve_secrets(client=None, secret_name=None):
    session = boto3.session.Session()
    region_name = "eu-west-2"

    if secret_name == None:
        secret_name = "bentley-RDS-credentials"
    if client == None:
        client = session.client(service_name="secretsmanager", region_name=region_name)

    try:
        get_secret_value_response = client.get_secret_value(SecretId=secret_name)
        print(get_secret_value_response)
    except ClientError as e:
        logger.error(f"Failed to retrieve secret {secret_name}: {str(e)}")
        raise e
    except KeyError:
        logger.error(f"Secret {secret_name} does not contain a SecretString")
        raise ValueError(f"Secret {secret_name} does not contain a SecretString")

    return json.loads(get_secret_value_response["SecretString"])


# connect to database, slightly different way of doing it, to allow manipulation through pandas


def connect_to_db_and_return_engine(sm_secret=None):
    if sm_secret is None:
        sm_secret = retrieve_secrets()

    try:
        secrets = json.loads(sm_secret)
        host = secrets["host"]
        port = secrets["port"]
        user = secrets["user"]
        password = secrets["password"]
        database = secrets["database"]
        conn_str = f"postgresql+pg8000://{user}:{password}@{host}:{port}/{database}"
        # interface between python (pandas) and SQL
        engine = create_engine(conn_str)
        return engine
    except Exception as e:
        logger.error(f"Interface error: {e}")
        raise RuntimeError("Failed to create database engine")


# get transform bucket
def get_transform_bucket(client=None):
    if client is None:
        client = boto3.client("s3")
    try:
        response = client.list_buckets()
    except ClientError as e:
        logger.error(f"Error listing S3 buckets: {e}")
        raise RuntimeError("Error listing S3 buckets")

    transform_bucket_filter = [
        bucket["Name"]
        for bucket in response["Buckets"]
        if "transform" in bucket["Name"]
    ]

    if not transform_bucket_filter:
        logger.error("No transform bucket found")
        raise ValueError("No transform bucket found")

    return transform_bucket_filter[0]


# list and then retrieve parquet files from S3 bucket
# convert parquet files into dataframes
# return a dictionary of dataframes with name as key, and dataframe object as value


def convert_parquet_files_to_dfs(bucket_name=None, client=None):
    try:
        if client is None:
            client = boto3.client("s3")
        if bucket_name is None:
            bucket_name = get_transform_bucket()
        files = client.list_objects_v2(Bucket=bucket_name)

        dfs = {}
        if "Contents" in files:
            for file in files["Contents"]:
                file_key = file["Key"]
                try:
                    file_obj = client.get_object(Bucket=bucket_name, Key=file_key)
                    parquet_file = pq.ParquetFile(BytesIO(file_obj["Body"].read()))
                    df = parquet_file.read().to_pandas()
                    dfs[file_key] = df
                except ClientError as e:
                    logger.error(f"Unable to retrieve S3 object {file_key}: {e}")
                except Exception as e:
                    logger.error(f"Unable to process file {file_key}: {e}")
        else:
            logger.error(f"No files found in {bucket_name}.")
            return {}
    except ValueError as value_error:
        logger.error(f"Unable to list objects: {value_error}")
        raise
    except ClientError as client_error:
        logger.error(f"Unable to list objects: {client_error}")
        raise

    return dfs


def upload_dfs_to_database():
    upload_status = {"uploaded": [], "not_uploaded": []}
    dict_of_dfs = convert_parquet_files_to_dfs()
    db_engine = connect_to_db_and_return_engine()
    immutable_df_dict = [
        "dim_counterparty.parquet",
        "dim_date.parquet",  # this needs to be mutable
        "dim_location.parquet",
        "dim_staff.parquet",
        "dim_design.parquet",
    ]
    mutable_df_dict = [
        "fact_sales_order",
        "fact_purchase_order",
        "fact_payment",
        "dim_currency",
    ]

    for file_name, df in dict_of_dfs.items():
        if file_name in immutable_df_dict:
            table_name = file_name.split(".")[0]
            try:
                df.to_sql(
                    table_name,
                    con=db_engine,
                    schema="project_team_2",
                    if_exists="append",
                    index=False,
                )
                upload_status["uploaded"].append(table_name)
            except Exception as e:
                logger.error(f"Error uploading dataframe {file_name} to database: {e}")
                raise
        elif file_name.rsplit("_", 1)[0] in mutable_df_dict:
            table_name = file_name.rsplit("_", 1)[0]
            try:
                df.to_sql(
                    table_name,
                    con=db_engine,
                    schema="project_team_2",
                    if_exists="append",
                    index=False,
                )
                upload_status["uploaded"].append(table_name)
            except Exception as e:
                logger.error(f"Error uploading dataframe {file_name} to database: {e}")
                raise
        else:
            upload_status["not_uploaded"].append(file_name)
            logger.error(f"{file_name} does not correspond with table in database")
    db_engine.dispose()
    return upload_status


if __name__ == "__main__":
    lambda_handler(None, None)
git.ajschof.me — hosted by ajschofield — powered by cgit